#! /bin/python3

##########
## Version: Apr, 2017
## Author: Tsubasa Takahashi
## Note: This work is done in author's visiting at Carnegie Mellon University.
##########

import math
import numpy as np
import matplotlib.pyplot as plt
import sys
import copy
from viz.mpnorm import MidpointNormalize as mpnorm
from more_itertools import pairwise

def viz_heatmap(fact_mat, fname=None, transpose=False, xlabel=None, ylabel=None, centoring=True):

    if transpose:
        F = fact_mat.T
    else:
        F = fact_mat

    plt.figure(figsize=(4,4),dpi=96)
    fig, ax = plt.subplots()
    absmax = np.max([np.abs(F.min()),np.abs(F.max())])

    if centoring:
        norm = mpnorm(midpoint=0.0, vmin = -absmax, vmax= absmax)
        c2 = ax.pcolor(np.ma.masked_equal(F,0), cmap='seismic', norm=norm, edgecolor='k')
    else:
        c2 = ax.pcolor(np.ma.masked_equal(F,0), cmap='binary', edgecolor='k')

    row_labels = list(range(1,F.shape[1]+1))
    col_labels = list(range(1,F.shape[0]+1))

    ax.set_xticks(np.arange(F.shape[1])+0.5, minor = False)
    ax.set_yticks(np.arange(F.shape[0])+0.5, minor = False)

    ax.set_xticklabels(row_labels,minor=False,fontsize='large')
    ax.set_yticklabels(col_labels,minor=False,fontsize='large')

    if xlabel is not None:
        plt.xlabel(xlabel, fontsize='x-large')
    if ylabel is not None:
        plt.ylabel(ylabel, fontsize='x-large')
    plt.axis('tight')

    ax.invert_yaxis()
    ax.xaxis.tick_top()
    ax.xaxis.set_label_position('top')

    ax.set_axis_bgcolor('.7')
    plt.colorbar(c2)

    if fname is not None:
        plt.savefig(fname, bbox_inches="tight", pad_inches=0.1)
