import numpy as np
import tree
import subtree
import random

def pearl_reconstruct(D, rootSelect = True, minEdgeLength=1.0):
    """
    Given a nxn distance matrix on n leaves, pearl_reconstruct returns an
    adjacency matrix that accurately reproduces the distance matrix
    """
    n = D.data.shape[0]

    ## choose the first three nodes of D arbitrarily and form a star tree
    d1a = (D.get(0,2) + D.get(0,1) - D.get(1,2))/2
    d2a = (D.get(1,0) + D.get(1,2) - D.get(0,2))/2
    d3a = (D.get(2,0) + D.get(2,1) - D.get(0,1))/2
    M = np.zeros([4, 4])
    M[0,3] = d1a
    M[3,0] = d1a
    M[1,3] = d2a
    M[3,1] = d2a
    M[2,3] = d3a
    M[3,2] = d3a

    T = tree.Tree(M)
    mapping = [0, 1, 2]

    ## go through each leaf node and add it to the tree.
    for xi in xrange(3, n):
        Tc = subtree.Subtree(T)
        parent = 0
        root = Tc.neighbors(parent)[0]
        while len(Tc.existent) > 2:
            if rootSelect:
                [root, parent] = Tc.getRoot()
            children = np.setdiff1d(Tc.neighbors(root), [parent])
            xj = random.choice(T.findLeavesBelow(children[0], root))
            xk = random.choice(T.findLeavesBelow(parent, root))
            xjm = mapping.index(xj)
            xkm = mapping.index(xk)
            
            dja = (D.get(xi, xjm) + D.get(xjm, xkm) - D.get(xi, xkm))/2
            dka = (D.get(xi, xkm) + D.get(xjm, xkm) - D.get(xi, xjm))/2

            if dja < T.paths[xj,root] - minEdgeLength/2:
                parent = root
                root = children[0]
                Tc.deleteUnreachable(parent, root)
            elif dka < T.paths[xk, root] - minEdgeLength/2:
                Tc.deleteUnreachable(root, parent)
                ptmp = root
                root = parent
                parent = ptmp
            else:
                Tc.deleteSubtrees(root, [parent, children[0]])
                parent = root
                if len (children) >= 2:
                    root = children[1]
                else:
                    root = children[0]
                if rootSelect == False and len(Tc.existent) > 1:
                    if len(Tc.neighbors(parent)) > 1:
                        root = parent
                        if len(children) >= 2:
                            parent = children[1]
                        else:
                            parent = children[0]
#                 if rootSelect == False and len(Tc.existent) > 1:
#                     parent = random.choice(Tc.findleaves())
#                     root = Tc.neighbors(parent)[0]
        assert Tc.numEdges() <= 1
        c1 = T.findLeavesBelow(root, parent)
        c2 = T.findLeavesBelow(parent, root)
        done = False
        if True:
            xj = random.choice(c1)   ## xj is below root (i.e. parent -> root -> xj)
            xk = random.choice(c2)   ## xk is below parent (i.e. root -> parent -> xk)
            xjm = mapping.index(xj)  ## for indexing the distance matrix 
            xkm = mapping.index(xk)  ## for indexing the distance matrix
            dja = (D.get(xi, xjm) + D.get(xjm, xkm) - D.get(xi, xkm))/2
            dka = (D.get(xi, xkm) + D.get(xjm, xkm) - D.get(xi, xjm))/2
            if (abs(dja - T.paths[xj, root]) < minEdgeLength/2 and abs(dka - T.paths[xk, root]) < minEdgeLength/2):
                xiIndex = T.addVertex({root : (D.get(xi, xjm) + D.get(xi, xkm) - D.get(xjm, xkm))/2})
                mapping.append(xiIndex)
                done = True
            elif (abs(dja - T.paths[xj, parent]) < minEdgeLength/2 and abs(dka - T.paths[xk, parent]) < minEdgeLength/2):
                xiIndex = T.addVertex({parent: (D.get(xi, xjm) + D.get(xi, xkm) - D.get(xjm, xkm))/2})
                mapping.append(xiIndex)
                done = True
            else:
#                 tmp = T.addVertexBetween(root, parent, 
#                                         (D.get(xi, xjm)+D.get(xkm, xjm) - D.get(xi, xkm))/2 - T.paths[xj, root],
#                                         (D.get(xkm, xi)+D.get(xkm, xjm) - D.get(xi, xjm))/2 - T.paths[xk, parent])
                tmp = T.addVertexBetween(root, parent,
                                         (D.get(xi, xjm)+D.get(xkm, xjm) - D.get(xi, xkm))/2 - T.paths[xj, root])
                xiIndex = T.addVertex({tmp : (D.get(xi, xjm) + D.get(xi, xkm) - D.get(xjm, xkm))/2})
                mapping.append(xiIndex)
                done = True


    return (T, mapping)
