#ifdef INCLUDE_MALLOC_H
#include <malloc.h>
#endif
#include <math.h>
#include "geppetto.h"
#include "proto.h"

#ifndef HUGE_VAL
#define HUGE_VAL	HUGE
#endif /* !HUGE_VAL */

static objectList *symregTerminals P((NOARGS));
static result *opAdd P((const result **, void *));
static result *opSubtract P((const result **, void *));
static result *opMultiply P((const result **, void *));
static result *opDivide P((const result **, void *));
static result *opSine P((const result **, void *));
static result *opCosine P((const result **, void *));
static result *opExponential P((const result **, void *));
static result *opLogarithm P((const result **, void *));
static int printCInfix P((const operator *, charString *));
static int printCFunction P((const operator *, charString *));
static objectList *symregFunctions P((NOARGS));
static void *symregCaseInitialize P((int, int));
static void symregCaseFitness P((result *, int, int *, double *, double *,
				 void *));
static int symregTerminateRun P((int, int, double, double));

const int fitnessCases = 20;

#define ErrorDivideByZero	ErrorUserDefined+0
const char *MsgDivideByZero =	"Divide by Zero";
#define ErrorMathFunction	ErrorUserDefined+1
const char *MsgMathFunction =	"Math Function";

static struct srvalues {
  float x;
  float answer;
} fitnessCase[20];

static float x;

static int mathError;

static objectList *
symregTerminals()
{
  objectList *list;
  variable *vp;

  list = objectListCreate(1);
  if (list) {

    vp = variableCreate(dtFloat, "X", &x);
    if (objectListAdd(list, vp)) {
      variableFree(vp);
      objectListFree(list);
      return(0);
    }
  }

  return(list);
}

int
matherr(exc)
struct exception *exc;
{
  mathError = 1;
  return(1);
}

static result *
opAdd(argv, envp)
const result **argv;
void *envp;
{
  float fval;

  mathError = 0;
  fval = resultFloat(argv[0]) + resultFloat(argv[1]);
  if (mathError)
    return(resultCreate(dtError, ErrorMathFunction));

  return(resultCreate(dtFloat, fval));
}

static result *
opSubtract(argv, envp)
const result **argv;
void *envp;
{
  float fval;

  mathError = 0;
  fval = resultFloat(argv[0]) - resultFloat(argv[1]);
  if (mathError)
    return(resultCreate(dtError, ErrorMathFunction));

  return(resultCreate(dtFloat, fval));
}

static result *
opMultiply(argv, envp)
const result **argv;
void *envp;
{
  float fval;

  mathError = 0;
  fval = resultFloat(argv[0]) * resultFloat(argv[1]);
  if (mathError)
    return(resultCreate(dtError, ErrorMathFunction));

  return(resultCreate(dtFloat, fval));
}

static result *
opDivide(argv, envp)
const result **argv;
void *envp;
{
  float fval;

  if (resultFloat(argv[1]) == 0.0)
    return(resultCreate(dtError, ErrorDivideByZero));

  mathError = 0;
  fval = resultFloat(argv[0]) / resultFloat(argv[1]);
  if (mathError)
    return(resultCreate(dtError, ErrorMathFunction));

  return(resultCreate(dtFloat, fval));
}

static result *
opSine(argv, envp)
const result **argv;
void *envp;
{
  float fval;

  mathError = 0;
  fval = sin(resultFloat(argv[0]));
  if (mathError)
    return(resultCreate(dtError, ErrorMathFunction));

  return(resultCreate(dtFloat, fval));
}

static result *
opCosine(argv, envp)
const result **argv;
void *envp;
{
  float fval;

  mathError = 0;
  fval = cos(resultFloat(argv[0]));
  if (mathError)
    return(resultCreate(dtError, ErrorMathFunction));

  return(resultCreate(dtFloat, fval));
}

static result *
opExponential(argv, envp)
const result **argv;
void *envp;
{
  float fval;

  if (resultFloat(argv[0]) < -88.0 || resultFloat(argv[0]) > 88.0)
    return(resultCreate(dtError, ErrorMathFunction));

  mathError = 0;
  fval = exp(resultFloat(argv[0]));
  if (mathError)
    return(resultCreate(dtError, ErrorMathFunction));

  return(resultCreate(dtFloat, fval));
}

static result *
opLogarithm(argv, envp)
const result **argv;
void *envp;
{
  float fval;

  if (resultFloat(argv[0]) <= 0.0)
    return(resultCreate(dtError, ErrorMathFunction));

  mathError = 0;
  fval = log(resultFloat(argv[0]));
  if (mathError)
    return(resultCreate(dtError, ErrorMathFunction));

  return(resultCreate(dtFloat, fval));
}

static int
printCInfix(op, cstr)
const operator *op;
charString *cstr;
{
  int rval = 0;
  int i;
  int addParens;
  const object *arg;
  const char *opName, *argName;

  for (i = 0; !rval && i < operatorNumArgs(op); i++) {
    if (i > 0) {
      rval = charStringCatenate(cstr, " ");
      if (!rval) {
	rval = charStringCatenate(cstr, operatorSrcName(op->src));
	if (!rval)
	  rval = charStringCatenate(cstr, " ");
      }
    }
    if (!rval) {
      arg = operatorArg(op, i);
      addParens = 0;
      if (objectIsOperator(arg)) {
	opName = operatorName(op);
	argName = operatorName((const operator *)arg);
	if (*argName == '+' || *argName == '-') {
	  if (*opName != '+' && *opName != '-')
	    addParens = 1;
	} else if (*argName == '*' || *argName == '/') {
	  if (*opName != '*' && *opName != '/')
	    addParens = 1;
	}
      }
      if (addParens)
	rval = charStringCatenate(cstr, "(");
      if (!rval)
	rval = objectToString(arg, cstr);
      if (!rval && addParens)
	rval = charStringCatenate(cstr, ")");
    }
  }
  return(rval);
}

static int
printCFunction(op, cstr)
const operator *op;
charString *cstr;
{
  int rval = 0;
  int i;

  rval = charStringCatenate(cstr, operatorSrcName(op->src));
  if (!rval)
    rval = charStringCatenate(cstr, "(");
  if (!rval) {
    if (!rval) {
      for (i = 0; !rval && i < operatorNumArgs(op); i++) {
	if (i > 0)
	  rval = charStringCatenate(cstr, ", ");
	if (!rval)
	  rval = objectToString(operatorArg(op, i), cstr);
      }
    }
  }
  if (!rval)
    rval = charStringCatenate(cstr, ")");
  return(rval);
}

static objectList *
symregFunctions()
{
  objectList *list;
  operatorSrc *osp;

  list = objectListCreate(8);
  if (list) {

    osp = simpleOperatorSrcCreate("+", opAdd, 2);
    operatorSrcSetPrintProc(osp, printCInfix);
    if (objectListAdd(list, osp)) {
      operatorSrcFree(osp);
      objectListFree(list);
      return(0);
    }

    osp = simpleOperatorSrcCreate("-", opSubtract, 2);
    operatorSrcSetPrintProc(osp, printCInfix);
    if (objectListAdd(list, osp)) {
      operatorSrcFree(osp);
      objectListFree(list);
      return(0);
    }

    osp = simpleOperatorSrcCreate("*", opMultiply, 2);
    operatorSrcSetPrintProc(osp, printCInfix);
    if (objectListAdd(list, osp)) {
      operatorSrcFree(osp);
      objectListFree(list);
      return(0);
    }

    osp = simpleOperatorSrcCreate("/", opDivide, 2);
    operatorSrcSetPrintProc(osp, printCInfix);
    if (objectListAdd(list, osp)) {
      operatorSrcFree(osp);
      objectListFree(list);
      return(0);
    }

    osp = simpleOperatorSrcCreate("sin", opSine, 1);
    operatorSrcSetPrintProc(osp, printCFunction);
    if (objectListAdd(list, osp)) {
      operatorSrcFree(osp);
      objectListFree(list);
      return(0);
    }

    osp = simpleOperatorSrcCreate("cos", opCosine, 1);
    operatorSrcSetPrintProc(osp, printCFunction);
    if (objectListAdd(list, osp)) {
      operatorSrcFree(osp);
      objectListFree(list);
      return(0);
    }

    osp = simpleOperatorSrcCreate("exp", opExponential, 1);
    operatorSrcSetPrintProc(osp, printCFunction);
    if (objectListAdd(list, osp)) {
      operatorSrcFree(osp);
      objectListFree(list);
      return(0);
    }

    osp = simpleOperatorSrcCreate("log", opLogarithm, 1);
    operatorSrcSetPrintProc(osp, printCFunction);
    if (objectListAdd(list, osp)) {
      operatorSrcFree(osp);
      objectListFree(list);
      return(0);
    }
  }

  return(list);
}

static void *
symregCaseInitialize(popNum, fc)
int popNum;
int fc;
{
  x = fitnessCase[fc].x;
  return(0);
}

static void
symregCaseFitness(rp, fc, hitp, rawp, stdp, envp)
result *rp;
int fc;
int *hitp;
double *rawp;
double *stdp;
void *envp;
{
  float fit;

  if (resultIsError(rp)) {
    *stdp = HUGE_VAL;
  } else {

    if ((resultFloat(rp) > fitnessCase[fc].answer - .01) &&
	(resultFloat(rp) < fitnessCase[fc].answer + .01))
      *hitp += 1;

    if (*stdp != HUGE_VAL) {
      fit = resultFloat(rp) - fitnessCase[fc].answer;
      if (fit < 0)
        fit = -fit;
      *stdp += fit;
    }
#ifdef DEBUG_SYMREGRESS
  printf("Case %d, X=%f: Computed=%f, Actual=%f\n", fc, x, resultFloat(rp),
	 fitnessCase[fc].answer);
  printf("     Absolute Difference=%f, Std Fitness now %f\n", fit, *stdp);
#endif /* DEBUG_SYMREGRESS */
  }
}

static int
symregTerminateRun(popNum, hits, raw, std)
int popNum;
int hits;
double raw;
double std;
{
  return(hits == fitnessCases);
}

void
appInitialize(gp, pop, popNum)
void *gp;
population *pop;
int popNum;
{
  int i;

  /* initialize fitness cases */
  for (i = 0; i < fitnessCases; i++) {
    x = fitnessCase[i].x = drnd();
    fitnessCase[i].answer = x*x*x*x + x*x*x + x*x + x;
  }

  errorCodeSetMsgPtr(ErrorDivideByZero, &MsgDivideByZero);
  errorCodeSetMsgPtr(ErrorMathFunction, &MsgMathFunction);

  populationSetFitnessCases(pop, fitnessCases);
  populationSetTerminalList(pop, symregTerminals());
  populationSetFunctionList(pop, symregFunctions());
  populationSetCaseInitializeFunc(pop, symregCaseInitialize);
  populationSetCaseFitnessFunc(pop, symregCaseFitness);
  populationSetTerminateRunFunc(pop, symregTerminateRun);
}

#ifdef DEBUG_SYMREGRESS

#include "program.h"

const char *proglist[] = {
  "(+ X (* X (+ X (* X (+ X (* X X))))))",
  0
};

#endif /* DEBUG_SYMREGRESS */
