from cluster_vote import *
import generate, counting_matrix, tree
import numpy as np

################################################################
## 
##    Test Suite
##
#################################################################
def test_cluster_vote():
    T = tree.Tree(generate.generate_balanced_binary(3*4))
    D = T.paths[np.ix_(T.findleaves(), T.findleaves())]
    (F, mapping, root) = cluster_vote(counting_matrix.CountingMatrix(D), active=False)
    assert (F.paths[np.ix_(mapping, mapping)] == D).all()
    (F, mapping, root) = cluster_vote(counting_matrix.CountingMatrix(D), active=True)
    assert (F.paths[np.ix_(mapping, mapping)] == D).all()

def test_cluster_vote_no_recurse():
    T = tree.Tree(generate.generate_balanced_binary(3*4))
    D = T.paths[np.ix_(T.findleaves(), T.findleaves())]
    (groups, mapping, root) = cluster_vote(counting_matrix.CountingMatrix(D), active=False, recurse=False)
    l1 = [len([x for x in l if x < 4]) for l in groups]
    l2 = [len([x for x in l if x >= 4 and x < 8]) for l in groups]
    l3 = [len([x for x in l if x >= 8]) for l in groups]
    try:
        p1 = l1.index(4)
        p2 = l2.index(4)
        p3 = l3.index(4)
        if p1 != p2 and p2 != p3 and p1 != p3:
            pass
        else:
            assert False, "failure to find correct split"
    except:
        assert False, "failure to find correct split, exception thrown"


def test_flatten():
    l = [[1,2,3], [4,5,6], [None]]
    out = flatten(l)
    assert reduce(lambda x, y: str(x) + " " + str(y), out) == '1 2 3 4 5 6 None'

def test_kwayvote():
    T = tree.Tree(generate.generate_balanced_binary(3*4))
    D = T.paths[np.ix_(T.findleaves(), T.findleaves())]
    groups = kwayvote(counting_matrix.Matrix(D), [[0,1,2], [4,5,6], [8,9,10]], range(12))
    assert groups == [[0,1,2,3], [4,5,6,7], [8,9,10,11]]
    groups = kwayvote(counting_matrix.Matrix(D), [[0, 1], [2], [4,5,6,7,8,9,10,11]], [0,1,2,3])
    groups.pop()
    assert groups == [[0,1], [2,3]]
    T = tree.Tree(np.array([[0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                            [1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
                            [1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
                            [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
                            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
                            [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                            [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]]))
    D = T.paths[np.ix_(T.findleaves(), T.findleaves())]
    groups = kwayvote(counting_matrix.Matrix(D), [[0], [2], [4], [6]], range(8))
    assert groups == [[0, 1], [2, 3], [4, 5], [6, 7]]


def test_computeEdgeLengths():
    T = tree.Tree(generate.generate_balanced_binary(3*2))
    D = T.paths[np.ix_(T.findleaves(), T.findleaves())]
    Tsub = tree.Tree(generate.generate_rooted_balanced_binary(2))
    assert computeEdgeLength(counting_matrix.Matrix(D), Tsub, 0, [None, 0, 1], [[2, 3], [4, 5]], range(6)) == 1
    Tsub = tree.Tree(np.array([0]))
    assert computeEdgeLength(counting_matrix.Matrix(D), Tsub, 0, [0], [[1]], [0, 1]) == 1
