/**
 * 
 */
package edu.cmu.cs.ark.compuframes;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.commons.math3.analysis.MultivariateFunction;

/**
 * @author Brendan O'Connor (modified by Yanchuan Sim)
 * @version 0.1
 * @since 0.1
 */

public class SliceSampler
{
  public static List<double[]> slice_sample(MultivariateFunction logdist, double[] initial, double[] widths, int niter)
  {
    boolean step_out = true;
    final int D = initial.length;
    assert widths.length == D;

    double[] state = initial;
    double log_Px = logdist.value(state);

    List<double[]> history = new ArrayList<double[]>();

    for (int itr = 0; itr < niter; itr++)
    {
      // if (itr%100==0) { U.pf("."); System.out.flush(); }
      double log_uprime = Math.log(Math.random()) + log_Px;
      // System.out.format("Slice iter %d stats %s log_Px %f log_u' %f\n", itr, Arrays.toString(state), log_Px, log_uprime);

      // # Sweep through axes
      for (int dd = 0; dd < D; dd++)
      {
        double[] x_l = Arrays.copyOf(state, D), x_r = Arrays.copyOf(state, D), xprime = Arrays.copyOf(state, D);
        // # Create a horizontal interval (x_l, x_r) enclosing xx
        double r = Math.random();
        x_l[dd] = state[dd] - r * widths[dd];
        x_r[dd] = state[dd] + (1 - r) * widths[dd];
        if (step_out)
        {
          while (logdist.value(x_l) > log_uprime)
            x_l[dd] -= widths[dd];
          while (logdist.value(x_r) > log_uprime)
            x_r[dd] += widths[dd];
        }
        // # Inner loop:
        // # Propose xprimes and shrink interval until good one is found.
        double zz = 0;
        while (true)
        {
          zz += 1;
          xprime[dd] = Math.random() * (x_r[dd] - x_l[dd]) + x_l[dd];
          log_Px = logdist.value(xprime);
          if (log_Px > log_uprime)
          {
            break;
          }
          else
          {
            if (xprime[dd] > state[dd])
            {
              x_r[dd] = xprime[dd];
            }
            else if (xprime[dd] < state[dd])
            {
              x_l[dd] = xprime[dd];
            }
            else
            {
              assert false : "BUG, shrunk to current position and still not acceptable";
            }
          }
        }
        state[dd] = xprime[dd];
      }
      history.add(Arrays.copyOf(state, D));
    }
    return history;
  }
}
