#include "TSFDefs.h"
#include "Sundance.h"
#include "LAPACKGeneralMatrix.h"
#include "DenseSerialVectorType.h"
#include "TSFTimer.h"
#include "PetraVector.h"
#include "PetraMatrix.h"
#include "IfpackOperator.h"

#include "StokesRightPreconditionerFactory.h"
#include "KayLoghinPreconditionerFactory.h"
#include "MaximalCellSet.h"
#include "RCMCellReorderer.h"
#include "FieldWriter.h"
#include "VTKWriter.h"

using namespace Sundance;

using namespace TSF;


int main(int argc, void** argv)
{
  try
    {
      Sundance::init(&argc, &argv);

      /* create a simple mesh on the rectangle */
      int nx = 32;
      int ny = 32;
      MeshGenerator mesher = new RectangleMesher(0.0, 1.0, nx, 0.0, 1.0, ny);
      Mesh mesh = mesher.getMesh();

      /* define coordinate functions for x and y coordinates */
      Expr x = new CoordExpr(0);
      Expr y = new CoordExpr(1);

      /* define cells sets for each of the four sides of the rectangle */
      CellSet boundary = new BoundaryCellSet();
      CellSet left = boundary.subset( x == 0.0 );
      CellSet right = boundary.subset( x == 1.0 );
      CellSet bottom = boundary.subset( y == 0.0 );
      CellSet top = boundary.subset( y == 1.0 );

      /* specify that we will use Petra vector and matrix objects */
      TSFVectorType petraType = new PetraVectorType();

      BasisFamily lagr2 = new Lagrange(2);
      TSFVectorSpace discreteSpace 
        = new SundanceVectorSpace(mesh, new Lagrange(1), petraType);

      TSFVectorSpace velSpace 
        = new SundanceVectorSpace(mesh, tuple(lagr2, lagr2), petraType);

      // define variations and unknowns for U, V, and P.
      /* use Taylor-Hood discretization */
      Expr delU = new TestFunction(new Lagrange(2), "delU");
      Expr U = new UnknownFunction(new Lagrange(2), "U");
      Expr delV = new TestFunction(new Lagrange(2), "delV");
      Expr V = new UnknownFunction(new Lagrange(2), "V");
      Expr delP = new TestFunction(new Lagrange(1), "delP");
      Expr P = new UnknownFunction(new Lagrange(1), "P");

      Expr delPsi = new TestFunction(new Lagrange(1), "delPsi");
      Expr psi = new UnknownFunction(new Lagrange(1), "psi");

      /* define block structure [ (u,v,psi), (p) ] */
      TSFArray<Block> unkBlocks = tuple(Block(List(U, V), petraType), 
                                        Block(P, petraType));

      TSFArray<Block> varBlocks = tuple(Block(List(delU, delV), petraType), 
                                        Block(delP, petraType));

      /* initial velocity */
      Expr zero = List(Expr(0.0), Expr(0.0));
      Expr vel0 = DiscreteFunction::discretize(velSpace, zero);
			
      Expr delVelocity = List(delU, delV);
      Expr velocity = List(U, V);

      // create differential operators for x and y directions, and
      // then form gradient operator.
      Expr dx = new Derivative(0);
      Expr dy = new Derivative(1);
      Expr grad = List(dx, dy);
			
      // momentum continuity 

      Expr reynolds = new ParameterExpr(100.0);
			
      Expr momentumEqn  = delU*(vel0[0]*dx + vel0[1]*dy)*U 
        + (grad*U)*(grad*delU)/reynolds 
        + delV*(vel0[0]*dx + vel0[1]*dy)*V + (grad*V)*(grad*delV)/reynolds
        + P*(dx*delU + dy*delV);

      // incompressibility

      Expr continuityEqn = delP*(dx*U + dy*V);

      /* extract streamfunction from velocity */
      Expr streamfunctionEqn = (grad*psi)*(grad*delPsi) 
        + delPsi*(dx*vel0[1] - dy*vel0[0]);

      Expr eqn = Integral(momentumEqn) + Integral(continuityEqn) ;
			
      Expr postprocEqn = Integral(streamfunctionEqn);
      EssentialBC postprocBC(boundary, delPsi*psi);

      /*
       * Boundary conditions: 
       * v=0 on all sides
       * u=1 on top 
       * u=0 elsewehere
       */

      EssentialBC bc = EssentialBC(top, delU*(U-1) + delV*V /*+ delPsi*psi*/)
        && EssentialBC(left, delU*(U) + delV*V /*+ delPsi*psi*/)
        && EssentialBC(right, delU*(U) + delV*V /*+ delPsi*psi*/)
        && EssentialBC(bottom, delU*(U) + delV*V /*+ delPsi*psi*/);


      /* Create solver objects. We'll use a Schur complement solver
       * that takes advantage of our ability to solve the diffusion
       * operator easily. We will use preconditioned BIGCSTAB for the
       * inner solve on the diffusion block and unpreconditioned
       * BICGSTAB for the outer solve. We'll ask for tighter tolerance
       * on the inner solve */
      TSFPreconditionerFactory innerPrecond = new ILUKPreconditionerFactory(2);
      TSFLinearSolver innerSolver = new BICGSTABSolver(innerPrecond, 1.0e-12, 2000);
      innerSolver.setVerbosityLevel(1);

      /*
        TSFPreconditionerFactory outerPrecond 
        = new StokesRightPreconditionerFactory(innerSolver);
        TSFLinearSolver outerSolver = new BICGSTABSolver(outerPrecond, 
        1.0e-10, 2000);
        TSFLinearSolver blockSolver = new SchurComplementSolver(innerSolver, outerSolver);
        outerSolver.setVerbosityLevel(2);
      */
      TSFPreconditionerFactory klPrecond 
        = new KayLoghinPreconditionerFactory(mesh, vel0, reynolds);

      TSFLinearSolver blockSolver = new BICGSTABSolver(klPrecond,
                                                       1.0e-10, 2000);

      StaticLinearProblem prob(mesh, eqn, bc, varBlocks, unkBlocks);

      StaticLinearProblem postProc(mesh, postprocEqn, postprocBC,
                                   delPsi, psi, petraType);
      Expr psi0;

      int i=0; 
      int maxiters = 100;
      bool converged = false;
      double picardTol = 1.0e-8;

      while (i < maxiters)
        {
          Expr soln = prob.solve(blockSolver);
          Expr step = (vel0-soln[0])*(vel0-soln[0]);
          double stepSize = step.integral(mesh);
          stepSize = sqrt(fabs(stepSize));
          vel0[0] = soln[0][0];
          vel0[1] = soln[0][1];
          Expr p0 = soln[1];
          TSFOut::println("starting postproc solve");
          psi0 = postProc.solve(innerSolver);
          FieldWriter writer = new MatlabWriter("psi." + TSF::toString(i) + ".vtk");
          writer.writeField(psi0);
          prob.flushMatrixValues();
          postProc.flushMatrixValues();
          TSFOut::printf("oseen iteration %d %g\n", i, stepSize);
          if (stepSize < picardTol) 
            {
              converged = true;
              break;
            }
          i++;
        }

      FieldWriter psiWriter = new MatlabWriter("psi.dat");
      psiWriter.writeField(psi0);

      Expr omega = dx*vel0[1] - dy*vel0[0];
      double totalVorticity = omega.integral(mesh);

      cerr << "total vorticity = " << totalVorticity << endl;
			

      double errorNorm = fabs(totalVorticity + 1.0);
      double tolerance = 1.0e-3;
      TSFOut::printf("error = %g\n", errorNorm);

      Testing::passFailCheck(__FILE__, errorNorm, tolerance);
    }
  catch(exception& e)
    {
      Sundance::handleError(e, __FILE__);
    }
  Sundance::finalize();
}
			







