#include "Sundance.h"
#include <stdio.h>

using namespace Sundance;

using namespace TSF;

/*

Solve the algebraic equation u^2=2.0 using Newton's method.

 */


int main(int argc, void** argv)
{
  try
    {
      MPIComm::init(&argc, &argv);

      // we need a mesh, even for this simple problem.
      int nx = 1;
      Mesh mesh = lineMesh(0.0, 1.0, nx);


      TSFVectorSpace discreteSpace 
        = new SundanceVectorSpace(mesh, new Lagrange(1),
                                  new DenseSerialVectorType());

      // initial guess
      Expr u0 = new DiscreteFunction(discreteSpace, 1.0, "u0");
			
      // define variation and unknown
      Expr varu = new TestFunction(new Lagrange(1), "varu");
      Expr du = new UnknownFunction(new Lagrange(1), "du");
			
      // linearized form of u^2 - 2.0 == 0.0
      Expr linearizedEqn = varu*(u0*u0 + 2.0*u0*du - 2.0);

      // variational statement of problem
      Expr eqn = Integral(linearizedEqn);
			
      // object for linearized problem
      StaticLinearProblem prob(mesh, eqn, varu, du, 
                               new DenseSerialVectorType());

      /*
        Create a solver object: stablized biconjugate gradient solver
      */
      TSFLinearSolver solver = new DirectSolver();

      int maxIters = 100;
      double tol = 1.0e-10;
      double resid = 1;
      int iter=0;

      while (resid > tol && iter < maxIters)
        {	
          cerr << "Newton step# " << iter;
          Expr duSoln = prob.solve(solver);
          cerr << "matrix: " << endl << prob.getOperator() << endl;
          cerr << "rhs: " << endl << prob.getRHS() << endl;
          resid = duSoln.norm();
          cerr << " resid=" << resid << endl;
          iter++;
          u0 = new DiscreteFunction(discreteSpace, u0+duSoln, "u0");
          prob.flushMatrixValues();
        }
			
      // compare to exact solution
      Expr exactSoln = sqrt(2.0);
      Expr errorExpr = new DiscreteFunction(discreteSpace, exactSoln-u0, 
                                            "errorExpr");
			
      // compute the norm of the error
      double errorNorm = errorExpr.norm();
      double tolerance = 1.0e-15;

      Testing::passFailCheck(__FILE__, errorNorm, tolerance);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
    }
  catch(exception& e)
    {
      cerr << e.what() << endl;
      Testing::crash(__FILE__);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
    }
			
				

}



	

