import tree
import numpy as np
import counting_matrix, pearl_reconstruct, dfs_ordering, cluster_vote, evaluate, data, sequoia
import generate
import random

def epsilon4pc_sampling(D, numSamples = 1000):
    """
    Approximate CDF of 4PC-epsilons by sampling
    """
    n = D.shape[0]
    epsilons = []
    for i in range(numSamples):
        while True:
            [w, x, y, z] = random.sample(range(n), 4)
            s1 = D[w,x] + D[y,z]
            s2 = D[w,y] + D[x,z]
            s3 = D[w,z] + D[x,y]
            if min(min(s1, s2), s3) > 0:
                break
        if s1 <= min(s2,s3):
            epsilons.append(float(abs(s3 - s2)/(2.0*min(D[w,x], D[y,z]))))
        elif s2 <= min(s1, s3):
            epsilons.append(float(abs(s1 - s3)/(2.0*min(D[w,y], D[x,z]))))
        elif s3 <= min(s1, s2):
            epsilons.append(float(abs(s1 - s2)/(2.0*min(D[w,z], D[x,y]))))
    epsilons = [2 if x < 0 or x > 1 else x for x in epsilons]
    f = open("./out/e4pc_sampled.out", "w")
    f.write(" ".join([str(x) for x in epsilons]))
    return epsilons

def epsilon4pc(D):
    """
    Compute CDF of 4pc-epsilon scores for distances in D
    """
    n = D.shape[0]
    epsilons = []
    for w in range(n):
        for x in range(w+1, n):
            for y in range(x+1, n):
                for z in range(y+1, n):
                    s1 = D[w,x] + D[y,z]
                    s2 = D[w,y] + D[x,z]
                    s3 = D[w,z] + D[x,y]
                    if s1 <= min(s2,s3):
                        epsilons.append(float(abs(s3 - s2)/(2.0*min(D[w,x], D[y,z]))))
                    elif s2 <= min(s1, s3):
                        epsilons.append(float(abs(s1 - s3)/(2.0*min(D[w,y], D[x,z]))))
                    elif s3 <= min(s1, s2):
                        epsilons.append(float(abs(s1 - s2)/(2.0*min(D[w,z], D[x,y]))))
    return epsilons

def epsilon4pc_experiment():
    """
    Generate approximations of the epsilon4pc CDF for the three datasets. 
    """
    D = data.load_king()
    D = D[np.ix_(range(500), range(500))]
    D2 = data.iplane_dist_mat()
    X = np.random.normal(0, 1, [100, 3])
    X = np.matrix([x/np.linalg.norm(x) for x in X])
    D3 = np.arccos(X*X.T)

    eps_king = epsilon4pc_sampling(D)
    eps_iplane = epsilon4pc_sampling(D2)
    eps_sphere = epsilon4pc_sampling(D3)

    f = open("./out/e4pc_sampled.out", "w")
    f.write(" ".join([str(x) for x in eps_king]))
    f.write("\n")
    f.write(" ".join([str(x) for x in eps_iplane]))
    f.write("\n")
    f.write(" ".join([str(x) for x in eps_sphere]))
    f.write("\n")

def err_vs_probes(q = 0.2):
    """
    Experiment for the error versus the number of probes on a balanced binary tree. 
    """
    p = [3*64, 3*128]
    mscale = range(2, 11, 2)
    cv = [[] for x in p]
    seq = [[] for x in p]
    for i in range(len(p)):
        T = tree.Tree(generate.generate_balanced_binary(p[i]))
        D = T.paths[np.ix_(T.findleaves(), T.findleaves())]
        for x in mscale:
            for iters in range(5):
                mask = np.random.binomial(1, 1-q/2, [p[i], p[i]])
                mask = mask + mask.T
                C = counting_matrix.CountingMatrix(D*mask)
                (F, mapping, root) = cluster_vote.cluster_vote(C, m = lambda y: x*np.log2(y), active=True)
                errs = evaluate.eval_absolute_error(D, F.paths[np.ix_(mapping, mapping)])
                cv[i].append((C.numUsed(), errs))
                C2 = counting_matrix.CountingMatrix(D*mask)
                SeqMats = []
                for j in range(x):
                    (F, mapping) = sequoia.sequoia(C2)
                    SeqMats.append(F.paths[np.ix_(mapping, mapping)])
                SOut = sequoia.getDistMat(SeqMats)
                errs2 = evaluate.eval_absolute_error(D, SOut)
                seq[i].append((C2.numUsed(), errs2))
                print "%d %d: %d %f %d %f" % (p[i], x, C.numUsed(), errs, C2.numUsed(), errs2)
        f = open("out/probes_vs_error_cv_relative_%d.out" % p[i], "w")
        for l in cv[i]:
            f.write(" ".join([str(x) for x in l]))
            f.write("\n")
        f.close()
        f = open("out/probes_vs_error_seq_relative_%d.out" % p[i], "w")
        for l in seq[i]:
            f.write(" ".join([str(x) for x in l]))
            f.write("\n")
        f.close()
    return (cv, seq)


def cluster_vote_threshold_curves(recurse=False):
    """
    Fraction of incorrect quartets for the rising algorithm as a function of number of nodes n.
    Only looks at the top split. 
    """
    numTrials = 20
    ns = [3*8, 3*16, 3*32]
    out = []
    for j in range(len(ns)):
        n = ns[j]
        T = tree.Tree(generate.generate_balanced_binary(n))
        D = T.paths[np.ix_(T.findleaves(), T.findleaves())]
        groups = [range(0, n/3), range(n/3, 2*n/3), range(2*n/3, n)]
        out.append([])
        for q in [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]:
            errs = []
            correct = 0
            for i in range(numTrials):
                mask = np.random.binomial(1, 1-q/2, [n, n])
                mask = mask + mask.T
                C = counting_matrix.CountingMatrix(D*mask)
                if recurse:
                    (T2, mapping, root) = cluster_vote.cluster_vote(C, active=True)
                    err = 1-evaluate.eval_quartets(counting_matrix.Matrix(T2.paths[np.ix_(mapping, mapping)]), counting_matrix.Matrix(D), numSamples=1000)
                    errs.append(err)
                else:
                    (T2, mapping, root) = cluster_vote.cluster_vote(C, active=True, recurse=False)
                    l1 = [len([x for x in l if x < n/3]) for l in T2]
                    l2 = [len([x for x in l if x >= n/3 and x < 2*n/3]) for l in T2]
                    l3 = [len([x for x in l if x >= 2*n/3]) for l in T2]
                    try:
                        p1 = l1.index(n/3)
                        p2 = l2.index(n/3)
                        p3 = l3.index(n/3)
                        if p1 != p2 and p2 != p3 and p1 != p3:
                            correct = correct + 1
                    except:
                        pass
            if recurse:
                out[j].append(errs)
                print "n = %d q = %f score = %f" % (n, q, np.mean(errs))
            else:
                print "n = %d q = %f score = %f" % (n, q, float(correct)/numTrials)
                out[j].append(float(correct)/numTrials)

        if recurse:
            f = open("out/cv_threshold_curves_%d.out" % n, "w")
            for l in out[j]:
                f.write(" ".join([str(x) for x in l]))
                f.write("\n")
    if not recurse:
        f = open("out/cluster_vote_threshold_curves.out", "w")
        for l in out:
            f.write(" ".join([str(x) for x in l]))
            f.write("\n")
    return out

def pearl_reconstruct_threshold_curves():
    """
    Fraction of quartets correct as a function of number of nodes p for the pearl-reconstruct algorithm. 
    """
    numTrials = 20
    ns = [20, 40, 60]
    out = []
    for j in range(len(ns)):
        n = ns[j]
        out.append([])
        for sigma in np.arange(1, 2, 0.1):
            errs = []
            for i in range(numTrials):
                T = tree.Tree(generate.generate(n))
                D = T.paths[np.ix_(T.findleaves(), T.findleaves())]
                if sigma == 0:
                    mask = np.ones((n,n))
                else:
                    mask = np.random.normal(0, float(sigma)/(2*np.log2(n**2)), [n, n])
                    mask = mask + mask.T
                C = counting_matrix.CountingMatrix(D+mask)
                (T2, mapping) = pearl_reconstruct.pearl_reconstruct(C)
                errMat = T2.paths[np.ix_(mapping, mapping)] - D
                normerr = np.linalg.norm(errMat, ord='fro')
                err1 = float(1)/(n**2)*normerr
                err = 1-evaluate.eval_quartets(counting_matrix.Matrix(T2.paths[np.ix_(mapping, mapping)]), counting_matrix.Matrix(D))
                errs.append(err)
            out[j].append(errs)
            print "n = %d sigma = %f score = %f" % (n, float(sigma), np.mean(errs))
        f = open("out/pearl_reconstruct_rescaled_threshold_curves_%d.out" % n, "w")
        for l in out[j]:
            f.write(" ".join([str(x) for x in l]))
            f.write("\n")
    return out

def probesUsed2():
    """
    Number of probes used. 
    """
    cv = []
    cv_recurse = []
    bound = []
    seq = []
    for x in xrange(100, 1501, 100):
        T = generate.generate_unbalanced_binary(x, 1)[0]
        leaves = T.findleaves()
        D = T.paths[np.ix_(leaves, leaves)]

        C2 = counting_matrix.CountingMatrix(D)
        (F, mapping, root) = cluster_vote.cluster_vote(C2, active=True)
        cv_recurse.append(C2.numUsed())

        C3 = counting_matrix.CountingMatrix(D)
        (F, mapping, root) = cluster_vote.cluster_vote(C3, active=True, recurse=False)
        cv.append(C3.numUsed())

        C5 = counting_matrix.CountingMatrix(D)
        (F, mapping) = sequoia.sequoia(C5)
        seq.append(C5.numUsed())

        bound.append(float(x*(x+1))/2)
        print "%d cv_rec %d cv %d seq %d bound %d" % (x, C2.numUsed(),
                                                      C3.numUsed(),
                                                      C5.numUsed(),
                                                      float(x*(x+1))/2)

    f = open('out/probes_used_cv_recurse.out', 'w')
    f.write(" ".join([str(x) for x in cv_recurse]))
    f.close()
    f = open('out/probes_used_cv.out', 'w')
    f.write(" ".join([str(x) for x in cv]))
    f.close()
    f = open('out/probes_used_seq.out', 'w')
    f.write(" ".join([str(x) for x in seq]))
    f.close()
    f = open('out/probes_used_bound.out', 'w')
    f.write(" ".join([str(x) for x in bound]))
    f.close()
    return {'cv_recurse': cv_recurse,
            'cv': cv,
            'bound': bound,
            'sequoia':sequoia}

def err_and_measurement_vs_p(q=0.1):
    """
    Error and measurement with the number of nodes p. 
    """
    xs = range(50, 501, 10)
    cv_measurements = []
    cv_error = []
    seq_measurements = []
    seq_error = []
    for x in xs:
        tmp_cv_m = []
        tmp_cv_e = []
        tmp_seq_m = []
        tmp_seq_e = []
        T = generate.generate_unbalanced_binary(x, 1)[0]
        leaves = T.findleaves()
        D = T.paths[np.ix_(leaves, leaves)]
        for j in range(5):
            mask = np.random.binomial(1, 1-q/2, [x, x])
            mask = mask + mask.T
            
            C2 = counting_matrix.CountingMatrix(D*mask)
            (F, mapping, root) = cluster_vote.cluster_vote(C2, m=lambda y: np.log2(y), active=True)
            tmp_cv_e.append(1-evaluate.eval_quartets(counting_matrix.Matrix(F.paths[np.ix_(mapping, mapping)]), counting_matrix.Matrix(D), numSamples=1000))
            tmp_cv_m.append(float(C2.numUsed()*2)/(x*(x+1)))
            
            C2 = counting_matrix.CountingMatrix(D*mask)
            (F, mapping) = sequoia.sequoia(C2)
            tmp_seq_e.append(1-evaluate.eval_quartets(counting_matrix.Matrix(F.paths[np.ix_(mapping, mapping)]), counting_matrix.Matrix(D), numSamples=1000))
            tmp_seq_m.append(float(C2.numUsed()*2)/(x*(x+1)))
        cv_measurements.append(np.mean(tmp_cv_m))
        cv_error.append(np.mean(tmp_cv_e))
        seq_measurements.append(np.mean(tmp_seq_m))
        seq_error.append(np.mean(tmp_seq_e))

        print "%d: %f %f %f %f" % (x, cv_error[len(cv_error)-1],
                                   cv_measurements[len(cv_error)-1],
                                   seq_error[len(cv_error)-1],
                                   seq_measurements[len(cv_error)-1])
    f = open("./out/err_and_measurement", "w")
    for i in range(len(xs)):
        f.write("%d %f %f %f %f\n" % (xs[i], cv_error[i], cv_measurements[i], seq_error[i], seq_measurements[i]))
    f.close()
    return (cv_measurements, cv_error, seq_measurements, seq_error)

def iplane_planet_lab_exp():
    """
    Simulation with iplane dataset.
    """
    D = data.iplane_dist_mat()
    C = counting_matrix.CountingMatrix(D)
    print "pearl reconstruct"
    (T, mapping) = pearl_reconstruct.pearl_reconstruct(C)
    f = open("out/iplane_planet_lab_pr_absolute.out", "w")
    errs = evaluate.eval_absolute_error(D, T.paths[np.ix_(mapping, mapping)])
    f.write(" ".join([str(x) for x in errs]))
    f.close()
    prUsed = C.numUsed()

    C = counting_matrix.CountingMatrix(D)
    print "cluster vote"
    (T, mapping, root) = cluster_vote.cluster_vote(C, m=lambda y: np.log2(D.shape[0]), active=True)
    f = open("out/iplane_planet_lab_cv_absolute.out", "w")
    errs = evaluate.eval_absolute_error(D, T.paths[np.ix_(mapping, mapping)])
    f.write(" ".join([str(x) for x in errs]))
    f.close()
    cvUsed = C.numUsed()
    print cvUsed

    C = counting_matrix.CountingMatrix(D)
    print "sequoia"
    SeqMats = []
    iters = 0
    while C.numUsed() + 1000 < cvUsed:
        (T, mapping) = sequoia.sequoia(C)
        SeqMats.append(T.paths[np.ix_(mapping, mapping)])
        iters += 1
        print "trees: %d, used %d" % (iters, C.numUsed())
    SOut = sequoia.getDistMat(SeqMats)
    f = open("out/iplane_planet_lab_sequoia_absolute.out", "w")
    errs = evaluate.eval_absolute_error(D, SOut)
    f.write(" ".join([str(x) for x in errs]))
    f.close()
    seqUsed = C.numUsed()
    f = open("out/iplane_planet_lab_probes.out", "w")
    f.write(" ".join([str(prUsed), str(cvUsed), str(seqUsed)]))
    f.close()


def iplane_measurement_exp():
    """
    Comparison between cluster_vote and sequoia on iplane data set. 
    """
    D = data.iplane_dist_mat()
    cv_out = []
    seq_out = []
    for i in range(1, 11):
        for j in range(5):
            C = counting_matrix.CountingMatrix(D)
            (T, mapping, root) = cluster_vote.cluster_vote(C, m=lambda y: np.log2(D.shape[0]), active=True)
            errs = evaluate.eval_relative_error(D, T.paths[np.ix_(mapping, mapping)])
            err_thres = float(len([x for x in errs if x < 1.0]))/len(errs)
            cv_out.append((C.numUsed(), err_thres))
        
            C2 = counting_matrix.CountingMatrix(D)
            SeqMats = []
            while C2.numUsed() + 1000 < C.numUsed():
                (F, mapping) = sequoia.sequoia(C2)
                SeqMats.append(F.paths[np.ix_(mapping, mapping)])
            SOut = sequoia.getDistMat(SeqMats)
            errs = evaluate.eval_relative_error(D, SOut)
            err_thres2 = float(len([x for x in errs if x < 1.0]))/len(errs)
            seq_out.append((C2.numUsed(), err_thres2))
            print "%d cv: %d %f sequoia: %d %f" % (i, C.numUsed(), err_thres, C2.numUsed(), err_thres2)
    f = open("out/iplane_measurement_cv.out", "w")
    for l in cv_out:
        f.write(" ".join([str(x) for x in l]))
        f.write("\n")
    f.close()
    f = open("out/iplane_measurement_seq.out", "w")
    for l in seq_out:
        f.write(" ".join([str(x) for x in l]))
        f.write("\n")
    f.close()
    return (cv_out, seq_out)

def king_exp():
    """
    Experiment on the king dataset
    """
    D = data.load_king()
    D = D[np.ix_(range(500), range(500))]
    C = counting_matrix.CountingMatrix(D)
    print "pearl reconstruct"
    (T, mapping) = pearl_reconstruct.pearl_reconstruct(C)
    f = open("out/king_pr_absolute.out", "w")
    errs = evaluate.eval_absolute_error(D, T.paths[np.ix_(mapping, mapping)])
    f.write(" ".join([str(x) for x in errs]))
    f.close()
    prUsed = C.numUsed()

    C = counting_matrix.CountingMatrix(D)
    print "cluster vote"
    (T, mapping, root) = cluster_vote.cluster_vote(C, m = lambda y: np.log2(D.shape[0]), active=True)
    f = open("out/king_cv_absolute.out", "w")
    errs = evaluate.eval_absolute_error(D, T.paths[np.ix_(mapping, mapping)])
    f.write(" ".join([str(x) for x in errs]))
    f.close()
    cvUsed = C.numUsed()
    print cvUsed

    C = counting_matrix.CountingMatrix(D)
    print "sequoia"
    SeqMats = []
    iters = 0
    while C.numUsed() + 4000 < cvUsed:
        (T, mapping) = sequoia.sequoia(C)
        SeqMats.append(T.paths[np.ix_(mapping, mapping)])
        iters += 1
        print "trees: %d, used %d" % (iters, C.numUsed())
    SOut = sequoia.getDistMat(SeqMats)
    f = open("out/king_sequoia_absolute.out", "w")
    errs = evaluate.eval_absolute_error(D, SOut)
    f.write(" ".join([str(x) for x in errs]))
    f.close()
    seqUsed = C.numUsed()
    f = open("out/king_probes.out", "w")
    f.write(" ".join([str(prUsed), str(cvUsed), str(seqUsed)]))
    f.close()
