# Training Procedure for MRNN
# Author: Raymond Xia (yangyanx@andrew.cmu.edu)

# MRNN core and configuration
from config import *
#import theano
#from mrnn_asnr import MRNNTheano, MRNNEngine
from mrnn_asnr_pytorch import *
from rats_util import SampleGenerator, MRNNLogger

# Audio utilities
from batch_io import AudioTransformer
from util import *
from spectrogram import *

# Other general utilities
import argparse
import subprocess
import sys
import os
import time
import numpy as np
import itertools
from datetime import datetime

# Configurations
cfg_white = config['white_noise']
cfg_rats  = config['rats']
#floatX = theano.config.floatX # config to float32 for GPU processing
flatten = lambda l: [item for sublist in l for item in sublist]
# for getting magnitude spectrogram from spectrogram
if cfg_rats['ZERO_PHASE_STFT']:
    magspec = lambda x: magphase(stft(x, cfg_rats['SAMPLE_RATE'],
            window_length=cfg_rats['WINDOW_LENGTH'],
            hop_fraction=cfg_rats['HOP_FRACTION'],
            nfft=cfg_rats['FFT_SIZE'],
            stft_truncate=True)[-1])[0]
else:
    magspec = lambda x: magphase(stft_old(x,
            int(cfg_rats['WINDOW_LENGTH']*cfg_rats['SAMPLE_RATE']),
            cfg_rats['HOP_FRACTION'],
            truncate=True))[0]

def train_white_noise(start_with=None):
    # Load data and train model.
    print "Loading data from %s ..." % cfg_white['TRAIN_DIR']
    TIMIT = AudioTransformer(cfg_white['TRAIN_DIR'],
                            cfg_white['SAMPLE_RATE'], mono=True, verbose=True)

    TRAIN = [] # holds all training segments
    VALIDATION = [] # holds all validation segments
    seg_length = int(cfg_white['SAMPLE_RATE'] * .32) # putting 20 frames per example
    seg_num = 10  # number of examples to be drawn from one audio segment
    for x, p in TIMIT:
        #train, validate = x[:len(x)/2],x[len(x)/2:]
        TRAIN.extend(sample(x,seg_length,seg_num))
        VALIDATION.extend(sample(x,seg_length,2))

    print "Total number of training examples: [{}]".format(len(TRAIN))
    print "Total number of validation examples: [{}]".format(len(VALIDATION))

    TRAIN_noisy = map(add_noise_rand, TRAIN)
    VALIDATION_noisy = map(add_noise_rand, VALIDATION)
    TRAIN_M, TRAIN_M_noisy = np.array(map(mag_spec,\
                                TRAIN)),np.array(map(mag_spec, TRAIN_noisy))
    VAL_M, VAL_M_noisy = np.array(map(mag_spec, \
                                VALIDATION)),np.array(map(mag_spec, VALIDATION_noisy))
    assert TRAIN_M.shape == TRAIN_M_noisy.shape
    assert VAL_M.shape == VAL_M_noisy.shape
    print "Training: {}, Validation: {}".format(TRAIN_M.shape,VAL_M.shape)

    print "Start building MRNN."
    engine = MRNNEngine(param_path=start_with)
    print "MRNN successfully built."
    engine.train_with_sgd(TRAIN_M_noisy,TRAIN_M,
                          VAL_M_noisy,VAL_M,'model/',
                          learning_rate=cfg_white['LEARNING_RATE'],
                          nepoch=cfg_white['NEPOCH'],decay=0.9,
                          callback_every=cfg_white['PRINT_EVERY'])


def train_additive_noise(outdir,start_with=None):
    # Train MRNN using TIMIT data
    # Use configuration for RATS data

    rnoise = list(AudioTransformer('/Users/xyy/local/afeka/deliverable1/restaurant_noise.wav',16000, mono=True))[0][0]

    # helper for feature extraction
    def compose_features(batch):
        """
        Compose noisy-clean features for training.
        There are two types of noisy speech:
            * additive white noise
            * additive background music
        """
        seg_length = int(cfg_rats['SAMPLE_RATE'] * cfg_rats['EXAMPLE_LENGTH'])
        num_examples_per_segment = cfg_rats['EXAMPLE_SIZE']
        TRAIN_clean,TRAIN_noisy = [],[]
        for x,p in batch: # collect segments from each batch
            # take num_examples_per_segment `seg_length` frames from each audio
            xref = sample(normalize(x),seg_length,num_examples_per_segment)
            xref_white = [add_white_noise_rand(x) for x in xref]
            xref_rnoise = [add_noise(x,rnoise) for x in xref]
            TRAIN_clean.extend(xref)
            TRAIN_clean.extend(xref)
            TRAIN_noisy.extend(xref_white)
            TRAIN_noisy.extend(xref_rnoise)
        # extract features from each batch
        TRAIN_M       = np.array(map(mag_spec,TRAIN_clean))
        TRAIN_M_noisy = np.array(map(mag_spec, TRAIN_noisy))
        return TRAIN_M, TRAIN_M_noisy

    # build MRNN engine
    print "Start building MRNN."
    engine = MRNNEngine(cfg_rats['FREQ_DIM'],cfg_rats['HIDDEN_DIM'],
                        cfg_rats['BPTT_TRUNCATE'], param_path=start_with)
    print "MRNN successfully built."

    # collect features and do training
    epoch_size = 200
    minibatch_size = cfg_rats['MINIBATCH_SIZE'] # number of audio segments to be used each batch
    print "Loading data from {}...".format(cfg_white['TRAIN_DIR'])
    #TRAIN = AudioTransformer(cfg_white['TRAIN_DIR'],
    #                        cfg_white['SAMPLE_RATE'], mono=True,
    #                        random_order=True, repeat=epoch_size)
    TRAIN = AudioTransformer('/Volumes/Data/LDC2015S02/RATS_SAD/data/dev-2/audio/src/',
                            cfg_white['SAMPLE_RATE'], mono=True,
                            random_order=True, repeat=20)



    # Learning starts here
    batch_seen = 0
    for batch in grouper(TRAIN,minibatch_size): # process 1 minibatch
        TRAIN_M, TRAIN_M_noisy = compose_features(batch)
        assert TRAIN_M.shape == TRAIN_M_noisy.shape
        train_loss = engine.train_with_sgd(TRAIN_M_noisy,TRAIN_M,
                              learning_rate=cfg_rats['LEARNING_RATE'],
                              decay=cfg_rats['RMSPROP_DECAY'])
        # print out loss after one pass
        dt = datetime.now().isoformat()
        print("\n%s (Examples:%d)" % (dt,engine.num_seen))
        print("--------------------------------------------------")
        print("Loss: %f" % train_loss)
        sys.stdout.flush()
        batch_seen += 1
        # Save parameters after some passes
        if not (batch_seen % 10): # save parameters per some batches
            dt = datetime.now().isoformat()
            ts = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
            MODEL_OUTPUT_FILE = "MRNN-%s-%s-%s-%s.dat" % (ts, engine.feat_dim, engine.hidden_dim,engine.num_seen)
            print("Saving parameters to %s..." % MODEL_OUTPUT_FILE)
            engine.save_params(os.path.join(outdir,MODEL_OUTPUT_FILE))

def train_rats_sre04(outdir,start_with=None):
    # Train MRNN using SRE04 data passed through RATS channels

    # Grab noisy-clean mapping
    keys = {}
    with open("/Volumes/Data/LDC2012E40/data/LDC2012E40/docs/filemap.tsv") as fp:
        lines = fp.readlines()
        for l in lines:
            clean,noisy = l.strip().split(' ')
            keys[noisy] = clean

    # helper for feature extraction
    def compose_rats_features(batch):
        """
        Compose noisy-clean features given a minibatch and a lookup table to
        search for clean speech.
        """
        seg_length = int(cfg_rats['SAMPLE_RATE'] * cfg_rats['EXAMPLE_LENGTH'])
        num_examples_per_segment = cfg_rats['EXAMPLE_SIZE']
        TRAIN_clean,TRAIN_noisy = [],[]
        for x,p in batch: # collect segments from each batch
            k = os.path.basename(p)[:-7] # trim _A.flac
            assert k in keys
            p_ref = os.path.join(cfg_rats['SRC_DIR'],keys[k]+'_src.flac')
            x_ref = list(AudioTransformer(p_ref,cfg_rats['SAMPLE_RATE'],
                                            mono=True,verbose=False))[0][0]
            # take 100 `seg_length` frames from each audio
            xsegs, xsegs_ref = sample_pair(x,x_ref,seg_length,
                                            num_examples_per_segment)
            TRAIN_clean.extend(xsegs_ref)
            TRAIN_noisy.extend(xsegs)
        # extract features from each batch
        TRAIN_M       = np.array(map(mag_spec,TRAIN_clean))
        TRAIN_M_noisy = np.array(map(mag_spec, TRAIN_noisy))
        return TRAIN_M, TRAIN_M_noisy

    # build MRNN engine
    print "Start building MRNN."
    engine = MRNNEngine(cfg_rats['FREQ_DIM'],cfg_rats['HIDDEN_DIM'],
                        cfg_rats['BPTT_TRUNCATE'], param_path=start_with)
    print "MRNN successfully built."

    # collect features and do training
    epoch_size = cfg_rats['NEPOCH']
    minibatch_size = cfg_rats['MINIBATCH_SIZE'] # number of audio segments to be used each batch

    # Learning starts here
    for epoch in xrange(epoch_size): # do 1 epoch
        print "Epoch %d: Loading data from %s ..." % (epoch,cfg_rats['TRAIN_DIR'])
        # Each epoch will have a different order
        RATS_A = AudioTransformer(cfg_rats['TRAIN_DIR'],
                                cfg_rats['SAMPLE_RATE'], mono=True,
                                random_order=True, verbose=False)
        for batch in grouper(RATS_A,minibatch_size): # process 1 minibatch
            TRAIN_M, TRAIN_M_noisy = compose_rats_features(batch)
            assert TRAIN_M.shape == TRAIN_M_noisy.shape
            train_loss = engine.train_with_sgd(TRAIN_M_noisy,TRAIN_M,
                                  learning_rate=cfg_rats['LEARNING_RATE'],
                                  decay=cfg_rats['RMSPROP_DECAY'])
            ll.append(train_loss)
            # print out loss after one pass
            dt = datetime.now().isoformat()
            print("\n%s (Epoch:%d, Examples:%d)" % (dt,epoch,engine.num_seen))
            print("--------------------------------------------------")
            print("Loss: %f" % train_loss)
            sys.stdout.flush()
        # Save parameters after one entire pass of all batches
        dt = datetime.now().isoformat()
        ts = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        MODEL_OUTPUT_FILE = "MRNN-%s-%s-%s.dat" % (ts, engine.feat_dim, engine.hidden_dim)
        print("Saving parameters to %s..." % MODEL_OUTPUT_FILE)
        engine.save_params(os.path.join(outdir,MODEL_OUTPUT_FILE))

def train_rats_2015(outdir,start_with=None,wiener=False):
    # Train MRNN using LDC2015S02 data

    # Create a MRNN logger
    logger = MRNNLogger(outdir)
    print "Logger created."

    # build MRNN engine
    print "Start building MRNN."
    #engine =None
    engine = MRNNEngine(cfg_rats['FREQ_DIM'],cfg_rats['HIDDEN_DIM'],
                        cfg_rats['BPTT_TRUNCATE'], param_path=start_with,
                        wiener=wiener)
    print "MRNN successfully built."

    # collect features and do training
    epoch_size = cfg_rats['NEPOCH']
    minibatch_size = cfg_rats['MINIBATCH_SIZE'] # number of audio segments to be used each batch
    example_size = cfg_rats['EXAMPLE_SIZE']
    type_ratio = cfg_rats['TYPE_RATIO']

    # Learning starts here
    for epoch in xrange(epoch_size): # do 1 epoch
        #print "Epoch %d: Loading data from %s ..." % (epoch,cfg_rats['RATS_A'])
        # Each epoch will have a different order
        TRAINDIR = [cfg_rats['RATS_A'],cfg_rats['RATS_D'],cfg_rats['RATS_H']]
        TRAIN = AudioTransformer(TRAINDIR,
                                cfg_rats['SAMPLE_RATE'], mono=True,
                                random_order=True, verbose=True)
        ll = [] # contains loss for each epoch
        #"""
        for audio_batch in grouper(TRAIN,10): # process 10 audio files at a time
            #xSG = SampleGenerator(x,xpath,minibatch_size,type_ratio,maxbatch=1)
            xSGs = map(lambda tup: SampleGenerator(tup[0],tup[1],example_size,
                                    type_ratio=cfg_rats['TYPE_RATIO'],maxbatch=1),audio_batch)
            for data_batch in itertools.izip(*xSGs): # grouped in batch already
                MAG_Y,MAG_X,PHZ_X = map(np.concatenate, zip(*data_batch))
                assert MAG_Y.shape == MAG_X.shape == PHZ_X.shape
                if wiener:
                    train_loss = engine.train_with_sgd(MAG_X,MAG_Y,P_train=PHZ_X,
                                      learning_rate=cfg_rats['LEARNING_RATE'],
                                      decay=cfg_rats['RMSPROP_DECAY'])
                else:
                    train_loss = engine.train_with_sgd(MAG_X,MAG_Y,
                                      learning_rate=cfg_rats['LEARNING_RATE'],
                                      decay=cfg_rats['RMSPROP_DECAY'])
                ll.append(train_loss)
                # print out loss after one pass
                dt = datetime.now().isoformat()
                print("\n%s (Epoch:%d, Examples:%d)" % (dt,epoch,engine.num_seen))
                print("--------------------------------------------------")
                print("Loss: %f" % train_loss)
                sys.stdout.flush()

        # Save parameters after one entire pass of all batches
        dt = datetime.now().isoformat()
        ts = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        logger.write_epoch_loss(ll,engine.num_seen,ts)
        MODEL_OUTPUT_FILE = "MRNN-%s-%s-%s-%s.dat" % (ts, engine.feat_dim, engine.hidden_dim,engine.num_seen)
        print("Saving parameters to %s..." % MODEL_OUTPUT_FILE)
        engine.save_params(os.path.join(outdir,MODEL_OUTPUT_FILE))

def train_asnr_rats_2015(outdir,start_with=None,mode='cpu'):
    """
    Train the a priori SNR estimation network using LDC2015S02 data
    Args:
        outdir       - output network parameters directory
        [start_with] - specifies pre-trained network parameters directory
    """

    # Create a MRNN logger
    logger = MRNNLogger(outdir)
    print "Logger created."

    # build MRNN engine
    print "Start building MRNN."
    #engine =None
    #engine = MRNNEngine(cfg_rats['FREQ_DIM'],cfg_rats['HIDDEN_DIM'],
    #                    cfg_rats['BPTT_TRUNCATE'], param_path=start_with,
    #                    train=1)
    engine = MRNNEngine(param_path=start_with,
                        learning_rate=CONFIG_MRNN['LEARNING_RATE'],
                        momentum=CONFIG_MRNN['MOMENTUM'],
                        mode=mode)
    print "MRNN successfully built."

    # collect features and do training
    epoch_size = cfg_rats['NEPOCH']
    minibatch_size = cfg_rats['MINIBATCH_SIZE'] # number of audio segments to be used each batch
    example_size = cfg_rats['EXAMPLE_SIZE']
    type_ratio = cfg_rats['TYPE_RATIO']

    # Learning starts here
    for epoch in xrange(epoch_size): # do 1 epoch
        #print "Epoch %d: Loading data from %s ..." % (epoch,cfg_rats['RATS_A'])
        # Each epoch will have a different order
        TRAINDIR = [cfg_rats['RATS_A'],cfg_rats['RATS_D'],cfg_rats['RATS_H']]
        #TRAINDIR = cfg_rats['RATS_A']
        TRAIN = AudioTransformer(TRAINDIR,
                                cfg_rats['SAMPLE_RATE'], mono=True,
                                random_order=True, verbose=True)
        ll = [] # contains loss for each epoch
        #"""
        for audio_batch in grouper(TRAIN,minibatch_size): # process 10 audio files at a time
            #xSG = SampleGenerator(x,xpath,minibatch_size,type_ratio,maxbatch=1)
            xSGs = map(lambda tup: SampleGenerator(tup[0],tup[1],example_size,
                    type_ratio=cfg_rats['TYPE_RATIO'],maxbatch=1),audio_batch)
            for data_batch in itertools.izip(*xSGs): # grouped in batch already
                MAG_CLEAN,MAG_NOISY,VAD_CLEAN = map(np.concatenate, zip(*data_batch))
                assert MAG_CLEAN.shape == MAG_NOISY.shape
                assert VAD_CLEAN.shape == MAG_CLEAN.shape[:2]
                speech_percentage = np.sum(VAD_CLEAN) / VAD_CLEAN.size * 100
                E_recon,E_vad,E_tot,vad_err = engine.train_with_sgd(MAG_NOISY,
                                  MAG_CLEAN,VAD_CLEAN)
                ll.append((E_recon,E_vad,E_tot,vad_err))
                # print out loss after one pass
                dt = datetime.now().isoformat()
                print("\n{} (Epoch:[{}], Examples:[{}], Speech percentage:[{}%])".format(dt,epoch,engine.num_seen,speech_percentage))
                print("--------------------------------------------------")
                print("MSE Loss:[{}]; VAD Loss:[{}]; Total Loss:[{}]; VAD Error Percentage (threshold=0.15):[{}%]".format(E_recon,E_vad,E_tot,vad_err))
                sys.stdout.flush()

        # Save parameters after one entire pass of all batches
        dt = datetime.now().isoformat()
        ts = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        logger.write_epoch_loss(ll,engine.num_seen,ts)
        MODEL_OUTPUT_FILE = "MRNN-ASNR-%s-%s-%s.dat" % (ts, engine.model.freq_dim,engine.num_seen)
        print("Saving parameters to %s..." % MODEL_OUTPUT_FILE)
        engine.save_params(os.path.join(outdir,MODEL_OUTPUT_FILE))


# An iterator grouper for chunk processing
def grouper(iterable,n):
    it = iter(iterable)
    while True:
       chunk = tuple(itertools.islice(it, n))
       if not chunk:
           return
       yield chunk


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-m',
        help='Specify initial model path', required=False)
    parser.add_argument('-o',
        help='Specify model saving path', required=True)
    parser.add_argument('-gpu',
        help='Enable CUDA.', required=False,action='store_true',
        default=False)
    args = parser.parse_args()

    start_with = None
    if args.m: start_with = args.m
    if args.gpu:
        mode = 'gpu'
    else:
        mode = 'cpu'
    train_asnr_rats_2015(args.o,start_with=start_with,mode=mode)
    #train_white_noise(start_with)
    #train_rats_sre04(args.o,start_with=start_with)
    #train_rats_2015(args.o,start_with=start_with,wiener=cfg_rats['WIENER'])
    #train_additive_noise(args.o,start_with=start_with)
