import numpy as np
import scipy.sparse
import matplotlib.pyplot as plt

"""
Library for matrix objects that keep track of how many observations they use.
Allows for sub-indexing and many other features. 
"""

class Matrix(object):
    """
    Matrix object has heatmap capability and submatrix capability. 
    """
    def __init__(self, arr):
        self.data = arr
        self.shape = arr.shape

    def get(self, i,j):
        return self.data[i,j]

    def submatrix(self, indices):
        return Matrix(self.data[np.ix_(indices, indices)])

    def row(self, i):
        return self.data[:,i]

    def all(self):
        return self.data

    def markasused(self, xs, ys):
        pass
    def show(self, file=None):
        fig = plt.figure()
        ax1 = fig.add_subplot(111)
        ax1.imshow(self.data)
        if file != None:
            plt.savefig(file)
        fig.clf()
        plt.close(fig)

class CountingMatrix(Matrix):
    """
    Counting matrices keep track of which entries have been queried by the algorithm. 
    Allows for fetching a submatrix (without seeing the entries), querying for an entry, 
    and other operations. 
    """
    def __init__(self, arr, type=scipy.sparse.dok_matrix):
        self.data = arr
        if type == np.matrix:
            self.used = np.matrix(np.zeros(arr.shape))
        else:
            self.used = type(arr.shape)
        self.shape = self.data.shape
        self.parent = None
        self.mapping = None

    def get(self, i,j):
        self.markasused([i], [j])
        return self.data[i,j]

    def numUsed(self):
        return scipy.sparse.triu(self.used).getnnz()

    def submatrix(self, indices):
        Cnew = CountingMatrix(self.data[np.ix_(indices, indices)], type=type(self.used))
        Cnew.parent = self
        Cnew.mapping = indices
        return Cnew

    def row(self, i):
        self.markasused(i, range(self.data.shape[0]))
        return self.data[:,i]

    def all(self):
        self.markasused(range(self.data.shape[0]), range(self.data.shape[1]))
        return self.data

    def markasused(self, xs, ys):
        if type(self.used) == np.matrix:
            self.used[np.ix_(xs, ys)] = np.ones((len(xs), len(ys)))
            self.used[np.ix_(ys, xs)] = np.ones((len(ys), len(xs)))
        else:
            self.used[xs,ys] = np.ones((len(xs), len(ys)))
            self.used[ys,xs] = np.ones((len(ys), len(xs)))
        if self.parent != None:
            mxs = [self.mapping[x] for x in xs]
            mys = [self.mapping[y] for y in ys]
            self.parent.markasused(mxs, mys)

    def getall(self,W):
        """
        Given a bernoulli matrix, return a matrix with my entries where W is 1
        and zeros everywhere else
        """
        out = np.matrix(np.array(self.data)*np.array(W))
        for i in range(W.shape[0]):
            for j in range(W.shape[1]):
                if W[i,j] == 1:
                    self.markasused([i],[j])
        return out

    def show(self, file=None):
        fig = plt.figure()
        ax1 = fig.add_subplot(111)
        ax1.imshow(np.array(self.data)*np.array(self.used.toarray()))
        if file != None:
            plt.savefig(file, dpi=200, format="eps")
