import numpy as np
import scipy.sparse
import apgl.graph
import subprocess

def load_graph_from_file(f):
    lines = open(f).readlines()
    data = np.array([[float(i) for i in x.split(" ")] for x in lines])
    return Graph(adjMat=data, computePaths = False)


class Graph(object):
    """
    A Graph object. Graphs are represented as a sparse adjacency matrices
    """
    def __init__(self, adjMat = np.zeros([1, 1]), computePaths = True):
        self.adjMat = scipy.sparse.coo_matrix(adjMat)
        if computePaths:
            self.allpairsshortestpaths()
        self.computeMaxDegree()

        

    def addEdge(self, i, j, val=1):
        n = self.adjMat.shape[0]

        ## HACK to handle edges with 0 weight, rather just make them really
        ## small, shouldn't mess with results too much
        if val == 0:
            val = 1*10**-20

        self.adjMat = (self.adjMat + scipy.sparse.coo_matrix(([val, val], ([i, j], [j, i])), shape=(n, n)))
        self.adjMat = self.adjMat.tocoo()

    def addEdges(self, edges):
        edges1 = [(e[0], e[1], 1) for e in edges if len(e) == 2]
        edges2 = [e for e in edges if len(e) == 3]
        edges2.extend(edges1)
        for e in edges2:
            self.addEdge(e[0], e[1], e[2])

    def getEdge(self, i, j):
        return(self.adjMat.todok()[i,j])

    def addVertex(self, connections={}):
        n = self.adjMat.shape[0]
        data = self.adjMat.data
        ij = [self.adjMat.row, self.adjMat.col]
        newAdjMat = scipy.sparse.csr_matrix((data, ij), shape=(n+1, n+1))
        self.adjMat = newAdjMat.tocoo()
        for c in connections.items():
            Graph.addEdge(self, n, c[0], c[1])
        return n

    def neighbors(self, i):
        nzs = self.adjMat.nonzero()
        return nzs[1][np.nonzero(nzs[0] == i)[0]]

    def numVertices(self):
        n = self.adjMat.shape[0]
        return n

    def numEdges(self):
        return len(self.adjMat.nonzero()[0])/2

    def delEdge(self, i, j):
        n = self.adjMat.shape[0]
        val = self.adjMat.tocsr()[i,j]
        self.adjMat = (self.adjMat - scipy.sparse.coo_matrix(([val, val], ([i,j], [j,i])), shape=(n,n)))
        self.adjMat = self.adjMat.tocoo()

    def computeMaxDegree(self):
        n = self.adjMat.shape[0]
        nzs = self.adjMat.nonzero()
        self.degree = max([len(self.neighbors(x)) for x in range(n)])

    def findleaves(self):
        n = self.adjMat.shape[0]
        leaves = [x for x in range(n) if len(self.neighbors(x)) == 1]
#         leaves.extend([x for x in range(n) if len(self.neighbors(x)) == 0])
        return leaves

    def allpairsshortestpaths(self):
        vList = apgl.graph.GeneralVertexList(self.adjMat.shape[0])
        G = apgl.graph.SparseGraph(vList, W=scipy.sparse.dok_matrix(self.adjMat.shape))
        G.setWeightMatrix(self.adjMat.toarray())
        self.paths = G.floydWarshall()

    def connectedComponents(self):
        paths = self.paths
        comps = []
        sizes = []
        available = set(range(paths.shape[0]))
        while len(available) > 0:
            next = available.pop()
            s = set(np.nonzero(paths[next,:] != np.inf)[0])
            available.difference_update(s)
            s.add(next)
            comps.append(s)
            sizes.append(len(s))
        return comps

    def edgeNodes(self):
        return ([x for x in range(self.paths.shape[0]) if len(self.neighbors(x)) == 1])
    
    def sharedPathMatrix(self, root):
        ## first do BFS to find the level sets of the graph (rooted at i)
        maxdist = int(np.max(self.paths[root,:]))
        if maxdist == np.inf:
            raise Exception ("Graph is not connected")
        levelsets = [[x for x in range(self.paths.shape[0]) if self.paths[root,x] == j] for j in range(maxdist+1)]
        ## then do DFS to find the descendents for each node
        parents = {root: None}
        leaves = set([])
        def dfsDesc(start, level):
            if level == maxdist:
                leaves.add(start)
                return
            cands = [x for x in self.neighbors(start) if x in levelsets[level+1]]
            if len(cands) == 0:
                leaves.add(start)
            for c in cands:
                parents[c] = start
                dfsDesc(c, level+1)
            return
        dfsDesc(root, 0)

        ## Now find the descendents of each node
        descendents = {}
        def dfsDesc2(start):
            descs = set([start])
            cands = [x for x in self.neighbors(start) if parents[x] == start]
            for c in cands:
                descs = descs.union(dfsDesc2(c))
            descendents[start] = list(descs)
            return descs
        dfsDesc2(root)

        M = np.zeros((self.adjMat.shape[0], self.adjMat.shape[0]))
        def dfsFill(start, level, pl):
            M[start, start] = pl
            if level == maxdist:
                return
            cands = [x for x in range(self.paths.shape[0]) if x in levelsets[level+1] and parents[x] == start]
            cands.append(start)
            for c in cands:
                for c2 in cands:
                    if c != c2:
                        M[np.ix_(descendents[c], descendents[c2])] = pl
                        M[np.ix_(descendents[c2], descendents[c])] = pl
            for c in cands:
                dfsFill(c, level+1, self.getEdge(c, start)+pl)
        dfsFill(root, 0, 0)
        return (M, list(leaves))

    def sharedPathMatrix2(self, root):
        ## first do BFS to find the level sets of the graph (rooted at i)
        maxdist = int(np.max(self.paths[root,:]))
        if maxdist == np.inf:
            raise Exception ("Graph is not connected")
        levelsets = [[x for x in range(self.paths.shape[0]) if self.paths[root,x] == j] for j in range(maxdist+1)]
        ## then do DFS to find the descendents for each node
        descendents = {}
        leaves = set([])
        def dfsDesc(start, level):
            descs = set([start])
            if level == maxdist:
                descendents[start] = list(descs)
                leaves.add(start)
                return descs
            cands = [x for x in self.neighbors(start) if x in levelsets[level+1]]
            if len(cands) == 0:
                leaves.add(start)
            for c in cands:
                descs = descs.union(dfsDesc(c, level+1))
            descendents[start] = list(descs)
            return descs
        dfsDesc(root, 0)

        M = np.zeros((self.adjMat.shape[0], self.adjMat.shape[0]))
        ## then do the second round of DFS to update the entries
        def dfsFill(start, level, pl):
            M[start, start] = pl
            if level == maxdist:
                return
            cands = [x for x in self.neighbors(start) if x in levelsets[level+1]]
            cands.append(start)
            for c in cands:
                for c2 in cands:
                    if c != c2:
                        M[np.ix_(descendents[c], descendents[c2])] = pl
                        M[np.ix_(descendents[c2], descendents[c])] = pl
            for c in cands:
                dfsFill(c, level+1, self.getEdge(c, start)+pl)
        dfsFill(root, 0, 0)
        return (M, list(leaves))


    def degree_dist(self):
        return [len(self.neighbors(i)) for i in range(self.numVertices())]

    def save_as_dot(self, file, leaves=None):
        name = file.split("/")
        name = name[len(name)-1]
        f = open(file + ".dot", "w")
        f.write("graph %s{\n" % name)
        f.write('size="7,6"\n')
        f.write('node[shape=circle,height=0.2,width=0.2,style=filled]\n')
        if leaves == None:
            leaves = self.findleaves()
        for l in leaves:
            f.write("%d [color=red]\n" % l)
        M = self.adjMat.todok()
        for k in M.iterkeys():
            if k[0] < k[1]:
                f.write('"%d" -- "%d"\n' % (k[0], k[1]))
        f.write("}")
        f.close()
