#include "Sundance.h"


/*

TwoZoneNonlinearHeat2D solves a multimaterial heat transfer problem
on a rectangular domain. The left half of the domain is filled with a 
simple conducting material, and the linear heat conduction equation applies.
In the right half of the domain, heat is transported by radiation diffusion.
The top and bottom boundaries are insulating; the left and right boundaries are
held at fixed temperatures 0 and 1. Flux conservation holds at the interface 
between the two subdomains; this condition is a natural boundary condition
and is satisfied without being imposed explicitly.

The nonlinear equation is solved with backtracking Newton's method. 
The linear and nonlinear subdomain problems
are bundled together into a single problem.


 */


bool inLeftZone(const Cell& cell);
bool inRightZone(const Cell& cell);

ListBatch solnFunc(const PointBatch& pts);

int main(int argc, void** argv)
{
  try
    {
      PMachine::init(&argc, &argv);


      // -------- create a rectangular mesh and cell sets -------------
      int nx = 20;
      int ny = 5;
      Mesh mesh = rectMesh(-1.0, 1.0, nx, -1.0, 1.0, ny);

      // label the cells in the left and right zones.
      mesh.labelCells(2, "leftZone", inLeftZone);
      mesh.labelCells(2, "rightZone", inRightZone);

      // define cell sets
      CellSet leftEdge(1, "left");
      CellSet rightEdge(1, "right");
      CellSet leftZone(2, "leftZone");
      CellSet rightZone(2, "rightZone");

      // --------- construct an expression for the initial guess ---------

      Expr x = new CoordExpr(0, "x");
      Expr tInit = (x+1.0)/2.0;
      double tLeft = 0.0;
      double tRight = 1.0;

      Expr T0 = new DiscreteFunction(mesh, tInit, Lagrange(1), "T0");

      Expr TStep0 = new DiscreteFunction(mesh, T0, 
                                         Lagrange(1), "TStep0");

			
      // ------- define the symbolic linearized problem ----------------

      Expr varT = new TestFunction(Lagrange(1), "varT");
      Expr dT = new UnknownFunction(Lagrange(1), "dT");
			
      Expr dx = new Derivative(0,1);
      Expr dy = new Derivative(1,1);
      Expr grad = List(dx, dy);
			
      Expr newton = (grad*varT)*(grad*(pow(T0, 4) + 4.0*pow(T0, 3)*dT));
      Expr linear = (grad*varT)*(grad*(T0 + dT));

      Expr eqn = Integral(rightZone, newton) + Integral(leftZone, linear);
			
      EssentialBC bc = 
        EssentialBC(leftEdge, (T0+dT-tLeft)*varT) 
        && EssentialBC(rightEdge, (T0+dT-tRight)*varT) ;
			
      LinearSolver solver = new BICGSTABSolver(1.e-12, 5000);
      StaticLinearProblem prob(mesh, eqn, bc, varT, dT, solver);


      // ----------------- define the symbolic representation -------------
      // --------------------- of the residual function -------------------

      Expr residExpr = new UnknownFunction(Lagrange(1), "resid");
      Expr rightResid 
        = varT*residExpr + (grad*varT)*(grad*(pow(TStep0,4)));
      Expr leftResid 
        = varT*residExpr + (grad*varT)*(grad*TStep0);
      Expr residEqn = Integral(leftZone, leftResid) 
        + Integral(rightZone, rightResid);

      EssentialBC residBC = 
        EssentialBC(leftEdge, varT*(residExpr + (tLeft - TStep0)))
        && EssentialBC(rightEdge, varT*(residExpr + ( tRight - TStep0)));
			
      StaticLinearProblem residProb(mesh, residEqn, residBC, varT, 
                                    residExpr, 
                                    new BICGSTABSolver(1.e-12, 5000));


      // ------- solve for the residual given the initial guess ----------
      Expr residSoln;			
      residProb.solve(residSoln);

      double prevResid = residSoln.norm();


      // ---------------- backtracking Newton algorithm -----------------
      // ----------Terminate when the residual is less than tol ---------

      int maxIters = 100;
      double tol = 1.0e-6;
      double resid;
      double tResid = 1.0;
      int iter=0;

      while (tResid > tol && iter < maxIters)
        {
          TSFOut::printf("Newton step #%d", iter);
          Expr dTSoln;
          prob.solve(dTSoln);

          double step = 1.0;
          bool decrease = false;
          int b = 0;
					
          while (!decrease)
            {
              TSFOut::printf("fraction of full step %g", step);
              TStep0 = new DiscreteFunction(mesh, T0+step*dTSoln, 
                                            Lagrange(1), "TStep0");
              b++;
              if (b > 10) TSFError::raise("too many backtracks");
              residProb.solve(residSoln);
              resid = residSoln.norm();
              TSFOut::printf("resid=%g prevResid=%g", resid, prevResid);

							
              if (resid < prevResid) 
                {
                  decrease = true;
                  prevResid = resid;
                }
              else
                {
                  decrease = false;
                  step = 0.5*step;
                }
            }

          iter++;

          T0 = TStep0;
					
          tResid = step*dTSoln.norm();
          TSFOut::printf("norm(step*dtSoln)=%g", tResid);

        }

      if (iter >= maxIters) 
        {
          TSFOut::printf("BT Newton failed to converge after %d steps", 
                         iter);
        }
      else
        {
          TSFOut::printf("BT Newton converged after %d steps", 
                         iter);
        }
			
      // compute the error and represent as a discrete function
      Expr exactSoln(solnFunc, Scalar(), "soln");
      Expr ex = new DiscreteFunction(mesh, exactSoln,
                                     Lagrange(1), "errorExpr");
      Expr errorExpr = new DiscreteFunction(mesh, exactSoln-T0, 
                                            Lagrange(1), "errorExpr");

      ofstream of1("exact.dat");
      ofstream of2("approx.dat");
      ex.matlabDump(of1);
      T0.matlabDump(of2);

      // compute the norm of the error
      double errorNorm = errorExpr.norm();
      double tolerance = 2.5e-3;

      Testing::passFailCheck(__FILE__, errorNorm, tolerance);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
    }
  catch(exception& e)
    {
      e.print();
      Testing::crash(__FILE__);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
    }

		

}


bool inLeftZone(const Cell& cell)
{
  // x < 0.0
  const Point& a = cell.facet(0,0).point(0);
  const Point& b = cell.facet(0,1).point(0);
  const Point& c = cell.facet(0,2).point(0);
		
  if (a[0] <= 0.0 && b[0] <= 0.0 && c[0] <= 0.0)
    {
      return true;
    }
  return false;
}

bool inRightZone(const Cell& cell)
{
  return !inLeftZone(cell);
}


/*

 Function to evaluate exact soln at a point batch. 

 The problem is multimaterial, so the solution will take a different form
 in each subdomain. Therefore, for each cell on which we evaluate, we
 first identify the subdomain in which it lives and branch accordingly.

*/

ListBatch solnFunc(const PointBatch& pts)
{
  // These constants come from numerical solution of the algebraic equations
  // for the boundary and jump conditions. 
  double alpha = 0.72449;
  double beta = 0.275508;

  const TSFArray<Point>& p = pts.physPts();
  const Cell& cell = pts.cell();
  Vector soln(p.length());

  // loop over points in batch, evaluating function for every point.
  for (int i=0; i<p.length(); i++)
    {
      const Point& x = p[i];
      // identify current subdomain, and use the solution for that subdomain.
      if (inLeftZone(cell))
        {
          soln[i] = alpha*(x[0]+1.0);
        }
      else
        {
          soln[i] = pow(alpha*x[0] + beta, 0.25);
        }
    }
  return ListBatch(soln);
}

