#include "Sundance.h"


/** \example optPoisson1D.cpp
 * Solves the following PDE-constrained optimization problem:
 * minimize 
 * 1/2 Integral[ (uHat - u)^2 ] + R/2 Sum[alpha^2]
 * subject to 
 * Laplacian[ u ] = Sum[ alpha_n sin(n x) ]
 * and BCs u(0)=u(pi)=0.
 * 
 * Take uHat=sin(x) for the target function. 
 * The domain is the 1D line [0, pi]. The solution is u=u0 sin(x), where
 * u0 = pi/2 / (R + pi/2). 
 */


int main(int argc, void** argv)
{
  try
    {
      Sundance::init(&argc, &argv);

      /*
        Create a mesh object. In this example, we will use a built-in method
        to create a uniform mesh on the unit line. In more realistic problems
        we would use a mesher to create a mesh, and then read the mesh using
        a MeshReader object. 
      */
      int n = 20;
      const double pi = 4.0*atan(1.0);
      MeshGenerator mesher = new LineMesher(0.0, pi, n);
      Mesh mesh = mesher.getMesh().getSubmesh();

      /* Define a symbolic object to represent the x coordinate function. */
      Expr x = new CoordExpr(0);

      Expr psi = List(sin(x), sin(2.0*x), sin(3.0*x));

      Expr target = sin(x);

      /*
       * Define a cell set that contains all boundary cells 
       */
      CellSet boundary = new BoundaryCellSet();
      /*
       *	Define a cell set that includes all cells at position x=0.
       */
      CellSet left = boundary.subset( fabs(x - 0.0) < 1.0e-10 );

      /*
       *	Define a cell set that includes all cells at position x=1.
       */
      CellSet right = boundary.subset( fabs(x - pi) < 1.0e-10 );



      /*
        Define an unknown function and its variation. The constructor
        argument is the basis family with which the function will be
        represented, in this case second-order Lagrange (nodal) polynomials.
      */
      Expr u = new UnknownFunction(new Lagrange(2));
      Expr v = u.variation();

      Expr lambda = new UnknownFunction(new Lagrange(2));
      Expr mu = lambda.variation();

      Expr alpha1 = new UnknownParameter();
      Expr alpha2 = new UnknownParameter();
      Expr alpha3 = new UnknownParameter();
      Expr alpha = List(alpha1, alpha2, alpha3);
      Expr beta = alpha.variation();
			
      TSFVectorType petra = new PetraVectorType();
      TSFVectorType dense = new DenseSerialVectorType();

      TSFArray<Block> unks = tuple(Block(alpha, dense), Block(u, lambda, petra));
			
      TSFArray<Block> vars = tuple(Block(beta, dense), Block(mu, v, petra));

      Expr forcing = alpha * psi;



      /*
        Define the differentiation operator of order 1 in direction 0.
      */
      Expr dx = new Derivative(0);

      /* write the objective function in symbolic form */
      double R = 0.1;

      Expr objectiveFunction = 0.5*Integral(pow(u-target, 2.0))
        + 0.5*R*alpha*alpha;

      /* Write the Lagrangian. The constraint is integrated by parts. 
       * The Lagrangian also contains an essential BC component. */
      Expr lagrangian = objectiveFunction - Integral((dx*u)*(dx*lambda))
        - Integral(lambda*forcing);

      EssentialBC lagrBC = EssentialBC(left, u*lambda) && EssentialBC(right, u*lambda);

      /* Take variations of the Lagrangian and Lagrangian BC to obtain first-order
       * necessary conditions. */
      Expr eqn = lagrangian.variation(List(u, lambda, alpha));
		
      EssentialBC bc = lagrBC.variation(List(u, lambda, alpha));

      /*
        Create a solver object: stablized biconjugate gradient solver
      */
      TSFPreconditionerFactory prec = new ILUKPreconditionerFactory(1);
      TSFLinearSolver innerSolver = new BICGSTABSolver(1.0e-14, 1000);
      TSFLinearSolver outerSolver = new BICGSTABSolver(1.0e-10, 1000);

			
      TSFLinearSolver solver = new SchurComplementSolver(innerSolver, outerSolver);
			

      /*
        Combine the geometry, the variational form, the BCs, and the solver
        to form a complete problem.
      */
      StaticLinearProblem prob(mesh, eqn, bc, vars, unks);
      prob.printRowMaps();
      mesh.printCells();

      /*
        solve the problem, obtaining the solution as a (discrete) Expr object
      */
      Expr soln = prob.solve(solver);

      /*
        write the solution in a form readable by matlab
      */
      FieldWriter writer = new MatlabWriter();
      cerr << "u" << endl;
      writer.writeField(soln[1][0]);
      cerr << "lambda" << endl;
      writer.writeField(soln[1][1]);

      cerr << soln[0] << endl;

      /*
        compute the error and represent as a discrete function
      */

      double u0 = pi/2.0/(pi/2.0 + R);
      Expr exactSoln = u0*sin(x);

      /*
        compute the norm of the error
      */
      double errorNorm = (soln[1][0] - exactSoln).norm(2);
      double tolerance = 1.0e-4;

      /*
        decide if the error is within tolerance
      */
      Testing::passFailCheck(__FILE__, errorNorm, tolerance);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);

    }
  catch(exception& e)
    {
      TSFOut::println(e.what());
      Testing::crash(__FILE__);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
    }
  Sundance::finalize();
}

