# Magnitude Recurrent Neural Networks for Speech denoising by A-Priori SNR
# Esitimation and Wiener Filtering
# Author: Raymond Xia (yangyanx@andrew.cmu.edu)
#
# Implemented in PyTorch

import numpy as np
import torch
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
import time
import operator
import pdb

DEBUG = 0

# Hyperparameters
CONFIG_MRNN = {
    'FREQ_DIM': 257, # frequency dimension of input spectrogram
    'ASNR_MU':.98, # default weighting factor for noise PSD estimate
    'ASNR_ETA':.15, # default VAD threshold
    'ASNR_DELTA': .05, # default delta to be used in piecewise linear activation for mu
    'COST_WEIGHT': .2, # weighting factor for MSE vs. VAD cost
    'LEARNING_RATE': 1e-4, # learning rate of SGD
    'MOMENTUM': .9 # learning update momentum
}

eps = 1e-16 # epsilon

class MRNN_ASNR(torch.nn.Module):
    """docstring for MRNN_ASNR.
    MRNN_ASNR implements the recurrent neural network version of the a priori
    SNR esitimation for signal denoising.
    Args:

    """
    def __init__(self, params=None,dtype=None):
        super(MRNN_ASNR, self).__init__()

        # Initilize learnable parameters
        self.freq_dim = CONFIG_MRNN['FREQ_DIM']
        self.eta = CONFIG_MRNN['ASNR_ETA']
        self.mu = CONFIG_MRNN['ASNR_MU']
        self.alpha = CONFIG_MRNN['COST_WEIGHT']
        self.delta = CONFIG_MRNN['ASNR_DELTA']
        if dtype == 'gpu' and torch.cuda.is_available():
            self.dtype = torch.cuda.FloatTensor # gpu mode
            print "GPU mode."
        elif dtype == 'gpu':
            self.dtype = torch.FloatTensor
            print "CUDA not available. Use CPU instead."
        else:
            self.dtype = torch.FloatTensor # cpu mode
            print "CPU mode."

        if params is None:
            # Noise PSD estimation network parameters: mu_1
            self.E11 = nn.Parameter(torch.eye(self.freq_dim).type(self.dtype),
                              requires_grad=True)
            self.V11 = nn.Parameter(torch.eye(self.freq_dim).type(self.dtype),
                              requires_grad=True)
            self.bE11 = nn.Parameter(torch.zeros(self.freq_dim).type(self.dtype),
                              requires_grad=True)
            self.bV11 = nn.Parameter(torch.zeros(self.freq_dim).type(self.dtype),
                             requires_grad=True)
            # Noise PSD estimation network parameters: mu_2
            self.E12 = nn.Parameter(torch.eye(self.freq_dim).type(self.dtype),
                              requires_grad=True)
            self.V12 = nn.Parameter(torch.eye(self.freq_dim).type(self.dtype),
                              requires_grad=True)
            self.bE12 = nn.Parameter(torch.zeros(self.freq_dim).type(self.dtype),
                              requires_grad=True)
            self.bV12 = nn.Parameter(torch.zeros(self.freq_dim).type(self.dtype),
                             requires_grad=True)
            # A priori SNR esimation network parameters: alpha_1
            self.E21 = nn.Parameter(torch.eye(self.freq_dim).type(self.dtype),
                              requires_grad=True)
            self.V21 = nn.Parameter(torch.eye(self.freq_dim).type(self.dtype),
                              requires_grad=True)
            self.bE21 = nn.Parameter(torch.zeros(self.freq_dim).type(self.dtype),
                              requires_grad=True)
            self.bV21 = nn.Parameter(torch.zeros(self.freq_dim).type(self.dtype),
                             requires_grad=True)
            # A priori SNR esimation network parameters: alpha_2
            self.E22 = nn.Parameter(torch.eye(self.freq_dim).type(self.dtype),
                              requires_grad=True)
            self.V22 = nn.Parameter(torch.eye(self.freq_dim).type(self.dtype),
                              requires_grad=True)
            self.bE22 = nn.Parameter(torch.zeros(self.freq_dim).type(self.dtype),
                              requires_grad=True)
            self.bV22 = nn.Parameter(torch.zeros(self.freq_dim).type(self.dtype),
                             requires_grad=True)
        else:
            self.__check_params__(params)
            self.E11 = nn.Parameter(torch.from_numpy(params['E11']).type(self.dtype),
                              requires_grad=True)
            self.V11 = nn.Parameter(torch.from_numpy(params['V11']).type(self.dtype),
                              requires_grad=True)
            self.bE11 = nn.Parameter(torch.from_numpy(params['bE11']).type(self.dtype),
                              requires_grad=True)
            self.bV11 = nn.Parameter(torch.from_numpy(params['bV11']).type(self.dtype),
                              requires_grad=True)
            self.E21 = nn.Parameter(torch.from_numpy(params['E21']).type(self.dtype),
                              requires_grad=True)
            self.V21 = nn.Parameter(torch.from_numpy(params['V21']).type(self.dtype),
                              requires_grad=True)
            self.bE21 = nn.Parameter(torch.from_numpy(params['bE21']).type(self.dtype),
                              requires_grad=True)
            self.bV21 = nn.Parameter(torch.from_numpy(params['bV21']).type(self.dtype),
                              requires_grad=True)
            self.E12 = nn.Parameter(torch.from_numpy(params['E12']).type(self.dtype),
                              requires_grad=True)
            self.V12 = nn.Parameter(torch.from_numpy(params['V12']).type(self.dtype),
                              requires_grad=True)
            self.bE12 = nn.Parameter(torch.from_numpy(params['bE12']).type(self.dtype),
                              requires_grad=True)
            self.bV12 = nn.Parameter(torch.from_numpy(params['bV12']).type(self.dtype),
                              requires_grad=True)
            self.E22 = nn.Parameter(torch.from_numpy(params['E22']).type(self.dtype),
                              requires_grad=True)
            self.V22 = nn.Parameter(torch.from_numpy(params['V22']).type(self.dtype),
                              requires_grad=True)
            self.bE22 = nn.Parameter(torch.from_numpy(params['bE22']).type(self.dtype),
                              requires_grad=True)
            self.bV22 = nn.Parameter(torch.from_numpy(params['bV22']).type(self.dtype),
                              requires_grad=True)

    def __check_params__(self,params):
        assert params['E11'].shape == (self.freq_dim,self.freq_dim)
        assert params['V11'].shape == (self.freq_dim,self.freq_dim)
        assert params['bE11'].shape == (self.freq_dim,)
        assert params['bV11'].shape == (self.freq_dim,)
        assert params['E21'].shape == (self.freq_dim,self.freq_dim)
        assert params['V21'].shape == (self.freq_dim,self.freq_dim)
        assert params['bE21'].shape == (self.freq_dim,)
        assert params['bV21'].shape == (self.freq_dim,)
        assert params['E12'].shape == (self.freq_dim,self.freq_dim)
        assert params['V12'].shape == (self.freq_dim,self.freq_dim)
        assert params['bE12'].shape == (self.freq_dim,)
        assert params['bV12'].shape == (self.freq_dim,)
        assert params['E22'].shape == (self.freq_dim,self.freq_dim)
        assert params['V22'].shape == (self.freq_dim,self.freq_dim)
        assert params['bE22'].shape == (self.freq_dim,)
        assert params['bV22'].shape == (self.freq_dim,)

    def get_params(self):
        params = {'E11':self.E11.cpu().data.numpy(),
                  'V11':self.V11.cpu().data.numpy(),
                  'bE11':self.bE11.cpu().data.numpy(),
                  'bV11':self.bV11.cpu().data.numpy(),
                  'E21':self.E21.cpu().data.numpy(),
                  'V21':self.V21.cpu().data.numpy(),
                  'bE21':self.bE21.cpu().data.numpy(),
                  'bV21':self.bV21.cpu().data.numpy(),
                  'E12':self.E12.cpu().data.numpy(),
                    'V12':self.V12.cpu().data.numpy(),
                    'bE12':self.bE12.cpu().data.numpy(),
                    'bV12':self.bV12.cpu().data.numpy(),
                    'E22':self.E22.cpu().data.numpy(),
                    'V22':self.V22.cpu().data.numpy(),
                    'bE22':self.bE22.cpu().data.numpy(),
                    'bV22':self.bV22.cpu().data.numpy()}
        return params

    def init_hidden(self,x):
        # Initialize hidden state initial conditions for:
        #     1. x_tm1        - x(-1)
        #     2. posteri_tm1  - posterior SNR at t=-1
        #     3. priori_tm1   - a priori SNR at t=-1
        #     4. llk_ratio_tm1- speech presence log likelihood ratio at t=-1
        #
        # x is assumed to be a TORCH VARIABLE that has dimension [NxTxF], where
        #     N - number of examples
        #     T - time dimension
        #     F - frequency dimension
        x_tm1 = x[:,0,:] # simply take x(-1) = x(0)
        Pxx_tm1 = x_tm1 **2
        Pnn_tm1 = torch.mean(x[:,:5,:],1)**2 #average first 5frames
        posteri_tm1 = Pxx_tm1/(Pnn_tm1+eps)
        priori_tm1 = self.mu+(1-self.mu)*torch.max(posteri_tm1-1,
                                                 torch.zeros_like(posteri_tm1))
        llk_ratio_tm1 = 1 - torch.log1p(priori_tm1)
        return x_tm1,posteri_tm1,priori_tm1,llk_ratio_tm1

    def forward(self,x_tm1,x_t,posteri_tm1,priori_tm1,llk_ratio_tm1):
        # Single forward step in time.
        # Assume x has the dimension [NxF], where
        # N is the batch size
        # F is the frequency dimension

        if DEBUG:
            if check_nan(llk_ratio_tm1):
                pdb.set_trace()

        if DEBUG:
            if check_nan(self.E11):
                pdb.set_trace()

        # G1: Estimate mu(t) from lk_ratio(t-1)
        #mu_t = self.M + (1-self.M) / \
        #        (1+torch.exp(-torch.matmul(llk_ratio_tm1-self.Eta,self.K)))
        # Modification: relax linear interpolation constraint
        mu_hat_t = self.mu+(1-self.mu)/(1+torch.exp(-(llk_ratio_tm1-self.eta)))
        # Alternatively, use linear activation below
        #mu_hat_t = torch.clamp((1-self.mu)/self.delta*(llk_ratio_tm1-(self.eta-self.delta))+self.mu,min=self.mu,max=1.)
        z11 = torch.clamp(torch.matmul(mu_hat_t,self.E11) + self.bE11,min=0.)
        mu1_t = torch.clamp(torch.matmul(z11,self.V11) + self.bV11,min=0.)
        z12 = torch.clamp(torch.matmul((1-mu_hat_t),self.E12) + self.bE12,min=0.)
        mu2_t = torch.clamp(torch.matmul(z12,self.V12) + self.bV12,min=0.)
        #mu_sum_t = mu1_t+mu2_t

        # a posteriori SNR Estimation
        Pnn_tm1_over_Pnn_t = 1/(mu1_t+ mu2_t*posteri_tm1)
        Pxx_tm1 = x_tm1 ** 2 # previous noisy PSD
        Pxx_t = x_t ** 2 # current noisy PSD
        posteri_t = Pnn_tm1_over_Pnn_t * posteri_tm1 * (Pxx_t/(Pxx_tm1+eps))

        if DEBUG:
            if check_nan(posteri_t):
                pdb.set_trace()

        posteri_prime_t = torch.clamp(posteri_t - 1,min=0.) # prevent negative
        # Modification: relax linear interpolation constraint
        #alpha_hat_t = 1/(1+((posteri_prime_t-priori_tm1)/(posteri_prime_t+1))**2)
        alpha_hat_t = torch.zeros_like(posteri_prime_t)+.98 # use constant instead of adaptation method

        # G2: Estimate alpha(t) from alpha_hat(t)
        z21 = torch.clamp(torch.matmul(alpha_hat_t,self.E21) + self.bE21,min=0.)
        # Clipping
        alpha1_t = torch.clamp(torch.matmul(z21,self.V21) + self.bV21,min=0.)
        z22 = torch.clamp(torch.matmul((1-alpha_hat_t),self.E22) + self.bE22,min=0.)
        # Clipping
        alpha2_t = torch.clamp(torch.matmul(z22,self.V22) + self.bV22,min=0.)
        #alpha_sum_t = alpha1_t+alpha2_t
        # A Priori SNR Estimation
        priori_t = alpha1_t*(priori_tm1*posteri_tm1/(1+priori_tm1)) + \
               alpha2_t*posteri_prime_t

        if DEBUG:
            if check_nan(priori_t):
                pdb.set_trace()

        # Finally, calculate likelihood ratio of the current frame and
        # the gain function of the current frame, and the output frame
        #G_t = torch.sqrt(priori_t/(1+priori_t)) # spectral subtraction solution
        G_t = priori_t/(1+priori_t) # Wiener optimal solution
        llk_ratio_t = priori_t*posteri_t/(1+priori_t) - torch.log1p(priori_t)
        y_t = x_t * G_t

        return y_t, x_t, posteri_t, priori_t, llk_ratio_t

    def denoise(self,x):
        y = torch.zeros_like(x) # hold denoised output
        llk_ratio = torch.zeros_like(x) # hold llk_ratio for all cells
        x_tm1,posteri_tm1,priori_tm1,llk_ratio_tm1 = self.init_hidden(x)
        for t in range(x.size(1)): # loop through time dimension
            y_t, x_tm1, posteri_tm1, priori_tm1, llk_ratio_tm1 =\
                self.forward(x_tm1,x[:,t,:],posteri_tm1,priori_tm1,llk_ratio_tm1)
            y[:,t,:] = y_t
            llk_ratio[:,t,:] = llk_ratio_tm1
            # Detach node so that gradients do not backpropagate through time
            x_tm1, posteri_tm1, priori_tm1, llk_ratio_tm1 =\
                self.repackage_hidden((x_tm1, posteri_tm1, priori_tm1, llk_ratio_tm1))
        return y,llk_ratio

    def get_cost(self,y,target,llk_ratio,vad):
        """
        Calculate cost from a batch output
        y     : denoised magnitude spectrogram
        target: target magnitude spectrogram
        lk_ratio: speech presence likelihood ratio inferred from a priori SNR
        vad   : true speech/nonspeech label (0 for nonspeech; 1 for speech)
        """
        # Error 1: magnitude spectrogram MSE
        E_recon = self.recon_error(target,y,option='mse')
        # Error 2: cross-entropy between speech-presence likelihood and true
        # class (speech/nonspeech)
        llk_ratio_mean = torch.clamp(torch.mean(llk_ratio,dim=-1),max=10.) # Clip to prevent overflow
        # might be overflow with exponential here @SOLVED
        E_vad = torch.mean(-vad * llk_ratio_mean + \
                            torch.log1p(torch.exp(llk_ratio_mean)))

        # Total cost (could add regularization here)
        cost = self.alpha * E_recon + (1-self.alpha) * E_vad

        return E_recon,E_vad,cost

    def recon_error(self,y,yhat,option='mse'):
        """
        A few cost function of the clean magnitude spectrogram y and its
        estimate yhat averaged over a single time-frequency bin.
        """
        if option == 'mse':
            # Simple mean squared error
            E_recon = torch.mean((y-yhat)**2)
        elif option == 'is':
            # Itakura-Saito measure
            # UNSTABLE because of division!!!
            Py = y**2
            Pyhat = yhat**2
            E_recon = torch.mean((Py/(Pyhat+eps))-torch.log(Py+eps)+torch.log(Pyhat+eps)-1)
        elif option == 'log':
            # log magnitude MSE
            E_recon = torch.mean((torch.log(y+eps)-torch.log(yhat+eps))**2)
        elif option == 'wmse':
            # weighted MSE d(y,yhat) = (y-yhat)**2/y
            # puts more emphasis on small amplitude than big amplitude
            E_recon = torch.mean((y-yhat)**2/(y+eps))
        else:
            raise ValueError('Error type not supported!')
        return E_recon



    def repackage_hidden(self,h):
        """Wraps hidden states in new Variables, to detach them from their history."""
        if type(h) == Variable:
            return Variable(h.data)
        else:
            return tuple(self.repackage_hidden(v) for v in h)

class MRNNEngine(object):

    def __init__(self,param_path=None,learning_rate=0.001,momentum=0.9,mode='cpu'):
        self.num_seen = 0 # remember number of examples passed
        if param_path is not None: # load from file
            params = self.load_params(param_path)
            print("Loading pre-trained parameters from [{}].".format(param_path))
            self.model = MRNN_ASNR(params=params,dtype=mode)
        else:
            print "Starting a blank model."
            self.model = MRNN_ASNR(dtype=mode)
        # TODO: What kind of optimizer should I use?
        # TODO: Dropout?
        # TODO: Back propagation through time?
        self.optimizer = optim.SGD(self.model.parameters(),
                                    lr=learning_rate,
                                    momentum=momentum)

    def train_with_sgd(self,X_train,T_train,VAD):
        assert X_train.shape == T_train.shape
        assert VAD.shape == X_train.shape[:2]
        idx = np.random.permutation(X_train.shape[0]) # randomize feature order
        X = Variable(torch.from_numpy(X_train[idx]).type(self.model.dtype),requires_grad=False)
        T = Variable(torch.from_numpy(T_train[idx]).type(self.model.dtype),requires_grad=False)
        VAD = Variable(torch.from_numpy(VAD[idx]).type(self.model.dtype),requires_grad=False)

        # Forward pass
        Y,LLK = self.model.denoise(X)
        # Calculate error
        E_recon,E_vad,cost = self.model.get_cost(Y,T,LLK,VAD)
        # Calculate gradients and update network weights
        self.optimizer.zero_grad()
        cost.backward()
        self.optimizer.step()
        self.num_seen += len(idx)
        vad_err = self.vad_error(LLK,VAD,threshold=.15)
        return E_recon.cpu().data.numpy()[0],E_vad.cpu().data.numpy()[0],cost.cpu().data.numpy()[0],vad_err


    def vad_error(self,llk_ratio,vad,threshold=.15):
        """
        Calculate VAD percentage error of wrongly predicted frame assuming a
        pre-determined threshold. Assume both llk_ratio and vad are torch variables.
        """
        LLK = llk_ratio.cpu().data.numpy()
        VAD = vad.cpu().data.numpy()
        LLK_mean = np.minimum(np.mean(LLK,axis=-1),10.)
        decision_wrong = np.logical_xor(LLK_mean>threshold,VAD==1)
        return np.sum(decision_wrong) * 100. / VAD.size


    def denoise(self,x_noisy):
        # assume x_noisy is a [TxF] magnitude spectrogram
        X = Variable(torch.from_numpy(np.array([x_noisy])).type(self.model.dtype),
                     requires_grad=False)
        y,llk_ratio =  self.model.denoise(X) # takes in tensor only
        return y.cpu().data[0].numpy()

    def get_params(self):
        return self.model.get_params()

    def load_params(self, path):
        return np.load(path)

    def save_params(self, path):
        params = self.get_params()
        np.savez(path,E11=params['E11'],V11=params['V11'],bE11=params['bE11'], bV11=params['bV11'],
              E21=params['E21'],V21=params['V21'],bE21=params['bE21'], bV21=params['bV21'],
              E12=params['E12'],V12=params['V12'],bE12=params['bE12'], bV12=params['bV12'],
              E22=params['E22'],V22=params['V22'],bE22=params['bE22'], bV22=params['bV22'])

def check_nan(m):
    tmp = m.cpu().data.numpy()
    return (True in np.isnan(tmp))


if __name__ == '__main__':
    engine = MRNNEngine()
    X = np.ones((10,10,257))
    for i in range(10):
        X[i] += i
    T = X.copy()
    VAD = np.zeros((10,10))
    E_recon,E_vad,cost = engine.train_with_sgd(X,T,VAD)
    print E_recon,E_vad,cost
