import tree
import scipy.sparse

class Forest(object):
    """
    A Forest object is a collection of trees that allows merging.
    """
    def __init__(self):
        self.forest = []
        self.mappings = []
        self.roots = []

    def addTree(self, tr, root, mapping):
        idx = len(self.forest)
        self.forest.append(tr)
        self.roots.append(root)
        self.mappings.append(mapping)
        return idx

    def merge(self, i, j, l1, l2, computePaths = True):
        t1 = self.forest[i]
        t2 = self.forest[j]
        m1 = self.mappings[i]
        m2 = self.mappings[j]
        r1 = self.roots[i]
        r2 = self.roots[j]

        if l1 != 0 and l2 != 0:
            n = t1.adjMat.shape[0] + t2.adjMat.shape[0]+1
            M = scipy.sparse.dok_matrix((n,n))
        else:
            n = t1.adjMat.shape[0] + t2.adjMat.shape[0]
            M = scipy.sparse.dok_matrix((n,n))
        for (key,val) in t1.adjMat.todok().iteritems():
            M[key[0], key[1]] = val
            M[key[1], key[0]] = val
        for (key, val) in t2.adjMat.todok().iteritems():
            M[key[0]+t1.adjMat.shape[0], key[1]+t1.adjMat.shape[0]] = val
            M[key[1]+t1.adjMat.shape[0], key[0]+t1.adjMat.shape[0]] = val

        if l1 == 0 and l2 == 0:
            l2 = 1*10**-20
        if l1 == 0:
            M[r1, r2+t1.adjMat.shape[0]] = l2
            M[r2+t1.adjMat.shape[0], r1] = l2
            root = r1
        elif l2 == 0:
            M[r1, r2+t1.adjMat.shape[0]] = l1
            M[r2+t1.adjMat.shape[0], r1] = l1
            root = r2+t1.adjMat.shape[0]
        else:
            M[n-1,r1] = l1
            M[r1, n-1] = l1
            M[n-1, r2+t1.adjMat.shape[0]] = l2
            M[r2+t1.adjMat.shape[0], n-1] = l2
            root = n-1
            m2.extend([None])
        m1.extend(m2)
        self.forest[i] = None
        self.forest[j] = None
        self.mappings[i] = None
        self.mappings[j] = None
        self.roots[i] = 0
        self.roots[j] = 0
        idx = len(self.forest)
        self.forest.append(tree.Tree(M.tocoo(), computePaths))
        self.roots.append(root)
        self.mappings.append(m1)
        return idx

    def active(self):
        return [(self.forest[i], self.mappings[i], self.roots[i]) for i in range(len(self.forest)) if self.forest[i] != None]
