
##########
## Version: Apr, 2017
## Author: Tsubasa Takahashi
## Note: This work is done in author's visiting at Carnegie Mellon University.
##########

import numpy as np
from numpy.linalg import svd
from scipy.sparse.linalg import svds
from scipy.sparse import csr_matrix

def ropca_solver_minrank(X, rank, max_iter=10, lamda_nu=1.0, lamda2=1.0):
    """
    input
    -----
    X: observed matrix (N x p)
    rank: upper bound of rank for low rank approximation
    max_iter: max number of iteration
    lamda2: sparsity control parameter

    output
    ------
    mean: mean vector (p x 1)
    U:
    S: principle components
    O: outliers
    """

    N, p = X.shape
    O = np.zeros((N,p))
    S = np.random.rand(N,rank)
    one_vec1 = np.ones((N,1)) / N
    one_vec2 = np.ones((N,1))

    for i in range(max_iter):
        # update mean
        mean = np.dot((X - O).T,one_vec1)

        # update Xo
        Xo = X - np.dot(one_vec2,mean.T) - O

        # update U
        SSlam = np.dot(S.T,S) + lamda_nu/2
        U = np.dot(Xo.T, np.dot(S, np.linalg.inv(SSlam)))

        # update S
        UUlam = np.dot(U.T,U) + lamda_nu/2
        S = np.dot(Xo, np.dot(U, np.linalg.inv(UUlam)))

        # update O
        R = X - np.dot(S, U.T)
        l2normR = np.linalg.norm(R, axis=1)
        regR = np.maximum(l2normR-lamda2/2, 0)
        O = R
        for j in range(N):
            O[j,:] = O[j,:] * regR[j] / l2normR[j]

    return (mean, U, S, O)

def ropca_solver(X, rank, max_iter=10, lamda2=1.0):
    """
    input
    -----
    X: observed matrix (N x p)
    rank: rank for low rank approximation
    max_iter: max number of iteration
    lamda2: sparsity control parameter

    output
    ------
    mean: mean vector (p x 1)
    U:
    S: principle components
    O: outliers
    """

    is_sparse = False
    XX = np.copy(X)
    if np.isnan(X).any():
        XX = csr_matrix(np.nan_to_num(XX))
        is_sparse = True
        rank = min(XX.shape)

    N, p = XX.shape
    U = np.ones((p,rank))
    O = np.zeros((N,p))
    one_vec1 = np.ones((N,1)) / N
    one_vec2 = np.ones((N,1))

    for i in range(max_iter):
        mean = np.dot((XX - O).T,one_vec1)
        if is_sparse:
            Xo = csr_matrix(XX - np.dot(one_vec2,mean.T) - O)
            S = Xo.dot(U)
            L, sig, VT = svds(Xo.transpose().dot(S),k=rank-1)
        else:
            Xo = XX - np.dot(one_vec2,mean.T) - O
            S = np.dot(Xo, U)
            L, sig, VT = svd(np.dot(Xo.T, S),full_matrices=False)
        U = np.dot(L,VT)

        # update O
        R = XX - np.dot(one_vec2, mean.T) - np.dot(S, U.T)
        l2normR = np.linalg.norm(R, axis=1)
        regR = np.maximum(l2normR-lamda2/2, 0)
        O = R
        for j in range(N):
            O[j,:] = O[j,:] * regR[j] / l2normR[j]

    return (mean, U, S, O)
