import sys
import importlib
try: fname = sys.argv[1]
except IndexError: fname = "cfr"
p = importlib.import_module(fname)


from cfr_util import *
import random

import json


def load_game(game):
    with open(game + ".json") as f: 
        game = json.load(f)

    # Convert all sequences from lists to tuples
    for tfsdp in [game["decision_problem_pl1"], game["decision_problem_pl2"]]:
        for node in tfsdp:
            if isinstance(node["parent_edge"], list):
                node["parent_edge"] = tuple(node["parent_edge"])
            if "parent_sequence" in node and isinstance(node["parent_sequence"], list):
                node["parent_sequence"] = tuple(node["parent_sequence"])
    for entry in game["utility_pl1"]:
        assert isinstance(entry["sequence_pl1"], list)
        assert isinstance(entry["sequence_pl2"], list)
        entry["sequence_pl1"] = tuple(entry["sequence_pl1"])
        entry["sequence_pl2"] = tuple(entry["sequence_pl2"])
    return game

def test_best_response_to_uniform_with_rm(game):
    sequence_set_pl1 = get_sequence_set(game["decision_problem_pl1"])

    uniform_pl2 = p.uniform_sf_strategy(game["decision_problem_pl2"])
    utility_pl1 = compute_utility_vector_pl1(game, uniform_pl2)

    cfr_pl1 = p.Cfr(game["decision_problem_pl1"], rm_class=p.RegretMatching)
    for _ in range(20):
        strategy_pl1 = cfr_pl1.next_strategy()
        cfr_pl1.observe_utility(utility_pl1)

    check_equal(p.expected_utility_pl1(game, strategy_pl1, uniform_pl2), p.best_response_value(game["decision_problem_pl1"], utility_pl1))

def dictionary_close(d1, d2, tol):
    if set(d1) != set(d2): return False # key sets must match
    for k in d1: 
        if abs(d1[k] - d2[k]) > tol: return False
    return True

def check_equal_dict(val, expected, tol=TOL):
    print("EXPECTED:", expected)
    print("GOT:     ", val)
    if dictionary_close(expected, val, tol):
        print("\033[0;32mOK\033[0m")
        print()
    else:
        print("\033[0;31mMISMATCH\033[0m")
        print()

def check_equal(val, expected, tol=TOL):
    print("EXPECTED:", expected)
    print("GOT:     ", val)
    if abs(expected - val) < tol:
        print("\033[0;32mOK\033[0m")
        print()
    else:
        print("\033[0;31mMISMATCH\033[0m")
        print()

def check_zero(val, tol=TOL):
    print("EXPECTED: <", tol)
    print("GOT:     ", val)
    if abs(val) < tol:
        print("\033[0;32mOK\033[0m")
        print()
    else:
        print("\033[0;31mMISMATCH\033[0m")
        print()
    
def test():
    game_ids = [
        "rock_paper_superscissors", 
        "kuhn_poker",
        "leduc_poker",
    ]
    games = [load_game(g) for g in game_ids]

    # test uniform_sf_strategy and expected_utility_pl1

    uniform_utilities = {
        "rock_paper_superscissors": 0.0,
        "kuhn_poker": 0.125,
        "leduc_poker": -0.078125,
    }
    for game_id, game in zip(game_ids, games):
        pl1_uniform = p.uniform_sf_strategy(game["decision_problem_pl1"])
        pl2_uniform = p.uniform_sf_strategy(game["decision_problem_pl2"])


        print("testing uniform_sf_strategy and expected_utility_pl1 in game", game_id)

        check_equal(p.expected_utility_pl1(game, pl1_uniform, pl2_uniform), uniform_utilities[game_id])

    # test RegretMatching

    print("checking regret matching...")
    actions = {0, 1, 2}
    rm = p.RegretMatching(actions)
    # make some random utils
    utils = [
        {0: 0.6888437030500962, 1: 0.515908805880605, 2: -0.15885683833831},
        {0: -0.4821664994140733, 1: 0.02254944273721704, 2: -0.19013172509917142},
        {0: 0.5675971780695452, 1: -0.3933745478421451, 2: -0.04680609169528838},
    ]
    correct_strategies = [
        {0: 0.3333333333333333, 1: 0.3333333333333333, 2: 0.3333333333333333},
        {0: 0.6703829932030705, 1: 0.32961700679692946, 2: 0.0},
        {0: 0.2558562038734756, 1: 0.7441437961265245, 2: 0.0},
        {0: 0.7738685354981276, 1: 0.22613146450187235, 2: 0.0},
    ]
    check_equal_dict(rm.next_strategy(), correct_strategies[0])
    for i in range(3):
        print("after", i, "iterations:")
        rm.observe_utility(utils[i])
        # print(rm.next_strategy())
        check_equal_dict(rm.next_strategy(), correct_strategies[i+1])

    # test Cfr by finding best responses

    for game_id, game in zip(game_ids, games):
        print("testing best response against uniform in game", game_id)
        test_best_response_to_uniform_with_rm(game)

    # test run_cfr
    print("testing cfr")
    target_gaps = {
        "rock_paper_superscissors": 0.1,
        "kuhn_poker": 0.05,
        "leduc_poker": 0.2,
    }
    for game_id, game in zip(game_ids, games):
        print("testing cfr in game", game_id)
        x, y = p.run_cfr(game, 1000)
        gap = p.gap(game, x, y)
        check_zero(gap, tol=target_gaps[game_id])

    print("testing dcfr")
    target_gaps = {
        "rock_paper_superscissors": 0.05,
        "kuhn_poker": 0.001,
        "leduc_poker": 0.001,
    }
    for game_id, game in zip(game_ids, games):
        print("testing dcfr in game", game_id)
        x, y = p.run_dcfr(game, 1000)
        gap = p.gap(game, x, y)
        check_zero(gap, tol=target_gaps[game_id])

    print("testing pcfr+")
    target_gaps = {
        "rock_paper_superscissors": 2e-5,
        "kuhn_poker": 2e-5,
        "leduc_poker": 0.005,
    }
    for game_id, game in zip(game_ids, games):
        print("testing pcfr+ in game", game_id)
        x, y = p.run_pcfrp(game, 1000)
        gap = p.gap(game, x, y)
        check_zero(gap, tol=target_gaps[game_id])
        


    


if __name__ == "__main__": 
    test()
