import numpy as np
import numpy.random
import random
import forest, tree, graph

def generate(n):
    """
    generates a tree with n leaves. Each internal node is randomly connected to
    at least 2 other nodes (leaves or internal nodes) ensuring that structure is
    a tree rather than a forest. This method returns an unweighted, undirected
    adjacency matrix.
    """
    tree = np.zeros([2*n-2, 2*n-2])
    nodesleft = set(range(n))
    index = n
    while len(nodesleft) > 2:
        x = numpy.random.poisson(1)+2;
        if (x > len(nodesleft)):
            toremove = set([])
            for i in nodesleft:
                tree[index, i] = 1;
                tree[i, index] = 1;
                toremove.add(i)
            nodesleft.difference_update(toremove)
#             nodesleft = nodesleft[min(x, len(nodesleft))+1:]
        else:
            s = random.sample(nodesleft, x)
            for i in s:
                tree[index, i] = 1;
                tree[i, index] = 1;
                nodesleft.remove(i)
        nodesleft.add(index)
        index += 1
    if len(nodesleft) == 2:
        x1 = nodesleft.pop()
        x2 = nodesleft.pop()
        tree[x1, x2] = 1;
        tree[x2, x1] = 1;

    tree = tree[0:index, 0:index]
    return tree

def generate_binary(n):
    """
    Generates a binary tree with n leaves. Returns an unweighted, undirected,
    adjacency matrix.
    """
    tree = np.zeros([2*n-2, 2*n-2])
    nodesleft = set(range(n))
    index = n
    while len(nodesleft) > 2:
        x = random.sample(nodesleft, 2)
        for i in x:
            tree[index, i] = 1;
            tree[i, index] = 1;
        nodesleft.difference_update(x)
        nodesleft.add(index)
        index += 1
    if len(nodesleft) == 2:
        x1 = nodesleft.pop()
        x2 = nodesleft.pop()
        tree[x1, x2] = 1
        tree[x2, x1] = 1
    tree = tree[0:index, 0:index]
    return tree

def generate_caterpillar(n):
    tree = np.zeros([2*n-2, 2*n-2])
    tree[0, n] = 1
    tree[n, 0] = 1
    index = n
    for i in range(1, n-2):
        tree[i, index] = 1
        tree[index, i] = 1
        tree[index, index+1] = 1
        tree[index+1, index] = 1
        index += 1
    tree[index, n-2] = 1
    tree[n-2, index] = 1
    tree[index,n-1] = 1
    tree[n-1, index] = 1
    index += 1
    return tree[0:index, 0:index]

def generate_balanced_binary_no_args():
    """
    Generates a balanced binary tree with n leaves. Returns an unweighted,
    undirected adjacency matrix. Note that n should be 3*2^k for some integer k.
    """
    n = 6
    assert np.log2(float(n)/3) == np.ceil(np.log2(float(n)/3)), "n is not of the form 3*2^k for some k"
    tree = np.zeros([2*n-2, 2*n-2])
    tree[0, 1:4] = [1, 1, 1]
    tree[1:4, 0] = [1, 1, 1]
    for i in xrange(1, n-2):
        tree[i, 2*(i+1)] = 1
        tree[i, 2*(i+1)+1] = 1
        tree[2*(i+1), i] = 1
        tree[2*(i+1)+1, i] = 1
    return tree

def generate_balanced_binary(n):
    """
    Generates a balanced binary tree with n leaves. Returns an unweighted,
    undirected adjacency matrix. Note that n should be 3*2^k for some integer k.
    """
    assert np.log2(float(n)/3) == np.ceil(np.log2(float(n)/3)), "n is not of the form 3*2^k for some k"
    tree = np.zeros([2*n-2, 2*n-2])
    tree[0, 1:4] = [1, 1, 1]
    tree[1:4, 0] = [1, 1, 1]
    for i in xrange(1, n-2):
        tree[i, 2*(i+1)] = 1
        tree[i, 2*(i+1)+1] = 1
        tree[2*(i+1), i] = 1
        tree[2*(i+1)+1, i] = 1
    return tree

def generate_kway_balanced(n, k):
    """
    Generates a balanced k-ary tree (where all nodes have degree k) on n
    leaves. Requires that n = k^c for some integer c. Returns an unweighted undirected adjacency matrix
    """
    assert float(np.log2(n))/np.log2(k) == np.ceil(float(np.log2(n))/np.log2(k))
    c = np.log2(n)/np.log2(k)
    size = int(n*sum([1.0/k**i for i in range(0, int(c+1))]))
    tree = np.zeros([size, size])
    for i in xrange(0, size - n):
        for j in range(k):
            tree[i, k*i+j+1] = 1
            tree[k*i+j+1, i] = 1
    return tree


def generate_rooted_balanced_binary_no_args():
    """
    Generates a rooted balanced binary tree with n leaves. Returns an
    unweighted, undirected adjacency matrix. Note that n should be 2^k for some
    integer k.
    """    
    n = 4
    assert np.log2(float(n)) == np.ceil(np.log2(float(n))), "n is not of the form 2^k"
    tree = np.zeros([2*n-1, 2*n-1])
    tree[0, 1:3] = [1,1]
    tree[1:3, 0] = [1,1]
    for i in xrange(1, n-1):
        tree[i, 2*i+1] = 1
        tree[i, 2*i+2] = 1
        tree[2*i+1, i] = 1
        tree[2*i+2, i] = 1
    return tree

def generate_rooted_balanced_binary(n):
    """
    Generates a rooted balanced binary tree with n leaves. Returns an
    unweighted, undirected adjacency matrix. Note that n should be 2^k for some
    integer k.
    """    
    assert np.log2(float(n)) == np.ceil(np.log2(float(n))), "n is not of the form 2^k"
    tree = np.zeros([2*n-1, 2*n-1])
    tree[0, 1:3] = [1,1]
    tree[1:3, 0] = [1,1]
    for i in xrange(1, n-1):
        tree[i, 2*i+1] = 1
        tree[i, 2*i+2] = 1
        tree[2*i+1, i] = 1
        tree[2*i+2, i] = 1
    return tree


def generate_unbalanced_binary(n, eta):
    def generate_unbalanced_binary_sub(n, eta, mapping):
        if n == 1:
            assert len(mapping) == 1
            return (tree.Tree(np.array([0])), 0, mapping)
        F = forest.Forest()
        rootidx = F.addTree(tree.Tree(np.array([0])), 0, [None])
        (T1, R1, M1) = generate_unbalanced_binary_sub(np.ceil(n/(1+eta)), eta, [mapping[i] for i in range(int(np.ceil(n/(1+eta))))])
        idx1 = F.addTree(T1, R1, M1)
        rootidx = F.merge(rootidx, idx1, 0, 1)
        (T2, R2, M2) = generate_unbalanced_binary_sub(n - np.ceil(n/(1+eta)), eta, [mapping[i] for i in range(int(np.ceil(n/(1+eta))), int(n))])
        idx2 = F.addTree(T2, R2, M2)
        rootidx = F.merge(rootidx, idx2, 0, 1)
        return (F.forest[rootidx], F.roots[rootidx], F.mappings[rootidx])
    if n == 1:
        return (tree.Tree(np.array([0])), 0, [])
    F = forest.Forest()
    rootidx = F.addTree(tree.Tree(np.array([0])), 0, [None])
    (T1, R1, M1) = generate_unbalanced_binary_sub(np.ceil(n/(2+eta)), eta, range(int(np.ceil(n/(2+eta)))))
    idx1 = F.addTree(T1, R1, M1)
    rootidx = F.merge(rootidx, idx1, 0, 1, computePaths = False)
    (T2, R2, M2) = generate_unbalanced_binary_sub(np.ceil(n/(2+eta)), eta, range(int(np.ceil(n/(2+eta))), int(2*np.ceil(n/(2+eta)))))
    idx2 = F.addTree(T2, R2, M2)
    rootidx = F.merge(rootidx, idx2, 0, 1, computePaths = False)
    (T3, R3, M3) = generate_unbalanced_binary_sub(n - 2*np.ceil(n/(2+eta)), eta, range(int(2*np.ceil(n/(2+eta))), int(n)))
    idx3 = F.addTree(T3, R3, M3)
    rootidx = F.merge(rootidx, idx3, 0, 1)
    return (F.forest[rootidx], F.roots[rootidx], F.mappings[rootidx])



def generate_star(n):
    tree = np.zeros([n+1, n+1])
    for i in xrange(n):
        tree[i, n] = 1
        tree[n, i] = 1
    return tree


def generate_erdos_renyi(n,p):
    G = np.matrix(np.random.binomial(1, p, [n,n]))
    G = np.triu(G,k=1) + np.triu(G,k=1).T
    M = graph.Graph(G)
    lsets = M.connectedComponents()
    besti = np.argmax([len(l) for l in lsets])
    best = lsets[besti]
    G = G[np.ix_(list(best), list(best))]
    M = graph.Graph(G)
    return (M, G)
            

def generate_clique_fringe(n):
    G = np.matrix(np.zeros([2*n, 2*n]))
    G[np.ix_(range(n, 2*n), range(n,2*n))] = np.ones([n,n]) - np.diagflat(np.repeat(1, n))
    for i in range(n):
        G[i,n+i] = 1
        G[n+i,i] = 1
    return (graph.Graph(G), G)

def generate_cross_graph():
    G = np.matrix([
            #0  1  2  3  4  5  6  7  8  9
            [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], # 0
            [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], # 1
            [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # 2
            [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], # 3
            [1, 0, 0, 0, 0, 1, 0, 0, 0, 1], # 4
            [0, 0, 0, 0, 1, 0, 1, 0, 1, 0], # 5
            [0, 0, 1, 0, 0, 1, 0, 1, 1, 0], # 6
            [0, 1, 0, 0, 0, 0, 1, 0, 0, 1], # 7
            [0, 0, 0, 1, 0, 1, 1, 0, 0, 0], # 8
            [0, 0, 0, 0, 1, 0, 0, 1, 0, 0]] # 9
                  )
    return graph.Graph(G)

def generate_clique_minus():
    G = np.matrix([
            [0, 0, 0, 0, 1, 0, 0, 0],
            [0, 0, 0, 0, 0, 1, 0, 0],
            [0, 0, 0, 0, 0, 0, 1, 0],
            [0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 1, 1, 0],
            [0, 1, 0, 0, 1, 0, 0, 1],
            [0, 0, 1, 0, 1, 0, 0, 1],
            [0, 0, 0, 1, 0, 1, 1, 0]
            ])
    return graph.Graph(G)

def assign_edge_weights(tree):
    for i in xrange(tree.shape[0]):
        for j in xrange(tree.shape[0]):
            if tree[i,j] == 1:
                x = numpy.random.uniform(0.1, 10)
                tree[i,j] = x
                tree[j,i] = x;
    return tree

