In [1]:
from z3 import *
from __future__ import annotations
from dataclasses import dataclass
In [2]:
@dataclass
class Const:
    value: int

@dataclass
class Var:
    name: str

@dataclass
class Sum:
    left: Term
    right: Term

@dataclass
class Difference:
    left: Term
    right: Term

Term = Const | Var | Sum | Difference
In [3]:
bit_width = 32

def term_enc(e):
    match e:
        case Const(value):
            return BitVecVal(value, bit_width)
        case Var(name):
            return BitVec(name, bit_width)
        case Sum(left, right):
            return term_enc(left) + term_enc(right)
In [4]:
type(term_enc(Const(1)))
Out[4]:
z3.z3.BitVecNumRef
In [5]:
@dataclass
class TrueC:
    _: None

@dataclass
class FalseC:
    _: None

@dataclass
class LtF:
    left: Term
    right: Term

@dataclass
class EqF:
    left: Term
    right: Term

@dataclass
class NotF:
    q: Formula

@dataclass
class AndF:
    p: Formula
    q: Formula

@dataclass
class OrF:
    p: Formula
    q: Formula

@dataclass
class ImpliesF:
    p: Formula
    q: Formula

Formula = TrueC | FalseC | LtF | EqF | NotF | AndF | OrF | ImpliesF
In [6]:
def fmla_enc(p: Formula):
    match p:
        case TrueC(_):
            return BoolVal(True)
        case FalseC(_):
            return BoolVal(False)
        case LtF(left, right):
            return term_enc(left) < term_enc(right)
        case EqF(left, right):
            return term_enc(left) == term_enc(right)
        case NotF(p):
            return Not(fmla_enc(p))
        case AndF(p, q):
            return And(fmla_enc(p), fmla_enc(q))
        case OrF(p, q):
            return Or(fmla_enc(p), fmla_enc(q))
        case ImpliesF(p, q):
            return Implies(fmla_enc(p), fmla_enc(q))
In [7]:
type(fmla_enc(EqF(Var('x'), Const(0))))
Out[7]:
z3.z3.BoolRef
In [8]:
@dataclass
class Asgn:
    left: Var
    right: Term

@dataclass
class Seq:
    alpha: Prog
    beta: Prog

@dataclass
class Test:
    q: Formula

@dataclass
class Choice:
    alpha: Prog
    beta: Prog

@dataclass
class Iter:
    alpha: Prog

Prog = Asgn | Seq | Test | Choice | Iter

Let's encode some simple programs.

First, x := x + 5

In [9]:
x = Var('x')
alpha1 = Asgn(x, Sum(x, Const(5)))

Next, (?(x < 5) ; x := x + 1)* ; ?(x >= 5)

Note that this is equivalent to the while loop: while(x < 5) x := x + 1

In [10]:
alpha2 = Seq(
    Iter(
        Seq(
            Test(LtF(x, Const(5))), 
            Asgn(x, Sum(x, Const(1)))
        )
    ),
    Test(NotF(LtF(x, Const(5))))
)
In [11]:
inc = 0
def next(x: Var):
    global inc
    inc += 1
    if len(x.name.split('_')) == 1:
        return Var('{}_{}'.format(x.name, inc))
    else:
        name = x.name.split('_')[0]
        index = int(x.name.split('_')[1])
        return Var('{}_{}'.format(name, inc+1))
    
def post(alpha: Prog, P: BoolRef, max_depth=10):
    
    if max_depth == 0:
        return BoolVal(False)
    
    match alpha:
        case Asgn(left, right):
            next_var = next(left)
            right_sub = substitute(term_enc(right), [(term_enc(left), term_enc(next_var))])
            P_sub = substitute(P, [(term_enc(left), term_enc(next_var))])
            
            return And(term_enc(left) == right_sub, P_sub)
        
        case Seq(alpha, beta):
            
            return post(beta, post(alpha, P, max_depth), max_depth)
        
        case Test(Q):
            
            return And(fmla_enc(Q), P)
        
        case Choice(alpha, beta):
            
            return Or(post(alpha, P, max_depth), post(beta, P, max_depth))
        
        case Iter(alpha):
            
            return Or(P, post(Seq(alpha, Iter(alpha)), P, max_depth=max_depth-1))
In [12]:
P = fmla_enc((EqF(Var('x'), Const(0))))
Q = post(Iter(alpha1), P, max_depth=10)
simplify(Q)
Out[12]:
x = 0 ∨ x = 5 + x_1 ∧ x_1 = 0 ∨ x = 5 + x_2 ∧ x_2 = 5 + x_1 ∧ x_1 = 0 ∨ x = 5 + x_3 ∧ x_3 = 5 + x_2 ∧ x_2 = 5 + x_1 ∧ x_1 = 0 ∨ x = 5 + x_4 ∧ x_4 = 5 + x_3 ∧ x_3 = 5 + x_2 ∧ x_2 = 5 + x_1 ∧ x_1 = 0 ∨ x = 5 + x_5 ∧ x_5 = 5 + x_4 ∧ x_4 = 5 + x_3 ∧ x_3 = 5 + x_2 ∧ x_2 = 5 + x_1 ∧ x_1 = 0 ∨ x = 5 + x_6 ∧ x_6 = 5 + x_5 ∧ x_5 = 5 + x_4 ∧ x_4 = 5 + x_3 ∧ x_3 = 5 + x_2 ∧ x_2 = 5 + x_1 ∧ x_1 = 0 ∨ x = 5 + x_7 ∧ x_7 = 5 + x_6 ∧ x_6 = 5 + x_5 ∧ x_5 = 5 + x_4 ∧ x_4 = 5 + x_3 ∧ x_3 = 5 + x_2 ∧ x_2 = 5 + x_1 ∧ x_1 = 0 ∨ x = 5 + x_8 ∧ x_8 = 5 + x_7 ∧ x_7 = 5 + x_6 ∧ x_6 = 5 + x_5 ∧ x_5 = 5 + x_4 ∧ x_4 = 5 + x_3 ∧ x_3 = 5 + x_2 ∧ x_2 = 5 + x_1 ∧ x_1 = 0 ∨ x = 5 + x_9 ∧ x_9 = 5 + x_8 ∧ x_8 = 5 + x_7 ∧ x_7 = 5 + x_6 ∧ x_6 = 5 + x_5 ∧ x_5 = 5 + x_4 ∧ x_4 = 5 + x_3 ∧ x_3 = 5 + x_2 ∧ x_2 = 5 + x_1 ∧ x_1 = 0
In [13]:
s = Solver()
s.add(Q)
s.add(fmla_enc(EqF(Var('x'), Const(15))))
s.check()
Out[13]:
sat
In [14]:
s.model()
Out[14]:
[x_3 = 10, x_2 = 5, x_9 = 10, x = 15, x_1 = 0]
In [15]:
inc = 0 # to make the counterexample more clear
alpha = Seq(Iter(Seq(Test(LtF(Var('x'), Const(10))),
                     Asgn(Var('x'), Sum(x, Const(1))))),
            Test(NotF(LtF(Var('x'), Const(5)))))
P = EqF(Var('x'), Const(0))
Q = EqF(Var('x'), Const(5))

s = z3.Solver()
s.add(post(alpha, fmla_enc(P)))
s.add(fmla_enc(NotF(Q)))

s.check()
print(s.model())

while s.check() == sat:
    x_val = s.model().evaluate(term_enc(Var('x')), model_completion=True)
    s.add(Not(term_enc(Var('x')) == x_val))
    print('counterexample: x = {}'.format(x_val))
[x_1 = 0,
 x_6 = 5,
 x_5 = 4,
 x = 9,
 x_8 = 7,
 x_4 = 3,
 x_9 = 8,
 x_3 = 2,
 x_7 = 6,
 x_2 = 1]
counterexample: x = 7
counterexample: x = 8
counterexample: x = 9
counterexample: x = 6
In [ ]: