
##########
## Version: Apr, 2017
## Author: Tsubasa Takahashi
## Note: This work is done in author's visiting at Carnegie Mellon University.
##########

import numpy as np
from ropca import ropca_solver

L = 4
NP = 2

## monthly data
monthly_data = ['_dataset/energy5.dat']

## weekely data
weekely_data = ['_dataset/fitness3.dat',
                '_dataset/s_retails6_dat',
                '_dataset/s_games4_dat',
                ]
## daily data
daily_data = ['_dataset/sst7.dat']


monthly_periods = [12, 6]
weekly_periods = [52, 26]
daily_periods = [365, 182]
univ_periods = [ monthly_periods,
                 weekly_periods,
                 daily_periods]
min_periods = [3, 10, 100]

def get_period_type(fname):
    if fname in monthly_data:
        ptype = 0
    elif fname in weekely_data:
        ptype = 1
    elif fname in weekely_data:
        ptype = 2
    else:
        ptype = None

    return ptype

def detect_periods(matX, Np=NP, period_type=None):
    rX = get_robust_principle_timeseries(matX)
    anal_periods = periodogram_analysis(rX, Np)
    
    if period_type in (0,1,2):
        anal_periods = anal_periods + univ_periods[period_type]
        anal_periods = np.unique(anal_periods)
    return anal_periods

def get_robust_principle_timeseries(matX):
    d = matX.shape[-1]
    r = max(d-1,1)
    mean, U, S, O = ropca_solver(matX,rank=r,lamda2=0.1*d/2)
    Xnew = np.dot(S,U.T)
    return Xnew

def periodogram_analysis(matX, Np=NP, period_type=None):
    periods = {}

    for i in range(matX.shape[1]):
        y = matX[:,i].astype(np.float64)
        x = np.arange(len(y)).astype(np.float64)
        ff = np.fft.fft(y)

        n = len(ff)
        power = np.abs(ff[:int(n/2)]) **2
        Fs = 1.
        nyquist = Fs/2
        freq = np.linspace(0.0001,1,n/2)*nyquist
        period = 1./freq

        pr = period[L:]
        pp = np.log(power[L:])

        pmax = np.max(pp)
        pmin = np.min(pp)
        pp = (pp - pmin)/(pmax-pmin)

        if period_type is not None and min_periods[period_type] is not None:
            thre_period = min_periods[period_type]
        else:
            thre_period = 0

        for i in range(len(pr)):
            per = pr[i]
            if per >= thre_period:
                if per in periods:
                    periods[per] *= pp[i]
                else:
                    periods[per] = pp[i]

    prs = []
    for k,v in sorted(periods.items(), key=lambda x:x[1]):
        prs.append(k)

    return list(np.round(prs[::-1][:Np]).astype(int))
