#include <iostream>
#include <iomanip>

#include "Sundance.h"
#include "TSFMPI.h"

//#define SHOW_HOLD_ERR         // Uncomment to show this error
#define SHOW_OBJ_FUNC_ERR     // Uncomment to show this error

int main( int argc, char* argv[] )
{

  using Sundance::List;

  const int uBasisDim = 1;
  const int fBasisDim = 1;

  /* size of domain */
  const double L = 2.0;

  /* number of sensors in x and y directions */
  const int nsx = 2;
  const int nsy = 2;

  const int n = 4;
  const int nx = nsx*n;
  const int ny = nsy*n;

  const double beta = 1.0;
  const double uGuess = 0.1;
  const double fGuess = 0.2;

  try 
    {

      Sundance::init(&argc,(void***)&argv);

      int npx, npy;
      int num_proc = TSF::TSFMPI::getNProc();
      npx = num_proc/2 + num_proc%2;
      npy = num_proc / npx;

      int rank = TSF::TSFMPI::getRank();
      int px = rank / npy;
      int py = rank % npy;

      double xMin = -L + 2.0*L*px/((double) npx);
      double xMax = -L + 2.0*L*(px+1)/((double) npx);
      double yMin = -L + 2.0*L*(py)/((double) npy);
      double yMax = -L + 2.0*L*(py+1)/((double) npy);

      TSFOut::printf("rank=%d npx=%d px=%d npy=%d py=%d\n", rank, npx, px, npy, py);
      TSFOut::printf("bounds are (%g, %g), (%g, %g)\n",
                     xMin, xMax, yMin, yMax);

      MeshGenerator mesher = new PartitionedRectangleMesher(-L, L, npx*nx, npx, -L, L, npy*ny, npy);
      //		MeshGenerator mesher = new RectangleMesher(-L, L, nx, -L, L, ny);
      const Mesh mesh_ = mesher.getMesh();

      TSFVectorType vst = new PetraVectorType();
      //		TSFVectorType vst = new DenseSerialVectorType();

      const Expr x = new CoordExpr(0);
      const Expr y = new CoordExpr(1);

      const CellSet boundary_ = new BoundaryCellSet();
      const CellSet left_     = boundary_.subset( fabs(x + L) < 1.0e-10 );
      const CellSet right_    = boundary_.subset( fabs(x - L) < 1.0e-10 );
      const CellSet top_      = boundary_.subset( fabs(y - L) < 1.0e-10 );
      const CellSet bottom_   = boundary_.subset( fabs(y + L) < 1.0e-10 );

      //
      // Create the nonlinear variables
      //
		
      TSFVectorSpace
        discreteStateSpace
        = new SundanceVectorSpace(mesh_, new Lagrange(uBasisDim), vst);
      TSFVectorSpace
        discreteControlSpace
        = new SundanceVectorSpace(mesh_, new Lagrange(fBasisDim), vst);

      Expr u0 = new DiscreteFunction( discreteStateSpace,   uGuess, "u0" );
      Expr f0 = new DiscreteFunction( discreteControlSpace, fGuess, "f0" );

      // define derivatives and gradient operator
      Expr dx = new Derivative(0,1);
      Expr dy = new Derivative(1,1);
      Expr grad = List(dx, dy);

      //
      // Formulate the objective function
      //

      const Expr u0u = new UnknownFunction(new Lagrange(uBasisDim));
      const Expr f0u = new UnknownFunction(new Lagrange(fBasisDim));

      // Regularization term
#ifdef SHOW_HOLD_ERR
      const Expr gradf0u = hold(grad*f0u);
      Expr obj_func_expr = Integral( 0.5 * beta * gradf0u * gradf0u );
#else
      Expr obj_func_expr = Integral( 0.5 * beta * (grad*f0u) * (grad*f0u) );
#endif

      // Match to sensor points for states terms

		
      CellSet pointCells = new DimensionalCellSet(0);

      for (int i=0; i<nsx; i++) 
        {
          double sx = -L + 2.0*i*L/((double) nsx);
          if (sx < xMin || sx > xMax) continue;
          for (int j=0; j<nsy; j++) 
            {
              double sy = -L + 2.0*j*L/((double) nsy);
              if (sy < yMin || sy > yMax) continue;
              double uVal = exp(-sx*sx - sy*sy);
              Cell sensor = mesh_.cellNearestToPoint(Point(sx, sy));
              TSFOut::println("sensor cell is " + sensor.toString());
              string label = "sensor(" + TSF::toString(i) + ", " + TSF::toString(j) + ")";
              sensor.setLabel(label);
              CellSet sensorRegion = pointCells.labeledSubset(label);
#ifdef SHOW_HOLD_ERR
              const Expr u0u_uVal = hold(u0u-uVal);
              obj_func_expr += Integral( sensorRegion, 0.5 * u0u_uVal * u0u_uVal );
#else
              obj_func_expr += Integral( sensorRegion, 0.5 * (u0u-uVal) * (u0u-uVal));
#endif
            }
        }

      // Evaluate the objective
      const double obj_eval = obj_func_expr.evaluateFunctional(
                                                               mesh_, List(u0u,f0u),List(u0,f0)
                                                               );

      TSFOut::println( "obj_func_expr.evaluateFunctional(...) = " + TSF::toString(obj_eval));

      // Evaluate the gradient
      Expr gu = obj_func_expr.directSensitivity(u0u,u0);
      //		MatlabWriter("g_u_dump.dat").writeField(gu.name(),gu);
      Expr gf = obj_func_expr.directSensitivity(f0u,f0);
      //		MatlabWriter("g_f_dump.dat").writeField(gf.name(),gf);

      //
      // Finite difference testing
      //

      std::cerr << std::setprecision(18);

      // Do finite-difference derivatives of objective wrt each component of u0. 
      // Store the results in a discrete function guFD
      Expr guFD = new DiscreteFunction(discreteStateSpace, 0.0, "guFD");
      TSFVector u0Vec;
      TSFVector guFDVec;
      TSFVector guVec;
      u0.getVector(u0Vec);
      guFD.getVector(guFDVec);
      gu.getVector(guVec);
		
      double fdStep = 1.0e-6;

      double vectorError = 0.0;

      TSFOut::println("comparing sundance and FD gradients wrt u");

      for (int i=0; i<u0Vec.space().dim(); i++)
        {
          double save = u0Vec[i];
          u0Vec[i] = save + fdStep;
          double objPlus = obj_func_expr.evaluateFunctional(mesh_, List(u0u,f0u), List(u0,f0));
          u0Vec[i] = save - fdStep;
          double objMinus = obj_func_expr.evaluateFunctional(mesh_, List(u0u,f0u), List(u0,f0));
          guFDVec[i] = (objPlus - objMinus)/2.0/fdStep;

          u0Vec[i] = save;
          const double localError = ::fabs( guVec[i] - guFDVec[i] ) / ( 1.0 + ::fabs(guVec[i]) + ::fabs(guFDVec[i]) );

          TSFOut::println(") objPlus = " + TSF::toString(objPlus) + ", objMinus = " 
                          + TSF::toString( objMinus));
          TSFOut::println(") guVec[" + TSF::toString(i) + "] = " + TSF::toString(guVec[i]) 
                          + ", guFDVec[" + TSF::toString(i) + "] = " 
                          + TSF::toString(guFDVec[i]) + ", relative error = "
                          + TSF::toString(localError));
          if (localError > vectorError) vectorError = localError;
        }

      // Do finite-difference derivatives of objective wrt each component of a0. 
      // Store the results in a discrete function gaFD.
      Expr gfFD = new DiscreteFunction(discreteControlSpace, 0.0, "gfFD");
      TSFVector f0Vec;
      TSFVector gfFDVec;
      TSFVector gfVec;
      f0.getVector(f0Vec);
      gfFD.getVector(gfFDVec);
      gf.getVector(gfVec);
		
      cerr << "comparing sundance and FD gradients wrt f" << endl;
      for (int i=0; i<f0Vec.space().dim(); i++)
        {
          double save = f0Vec[i];
          f0Vec[i] = save + fdStep;
          double objPlus = obj_func_expr.evaluateFunctional(mesh_, List(u0u,f0u), List(u0,f0));
          f0Vec[i] = save - fdStep;
          double objMinus = obj_func_expr.evaluateFunctional(mesh_, List(u0u,f0u), List(u0,f0));
          gfFDVec[i] = (objPlus - objMinus)/2.0/fdStep;
          f0Vec[i] = save;
          const double localError = ::fabs( gfVec[i] - gfFDVec[i] ) / ( 1.0 + ::fabs(gfVec[i]) + ::fabs(gfFDVec[i]) );
          cerr << i << ") objPlus = " << objPlus << ", objMinus = " << objMinus << endl;
          cerr << i << ") gfVec[" << i << "] = " << gfVec[i] << ", gfFDVec[" << i << "] = " << gfFDVec[i]
               << ", relative error = " << localError << endl;
          if (localError > vectorError) vectorError = localError;
        }
		
      double gradUErrorNorm = (guFD - gu).norm(2);
      double gradAErrorNorm = (gfFD - gf).norm(2);

      double errorNorm = gradUErrorNorm + gradAErrorNorm + vectorError;
      TSFOut::printf("error norm for df/du: %g\n", gradUErrorNorm);
      TSFOut::printf("error norm for df/da: %g\n", gradAErrorNorm);
      TSFOut::printf("error norm for vector comparisons: %g\n", vectorError);

      double tolerance = 1.0e-4;

      //
      // decide if the error is within tolerance
      //
      Testing::passFailCheck(__FILE__, errorNorm, tolerance);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
		


    }
  catch(exception& e) 
    {
      TSFOut::println(e.what());
      Testing::crash(__FILE__);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
    }
  return 0;

  Sundance::finalize();
}

