import tree, forest
import numpy as np
import random, itertools, copy
from util import *
import time
import hierarchical
from collections import Counter

def compute_S(D, cands, nodes):
    """
    Helper routine to compute similarity scores for a set cands of nodes
    (subset of nodes). S[i,j] = max_k | \{k': D[i,k] - D[j,k] = D[i,k'] - D[j,k']\}|
    
    cands -- subset of nodes for which to compute similarity scores
    
    nodes -- if nodes is a subset of all the nodes in the tree, we'll chose
    k' from others. This give stronger guarantees on the ability to find a
    partition
    """
    n = len(cands)
    I = np.zeros((n,n))
    if (len(nodes) < D.data.shape[0]):
        others = random.sample(set(range(D.data.shape[0])).difference(set(nodes)), min(n, D.data.shape[0]-len(nodes)))
        for i in range(n):
            for j in range(i, n):
                I[i,j] = np.sum([1 if D.get(cands[i],x[0]) + D.get(cands[j], x[1]) == D.get(cands[i], x[1]) + D.get(cands[j], x[0]) else 0 for x in zip(others, cands)])
                I[j,i] = I[i,j]
        return I
    for i in range(n):
        print i
        for j in range(i, n):
            if len(nodes) < D.data.shape[0]:
                others = random.sample(set(range(D.data.shape[0])).difference(set(nodes)), min(n, D.data.shape[0]-len(nodes)))
                I[i,j] = np.sum([1 if D.get(cands[i],x[0]) + D.get(cands[j], x[1]) == D.get(cands[i], x[1]) + D.get(cands[j], x[0]) else 0 for x in zip(others, cands)])
            else:
                scores = [int(D.get(cands[i], cands[x]) - D.get(cands[j], cands[x])) for x in range(n)]
                score, count = max([(x, scores.count(x)) for x in set(scores)], key = lambda x: x[1])
                I[i,j] = count
            I[j,i] = I[i,j]
    return I

figcount = 0

def cluster_vote(D, m = lambda x: np.log2(x)**2, active=False, rooted=False, verbose=False, perf=False, measure=False, recurse=True):
    """
    Clustering and voting recursive hierarchical algorithm for robust active tree
    reconstruction. This algorithm operates on the distance matrix.
    Arguments:
    D -- distance matrix on leaves of a tree
    active -- if True: subsample nodes + cluster + vote. Otherwise just cluster
    rooted -- if False looks for at least 3 groups in the first split, other just 2
    verbose -- print out debugging information
    """
    realN = D.data.shape[0]

    def find_partition(I, cands, nodes, target):
        """
        Partition cands into groups based on similarity scores in I. This does a
        "exhaustive" search over potential splitting values until it finds one
        that partitions the nodes into more groups that specified by target.
        
        I -- similarity scores (a |cands| x |cands| matrix)

        cands -- list of nodes to partition

        nodes -- all of the nodes in this subtree

        target -- require more than target sets in the partition
        """
        F = hierarchical.single_linkage(I, cands, k=target)
        if len(F.active()) < target:
            ## then we couldn't find a good split so just make a star
            n = len(nodes)
            F = forest.Forest()
            rootidx = F.addTree(tree.Tree(np.array([0])), 0, [None])
            for i in nodes:
                Tsub = tree.Tree(np.array([0]))
                idx = F.addTree(Tsub, 0, [i])
                rootidx = F.merge(rootidx, idx, 0, computeEdgeLength(D, Tsub, 0, [i], [[j] for j in nodes if j != i], nodes))
        return F
            
    def cluster_vote_sub(nodes, active=False, rooted=True, verbose=False):
        """
        Recursive subroutine for clustering and voting algorithm. Algorithm is
        broken into several components:
        1. (Optional) Sample candidates to cluster
        2. Compute pairwise similarity scores between candidates
        3. Cluster candidates based on similarity scores
        4. (Optional) Vote to place remaining nodes in clusters
        5. Recurse on clusters
        6. Compute edge distances and construct tree.
        """
        global figcount
        n = len(nodes)
        mapping = nodes
        if n <= 1:
            return (tree.Tree(np.array([0])), nodes, 0)

        start = time.clock()
        ## 1. (Optional) Subsampling
        cands = nodes
        if active and n > np.ceil(m(n)):
            cands = random.sample(nodes, int(np.ceil(m(n))))
        stop = time.clock()
        if perf: print "Subsampling: %f" % (stop-start)

        start = time.clock()
        ## 2. Compute pairwise similarity
        if verbose: print "computing I"
        I = compute_S(D, cands, nodes)
        stop = time.clock()
        if perf: print "Similarity: %f" % (stop-start)

        start = time.clock()
        ## 3. Find a partition
        if verbose: print "finding partition"
        ## do a binary search to find a good value to split on
        F = find_partition(I, cands, nodes, 2 if rooted else 3)
        if len(F.active()) == 1:
            return F.active()[0]
        stop = time.clock()
        if perf: print "Partitioning: %f" % (stop - start)

        start = time.clock()
        if measure: print "Used: %d out of %d" % (D.numUsed(), (n*(n+1))/2)
        ## 4. (Optional) Vote
        if verbose: print "voting"
        if not rooted: assert len([f for f in F.forest if f != None]) >= 3
        inds = [i for i in range(len(F.forest)) if F.forest[i] != None]
        subtrees = []
        for i in inds:
            subtrees.append([x for x in F.mappings[i] if x != None])
        if rooted: subtrees.append([x for x in range(D.data.shape[0]) if x not in nodes])
        subtrees = kwayvote(D, subtrees, nodes)
        if rooted: 
            st = subtrees.pop()
            ## for the nodes that were placed into the wrong group we don't
            ## really know what to do so just place them randomly. TODO: a better
            ## solution would be to place them in the group with second most
            ## votes
            for i in [x for x in st if x in nodes]:
                subtrees[random.sample(range(len(subtrees)), 1)[0]].append(i)
        stop = time.clock()
        if perf: print "Voting: %f" % (stop - start)

        assert len([x for x in nodes if x not in flatten(subtrees)]) == 0

        start = time.clock()
        if measure: print "Used: %d out of %d" % (D.numUsed(), (n*(n+1))/2)
        if measure:
            D.show(file="./%d.eps" % (figcount))
            figcount += 1
        if not recurse:
            return (subtrees, None, None)
        ## 5. Recurse + compute edge weights
        if verbose: print "recursing"
        F = forest.Forest()
        rootidx = F.addTree(tree.Tree(np.array([0])), 0, [None])
        for i in range(len(subtrees)):
            (Tsub, Msub, Rsub) = cluster_vote_sub(copy.copy(subtrees[i]), active=active)
            idx = F.addTree(Tsub, Rsub, Msub)
            rootidx = F.merge(rootidx, idx, 0, 
                              computeEdgeLength(D, Tsub, Rsub, Msub, [subtrees[j] for j in range(len(subtrees)) if j != i], nodes), 
                              computePaths = False)
        stop = time.clock()
        if perf: print "Recursing: %f" % (stop - start)
        return (F.forest[rootidx], F.mappings[rootidx], F.roots[rootidx])
    
    (F, mapping, root) = cluster_vote_sub(range(D.data.shape[0]), active=active, rooted=rooted, verbose=verbose)
    if recurse:
        F.allpairsshortestpaths()
        mapping = translate_mapping(mapping)
    return (F, mapping, root)


def flatten(lols):
    """
    Perform on level of flatting for a list of lists
    """
    return itertools.chain.from_iterable(lols)

def kwayvote(D, groups, nodes):
    """
    Voting algorithm for kway. This is a subroutine of the cluster_voting
    algorithm.

    D -- distance matrix on leaves
    groups -- clustering of a subset of the nodes in "nodes"
    nodes -- all of the nodes we want to cluster

    returns a list of lists, where each list corresponds to one of the clusters
    """
    newgroups = [[] for x in groups]
    toplace = [x for x in nodes if x not in flatten(groups)]
    for i in toplace:
        cands = range(len(groups))
        while len(cands) > 1:
            chosen = random.sample(cands, 3)
            tvotes = min([len(x) for x in groups])
            votes = [0, 0, 0, 0]
            for j in range(tvotes):
                voters = [random.sample(groups[x], 1) for x in chosen]
                if is_star(D, voters[0], voters[1], voters[2], i): votes[3] += 1
                elif pairs_with(D, voters[0], voters[1], voters[2], i): votes[2] += 1
                elif pairs_with(D, voters[0], voters[2], voters[1], i): votes[1] += 1
                elif pairs_with(D, voters[1], voters[2], voters[0], i): votes[0] += 1
            x = np.argmax(votes)
            if x == 3:
                for j in chosen:
                    if len(cands) > 3: cands.remove(j)
            else:
                newgroups[chosen[x]].append(i)
                break
#         if len(cands) == 1:
#             newgroups[cands[0]].append(i)
    [groups[i].extend(newgroups[i]) for i in range(len(groups))]
    return groups


def computeEdgeLength(D, Tsub, root, mapping, subtrees, nodes):
    """
    Compute the length of the edge connecting Tsub to the main tree. There are
    two ways to do this 

    1. if Tsub is a leaf, then do triplet tests: take that node and take two
    nodes from different subtrees of the main tree and compute the distance to
    the shared common ancestor
    
    2. if Tsub is not a leaf, then do quartet tests: take two nodes from
    different subtrees of Tsub and two nodes from different subtrees of the main
    tree. The quartet between these four nodes has as its connecting edge the
    edge connecting Tsub to the main tree. 

    D -- distance matrix on leaves

    Tsub -- a rooted subtree (A Tree object)

    root -- (int) the index of the root node of Tsub

    mapping -- an array where mapping[i] is the index in D of the ith leaf of
    Tsub (this is what I sometimes call a reverse mapping)

    subtrees -- lists of leaves in the other subtrees of the parent node (not Tsub)
    
    nodes -- list of nodes underneath the parent node

    returns a float, which is the estimated distance
    """
    if len(nodes) != D.data.shape[0]:
        subtrees.append([x for x in range(D.data.shape[0]) if x not in nodes])
    subsubs = [Tsub.findLeavesBelow(i, root) for i in Tsub.neighbors(root)]
    if len(subsubs) == 0:
        ## then root is the only node in the tree
        dists = []
        for i in range(10):
            chosen = random.sample(range(len(subtrees)), 2)
            voters = [random.sample(subtrees[x], 1) for x in chosen]
            dists.append(max(0, float(1.0/2.0*(D.get(mapping[root], voters[0]) + D.get(mapping[root], voters[1]) - D.get(voters[0], voters[1])))))
    else:
        dists = []
        for i in range(10):
            chosen1 = random.sample(range(len(subsubs)), 2)
            chosen2 = random.sample(range(len(subtrees)), 2)
            xi = mapping[random.sample(subsubs[chosen1[0]], 1)[0]]
            xj = mapping[random.sample(subsubs[chosen1[1]], 1)[0]]
            xk = random.sample(subtrees[chosen2[0]], 1)[0]
            xl = random.sample(subtrees[chosen2[1]], 1)[0]
            dists.append(max(0, float(D.get(xi,xk) + D.get(xj, xl) -
                               D.get(xi, xj) - D.get(xk, xl))/2.0))
    (score, count) = max([(x, dists.count(x)) for x in set(dists)], key = lambda x: x[1])
    return score
