#include "TSFDefs.h"
#include "Sundance.h"
#include "LAPACKGeneralMatrix.h"
#include "DenseSerialVectorType.h"
#include "TSFTimer.h"

using namespace Sundance;

using namespace TSF;

bool onTop(const Cell& cell);
bool onLeft(const Cell& cell);
bool onRight(const Cell& cell);
bool onBottom(const Cell& cell);
bool onCyl(const Cell& cell);

int main(int argc, void** argv)
{
  try
    {
      MPIComm::init(&argc, &argv);
      Timer::start();
			
      MeshReader reader = new ShewchukMeshReader("halfCyl_911");
      Mesh mesh = reader.getMesh();

      mesh.labelCells(1, "top", onTop);
      mesh.labelCells(1, "bottom", onBottom);
      mesh.labelCells(1, "left", onLeft);
      mesh.labelCells(1, "right", onRight);
      mesh.labelCells(1, "cyl", onCyl);

      CellSet left(1, "left");
      CellSet right(1, "right");
      CellSet top(1, "top");
      CellSet bottom(1, "bottom");
      CellSet cyl(1, "cyl");

      double Re = 1000.0;

      TSFVectorSpace discreteSpace 
        = new SundanceVectorSpace(mesh, new Lagrange(1),
                                  new PetraVectorType());

      // define variations and unknowns for U, V, and P.
      /* use Taylor-Hood discretization */
      Expr delU = new TestFunction(new Lagrange(2), "delU");
      Expr U = new UnknownFunction(new Lagrange(2), "U");
      Expr delV = new TestFunction(new Lagrange(2), "delV");
      Expr V = new UnknownFunction(new Lagrange(2), "V");
      Expr delP = new TestFunction(new Lagrange(1), "delP");
      Expr P = new UnknownFunction(new Lagrange(1), "P");

      Expr delPsi = new TestFunction(new Lagrange(1), "delPsi");
      Expr psi = new UnknownFunction(new Lagrange(1), "psi");

      /* initial velocity */
      Expr u0 = new DiscreteFunction(discreteSpace, 0.0, "u0");
      Expr v0 = new DiscreteFunction(discreteSpace, 0.0, "v0");
			
      Expr delVelocity = List(delU, delV);
      Expr velocity = List(U, V);

      // create differential operators for x and y directions, and
      // then form gradient operator.
      Expr dx = new Derivative(0,1);
      Expr dy = new Derivative(1,1);
      Expr grad = List(dx, dy);
			
      // coordinate functions
      Expr x = new CoordExpr(0, "x");
      Expr y = new CoordExpr(1, "y");

      // momentum continuity 
			
      Expr momentumEqn  = delU*(u0*dx + v0*dy)*U + 1/Re*(grad*U)*(grad*delU) 
        + delV*(u0*dx + v0*dy)*V + 1.0/Re*(grad*V)*(grad*delV)
        + P*(dx*delU + dy*delV);

      // incompressibility

      Expr continuityEqn = -delP*(dx*U + dy*V);

      Expr eqn = Integral(momentumEqn) + Integral(continuityEqn) ;


      /*
       * Boundary conditions: 
       * inflow u=1.0 on left 
       * symmetry v=0.0, NBC for u, on top & bottom
       * outflow NBC on right
       * no-slip u=0, v=0 on cylinder
       */

      EssentialBC bc = EssentialBC(left, delU*(U-1))
        && EssentialBC(bottom, delV*V)
        && EssentialBC(top, delV*V)
        && EssentialBC(cyl, delV*V + delU*U);


      /*
        Create a solver object: stablized biconjugate gradient solver
      */
      TSFPreconditionerFactory prec = new ILUKPreconditionerFactory(4);
      TSFLinearSolver solver = new BICGSTABSolver(prec, 1.0e-12, 4000);

      StaticLinearProblem prob(mesh, eqn, bc, List(delU, delV, delP),
                               List(U, V, P), 
                               new PetraVectorType());

      /* extract streamfunction from velocity */
      Expr streamfunctionEqn = Integral((grad*psi)*(grad*delPsi) 
                                        + delPsi*(dx*v0 - dy*u0)) 
        + Integral(bottom, -delPsi*u0)
        + Integral(top, delPsi*u0)
        + Integral(right, -delPsi*v0);
      EssentialBC streamfunctionBC(left, delPsi*(psi-y));
      StaticLinearProblem postproc(mesh, streamfunctionEqn, streamfunctionBC,
                                   delPsi, psi, new PetraVectorType());
																	 

      int i=0; 
      int maxiters = 10;
      bool converged = false;
			
      while (i < maxiters)
        {
          Expr soln = prob.solve(solver);
          u0 = soln[0];
          v0 = soln[1];

          Expr psi0 = postproc.solve(solver);
          ofstream psiFile(string("psi." + TSF::toString(i) + ".dat").c_str());
          psi0.matlabDump(psiFile);
          prob.flushMatrixValues();
          i++;
        }

			

    }
  catch(exception& e)
    {
      TSFOut::println(e.what());
      Testing::crash(__FILE__);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
    }
  Timer::report();
  MPIComm::finalize();
}






bool onBottom(const Cell& cell)
{
  // y == 0
  const Point& a = cell.facet(0,0).point(0);
  const Point& b = cell.facet(0,1).point(0);
		
  if (fabs(a[1]) < 1.0e-10 && fabs(b[1]) < 1.0e-10)
    {
      return true;
    }
  return false;
}

bool onTop(const Cell& cell)
{
  // y == 2
  const Point& a = cell.facet(0,0).point(0);
  const Point& b = cell.facet(0,1).point(0);
		
  if (fabs(a[1]-2.0) < 1.0e-10 && fabs(b[1]-2.0) < 1.0e-10)
    {
      return true;
    }
  return false;
}

bool onLeft(const Cell& cell)
{
  // x == -3
  const Point& a = cell.facet(0,0).point(0);
  const Point& b = cell.facet(0,1).point(0);
		
  if (fabs(a[0]+3.0) < 1.0e-10 && fabs(b[0]+3.0) < 1.0e-10)
    {
      return true;
    }
  return false;
}

bool onRight(const Cell& cell)
{
  // x == 3
  const Point& a = cell.facet(0,0).point(0);
  const Point& b = cell.facet(0,1).point(0);
		
  if (fabs(a[0]-3.0) < 1.0e-10 && fabs(b[0]-3.0) < 1.0e-10)
    {
      return true;
    }
  return false;
}


bool onCyl(const Cell& cell)
{
  // x^2 +y^2 == 1
  const Point& a = cell.facet(0,0).point(0);
  const Point& b = cell.facet(0,1).point(0);
		
  double ra = sqrt(a*a);
  double rb= sqrt(b*b);
  if (fabs(ra-1.0) < 1.0e-5 && fabs(rb) < 1.0e-5)
    {
      return true;
    }
  return false;
}




