#include "Sundance.h"


int main(int argc, void** argv)
{
  try
    {
      Sundance::init(&argc, &argv);

      int nx = 20;
      int ny = 20;
      int npx = 1;
      int npy = 1;

      int fill=1;
      int overlap=1;
      double tol = 1.0e-10;
      int maxiter = 5000;

      TSFCommandLine::findInt("-nx", nx);
      TSFCommandLine::findInt("-ny", ny);
      TSFCommandLine::findInt("-maxiter", maxiter);
      TSFCommandLine::findInt("-fill", fill);
      TSFCommandLine::findInt("-overlap", overlap);
      TSFCommandLine::findDouble("-tol", tol);


      bool useAztec =	TSFCommandLine::find("-aztec");

      /* create a preconditioner and solver */
      TSFHashtable<int, int> azOptions;
      TSFHashtable<int, double> azParams;
      TSFLinearSolver solver;

      if (useAztec)
        {
          azOptions.put(AZ_solver, AZ_gmres);
          azOptions.put(AZ_precond, AZ_dom_decomp);
          azOptions.put(AZ_subdomain_solve, AZ_ilu);
          azOptions.put(AZ_graph_fill, fill);
          azParams.put(AZ_tol, tol);
          azOptions.put(AZ_max_iter, maxiter);
          solver = new AZTECSolver(azOptions, azParams);
        }
      else
        {
          TSFPreconditionerFactory precond 
            = new ILUKPreconditionerFactory(fill, overlap);
          solver = new BICGSTABSolver(precond, tol, maxiter);
          solver.setVerbosityLevel(4);
        }

      /* create a simple mesh on the rectangle */
      double L = 4.0;
      MeshGenerator mesher = new RectangleMesher(-L, L, nx, -L, L, ny);
      Mesh mesh = mesher.getMesh();

      /* define coordinate functions for x and y coordinates */
      Expr x = new CoordExpr(0);
      Expr y = new CoordExpr(1);

      /* define cells sets for each of the four sides of the rectangle */
      CellSet boundary = new BoundaryCellSet();
      CellSet left = boundary.subset( x == -L );
      CellSet right = boundary.subset( x == L );
      CellSet bottom = boundary.subset( y == -L );
      CellSet top = boundary.subset( y == L );


      /* known velocity field: Poiuselle flow */
      Expr cx = (L+y)*(L-y);

      /* diffusivity */
      Expr k = 1.0;

      /* define variations and unknowns for state, adjoint, inversion fields */
      Expr v = new TestFunction(new Lagrange(1), "v");
      Expr u = new UnknownFunction(new Lagrange(1), "u");
      Expr mu = new TestFunction(new Lagrange(1), "mu");
      Expr lambda = new UnknownFunction(new Lagrange(1), "lambda");
      Expr q = new TestFunction(new Lagrange(1), "q");
      Expr p = new UnknownFunction(new Lagrange(1), "p");

      /* define derivatives and gradient operator */
      Expr dx = new Derivative(0,1);
      Expr dy = new Derivative(1,1);
      Expr grad = List(dx, dy);
                        

      //set up weak form of state, adjoint, and inversion equations
      Expr stateEqn = Integral(mu*((c*grad)*u))+k*((grad*u)*(grad*mu)) 
        - mu*p);

  Expr adjointEqn =  Integral(lambda*(cx*(dx*v)) 
                              + k*((grad*lambda)*(grad*v)));

  /* read sensor data from file and append to adjoint equation */
  ifstream is("sensors.dat");
  int nSensor;
  is >> nSensor;
  for (int i=0; i<nSensor; i++)
    {
      double sx;
      double sy;
      double uVal;
      is >> sx >> sy >> uVal;
      Cell sensor = mesh.cellNearestToPoint(Point(sx, sy));
      string label = "sensor" + TSF::toString(i);
      sensor.setLabel(label);
      CellSet sensorRegion = new LabeledCellSet(label);
      adjointEqn = adjointEqn + Integral(sensorRegion, v*(u-uVal));
    }

  Expr inversionEqn = Integral((grad*q)*(grad*p) - q*lambda);
  Expr weakform = stateEqn + adjointEqn + inversionEqn;

  /*
   * Boundary conditions:
   * - p has Neumann on all surfaces.
   * - u is zero on left, Neumann everywhere else. 
   * - lambda is zero on left, bottom, top
   * - lambda has Neumann on right 
   */
  Expr uhat = 0.0;
  Expr lambdahat = 0.0;

  EssentialBC bc = EssentialBC(left, v*(lambda-lambdahat))
    && EssentialBC(top, v*(lambda-lambdahat)) 
    && EssentialBC(bottom, v*(lambda-lambdahat));

			
  /* 
   * Assemble and solve problem
   */ 
  StaticLinearProblem prob(mesh, weakform, bc, 
                           List(mu,v,q), List(u,lambda,p), 
                           new PetraVectorType()); 
			
  Expr soln = prob.solve(solver);

			
  Expr u_h = soln[0];
  Expr lambda_h = soln[1];
  Expr p_h = soln[2];

  ofstream ofu("uInv.dat");
  u_h.matlabDump(ofu);

  ofstream oflambda("lambda.dat");
  lambda_h.matlabDump(oflambda);

  ofstream ofp("p.dat");
  p_h.matlabDump(ofp);

}

catch(exception& e)
{
  TSFOut::println(e.what());
  Testing::crash(__FILE__);
  Testing::timeStamp(__FILE__, __DATE__, __TIME__);
}
Sundance::finalize();
}


