#include "geppetto.h"
#include "proto.h"

const int fitnessCases = 4;

static char d0, d1;

#define dtRoot		dtUserDef0
#define dtWeight	dtUserDef1

#define ErrorDivideByZero	ErrorUserDefined+0
const char *MsgDivideByZero =	"Divide by Zero";
#define ErrorMemAlloc		ErrorUserDefined+1
const char *MsgMemAlloc =	"Memory Allocation";

static result *opList P((int, const object **, void *));
static result *opP P((int, const object **, void *));
static result *opW P((int, const object **, void *));
static result *opAdd P((int, const object **, void *));
static result *opSubtract P((int, const object **, void *));
static result *opMultiply P((int, const object **, void *));
static result *opDivide P((int, const object **, void *));
static objectList *neuralTerminals P((NOARGS));
static objectList *neuralFunctions P((NOARGS));
static void *neuralCaseInitialize P((int, int));
static void neuralCaseFitness P((result *, int, int *, double *, double *,
				 void *));
static int neuralTerminateRun P((int, int, double, double));

static result *
opList(argc, argv, envp)
int argc;
const object **argv;
void *envp;
{
  objectList *olp;
  int i;
  result *rp;

  /* build a list to hold the arguments */
  olp = objectListCreate(2);
  if (!olp)
    return(resultCreate(dtError, ErrorMemAlloc));

  for (i = 0; i < argc; i++) {

    /* evaluate next argument */
    rp = objectEval(argv[i], envp);
    if (resultIsError(rp)) {
      objectListFree(olp);
      return(rp);
    }

    /* try to add result to list */
    if (objectListAdd(olp, rp)) {
      resultFree(rp);
      objectListFree(olp);
      return(resultCreate(dtError, ErrorMemAlloc));
    }
  }

  /* create result and return it */
  return(resultCreate(dtList, olp));
}

static result *
opP(argc, argv, envp)
int argc;
const object **argv;
void *envp;
{
  result *r1, *r2;
  float val;

  /* evaluate arguments */
  r1 = objectEval(argv[0], envp);
  if (resultIsError(r1))
    return(r1);
  r2 = objectEval(argv[1], envp);
  if (resultIsError(r2)) {
    resultFree(r1);
    return(r2);
  }

  /* make sure both results are valid */
  if (!resultIsFloat(r1) || !resultIsFloat(r2))
    resultSetError(r1, ErrorBadDataType);
  else {
    val = resultFloat(r1) + resultFloat(r2);
#ifdef DEBUG_NEURAL_NOT
    printf("P(%f + %f) = %s\n", resultFloat(r1), resultFloat(r2),
	   (val >= 1.0 ? "TRUE" : "FALSE"));
#endif /* DEBUG_NEURAL */
    resultSetBoolean(r1, (val >= 1.0));
    resultSetType(r1, dtRoot);
  }

  /* return the result */
  resultFree(r2);
  return(r1);
}

static result *
opW(argc, argv, envp)
int argc;
const object **argv;
void *envp;
{
  result *r1, *r2;

  /* evaluate arguments */
  r1 = objectEval(argv[0], envp);
  if (resultIsError(r1))
    return(r1);
  r2 = objectEval(argv[1], envp);
  if (resultIsError(r2)) {
    resultFree(r1);
    return(r2);
  }

  if (!resultIsFloat(r1) || !resultIsBoolean(r2))
    resultSetError(r1, ErrorBadDataType);
  else {

#ifdef DEBUG_NEURAL_NOT
    printf("W(%f * %s) = %f\n", resultFloat(r1),
	   (resultBoolean(r2) ? "TRUE" : "FALSE"),
	   (resultBoolean(r2) ? resultFloat(r1) : 0.0));
#endif /* DEBUG_NEURAL */

    resultSetFloat(r1, (resultBoolean(r2) == 0 ? 0.0 : resultFloat(r1)));
  }

  /* return the result */
  resultFree(r2);
  return(r1);
}

static result *
opAdd(argc, argv, envp)
int argc;
const object **argv;
void *envp;
{
  result *r1, *r2;

  /* evaluate arguments */
  r1 = objectEval(argv[0], envp);
  if (resultIsError(r1))
    return(r1);
  r2 = objectEval(argv[1], envp);
  if (resultIsError(r2)) {
    resultFree(r1);
    return(r2);
  }

  /* make sure both results are valid */
  if (!resultIsFloat(r1) || !resultIsFloat(r2))
    resultSetError(r1, ErrorBadDataType);
  else
    resultSetFloat(r1, resultFloat(r1) + resultFloat(r2));

  /* return the result */
  resultFree(r2);
  return(r1);
}

static result *
opSubtract(argc, argv, envp)
int argc;
const object **argv;
void *envp;
{
  result *r1, *r2;

  /* evaluate arguments */
  r1 = objectEval(argv[0], envp);
  if (resultIsError(r1))
    return(r1);
  r2 = objectEval(argv[1], envp);
  if (resultIsError(r2)) {
    resultFree(r1);
    return(r2);
  }

  /* make sure both results are valid */
  if (!resultIsFloat(r1) || !resultIsFloat(r2))
    resultSetError(r1, ErrorBadDataType);
  else
    resultSetFloat(r1, resultFloat(r1) - resultFloat(r2));

  /* return the result */
  resultFree(r2);
  return(r1);
}

static result *
opMultiply(argc, argv, envp)
int argc;
const object **argv;
void *envp;
{
  result *r1, *r2;

  /* evaluate arguments */
  r1 = objectEval(argv[0], envp);
  if (resultIsError(r1))
    return(r1);
  r2 = objectEval(argv[1], envp);
  if (resultIsError(r2)) {
    resultFree(r1);
    return(r2);
  }

  /* make sure both results are valid */
  if (!resultIsFloat(r1) || !resultIsFloat(r2))
    resultSetError(r1, ErrorBadDataType);
  else
    resultSetFloat(r1, resultFloat(r1) * resultFloat(r2));

  /* return the result */
  resultFree(r2);
  return(r1);
}

static result *
opDivide(argc, argv, envp)
int argc;
const object **argv;
void *envp;
{
  result *r1, *r2;

  /* evaluate arguments */
  r1 = objectEval(argv[0], envp);
  if (resultIsError(r1))
    return(r1);
  r2 = objectEval(argv[1], envp);
  if (resultIsError(r2)) {
    resultFree(r1);
    return(r2);
  }

  /* make sure both results are valid */
  if (!resultIsFloat(r1) || !resultIsFloat(r2))
    resultSetError(r1, ErrorBadDataType);
  else if (resultFloat(r2) == 0.0)
    resultSetError(r1, ErrorDivideByZero);
  else
    resultSetFloat(r1, resultFloat(r1) * resultFloat(r2));

  /* return the result */
  resultFree(r2);
  return(r1);
}

static objectList *
neuralTerminals()
{
  objectList *list;
  variable *vp;
  constantSrc *csp;

  list = objectListCreate(3);
  if (list) {

    vp = variableCreate(dtBoolean, "D0", &d0);
    if (!vp || objectListAdd(list, vp)) {
      variableFree(vp);
      objectListFree(list);
      return(0);
    }

    vp = variableCreate(dtBoolean, "D1", &d1);
    if (!vp || objectListAdd(list, vp)) {
      variableFree(vp);
      objectListFree(list);
      return(0);
    }

    csp = floatSrcCreate(-2.0, 2.0);
    if (!csp || objectListAdd(list, csp)) {
      constantSrcFree(csp);
      objectListFree(list);
      return(0);
    }
  }

  return(list);

}

static objectList *
neuralFunctions()
{
  objectList *list;
  operatorSrc *osp;

  list = objectListCreate(7);
  if (list) {

    osp = complexOperatorSrcCreate("+", opAdd, dtMath, 2, 2, dtMath, dtMath);
    if (objectListAdd(list, osp)) {
      operatorSrcFree(osp);
      objectListFree(list);
      return(0);
    }

    osp = complexOperatorSrcCreate("-", opSubtract, dtMath, 2, 2, dtMath,
				   dtMath);
    if (objectListAdd(list, osp)) {
      operatorSrcFree(osp);
      objectListFree(list);
      return(0);
    }

    osp = complexOperatorSrcCreate("*", opMultiply, dtMath, 2, 2, dtMath,
				   dtMath);
    if (objectListAdd(list, osp)) {
      operatorSrcFree(osp);
      objectListFree(list);
      return(0);
    }

    osp = complexOperatorSrcCreate("/", opDivide, dtMath, 2, 2, dtMath,
				   dtMath);
    if (objectListAdd(list, osp)) {
      operatorSrcFree(osp);
      objectListFree(list);
      return(0);
    }

    osp = complexOperatorSrcCreate("W", opW, dtWeight, 2, 2, dtFloat,
			    dtRoot|dtBoolean);
    if (objectListAdd(list, osp)) {
      operatorSrcFree(osp);
      objectListFree(list);
      return(0);
    }

    osp = complexOperatorSrcCreate("P", opP, dtRoot, 2, 2, dtWeight, dtWeight);
    if (objectListAdd(list, osp)) {
      operatorSrcFree(osp);
      objectListFree(list);
      return(0);
    }

    osp = complexOperatorSrcCreate("List", opList, dtList, 2, 2, dtRoot,
				   dtRoot);
    if (objectListAdd(list, osp)) {
      operatorSrcFree(osp);
      objectListFree(list);
      return(0);
    }
  }

  return(list);
}

static void *
neuralCaseInitialize(popNum, fc)
int popNum;
int fc;
{
  d0 = ((fc & 0x1) != 0);
  d1 = ((fc & 0x2) != 0);
  return(0);
}

static void
neuralCaseFitness(rp, fc, hitp, rawp, stdp, envp)
result *rp;
int fc;
int *hitp;
double *rawp;
double *stdp;
void *envp;
{
  objectList *olp;
  object *op;
  char v0, v1;
  int value;
  double std;

#ifdef DEBUG_NEURAL_NOT
  {
    charString *cstr;

    cstr = charStringCreate();
    charStringSet(cstr, "Result is");
    resultToString(rp, cstr);
    charStringPrint(cstr);
    charStringFree(cstr);
  }
#endif /* DEBUG_NEURAL */

  /* the result should be a list */
  if (!resultIsList(rp))
    *stdp += 4.0;
  else {
    olp = resultListPtr(rp);

    /* make sure first value is a Root value */
    op = objectListEntry(olp, 0);
    if ((objectType(op) != otResult) || (objectDataType(op) != dtRoot))
      *stdp += 4.0;
    else {
      v0 = resultBoolean((result *)op);

      /* make sure second value is a Root value */
      op = objectListEntry(olp, 1);
      if ((objectType(op) != otResult) || (objectDataType(op) != dtRoot))
	*stdp += 4.0;
      else {
	v1 = resultBoolean((result *)op);

	/* see if the adder worked */
	value = (d0 ? 1 : 0) + (d1 ? 1 : 0);
	std = 0.0;
	if (((value & 0x1) != 0) != v0)
	  std += 1.0;
	if (((value & 0x2) != 0) != v1)
	  std += 2.0;

#ifdef DEBUG_NEURAL
	printf("IN=%1d%1d  OUT=%1d%1d CORRECT=%1d%1d STD=%3.1f\n",
	       d1, d0, v1, v0, (value & 0x2 ? 1 : 0), value & 0x1, std);
#endif /* DEBUG_NEURAL */

	/* set hits and standardized fitness */
	if (std == 0)
	  *hitp += 1;
	else
	  *stdp += std;
      }
    }
  }
}

static int
neuralTerminateRun(popNum, hits, raw, std)
int popNum;
int hits;
double raw;
double std;
{
  return(hits == fitnessCases);
}

void
appInitialize(gp, pop, popNum)
void *gp;
population *pop;
int popNum;
{
  datatypeMakeAlias(dtRoot, dtBoolean);
  datatypeMakeAlias(dtWeight, dtFloat);

  errorCodeSetMsgPtr(ErrorDivideByZero, &MsgDivideByZero);
  errorCodeSetMsgPtr(ErrorMemAlloc, &MsgMemAlloc);

  populationSetFitnessCases(pop, fitnessCases);
  populationSetTerminalList(pop, neuralTerminals());
  populationSetFunctionList(pop, neuralFunctions());
  populationSetReturnTypes(pop, dtList);
  populationSetCaseInitializeFunc(pop, neuralCaseInitialize);
  populationSetCaseFitnessFunc(pop, neuralCaseFitness);
  populationSetTerminateRunFunc(pop, neuralTerminateRun);
}

#ifdef DEBUG_NEURAL

const char *proglist[] = {
  "(List (P (W 1.0 (P (W -1.0 D0) (W 1.0 D1))) (W 1.0 (P (W 1.0 D0) (W -1.0 D1)))) (P (W 0.5 D0) (W 0.5 D1)))",
  0
};

#endif /* DEBUG_NEURAL */
