'''
Note: this code was written hastily in lecture and is only for
instructional purposes.  It liberally cuts corners, has dubious
design decisions, and likely contains several dreadful bugs.
Use at your own risk.

Now with:
  * 'for' loops
  * local variables (lexical scoping), so recursive fib(n) works!
'''

EOF = chr(0)
EOL = '\n'
COMMENT_CHAR = '%'

def tokenize(code):
    code = list(code) + [EOF]
    tokens = [ ]
    while len(code) > 1:
        c = code.pop(0) # this is gross, slow, yuck, who wrote this shit?
        if c in '()+-*/<>':
            tokens.append(c)
        elif c.isspace():
            continue # ignore whitespace
        elif c == COMMENT_CHAR:
            # we're in a comment, so eat until EOF or EOL
            while (c := code.pop(0)) not in [EOF, EOL]:
                pass
        elif c.isdigit():
            # it's an integer, woohoo!
            numberStr = c
            while True:
                c = code[0]
                if not c.isdigit(): break
                numberStr += code.pop(0)
            tokens.append(int(numberStr))
        elif c.isalpha():
            # it's a variable name, woohoo!
            varnameStr = c
            while True:
                c = code[0]
                if (not c.isalpha()) and (not c.isdigit()) and (c != '_'): break
                varnameStr += code.pop(0)
            tokens.append(varnameStr)
        else:
            raise Exception(f'Unknown start of token: {c}')
    return tokens

def parse(tokens):
    if tokens[0] == '(':
        # we are parsing a list
        parseTree = [ ]
        tokens.pop(0) # remove the left paren
        while True:
            if tokens == [ ]:
                raise Exception('Missing closing parenthesis')
            if tokens[0] == ')':
                break
            parseTree.append(parse(tokens))
        tokens.pop(0) # remove the right paren
        return parseTree
    else:
        # we are parsing a single value
        return tokens.pop(0)

globalEnvironment = dict()
fnEnvironment = dict()

def evaluate(parseTree, environment):
    if isinstance(parseTree, list):
        return evaluateList(parseTree, environment)
    elif isinstance(parseTree, int):
        return parseTree
    elif isinstance(parseTree, str):
        # it is a variable...
        varName = parseTree
        if varName not in environment:
            if varName not in globalEnvironment:
                raise Exception(f'Unbound variable: {varName}')
            return globalEnvironment[varName]
        return environment[varName]
    else:
        raise Exception(f'Unknown type of object: {parseTree}')

def evaluateList(parseTree, environment):
    # first evaluate macros
    macro = parseTree[0]
    args = parseTree[1:]
    if macro == 'set':
        # set varName varValue
        varName, varValue = args[0], evaluate(args[1], environment)
        environment[varName] = varValue
        return varValue
    elif macro == 'if':
        # if test trueResult falseResult
        test = evaluate(args[0], environment)
        if test: return evaluate(args[1], environment)
        else: return evaluate(args[2], environment)
    elif macro == 'for':
        # for varname loVal hiVal body
        varName = args[0]
        loVal = evaluate(args[1], environment)
        hiVal = evaluate(args[2], environment)
        result = 0
        for value in range(loVal, hiVal+1):
            environment[varName] = value
            result = evaluate(args[3], environment)
        return result
    elif macro == 'while':
        # if test trueResult falseResult
        result = 0
        while (test :=  evaluate(args[0], environment)):
            if test: result = evaluate(args[1], environment)
            else: result = evaluate(args[2], environment)
        return result
    elif macro == 'funk':
        # funk userFnName parms body
        userFnName, parms, body = args
        fnEnvironment[userFnName] = (parms, body)
        return userFnName
    # next evaluate functions
    fn = parseTree[0]
    args = [evaluate(arg, environment) for arg in args]
    if fn == '*':
        return args[0] * args[1]
    elif fn == '+':
        return args[0] + args[1]
    elif fn == '-':
        return args[0] - args[1]
    elif fn in ['<', 'lt']:
        return int(args[0] < args[1])
    elif fn == 'out':
        print(args if len(args) != 1 else args[0])
    elif fn == 'block':
        return args[-1]
    elif fn in fnEnvironment:
        return callUserDefinedFunction(fn, args, environment)
    else:
        raise Exception(f'Unknown function: {fn}')

def callUserDefinedFunction(userFnName, args, environment):
    (parms, body) = fnEnvironment[userFnName]
    if len(args) != len(parms):
        raise Exception(f'Wrong # of args to {userFnName}')
    newEnvironment = dict()
    for parm,arg in zip(parms, args):
        newEnvironment[parm] = arg
    return evaluate(body, newEnvironment)

def run(code):
    return evaluate(parse(tokenize(code)), globalEnvironment)

code1 = '''
(block
 (funk fact (n)
     (block
         (set result 1)
         (while (< 1 n)
            (block
              (set result (* result n))
              (set n (- n 1))))
         result
      )
  )
  (+ (fact 3) (fact 4)) % 6 + 24 == 30
 )'''

code1 = '''
(block
 (funk fact (n)
     (block
         (set result 1)
         (while (< 1 n)
            (block
              (set result (* result n))
              (set n (- n 1))))
         result
      )
  )
  (+ (fact 3) (fact 4)) % 6 + 24 == 30
 )'''

code2 = '''
(block
 (funk max (x y)
     (if (< x y) y x)
  )
  (+ (max 5 6) (max 20 10)) % 6 + 20 == 26
 )'''

code3 = '''
(block
 (funk fib (n)
     (if (< n 2) 1
                 (+ (fib (- n 1))
                    (fib (- n 2))
                 )
     )
 )
 (set x 0)
 (while (< x 10)
    (block
        (out (fib x))
        (set x (+ x 1))
    )
 )
)'''

code4 = '''
(block
 (set y 5)
 (funk f (x)
     (+ x y)
 )
 (f 3)
)
'''

code5 = '''
(block
 (funk fib (n)
     (if (< n 2) 1
                 (+ (fib (- n 1))
                    (fib (- n 2))
                 )
     )
 )
 (for x 4 (+ 5 5)
    (block
        (out x (fib x))
        (set x (+ x 1))
    )
 )
)'''

def test():
    print('Running tests...', end='')
    assert(run(code1) == 30)
    assert(run(code2) == 26)
    assert(run(code4) == 8)
    print('Passed!')

test()

run(code5)