from sequoia import *
import tree, generate, counting_matrix
import numpy as np

def test_sequoia():
    T = tree.Tree(np.array([[0, 0, 0, 0, 0, 1],
                            [0, 0, 0, 0, 1, 0],
                            [0, 0, 0, 0, 1, 0],
                            [0, 0, 0, 0, 0, 1],
                            [0, 1, 1, 0, 0, 1],
                            [1, 0, 0, 1, 1, 0]]))
    D = T.paths[np.ix_(T.findleaves(), T.findleaves())]
    (T, mapping) = sequoia(counting_matrix.CountingMatrix(D))
    assert (T.paths[np.ix_(mapping, mapping)] == D).all()

    T = tree.Tree(generate.generate_balanced_binary(3*8))
    D = T.paths[np.ix_(T.findleaves(), T.findleaves())]
    (T, mapping) = sequoia(counting_matrix.CountingMatrix(D))
    assert (T.paths[np.ix_(mapping, mapping)] == D).all()

