/* 

  ****************   NO WARRANTY  *****************

Since the Aspirin/MIGRAINES system is licensed free of charge,
the MITRE Corporation provides absolutley no warranty. Should
the Aspirin/MIGRAINES system prove defective, you must assume
the cost of all necessary servicing, repair or correction.
In no way will the MITRE Corporation be liable to you for
damages, including any lost profits, lost monies, or other
special, incidental or consequential damages arising out of
the use or inability to use the Aspirin/MIGRAINES system.

  *****************   COPYRIGHT  *******************

This software is the copyright of The MITRE Corporation. 
It may be freely used and modified for research and development
purposes. We require a brief acknowledgement in any research
paper or other publication where this software has made a significant
contribution. If you wish to use it for commercial gain you must contact 
The MITRE Corporation for conditions of use. The MITRE Corporation 
provides absolutely NO WARRANTY for this software.

   January, 1992 
   Russell Leighton
   The MITRE Corporation
   7525 Colshire Dr.
   McLean, Va. 22102-3481

*/
#include <stdio.h>
#include "aspirin_bp.h"
#include "BpDatafile.h"
#include "BpConverge.h"

#ifdef SUNOS
#include <sys/time.h>
#include <sys/resource.h>
extern int      getrusage();
#endif

#define TRUE 1
#define FALSE 0
#define EPOCH -1

/* if the user did not supply an init function => nop */
#ifdef USER_INIT
extern void     user_init();
#else
void
user_init()
{;
}				/* noop */
#endif

#define DUMPFILE "Network"


/* Network Stuff from Aspirin file (must use -interface flag) */
extern float    BPlearning_rate;
extern float    BPinertia;
extern void     network_initialize();	/* start up */
extern void     network_forward();
extern void     network_learn();
extern void     network_dump();
extern LB_PTR   network_query();
extern void     network_forward_print();
extern void     network_forward_pdpfa();
extern void     network_ascii_dump();
extern void     network_load_ascii();

#define SET_LEARNING_RATE(x)     (BPlearning_rate = x)
#define GET_LEARNING_RATE()      (BPlearning_rate)
#define SET_INERTIA(x)           (BPinertia = x)
#define GET_INERTIA()            (BPinertia)
#define GET_N_CONNECTIONS()      ((network_query(0,0))->network_info.n_connections)
#define IS_TEMPORAL()            ((network_query(0,0))->network_info.temporal)

extern void     nn_ui_loop();	 /* run the ui */

/* Data File Stuff in libBpDatafile.a */
extern int      loaddata();	 /* load data from .df */
extern int      epoch_length();	 /* # patterns in ALL data files */
extern DATA    *data;		 /* => the data */
extern void     generator1(),	 /* only files */
                generator2(),	 /* only user generators */
                generator3();	 /* both */
extern VFPTR    set_generator(); /* sets i/o patterns */

/* disable_exceptions: Live life dangerously!  */
void
disable_exceptions()
{
#ifdef SUNOS
	extern void     abrupt_underflow_();
	abrupt_underflow_();
#endif

#ifdef MCOS
	extern int      fpcntl_trap_disable(), fpcntl_flush0_enable();
	(void) fpcntl_trap_disable();	/* turn off exceptions */
	(void) fpcntl_flush0_enable();	/* underflow -> 0 */
#endif

#ifdef MEIKO860
	/* extern void enableft();
	enableft(); 
	extern void _fp_underflow0();
	_fp_underflow0(); */
#endif

}				/* end disable_exceptions */


/* initialize:  Init */
static void
initialize(filename, ui, inputfile, learning_rate, inertia,
	   noise_type, mean, variance,
	   ascii_load, dumpfile_name)
	char           *filename;
	int             ui;
	char           *inputfile;
	float           learning_rate, inertia;
	int             noise_type;
	float           mean, variance;
        int             ascii_load;
        char           *dumpfile_name;
{
	int counter;

	/* disable underflow */
	disable_exceptions();

	/* init simulator, load dump, excute any user init code */

	/* execute user's initialization */
	(void) user_init();

#define VERBOSE 1
	network_initialize(filename, VERBOSE);

	/* read in ascii weights? */
	if (ascii_load) {
	  char save_string[128];

	  printf("\n\nLoading Ascii Weights from stdin...");
	  network_load_ascii();
	  printf("Done.");

	  sprintf(save_string, "%s.save", dumpfile_name);
	  printf("\nDumping %s.", dumpfile_name);
	  network_dump(save_string);

	  /* exit program */
	  printf("\nDone Converting Ascii Weights.\n");
	  exit(0);
	}/* end if */

	/* create list of output vectors for convergence test */
	Create_Test_List(network_query);

	/* load data? */
	if (inputfile != (char *) NULL) {
		printf("\n\nLoading Data Files...");

		if (loaddata(inputfile, network_query, noise_type, mean, variance)) {
			fprintf(stderr, "\n\nProblem reading data file.\n");
			exit(1);
		}
		printf("Done.");

		if (noise_type != NO_NOISE) {
			switch (noise_type) {
			case NORMAL_NOISE:
				printf("\nAdding normal noise to input data, N(%f,%f).",
				       mean, variance);
				break;
			case UNIFORM_NOISE:
				printf("\nAdding uniform noise to input data, U(%f,%f).",
				       mean, variance);
				break;
			}	/* end switch */
		}
	}			/* end if */
	/* override the learning rate */
	if (learning_rate >= 0.0)
		SET_LEARNING_RATE(learning_rate);
	if (inertia >= 0.0)
		SET_INERTIA(inertia);

}				/* end initialize */

static void
help(prog)
	char           *prog;
{
	printf("\n\nusage: %s [flags] [dump file]",
	       prog);
	printf("\n\nflags:");
	printf("\n\t[-d][-datafile <datafile>] read datafile (a .df file)");
	printf("\n\t[-a][-alpha <learning rate>] set learning rate");
	printf("\n\t[-i][-inertia <inertia>] set inertia");
	printf("\n\t[-F][-Filename <dump file name>] (\"Network\" default)");
	printf("\n\t[-l][-learn] learn without ui");
	printf("\n\t[-s][-save <iterations>] save to \"Network.save\" every <iterations> (5000 default)");
	printf("\n\t[-#] Append the current iteration number to the save file name");
	printf("\n\t[-t][-test <iterations> <passes> <bound>] test for convergence");
	printf("\n\t every <iterations> (5000 default) by going <passes>(default 100)");
	printf("\n\t through generators without an error exceeding +/- <bound>(default 0.1)");
	printf("\n\t[-N][-Notest ] never test for convergence");
	printf("\n\t[-I][-Iterations <max_iterations>] exit after max_iterations (default is unlimited)");
	printf("\n\t[-n <mean> <variance>] add normally distributed noise to inputs(very slow)");
	printf("\n\t[-u <mean> <variance>] add uniformly distributed noise to inputs(slow)");
	printf("\n\t[-f][-forward <iterations>] go forward <iterations> (used for stats and benchmarking)");
	printf("\n\t[-E][-Epoch] go forward one epoch (1 pass thru all data)\n\t(does not apply to user-defined generators!)");
	printf("\n\t[-p][-print] print outputs and targets (used with -f)");
	printf("\n\t[-P][-Pdpfa <threshold>] Calculate Pd and Pfa for -f <iterations>");
	printf("\n\t using <threshold> for detection threshold (L2 norm)");
	printf("\n\t[-A][-AsciiDump] print out all the weights and thresholds to stdout");
	printf("\n\t[-AsciiDumpNoFmt] print out all the weights and thresholds to stdout (no formatting, use with -L)");
	printf("\n\t[-L][-LoadAscii] read from stdin the results of -AsciiDumpNoFmt");
	printf("\n\t[-h][-help] this message");
	printf("\n");
	exit(0);

}				/* end help */


main(argc, argv)
	int             argc;
	char           *argv[];
{
	char           *prog = NULL;	/* simulation name */

	char           *filename = NULL;	/* dump file name */

	int             learn = FALSE;	/* learn without ui? */

	int             forward = FALSE;	/* just go forward */

	int             ui = TRUE;	/* using ui? */

	int             saverate = 5000;	/* how often to save */
	int             index_save = FALSE;	/* append iteration to save
						 * file name? */

	char            dumpfile_name[64];	/* string for dump file
						 * prefix */
	char            save_string[128];	/* string to use for filename */

	int             testrate = 5000;	/* how often to test */

	int             passes = 100;	/* this many times thru generators */

	float           bound = 0.1;	/* +/- bound for error */

	float           learning_rate = -1;	/* user set learning rate */

	float           inertia = -1;	/* user set inertia */

	char           *inputfile = NULL;	/* data file */

	int             noise_type = NO_NOISE;	/* kind of noise to add to
						 * inputs (default none) */
	float           mean, variance;	/* mean, variance of noise */

	FFPTR           error_test = (FFPTR) Max_Absolute_Error;	/* error test */

	VFPTR            network_generator;	/* input pattern generator */

	int             max_iterations = 0;	/* max number of iterations
						 * (0 => unlimited) */
	int             iteration = 0;	/* current iteration */

	int             verbose = 0;	/* print targets and outputs */

#ifdef SUNOS
	struct rusage   r_buffer;	/* for timing information */
#endif
	int             pdpfa = 0;	/* calc pd and pfa */
	float           dthreshold;	/* detection threshold */

	int             ascii_dump = FALSE, ascii_dump_fmt = TRUE;	/* print the state of the
						 * network? */
	int             ascii_load = FALSE;	/* read the state of the
						 * network from stdin ? */



	sprintf(dumpfile_name, "%s", DUMPFILE);

	/* parse argument string */
	argc--;
	prog = *argv++;
	while (argc) {
		if (*(*argv) == '-') {	/* flag */
			switch (*(*argv + 1)) {
			case 'a':{	/* override learning rate */
					argc--;
					sscanf(*(++argv), "%f", &learning_rate);
					break;
				}	/* end case */
			case 'i':{	/* override inertia */
					argc--;
					sscanf(*(++argv), "%f", &inertia);
					break;
				}	/* end case */
			case 'l':{	/* learn */
					learn = TRUE;
					ui = FALSE;
					break;
				}	/* end case */
			case 's':{	/* iterations between saves */
					argc--;
					sscanf(*(++argv), "%d", &saverate);
					break;
				}	/* end case */
			case 'F':{	/* name of file */
					argc--;
					sscanf(*(++argv), "%s", dumpfile_name);
					break;
				}	/* end case */
			case '#':{	/* append current iteration to save
					 * file name */
					index_save = TRUE;
					break;
				}	/* end case */
			case 'A':{	/* print the state of the network */
					ascii_dump = TRUE;
					ui = FALSE;
					if (! strcmp(*argv,"-AsciiDumpNoFmt"))
					  ascii_dump_fmt = FALSE;
					break;
				}	/* end case */
			case 'L':{	/* read the state of the network thru stdin */
					ascii_load = TRUE;
					break;
				}	/* end case */
			case 'f':{	/* iterations forward */
					learn = FALSE;
					ui = FALSE;
					argc--;
					sscanf(*(++argv), "%d", &forward);
					break;
				}	/* end case */
			case 'E':{	/* iterations forward */
					learn = FALSE;
					ui = FALSE;
					forward = EPOCH;
					break;
				}	/* end case */
			case 'p':{	/* print targets and outputs */
					verbose = TRUE;
					break;
				}	/* end case */
			case 'd':{	/* load datafiles */
					argc--;
					inputfile = *(++argv);
					break;
				}	/* end case */
			case 'P':{	/* iterations for pd pfa calc */
					pdpfa = TRUE;
					argc--;
					sscanf(*(++argv), "%f", &dthreshold);
					break;
				}	/* end case */
			case 'h':{	/* help */
					help(prog);
					break;
				}	/* end case */
			case 't':{	/* iterations btwn tests, time thru
					 * generators, bound */
					argc--;
					sscanf(*(++argv), "%d", &testrate);
					argc--;
					sscanf(*(++argv), "%d", &passes);
					argc--;
					sscanf(*(++argv), "%f", &bound);

					break;
				}	/* end case */
			case 'n':{	/* add normally distributed noise */
					noise_type = NORMAL_NOISE;
					argc--;
					sscanf(*(++argv), "%f", &mean);
					argc--;
					sscanf(*(++argv), "%f", &variance);
					break;
				}	/* end case */
			case 'u':{	/* add uniformly distributed noise */
					noise_type = UNIFORM_NOISE;
					argc--;
					sscanf(*(++argv), "%f", &mean);
					argc--;
					sscanf(*(++argv), "%f", &variance);
					break;
				}	/* end case */
			case 'I':{	/* max iterations */
					argc--;
					sscanf(*(++argv), "%d", &max_iterations);
					break;
				}	/* end case */
			case 'N':{	/* no error test */
					error_test = (FFPTR) NULL;
					break;
				}	/* end case */
			default:{	/* file name */
					fprintf(stderr, "Unknown flag: %s", *argv);
					help(prog);
				}	/* end default */
			}	/* end switch */
		} else {	/* file name */
			if (filename != (char *) NULL) {
				fprintf(stderr, "\nMulitple filenames:");
				fprintf(stderr, "\n\t%s and %s\n", filename, *argv);
				help(prog);
			} else
				filename = *argv;
		}		/* end if else */
		argc--;
		argv++;
	}			/* end while argc */


	/* init the system */
	initialize(filename, ui, inputfile,
		   learning_rate, inertia,
		   noise_type, mean, variance,
		   ascii_load, dumpfile_name);

	if (learn) {
		int             batch;

		printf("\n\nLearning...");


		/*
		 * if there is feedback you can't stop and test...messes up
		 * things
		 */
		if (IS_TEMPORAL() && error_test != (FFPTR) NULL) {
			fprintf(stderr, "\nSorry cannot test if this is temporal network! Use -N option.\n");
			exit(1);
		}		/* end if */
		/* keep stupid people from using too high a learning rate */
		printf("\n\nLearning Rate: %f", GET_LEARNING_RATE());
		printf("\nInertia: %f", GET_INERTIA());
		if (GET_LEARNING_RATE() >= 0.1) {
			printf("\nYour learning rate is a bit high,");
			printf("\nif this does not converge try lowering it.");
		}		/* end if */
		/* set the generator used */
		network_generator = set_generator(1);
		if (network_generator == NULL)
			exit(1);

		/*
		 * make both saverate and test rate divisible by
		 * min(saverate,testrate)
		 */
		if (error_test == (FFPTR) NULL) {
			printf("\nNot testing.");
			/* min */
			batch = saverate;
		} else {
			/* min */
			batch = (saverate > testrate) ? testrate : saverate;
			testrate = (testrate / batch) * batch;
			saverate = (saverate / batch) * batch;
			printf("\nTesting every %d iterations.", testrate);
			printf("\nMust pass %d patterns with %f error bound.",
			       passes, bound);
		}		/* end else */
		printf("\nSaving every %d iterations.", saverate);

		if (max_iterations) {
			if (batch > max_iterations)
				batch = max_iterations;
			printf("\nTruncating Learning after %d iterations.",
			       max_iterations);
		}		/* end if */
		/* learn... */
		while (TRUE) {
			network_learn(batch, network_generator);
			iteration += batch;

			/* save state? */
			if (!(iteration % saverate)) {
				if (index_save) {
					sprintf(save_string, "%s%d.save", dumpfile_name, iteration);
					printf("\nDumping %s at %d.", dumpfile_name, iteration);
					network_dump(save_string);
				} else {
					sprintf(save_string, "%s.save", dumpfile_name);
					printf("\nDumping %s at %d.", dumpfile_name, iteration);
					network_dump(save_string);
				}	/* end if */
				fflush(stdout);
			}	/* end if */
			/* done ? */
			if (max_iterations) {
				if (iteration >= max_iterations) {
					sprintf(save_string, "%s.save", dumpfile_name);
					printf("\nDumping %s at %d.", dumpfile_name, iteration);
					network_dump(save_string);
					goto done;
				}
				/* ? go over on next cyle ? */
				if (batch + iteration > max_iterations)
					batch = max_iterations - iteration;
			}	/* end if */
			/* ? convergence test ? */
			if (error_test != (FFPTR) NULL && !(iteration % testrate)) {
				int             counter = passes;

				while (counter--) {
					/* forward */
					network_forward(1, network_generator);
					if (error_test() > bound)
						break;
				}	/* end while */

				/* thru n passes */
				if (counter == -1) {
					sprintf(save_string, "%s.Finished", dumpfile_name);
					network_dump(save_string);
					printf("\nSuccess!! Dumping %s.\n", save_string);
					goto done;
				}	/* end if */
			}	/* end if */
		}		/* end while */

	} else if (forward) {	/* just go forward */

		printf("\n\nForward...");

		/* set the generator used */
		network_generator = set_generator(1);
		if (network_generator == NULL)
			exit(1);

		if (forward == EPOCH)
			forward = epoch_length();

		iteration = forward;

		if (pdpfa) {	/* calc pd and pfa */

			/* go forward, find relative frequencies */
			network_forward_pdpfa(forward, network_generator, dthreshold);

		} else if (verbose) {	/* if printing outputs/targets... */

			/* go forward, print each set of targets/outputs */
			network_forward_print(forward, network_generator);

		} else {
			/* forward */
			network_forward(forward, network_generator);
		}		/* end if else */

	} else if (ui) {	/* user interface */

		/* run */
	  network_generator = set_generator(0);
	  nn_ui_loop(prog, network_generator);

	}			/* end else */
done:

	/*--------------------  print out useful information ---------------------*/

	if (learn) {
		printf("\n\nLearning Rate: %f", GET_LEARNING_RATE());
		printf("\nInertia: %f", GET_INERTIA());
	}			/* end if */

	/* print out the state of the network? */
	if (ascii_dump)
		network_ascii_dump(ascii_dump_fmt);

#ifdef SUNOS
	/* time */
	(void) getrusage(RUSAGE_SELF, &r_buffer);
	printf("\n\nElapsed compute time: %d.%d seconds",
	       r_buffer.ru_utime.tv_sec,
	       r_buffer.ru_utime.tv_usec);
	printf("\nElapsed system time: %d.%d seconds",
	       r_buffer.ru_stime.tv_sec,
	       r_buffer.ru_stime.tv_usec);
	if (ui == FALSE) {
	  double time;
	  
	  time = ((double) r_buffer.ru_utime.tv_sec + (0.000001 * (double) r_buffer.ru_utime.tv_usec));	  if (time == 0.0)
	    printf("\nAverage performance on this task = 0.0 connections per second\n");
	  else
	    printf("\nAverage performance on this task = %g connections per second\n",
		   (((double) (GET_N_CONNECTIONS()) * iteration) / time));
	}                       /* end if ui */
#endif

	/* iterations */
	if (ui == FALSE) {
	  printf("\n\nTotal Iterations %d", iteration);
	  if ((VFPTR) network_generator == (VFPTR) generator1){	/* files only */
	    printf("\nTotal Epochs %f", (float)iteration / (float)epoch_length() );
	  } /* end if */
	} /* end if */
	printf("\nTotal number of connections: %d\n",
	       (GET_N_CONNECTIONS()));

	/* done */
	exit(0);
}				/* end main */
