#! /bin/python3

##########
## Version: Apr, 2017
## Author: Tsubasa Takahashi
## Note: This work is done in author's visiting at Carnegie Mellon University.
##########

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import *
from viz.viz_heatmap import viz_heatmap
import os
import math

colors = ['m','c','orange','brown','b','g','gray','r','purple','black']
CaptionOn = True
    
def viz_decomp_all(models, fname, labels=None, horizontal_stack=True, order_label=True, period_type=None):
    markers = ['-','-','-']
    alphas = [1.0,1,1]
    if labels is None:
        labels = np.arange(models[0].shape[1])

    plt.figure(figsize=(18,10),dpi=96)
    n_models = len(models)
    fontsize = 20
    plt.rcParams['font.size'] = fontsize
    labs, mloc = _get_xaxis_labels(period_type)

    if len(models) == 2:
        titles = ["Regular Patterns","Outliers"]
    else:
        titles = None

    lenx = len(models[0])
    x = np.arange(0,lenx)
    for ii, model in enumerate(models,1):
        plt.subplot(n_models,1,ii)
        ax = plt.gca()

        if titles is not None:
            plt.title(titles[ii-1])

        if np.matrix(model).shape[0] == 1:
            model = np.array( [ np.matrix(model).T ] )

        if mloc is not None and labs is not None:
            ax.xaxis.set_major_locator(mloc)
            ax.set_xticklabels(labs)

        ncol = model.shape[1]

        for jj, sig in enumerate(model.T):
            if order_label:
                label=r'$%s_%d$' % (labels[ii-1],jj)
            else:
                label = labels[jj]

            plt.plot(x, sig, markers[ii-1], color=colors[jj%10],
                     label = label,
                     alpha=alphas[ii-1], markersize=6)

        if CaptionOn:
            if len(models) > 1:
                ncol_legend = _set_ncol_legend(ncol)
                font_size = 16
            else:
                ncol_legend = 1
                font_size = 24
            if ii == 2:
                plt.legend(loc='best',
                           borderaxespad=0.,
                           #prop={'size':font_size},
                           ncol = ncol_legend,
                           numpoints=1)
        plt.axis('tight')

    xlabel_string = _get_xlabel_string(period_type)
    plt.xlabel(xlabel_string)
    plt.savefig(fname,bbox_inches="tight",pad_inches=0.1)
    plt.close()


def viz_signal_pair(models, fname, labels=None, period_type=None, markers=None, alphas=None):

    if markers is None or len(markers) != 2:
        markers = ['o','-']
    if alphas is None or len(alphas) !=2:
        alphas = [0.2,1.0]

    if labels is None:
        labels = ['x',r'\hat{x}']
        labels = np.arange(models[0].shape[1])
    labels_oe = ['Original','Estimated']
    plt.rcParams['font.size'] = 18
    plt.figure(figsize=(18,6),dpi=96)
    labs, mloc = _get_xaxis_labels(period_type)

    for ii, model in enumerate(models,1):
        #plt.ylabel('Volume @ time')
        ax = plt.gca()
        if labs is not None and mloc is not None:
            ax.xaxis.set_major_locator(mloc)
            ax.set_xticklabels(labs)
        xlabel_string = _get_xlabel_string(period_type)
        plt.xlabel(xlabel_string)

        x = np.arange(0,len(model))
        if np.matrix(model).shape[0] == 1:
            ncol = 1
        else:
            ncol = model.shape[1]

        ncol_legend = _set_ncol_legend(ncol)

        for jj in range(ncol):
            if ncol == 1:
                sig = model
            else:
                sig = model[:,jj]

            if ncol == 1:
                lab = labels_oe[ii-1]
                markersize = 10
                color = colors[ii-1]
            else:
                color = colors[jj%10]
                #if ii == 2:
                if ii == 1:
                    lab = None
                    if ncol > 6:
                        markersize=6
                    else:
                        markersize=8
                else:
                    lab = labels[jj%10]
                    markersize=10
            plt.plot(x, sig, markers[ii-1], color=color,
                     label=lab,
                     alpha=alphas[ii-1], markersize=markersize)

        if CaptionOn:
            plt.legend(loc='best',
                       borderaxespad=0.,
                       ncol = ncol_legend,
                       numpoints=1)

    plt.axis('tight')
    plt.savefig(fname, bbox_inches="tight", pad_inches=0.1)
    plt.close()

def viz_signal_pair_missing(models, fname, labels=None, period_type=None):
    markers = ['o','-']
    alphas = [0.2,1.0]
    labels_oe = ['Original','Estimated']
    plt.rcParams['font.size'] = 18
    plt.figure(figsize=(18,4))
    labs, mloc = _get_xaxis_labels(period_type)

    for ii, model in enumerate(models,1):
        plt.ylabel('Volume @ time')
        ax = plt.gca()

        if labs is not None and mloc is not None:
            ax.xaxis.set_major_locator(mloc)
            ax.set_xticklabels(labs)

        xlabel_string = _get_xlabel_string(period_type)
        plt.xlabel(xlabel_string)

        x = np.arange(0,len(model))
        if np.matrix(model).shape[0] == 1:
            ncol = 1
        else:
            ncol = model.shape[1]


        for jj in range(ncol):
            if ncol == 1:
                sig = model
            else:
                sig = model[:,jj]

            ## draw gray zones over missing values
            fnan = False
            x1 = 0
            for k in range(len(sig)):
                s = sig[k]
                if fnan is False and math.isnan(s):
                    fnan = True
                    x1 = k
                elif fnan and math.isnan(s) is False:
                    x2 = k
                    plt.axvspan(x1,x2,facecolor='gray',alpha=0.3)
                    fnan = False
            if fnan:
                x2 = len(sig)
                plt.axvspan(x1,x2,facecolor='gray',alpha=0.3)
            ##

            lab = labels_oe[ii-1]
            markersize = 6
            if ii == 1:
                color = 'c'
            else:
                color = 'm'
            plt.plot(x, sig, markers[ii-1], color=color,
                     label=lab,
                     alpha=alphas[ii-1], markersize=markersize)

        plt.legend(loc=1,
                   borderaxespad=0.,
                   prop={'size':16},
                   ncol = 1,
                   numpoints=1)

    plt.savefig(fname, bbox_inches="tight", pad_inches=0.1)
    plt.close()

def viz_errors(models, fname):
    m1 = models[0]
    m2 = models[1]

    fig, ax = plt.subplots()
    marker = 'o'
    color = 'm'
    alpha = 0.3

    color_line = 'c'
    plt.figure(figsize=(5,5),dpi=96)
    plt.xlabel("Original Data")
    plt.ylabel("Fitting Data")
    plt.ylim(0,1)
    plt.xlim(0,1)

    x = np.arange(0,1.1,0.1)
    y = x
    plt.plot(m1,m2,marker,color=color,markersize=6,alpha=alpha)
    plt.plot(x,y,'-',color=color_line, label='Ideal')

    plt.legend(loc='best',
               borderaxespad=0.,
               numpoints=1)

    plt.savefig(fname,bbox_inches="tight",pad_inches=0.1)
    plt.close()


def viz_seasonal(models, fname, labels=None):
    colors = ['r','g','b','black','orange','m','c','y']
    markers = ['-','-','-','-','-','-','-','-']
    alpha = 1.0
    if labels is None:
        labels = np.arange(models[0].shape[1])

    n_models = len(models)
    fontsize = 20

    plt.rcParams['font.size'] = fontsize

    lenx = len(models[0])
    x = np.arange(0,lenx)
    for ii, model in enumerate(models,1):
        plt.subplot(n_models,1,ii)
        ax = plt.gca()

        if np.matrix(model).shape[0] == 1:
            model = np.array( [ np.matrix(model).T ] )

        plt.tick_params(axis='both', which='major', labelsize=20)
        ncol = model.shape[1]

        for jj, sig in enumerate(model.T):
            label=r'$%s_%d$' % (labels[ii-1],jj+1)

            plt.plot(x, sig, markers[ii-1], color=colors[jj%10],
                     label=label,alpha=alpha, markersize=6)

        ncol_legend = _set_ncol_legend(ncol)
        font_size = 28
        plt.legend(loc='best',
                   borderaxespad=0.,
                   #prop={'size':font_size},
                   ncol = ncol_legend,
                   numpoints=1)
        plt.axis('tight')

    plt.xlabel('Time')
    plt.savefig(fname,bbox_inches="tight",pad_inches=0.1)
    plt.close()

def _set_ncol_legend(ncol):
    if ncol == 1:
        legend_ncol = 1
    elif ncol <= 4:
        legend_ncol = ncol
    else:
        legend_ncol = int(np.ceil(ncol/2))
    return legend_ncol

def _get_xaxis_labels(period_type):
    labs = None
    mloc = None
    if period_type == 0:
        labs = np.arange(2005,2016)
        mloc = MultipleLocator(12)
    elif period_type == 1:
        labs = np.arange(2003,2015)
        mloc = MultipleLocator(52)
    elif period_type == 2:
        labs = np.arange(2000,2011)
        mloc = MultipleLocator(365)
    return (labs, mloc)

def _get_xlabel_string(period_type):
    if period_type == 0:
        xlabel_string = "Time (monthly)"
    elif period_type == 1:
        xlabel_string = "Time (weekly)"
    elif period_type == 2:
        xlabel_string = 'Time (daily)'
    else:
        xlabel_string = "Time"
    return xlabel_string
