/*
  title: fbank.c
  purpose: Frontend filter to generate Filter Bank feature sequences for use
    by the fview package.

  authors:  Gareth Lee and Edmund Lai.
  date:     28-10-93
  modified: 

  changes:
  23-06-93: Error in call to preemp_short fixed by passing &bdata[1] rather
  than the pointer itself. cf. The function source in file preemp.c.
  28-10-93: fbank.c based on FRONTEND.C by Edmund Lai and Gareth Lee.
*/
#define VERSION "1.0"

#include <stdio.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <string.h>
#include <math.h>
#include <memory.h>

#include "sfheader.h"
#include "fviewio.h"

#define	freq2fft(x)	(((x) / mfreq) * (bins-1.0) + 0.5)

/* Mel = 1 - exp(-alpha * freq) */
#define	alpha	        (3.4657e-4)                     /* M=0.5 at f=2000Hz */
#define	Mel(x)		(1.0 - exp(-alpha*(x)))
#define IMel(x)		(-log(1.0-(x))/alpha)

/*============= CONSTANTS ================*/
#define  ON	       (1)
#define  OFF	       (0)
#define	 NITEMS	       (256)
#define  FBANDS        (20)
#define  OUTPUT_SCALE  (10.0)
#define  HDR_SIZE      (10240)

int    debug = 0;                                         /* debugging level */
int    nnrg = 12;                                  /* number of filter bands */
int    preemphasis = OFF;                                  /* no preemphasis */
double preemp_factor;
int    flength = 256;                                        /* frame length */
int    flength2 = 128;                                 /* always flength / 2 */
int    advs = 0;                          /* frame advance (marked as unset) */
double overlap = 0.0;                    /* fractional overlap between bands */
double start = 60.0;                                  /* LF collection limit */
double stop = 5000.0;                                 /* HF collection limit */

double **weight;            /* weightings for use in generating energy bands */
int mfreq;                                     /* half sample frequency (Hz) */

FviewSampleHeader fsh;                        /* input samples header record */
FviewFeatureHeader ffh;                     /* output features header record */
char header[HDR_SIZE];                         /* Textual header information */

/* default magnitude spectrum collection method */
enum {square, triangular, gaussian} collector = square;

extern char *optarg;                   /* for use with the getopt() function */

/*****************************************************************************/

/*
  fft: radix 2 FFT routine from "Numerical Recipes in C" performing an n
    point complex FFT but modified to accept source array in C format,
    ie. data = [0..2n-1].
*/
fft(double *data, int n)
{
  int    mmax, m, j,istep, i;
  double wtemp, wr, wpr, wpi, wi, theta;
  double tempr, tempi;

  n *= 2;
  
  /* Bit Reversal routine */
  j = 1;
  for (i = 1; i < n; i += 2 ) {
    if (j > i) {
      tempr = data[j-1];
      data[j-1] = data[i-1];
      data[i-1] = tempr;
      tempi = data[j];
      data[j] = data[i];
      data[i] = tempi;
    }
    m = n >> 1;
    while (m >= 2 && j > m) {
      j -= m;
      m >>= 1;
    }
    j += m;
  }

  mmax = 2;
  while (n > mmax ) {
    istep = 2*mmax;
    theta = 2*M_PI/mmax;
    wtemp = sin(0.5*theta);
    wpr = -2.0*wtemp*wtemp;
    wpi = sin(theta);
    wr = 1.0;
    wi = 0.0;
    for (m = 1; m < mmax; m += 2) {
      for (i = m; i <= n; i += istep) {
	j = i + mmax;
	tempr = wr*data[j-1] - wi*data[j];
	tempi = wr*data[j] + wi*data[j-1];
	data[j-1] = data[i-1] - tempr;
	data[j] = data[i] - tempi;
	data[i-1] += tempr;
	data[i] += tempi;
      }
      wtemp = wr;
      wr = wr*wpr - wi*wpi + wr;
      wi = wi*wpr + wtemp*wpi + wi;
    }
    mmax = istep;
  }
}

/*****************************************************************************/

/*
  complex_power: find instantaneous power within each complex vector provided.
*/
void complex_power(int len, double* ri, double* pow)
{
  int     i;
  double  re, im;
 
  for (i = 0; i < len; i++) {
    re = (*ri) * (*ri);
    ri++;
    im = (*ri) * (*ri);
    ri++;
    *pow++ = re + im;
  }
}

/*****************************************************************************/

/*
  applylog: log all elements within the vector provided and scale elements.
*/
void applylog(int len, double* energy)
{
  int     i;

  for (i = 0; i < len; i++) {
    energy[i] = OUTPUT_SCALE * log10(energy[i]);
  }
}

/*****************************************************************************/

/*
  alloc_weights: dynamically allocate weights array such that it is indexed
    weights[0..rows-1][0..cols-1].
*/
double **alloc_weights(int rows, int cols)
{
  int i;
  double **p;

  p = (double **) calloc(rows, sizeof(double *));
  for (i = 0; i < rows; i++)
    p[i] = (double *) calloc(cols, sizeof(double));
  return p;
}

/*****************************************************************************/

/*
  setup_square: construct square collecting function using Mel scale.
*/
void setup_square(int bins)
{
  int i, j;
  double bandwidth, incrband, freq;
  int low_edge, high_edge;

  if (debug > 0)
  {
    fprintf(stderr, "fbank: setup_square: nnrg = %d, start:stop = %f:%f, "
            "overlap = %f\n", nnrg, start, stop, overlap); /**/
    fprintf(stderr, "fbank: setup_square: mfreq = %d, nnrg = %d\n",
            mfreq, nnrg); /**/
  }
  
  bandwidth = (Mel(stop) - Mel(start)) / ((nnrg-1) * (1.0-overlap) + 1);
  incrband = (1.0 - overlap) * bandwidth;
  freq = Mel(start);

  /* foreach filter band */
  for (i = 0; i < nnrg; i++)
  {
    low_edge  = freq2fft(IMel(freq)); 
    high_edge = freq2fft(IMel(freq + bandwidth));

    for (j = 0; j < low_edge; j++)
      weight[i][j] = 0.0;
    for (j = low_edge; j <= high_edge; j++)
      weight[i][j] = 1.0 / (double) (high_edge - low_edge + 1);
    for (j = high_edge + 1; j < bins; j++)
      weight[i][j] = 0.0;

    if (debug > 0)
      fprintf(stderr, "[%d, %d], ", low_edge, high_edge);  /**/
    freq += incrband;
  }
  if (debug > 0)
  {
    fprintf(stderr, "\n"); /**/
    fflush(stderr);
  }
}


/*****************************************************************************/

/*
  setup_gaussian: construct gaussian collecting function using Mel scale.
*/
void setup_gaussian(int bins)
{
  const double c1 = 0.39894228;                  /* 1.0 / (sqrt(2.0 * M_PI)) */
    
  int i, j;
  double incrband, freq, f, df;
  double centre_freq, fdist;
  double scale, sd;

  if (debug > 0)
  {
    fprintf(stderr, "fbank: setup_gaussian: nnrg = %d, start:stop = %f:%f, "
            "overlap (sd%) = %f\n", nnrg, start, stop, overlap);  /**/
    fprintf(stderr, "fbank: setup_gaussian: mfreq = %d, nnrg = %d\n",
            mfreq, nnrg);  /**/
  }
  
  incrband = (Mel(stop) - Mel(start)) / (nnrg + 1);
  df = (double) mfreq / (double) bins;               /* freq resolution (Hz) */

  /* foreach filter band */
  for (i = 0, freq = (Mel(start) + incrband); i < nnrg; i++, freq += incrband)
  {
    centre_freq = IMel(freq);
    sd = (centre_freq - IMel(freq - incrband)) * overlap;

    /* j = bin index, f = correpsonding frequency (Hz) */
    scale = 0.0;
    for (j = 0, f = 0.0; j < bins; j++, f += df) {
      fdist = (f - centre_freq) * (f - centre_freq);
      weight[i][j] = exp(-0.5 * fdist / (sd * sd));
      scale += weight[i][j];
    }

    /* normalize all band weightings to sum to unity */
    for (j = 0; j < bins; j++)
      weight[i][j] /= scale;

    if (debug > 0)
      fprintf(stderr, "fbank: Band %d: centre = %f, std. dev = %f (Hz)\n",
              i, centre_freq, sd);
  }
}

/*****************************************************************************/

/*
  setup_triangular: construct triangular collecing function using Mel scale.
*/
void setup_triangular(int bins)
{
  int i, j;
  double halfband, incrband, freq, f, df;
  double low_freq, centre_freq, high_freq;
  double scale;

  if (debug > 0)
  {
    fprintf(stderr, "fbank: setup_triangular: nnrg = %d, start:stop = %f:%f\n",
            nnrg, start, stop);   /**/
    fprintf(stderr, "fbank: setup_triangular: mfreq = %d, nnrg = %d\n",
            mfreq, nnrg);  /**/
  }
  
  halfband = (Mel(stop) - Mel(start)) / (nnrg + 1);
  incrband = halfband * 2.0;
  df = (double) mfreq / (double) bins;

  /* foreach filter band */
  for (i = 0, freq = Mel(start); i < nnrg; i++, freq += halfband)
  {
    /* band low edge, centre and high edge (Hz) */
    low_freq = IMel(freq);
    centre_freq = IMel(freq + halfband);
    high_freq = IMel(freq + incrband);

    /* j = bin index, f = corresponding frequency (Hz) */
    scale = 0.0;
    for (j = 0, f = 0.0; j < bins; j++, f += df) {
      if (f < low_freq)
	weight[i][j] = 0.0;
      else if (f < centre_freq) {
	weight[i][j] = 1.0 - ((centre_freq - f) / (centre_freq - low_freq));
	scale += weight[i][j];
      }
      else if (f < high_freq) {
	weight[i][j] = 1.0 - ((f - centre_freq) / (high_freq - centre_freq));
	scale += weight[i][j];
      }
      else
	weight[i][j] = 0.0;
    }

    /* normalize all the bands weightings to sum to unity */
    for (j = 0; j < bins; j++)
      weight[i][j] /= scale;

    if (debug > 0)
      fprintf(stderr, "fbank: Band %d: low= %f, centre= %f, high= %f (Hz)\n",
              i, low_freq, centre_freq, high_freq);
  }
}

/*****************************************************************************/

/*
  collection: collects energy from FFT power bins into filter bands using
    a collection function specified in weight[0<=band<nnrg][0<=bin<flength2]
*/
void collection(int samples, double* spectrum, int bands, double* fband)
{
  int i, j;
  double sum;

  for (i = 0; i < bands; i++) {
    sum = 0.0;
    for (j = 0; j < samples; j++)
      sum += spectrum[j] * weight[i][j];
    fband[i] = sum;
  }
}

/*****************************************************************************/

void parse_cmd_line(int argc, char* argv[])
{
  int opt;
  
  while ((opt = getopt(argc, argv, "a:b:c:d:f:o:p:r:")) != -1)
  {
    switch (opt)
    {
    case 'a':
      advs = atoi(optarg);
      break;
    case 'b':
      nnrg = atoi(optarg);
      break;
    case 'c':
    {
      switch (*optarg)
      {
      case 's': case 'S':
	collector = square;
	break;
      case 'g': case 'G':
	collector = gaussian;
	break;
      case 't': case 'T':
	collector = triangular;
	break;
      default:
	fprintf(stderr, "fbank: collection option -c %s not supported\n",
                optarg);
	exit(-1);
      }
      break;
    }
    case 'd':
      debug = atoi(optarg);
      break;
    case 'f':
      flength = atoi(optarg);
      flength2 = flength / 2;
      break;
    case 'o':
      overlap = atof(optarg);
      break;
    case 'p':
      preemphasis = ON;
      preemp_factor = atof(optarg);
      break;
    case 'r':
      sscanf(optarg, "%f:%f", &start, &stop);
      break;
    case '?':
      fprintf(stderr, "fbank: command line error: '%s'.\n", optarg);
      exit(-1);
    }
  }

  /* Set default advs as a fraction of the analysis frame width */
  if (advs == 0)
    if (flength > 0)
      advs = (int) ((100.0 / 256.0) * flength);
    else
      advs = 100;
}
  
/*****************************************************************************/

void main(int argc, char* argv[])
{
  int 	  i, j, nframes, nsamples, m;
  short	  *bdata, *bdtmp;
  int	  sample_bits;
  off_t	  datasize;
  double  *w, *ar, *ai, *td;
  double  atof();
  struct stat buf;
  double  *fbands;
  char    *cptr;
  char	  *rindex();
  char    magic[10];
  
  parse_cmd_line(argc, argv);
    
  /* Read header information from data file and check file type */
  if (fread(&fsh, sizeof(FviewSampleHeader), 1, stdin) != 1)
  {
    fprintf(stderr, "fbank: error reading header record\n");
    exit(1);
  }
  strcpy(magic, FVIEW_SAMPLE_MAGIC);
  for (i = 0; i < 8; i++)
    if (fsh.magic[i] != magic[i])
      break;
  if (i < 8)
  {
    fprintf(stderr, "fbank: magic number incorrect (%s)\n", fsh.magic);
    exit(1);
  }
  if (fread(header, sizeof(char), fsh.info_size, stdin) != fsh.info_size)
  {
    fprintf(stderr, "fbank: error reading %d byte textual header\n",
            fsh.info_size);
    exit(1);
  }
  
  /* Read data */
  bdata = (short *) malloc(fsh.number_samples * sizeof(short));
  if (fread(bdata, sizeof(short), fsh.number_samples, stdin) !=
      fsh.number_samples)
  {
    fprintf(stderr, "fbank: data finished prematurely\n");
    exit(1);
  }
  nframes = (fsh.number_samples - flength + advs) / advs;
  mfreq = fsh.sample_freq / 2;
  
  if (preemphasis == ON)
    preemp_short(fsh.number_samples, preemp_factor, &bdata[1]);

  /* Append extra information to textual header */
  cptr = header + strlen(header);
  sprintf(cptr,"FBANK_1.0\nframes=%d\norder=%d\nframe_length=%d\n"
          "frame_advance=%d\n",
          nframes,   /* Number of observation vectors that will be generated */
          nnrg,                          /* Number of coefficients per frame */
          flength,		        /* Frame length in number of samples */
          advs);	       	  /* Advancement per frame in no. of samples */
  cptr = header + strlen(header);
  switch(collector)
  {
  case square:
    sprintf(cptr, "collection=square\n");
    break;
  case gaussian:
    sprintf(cptr, "collection=gaussian\n");
    break;
  case triangular:
    sprintf(cptr, "collection=triangular\n");
    break;
  }

  memcpy(ffh.magic, FVIEW_FEATURE_MAGIC, 8);
  ffh.vector_dimension = nnrg;
  ffh.number_observations = nframes;
  ffh.info_size = strlen(header) + 1;
  if (ffh.info_size > HDR_SIZE)
  {
    fprintf(stderr, "fbank: textual header overflow\n");
    exit(-1);
  }
  fwrite(&ffh, sizeof(FviewFeatureHeader), 1, stdout);      /* header record */
  fwrite(header, sizeof(char), ffh.info_size, stdout);        /* text header */

  /* allocate memory */
  w = (double *) calloc(flength, sizeof(double));
  ar = (double *) calloc(flength, sizeof(double));
  ai = (double *) calloc(flength, sizeof(double));
  fbands = (double *) calloc(nnrg, sizeof(double));
  td = (double *)calloc(flength * 2, sizeof(double));
  
  Hamming(flength, w);                          /* Generate a Hamming window */

  /* call appropriate setup routine */
  weight = alloc_weights(nnrg, flength2);
  switch(collector)
  {
  case square:
    setup_square(flength2);
    break;
  case gaussian:
    setup_gaussian(flength2);
    break;
  case triangular:
    setup_triangular(flength2);
    break;
  }

  /* print the weights array created by setup_XXXXX */
  if (debug > 1)
  {
    for (j = 0; j < flength2; j++)
    {
      for (i = 0; i < nnrg; i++)
	fprintf(stderr, "%f ", weight[i][j]);
      fprintf(stderr, "\n");
    }
  }

  /* Perform Fourier Analysis on each frame */
  if (debug > 0)
    fprintf(stderr, "fbank: Performing %d point FFT analysis ...\n", flength);
  bdtmp = bdata;
  for (m = 0; m < nframes; m++) {
    detrend_short(flength,bdtmp,ar);     /* ensure ar samples have zero mean */
    window(flength,w,ar);			/* Perform Hamming windowing */
    for (i = 0, j = 0; i < flength; i++, j+=2) {
      td[j] = ar[i];    /* set real */
      td[j+1] = 0.0;    /* set imag */
    }
    fft(td, flength);                                   /* Fourier Transform */
    complex_power(flength2, td, ar);	           /* form vector magnitudes */

    /* collect energy in filter bands then log energy values*/
    collection(flength2, ar, nnrg, fbands);
    applylog(nnrg, fbands);

    /* Store coefficients into file */
    fwrite(fbands, sizeof(double), nnrg, stdout);
    bdtmp += advs;
  }

  /* free memory */
  free(bdata);

  /* Signal completion */
  switch(collector)
  {
  case square:
    fprintf(stderr, "fbank (%s): square window: %d samps -> %d frames\n",
            VERSION, fsh.number_samples, nframes);
    break;
  case gaussian:
    fprintf(stderr, "fbank (%s): gaussian window: %d samps -> %d frames\n",
            VERSION, fsh.number_samples, nframes);
    break;
  case triangular:
    fprintf(stderr, "fbank (%s): triangular window: %d samps -> %d frames\n",
            VERSION, fsh.number_samples, nframes);
    break;
  }
}

