#! /bin/python3

##########
## Version: Apr, 2017
## Author: Tsubasa Takahashi
## Note: This work is done in author's visiting at Carnegie Mellon University.
##########

import numpy as np
from sktensor import dtensor, ktensor, sptensor
from sktensor.core import nvecs, khatrirao, norm
import scipy.stats as stats

from tenfact_tools import random_ten_init, normalize_mode, lasso_update, group_lasso_update, sptensor_diff
from cyclone_m import CycloneM, matB_totensor

class CycloneFact():
    def __init__(self):
        pass

    def fit(self, X, rank, max_iter=100, lamda2=0.1, lamda1=0, tol=1e-3):
        """
        input
        -----
        X: observed tensor
        rank: rank for low rank approximation
        max_iter: max number of iteration
        lamda2: sparsity control parameter for outlier
        lamda1: sparsity control parameter for rare patterns

        output
        ------
        model: CycloneM (combination of B, C, M)
        stas: num of iteration, degree of fit
        """

        xdims = X.shape
        ndim = X.ndim
        normX = X.norm()
        tenO = dtensor(np.zeros(xdims))
        U = random_ten_init(X.shape, rank)
        fit = 0
        diff = 0
        paraC = ktensor(U, lmbda=None)

        prior_model = None
        itr = 0
        for itr in range(max_iter):
            fitold = fit
            if isinstance(X,sptensor):
                Xo = sptensor_diff(X, tenO, is_sptensor=True)
            else:
                Xo = X-tenO

            ## Step1 Base Trend Tensor (tenB) Decomp.
            B, tenB = self.solve_mean(Xo, itr)

            ## Step2 Cyclic Pattern Tensor (tenC) by Sparse PARAFAC
            paraC, U = self.solve_regular(Xo, tenB, paraC, rank, ndim, itr, lamda1)
            tenC = paraC.toarray()

            ## Step3 Outlier Decomposition by Soft Thresholding with group lasso
            tenO = self.solve_outlier(X, tenB, tenC, lamda2, xdims)
            if np.count_nonzero(tenO) == 0:
                break

            ## Step4 Measure Score and Convergence
            fit = self.eval_score(X, tenB, paraC, tenC, tenO, lamda2, lamda1)
            diff = fitold - fit
            if prior_model is not None and diff < tol:
                break
            else:
                prior_model = (tenB, B, paraC, tenO)

        if diff < 0:
            tenB, B, paraC, tenO = prior_model

        model = CycloneM(B,paraC,tenO,tenB,tenC)
        stats = (itr, fit)

        return (model, stats)

    def solve_mean(self, Xo, itr):
        if isinstance(Xo, sptensor):
            B = _sptensor_mean_trend(Xo)
        else:
            B = np.mean(Xo, axis=1)
        tenB = matB_totensor(B, Xo.shape)

        return (B, tenB)

    def solve_regular(self, Xo, tenB, paraC, rank, ndim, itr, lamda1=0, max_inner=1000,tol = 1.e-6):
        U = paraC.U
        lam = paraC.lmbda

        if isinstance(Xo, sptensor):
            Xom = sptensor_diff(Xo, tenB, is_sptensor=True)
        else:
            Xom = Xo - tenB

        fitold = 0
        newParaC = None
        for i in range(max_inner):
            for n in range(ndim):
                Unew = Xom.uttkrp(U, n)
                Y = np.ones((rank,rank))
                for i in (list(range(n)) + list(range(n+1, ndim))):
                    Y = Y * np.dot(U[i].T, U[i])
                Unew = Unew.dot(np.linalg.pinv(Y))
                U[n], lam = normalize_mode(Unew, itr)

                # U[0]: V, U[1]:W, U[2]: U in tenC
                if lamda1 > 0 and n != 1:
                    U[n] = lasso_update(U[n], lamda1)

            newParaC = ktensor(U, lam)
            tenC = newParaC.toarray()
            if isinstance(Xo, sptensor):
                residual = sptensor_diff(Xo, tenB+tenC)
            else:
                residual = Xo - tenB - tenC
            fit = np.linalg.norm(residual) ** 2
            diff = np.abs(fitold-fit)
            if diff < tol:
                break
            fitold = fit

        return (newParaC, U)


    def solve_outlier(self, X, tenB, tenC, lamda2, xdims):
        if isinstance(X, sptensor):
            R = sptensor_diff(X, tenB+tenC)
        else:
            R = X - tenB - tenC

        if lamda2 <= 0:
            return R

        lam2 = lamda2
        for i in range(xdims[1]):
            for j in range(xdims[0]):
                R[j,i,:] = group_lasso_update(R[j,i,:], lam2)
        return R

    def eval_score(self, X, tenB, paraC, tenC, tenO, lamda2, lamda1=0):
        if isinstance(X, sptensor):
            residual = sptensor_diff(X, tenB+tenC+tenO)
        else:
            residual = X - (tenB+tenC+tenO)

        norm_residual = np.linalg.norm(residual) ** 2

        score_o = 0
        for i in range(X.shape[1]):
            for j in range(X.shape[0]):
                l2normR = np.linalg.norm(tenO[j,i,:])
                score_o += l2normR
        norm_penalty = score_o

        fit = norm_residual + lamda2 * np.sum(norm_penalty)
        if lamda1 > 0:
            # U[0]: V, U[1]:W, U[2]: U in tenC
            pena1 = np.linalg.norm(paraC.U[0],1) + np.linalg.norm(paraC.U[2],1)
            fit += lamda1 * pena1

        return fit


def _sptensor_mean_trend(Xo):
    subs = Xo.subs
    vals = Xo.vals
    n_tuple = len(vals)

    dict_sum = {}
    # mean for sparse tensor
    for i in range(n_tuple):
        idx = []
        for j in list(range(1))+list(range(2,len(subs))):
            idx.append(subs[j][i])
        idx = tuple(idx)
        val = vals[i]
        if idx in dict_sum:
            psum, pcnt = dict_sum[idx]
            psum += val
            pcnt += 1
            dict_sum[idx] = (psum,pcnt)
        else:
            dict_sum[idx] = (val,1)

    sparse_mean_subs = []
    sparse_mean_vals = []
    for j in range(len(subs)-1):
        sparse_mean_subs.append([])
    sparse_mean_subs = tuple(sparse_mean_subs)

    for k,v in dict_sum.items():
        for kj in range(len(k)):
            sparse_mean_subs[kj].append(k[kj])
        vsum = v[0]
        vcnt = v[1]
        sparse_mean_vals.append(vsum/vcnt)

    center = sptensor(sparse_mean_subs, sparse_mean_vals).toarray()

    return center
