from generate import *
import tree
import numpy as np
import graph
from util import *

################################################################
## 
##    Test Suite
##
#################################################################
def test_generate():
    import tree
    T = generate(2)
    assert (T == np.array([[0, 1], [1, 0]])).all()
    T = generate(30)
    T = tree.Tree(T)
    assert T.connected()
    assert not T.cycleCheck()

def test_generate_binary():
    import tree
    T = generate_binary(20)
    T = tree.Tree(T)
    assert T.connected()
    assert not T.cycleCheck()

def test_generate_caterpillar():
    import tree
    T = generate_caterpillar(4)
    assert (T == np.array([[0, 0, 0, 0, 1, 0],
                           [0, 0, 0, 0, 1, 0],
                           [0, 0, 0, 0, 0, 1],
                           [0, 0, 0, 0, 0, 1],
                           [1, 1, 0, 0, 0, 1],
                           [0, 0, 1, 1, 1, 0]])).all()
    T = tree.Tree(T)
    assert T.connected()
    assert not T.cycleCheck()

def test_generate_balanced_binary():
    T = generate_balanced_binary(6)
    assert (T == np.array([[0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
                           [1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
                           [1, 0, 0, 0, 0, 0, 1, 1, 0, 0],
                           [1, 0, 0, 0, 0, 0, 0, 0, 1, 1],
                           [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]])).all()

def test_generate_kway_balanced():
    T = generate_kway_balanced(9, 3)
    assert (T == np.array([[0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
                           [1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
                           [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
                           [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]])).all()

def test_generate_rooted_balanced_binary():
    T = generate_rooted_balanced_binary(8)
    assert (T == np.array([[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
                           [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
                           [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
                           [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
                           [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]])).all()

def test_generate_unbalanced_binary():
    (T, root, mapping) = generate_unbalanced_binary(6, 1)
    assert (T.paths[np.ix_(T.findleaves(), T.findleaves())] == np.array([
                [0, 2, 4, 4, 4, 4],
                [2, 0, 4, 4, 4, 4],            
                [4, 4, 0, 2, 4, 4],
                [4, 4, 2, 0, 4, 4],            
                [4, 4, 4, 4, 0, 2],
                [4, 4, 4, 4, 2, 0]])).all()
    (T, root, mapping) = generate_unbalanced_binary(1, 1)
    assert (T.paths == np.array([0])).all()

def test_generate_star():
    T = generate_star(4)
    assert (T == np.array([[0, 0, 0, 0, 1],
                           [0, 0, 0, 0, 1],
                           [0, 0, 0, 0, 1],
                           [0, 0, 0, 0, 1],
                           [1, 1, 1, 1, 0]])).all()

def test_generate_erdos_renyi():
    (M,G) = generate_erdos_renyi(4, 1)
    assert (G == np.matrix([[0,1,1,1],[1, 0, 1, 1], [1, 1, 0, 1], [1, 1, 1, 0]])).all()
    assert type(M) == graph.Graph
    
    (M, G) = generate_erdos_renyi(100, 0.1)
    assert(np.triu(G) == np.tril(G).T).all()

def test_generate_clique_fringe():
    (M, G) = generate_clique_fringe(3)
    assert (G == np.matrix([[0, 0, 0, 1, 0, 0],
                            [0, 0, 0, 0, 1, 0],
                            [0, 0, 0, 0, 0, 1],
                            [1, 0, 0, 0, 1, 1],
                            [0, 1, 0, 1, 0, 1],
                            [0, 0, 1, 1, 1, 0]])).all()


def test_assign_edge_weights():
    import tree
    T = generate_star(4)
    T2 = assign_edge_weights(T)
    T2 = tree.Tree(T2)
    assert T2.connected()
    assert not T2.cycleCheck()
    (l1, l2) = np.nonzero(T2.adjMat)
    (l3, l4) = np.nonzero(T)
    assert (l1 == l3).all()
    assert (l2 == l4).all()
