#include "Sundance.h"

/*

Solve the algebraic equation u^2=2.0 using Newton's method.

 */


int main(int argc, void** argv)
{
  try
    {
      PMachine::init(&argc, &argv);

      // we need a mesh, even for this simple problem.
      int nx = 1;
      int ny = 1;
      Mesh mesh = rectMesh(0.0, 1.0, nx, 0.0, 1.0, ny);

      // initial guess
      Expr u0 = new DiscreteFunction(mesh, 1.0, Lagrange(1), "u0");
			
      // define variation and unknown
      Expr varu = new TestFunction(Lagrange(1), "varu");
      Expr du = new UnknownFunction(Lagrange(1), "du");
			
      // linearized form of u^2 - 2.0 == 0.0
      Expr linearizedEqn = varu*(u0*u0 + 2.0*u0*du - 2.0);

      // variational statement of problem
      Expr eqn = Integral(linearizedEqn);
			
      // solver for linear system
      LinearSolver solver = new BICGSTABSolver(1.e-14, 5000);

      // object for linearized problem
      StaticLinearProblem prob(mesh, eqn, varu, du, solver);

      // give the linear problem and initial guess to a nonlinear solver
      //NewtonSolver ns(prob, u0);

      int maxiters = 20;
      double tol = 1.0e-12;

      int i =0 ;
      bool converged = false;

      while (i < maxiters)
        {
          Expr step;
          prob.solve(step);
          double stepNorm = step.norm();
          u0 = new DiscreteFunction(mesh, u0+step, Lagrange(1), "u0");
          TSFOut::printf("iter %d", i);
          u0.matlabDump(cout);
          if (stepNorm < tol) {converged = true; break;}
          i++;
        }

      if (converged) TSFOut::print("converged");
      else TSFOut::print("failed to converge");
      // do the nonlinear solve
      //Expr soln;
      //ns.solve(u0, soln);
			
      // compare to exact solution
      Expr exactSoln = sqrt(2.0);
      Expr errorExpr = new DiscreteFunction(mesh, exactSoln-u0, 
                                            Lagrange(1), "errorExpr");
			
      // compute the norm of the error
      double errorNorm = errorExpr.norm();
      double tolerance = 1.0e-15;

      Testing::passFailCheck(__FILE__, errorNorm, tolerance);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
    }
  catch(exception& e)
    {
      e.print();
      Testing::crash(__FILE__);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
    }
			
				

}



	

