#! /bin/python3

##########
## Version: Apr, 2017
## Author: Tsubasa Takahashi
## Note: This work is done in author's visiting at Carnegie Mellon University.
##########

import numpy as np
from sktensor import dtensor, sptensor
import os

from tenfact_tools import split_by_interval, split_by_segments
from viz.viz_heatmap import viz_heatmap
from viz.viz_result import viz_decomp_all, viz_signal_pair, viz_errors, viz_seasonal, viz_signal_pair_missing

#fig_format = 'pdf'
fig_format = 'png'

legend_map = {'_dataset/fitness3.dat': ["Swimming","Running","Yoga"],
              '_dataset/s_retails6_dat':["Amazon","Walmart","Home Depot","Best buy","Lowes","Costco"],
              '_dataset/s_games4_dat':["Xbox","Play Station","Wii","Android"],
              '_dataset/sst5.dat.dat':["0n125w","0n140w","0n155w","0n170w","0n180w"],
              '_dataset/synth2.dat':['1','2'],
              '_dataset/energy7.dat':['CA','DE','FR','GB','JP','KR','US']}

def output_results(matX, model, outpath='_out',src="", period_type=None):

    if os.path.exists(outpath) is False:
        os.mkdir(outpath)
    print("... Writing Results in [Dir]%s" % outpath)

    tenB = model.tensorB()
    B = model.matrixB()
    tenC = model.tensorC()
    paraC = model.parafacC()
    tenO = model.tensorO()

    if outpath is None:
        print("output results into _out")

    shape = tenO.shape
    if shape[2] == 1:
        matshape = (shape[0]*shape[1])
    else:
        matshape = (shape[0]*shape[1],shape[2])
    if shape[2] != 1:
        matX = matX[:matshape[0]]

    matB = tenB.reshape(matshape)
    matC = tenC.reshape(matshape)
    matO = tenO.reshape(matshape)
    matBC = matB + matC
    matBCO = matBC + matO

    ## fitting
    fname = "%s/%s.%s" % (outpath,'fit',fig_format)
    viz_signal_pair([matX,matBCO], fname, labels=legend_map[src], period_type=period_type, alphas=[0.5,1])

    fname = "%s/%s.%s" % (outpath,'org',fig_format)
    viz_signal_pair([matX,matX], fname, labels=legend_map[src], period_type=period_type, markers=['-','-'], alphas=[0,1])
    
    ## fitting error
    fname = "%s/decomp.%s" % (outpath,fig_format)
    models = [matBC, matO]
    labels = legend_map[src]
    viz_decomp_all(models, fname, labels, order_label=False, period_type=period_type)

    ## fitting error
    fname = "%s/%s.%s" % (outpath,'fit_error',fig_format)
    viz_errors([matX,matBCO], fname)

    ## heatmap for PARAFAC
    viz_heatmap((paraC.U[0]),'%s/parafacV.%s'%(outpath,fig_format),xlabel='k',ylabel='m',centoring=True)
    viz_heatmap((paraC.U[2]),'%s/parafacU.%s'%(outpath,fig_format),xlabel='k',ylabel='d',centoring=True)
    viz_seasonal([paraC.U[1]],'%s/parafacW.%s'%(outpath,fig_format),labels=['W'])

    for i in range(tenB.shape[2]):
        fname = "%s/%s_%d.%s" % (outpath,'single_fit',i,fig_format)
        viz_signal_pair_missing([matX[:,i],matBCO[:,i]], fname,
                                labels=legend_map[src],
                                period_type=period_type)
