from tree import *
import generate
from nose.tools import raises

################################################################
## 
##    Test Suite
##
#################################################################
class TestTree():
    def setUp(self):
        self.tiny = Tree(np.array([[0, 1], [1, 0]]))
        self.small = Tree(generate.generate_rooted_balanced_binary(2))
        self.large = Tree(generate.generate_rooted_balanced_binary(8))
        self.unrooted = Tree(generate.generate_balanced_binary(12))

    @raises(AssertionError)
    def test_constructor(self):
        T = Tree(np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]))

    @raises(AssertionError)
    def test_constructor2(self):
        T = Tree(np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]]))
            
    @raises(Exception)
    def test_addEdge(self):
        self.large.addEdge(0, 1)

    @raises(Exception)
    def test_addEdges(self):
        self.large.addEdges({})

    def test_addVertex(self):
        self.small.addVertex({0 : 1})
        assert (self.small.paths == np.array([[0, 1, 1, 1], [1, 0, 2, 2], [1, 2, 0, 2], [1, 2, 2, 0]])).all()
        assert self.small.degree == 3

    def test_addVertexBetween(self):
        self.small.addVertexBetween(0, 1, 0.25)
        assert (self.small.paths == np.array([[0, 1, 1, 0.25], [1, 0, 2, 0.75], [1, 2, 0, 1.25], [0.25, 0.75, 1.25, 0]])).all()
        assert self.small.degree == 2

        self.tiny.addVertexBetween(0, 1, 0.25)
        assert self.tiny.degree == 2
        

    def test_findLeavesBelow(self):
        assert (self.small.findLeavesBelow(0, 1) == np.array([2])).all()
        T = Tree(generate.generate_rooted_balanced_binary(4))
        assert len(np.setxor1d(T.findLeavesBelow(0, 1), [5,6])) == 0
        assert len(np.setxor1d(T.findLeavesBelow(1, 0), [3,4])) == 0

    def test_getRoot(self):
        [root, parent] = self.small.getRoot()
        s = float(len(self.small.findleaves()))
        k = self.small.degree
        assert len(self.small.findLeavesBelow(root, parent)) >= s/(k+1)
        assert len(self.small.findLeavesBelow(root, parent)) <= s*k/(k+1)

        [root, parent] = self.large.getRoot()
        s = float(len(self.large.findleaves()))
        k = self.large.degree
        assert len(self.large.findLeavesBelow(root, parent)) >= s/(k+1)
        assert len(self.large.findLeavesBelow(root, parent)) <= s*k/(k+1)

        [root, parent] = self.unrooted.getRoot()
        s = float(len(self.unrooted.findleaves()))
        k = self.unrooted.degree
        assert len(self.unrooted.findLeavesBelow(root, parent)) >= s/(k+1)
        assert len(self.unrooted.findLeavesBelow(root, parent)) <= s*k/(k+1)
    
    def test_sharedPaths(self):
        M = self.large.sharedPaths()
        assert (M[np.ix_(self.large.findleaves(), self.large.findleaves())] == np.array([[3, 2, 1, 1, 0, 0, 0, 0],
                                                                                         [2, 3, 1, 1, 0, 0, 0, 0],
                                                                                         [1, 1, 3, 2, 0, 0, 0, 0],
                                                                                         [1, 1, 2, 3, 0, 0, 0, 0],
                                                                                         [0, 0, 0, 0, 3, 2, 1, 1],
                                                                                         [0, 0, 0, 0, 2, 3, 1, 1],
                                                                                         [0, 0, 0, 0, 1, 1, 3, 2],
                                                                                         [0, 0, 0, 0, 1, 1, 2, 3]])).all()
    def test_extendTree(self):
        T = Tree(generate.generate_rooted_balanced_binary(4))
        T2 = T.extendTree(2)
        P = T2.paths
        P = P[np.ix_(T2.findleaves(), T2.findleaves())]
        assert (P == np.array([[0, 2, 4, 4, 6, 6, 6, 6],
                               [2, 0, 4, 4, 6, 6, 6, 6],
                               [4, 4, 0, 2, 6, 6, 6, 6],
                               [4, 4, 2, 0, 6, 6, 6, 6],
                               [6, 6, 6, 6, 0, 2, 4, 4],
                               [6, 6, 6, 6, 2, 0, 4, 4],
                               [6, 6, 6, 6, 4, 4, 0, 2],
                               [6, 6, 6, 6, 4, 4, 2, 0]])).all()
