import numpy as np
import generate, tree
from counting_matrix import *


class TestCountingMatrix(object):
    def test_constructor(self):
        arr = np.array([[0, 1, 1], [1, 0, 0], [1, 0, 0]])
        C = CountingMatrix(arr)
        assert C.numUsed() == 0
        M = Matrix(arr)
        assert (M.data == arr).all()
        arr = np.matrix([[0, 1, 1], [1, 0, 0], [1, 0, 0]])
        C = CountingMatrix(arr, type=np.matrix)
        assert(type(C.used) == np.matrix)

    def test_get(self):
        arr = np.array([[1, 2, 3], [4,5,6], [7,8,9]])
        C = CountingMatrix(arr)
        assert C.numUsed() == 0
        assert C.get(0, 2) == 3
        assert C.numUsed() == 1
        assert (C.used.toarray() == np.array([[0, 0, 1], [0, 0, 0], [1, 0, 0]])).all()
        M = Matrix(arr)
        assert M.get(0,2) == 3

    def test_num_used(self):
        arr = np.array([[1, 2, 3], [4,5,6], [7,8,9]])
        M = Matrix(arr)
        Sub = M.submatrix([0, 1])
        assert (Sub.data == np.array([[1,2], [4,5]])).all()

    def test_submatrix(self):
        arr = np.array([[1,2,3], [4,5,6], [7,8,9]])
        M = Matrix(arr)
        M2 = M.submatrix([0,1])
        assert type(M2) == Matrix
        assert(M2.data == np.array([[1,2], [4,5]])).all()

        M = CountingMatrix(arr)
        M2 = M.submatrix([0,1])
        assert type(M2) == CountingMatrix
        assert(M2.data == np.array([[1,2], [4,5]])).all()

    def test_row(self):
        arr = np.array([[1,2,3], [4,5,6], [7,8,9]])
        M = Matrix(arr)
        M2 = M.row([0])
        assert (M2 == [[1],[4],[7]]).all()

        M = CountingMatrix(arr)
        M2 = M.row([0])
        assert (M2 == [[1],[4],[7]]).all()
        assert M.numUsed() == 3

    def test_all(self):
        arr = np.array([[1,2,3], [4,5,6], [7,8,9]])
        M = Matrix(arr)
        M2 = M.all()
        M.markasused(1,1)
        assert (M2 == arr).all()

        M = CountingMatrix(arr)
        M2 = M.all()
        assert (M2 == arr).all()
        assert M.numUsed() == 6
        

    def test_type(self):
        arr = np.array([[1,2,3], [4,5,6], [7,8,9]])
        C = CountingMatrix(arr, type=np.matrix)
        C.markasused([0, 1], [1, 2])
        assert C.numUsed() == 4

    def test_get_all(self):
        M = np.matrix([[1,2,3,4], [5,6,7,8], [9,10,11,12], [13,14,15,16]])
        C = CountingMatrix(M)
        M2 = C.getall(np.matrix([[0, 1, 0, 1], [1, 0, 0, 0], [0, 0, 0, 0], [1, 0, 0, 0]]))
        assert (M2 == np.matrix([[0, 2, 0, 4], [5, 0, 0, 0], [0, 0, 0, 0], [13, 0, 0, 0]])).all()

    def test_ancestry(self):
        T = tree.Tree(generate.generate_rooted_balanced_binary(8))
        D = T.paths[np.ix_(T.findleaves(), T.findleaves())]
        C = CountingMatrix(D)
        Sub = C.submatrix([4,5,6,7])
        assert (Sub.data == np.array([[0, 2, 4, 4],
                                      [2, 0, 4, 4],
                                      [4, 4, 0, 2],
                                      [4, 4, 2, 0]])).all()

        x = Sub.get(0, 1)
        assert C.numUsed() == 1
        
