# Wrapper for VTSpkg by Pedro Moreno
# Add utilities for raw audio input and output
# Author: Raymond Xia (yangyanx@andrew.cmu.edu)
from AudioTransformer import AudioTransformer
from spectrogram import *
import os
import subprocess
from audio_io import *

# Global configuration parameters
VTS_CONFIG = {
    # spectrogram parameters
    'SAMPLE_RATE': 16000,
    'WINDOW_LENGTH': 0.032,
    'FFT_SIZE': 512,
    'HOP_FRACTION': 0.5, # 0.25 too small?
    # Log mel-spectrogram parameters
    'NFILTS': 40,
    'MINFRQ': 0.,
    'MAXFRQ': 8000.,
    # VTS parameters
    'GMM_PATH': 'estimate_mixgau/linux/estim_mixgau',#path to `estim_mixgau` binary
    'VTS_PATH': 'vts_norm/linux/VTS', # path to `VTS` binary
    'MIXTURE_PATH': 'gmm/mixture.gmm',  # mixure saving path
    'FEATURE_PATH': 'gmm/feats/',# feature saving path
    'MIXTURE_MODE': 32,
    'INVERT_SPECTROGRAM': 'mask' # one of None, 'direct', or 'mask'
                                 # None - do nothing and leave logspec in clean/
                                 # 'direct' - invert using Dan Ellis' method
                                 # 'mask' - estimate T-F mask and apply to
                                 #          original complex spectrogram
}

class VTS(object):
    """
    Vector Taylor Series (VTS) compensation.
    Args:
        indir   - input audio file/directory
        outdir  - VTS-compensated audio saving directory
        [sr]    - sampling rate. default=64.
        [model] - path to clean feature GMM. default=None
    """
    def __init__(self,indir,outdir,model=None):
        self.indir = indir
        self.outdir = outdir
        self.model = model

    def extract_feature(self,x):
        """
        Extract features from raw audio x for VTS.
        """
        _,__,X = stft(x,VTS_CONFIG['SAMPLE_RATE'],
                      window_length=VTS_CONFIG['WINDOW_LENGTH'],
                      hop_fraction=VTS_CONFIG['HOP_FRACTION'],
                      nfft=VTS_CONFIG['FFT_SIZE'])
        X,_ = audspec(magphase(X)[0]**2,
                    nfft=VTS_CONFIG['FFT_SIZE'],
                    sr=VTS_CONFIG['SAMPLE_RATE'],
                    nfilts=VTS_CONFIG['NFILTS'],
                    minfrq=VTS_CONFIG['MINFRQ'],
                    maxfrq=VTS_CONFIG['MAXFRQ'])
        return logspec(X)

    def invert_feature(self,X,X_orig,mask=True):
        """
        Invert log mel-spectrogram to raw audio.
        """
        if not mask: # invert using Ellis' method
            X = np.exp(X) # invert log
            P = magphase(X_orig)[1] # original phase
            M_inv = invaudspec(X, nfft=VTS_CONFIG['FFT_SIZE'],
                sr=VTS_CONFIG['SAMPLE_RATE'], nfilts=VTS_CONFIG['NFILTS'],
                minfrq=VTS_CONFIG['MINFRQ'], maxfrq=VTS_CONFIG['MAXFRQ'])
            x = istft(np.sqrt(M_inv)*P,VTS_CONFIG['SAMPLE_RATE'],
                      window_length=VTS_CONFIG['WINDOW_LENGTH'],
                      hop_fraction=VTS_CONFIG['HOP_FRACTION'])
        else: # estimate mask and apply to original complex spectrogram
            # compute original hz-mel weighting matrix first
            weights,_ = fft2mel(VTS_CONFIG['FFT_SIZE'],
                          sr=VTS_CONFIG['SAMPLE_RATE'],
                          nfilts=VTS_CONFIG['NFILTS'],
                          minfrq=VTS_CONFIG['MINFRQ'],
                          maxfrq=VTS_CONFIG['MAXFRQ'])
            mask = invaudspec_mask(X,weights[:,:X_orig.shape[1]])
            x = istft(X_orig * mask,VTS_CONFIG['SAMPLE_RATE'],
                      window_length=VTS_CONFIG['WINDOW_LENGTH'],
                      hop_fraction=VTS_CONFIG['HOP_FRACTION'])
        return x


    def estimate_mixture(self,mode):
        """
        Estimate a GMM from (clean) data in indir.
        This is equivalent to
        `estim_mixgau -c vts_train.fileid -e logspec -o mixture.gmm -m 64`
        except it uses my own log spectrogram.
        Args:
            [mode] - number of Gaussians. default=64.
        """
        # Create directories
        mixture_path = os.path.join(self.outdir,VTS_CONFIG['MIXTURE_PATH'])
        feat_path = os.path.join(self.outdir,VTS_CONFIG['FEATURE_PATH'])
        mixture_path = os.path.abspath(mixture_path)
        feat_path = os.path.abspath(feat_path)
        if not os.path.exists(os.path.dirname(mixture_path)):
            os.makedirs(os.path.dirname(mixture_path))
        if not os.path.exists(feat_path): os.makedirs(feat_path)

        # Load data from indir and extract log mel-spectrogram
        dataset = AudioTransformer(self.indir,
                                   VTS_CONFIG['SAMPLE_RATE'],
                                   mono=True,
                                   verbose=True)
        outpaths = []
        for x,p in dataset:
            X = self.extract_feature(x)
            fname = os.path.basename(p).split('.')[0]
            outpath = os.path.join(feat_path,fname)
            outpaths.append(outpath)
            write_sphinx(outpath+'.logspec',X)

        # Write paths to control file
        ctr = os.path.join(feat_path,'mixture.fileids') # control file name
        print("Writing control file [{}]".format(ctr))
        with open(ctr,'w') as fp:
            fp.write('\n'.join(outpaths))
            fp.write('\n')

        # Finally estimate GMMs from VTS package
        binpath = os.path.abspath(VTS_CONFIG['GMM_PATH'])
        outdir = os.path.abspath(self.outdir)
        cmd = [binpath,'-c',ctr,'-e','logspec','-o',mixture_path,'-m',str(mode)]
        cmd = ' '.join(cmd)
        print(cmd)
        with open(os.path.join(outdir,'GMM{}.log'.format(mode)),'w') as fp:
            with open(os.path.join(outdir,'GMM{}.err'.format(mode)),'w') as fep:
                p=subprocess.Popen(cmd,shell=True,stdout=fp,stderr=fep)
                p.wait()
        self.model = mixture_path

    def vts_compensate(self,invert_spec=None):
        if self.model == None: raise ValueError('Model path not specified!')
        if os.path.isfile(self.outdir):
            npath = os.path.join(os.path.dirname(self.outdir),
                                  'noisy_logspec/') # noisy directory
            cpath = os.path.join(os.path.dirname(self.outdir),
                                  'clean_logspec/') # clean directory
        else:
            npath = os.path.join(self.outdir,'noisy_logspec/')
            cpath = os.path.join(self.outdir,'clean_logspec/')
        npath = os.path.abspath(npath) # path to save temporary logspec
        cpath = os.path.abspath(cpath) # path to save temporary logspec
        if not os.path.exists(npath): os.makedirs(npath)
        if not os.path.exists(cpath): os.makedirs(cpath)
        # Load data from indir and extract log mel-spectrogram
        dataset = AudioTransformer(self.indir,
                                   VTS_CONFIG['SAMPLE_RATE'],
                                   mono=True,
                                   verbose=True)
        outpaths = []
        for x,p in dataset:
            X = self.extract_feature(x)
            fname = os.path.basename(p).split('.')[0]
            outpaths.append(fname)
            outpath = os.path.join(npath,fname)
            write_sphinx(outpath+'.logspec',X)

        # Write paths to control file
        ctr = os.path.join(npath,'noisy.fileids') # control file name
        with open(ctr,'w') as fp:
            fp.write('\n'.join(outpaths))
            fp.write('\n')

        # Run VTS compensation
        binpath = os.path.abspath(VTS_CONFIG['VTS_PATH'])
        outdir = os.path.abspath(self.outdir)
        cmd = [binpath,'-c',ctr,'-i',npath,'-o',cpath,'-x','logspec',\
                '-y','logspec','-d',self.model]
        cmd = ' '.join(cmd)
        print(cmd)
        with open(os.path.join(outdir,'VTS.log'),'w') as fp:
            with open(os.path.join(outdir,'VTS.err'),'w') as fep:
                p=subprocess.Popen(cmd,shell=True,stdout=fp,stderr=fep)
                p.wait()

        if invert_spec == None: return # logspec is the final output
        # Then read in all cleaned spectrograms and invert back to audio
        dataset = AudioTransformer(self.indir,
                                   VTS_CONFIG['SAMPLE_RATE'],
                                   mono=True,
                                   verbose=True)
        for x,p in dataset: # iterate again if audio reconstruction is desired
            fname = os.path.basename(p).split('.')[0]
            cleaned_path = os.path.join(cpath,fname+'.logspec')
            if not os.path.exists(cleaned_path):
                print("File [{}] cannot be compensated.".format(fname))
                continue
            X_logspec = read_sphinx(cleaned_path,VTS_CONFIG['NFILTS'])
            _,__,X_orig = stft(x,VTS_CONFIG['SAMPLE_RATE'],
                          window_length=VTS_CONFIG['WINDOW_LENGTH'],
                          hop_fraction=VTS_CONFIG['HOP_FRACTION'],
                          nfft=VTS_CONFIG['FFT_SIZE'])
            if invert_spec == 'direct': # invert directly
                x_clean = self.invert_feature(X_logspec,X_orig,mask=False)
            elif invert_spec == 'mask': # apply mask to original spectrogram
                x_clean = self.invert_feature(X_logspec,X_orig,mask=True)
            else:
                raise ValueError("invert_spec has to be [direct]/[mask]")
            # finally write wave to file
            outpath = os.path.join(outdir,fname+'.wav')
            audiowrite(x_clean,VTS_CONFIG['SAMPLE_RATE'],outpath,verbose=True)


if __name__ == '__main__':
    # Use as standalone application
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-i',help='Input audio file/directory', required=True)
    parser.add_argument('-o',help='Output model path if in learning distribution mode; Output audio file/directory in compensation mode',required=True)
    parser.add_argument('-m',help='Clean GMM mixture path [Enable for Compensation; Disable for learning distribution]',required=False)
    args = parser.parse_args()
    model = None if not args.m else args.m
    vts = VTS(args.i,args.o,model=model)
    if model == None: # estimate mixture mode
        vts.estimate_mixture(VTS_CONFIG['MIXTURE_MODE'])
    else: # compensation model
        vts.vts_compensate(invert_spec=VTS_CONFIG['INVERT_SPECTROGRAM'])
