import graph
import numpy as np
import scipy.sparse
import copy

class Tree(graph.Graph):
    """
    A Tree object. Maintains the invariant that it is a tree. I.e. you cannot
    add a vertex without adding an appropriate edge to connect it to the
    tree. This is a connected graph with out any cycles. Allows for more
    optimized computations of all pairs shortest paths.
    """
    def __init__(self, adjMat = np.zeros([0, 0]), computePaths = True):
        graph.Graph.__init__(self, adjMat, computePaths)
        assert self.connected() == True, "tree is not connected"
        assert self.cycleCheck() == False, "tree initialized with cycles"


    def addEdge(self, i, j, val=1):
        raise Exception('addEdge is an invalid operation on Trees')

    def addEdges(self, edges):
        raise Exception('addEdges is an invalid operation on Trees')

    def addVertex(self, connections = {}):
        assert len(connections) == 1, "Attempt to add node with either too many or too few connections"
        node = graph.Graph.addVertex(self, connections)

        ## update paths
        paths = np.zeros(self.adjMat.shape)
        paths[0:node, 0:node] = self.paths
        parent = connections.items()[0][0]
        dist = connections.items()[0][1]
        paths[node, 0:node] = paths[parent, 0:node] + np.tile(dist, (1, node))
        paths[0:node, node] = paths[0:node, parent] + np.tile(dist, (1, node))
        self.paths = paths

        ## update degrees
        if len(self.neighbors(parent)) > self.degree:
            self.degree = len(self.neighbors(parent))

        ## and return the new node index
        return node
        
    def addVertexBetween(self, n1, n2, d1):
#         assert d1 + d2 == self.paths[n1, n2], 'Attempt to insert node with inaccurate distances'
        assert self.adjMat.tocsr()[n1, n2] != 0, 'Attempt to insert node where no edge exists'
        d2 = self.getEdge(n1, n2) - d1
        graph.Graph.delEdge(self, n1,n2)
        n3 = graph.Graph.addVertex(self)
        graph.Graph.addEdge(self, n1, n3, d1)
        graph.Graph.addEdge(self, n2, n3, d2)

        ## update paths
        n1s = [i for i in xrange(n3) if self.paths[n1, i] < self.paths[n2, i]]
        n2s = [i for i in xrange(n3) if self.paths[n2, i] < self.paths[n1, i]]
        paths = np.zeros(self.adjMat.shape)
        paths[0:n3, 0:n3] = self.paths
        paths[n3, n1s] = paths[n1, n1s] + np.tile(d1, len(n1s))
        paths[n1s, n3] = paths[n1s, n1] + np.tile(d1, len(n1s))
        paths[n3, n2s] = paths[n2, n2s] + np.tile(d2, len(n2s))
        paths[n2s, n3] = paths[n2s, n2] + np.tile(d2, len(n2s))
        self.paths = paths

        ## update degree. Only node that changed degree was the one we inserted
        if self.degree < 2:
            self.degree = 2

        return n3

    def findLeavesBelow(self, root, parent):
        visited = set()
        if parent != None:
            visited.add(parent)
        if len(self.neighbors(root)) == 0:
            ## this is a degenerate case
            return [root]
        queue = []
        queue.extend(np.setdiff1d(self.neighbors(root), visited))
        queue = [root]
        leaves = []
        while len(queue) > 0:
            i = queue.pop()
            visited.add(i)
            neighbors = self.neighbors(i)
            neighbors = np.setdiff1d(neighbors, visited)
            if len(neighbors) == 0:
                leaves.append(i)
            else:
                neighbors = np.setdiff1d(neighbors, visited)
                queue.extend(neighbors)
        return leaves

    def cycleCheck(self):
        visited = []
        self.found = False
        def dfs(start, prev):
            if self.found == True:
                return
            visited.append(start)
            cands = self.adjMat.col[np.nonzero(self.adjMat.row == start)]
            cands = [c for c in cands if c != prev]
            cycles = np.intersect1d(cands, visited)
            if len(cycles) > 0:
                self.found = True
            cands = [c for c in cands if c not in visited]
            for c in cands:
                dfs(c, start)
        dfs(0, None)
        return self.found
    
    def connected(self):
        visited = set([])
        def dfs(start, prev):
            visited.add(start)
            cands = self.neighbors(start)
#             cands = self.adjMat.col[np.nonzero(self.adjMat.row == start)]
            cands = [c for c in cands if c != prev]
            cands = [c for c in cands if c not in visited]
            for c in cands:
                dfs(c, start)
        dfs(0, None)
        return (len(visited) == self.adjMat.shape[0])

    def getRoot(self):
        leaves = self.findleaves()
        s = float(len(leaves))
        k = self.degree
        queue = [(x, leaves[0]) for x in self.neighbors(leaves[0])]
        while len(queue) > 0:
            (pointer,parent) = queue.pop()
            d = self.findLeavesBelow(pointer, parent)
            if (s/(k+1) <= len(d) and len(d) <= s*k/(k+1)):
                return (pointer, parent)
            children = [x for x in self.neighbors(pointer) if x != parent]
            queue.extend([(x, pointer) for x in children])
        
    def sharedPaths(self, root = 0):
        # arbitrarily choose node 0 to be the source
        M = np.zeros((self.adjMat.shape[0], self.adjMat.shape[0]))
        descendents = {}
        def dfsDesc(start, parent):
            descs = set([start])
            cands = [x for x in self.neighbors(start) if x != parent]
            for c in cands:
                descs = descs.union(dfsDesc(c, start))
            descendents[start] = list(descs)
            return descs
        dfsDesc(root, None)
        def dfsConst(start, parent, pl):
            cands = [x for x in self.neighbors(start) if x != parent]
            M[start, start] = pl
            for c in cands:
                for c2 in cands:
                    if c != c2:
                        M[np.ix_(descendents[c], descendents[c2])] = pl
                        M[np.ix_(descendents[c2], descendents[c])] = pl
            for c in cands:
                dfsConst(c, start, self.getEdge(c, start)+pl)
        dfsConst(root, None, 0)
        return M

    def extendTree(self, p):
        """
        make each leaf of self a hub with p nodes attached
        returns a new tree object
        """
        leaves = self.findleaves()
        n = self.adjMat.shape[0]
        newsize = n + len(leaves)*p
        M = scipy.sparse.csr_matrix((self.adjMat.data, [self.adjMat.row, self.adjMat.col]), shape=(newsize, newsize))
        M = M.todok()
        index = n
        for i in leaves:
            for j in range(index, index+p):
                M[i, j] = 1
                M[j, i] = 1
            index += p
        return Tree(M)

    def path(self, i, j):
        """
        Compute the path from i to j in the tree, return a list of vertices as
        well as their distances from i.
        """
        T = self.adjMat.todok()
        visited = set([i])
        ps = {i: ([i], [0])}
        queue = [i]
        while len(queue) > 0:
            curr = queue.pop()
            next = [node for node in self.neighbors(curr) if node not in visited]
            currl = ps[curr][0]
            currd = ps[curr][1]
            for node in next:
                visited.add(node)
                nextl = copy.copy(currl)
                nextl.append(node)
                nextd = copy.copy(currd)
                nextd.append(T[curr,node])
                ps[node] = (nextl, nextd)
                queue.append(node)
            if j in ps.keys():
                return ps[j]
        return ([], [])
