
import copy
import random
counter = 0
max_size = 100


def is_pair(elem):
    return isinstance(elem, tuple) or isinstance(elem, list) or isinstance(elem, set) or isinstance(elem, dict)   


######## IMPLEMENTATIONS OF VARIABLES, SUCCESS AND FAIL ##################

class Var: # implementation of variables
    def __init__(self):
        self.id = ""

    def __repr__(self):
        return 'var ' + str(self.id)

    def __eq__(self, other):
        return (type(self) == type(other)) and (self.id == other.id)

    def __hash__(self): 
        return ord(self.id)

def var_huh(v): # checks if the argument is a variable
    return isinstance(v, Var)


def var(name): # creates the variable
    res = Var()
    res.id = name
    return res

def flatten(S): 
    if S == []:
        return S
    if not is_pair(S):
        return S
    if isinstance(S[0], list):
        return flatten(S[0]) + flatten(S[1:])
    return S[:1] + flatten(S[1:])

def no_vars(term): # checks if the given tuple has any variables
    term = flatten(term)
    if not is_pair(term):
        if var_huh(term):
            return False
        return True
    for elem in term:
        if var_huh(elem):
            return False
    return True
    
def succeed(substn): # our definition of a successful goal: just return a stream
    # containing the constraint satisfying goal
    return [substn]

def fail(subst): # failure: return empty stream as no solutions possible
    return []

#################### END OF BASIC DEFINITIONS #####################


############### IMPLEMENTATION OF GOALS, (EQ, AND, OR) ####################

####### EQ_EQ GOAL ############

def eq_eq(u, v):
    def anon(s): # returns an anonymous fn acting on the substitution
        s = unify(u, v, s)
        if s != False:
            return succeed(s)
        else:
            return fail(s)
    return anon

def unify(u, v, s):
    u = walk(u, s) # reduce u to root/base
    v = walk(v, s) # reduce v to root/base
    if equals(u, v): 
        return s
    elif isinstance(u, Var) and not occurs(u, v, s): # assign v to u, check for cycles
        s[u] = v
        return s
    elif isinstance(v, Var) and not occurs(v, u, s): # assign u to v, check for cycles
        s[v] = u
        return s
    elif isinstance(u, tuple) and isinstance(v, tuple): # tuple case, handle each element separately
        s = unify(u[0], v[0], s)
        if (s == False):
            return False
        else:
            return unify(u[1:], v[1:], s)
    else:
        return False

def walk(v, s): 
    if isinstance(v, Var) and (v in s):
        return walk(s[v], s)
    else:
        return v # base case

def occurs(x, y, s): # x is fresh, want to ensure that x does not occur in y
    y = walk(y, s) # reduce y to root/base
    if var_huh(y): # if y is a variable and equal to x
        if (y == x):
            return True
    elif is_pair(y): # tuple case to be handled by checking each element
        return (occurs(x, y[0], s) or (occurs(x, y[1:], s)))
    return False

def equals(u, v): # checking for equality b/w root terms
    if isinstance(u, int) and isinstance(v, int):
        return u == v
    elif isinstance(u, str) and isinstance(v, str):
        return u == v
    elif isinstance(u, Var) and isinstance(v, Var):
        return u.id == v.id
    elif isinstance(u, tuple) and len(u) == 0 and\
         isinstance(v, tuple) and len(v) == 0:
        return True
    else:
        return False

####### END EQ_EQ ##########


# logical AND    
def conj(*args):
    if len(args) == 0: # vacuous success
        return succeed
    elif len(args) == 1: # conj(goal) == goal
        return args[0]
    else: 
        return conj_2(args[0], conj(*args[1:])) # conjunct the first goal with
                                                # the conjunction of the rest of the goals

def conj_2(goal_1, goal_2):
    return lambda subst: append_map_inf(goal_2, goal_1(subst), list())

def append_map_inf(goal, subst_stream, res): # why did using the empty argument not work?
    #print(subst_stream)
    if len(subst_stream) == 0:
        return res
    #print(f"the appended subst is: {goal(subst_stream.pop())}")
    #assert(False)
    if isinstance(subst_stream[0],dict):
        res.extend(goal(subst_stream[0]))
    elif callable(subst_stream[0]): # handle the generator elem case
        stream = subst_stream[0]() # replace the generator fn with a new generator fn satis. both goals
        res.append(lambda: append_map_inf(goal, stream, list())) # this is the new fn
    return append_map_inf(goal, subst_stream[1:], res)

###### END LOGICAL AND ###########

######## LOGICAL OR (DISJUNCTION) ###########

def disj(*args):
    if len(args) == 0: # base case of disjunction
        return fail
    elif len(args) == 1: # 1 goal argument case
        return args[0]
    else:
        return disj_2(args[0], disj(*args[1:])) # helper

def disj_2(goal_1, goal_2):
    def anon(subst):
        subst1, subst2 = dict(subst), subst
        return append_inf(goal_2(subst1), goal_1(subst2)) # apply goals independently to substitution
    return anon

def append_inf(subst_stream1, subst_stream2):
    res = copy.deepcopy(subst_stream1) # extent the results of first goal with those of second
    res.extend(subst_stream2)
    return res

########### END LOGICAL OR ##############

# a fun little exercise demonstrating the purpose of logic programming
# PROBLEM: find me the even numbers in a given tuple. Here is the logical implementation:

def evmembero(mem, l): # mem is the term to be searched for, l is the tuple
    def anon(subst):
        if len(l) == 0:
            return fail(subst) # base case
        return conj(disj(eq_eq(mem, l[0]), evmembero(mem, l[1:])), eveno(mem))(subst)
    # translation: ((mem is equal to the first elem) OR (mem is in the rest of the tuple))
    # AND (mem is even). Find all solutions to mem and we're done!
    return anon

def eveno(x): # goal to check for even-ness
    def anon(subst):
        if (walk(x,subst) % 2 == 0): # x reduces to an even root/base value
            return succeed(subst)
        else:
            return fail(subst)
    return anon

########## FUN EXAMPLE ENDS ###########

####### NEGATION ###########

def not_equaliser(t1,t2):
    # 0 pairs
        # one is var
            # set var to random till not equal to constant
        # both are var
            # set both to random till unequal
    # 1 pair
        # we will decide to return unequal values of the same type
        # use 0 pair base case for each elem
    # 2 pairs
        # use 0 pair base case for each elem
    if var_huh(t1) or var_huh(t2):
        if var_huh(t1):
            t1 = random.randint(-1 * max_size, max_size)
        if var_huh(t2):
            t2 = random.randint(-1 * max_size, max_size)
        if (t1 == t2):
            not_equaliser(t1, t2)
        else:
            return t1, t2

def not_eq(t1,t2):
    def anon(subst):
        nonlocal t1
        nonlocal t2
        t1 = walk_star(t1,subst,list()) # get root/base value
        t2 = walk_star(t2,subst,list()) # get root/base value
        l_subst = dict(subst)
        if equals(t1, t2): # not_eq fails if the base values are equal
            return fail(subst)
        if (no_vars(t1) and no_vars(t2)):
            return succeed(subst)
        x, y = not_equaliser(t1,t2) # assign random values
        if not(occurs(t1, x, subst)) and not(occurs(t2, y, subst)): # same process as unify hereon
            if var_huh(t1):
                l_subst[t1] = x
            if var_huh(t2):
                l_subst[t2] = y
        return [l_subst, lambda: not_eq(t1,t2)(subst)] # append the generator to create infinite stream
    return anon

def walk_star(v, subst, res): # a more powerful version of walk, looks inside tuples to get at root
    v = walk(v, subst)
    #print(v)
    if is_pair(v):
        for i in range(len(v)): # look inside tuples and get at root value
            if is_pair(walk(v[i],subst)):
                res.append(walk_star(v[i],subst,list()))
            else:
                res.append(walk_star(v[i],subst,res))
        return res
    else:
        return v

########### END NEGATION #########


# When playing around with these goals, remember to specify the initial substitution
# in the input!

###################### END OF GOAL DEFINITIONS #############################