from config import *
from util import *
from spectrogram import *
#import theano
from batch_io import AudioTransformer
import argparse
from train import *
from wiener import *

#theano.config.compute_test_value = 'warn'
#theano.config.optimizer='None'
cfg_rats = config['rats']
#floatX = theano.config.floatX
def mrnn_denoise(engine, noisy_path, out_path=None):

    noisy_dataset = AudioTransformer(noisy_path,cfg_rats['SAMPLE_RATE'],
                                    verbose=True)
    for x,_ in noisy_dataset:
        if cfg_rats['ZERO_PHASE_STFT']:
            X = stft(x, cfg_rats['SAMPLE_RATE'],
                    window_length=cfg_rats['WINDOW_LENGTH'],
                    hop_fraction=cfg_rats['HOP_FRACTION'],
                    nfft=cfg_rats['FFT_SIZE'],
                    stft_truncate=True)[-1]
        else:
            X = stft_old(x,
                    int(cfg_rats['WINDOW_LENGTH']*cfg_rats['SAMPLE_RATE']),
                    cfg_rats['HOP_FRACTION'],
                    truncate=True)
        M,P = magphase(X)

        M_de = engine.denoise(M)
        xr = istft(M_de*P,cfg_rats['SAMPLE_RATE'],len(x),
                    window_length=cfg_rats['WINDOW_LENGTH'],
                    hop_fraction=cfg_rats['HOP_FRACTION'],
                    stft_truncate=True)
        if out_path != None: noisy_dataset.write_to_file(xr,out_path)

def wiener_denoise(noisy_path, out_path=None):
    """
    Perform classic a-priori SNR estimation and Wiener filtering on all audio
    data in `noisy_path`. Save the denoised audio in out_path.
    """

    noisy_dataset = AudioTransformer(noisy_path,cfg_rats['SAMPLE_RATE'],
                                    verbose=True)
    for x,_ in noisy_dataset:
        xr = wiener_asnr(x)
        if out_path != None: noisy_dataset.write_to_file(xr,out_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-m',
        help='Specify model path', required=False)
    parser.add_argument('-i',
        help='Input path (could be file/directory)', required=True)
    parser.add_argument('-o',
        help='Output path', required=True)
    parser.add_argument('-w',
        help='Classic Wiener filtering.', required=False,action='store_true',
        default=False)
    parser.add_argument('-gpu',
        help='Enable CUDA.', required=False,action='store_true',
        default=False)
    args = parser.parse_args()
    if args.w:
        wiener_denoise(args.i,args.o)
    else:
        if args.gpu:
            mode = 'gpu'
        else:
            mode = 'cpu'
        engine = MRNNEngine(param_path=args.m,mode=mode)
        mrnn_denoise(engine,args.i,out_path=args.o)
