#! /bin/python3

##########
## Version: Apr, 2017
## Author: Tsubasa Takahashi
## Note: This work is done in author's visiting at Carnegie Mellon University.
##########

import numpy as np
from sktensor import dtensor, ktensor, sptensor
import time
import math

from cyclone_fact import CycloneFact
from exp_tools import output_results
from tenfact_tools import sptensor_diff, load_matrix, normalize_matrix, periodic_folding
from detect_period import detect_periods
from ac_result import ACResult

base_init_lam2 = 0.1
delta_lam2 = 0.1
eps_lam2 = 1.e-4

n_period_set = 3

class AutoCyclone:
    def __init__(self):
        pass

    def fit_and_asses(self, X, rank, lamda2, lamda1, max_iter=100, tol=1e-3):
        cyfact = CycloneFact()

        start = time.time()
        model, stats = cyfact.fit(X,rank=rank,lamda2=lamda2,lamda1=lamda1,
                                  max_iter=max_iter,tol=tol)
        elapsed_time = time.time() - start

        itr, fit = stats
        ac_res = ACResult(X, rank, lamda2, lamda1, X.shape[1], model, itr, elapsed_time)

        return ac_res

    def auto_cyclonefact_lam2(self, X, rank, lam2, lamda1=0, max_iter=100, tol=1e-3):
        best_results = None
        min_cost = np.inf

        ac_res = self.fit_inner_rank(X,rank,lam2,lamda1,max_iter,tol)
        if ac_res is not None and ac_res.coding_cost < min_cost:
            best_res = ac_res
            min_cost = ac_res.coding_cost

        return best_res

    def fit_inner_auto(self, X, rank, lamda1=0, max_iter=100, tol=1e-3):
        best_res = None
        min_cost = np.inf

        d = X.shape[-1]
        lam2 = d/2 * base_init_lam2

        while lam2 >= eps_lam2:
            res = self.auto_cyclonefact_lam2(X, rank, lam2, lamda1, max_iter, tol)
            if res is not None and res.coding_cost < min_cost:
                best_res = res
                min_cost = res.coding_cost
                lam2 *= delta_lam2
            else:
                break
        return best_res

    def fit_inner_list_search(self, X, rank, list_lam2, lamda1=0, max_iter=100, tol=1e-3):
        best_res = None
        min_cost = np.inf

        for lam2 in list_lam2:
            res = self.auto_cyclonefact_lam2(X, rank, lam2, lamda1, max_iter, tol)

            if res is not None and res.coding_cost < min_cost:
                best_res = res
                min_cost = res.coding_cost
            else:
                break
        return best_res

    def fit_inner(self, X, rank, lamda2, lamda1=0, max_iter=100, tol=1e-3):
        best_res = None

        if lamda2 is not None:
            if isinstance(lamda2,list):
                list_lam2 = lamda2
            else:
                list_lam2 = [lamda2]
            best_res = self.fit_inner_list_search(X, rank, list_lam2, lamda1, max_iter, tol)
        else:
            best_res = self.fit_inner_auto(X, rank, lamda1, max_iter, tol)

        return best_res


    def fit_autoparam_lam1(self, X, rank, lamda2, lamda1=None, max_iter=100, tol=1e-3, prior_corc=0, prior_mdl=np.inf):

        best_res = None
        min_cost = prior_mdl

        if lamda1 is None:
            list_lam1 = 2 * np.array([1e-4,1e-3,1e-2,1e-1,1])
        elif isinstance(lamda1,list):
            list_lam1 = lamda1
        else:
            list_lam1 = [lamda1]

        for lam1 in list_lam1:
            res = self.fit_and_asses(X,rank=rank,
                                     lamda2=lamda2,
                                     lamda1=lam1,
                                     max_iter=max_iter,
                                     tol=tol)

            tenO = res.model.tensorO()

            if np.count_nonzero(tenO) < np.prod(tenO.shape) * 0.01:
                print("INVALID RESULT: Outlier tensor is empty.")
                return None

            rank,lamda2,lamda1 = res.params()
            print("l:%d [k, lam2, lam1]: (MDL, CORCO, RMSE) = [%d, %f, %f]: (%f, %f, %f)" % (res.period, rank, lamda2, lam1, res.coding_cost, res.corc, res.rmse))

            if res.corc > 0 and res.corc > prior_corc / rank * (rank-1):
                if res.coding_cost < min_cost:
                    best_res = res
                    min_cost = res.coding_cost
                else:
                    break
            else:
                if best_res is not None:
                    break

        return best_res


    def fit_inner_rank(self, X, rank = None, lamda2=None, lamda1=None, max_iter=100, tol=1.e-5):

        ## param init
        prior_mdl = np.inf
        prior_corc = 0
        best_res = None
        min_cost = np.inf

        ## rank: fixed
        if rank is not None and rank >= 1:
            max_rank = rank
        else:
            ## rank: auto
            shape = np.sort(X.shape)
            max_rank = shape[0] * shape[1]
            rank = 1

        for k in range(rank,max_rank+1):
            result = self.fit_autoparam_lam1(X, rank=k, max_iter=max_iter,
                                             lamda2=lamda2, lamda1=lamda1, tol=tol,
                                             prior_corc=prior_corc,prior_mdl=prior_mdl)

            if result is None:
                print("--- No valid result in k:%d ---\n" % k)
                break

            k, lam2, lam1 = result.params()
            print("*** k:%d lam2:%f lam1:%f ***\n" % (k, lam2, lam1))

            if k>1 and result.coding_cost > min_cost:
                break

            best_res = result
            min_cost = result.coding_cost
            prior_corc = result.corc

        return best_res

    def fit_fullauto(self, fname, window_list, lam2, rank=0, lam1=0, max_iter=100, tol=1.e-5, normalize=True, max_period=None, period_type=None):

        best_result = None
        min_cost = np.inf
        best_period = 0
        best_foldX = None

        matX = load_matrix(fname)
        if normalize:
            matX = normalize_matrix(matX)

        if window_list is None:
            window_list = detect_periods(matX,n_period_set,period_type)
            print(">> Candidate of Periodicity:",window_list)

        for w in window_list:
            matX, X = periodic_folding(matX, w, max_period)
            print(">> Runnig Factorization... (Period:%d)" % (int(w)))
            result = self.fit_inner(X, rank, lam2, lam1, max_iter, tol)

            k = result.params()[0]
            cost = result.coding_cost

            print("Period: %d  Rank: %d  Coding Cost: %f\n" % (w, k, cost))

            if cost < min_cost:
                min_cost = cost
                best_result = result
                best_period = w
                best_foldX = matX

        if best_result is not None:
            rank, lam2, lam1 = best_result.params()
            period = best_result.period
            print("====================")
            print("BEST RANK: %d  BEST PERIOD: %d (LAM2, LAM1)=(%f, %f)"
                    % (rank,period,lam2,lam1))
            print("====================")

        return (best_foldX, best_result)

    def fit_and_visualize(self,fname, rank, lam2, lam1, window_list,
                          max_iter=100, tol=1.e-5,
                          save_result=True, outpath="", normalize=False,
                          max_period=None, period_type=None):

        foldX, ac_res = self.fit_fullauto(fname, window_list, lam2, rank, lam1,
                                          max_iter, tol, normalize, max_period, period_type)

        if save_result:
            output_results(foldX, ac_res.model, outpath=outpath, src=fname, period_type=period_type)

        return ac_res
