#include "Sundance.h"

/**
 *
 */

int main(int argc, void** argv)
{
  try
    {
			
      Sundance::init(&argc, &argv);

      /* create a simple mesh on the rectangle */
      int nx = 64;
      int ny = 64;
      MeshGenerator mesher = new RectangleMesher(0.0, 1.0, nx, 0.0, 1.0, ny);
      Mesh mesh = mesher.getMesh();

      /* define coordinate functions for x and y coordinates */
      Expr x = new CoordExpr(0);
      Expr y = new CoordExpr(1);

      /* define cells sets for each of the four sides of the rectangle */
      CellSet boundary = new BoundaryCellSet();
      CellSet left = boundary.subset( x == 0.0 );
      CellSet right = boundary.subset( x == 1.0 );
      CellSet bottom = boundary.subset( y == 0.0 );
      CellSet top = boundary.subset( y == 1.0 );

			

      /* define unknowns for vorticity and stream function */
      Expr delOmega = new TestFunction(new Lagrange(1));
      Expr omega = new UnknownFunction(new Lagrange(1));
      Expr delPsi = new TestFunction(new Lagrange(1));
      Expr psi = new UnknownFunction(new Lagrange(1));

      // create differential operators for x and y directions, and
      // then form gradient operator.
      Expr dx = new Derivative(0);
      Expr dy = new Derivative(1);
      Expr grad = List(dx, dy);

      Expr vorticityEqn = -(grad*delOmega)*(grad*omega);
      Expr streamfunctionEqn = -(grad*delPsi)*(grad*psi) - delPsi*omega;

      Expr eqn = Integral(vorticityEqn) + Integral(streamfunctionEqn)
        + Integral(top, -delPsi);

      /*
       * Boundary conditions: 
       * psi=0 on all surfaces
       * d(psi)/dn = 0 on left, right, bottom
       * d(psi)/dn = 1 on top
       */

      EssentialBC bc = EssentialBC(top, delOmega*psi)
        && EssentialBC(bottom, delOmega*psi)
        && EssentialBC(left, delOmega*psi)
        && EssentialBC(right, delOmega*psi);


      /*
        Create a solver object: stablized biconjugate gradient solver
      */
      TSFPreconditionerFactory prec = new ILUKPreconditionerFactory(2);
      TSFLinearSolver solver = new BICGSTABSolver(prec, 1.0e-12, 5000);

      StaticLinearProblem prob(mesh, eqn, bc, List(delPsi, delOmega),
                               List(psi, omega), 
                               new PetraVectorType());

      Expr soln = prob.solve(solver);

      Expr omega0 = soln[1];
      Expr psi0 = soln[0];

      FieldWriter psiWriter = new MatlabWriter("psi.dat");
      psiWriter.writeField(psi0);

      double totalVorticity = omega0.integral(mesh);
      cerr << "total vorticity = " << totalVorticity << endl;

      double errorNorm = fabs(totalVorticity + 1.0);
      double tolerance = 1.0e-8;
      TSFOut::printf("error = %g\n", errorNorm);

      Testing::passFailCheck(__FILE__, errorNorm, tolerance);
    }
  catch(exception& e)
    {
      TSFOut::println(e.what());
      Testing::crash(__FILE__);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
    }
  MPIComm::finalize();
}







