#include "Sundance.h"

int main(int argc, void** argv)
{
  try
    {
      Sundance::init(&argc, &argv);

      /* create a simple mesh on the rectangle */
      int nx = 5;
      int ny = 5;
      MeshGenerator mesher = new RectangleMesher(-1.0, 1.0, nx, -1.0, 1.0, ny);
      Mesh mesh = mesher.getMesh();

      /* define coordinate functions for x and y coordinates */
      Expr x = new CoordExpr(0);
      Expr y = new CoordExpr(1);

      /* define cells sets for each of the four sides of the rectangle */
      CellSet boundary = new BoundaryCellSet();
      CellSet left = boundary.subset( x == -1.0 );
      CellSet right = boundary.subset( x == 1.0 );
      CellSet bottom = boundary.subset( y == -1.0 );
      CellSet top = boundary.subset( y == 1.0 );

      /* define variations and unknowns for U, V, and P. */
      /* use Taylor-Hood discretization */
      Expr delU = new TestFunction(new Lagrange(2));
      Expr U = new UnknownFunction(new Lagrange(2));
      Expr delV = new TestFunction(new Lagrange(2));
      Expr V = new UnknownFunction(new Lagrange(2));
      Expr delP = new TestFunction(new Lagrange(1));
      Expr P = new UnknownFunction(new Lagrange(1));

      /* specify that we will use Petra vector and matrix objects */
      TSFVectorType petraType = new PetraVectorType();

      /* define block structure [ (u,v), (p) ] */
      TSFArray<Block> unkBlocks = tuple(Block(List(U, V), petraType), 
                                        Block(P, petraType));

      TSFArray<Block> varBlocks = tuple(Block(List(delU, delV), petraType), 
                                        Block(delP, petraType));

      Expr delVelocity = List(delU, delV);
      Expr velocity = List(U, V);

      /* create differential operators for x and y directions, and
       * then form gradient operator. */
      Expr dx = new Derivative(0);
      Expr dy = new Derivative(1);
      Expr grad = List(dx, dy);
			
      /* weak momentum continuity equation */
      Expr momentumEqn  = -(grad*U)*(grad*delU) - (grad*V)*(grad*delV)
        + P*(dx*delU + dy*delV);

			
      /* weak incompressibility constraint */
      Expr continuityEqn = -delP*(dx*U + dy*V);

      /* integrate */
      Expr eqn = Integral(momentumEqn) + Integral(continuityEqn);

      /*
       * Boundary conditions for Pouiselle flow: 
       * v=0 on top, bottom, and left.
       * u=0 on top and bottom.
       * u=1/2 (1-y^2) on left.
       * Natural BCs on right.
       */

      Expr ULeft = 0.5*(1.0-y*y);
			
      EssentialBC bc = EssentialBC(top, delU*U + delV*V)
        && EssentialBC(bottom, delU*U + delV*V)
        && EssentialBC(left, delU*(U-ULeft));


      /* form a problem object from which we can extract operators */
      StaticLinearProblem prob(mesh, eqn, bc, varBlocks, unkBlocks);

      /* Create solver objects. We'll use a Schur complement solver
       * that takes advantage of our ability to solve the diffusion
       * operator easily. We will use preconditioned BIGCSTAB for the
       * inner solve on the diffusion block and unpreconditioned
       * BICGSTAB for the outer solve. We'll ask for tighter tolerance
       * on the inner solve */
      TSFPreconditionerFactory innerPrecond = new ILUKPreconditionerFactory(2);
      TSFLinearSolver innerSolver = new BICGSTABSolver(innerPrecond, 1.0e-12, 1000);
      innerSolver.setVerbosityLevel(1);

      //			TSFPreconditionerFactory outerPrecond 
      //				= new StokesRightPreconditionerFactory(innerSolver);
      TSFLinearSolver outerSolver = new BICGSTABSolver(1.0e-10, 2000);
      outerSolver.setVerbosityLevel(2);

      TSFLinearSolver blockSolver = new SchurComplementSolver(innerSolver, outerSolver);

      /* solve the thing */
      Expr soln = prob.solve(blockSolver);

      Expr U0 = soln[0][0];
      Expr V0 = soln[0][1];
      Expr P0 = soln[1];

      /* write to matlab */
      ofstream ofu("u.dat");
      U0.matlabDump(ofu);
      ofstream ofv("v.dat");
      V0.matlabDump(ofv);
      ofstream ofp("p.dat");
      P0.matlabDump(ofp);

      /*
       * Exact solution is 
       * u = 1/2 (1-y^2)
       * v = 0
       * P = -x
       */
      Expr exactVelocity = List( 0.5*(1-y*y), 0.0 );
      Expr exactPressure = 1.0 - x;
      BasisFamily L1 = new Lagrange(1);
      BasisFamily L2 = new Lagrange(2);
      TSFVectorSpace discreteSpace1 
        = new SundanceVectorSpace(mesh, L1, new PetraVectorType());
      TSFVectorSpace discreteSpace2
        = new SundanceVectorSpace(mesh, tuple(L2, L2), new PetraVectorType());

      Expr pErrorExpr = new DiscreteFunction(discreteSpace1, exactPressure-soln[1]);
      Expr velErrorExpr = DiscreteFunction::discretize(discreteSpace2, 
                                                       exactVelocity-soln[0]);

      /*
        compute the norm of the error
      */
      double velErrorNorm = velErrorExpr[0].norm(2);
      double pErrorNorm = pErrorExpr.norm(2);
      double errorNorm = velErrorNorm + pErrorNorm;
      cerr << "error in velocity = " << velErrorNorm << endl;
      cerr << "error in pressure = " << pErrorNorm << endl;
      double tolerance = 1.0e-4;

      /*
        decide if the error is within tolerance
      */
      Testing::passFailCheck(__FILE__, errorNorm, tolerance);
    }
  catch(exception& e)
    {
      Sundance::handleError(e, __FILE__);
    }
  Sundance::finalize();
}







