import tree

import numpy as np
import scipy.sparse

class Subtree(tree.Tree):
    """
    A subtree is more or less a tree, except it maintains a mapping that relates
    its nodes to the nodes in its parent tree. It also allows for prunings,
    which Trees do not. Subtrees do NOT keep track of paths

    local state:
    adjMat = a coo_sparse matrix of weighted edges
    degree = the maximum degree in the subtree
    existent = the set of nodes that are existent in the tree
    """

    def __init__(self, tree):
        self.adjMat = tree.adjMat
        self.degree = tree.degree
        self.existent = set(range(self.adjMat.shape[0]))

    def deleteUnreachable(self, parent, root):
        self.existent = set([parent])
        n = self.adjMat.shape[0]
        O = self.adjMat.todok()
        M = scipy.sparse.dok_matrix((n,n))
        queue = [(root, parent)]
        while len(queue) > 0:
            [curr, currp] = queue.pop()
            self.existent.add(curr)
            M[curr, currp] = O[curr, currp]
            M[currp, curr] = O[currp, curr]
            children = [x for x in self.neighbors(curr) if x != currp]
            queue.extend([(x, curr) for x in children])
        self.adjMat = M.tocoo()
        self.computeMaxDegree()

    def deleteSubtrees(self, root, strees):
        self.existent = set([root])
        n = self.adjMat.shape[0]
        O = self.adjMat.todok()
        M = scipy.sparse.dok_matrix((n,n))
        queue = [(x, root) for x in self.neighbors(root) if x not in strees]
        while len(queue) > 0:
            [curr, currp] = queue.pop()
            self.existent.add(curr)
            M[curr, currp] = O[curr, currp]
            M[currp, curr] = O[currp, curr]
            children = [x for x in self.neighbors(curr) if x != currp]
            queue.extend([(x, curr) for x in children])
        self.adjMat = M.tocoo()
        self.computeMaxDegree()


    def connected(self):
        n = self.adjMat.shape[0]
        nzs = self.adjMat.nonzero()
        nonexistent = [x for x in range(n) if len(np.nonzero(nzs[0] == x)[0]) == 0]
        existent = [x for x in range(n) if len(np.nonzero(nzs[0] == x)[0]) > 0]
        visited = set([])
        def dfs(start, prev):
            visited.add(start)
            cands = self.adjMat.col[np.nonzero(self.adjMat.row == start)]
            cands = [c for c in cands if c != prev]
            cands = [c for c in cands if c not in visited]
            for c in cands:
                dfs(c, start)
        if len(existent) == 0:
            return True
        dfs(existent[0], None)
        return (len(visited) == len(existent))
