#include "Expr.h"
#include <string>
#include "Derivative.h"
#include "Lagrange.h"
#include "WeakForm.h"
#include "TestFunction.h"

#include "GaussLegendre.h"
#include "SerialMatrixBuilder.h"
#include "MatrixBase.h"
#include "SerialRowSparseMatrix.h"
#include "CellSetBase.h"

#include "SimpleMeshes.h"
#include "TSFUtils.h"
#include "BICGSTABSolver.h"
#include "DiscreteFunction.h"
#include <iostream.h>
#include <fstream.h>

bool onLeft(const Cell& cell)
{
  const Point& a = cell.facet(0,0).point(0);
  const Point& b = cell.facet(0,1).point(0);
	
  if (fabs(a[0]) < 1.0e-10 && fabs(b[0]) < 1.0e-10)
    {
      return true;
    }
  return false;
}

bool onRight(const Cell& cell)
{
  const Point& a = cell.facet(0,0).point(0);
  const Point& b = cell.facet(0,1).point(0);
	
  if (fabs(a[0]-1.0) < 1.0e-10 && fabs(b[0]-1.0)<1.0e-10)
    {
      return true;
    }
  return false;
}

bool onTop(const Cell& cell)
{
  const Point& a = cell.facet(0,0).point(0);
  const Point& b = cell.facet(0,1).point(0);
	
  if (fabs(a[1]-1.0) < 1.0e-10 && fabs(b[1]-1.0)<1.0e-10)
    {
      return true;
    }
  return false;
}

bool onBottom(const Cell& cell)
{
  const Point& a = cell.facet(0,0).point(0);
  const Point& b = cell.facet(0,1).point(0);
	
  if (fabs(a[1]) < 1.0e-10 && fabs(b[1])<1.0e-10)
    {
      return true;
    }
  return false;
}



int main()
{
  try
    {
      // create an expr
      Expr delU = TestFunction(Lagrange(2), "delU");
      Expr U(Lagrange(2), "U");
      Expr delSx = TestFunction(Lagrange(1), "delSx");
      Expr Sx(Lagrange(1), "Sx");
      Expr delSy = TestFunction(Lagrange(1), "delSy");
      Expr Sy(Lagrange(1), "Sy");
      Derivative dx(0,1);
      Derivative dy(1,1);
			
      double f = 8.0;

      Expr grad = List(dx, dy);
      Expr S = List(Sx, Sy);
      Expr delS = List(delSx, delSy);
      Expr divS = dx*Sx + dy*Sy;
      Expr divDelS = dx*delSx + dy*delSy;
			
      Expr e = divS*divDelS
        - f*(divDelS) + delS*S 
        + (grad*delU)*(grad*U)
        - delS*(grad*U) - (grad*delU)*S;

			
      CellSet left(1, "left");
      CellSet right(1, "right");
      CellSet top(1, "top");
      CellSet bottom(1, "bottom");

      int n = 100;
      double h = 1.0/((double) n);

      Expr eqn = Integral(e) + Integral(left, delU*(U)/h) 
        + Integral(right, delU*(U)/h) + Integral(top, delSy*Sy/h)
        + Integral(bottom, delSy*Sy/h);

      // create a mesh


      Mesh mesh = rectMesh(0.0, 1.0, n, 0.0, 1.0, n);
      mesh.labelCells(1, "left", onLeft);
      mesh.labelCells(1, "right", onRight);
      mesh.labelCells(1, "top", onTop);
      mesh.labelCells(1, "bottom", onBottom);
			
      BICGSTABSolver solver(1.e-10, 2000);

      MatrixBase* A = solver.matrix();

      builderPreproc.start();

      cerr << "building MatrixBuilder... " << endl;
      SerialMatrixBuilder builder(mesh, eqn, List(delSx, delSy, delU),
                                  List(Sx, Sy, U));

      builderPreproc.stop();

      cerr << "building matrix... " << endl;
      buildTime.start();
      builder.buildMatrix(*A);
      buildTime.stop();

      ofstream mf("matrix.dat");
      mf << A->nCols() << endl;
      A->print(mf);
			
      cerr << "matrix size " << A->nRows() << endl;
      cerr << "building vector..." << endl;
      Vector b;
      builder.buildVector(b);
      ofstream bf("vector.dat");
      bf << b.length() << endl;
      for (int i=0; i<b.length(); i++)
        {
          bf << i << " " << b[i] << endl;
        }

      exit(0);
      cerr << "starting solve..." << endl;
      TSFSmartPtr<DenseSerialVector> x = new DenseSerialVector(b.length());
      Stopwatch solveTime("solve time");
      solveTime.start();
      solver.solve(b, *x);
      solveTime.stop();

      int reducedShapeID;
      bool gotIt = builder.lookupReducedShapeID(Sx.funcID(), reducedShapeID);
      if (!gotIt) 
        {
          cerr << Sx.funcID() << " " << reducedShapeID << endl;
          TSFError::raise("could not find unk in table of shape functions");
        }
      Expr solnSx = DiscreteFunction(builder.shapeDOFMap(), x, 
                                     Lagrange(1), reducedShapeID, "solnSx");

      ofstream of("tmp.dat");
      solnSx.matlabDump(of);
			
    }
  catch(exception& e)
    {
      e.print();
      exit(1);
    }
				

}

