#! /bin/python3

##########
## Version: Apr, 2017
## Author: Tsubasa Takahashi
## Note: This work is done in author's visiting at Carnegie Mellon University.
##########

import numpy as np
import scipy.stats as stats
from sktensor import dtensor, sptensor

from tenfact_tools import sptensor_diff, bucketize
from corcondia import efficient_corcondia

Cf = 8

class ACResult:
    def __init__(self, X, rank, lamda2, lamda1, period, model, itr, elapsed_time):
        self.rmse = _rmse(X, model)
        self.coding_cost = _eval_mdl(X,rank,model)
        self.corc = _eval_quality(X, model)
        self.rank = rank
        self.lamda2 = lamda2
        self.lamda1 = lamda1
        self.period = period
        self.model = model
        self.itr = itr
        self.elapsed_time = elapsed_time

    def params(self):
        return (self.rank, self.lamda2, self.lamda1)

def _rmse(X, model):
    Xnew = model.allsum_tensor()
    if isinstance(X, sptensor):
        Z = sptensor_diff(X, Xnew, is_sptensor=True).vals
    else:
        Z = X - Xnew
    return np.sqrt(np.mean(Z ** 2))


def _eval_mdl(X, rank, model):
    paraC = model.parafacC()
    tenC = model.tensorC()
    tenO = model.tensorO()
    matB = model.matrixB()
    tenB = model.tensorB()

    ## for base trend tensor
    costB = matB.shape[0] * matB.shape[1] * Cf

    ## for cyclic pattern tensor
    ndim = len(paraC.U)
    costC = 0
    for k in range(rank):
        for n in range(ndim):
            m = paraC.U[n].shape[0]
            if n == 1: # for W
                costC += m * Cf
            else:
                nzmk = np.count_nonzero(paraC.U[n][:,k])
                posmk = np.log2(m)
                freedomNZ = np.log2(m+1)
                costC += nzmk * (posmk + Cf) + freedomNZ

    ## for outlier tensor
    costO = 0
    if tenO is not None:
        valO = Cf
        d = tenO.shape[-1]
        posO = np.sum(np.log2(tenO.shape))
        N = np.prod(tenO.shape)
        nzO = np.count_nonzero(tenO)
        nzfreedom = np.log2(N+1)
        costO = nzO * (posO + valO) + nzfreedom

    period_freedom = np.log2(np.prod(tenO.shape[0:2]))
    cost_m = costB + costC + costO + _log_s(rank) + period_freedom

    Xnew = tenB + tenC + tenO
    cost_c = _coding(X, Xnew)
    cost_t = cost_c + cost_m

    return cost_t

def _eval_quality(X, model):
    tenB = model.tensorB()
    paraC = model.parafacC()
    tenO = model.tensorO()

    if isinstance(X, sptensor):
        Z1 = sptensor_diff(X, tenB, is_sptensor=True)
        Z = sptensor_diff(Z1, tenO, is_sptensor=True)
    else:
        Z1 = X - tenB
        Z = X - tenB -tenO

    corc = efficient_corcondia(Z, paraC)

    return corc

def _coding(Xorg, Xnew):

    if isinstance(Xorg, sptensor):
        delta = sptensor_diff(Xorg,Xnew)
        N = len(Xorg.subs[0])
    else:
        delta = Xnew - Xorg
        N = np.prod(Xnew.shape)

    delta = delta * pow(2,Cf)
    std = np.std(delta)
    mean = np.mean(delta)
    pdfs = stats.norm.pdf(delta, mean, std)

    cost_c = np.nansum(-np.log2(pdfs))

    return cost_c

def _log_s(x):
    return 2.0*np.log2(x)+1.0
