import tree, dfs_ordering, counting_matrix, union_find, graph
import numpy as np
import cPickle as pickle
import data
import copy

def graphreconstruct(SPMs, mappings, gap, reuse=False):
    treeset = []
    mset = []
    if reuse != True:
        for i in range(len(SPMs)):
            print "Constructing tree for %d" % (i)
            root = mappings[i][0]
            (T, m, root) = dfs_ordering.dfs_ordering(SPMs[i], gap, verbose=False)
            n = T.numVertices()
            T = T.adjMat.toarray()
            Tnew = np.zeros([T.shape[0]+1, T.shape[1]+1])
            Tnew[0,root+1] = 1
            Tnew[root+1,0] = 1
            Tnew[np.ix_(range(1,n+1), range(1,n+1))] = T
            T = tree.Tree(Tnew)
            m = [x+1 for x in m]
            m = [0 if j == 0 else m[j-1] for j in range(len(mappings[i]))]
            mnew = [m[mappings[i].index(j)] for j in mappings[i]]
            treeset.append(T)
            mset.append(mnew)
        pickled_data = pickle.dumps((treeset, mset))
        f = open("reuse_treesets.pkl", "w")
        f.write(pickled_data)
        f.close()
    else:
        f = open("reuse_treesets.pkl", "r")
        data = pickle.load(f)
        (treeset, mset) = data
    
    ## SUMMARY OF THE MAPPINGS WE MAINTAIN
    ## mappings[i][j] = The id (in the original graph) of the jth leaf of the ith tree
    ## mset[i][j] = the index (in the ith tree) of the jth leaf (for that tree)
    ## mset2[i][j] = the index in the resulting graph of the jth node of the ith tree
    ## m[i] = the index in the resulting graph of the node with uniqueid i

    ## now we have to construct the graph
    ## first compute an upper bound on the size of the graphs
    ssize = sum([T.adjMat.shape[0] - len(mappings) for T in treeset]) + len(mappings)

    ## this is an overestimate of the number of nodes in our graph
    G = np.matrix(np.zeros((ssize, ssize)))

    ## the leaf set will be the first several nodes
    m = [-1 for i in range(ssize)]
    next = 0
    for i in range(len(mappings)):
        m[mappings[i][0]] = next
        next += 1

    ## now we have to reverse all the mappings in mset
    mset2 = []
    for i in range(len(mset)):
        mset2.append([m[mappings[i][mset[i].index(j)]] if j in mset[i] else -1 for j in range(ssize)])

        
    ## the next internal node will be identified by nextuid
    duplicates = union_find.UnionFind(range(next))
    nextuid = len(mappings)
    for i in range(len(mappings)):
        print "Working on node %d" % (i)
        for j in range(len(mappings)):
            if j != i:
                ## Look at the paths from one to the other
                iname = mappings[i][0]
                jname = mappings[j][0]
                iini = mset[i][0]
                jinj = mset[j][0]
                jini = mset[i][mappings[i].index(jname)]
                iinj = mset[j][mappings[j].index(iname)]
                (p1,d1) = treeset[i].path(iini,jini)
                (p2,d2) = treeset[j].path(iinj,jinj)
                diff = 0
                prev = m[iname]
                ipathsum = 0
                jpathsum = 0
                iind = 1
                jind = 1
                while iind < len(p1) and jind < len(p2):
#                     print "%d %d" % (iind, jind)
                    nextini = duplicates.find(mset2[i][p1[iind]]).data
                    nextinj = duplicates.find(mset2[j][p2[jind]]).data
#                     print "%f %f %f %f" % (ipathsum+d1[iind], jpathsum+d2[jind], np.abs(ipathsum+d1[iind] - jpathsum-d2[jind]), gap)
#                     print np.abs(ipathsum+d1[iind] - jpathsum-d2[jind]) <= gap
                    if np.abs(ipathsum+d1[iind] - jpathsum-d2[jind]) <= gap:
                        ## then these two nodes should be the same
#                         print "here"
                        if nextini != -1 and nextinj != -1 and nextini != nextinj:
#                             print [mset2[i][p1[iiind]] for iiind in range(len(p1))]
#                             print [mset2[j][p2[jjjnd]] for jjjnd in range(len(p2))]
                            ## Then we have to merge these two because they are supposed to be the same node but they have different names
#                             print "Performing union %d %d" % (i, j)
                            duplicates.union(nextini, nextinj)
#                             print G[nextini,range(len(mappings))]
#                             print G[nextinj,range(len(mappings))]
                            ## we also have to merge the rows in the matrix
                            for tmpind in range(G.shape[0]):
                                G[nextinj,tmpind] = max(G[nextini,tmpind], G[nextinj,tmpind])
                                G[tmpind,nextinj] = max(G[tmpind,nextini], G[tmpind,nextinj])
                            prev = nextinj
                        elif nextini != -1:
                            ## here make sure they have the same name, but they have already been assigned a name
                            mset2[j][p2[jind]] = nextini
                            G[prev, nextini] = min(d1[iind],d2[jind])
                            G[nextini, prev] = min(d1[iind],d2[jind])
#                             print "#1 Adding edge between %d and %d" % (prev, nextini)
                            prev = nextini
                        elif nextinj != -1:
                            ## here make sure they have the same name, but they have already been assigned a name
                            mset2[i][p1[iind]] = nextinj
                            G[prev, nextinj] = min(d1[iind],d2[jind])
                            G[nextinj, prev] = min(d1[iind],d2[jind])
#                             print "#2 Adding edge between %d and %d" % (prev, nextinj)
                            prev = nextinj
                        else:
                            ## Neither of them have a name yet, assign them to a new name
                            mset2[i][p1[iind]] = nextuid
                            mset2[j][p2[jind]] = nextuid
                            duplicates.add_object(nextuid)
                            G[prev, nextuid] = min(d1[iind],d2[jind])
                            G[nextuid, prev] = min(d1[iind],d2[jind])
#                             print "#3 Adding edge between %d and %d" % (prev, nextuid)
                            prev = nextuid
                            nextuid += 1
                        ipathsum += d1[iind]
                        jpathsum += d2[jind]
                        iind += 1
                        jind += 1
                    elif ipathsum+d1[iind] > jpathsum+d2[jind] + gap:
                        ## Then these two nodes shouldn't be the same, but we need to make sure that prev is connected to d2[jind]
                        if nextinj != -1:
                            G[prev, nextinj] = d2[jind] + jpathsum - max(ipathsum, jpathsum)
                            G[nextinj, prev] = d2[jind] + jpathsum - max(ipathsum, jpathsum)
#                             print "#4 Adding edge between %d and %d" % (prev, nextinj)
                            prev = nextinj
                        else:
                            mset2[j][p2[jind]] = nextuid
                            duplicates.add_object(nextuid)
                            G[prev, nextuid] = d2[jind] + jpathsum - max(ipathsum, jpathsum)
                            G[nextuid, prev] = d2[jind] + jpathsum - max(ipathsum, jpathsum)
#                             print "#5 Adding edge between %d and %d" % (prev, nextuid)
                            prev = nextuid
                            nextuid += 1
                        jpathsum += d2[jind]
                        jind += 1
                    elif ipathsum+d1[iind] < jpathsum+d2[jind] - gap:
                        ## then these two nodes shouldn't be the same, but we need to make sure that prev is connected to d1[iind]
                        if nextini != -1:
                            G[prev, nextini] = d1[iind] + ipathsum - max(ipathsum, jpathsum)
                            G[nextini, prev] = d1[iind] + ipathsum - max(ipathsum, jpathsum)
#                             print "#6 Adding edge between %d and %d" % (prev, nextini)
                            prev = nextini
                        else:
                            mset2[i][p1[iind]] = nextuid
                            duplicates.add_object(nextuid)
                            G[prev, nextuid] = d1[iind] + ipathsum - max(ipathsum, jpathsum)
                            G[nextuid, prev] = d1[iind] + ipathsum - max(ipathsum, jpathsum)
#                             print "#7 Adding edge between %d and %d" % (prev, nextuid)
                            prev = nextuid
                            nextuid += 1
                        ipathsum += d1[iind]
                        iind += 1
    active = duplicates.get_active()
    return (G[np.ix_(active, active)], m[0:nextuid])

def construct_inputs(G, leaves):
    SPMs = []
    mappings = []
    for i in leaves:
        (SPM, m) = G.sharedPathMatrix(i)
        SPM = SPM[np.ix_(m, m)]

        tokeep = []
        newm = []
        for j in range(len(m)):
            if m[j] in leaves:
                tokeep.append(j)
                newm.append(m[j])
        SPM = SPM[np.ix_(tokeep, tokeep)]
        newm = [i if j == 0 else newm[j-1] for j in range(len(leaves))]
#         C = counting_matrix.CountingMatrix(SPM[np.ix_(newm[1:], newm[1:])])
        SPMs.append(counting_matrix.CountingMatrix(SPM))
        mappings.append(newm)
    return (SPMs, mappings)


def are_same_node(duplicates, i, j):
    if i == j:
        return True
    if len([x for x in duplices if i in x and j in x]) > 0:
        return True

def reconstruct_orbis_graph(gap=0.1, sigma = 0, reuse=False):
    M = data.load_orbis_topo("./test.graph")
    G = graph.Graph(M)
    print "Constructing Inputs"
    (SPMs, mappings, paths) = make_spms(G, sigma)
    print "Recovering"
    (M2, m2) = graphreconstruct(SPMs, mappings, 0.1, reuse=reuse)
    print "Done"
    G2 = graph.Graph(M2)
    return (G2.paths[np.ix_(range(len(G.findleaves())), range(len(G.findleaves())))], G.paths[np.ix_(G.findleaves(), G.findleaves())], G2, G, (G2.paths[np.ix_(range(len(G.findleaves())), range(len(G.findleaves())))] == G.paths[np.ix_(G.findleaves(), G.findleaves())]).all())

def error_vs_sigma(sigmas):
    vec = []
    for s in sigmas:
        out = reconstruct_orbis_graph(gap=s*10, sigma=s)
        vec.append(1.0/(out[0].shape[0])**2 * np.sum(np.abs(out[0] - out[1])))
    return vec

    
def make_spms(G, sigma=0):
    def dfsDesc(start, level):
        if level == maxdist:
            return
        cands = [x for x in G.neighbors(start) if x in levelsets[level+1]]
        for c in cands:
            parents[c] = start
            dfsDesc(c, level+1)
        return

    leaves = G.findleaves()
    paths = {}
    ## start with bfs at leaves[0]
    for root in leaves:
        maxdist = int(np.max(G.paths[root,:]))
        if maxdist == np.inf:
            raise Exception("Graph is not connected")
        levelsets = [[x for x in range(G.paths.shape[0]) if G.paths[root,x] == j] for j in range(maxdist+1)]
        parents = {root: None}
        dfsDesc(root, 0)
        for i in leaves:
            if (i,root) not in paths.keys():
                p = [i]
                while parents[p[len(p)-1]] != None:
                    p.append(parents[p[len(p)-1]])
                paths[(i, root)] = p
                p2 = copy.copy(p)
                p2.reverse()
                paths[(root,i)] = p2
    SPMs = []
    mappings = []
    
    for i in range(len(leaves)):
        M = np.zeros((len(leaves)-1, len(leaves)-1))
        m = [leaves[i]]
        for j in range(len(leaves)-1):
            if j >= i:
                M[j,:] = [len(longest_prefix(paths[(leaves[i], leaves[j+1])], paths[(leaves[i], leaves[k])])) for k in range(len(leaves)) if k != i]
                m.append(leaves[j+1])
            elif j < i:
                M[j,:] = [len(longest_prefix(paths[(leaves[i], leaves[j])], paths[(leaves[i], leaves[k])])) for k in range(len(leaves)) if k != i]
                m.append(leaves[j])
        if sigma != 0:
            R = np.matrix(np.random.normal(0, sigma, M.shape))
            M = M + R
        SPMs.append(counting_matrix.CountingMatrix(M))
        mappings.append(m)
    return (SPMs, mappings, paths)

def longest_prefix(l1, l2):
    pref = []
    i = 0
    while i < len(l1) and i < len(l2) and l1[i] == l2[i]:
        pref.append(l1[i])
        i += 1
    return pref[1:]
