from pearl_reconstruct import *
import generate, tree
import numpy as np

from counting_matrix import *

################################################################
## 
##    Test Suite
##
#################################################################

def test_pearl_reconstuct():
    D = np.array([[0, 3, 3], [3, 0, 3], [3, 3, 0]], dtype=np.float32)
    (T, mapping) = pearl_reconstruct(CountingMatrix(D))
    assert (T.adjMat.toarray() == np.array([[0, 0, 0, 1.5], [0, 0, 0, 1.5], [0, 0, 0, 1.5], [1.5, 1.5, 1.5, 0]])).all()
    assert (T.paths == np.array([[0, 3, 3, 1.5], [3, 0, 3, 1.5], [3, 3, 0, 1.5], [1.5, 1.5, 1.5, 0]])).all()
    D = np.array([[0, 2, 3, 3], [2, 0, 3, 3], [3, 3, 0, 2], [3, 3, 2, 0]])
    (T, mapping) = pearl_reconstruct(CountingMatrix(D))
    assert (T.paths[np.ix_(mapping, mapping)] == D).all()

    Tin = tree.Tree(generate.generate(5))
    leaves = Tin.findleaves()
    D = Tin.paths[np.ix_(leaves, leaves)]
    (T, mapping) = pearl_reconstruct(CountingMatrix(D))
    assert(T.paths[np.ix_(mapping, mapping)] == D).all()
    (T, mapping) = pearl_reconstruct(CountingMatrix(D), rootSelect=False)
    assert(T.paths[np.ix_(mapping, mapping)] == D).all()
    
    Tin = tree.Tree(generate.generate(10))
    leaves = Tin.findleaves()
    D = Tin.paths[np.ix_(leaves, leaves)]
    (T, mapping) = pearl_reconstruct(CountingMatrix(D))
    assert(T.paths[np.ix_(mapping, mapping)] == D).all()


    Tin = tree.Tree(generate.generate(30))
    leaves = Tin.findleaves()
    D = Tin.paths[np.ix_(leaves, leaves)]
    (T, mapping) = pearl_reconstruct(CountingMatrix(D))
    assert(T.paths[np.ix_(mapping, mapping)] == D).all()
    (T, mapping) = pearl_reconstruct(CountingMatrix(D), rootSelect=False)
    assert(T.paths[np.ix_(mapping, mapping)] == D).all()
