#include "TSFDefs.h"
#include "Sundance.h"
#include "DirectSolver.h"
#include "LAPACKGeneralMatrix.h"
#include "DenseSerialVectorType.h"

using namespace Sundance;

using namespace TSF;

/**
\example heat1D.cpp
Solve the heat equation with a constant source term on the unit line.
This example will illustrate the creation of a simple mesh, the definition
of a symbolic representation of a finite-element problem, and the
solution of that problem.

 */


int main(int argc, void** argv)
{
  try
    {
      /*
        Begin by initializing the parallel environment. This step is
        necessary even in serial problems.
      */
      MPIComm::init(&argc, &argv);


      /*
        Create a mesh object. In this example, we will use a built-in method
        to create a uniform mesh on the unit line. In more realistic problems
        we would use a mesher to create a mesh, and then read the mesh using
        a MeshReader object. 
      */
      int n = 10;
      Mesh mesh = lineMesh(0.0, 1.0, n);

			
      /*
        Define a CellSet representing the set of all cells having label "left".
        Certain cells in the mesh are assumed to have been given the label
        "left". The CellSet will be used in the specification of 
        boundary conditions; it is mesh-independent, so that if the mesh were
        to be refined, the same high-level 
        boundary condition specification could be used.
      */
      CellSet left(0, "left");

      /*
        Now define a CellSet that will represent all cells labeled "right".
      */
      CellSet right(0, "right");



      /*
        Define an unknown function and its variation. The first constructor
        argument is the basis family with which the function will be
        represented, in this case second-order Lagrange (nodal) polynomials.
        The second argument is a descriptive label that can be used in 
        human-readable printing.
      */
      Expr U = new UnknownFunction(new Lagrange(1), "U");
      Expr varU = new TestFunction(new Lagrange(1), "varU");
      Expr W = new UnknownFunction(new Lagrange(1), "W");
      Expr varW = new TestFunction(new Lagrange(1), "varW");

      /* define block structure */
      TSFVectorType vecType = new DenseSerialVectorType();
      TSFArray<Block> unkBlocks = tuple(Block(U, vecType), Block(W, vecType));
      TSFArray<Block> varBlocks = tuple(Block(varU, vecType), 
                                        Block(varW, vecType));


      /*
        Define the differentiation operator of order 1 in direction 0.
      */
      Expr dx = new Derivative(0,1);
			
			

      /*
        Create a discrete function with constant value 1.0, discretized
        on our mesh object with first-order Lagrange basis. We will use
        this discrete function as the source term. (We could simply use
        a constant expression 1.0 instead, but we wanted to illustrate
        how to create a discrete function). 
      */
      TSFVectorSpace discreteSpace 
        = new SundanceVectorSpace(mesh, new Lagrange(1), vecType);
																	

      Expr f = new DiscreteFunction(discreteSpace, 1.0, "f");
			


      /*
        Having constructed the unknown, variation, source, 
        and differentiation operator,
        we can now define the variational form of the problem on the interior
        of the domain. 
      */
      Expr eqn = Integral(-(dx*U)*(dx*varU) - varU*W + varW*(W-f), new GaussLegendre(4));

      /*
        Now specify the boundary conditions on the left and right CellSets.
      */
      Expr x = new CoordExpr(0, "x");
      EssentialBC bc = 
        EssentialBC(left, varU*U) && EssentialBC(right, varU*U);


      /*
        Create a solver object: stablized biconjugate gradient solver
      */
      TSFLinearSolver solver = new DirectSolver();



      /*
        Combine the geometry, the variational form, the BCs, and the solver
        to form a complete problem.
      */
      StaticLinearProblem prob(mesh, eqn, bc, varBlocks, unkBlocks);

      TSFLinearOperator op = prob.getOperator();

      TSFLinearOperator K = op.getBlock(0,0);
      TSFLinearOperator M01 = op.getBlock(0,1);			
      TSFLinearOperator M10 = op.getBlock(1,1);

      TSFVector b = prob.getRHS().getBlock(1);
			
      TSFVector wVec = M10.inverse()*b;
      TSFVector uVec = K.inverse()*M01*wVec;

      Expr soln = prob.formSolnExpr(tuple(uVec, wVec));
      Expr uSoln = soln[0];

      /*
        write the solution in a form readable by matlab
      */
      uSoln.matlabDump(cout);
      ofstream of("T.dat");
      uSoln.matlabDump(of);

      /*
        compute the error and represent as a discrete function
      */
      Expr exactSoln = 0.5*x*(1.0-x);
      Expr errorExpr = new DiscreteFunction(discreteSpace, exactSoln-uSoln, 
                                            "errorExpr");

      /*
        compute the norm of the error
      */
      double errorNorm = errorExpr.norm();
      double tolerance = 1.0e-10;

      /*
        decide if the error is within tolerance
      */
      Testing::passFailCheck(__FILE__, errorNorm, tolerance);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
    }
  catch(exception& e)
    {
      TSFOut::println(e.what());
      Testing::crash(__FILE__);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
    }
  MPIComm::finalize();
}

