from __future__ import annotations

from z3 import *
from dataclasses import dataclass
from typing import Union

WIDTH = 32

@dataclass
class Const:
	value: int

@dataclass
class Var:
	name: str

@dataclass
class Sum:
	left: Term
	right: Term

@dataclass
class Difference:
	left: Term
	right: Term

@dataclass
class TrueC:
	_: None

@dataclass
class FalseC:
	_: None

@dataclass
class LtF:
	left: Term
	right: Term

@dataclass
class EqF:
	left: Term
	right: Term

@dataclass
class NotF:
	q: Formula

@dataclass
class AndF:
	p: Formula
	q: Formula

@dataclass
class OrF:
	p: Formula
	q: Formula

@dataclass
class ImpliesF:
	p: Formula
	q: Formula

@dataclass
class Asgn:
	left: Var
	right: Term

@dataclass
class Seq:
	alpha: Prog
	beta: Prog

@dataclass
class Test:
	q: Formula

@dataclass
class Choice:
	alpha: Prog
	beta: Prog

@dataclass
class Iter:
	alpha: Prog

Term = Union[Const, Var, Sum, Difference]
Formula = Union[TrueC, FalseC, LtF, EqF, NotF, AndF, OrF, ImpliesF]
Prog = Union[Asgn, Seq, Test, Choice, Iter]

def term_enc(e: Term) -> BitVecRef:
	if isinstance(e, Const):
		return BitVecVal(e.value, WIDTH)
	elif isinstance(e, Var):
		return BitVec(e.name, WIDTH)
	elif isinstance(e, Sum):
		return term_enc(e.left) + term_enc(e.right)
	elif isinstance(e, Difference):
		return term_enc(e.left) * term_enc(e.right)

def fmla_enc(e: Formula) -> BoolRef:
	if isinstance(e, TrueC):
		return BoolVal(True)
	elif isinstance(e, FalseC):
		return BoolVal(False)
	elif isinstance(e, LtF):
		return term_enc(e.left) < term_enc(e.right)
	elif isinstance(e, EqF):
		return term_enc(e.left) == term_enc(e.right)
	elif isinstance(e, NotF):
		return Not(fmla_enc(e.q))
	elif isinstance(e, AndF):
		return And(fmla_enc(e.p), fmla_enc(e.q))
	elif isinstance(e, OrF):
		return Or(fmla_enc(e.p), fmla_enc(e.q))
	elif isinstance(e, ImpliesF):
		return Implies(fmla_enc(e.p), fmla_enc(e.q))

inc = 0
def next(x: Var) -> Var:
	global inc
	inc += 1
	if len(x.name.split('_')) == 1:
		return Var('{}_{}'.format(x.name, inc))
	else:
		name = x.name.split('_')[0]
		index = int(x.name.split('_')[1])
		return Var('{}_{}'.format(name, inc+1))

	
def post(alpha: Prog, P: BoolRef, R: BoolRef = None, max_depth=10) -> BoolRef:
	
	if max_depth == 0:
		return BoolVal(False)
	
	if isinstance(alpha, Asgn):
		left = alpha.left
		right = alpha.right

		next_var = next(left)
		right_sub = substitute(term_enc(right), [(term_enc(left), term_enc(next_var))])
		P_sub = substitute(P, [(term_enc(left), term_enc(next_var))])
		
		if R is not None:
			# TODO: Add code here to "simulate" having rewritten the program as
      		# being followed by a check on `R`, as described in the handout
		else:
			# If R is not provided, then return the "normal"
			# strongest postcondition
			return And(term_enc(left) == right_sub, P_sub)
		
	elif isinstance(alpha, Seq):
			
		return post(alpha.beta, post(alpha.alpha, P, R, max_depth), R, max_depth)
		
	elif isinstance(alpha, Test):
		
		return And(fmla_enc(alpha.q), P)

	elif isinstance(alpha, Choice):
		
		return Or(post(alpha.alpha, P, R, max_depth), post(alpha.beta, P, R, max_depth))
	
	elif isinstance(alpha, Iter):
		
		return Or(P, post(Seq(alpha.alpha, Iter(alpha.alpha)), P, R, max_depth=max_depth-1))

def check_invariant(alpha: Prog, P: BoolRef, R: BoolRef, max_depth=10) -> bool:

	'''
		TODO: implement this procedure so that it returns True
			  if and only if all states that alpha enters
			  up to the given execution depth satisfy R
	'''

	return False
