#include "TSFDefs.h"
#include "Sundance.h"
#include "LAPACKGeneralMatrix.h"


using namespace Sundance;

using namespace TSF;


int main(int argc, void** argv)
{
  try
    {
      MPIComm::init(&argc, &argv);

      int nx = 4;
      int ny = 4;
      Mesh mesh = rectMesh(-1.0, 1.0, nx, -1.0, 1.0, ny);

      CellSet left(1, "left");
      CellSet right(1, "right");
      CellSet top(1, "top");
      CellSet bottom(1, "bottom");

      // define variations and unknowns for U, V, and P.
      /* use Taylor-Hood discretization */
      Expr delU = new TestFunction(new Lagrange(2), "delU");
      Expr U = new UnknownFunction(new Lagrange(2), "U");
      Expr delV = new TestFunction(new Lagrange(2), "delV");
      Expr V = new UnknownFunction(new Lagrange(2), "V");
      Expr delP = new TestFunction(new Lagrange(1), "delP");
      Expr P = new UnknownFunction(new Lagrange(1), "P");
			
      Expr delVelocity = List(delU, delV);
      Expr velocity = List(U, V);

      // create differential operators for x and y directions, and
      // then form gradient operator.
      Expr dx = new Derivative(0,1);
      Expr dy = new Derivative(1,1);
      Expr grad = List(dx, dy);
			
      // coordinate functions
      Expr x = new CoordExpr(0, "x");
      Expr y = new CoordExpr(1, "y");

      // momentum continuity 
			
      Expr eqn1  = (grad*U)*(grad*delU) - delU;// + (grad*V)*(grad*delV) - delU;
      Expr eqn2 = (grad*P)*(grad*delP);
			

      Expr eqn = Integral(eqn1);// + Integral(eqn2);

      /*
       * Boundary conditions: 
       * v=0 on top, bottom, and left.
       * u=0 on top and bottom.
       * u=1/2 (1-y^2) on left.
       * Natural BCs on right.
       */

      Expr ULeft = 0.5*(1.0-y*y);
			
			
      /*
        EssentialBC bc = EssentialBC(top, delU*U + delV*V)
        && EssentialBC(bottom, delU*U + delV*V)
        && EssentialBC(left, delU*(U-ULeft) + delV*V + delP*(P+1.0))
        && EssentialBC(right, delU*(U-ULeft) + delV*V + delP*(P-1.0));
      */
      EssentialBC bc = EssentialBC(top, delU*U)
        && EssentialBC(bottom, delU*U)
        && EssentialBC(left, delU*(U-ULeft))
        && EssentialBC(right, delU*(U-ULeft));
      /*
        Create a solver object: stablized biconjugate gradient solver
      */
      TSFPreconditionerFactory prec = new ILUKPreconditionerFactory(1);
      TSFLinearSolver solver = new BICGSTABSolver(prec, 1.0e-14, 300);

      /*
        StaticLinearProblem prob(mesh, eqn, bc, List(delU, delV, delP),
        List(U, V, P), new LAPACKGeneralMatrix());
      */
      StaticLinearProblem prob(mesh, eqn, bc, delU, U,
                               new LAPACKGeneralMatrix());
      Expr soln = prob.solve();

      Expr U0 = soln[0];
      //			Expr V0 = soln[1];
      //Expr P0 = soln[2];

      ofstream ofu("u.dat");

      U0.matlabDump(ofu);
      //ofstream ofv("v.dat");
      //V0.matlabDump(ofv);
      //ofstream ofp("p.dat");
      //P0.matlabDump(ofp);

      /*
       * Exact solution is 
       * u = 1/2 (1-y^2)
       * v = 0
       * P = -x
       */
      Expr exactSoln = ULeft; //List( 0.5*(1-y*y), 0.0, x );
      TSFVectorSpace discreteSpace 
        = new SundanceVectorSpace(mesh, new Lagrange(1),
                                  new PetraVectorType());

      Expr uErrorExpr = new DiscreteFunction(discreteSpace, exactSoln[0]-U0, 
                                             "uErrorExpr");
      /*
        Expr vErrorExpr = new DiscreteFunction(discreteSpace, exactSoln[1]-V0, 
        "vErrorExpr");
        Expr pErrorExpr = new DiscreteFunction(discreteSpace, exactSoln[2]-P0, 
        "pErrorExpr");
      */

      /*
        compute the norm of the error
      */
      double uErrorNorm = uErrorExpr.norm();
      //			double vErrorNorm = vErrorExpr.norm();
      //double pErrorNorm = pErrorExpr.norm();
      double errorNorm = uErrorNorm;// + vErrorNorm + pErrorNorm;
      cerr << "error in u = " << uErrorNorm << endl;
      //			cerr << "error in v = " << vErrorNorm << endl;
      //cerr << "error in p = " << pErrorNorm << endl;
      double tolerance = 1.0e-4;

      /*
        decide if the error is within tolerance
      */
      Testing::passFailCheck(__FILE__, errorNorm, tolerance);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
    }
  catch(exception& e)
    {
      TSFOut::println(e.what());
      Testing::crash(__FILE__);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
    }
  MPIComm::finalize();
}







