/* 

  ****************   NO WARRANTY  *****************

Since the Aspirin/MIGRAINES system is licensed free of charge,
the MITRE Corporation provides absolutley no warranty. Should
the Aspirin/MIGRAINES system prove defective, you must assume
the cost of all necessary servicing, repair or correction.
In no way will the MITRE Corporation be liable to you for
damages, including any lost profits, lost monies, or other
special, incidental or consequential damages arising out of
the use or inability to use the Aspirin/MIGRAINES system.

  *****************   COPYRIGHT  *******************

This software is the copyright of The MITRE Corporation. 
It may be freely used and modified for research and development
purposes. We require a brief acknowledgement in any research
paper or other publication where this software has made a significant
contribution. If you wish to use it for commercial gain you must contact 
The MITRE Corporation for conditions of use. The MITRE Corporation 
provides absolutely NO WARRANTY for this software.

   January, 1992 
   Russell Leighton
   The MITRE Corporation
   7525 Colshire Dr.
   McLean, Va. 22102-3481

*/

#include "bp_generator.h"

extern void _mistake();
extern int _connection_size();


/* declare_init_weights: Initialize the weights randomly unless
                         a data file has been specified or an init function.
			 Init Functions: 
			                 RANDOM_INIT - small random numbers
					 < User's C function > - function symbol
 */
static declare_init_weights(bbd)
     BD_PTR bbd;
{
  extern FILE *stream;
  DICTIONARYPTR d = bbd->lookup;
  int number = bbd->number;
  int size;
  CD_PTR connection;
  LD_PTR from_layer, to_layer = bbd->output_layer;
  
  /* if a datafile has been specified then load the black box */
  if (bbd->datafile != (char *)NULL)   {
    GenCode( "\n /* load state of %s */", bbd->name);
    GenCode( "\n error_code = load_black_box_data(\"%s\",\"%s\", \"%s\");",
	    bbd->name,
	    bbd->load_key,
	    bbd->datafile);
  } else  {
    /* randomize weights and set deltas to zero */
    GenCode( "\n\n /* initializing %s */", bbd->name);
    while (to_layer != (LD_PTR)NULL) {
      switch(to_layer->type) {
      case USER_LAYER_TYPE :
      case LINEAR_LAYER_TYPE :
      case QUAD_LAYER_TYPE :
      case PDP_LAYER_TYPE1 : 
      case PDP_LAYER_TYPE2 : 
      case PDP_LAYER_TYPE3 : {
	/* do thresholds */
	GenCode( "\n /* init for %s */", to_layer->name);
	GenCode("\n for(counter1 = 0; counter1<%d; counter1++) {",
		to_layer->n_nodes);
	
	/* init refection coeficients */
	{ int counter = to_layer->layer_order;
	  if (counter) {
	    GenCode("\n  /* self feed back connections */");
	    GenCode("\n  do { /* must choose stable set */");
	    do {
	      GenCode("\n    b%d_l%d_r%d[counter1] = BPfrandom(2.0 * init_range) - init_range;",
		      number, to_layer->number, counter);
	    } while(--counter);
	    switch (to_layer->layer_order) {
	    case 1 :
	      GenCode("\n  } while (BPunstable_1st(*(b%d_l%d_r1 + counter1)));",
		      number, to_layer->number);
	      break;
	    case 2 :
	      GenCode("\n  } while (BPunstable_2nd(*(b%d_l%d_r1 + counter1),*(b%d_l%d_r2 + counter1)));",
		      number, to_layer->number,
		      number, to_layer->number);
	      break;
	    case 3 :
	      GenCode("\n  } while (BPunstable_3rd(*(b%d_l%d_r1 + counter1),*(b%d_l%d_r2 + counter1),*(b%d_l%d_r3 + counter1)));",
		      number, to_layer->number,
		      number, to_layer->number,
		      number, to_layer->number);
	      break;
	    }/* end switch */
	  }/* end if counter */
	}/* end block */

	
	GenCode("\n  /* init thresholds (biases) */");
	
	/* how to init threholds? */
	switch(to_layer->init_function) {
	case RANDOM_INIT : {
	  GenCode("\n  b%d_l%d_t[counter1] = BPfrandom(2.0 * init_range) - init_range;",
		  number,
		  to_layer->number);
	  break;
	}/* end case */
	case CONSTANT_INIT : {
	  GenCode(
		  "\n  b%d_l%d_t[counter1] = %f;", 
		  number,
		  to_layer->number,
		  to_layer->bias);
	  break;
	}/* end case */
	case C_INIT : {
	  GenCode(
		  "\n  b%d_l%d_t[counter1] = %s(counter1);", 
		  number,
		  to_layer->number,
		  to_layer->C_bias_init);
	  break;
	}/* end case */
	  default :
	    {
	      _mistake("\nUnknown Init Function: declare_init_weights");
	    }/* end default */
	}/* end switch */
	
	GenCode( "\n }/* end for */");
	
	GenCode( "\n /* init weights for %s */",
		to_layer->name);
	connection = to_layer->inputs_from;
	while (connection != (CD_PTR)NULL)  {
	  size = _connection_size(connection, bbd);
	  
	  switch(connection->init_function) {
	  case RANDOM_INIT : {  /* initialize with small random numbers */
	    GenCode(
		    "\n /* initialize w/small random weights */");
	    GenCode(
		    "\n for(counter1 = 0; counter1<%d; counter1++) {",
		    size);
	    GenCode(
		    "\n  b%d_%s[counter1] = BPfrandom(2.0 * init_range) - init_range;",
		    number,
		    connection->array_name);
	    GenCode( "\n }/* end for */");
	    
	    break;
	  }/* end case */
	  case CONSTANT_INIT : { /* initialize with a constant number */
	    GenCode(
		    "\n for(counter1 = 0; counter1<%d; counter1++) {",
		    size);
	    GenCode(
		    "\n  b%d_%s[counter1] = %f;",
		    number,
		    connection->array_name,
		    connection->constant);
	    GenCode( "\n }/* end for */");
	    
	    break;
	  }/* end case */
	  case C_INIT : {  /* initialize with user's  filter */
	    GenCode(
		    "\n /* initialize weights with user's function */");
	    GenCode(
		    "\n bzero((char *)b%d_%s, %d * sizeof(float)); /* reset weights*/",
		    number,
		    connection->array_name,
		    size);
	    GenCode(
		    "\n weights = b%d_%s;",
		    number,
		    connection->array_name);
	    GenCode(
		    "\n for(counter4 = 0; counter4<%d; counter4++) {",
		    (connection->shared)?1:to_layer->ydim);
	    GenCode(
		    "\n  for(counter3 = 0; counter3<%d; counter3++) {",
		    (connection->shared)?1:to_layer->xdim);
	    GenCode(
		    "\n   weights = b%d_%s + (counter3 * %d) + (counter4 * %d);",
		    number,
		    connection->array_name,
		    connection->xrange,
		    to_layer->xdim * connection->xrange * connection->yrange);
	    GenCode(
		    "\n   for(counter1 = 0; counter1<%d; counter1++) {",
		    connection->yrange);
	    GenCode(
		    "\n    for(counter2 = 0; counter2<%d; counter2++) {",
		    connection->xrange);
	    /* call user's function with as many args as dimensions */
	    switch (connection->type) {
	    case NXM_CONNECTION_TYPE :
	    case TESS_1D_CONNECTION_TYPE : {
	      GenCode(
		      "\n     *weights++ = %s(counter2); /* pass x */",
		      connection->C_init);
	      break;
	    }/* end case */
	    case TESS_2D_CONNECTION_TYPE :  {
	      GenCode(
		      "\n     *weights++ = %s(counter2, counter1); /* pass x,y */",
		      connection->C_init);
	      break;
	    }/* end case */
	    }/* end switch */
	    GenCode(
		    "\n    }/* end for */");
	    /* don't need to jump when there is a shared tessellation
	       or only one tile in the x dimension
	       */
	    if ((connection->shared == 0) && (to_layer->xdim > 1))
	      GenCode(
		      "\n    weights += %d;",
		      ((to_layer->xdim - 1) *  connection->xrange));
	    GenCode(
		    "\n   }/* end for */");
	    GenCode(
		    "\n  }/* end for */");
	    GenCode(
		    "\n }/* end for */");
	    
	    break;
	  }/* end case */
	    default :
	      {
		_mistake("\nUnknown Init Function: declare_init_weights");
	      }/* end default */
	  }/* end switch */
	  /* next */
	  connection = connection->next;
	}/* end while */
	
	break;
      }/* end case */
	default :
	  {
	    _mistake("\nUnknown Layer Type: declare_init_weights");
	  }/* end default */
      }/* end switch */
      
      /* previous */
      to_layer = to_layer->previous;
    }/* end while */
  }/* end else */
  
} /* end declare_init_weights */

static void clear_each_bb(bbd)
     BD_PTR bbd;
{
  GenCode("\n\t%s_bb_clear_delays();", bbd->name);
}

/* _declare_init:   Declare the network initialization routine */
void _declare_init(network, aspirin_file, file)
     ND_PTR network;
     char *aspirin_file, *file;
{
  extern FILE *stream;

  /* control over clearing delays */
  GenCode("\n\nvoid %s_clear_delays()\n{", file);
  map_dictionary(clear_each_bb, network->lookup);
  GenCode("\n}/* end %s_clear_delays */", file);

  /* control over random seed and range */
  GenCode( "\n\nvoid %s_set_random_init_seed(x) long x; { init_seed = x; } \n", file);
  GenCode( "\n\nvoid %s_set_random_init_range(x) float x; { init_range = x; } \n", file);

  /* initialization routine */
  GenCode( "\n\nint %s_init_network()\n{", file);
  GenCode( "\n int error_code = 0;");
  GenCode( "\n int counter1, counter2, counter3, counter4;");
  GenCode( "\n float *weights;\n");

  GenCode( "\n BPfrandom_init(init_seed); /* init random number generator */\n");


  GenCode( "\n /* clear all data */");
  GenCode( "\n bzero((char *)network_data, %d * sizeof(float));",
	  network->size);

  GenCode( "\n /* init table lookup for sigmoid [-1,1] */");
  GenCode( "\n BPinit_sigmoid_table();");

  GenCode( "\n\n error_string[0] = '\\0'; /* empty string */");

  GenCode( "\n\n /* init comms_buffer with network information */");
  GenCode( "\n comms_buffer.network_info.aspirin_file = \"%s\";",
	  aspirin_file);
  GenCode( "\n comms_buffer.network_info.file = \"%s\";",
	  file);
  GenCode( "\n comms_buffer.network_info.temporal = %d;",
	  network->temporal);
  GenCode( "\n comms_buffer.network_info.n_black_boxes = %d;",
	  network->n_bbds);
  GenCode( "\n comms_buffer.network_info.n_nodes = %d;",
	  network->n_nodes);
  GenCode( "\n comms_buffer.network_info.n_connections = %d;",
	  network->n_connections);
  GenCode( "\n comms_buffer.network_info.set_learning_rate = (VFPTR)%s_set_learning_rate;",
	  file);
  GenCode( "\n comms_buffer.network_info.set_inertia = (VFPTR)%s_set_inertia;",
	  file);
  GenCode( "\n comms_buffer.network_info.get_learning_rate = (FFPTR)%s_get_learning_rate;",
	  file);
  GenCode( "\n comms_buffer.network_info.get_inertia = (FFPTR)%s_get_inertia;",
	  file);
  GenCode( "\n comms_buffer.network_info.set_random_init_seed = (VFPTR)%s_set_random_init_seed;",
	  file);
  GenCode( "\n comms_buffer.network_info.set_random_init_range = (VFPTR)%s_set_random_init_range;",
	  file);
  GenCode( "\n comms_buffer.network_info.init_network = (IFPTR)%s_init_network;",
	  file);
  GenCode( "\n comms_buffer.network_info.network_forward = (VFPTR)%s_network_forward;",
	  file);
  GenCode( "\n comms_buffer.network_info.network_learn = (VFPTR)%s_network_learn;",
	  file);

  GenCode( "\n comms_buffer.network_info.dump_network = (IFPTR)%s_dump_network;",
	  file);
  GenCode( "\n comms_buffer.network_info.load_network = (IFPTR)%s_load_network;",
	  file);
  GenCode( "\n comms_buffer.network_info.error_string = (PCFPTR)%s_error_string;",
	  file);

  GenCode( "\n /* hook up connection buffer */");
  GenCode( "\n comms_buffer.connections = connection_buffer;");


  map_dictionary(declare_init_weights, network->lookup);

  GenCode( "\n return(error_code);");
  GenCode( "\n}/* end %s_init_network */", file);
  
}/* end _declare_init */
