#include "Sundance.h"

namespace Sundance
{

  Expr cross(const Expr& A, const Expr& B)
  {
    return List(A[1]*B[2] - A[2]*B[1],
                A[0]*B[2] - A[2]*B[0],
                A[0]*B[1] - A[1]*B[0]);
  }

  Expr curl(const Expr& v)
  {
    Expr dx = new Derivative(0);
    Expr dy = new Derivative(1);
    Expr dz = new Derivative(2);
    Expr nabla = List(dx, dy, dz);
    return cross(nabla, v);
  }




  Expr div(const Expr& v)
  {
    Expr dx = new Derivative(0);
    Expr dy = new Derivative(1);
    Expr dz = new Derivative(2);
	
    Expr nabla = List(dx, dy, dz);
    return nabla*v;
  }

  Expr List(const Expr& a, const Expr& b, const Expr& c, const Expr& d,
            const Expr& e, const Expr& f, const Expr& g)
  {
    Expr rtn = List(a,b,c,d);
    rtn.append(e);
    rtn.append(f);
    rtn.append(g);
    return rtn;
  }

}

int main(int argc, void** argv)
{
  try
    {
      Sundance::init(&argc, &argv);
			
      double Lx = 1.0;
      double Ly = 1.0;
      double Lz = 1.0;
			
      int nx = 8;
      int ny = 8;
      int nz = 8;

      MeshGenerator mesher = new RectangleMesher(0.0, Lx, nx, 0.0, Ly, ny);
      Mesh mesh2 = mesher.getMesh();
      
      Mesh mesh = extrudeMesh(mesh2, 0.0, Lz, nz);

      /* define coordinate functions for x and y coordinates */
      Expr x = new CoordExpr(0);
      Expr y = new CoordExpr(1);
      Expr z = new CoordExpr(2);


      /* define cells sets for each of the six sides of the box */
      CellSet boundary = new BoundaryCellSet();
      CellSet left = boundary.subset( fabs(x) < 1.0e-9 );
      CellSet right = boundary.subset( fabs(x-Lx) < 1.0e-9 );
      CellSet front = boundary.subset( fabs(y) < 1.0e-9 );
      CellSet back = boundary.subset( fabs(y-Ly) < 1.0e-9 );
      CellSet bottom = boundary.subset( fabs(z) < 1.0e-9 );
      CellSet top = boundary.subset( fabs(z-Lz) < 1.0e-9 );

      /* define variations and unknowns for U, V, and P
       * using the same basis functions for velocity and pressure.
       * We will have to stabilize the system */
			
      BasisFamily velocityBasis = new Lagrange(1);
      Expr vx = new TestFunction(velocityBasis);
      Expr vy = new TestFunction(velocityBasis);
      Expr vz = new TestFunction(velocityBasis);
      Expr dux = new UnknownFunction(velocityBasis);
      Expr duy = new UnknownFunction(velocityBasis);
      Expr duz = new UnknownFunction(velocityBasis);

      /* test and unknown functions for magnetic field */
      Expr wx = new TestFunction(velocityBasis);
      Expr wy = new TestFunction(velocityBasis);
      Expr wz = new TestFunction(velocityBasis);
      Expr dBx = new UnknownFunction(velocityBasis);
      Expr dBy = new UnknownFunction(velocityBasis);
      Expr dBz = new UnknownFunction(velocityBasis);

      /* test and unknown functions for pressure */
      Expr q = new TestFunction(new Lagrange(1));
      Expr dp = new UnknownFunction(new Lagrange(1));

      /* initial guesses */
      BasisFamily lagr = new Lagrange(1);
      TSFVectorSpace discreteSpace 
        = new SundanceVectorSpace(mesh, 
                                  tuple(lagr, lagr, lagr, lagr, 
                                        lagr, lagr, lagr), 
                                  new PetraVectorType());

      Expr guess = List(Expr(0.0), Expr(0.0), Expr(0.0),
                        Expr(0.0), Expr(0.0), Expr(0.0), Expr(0.0));
      Expr phi0 = DiscreteFunction::discretize(discreteSpace, guess);
      Expr u0 = List(phi0[0], phi0[1], phi0[2]);
      Expr p0 = phi0[3];
      Expr B0 = List(phi0[4], phi0[5], phi0[6]);


      /* vectors for velocity test and unknown */
      Expr v = List(vx, vy, vz);
      Expr du = List(dux, duy, duz);

      /* vectors for B-field test and unknown */
      Expr w = List(wx, wy, wz);
      Expr dB = List(dBx, dBy, dBz);

      /* create differential operators for x and y directions, and
       * then form gradient operator. */
      Expr dx = new Derivative(0);
      Expr dy = new Derivative(1);
      Expr dz = new Derivative(2);
      Expr grad = List(dx, dy, dz);
			
      /* Define expressions for magnetic reynolds number, 
       * hartmann number, and interaction parameter */
      Expr Rm = new ParameterExpr(0.0001);
      Expr Ha = new ParameterExpr(1.0);
      Expr N = new ParameterExpr(1.0);
			

      /* linearized momentum continuity equation */
			
      Expr momentumEqn  = -(grad*v)*(grad*(u0+du))/Ha + (p0+dp)*div(v) 
        - v*(u0*grad)*(u0+du)/N - v*(du*grad)*u0/N 
        + v*cross(curl(B0 + dB), u0) + v*cross(curl(B0), du)/Rm;

      /* incompressibility constraint equation with stabilization */

      Expr beta = new ParameterExpr(0.02);
      Expr h = new CellDiameterExpr();

      Expr continuityEqn = q*div(du+u0) + beta*h*h*(grad*(p0+dp))*(grad*q);

      /* Ampere's and Faraday's law for magnetic field, dropping
       * displacement current */

      Expr magneticEqn = -curl(B0+dB)*curl(w) - div(w)*div(B0+dB)
        + Rm * cross(u0, B0+dB)*curl(w) + Rm * cross(du,B0)*curl(w);

      Expr eqn = Integral(momentumEqn, new GaussianQuadrature(1)) 
        + Integral(continuityEqn, new GaussianQuadrature(1))
        + Integral(magneticEqn, new GaussianQuadrature(1));



      /* boundary conditions:
       * v=0 everywhere
       * u=1 on top, u=0 elsewhere
       * p floats everywhere
       * B is (0,0,1) on all boundary surfaces
       **/
      Expr Bbc = List(Expr(0.0), Expr(0.0), Expr(1.0));
      Expr driving = List(Expr(1.0), Expr(0.0), Expr(0.0));
      CellSet everythingButTop = boundary - top;
      EssentialBC bc = EssentialBC(top, v*(u0 + du - driving)
                                   + w*(B0 + dB - Bbc))
        && EssentialBC(everythingButTop, v*(u0+du) + w*(B0 + dB - Bbc));

      /*
        Create a solver object: stablized biconjugate gradient solver
      */
      TSFPreconditionerFactory prec = new ILUKPreconditionerFactory(2);
      TSFLinearSolver solver = new BICGSTABSolver(prec, 1.0e-6, 400);
      solver.setVerbosityLevel(4);

      Expr tests = List(vx, vy, vz, q, wx, wy, wz);
      Expr unks = List(dux, duy, duz, dp, dBx, dBy, dBz);
      StaticLinearProblem linearizedProb(mesh, eqn, bc, tests, unks,
                                         new PetraVectorType());

			
      /* Continuation loop: to improve convergence, we solve a sequence of
       * problems at increasing Reynolds number, using the solution to
       * each as an initial guess to the next. */
      int nr = 2;
      for (int r=0; r<nr; r++)
        {
          /* set the next value of the reynolds number */
          double rm = 1.0 + 1.0 * ((double) r)/((double) nr-1);
          TSFOut::printf("doing Re=%g\n", rm);
          Rm.setParameterValue(rm);
					
          NewtonLinearization newton(linearizedProb, phi0, solver);
          NewtonSolver newtonSolver(solver, 150, 1.0e-4, 1.0e-4);
          Expr soln = newton.solve(newtonSolver);
					
          /*
            write the solution in a form readable by matlab
          */
          char fName[100];
          sprintf(fName, "psi-%g.dat", rm);
          FieldWriter psiWriter = new MatlabWriter(fName);
          psiWriter.writeField(soln[0]);
        }
    }
  catch(exception& e)
    {
      TSFOut::println(e.what());
      Testing::crash(__FILE__);
      Testing::timeStamp(__FILE__, __DATE__, __TIME__);
    }
  Timer::report();
  MPIComm::finalize();
}







