import numpy as np
import tree
import random

def sequoia(D):
    """
    This is the Sequoia algorithm from the "treeness" paper. It resembles
    pearl's algorithm in that it looks for shared common ancestors but it is
x    substantially different. This algorithm operates on the distance matrix.
    """

    def construct_anchor_tree(D):
        def insert_into_anchor_tree(i, D, T, curr, parent):
            neighbors = []
            neighbors.extend(T.neighbors(curr))
            if parent != None:
                neighbors.remove(parent)
            neighbors.append(curr)
            ## maximize Gromov Product for neighbors
            gprods = [D.get(x, root) - D.get(x, i) for x in neighbors]
            cands = [neighbors[x] for x in range(len(neighbors)) if gprods[x] == np.max(gprods)]
            if curr in cands:
                best = curr
            else:
                best = random.choice(cands)
            if best == curr:
                ## then attach i to curr
                T.addVertex({curr : 1})
            else:
                insert_into_anchor_tree(i, D, T, best, curr)
        root = 0
        T = tree.Tree(np.array([[0]]))
        for i in range(1, D.shape[0]):
            insert_into_anchor_tree(i, D, T, 0, None)
        return T

    def getPath(T, root, j):
        """
        find the path from root to j in T, returns an array
        """
        path = [root]
        while path[len(path)-1] != j:
            curr = path[len(path)-1]
            neighbors = T.neighbors(curr)
#             print neighbors
            neighbors = np.setdiff1d(neighbors,path)
            if len(neighbors) > 0:
                next = neighbors[np.argmin([T.paths[j, x] for x in neighbors])]
                path.append(next)
            else:
                return path
        return path

    def convert_anchor_tree(D, A):
        """
        Takes an anchor tree and converts it into a topology tree
        """
        T = tree.Tree(np.array([[0]]))
        n = D.shape[0]
        mapping = [None for x in range(n)]
        mapping[0] = 0
        root = 0
        visited = set()
        visited.add(0)
        queue = []
        queue.append(0)
        while len(queue) > 0:
            curr = queue.pop()
            neighbors = A.neighbors(curr)
            queue.extend(np.setdiff1d(set(neighbors), set(visited)))
            parent = np.intersect1d(set(neighbors), set(visited))
            visited.add(curr)
            if len(parent) == 0:
                continue
            assert len(parent) == 1
            parent = parent[0]
            gromov = float(D.get(curr, root)+D.get(parent, root)- D.get(parent, curr))/2
            ## Compute path from Root to Parent
            path = getPath(T, mapping[root], mapping[parent])
            ## now we have to find where along the path to insert the next node
            if len(path) == 1:
                idx = T.addVertex({path[0]: np.abs(D.get(root, curr) - gromov)})
#                 print "1adding edge with weight: %d" % (np.abs(D.get(root, curr) - gromov))
            else:
                ptr = 0
                while T.paths[path[ptr], mapping[root]] < gromov:
                    if ptr == len(path) - 1:
                        break
                    ptr += 1
                if T.paths[path[ptr], mapping[root]] <= gromov:
                    ## then this is a star
                    idx = T.addVertex({path[ptr] : np.abs(D.get(root, curr) - gromov)})
#                     print "2adding edge with weight: %d" % (np.abs(D.get(root, curr) - gromov))
                else:
                    idx1 = T.addVertexBetween(path[ptr], path[ptr-1], T.paths[path[ptr], mapping[root]] - gromov)
#                     print "3adding edge with weight: %d %d" % (T.paths[path[ptr], mapping[root]] - gromov, gromov - T.paths[path[ptr-1], mapping[root]])
                    idx = T.addVertex({idx1 : np.abs(D.get(root, curr) - gromov)})
#                     print "4adding edge with weight: %d" % (np.abs(D.get(root, curr) - gromov))
            mapping[curr] = idx

        return (T, mapping)

    mapping = range(D.shape[0])
    random.shuffle(mapping)
    C = D.submatrix(mapping)
    A = construct_anchor_tree(C)
    (T, m2) = convert_anchor_tree(C, A)
    m3 = [m2[mapping.index(x)] for x in range(D.shape[0])]
    return (T, m3)


def getDistMat(S):
    n = S[0].shape[0]
    O = np.zeros(S[0].shape)
    for i in range(n):
        for j in range(i, n):
            O[i,j] = np.median([x[i,j] for x in S])
            O[j,i] = O[i,j]
    return O
