#! /bin/python3

##########
## Version: Apr, 2017
## Author: Tsubasa Takahashi
## Note: This work is done in author's visiting at Carnegie Mellon University.
##########

import numpy as np
import math
from sktensor import dtensor, ktensor, sptensor
from sktensor.core import nvecs, khatrirao, norm
from sklearn.preprocessing import normalize
from more_itertools import pairwise

def load_matrix(fname, delim=None, n_skiprows=0, usecols=None):
    matX = np.loadtxt(fname,delimiter=delim,skiprows=n_skiprows,usecols=usecols)
    return matX

def normalize_matrix(matX):
    mins = np.nanmin(matX,axis=0)
    maxs = np.nanmax(matX,axis=0)
    matX = (matX-mins) / (maxs-mins)
    matX[matX<0]=0
    matX[matX>1]=1
    return matX

def periodic_folding(matX, _window_size, max_period):
    window_size = int(_window_size)
    shape_X = matX.shape
    n_rows = shape_X[0]
    n_cols = shape_X[1]
    n0 = int(n_rows / window_size)
    N = n0 * window_size
    tenX = dtensor(matX[0:N,:].reshape((n0,window_size,n_cols)))
    matX = matX[0:N,:]

    if max_period is not None:
        mp = int(max_period)
        tenX = tenX[:mp,:,:]

    if np.isnan(tenX).any():  # if sparse
        subs = np.where(tenX == tenX)
        vals = tenX[subs]
        tenX = sptensor(subs,vals)

    return (matX, tenX)

def random_ten_init(mode, rank):
    U = []
    for m in mode:
        tmp = np.random.rand(m, rank)
        norm_tmp = normalize(tmp, axis=0, norm='l1')
        U.append(norm_tmp)
    return U

def normalize_mode(U, itr):
    col_norm = np.apply_along_axis(np.linalg.norm, 0, U, 1)
    zero_norm = np.where(col_norm == 0)
    col_norm[zero_norm] = 1

    lmbda = col_norm
    Unew = U / lmbda

    return (Unew, lmbda)

def lasso_update(U, lam1):
    Unew = U
    sgn = np.sign(Unew)
    dif = np.maximum(np.abs(Unew)-lam1, 0)
    Unew = sgn * dif
    return Unew

def group_lasso_update(grouped_elems, lamda2):
    l2norm = np.linalg.norm(grouped_elems)
    regR = np.maximum(l2norm-lamda2/2,0)
    if regR != 0:
        return grouped_elems * regR / l2norm
    else:
        return 0

def identity_tensor(dim, order):
    index = np.arange(0,dim)
    indices = tuple([index]*order)
    I = sptensor(indices, np.ones(dim))
    return I.toarray()

def sptensor_diff(X, other, is_sptensor=False):

    if is_sptensor is False:
        residual = np.zeros(X.shape)
    subs = X.subs
    vals = X.vals

    r_subs = subs
    r_vals = []

    for i in range(len(vals)):
        idx = []
        for j in range(len(subs)):
            idx.append(subs[j][i])
        idx = tuple(idx)
        if is_sptensor:
            r_vals.append(vals[i] - other[idx])
        else:
            residual[idx] = vals[i] - other[idx]

    if is_sptensor:
        residual = sptensor(r_subs, r_vals)

    return residual

def split_by_segments(X, segments):
    DS = []
    for i,j in pairwise(segments):
        DS.append(X[i:j])
    return DS

def split_by_interval(X, interval):
    DS = []
    for i in range(0, len(X)-interval+1, interval):
        DS.append(X[i:i+interval])
    return DS

def bucketize(X,bit):
    vmax = np.max(np.abs(X))
    vmin = np.min(np.abs(X))

    diff = vmax - vmin
    ## sign uses 1 bit
    bunit = diff / (2 ** (bit-1))

    bX = np.floor(X/bunit)
    return bX * bunit
