#include "Sundance.h"

int main(int argc, void** argv)
{
  try
    {
      Sundance::init(&argc, &argv);

      /* read mesh from Shewchuk file */
      MeshReader reader = new ShewchukMeshReader("inOut.1");
      Mesh mesh = mesher.getMesh();

      /* define coordinate functions for x and y coordinates */
      Expr x = new CoordExpr(0);
      Expr y = new CoordExpr(1);

      /* define cells sets for each of the four sides of the rectangle */
      CellSet boundary = new BoundaryCellSet();
      
      /* boundary segment definition from Paul's code */
      CellSet seg1 = boundary.subset(x == 0.0 && y >= 0.0 && y <= 2.0);
      CellSet seg2 = boundary.subset(y == 2.0 && x >= -2.0 && x <= 0.0);
      CellSet seg3 = boundary.subset(y == 0.0 && x >= 0.0 && x <= 20.0);
      CellSet seg4 = boundary.subset(x == 20.0 && y >= 0.0 && y <= 11.0);
      CellSet seg5 = boundary.subset(y == 11.0 && x >= 20.0 && x <= 22.0);

      CellSet seg6 = boundary.subset(y == 4.0 && x >= -2.0 && x <= 0.0);
      CellSet seg7 = boundary.subset(x == 0.0 && y >= 4.0 && y <= 15.0);
      CellSet seg8 = boundary.subset(y == 15.0 && x >= 0.0 && x <= 20.0);
      CellSet seg9 = boundary.subset(x == 20.0 && y <= 15.0 && y >= 13.0);
      CellSet seg10 = boundary.subset(y == 13.0 && x >= 20.0 && x <= 22.0);

      CellSet wall = seg1 + seg2 + seg3 + seg4 + seg5 + seg6 
        + seg7 + seg8 + seg9 + seg10;

      CellSet inlet  = boundary.subset(x == -2.0 && y >= 2.0 && y <= 4.0);
      CellSet outlet = boundary.subset(x == 22.0 && y >= 11.0 && y <= 13.0);


      /* define variations and unknowns for U, V, and P
       * using the same basis functions for velocity and pressure */
      Expr delU = new TestFunction(new Lagrange(1), "delU");
      Expr U = new UnknownFunction(new Lagrange(1), "U");
      Expr delV = new TestFunction(new Lagrange(1), "delV");
      Expr V = new UnknownFunction(new Lagrange(1), "V");
      Expr delP = new TestFunction(new Lagrange(1), "delP");
      Expr P = new UnknownFunction(new Lagrange(1), "P");

      Expr delVelocity = List(delU, delV);
      Expr velocity = List(U, V);

      /* create differential operators for x and y directions, and
       * then form gradient operator. */
      Expr dx = new Derivative(0);
      Expr dy = new Derivative(1);
      Expr grad = List(dx, dy);
			
      /* momentum continuity equation */
			
      Expr momentumEqn  = -(grad*U)*(grad*delU) - (grad*V)*(grad*delV)
        + P*(dx*delU + dy*delV);

      /* incompressibility constraint equation with stabilization */

      Expr beta = new ParameterExpr(0.02);
      Expr h = new CellDiameterExpr();

      Expr continuityEqn = -delP*(dx*U + dy*V) 
        - beta*h*h*(grad*P)*(grad*delP);

      Expr eqn = Integral(momentumEqn, new GaussianQuadrature(1)) 
        + Integral(continuityEqn, new GaussianQuadrature(1));

      /* boundary conditions:
       * No-slip on all walls.
       * Natural on outlet.
       * Poissiuelle on inlet.
       */
      
      Expr zero = 0.0;
      Expr yLow = 2.0;
      Expr yHigh = 4.0;
      Expr u0 = 1.0;

      Expr diff2 = (yHigh - yLow)*(yHigh - yLow); 
      Expr inletBC = List( -4.0*u0*(y - yLow)*(y - yHigh)/diff2, zero);
      
      EssentialBC bc = EssentialBC(wall, delU*U + delV*V)
        && EssentialBC(inlet, delU*(U-inletBC[0]) + delV*(V-inletBC[1]));

      /*
        Create a solver object: stablized biconjugate gradient solver
      */
      TSFPreconditionerFactory prec = new ILUKPreconditionerFactory(2);
      TSFLinearSolver solver = new BICGSTABSolver(prec, 1.0e-12, 400);

      StaticLinearProblem prob(mesh, eqn, bc, List(delU, delV, delP),
                               List(U, V, P), new PetraVectorType());

      Expr soln = prob.solve(solver);
      Expr u0 = soln[0];
      Expr v0 = soln[1];

      /* We've now solved the problem for the primitive variables.
       * For visualization we next compute the streamfunction */
      Expr delPsi = new TestFunction(new Lagrange(1));
      Expr psi = new UnknownFunction(new Lagrange(1));

      Expr vorticity = dx*v0 - dy*u0;
      Expr streamfunctionEqn = Integral(-(grad*delPsi)*(grad*psi) - delPsi*vorticity,
                                        new GaussianQuadrature(1));

      /* streamfunction is zero along entire boundary */
      EssentialBC streamfunctionBC(boundary, delPsi*psi);
			
      StaticLinearProblem streamfunctionProb(mesh, streamfunctionEqn, 
                                             streamfunctionBC,
                                             delPsi, psi, 
                                             new PetraVectorType());

      Expr psi0 = streamfunctionProb.solve(solver);

      /*
        write the streamfunction in a form readable by matlab
      */
      FieldWriter psiWriter = new MatlabWriter("psi.dat");
      psiWriter.writeField(psi0);
      FieldWriter vWriter = new MatlabVectorWriter("v.dat");
      vWriter.writeField("velocity", List(u0, v0));
			
    }
  catch(exception& e)
    {
      Sundance::handleError(e, __FILE__);
    }
  Sundance::finalize();

}





