#include "Sundance.h"
#include "OptQNewton.h"
#include "NLF.h"


/** \example heat2D.cpp
 * Solve Poisson's equation with a unit source term on the 
 * rectangle [0,1] x [0, 2] with the following boundary conditions:
 *
 * Left:   Natural, du/dx = 0
 * Bottom: Dirichlet, u= 0.5 x^2
 * Right:  Robin, u + du/dx = 3/2 + y/3
 * Top:    Neumann, du/dy = 1/3
 *
 * The solution is u(x,y) = 0.5*x^2 + y/3.
 *
 * This problem can be solved exactly in the space of second-order polynomials.
 */


class FlowSimulator
{
public:
  FlowSimulator();

  void solve(const TSFArray<TSFArray<double> >& fourier);

  double evalObjective(const TSFArray<TSFArray<double> >& fourier);

  void readFourier(const string& filename, TSFArray<TSFArray<double> >& fourier);

  void print(int k);

  int nxParam() {return nxParams_;}
  int nyParam() {return nyParams_;}

  TSFArray<TSFArray<double> >& coeffs() {return coeffs_;}
private:
  Mesh mesh_;
  TSFLinearSolver solver_;
  CellSet outlet_;
  double L_;
  double H_;
  Expr travelTime_;
  Expr potential_;
  bool write_;
  int nxParams_;
  int nyParams_;
  TSFArray<TSFArray<double> > coeffs_;
};

static FlowSimulator& globalSim()
{
  static FlowSimulator rtn ;
  return rtn;
}


/* read initial coefficients and fill Opt++ initial guess */
void initFunction(int n, ColumnVector& x0)
{
  /* warning: opt++ vectors use unit-offset indices */
  int k = 1;
  for (int i=0; i<globalSim().nxParam(); i++)
    {
      for (int j=0; j<globalSim().nyParam(); j++)
        {
          x0(k++) = globalSim().coeffs()[i][j];
        }
    }
}

/* update function gets called by Opt++ at the end of each opt step */
void updateFunction(int k, int n, ColumnVector x)
{
  globalSim().print(k);
}


/* opt++ obj function */
void objFunction(int n, const ColumnVector& x, double& fx, int& result)
{
  /* get expansion coefficients out of Opt++'s vector */
  /* warning: opt++ vectors use unit-offset indices */
  int k = 1;
  TSFOut::println("writing vector");
  for (int i=0; i<globalSim().nxParam(); i++)
    {
      for (int j=0; j<globalSim().nyParam(); j++)
        {
          globalSim().coeffs()[i][j] = x(k);
          cerr << k << " " << x(k) << endl;
          k++;
        }
    }

  /* compute objective function */
  TSFOut::println("starting obj function");
  fx = globalSim().evalObjective(globalSim().coeffs());
  TSFOut::println("done obj function");
  result = NLPFunction;
}

/* opt++ obj function */
void objFunctionAndGradient(int mode, int n, const ColumnVector& x, double& fx, 
                            ColumnVector& g, int& result)
{
  TSFArray<double> f(n+1);
  int r;
  double h = 0.1;

  objFunction(n, x, f[0], r);
  for (int i=1; i<=n; i++)
    {
      ColumnVector xp = x;
      xp(i) += h;
      objFunction(n, xp, f[i], r);
      g(i) = (f[i]-f[0])/h;
    }
  fx = f[0];
  result = NLPGradient;
}


int main(int argc, void** argv)
{
  try
    {
      Sundance::init(&argc, &argv);
      int n = globalSim().nxParam() * globalSim().nyParam();

      NLF1 nlp(n, objFunctionAndGradient, initFunction);
      nlp.initFcn();

      OptQNewton opt(&nlp, updateFunction);

      opt.setOutputFile("opt", 0);
      opt.optimize();
      opt.printStatus("solution from quasi-newton");
      opt.cleanup();
			
    }
	
  catch(exception& e)
    {
      Sundance::handleError(e, __FILE__);
    }
  Sundance::finalize();
}

FlowSimulator::FlowSimulator()
  : mesh_(), solver_(), outlet_(), L_(1.0), H_(1.0), travelTime_(), potential_(),
    write_(true)
{
  /* create a simple mesh on the rectangle */
	
  int nx = 50;
  int ny = 100;
  TSFCommandLine::findInt("-nx", nx);
  TSFCommandLine::findInt("-ny", ny);

  MeshGenerator mesher = new RectangleMesher(0.0, L_, nx, -H_, H_, ny);
  mesh_ = mesher.getMesh();

  /* create a preconditioner and solver */
  map<int, int> azOptions;
  map<int, double> azParams;

  int maxiter = 1000;
  int fill = 1;
  int overlap = 1;
  TSFCommandLine::findInt("-maxiter", maxiter);
  TSFCommandLine::findInt("-fill", fill);
  TSFCommandLine::findInt("-overlap", overlap);
  bool noAztec = 	TSFCommandLine::find("-noaztec");
			
  if (!noAztec)
    {
      azOptions[AZ_output] = AZ_none;
      azOptions[AZ_solver] = AZ_gmres;
      azOptions[AZ_precond] = AZ_dom_decomp;
      azOptions[AZ_subdomain_solve] = AZ_ilu;
      azOptions[AZ_graph_fill] = fill;
      azParams[AZ_tol] = 1.0e-13;
      azOptions[AZ_max_iter] = maxiter;
      solver_ = new AZTECSolver(azOptions, azParams);
    }
  else
    {
      TSFPreconditionerFactory precond 
        = new ILUKPreconditionerFactory(fill, overlap);
      solver_ = new BICGSTABSolver(precond, 1.e-10, maxiter);
    }

  readFourier("fourier.dat", coeffs_);
  nxParams_ = coeffs_.length();
  nyParams_ = coeffs_[0].length();
}

void FlowSimulator::readFourier(const string& filename, 
                                TSFArray<TSFArray<double> >& fourier)
{
  ifstream is(filename.c_str());
  int nx, ny;

  is >> nx >> ny;

  fourier = TSFArray<TSFArray<double> >(nx, ny);

  for (int i=0; i<nx; i++)
    {
      for (int j=0; j<ny; j++)
        {
          is >> fourier[i][j];
        }
    }
	
}

double FlowSimulator::evalObjective(const TSFArray<TSFArray<double> >& fourier)
{
  TSFOut::printf("params = %s\n", fourier.toString().c_str());
  solve(fourier);

  Expr one = 1.0;
  double TMean = travelTime_.integral(mesh_, outlet_)/one.integral(mesh_, outlet_);
  double obj = (travelTime_-TMean).norm(2, mesh_, outlet_);
  TSFOut::printf("value = %g\n", obj);
  return obj;
}

void FlowSimulator::print(int k)
{

  if (write_)
    {
      TSFOut::println("writing plot");
      string suffix = "." + TSF::toString(k);
      FieldWriter phiWriter = new MatlabWriter("phi" + suffix);
      phiWriter.writeField(potential_);
			
      FieldWriter tWriter = new MatlabWriter("T" + suffix);
      tWriter.writeField(travelTime_);
    }
}

void FlowSimulator::solve(const TSFArray<TSFArray<double> >& fourier)
{
  const double pi = 4.0*atan(1.0);

  /* define coordinate functions for x and y coordinates */
  Expr x = new CoordExpr(0);
  Expr y = new CoordExpr(1);
	
  /* define cells sets for each of the four sides of the rectangle */
  CellSet boundary = new BoundaryCellSet();
  CellSet left = boundary.subset( x == 0.0 );
  CellSet right = boundary.subset( x == L_ );
  CellSet bottom = boundary.subset( y == -H_ );
  CellSet top = boundary.subset( y == H_ );
  CellSet maximal = CellSet::maximalCells();
	
  double wTop = 0.2*L_;
  double wBottom = 0.8*L_;
  double deltaY = 0.2;
  Expr w = wTop + 0.5*(wBottom-wTop)*(1.0 - tanh(y/deltaY));

  Expr f0 = 1.0 - x*x/w/w;

  for (int i=0; i<fourier.length(); i++)
    {
      for (int j=0; j<fourier[i].length(); j++)
        {
          if (fabs(fourier[i][j]) < 1.0e-10) continue;
          f0 = f0 + exp(-2.0*(H_+y))*exp(-3.0*x)*fourier[i][j]*sin(pi*(j+1)*y/H_)*cos(pi*x/L_ * (((double) i) + 0.5));
        }
    }
	
	
  /* Create a vector space factory, used to 
   * specify the low-level linear algebra representation */
  TSFVectorType petra = new PetraVectorType();
  /* create a discrete space on the mesh */
  TSFVectorSpace discreteSpace = new SundanceVectorSpace(mesh_, new Lagrange(1), petra);
	
  Expr f = new DiscreteFunction(discreteSpace, f0);

  CellSet channel = maximal.subset( f >= 0.0 );
  CellSet body = maximal - channel;
  CellSet inlet = top.subset( f >= 0.0 );
  outlet_ = bottom.subset( f >= 0.0 );

  /* create symbolic objects for test and unknown functions */
  Expr v = new TestFunction(new Lagrange(1));
  Expr u = new UnknownFunction(new Lagrange(1));
	
  Expr varT = new TestFunction(new Lagrange(1));
  Expr T = new UnknownFunction(new Lagrange(1));
	
  /* create symbolic differential operators */
  Expr dx = new Derivative(0,1);
  Expr dy = new Derivative(1,1);
  Expr grad = List(dx, dy);
	
  /* Write symbolic weak equation and Neumann and Robin BCs */
  Expr poisson = Integral(channel, -(grad*v)*(grad*u));
	
	
  /* Write essential BCs: 
   * Bottom: u=x^2
   */
  EssentialBC bc = EssentialBC(inlet, v*u)
    && EssentialBC(outlet_, v*(u-1.0));
	
  /* Assemble everything into a problem object, with a specification that
   * Petra be used as the low-level linear algebra representation */
  StaticLinearProblem prob(mesh_, poisson, bc, v, u, petra);
	
  /* solve the problem and return the solution as a symbolic object */
  Expr phi0 = prob.solve(solver_);
	
  double eps = 1.0e-4;
  Integral timeEqn(channel, eps*(grad*varT)*(grad*T) + varT*(grad*phi0)*(grad*T) - varT);
  EssentialBC startBC(inlet, varT*T);
  StaticLinearProblem travelTimeProb(mesh_, timeEqn, startBC, varT, T, petra);

  Expr T0 = travelTimeProb.solve(solver_);


  potential_ = phi0;
  travelTime_ = T0;
}

