/*
 * mex_logsum.cc, started 22 October 2009 by tss.
 *
 * Here is the MATLAB documentation for the MATLAB version of this code:
 *
% LOGSUM   Compute log(sum(exp(vals))) stably
%
%   log_sum = logsum(logvals,dim)
%
%   Uses Gaussian logarithms to compute log(sum(exp(logvals))), which, for
%   vals whose log representation is very low or very high, can fail to
%   being infinity or 0 if the standard approach is used. Works kinda like
%   the MATLAB sum function w.r.t. the dim argument, but does not support
%   3-or-higher dimensional data.
 *
 * Works only with dense matrices full of doubles.
 */

#include <mex.h>
#include <cmath>
#include <limits>

// Predeclaration
extern "C" {
void mexFunction(int nlhs, mxArray **plhs, int nrhs, const mxArray **prhs);
}

// This helper does the actual heavy lifting of Gaussian logarithms, but it
// is mandatory that a >= b.
inline double log_add_core(const double &a, const double &b)
{
  if(std::isinf(b) && (b < 0.0)) return a; // necessary?
  const double thediff = b - a;
  return a + log1p(exp(thediff));
}

// Compute log(exp(a) + exp(b)) using Gaussian logarithms
inline double log_add(const double &a, const double &b)
{
  // We always want the first argument to be the largest.
  if(a > b) return log_add_core(a,b);
       else return log_add_core(b,a);
}

// This function iterates through matrices and performs the actual summing.
// The _step variables allow this code to do column and row sums.
inline void do_logsum(const double *data,
                      const unsigned int &iters_outer,
                      const unsigned int &iters_inner,
                      const unsigned int &step_outer,
                      const unsigned int &step_inner,
                      double *result,
                      const double &offset)
{
  const double *data_outer = data;
  for(unsigned int iter_outer=0; iter_outer<iters_outer; ++iter_outer) {
    const double *data_inner = data_outer;

    double sum = offset;
    for(unsigned int iter_inner=0; iter_inner<iters_inner; ++iter_inner) {
      sum = log_add(sum, *data_inner);
      data_inner += step_inner;

      if(std::isnan(sum)) break;                // these conditions allow us
      if((sum > 0.0) && std::isinf(sum)) break; // to short-circuit
    }

    *result++ = sum;
    data_outer += step_outer;
  }
}


void mexFunction(int nlhs, mxArray **plhs, int nrhs, const mxArray **prhs)
{
  // Declare some constants
  const double _NAN    =  std::numeric_limits<double>::quiet_NaN();
  const double _NEGINF = -std::numeric_limits<double>::infinity();

  // Complain if something is wrong.
  if(nrhs < 1) mexErrMsgTxt("logsum requires an argument");
  const mxArray *data = prhs[0];
  if(mxGetNumberOfDimensions(data) > 2)
    mexErrMsgTxt("logsum can't handle N-dimensional arrays with N > 2");
  if(mxIsComplex(data)) mexErrMsgTxt("logsum can only handle real arrays");
  if(!mxIsDouble(data)) mexErrMsgTxt("logsum can only handle double arrays");
  if( mxIsSparse(data)) mexErrMsgTxt("logsum can only handle dense arrays");

  // Get dimensions of argument array
  const unsigned int true_M = mxGetM(data);
  const unsigned int true_N = mxGetN(data);

  // Determine which dimension to sum
  enum { DOWN, RIGHT } sumdir = DOWN;
  bool direction_is_explicit  = false; // true iff user actually specified
  // First, see if the user has specified something
  if(nrhs > 1) {
    const unsigned char dir = (unsigned char) mxGetScalar(prhs[1]);
         if(dir == 1) sumdir = DOWN;
    else if(dir == 2) sumdir = RIGHT;
    else mexErrMsgTxt("invalid summation direction");
    direction_is_explicit = true;
  }
  // Otherwise, see if it's a vector; if it is, sum along the vector
  // dimension. It only makes sense to do this if both dims are nonzero.
  else if((true_M > 0) && (true_N > 0)) {
         if(true_N == 1) sumdir = DOWN;
    else if(true_M == 1) sumdir = RIGHT;
  }

  // Handle cases where one dim or another is zero
  if(true_M == 0) {
    // What follows is just copying MATLAB.
    if(true_N == 0) {
      if(direction_is_explicit) {
        if(sumdir == DOWN) plhs[0] = mxCreateDoubleMatrix(1,0, mxREAL);
                      else plhs[0] = mxCreateDoubleMatrix(0,1, mxREAL);
      }
      else plhs[0] = mxCreateDoubleScalar(_NEGINF);
      return;
    }
    // The next two behaviors make sense if you think hard about them. If
    // they still don't, well, we just do the same thing that sum() does.
    if(sumdir == DOWN) {
      plhs[0] = mxCreateDoubleMatrix(1,true_N, mxREAL);
      double *data = mxGetPr(plhs[0]);
      for(unsigned int i=0; i<true_N; ++i) *data++ = _NEGINF;
      return;
    }
    else { // sumdir == RIGHT
      plhs[0] = mxCreateDoubleMatrix(0,1, mxREAL);
      return;
    }
  }
  else if(true_N == 0) {
    if(sumdir == DOWN) {
      plhs[0] = mxCreateDoubleMatrix(1,0, mxREAL);
      return;
    }
    else { // sumdir == RIGHT
      plhs[0] = mxCreateDoubleMatrix(true_M,1, mxREAL);
      double *data = mxGetPr(plhs[0]);
      for(unsigned int i=0; i<true_M; ++i) *data++ = _NEGINF;
      return;
    }
  }

  // Well, looks like we'll actually be computing a sum today. Allocate space
  // for the sum and sum it up, then.
  else if(sumdir == DOWN) {
    plhs[0] = mxCreateDoubleMatrix(1,true_N, mxREAL);
    do_logsum(mxGetPr(data),
              true_N, true_M, true_M, 1,
              mxGetPr(plhs[0]), _NEGINF);
  }
  else {
    plhs[0] = mxCreateDoubleMatrix(true_M,1, mxREAL);
    do_logsum(mxGetPr(data),
              true_M, true_N, 1, true_M,
              mxGetPr(plhs[0]), _NEGINF);
  }
}

