import itertools
import numpy as np
import xml.etree.ElementTree as ET


def main():
    vars, cpts = load_xmlbif("alarm.xmlbif")

    # You can test your code on these three test cases. 
    # Your answers below should be similar to the solution distributions 
    # (each probability should probably be within 0.1 of the solution).
    long_evidence = {'ARTCO2': 'HIGH', 'BP': 'HIGH', 'CATECHOL': 'NORMAL', 'DISCONNECT': 'TRUE', 'ERRCAUTER': 'FALSE', 'ERRLOWOUTPUT': 'TRUE', 'EXPCO2': 'HIGH', 'FIO2': 'NORMAL', 'HISTORY': 'FALSE', 'HR': 'NORMAL', 'HRBP': 'NORMAL', 'HREKG': 'LOW', 'HRSAT': 'LOW', 'HYPOVOLEMIA': 'FALSE', 'INSUFFANESTH': 'TRUE', 'INTUBATION': 'ONESIDED', 'KINKEDTUBE': 'TRUE', 'LVEDVOLUME': 'LOW', 'LVFAILURE': 'TRUE', 'MINVOL': 'LOW', 'MINVOLSET': 'NORMAL', 'PAP': 'HIGH', 'PCWP': 'HIGH', 'PRESS': 'LOW', 'PULMEMBOLUS': 'TRUE', 'SAO2': 'HIGH', 'SHUNT': 'HIGH', 'STROKEVOLUME': 'LOW', 'TPR': 'LOW', 'VENTALV': 'NORMAL', 'VENTLUNG': 'NORMAL', 'VENTMACH': 'NORMAL', 'VENTTUBE': 'ZERO'}
    test_cases = [({}, 'CATECHOL', [0.298, 0.702]), ({'BP': 'HIGH'}, 'TPR', [0.012, 0.284, 0.704]), (long_evidence, 'PVSAT', [0.946,  0.042, 0.012])]

    n_samples = 100
    for evidence, var, dist in test_cases:
        values = [key for key in vars[var]]
        counts = np.array([0.0 for key in vars[var]])
        for i in range(n_samples):
            all_vars = gibbs(vars, cpts, evidence)
            counts[values.index(all_vars[var])] += 1
        print('-'*30)
        print("  + Evidence Variables: %s" %(evidence))
        print("  + Sampled Varible: %s" %(var))
        print("  + Solution: %s" %(dist))
        print("  + Your Answer: %s" %(counts/np.sum(counts)))
    return

def load_xmlbif(filename):
    root = ET.parse(filename).getroot()
    vars = {};
    for var in root.iter('VARIABLE'):
        vars[var.find('NAME').text] = [n.text for n in var.iter('OUTCOME')]
    factors = []
    for factor in root.iter('DEFINITION'):
        child = factor.find('FOR').text
        parents = [n.text for n in factor.iter('GIVEN')]
        f = CPT(dict([(v,vars[v]) for v in [child]+parents]))
        if parents == []:
            for v,x in zip(vars[child],factor.find('TABLE').text.split(" ")):
                f[{child:v}] = float(x);
        else:
            for entry in factor.iter('ENTRY'):
                pvals = list(zip(parents,[n.text for n in entry.iter('CATEGORY')]))
                for v,x in zip(vars[child],entry.find('LIST').text.split(" ")):
                    f[dict(pvals + [(child,v)])] = float(x);
        factors.append(f)
    return vars,factors


class CPT:
    """ Generic code for a Bayes net CPT """
    def __init__(self, variables, default=0):
        self.variables = variables
        keys = [tuple(sorted(zip(self.variables.keys(),r)))
                for r in itertools.product(*self.variables.values())];
        self.f = dict(zip(keys, [default]*len(keys)))

    def __call__(self, **kwargs):
        return self[{k:v for k,v in kwargs.items() if k in self.variables}]
    
    def __getitem__(self,e):
        return self.f[tuple(sorted(e.items()))]

    def __setitem__(self,e,x):
        k = tuple(sorted(e.items()));
        if (k in self.f):
            self.f[k] = x;
        else:
            raise KeyError(e)
            
            
def gibbs(vars, cpts, evidence, n_steps=1000):
    """
    Run Gibbs sampling on Bayes net to sample values of all variables 
    given set of evidence variables.

    Inputs:
        vars: dictionary mapping each variable to list of values for that variable
        cpts: list of CPTs, one for each variable
        evidence: dictionary mapping each evidence variable to its observed value
        n_steps: number of steps of Gibbs sampling (i.e., number of times to randomly sample a variable)

    Output:
        all_vars: dictionary mapping each variable to its sampled value (include evidence variables)
    """

    # TODO: Implement this function.
    pass
    return all_vars



if __name__=='__main__':
    main()