# Some utility functions for sampling data in the RATS LDC2015S02 dataset
#
# Author: Raymond Xia (yangyanx@andrew.cmu.edu)
import numpy as np
import os
from audio_io import audioread
from config import *
from copy import copy
import itertools
import theano
from spectrogram import *
import subprocess
import random
import numpy.random as rand

cfg_rats = config['rats']
floatX = theano.config.floatX # config to float32 for GPU processing

# spectrogram helper functions for current settings
stft_tmp = lambda x: stft(x, cfg_rats['SAMPLE_RATE'],
        w_start=int(-(cfg_rats['WINDOW_LENGTH']*cfg_rats['SAMPLE_RATE'])/2),
        window_length=cfg_rats['WINDOW_LENGTH'],
        hop_fraction=cfg_rats['HOP_FRACTION'],
        nfft=cfg_rats['FFT_SIZE'],
        stft_truncate=True,
        tf_unit='discrete')
magspec_tmp = lambda X: magphase(X)[0].astype(floatX)

# An iterator grouper for chunk processing
def grouper(iterable,n):
    it = iter(iterable)
    while True:
       chunk = tuple(itertools.islice(it, n))
       if not chunk:
           return
       yield chunk



class SampleGenerator(object):
    """
    Generator that yields `batch_size` of RATS LDC2015S02 segments on each
    iteration. 3 types of segments are available: pure noise, pure speech, and
    noisy speech. Pure speech is pre-appended by 5 frames of silence, and noisy
    speech is pre-appended with 5 frames of pure noise for the purpose of MRNN
    training.
    """
    def __init__(self,x,xpath,batch_size,type_ratio=None,maxbatch=10,verbose=False):
        """
        x  - audio as numpy array
        xpath - path to x
        batch_size - Number of examples per yield
        type_ratio - a length-3 tuple specifying the mix ratio of
        (pure_noise,pure_speech,noisy_speech)
        """
        if type_ratio is not None:
            assert np.sum(type_ratio) == 1.
        if verbose: print("Process [{}].".format(os.path.basename(xpath)))
        self.x = x
        self.xpath = xpath
        self.batch_size = batch_size # number of examples drawn per batch
        #self.noise_ratio,self.speech_ratio,self.nspeech_ratio = type_ratio
        #self.noise_ratio,self.nspeech_ratio = type_ratio
        self.batch_togo = maxbatch # maximum number of batches to yield
                                   # default to 10

        if type_ratio is not None:
            self.noise_size = int(self.batch_size*type_ratio[0])
            #self.speech_size = int(self.batch_size*self.noise_ratio)
            self.nspeech_size = int(self.batch_size*type_ratio[1])
        else:
            self.noise_size = None
            self.nspeech_size = None

        # Get clean speech from noisy speech path
        fid = os.path.basename(xpath)[:5]
        fname = os.path.basename(xpath)
        if ('A' in fname) or ('H' in fname): # training data from dev-1
            srcdir = cfg_rats['SRC_AUD']
            saddir = cfg_rats['SRC_SAD']
        elif ('D' in fname): # channel D training data comes from dev-2
            srcdir = cfg_rats['SRC_AUD2']
            saddir = cfg_rats['SRC_SAD2']
        else:
            raise ValueError('File has to be in one of A/D/H channels.')
        self.rpath = subprocess.check_output(
            ['find',srcdir,'-name',fid+'*.flac']).strip()
        self.r,_ = audioread(self.rpath,cfg_rats['SAMPLE_RATE'])

        # Now read its corresponding SAD file
        self.speech = [] # stores tuples of time [start,end)
        self.nonspeech = []
        self.seglength = int(cfg_rats['EXAMPLE_LENGTH']*cfg_rats['SAMPLE_RATE'])
        self.maxlength = min(len(self.x),len(self.r))
        assert self.seglength <= self.maxlength
        spath = subprocess.check_output(
            ['find',saddir,'-name',fid+'*.tab']).strip()
        with open(spath,'r') as fp:
            lines = fp.readlines()
            for l in lines:
                l = l.strip().split()
                stype,tstart,tend = l[4],float(l[2]),float(l[3])
                # store index rather
                tstart = int(tstart*cfg_rats['SAMPLE_RATE'])
                tend = int(tend*cfg_rats['SAMPLE_RATE'])
                if stype == 'S': # speech
                    self.speech.append((tstart,tend))
                elif stype == 'NS':
                    self.nonspeech.append((tstart,tend))
                else:
                    print('Incorrect label in {}! Ignoring time points.'.format(self.rpath))

        def tup2idx(ttup):
            """
            Convert time tuple [tstart,tend) to an array of starting index such that
            x[idx:idx+EXAMPLE_LENGTH*fs] is one example segment. No such segment will
            overlap
            """
            tstart,tend = ttup
            idx_start = int(tstart*cfg_rats['SAMPLE_RATE'])
            idx_stop  = int(tend*cfg_rats['SAMPLE_RATE'])
            if idx_start < 0 or idx_stop > min(len(self.x),len(self.r)):
                # get rid of bad points
                return []
            idx_start_max = idx_stop-self.seglength # max index start
            return range(idx_start,idx_start_max+1,self.seglength)


        # Now convert time points into sample points, with appropriate length
        # for each segment
        flatten = lambda l: [item for sublist in l for item in sublist]
        """
        self.noise_idx = flatten(map(tup2idx,self.nonspeech))
        self.noise_idx = self.noise_idx[:len(self.noise_idx)\
                                /self.noise_size*self.noise_size]
        #self.cspeech_idx = flatten(map(tup2idx,self.speech))
        self.nspeech_idx = flatten(map(tup2idx,self.speech))
        #self.cspeech_idx = self.cspeech_idx[:len(self.cspeech_idx)\
        #                        /self.speech_size*self.speech_size]
        self.nspeech_idx = self.nspeech_idx[:len(self.nspeech_idx)\
                                /self.nspeech_size*self.nspeech_size]

        # Finally, randomize order
        self.noise_idx = sorted(iter(self.noise_idx),
                                key=lambda k: random.random())
        self.nspeech_idx = sorted(iter(self.nspeech_idx),
                                key=lambda k: random.random())
        """
        # All set for iteration
    def is_speech(self,idx):
        """
        Given an index, check if it belongs to a speech segment or not.
        """
        for n_start,n_end in self.speech:
            if idx >= n_start and idx < n_end: # in speech region
                return True
            elif idx < n_end: # all the rest n_start will be larger than idx
                              # assuming pair is arranged in order (as in .tab)
                return False
        return False

    # helper for feature extraction
    def compose_rats_features(self,idxl):
        """
        Compose noisy-clean features given a list of starting indices for
        3 types of audio: [no]ise, clean speech, and noisy speech
        """
        # learning noisy -> clean
        CLEAN,NOISY = [],[]

        # append all pure noise and noisy speech segments
        for i in idxl:
            idx = int(i)
            CLEAN.append(self.r[idx:idx+self.seglength])
            NOISY.append(self.x[idx:idx+self.seglength])
        # append clean; note that this is learning identity for clean speech
        #TRAIN_Y.extend([self.r[i:i+self.seglength] for i in cs])
        #TRAIN_X.extend([self.r[i:i+self.seglength] for i in cs])

        # extract features from each batch
        TIME_CLEAN,_,STFT_CLEAN = zip(*map(stft_tmp,CLEAN))
        MAG_CLEAN = np.array(map(magspec_tmp,STFT_CLEAN))
        MAG_NOISY = np.array(map(magspec_tmp,map(lambda x: stft_tmp(x)[-1],NOISY)))
        # Get true VAD labels for each time frame
        idx_offset = np.array(idxl,dtype='int')
        IDX_CLEAN = np.array(TIME_CLEAN) + idx_offset[:,np.newaxis]
        VAD_CLEAN = np.zeros_like(IDX_CLEAN,dtype=floatX)
        for i in xrange(IDX_CLEAN.shape[0]):
            for j in xrange(IDX_CLEAN.shape[1]):
                if self.is_speech(IDX_CLEAN[i,j]):
                    VAD_CLEAN[i,j] = 1.
        #mag,phase     = zip(*map(magphase,map(stft_tmp,TRAIN_X)))
        #MAG_X = np.array(mag,dtype=floatX)
        #PHZ_X = np.array(np.angle(phase),dtype=floatX)
        return MAG_CLEAN,MAG_NOISY,VAD_CLEAN



    def __iter__(self):
        for iters in xrange(self.batch_togo):
            if self.noise_size is None or self.nspeech_size is None:
                # Generate randomly without considering proportion
                # of the type of the starting audio frame (speech/nonspeech)
                start_idx_max = self.maxlength-self.seglength
                start_idx = np.around(rand.rand(
                                cfg_rats['EXAMPLE_SIZE']) * start_idx_max)
            else:
                # Need to mix with specified proportion
                no_idx = [] # noise starting index
                ns_idx = [] # noisy speech starting index
                while (len(no_idx) < self.noise_size) or (len(ns_idx) < self.nspeech_size):
                    start_idx_max = self.maxlength-self.seglength
                    start_idx = np.around(rand.rand(
                                    cfg_rats['EXAMPLE_SIZE']) * start_idx_max)
                    if len(ns_idx) < self.nspeech_size:
                        ns_idx.extend(filter(self.is_speech,start_idx))
                    if len(no_idx) < self.noise_size:
                        no_idx.extend(filter(lambda idx: not self.is_speech(idx),start_idx))
                assert len(no_idx) >= self.noise_size
                assert len(ns_idx) >= self.nspeech_size
                start_idx = no_idx[:self.noise_size]+ns_idx[:self.nspeech_size]
                # randomize order
                start_idx = sorted(iter(start_idx),
                                        key=lambda k: random.random())

            yield self.compose_rats_features(start_idx)

class MRNNLogger(object):
    """
    Logging utility for MRNN. Print meta-parameters of RNN to the header of
    a log file. On each training iteration, print the average loss to file as
    well.
    """
    def __init__(self,outdir):
        self.filename = 'mrnn.log'
        self.filepath = os.path.abspath(os.path.join(outdir,self.filename))
        self.epoch_count = 0 # counter for epoch
        if not os.path.exists(os.path.dirname(self.filepath)):
            os.makedirs(os.path.dirname(self.filepath))
        # Create fresh log file and write headers
        with open(self.filepath,'w') as fp:
            fp.write("MRNN training parameters:\n")
            fp.write("\tTime dimension:{}\n".format(\
                    int(cfg_rats['EXAMPLE_LENGTH']/cfg_rats['HOP_FRACTION']/\
                    cfg_rats['WINDOW_LENGTH'])))
            fp.write("\tFrequency dimension:{}\n".format(cfg_rats['FREQ_DIM']))
            fp.write("\tHidden dimension:{}\n".format(cfg_rats['HIDDEN_DIM']))
            fp.write("\tLearning rate:{}\n".format(cfg_rats['LEARNING_RATE']))
            fp.write("\tRMSProp:{}\n".format(cfg_rats['RMSPROP_DECAY']))
            fp.write("\tBPTT:{}\n".format(cfg_rats['BPTT_TRUNCATE']))
            fp.write("\tRandom samples/audio:{}\n".format(cfg_rats['EXAMPLE_SIZE']))
            fp.write("\tMinibatch size:{}\n".format(cfg_rats['MINIBATCH_SIZE']))
            fp.write("\Speech/nonspeech proportion:{}\n".format(cfg_rats['TYPE_RATIO']))
            fp.write("\tMSE/VAD cost weighting factor{}\n".format(cfg_rats['BETA']))
            fp.write("END OF specifications. Loss record starts next line.\n")

    def write_epoch_loss(self,ll,numseen,timestamp):
        """
        Write a list of loss 'l' to log file. Assume from a single epoch.
        """
        self.epoch_count += 1
        with open(self.filepath,'a+') as fp:
            fp.write("{}-Epoch{}[{}-seen]:".format(timestamp,self.epoch_count,numseen))
            fp.write("E_recon\tE_vad\tE_tot\tVAD_err%\n")
            for E_recon,E_vad,E_tot,vad_err in ll:
                fp.write(str(E_recon)+'\t'+str(E_vad)+'\t'+str(E_tot)+'\t'+str(vad_err)+'\n')
