/* 

  ****************   NO WARRANTY  *****************

Since the Aspirin/MIGRAINES system is licensed free of charge,
the MITRE Corporation provides absolutley no warranty. Should
the Aspirin/MIGRAINES system prove defective, you must assume
the cost of all necessary servicing, repair or correction.
In no way will the MITRE Corporation be liable to you for
damages, including any lost profits, lost monies, or other
special, incidental or consequential damages arising out of
the use or inability to use the Aspirin/MIGRAINES system.

  *****************   COPYRIGHT  *******************

This software is the copyright of The MITRE Corporation. 
It may be freely used and modified for research and development
purposes. We require a brief acknowledgement in any research
paper or other publication where this software has made a significant
contribution. If you wish to use it for commercial gain you must contact 
The MITRE Corporation for conditions of use. The MITRE Corporation 
provides absolutely NO WARRANTY for this software.

   January, 1992 
   Russell Leighton
   The MITRE Corporation
   7525 Colshire Dr.
   McLean, Va. 22102-3481

*/
#include "bp_generator.h"

extern void _mistake();
extern int _connection_size();

/***   Learning Functions ***/
typedef struct bblist {
  BD_PTR bbox;         /* pointer to black box */
  int number;
  struct bblist *next;     
  struct bblist *previous;     
}BBL, *BBLPTR;

static BBLPTR black_box_list  = NULL;
static int index_sum;

/* list_black_boxes:  Create a double linked list. */
static int list_black_boxes(bb)
     BD_PTR bb;
{
  BBLPTR list_cell;
  
  /* if it is an output to the network... */
  if (black_box_list == (BBLPTR)NULL) {
    list_cell = (BBLPTR)am_alloc_mem(sizeof(BBL));
    list_cell->number = bb->number;
    bb->number = 1; /* reset for sort */        
    list_cell->bbox = bb;
    list_cell->next = (BBLPTR)NULL;
    list_cell->previous = (BBLPTR)NULL;
    black_box_list = list_cell;
  } else {
    list_cell = (BBLPTR)am_alloc_mem(sizeof(BBL));
    list_cell->number = bb->number;   
    bb->number = 1; /* reset for sort */
    list_cell->bbox = bb;
    list_cell->next = black_box_list;
    black_box_list->previous = list_cell;
    black_box_list = list_cell;
    black_box_list->previous = (BBLPTR)NULL;
  }/* end if else */
  
}/* end list_black_boxes */

/* sum_indices: My number = sum over the numbers of input black boxes. */
static sum_indices(bbd)
    BD_PTR bbd;
{
  index_sum += bbd->number;
}/* end sum_indices */

/* propagate_labels: for each bb sum_indices...results in proper ordering
                     by index.
 */
static propagate_labels(bbd)
     BD_PTR bbd;
{
  index_sum = 0;
  map_dictionary(sum_indices, bbd->input_bbds);
  bbd->number += index_sum;
}/* end propagate_labels */

/* reorder_black_boxes: Sort black boxes. */
static reorder_black_boxes()
{
  extern BBLPTR black_box_list;
  BBLPTR list, temp_list;
  ND_PTR network;
  int iterations;
  register int counter1, swapped;

  network = black_box_list->bbox->network;
  iterations = network->n_bbds;

  /* order indices */
  while(--iterations) {
    map_dictionary(propagate_labels, network->lookup);
  }/* end while */

  /* bubble sort in descending order */
  iterations = network->n_bbds - 1;
  swapped = 1;
  while(swapped) {
    swapped = 0;
    list = black_box_list;
    for(counter1=0; counter1<iterations; counter1++) {
      /* test order...swap */
      if (list->bbox->number < list->next->bbox->number) {
	/* reset head? */
	if (black_box_list == list)
	  black_box_list = temp_list = list->next;
	else
	  temp_list = list->next;
	/* link previous element */
	temp_list->previous = list->previous;
	if (list->previous != (BBLPTR)NULL)
	  list->previous->next = temp_list;
	/* link last element */
	list->next = temp_list->next;
	if (temp_list->next != (BBLPTR)NULL)
	  temp_list->next->previous = list;
	/* link temp_list */
	temp_list->next = list;
	list->previous = temp_list;
	/* not done */
	swapped = 1;
      }/* end if */
      else
	/* next */
	list = list->next;
    }/* end for counter1 */
  }/* end while swapped */

  /* put indices back */
  list = black_box_list;
  while(list != (BBLPTR)NULL) {
    list->bbox->number = list->number;
    list = list->next;
  }/*  end while */
    
}/* end reorder_black_boxes */

/* order_black_boxes: Make order list of black boxes, such that
                      they produce the longest path from inputs
		      to outputs.
 */
static order_black_boxes(network)
     ND_PTR network;
{

  /* make an initial list */
  map_dictionary(list_black_boxes, network->lookup);

  /* reorder black boxes */
  reorder_black_boxes();

}/* end order_black_boxes */



/* declare_network_forward: Forward prop all of the black boxes */
static declare_network_forward(file)
     char *file;
{
  register BBLPTR list = black_box_list;

  GenCode("\n\n/* network_forward: Propagate all bb's forward.*/");
  GenCode("\nvoid %s_network_forward()\n{", file);

  GenCode("\n /* all black boxes forward! */");

  /* find the end */
  while(list->next != (BBLPTR)NULL) {
    list = list->next;
  }/* end while */
  /* forward */
  while(list != (BBLPTR)NULL) {
    GenCode("\n  %s_propagate_forward();",
	    list->bbox->name);
    list = list->previous;
  }/* end while */
  

  GenCode("\n}/* end %s_network_forward */", file);

}/* end declare_network_forward */

/* ls_make_buffers: Create scratch storage for line search */
static void ls_make_buffers(bbd)
     BD_PTR bbd;
{
  LD_PTR layer = bbd->layers;
  int bb_number = bbd->number;
  
  while(layer != (LD_PTR)NULL) {
    CD_PTR c = layer->inputs_from;
    int l_number = layer->number;
    int l_size = layer->n_nodes;
    
    /* biases */
    GenCode("\n\t   static float b%d_l%d_tb[%d], b%d_l%d_dtb[%d],b%d_l%d_acb[%d];",
	    bb_number, l_number, l_size,
	    bb_number, l_number, l_size,
	    bb_number, l_number, l_size);
    
    /* input connections */
    while (c != (CD_PTR)NULL) {
      int c_size;
      char *c_name;
      
      c_size = _connection_size(c,bbd);
      c_name = c->array_name;
      
      GenCode("\n\t   static float b%d_%sb[%d], b%d_%sdb[%d],b%d_%sacb[%d];",
	      bb_number, c_name, c_size,
	      bb_number, c_name, c_size,
	      bb_number, c_name, c_size);
      
      c = c->next;
    }/* end while c */
    
    layer = layer->next;
  } /* end while layer */
  
}/* end ls_make_buffers */

/* ls_copy_buffers: Create scratch storage for line search */
static void ls_copy_buffers(bbd)
     BD_PTR bbd;
{
  LD_PTR layer = bbd->layers;
  int bb_number = bbd->number;
  
  while(layer != (LD_PTR)NULL) {
    CD_PTR c = layer->inputs_from;
    int l_number = layer->number;
    int l_size = layer->n_nodes;
    
    /* biases */
    GenCode("\n\t bcopy((char *)b%d_l%d_t, (char *)b%d_l%d_tb, %d*sizeof(float));",
	    bb_number, l_number, 
	    bb_number, l_number, 
	    l_size);
    GenCode("\n\t bcopy((char *)b%d_l%d_dt, (char *)b%d_l%d_dtb, %d*sizeof(float));",
	    bb_number, l_number, 
	    bb_number, l_number, 
	    l_size);
    GenCode("\n\t bcopy((char *)b%d_l%d_ac, (char *)b%d_l%d_acb, %d*sizeof(float));",
	    bb_number, l_number, 
	    bb_number, l_number, 
	    l_size);
    
    /* input connections */
    while (c != (CD_PTR)NULL) {
      int c_size;
      char *c_name;
      
      c_size = _connection_size(c,bbd);
      c_name = c->array_name;
      
      GenCode("\n\t bcopy((char *)b%d_%s, (char *)b%d_%sb, %d*sizeof(float));",
	      bb_number, c_name, 
	      bb_number, c_name, 
	      c_size);
      GenCode("\n\t bcopy((char *)b%d_%sd, (char *)b%d_%sdb, %d*sizeof(float));",
	      bb_number, c_name, 
	      bb_number, c_name, 
	      c_size);
      GenCode("\n\t bcopy((char *)b%d_%sac, (char *)b%d_%sacb, %d*sizeof(float));",
	      bb_number, c_name, 
	      bb_number, c_name, 
	      c_size);

      c = c->next;
    }/* end while c */
    
    layer = layer->next;
  } /* end while layer */
  
}/* end ls_copy_buffers */


/* ls_calc_mu: Create scratch storage for line search */
static void ls_calc_mu(bbd)
     BD_PTR bbd;
{
  LD_PTR layer = bbd->layers;
  int bb_number = bbd->number;
  
  while(layer != (LD_PTR)NULL) {
    CD_PTR c = layer->inputs_from;
    int l_number = layer->number;
    int l_size = layer->n_nodes;
    
    /* biases */
    GenCode("\n\t mu += BPsum_squares(b%d_l%d_ac, %d);",
	    bb_number, l_number,
	    l_size);
    
    /* input connections */
    while (c != (CD_PTR)NULL) {
      int c_size;
      char *c_name;
      
      c_size = _connection_size(c,bbd);
      c_name = c->array_name;
      
      GenCode("\n\t mu += BPsum_squares(b%d_%sac, %d);",
	      bb_number, c_name,
	      c_size);
      
      c = c->next;
    }/* end while c */
    
    layer = layer->next;
  } /* end while layer */
  
}/* end ls_calc_mu */


/* ls_scale_grad: Create scratch storage for line search */
static void ls_scale_grad(bbd)
     BD_PTR bbd;
{
  LD_PTR layer = bbd->layers;
  int bb_number = bbd->number;


  while(layer != (LD_PTR)NULL) {
    CD_PTR c = layer->inputs_from;
    int l_number = layer->number;
    int l_size = layer->n_nodes;
    
    GenCode("\n\t /* biases */");
    GenCode("\n\t BPvsmul(b%d_l%d_ac, b%d_l%d_ac, alpha, %d);",
	    bb_number, l_number, 
	    bb_number, l_number, 
	    l_size);
    
    /* ar? */
    { int counter = layer->layer_order;
      if (counter) {
	GenCode("\n\t /* self feed back connection weights  */");
	do {
	  GenCode("\n\t BPvsmul(b%d_l%d_r%dac, b%d_l%d_r%dac, alpha, %d);",
		  bb_number, l_number, counter,
		  bb_number, l_number, counter,
		  l_size);
	} while(--counter);
      }/* end if counter */
    }/* end block */
    
    /* input connections */
    GenCode("\n\t /* connection weights  */");
    while (c != (CD_PTR)NULL) {
      int c_size;
      char *c_name;
      
      c_size = _connection_size(c,bbd);
      c_name = c->array_name;
      
      GenCode("\n\t BPvsmul(b%d_%sac, b%d_%sac, alpha, %d);",
	      bb_number, c_name, 
	      bb_number, c_name, 
	      c_size);
      c = c->next;
    }/* end while c */
    
    layer = layer->next;
  } /* end while layer */


      
}/* end ls_scale_grad */


/* ls_mag_grad: Calc magnitude squared of gradient */
static void ls_mag_grad(bbd)
     BD_PTR bbd;
{
  LD_PTR layer = bbd->layers;
  int bb_number = bbd->number;


  while(layer != (LD_PTR)NULL) {
    CD_PTR c = layer->inputs_from;
    int l_number = layer->number;
    int l_size = layer->n_nodes;
    
    GenCode("\n\t /* biases */");
    GenCode("\n\t mag_grad += BPsum_squares(b%d_l%d_ac, %d);",
	    bb_number, l_number, 
	    l_size);
    
    /* ar? */
    { int counter = layer->layer_order;
      if (counter) {
	GenCode("\n\t /* self feed back connection weights  */");
	do {
	  GenCode("\n\t mag_grad += BPsum_squares(b%d_l%d_r%dac,%d);",
		  bb_number, l_number, counter,
		  l_size);
	} while(--counter);
      }/* end if counter */
    }/* end block */
    
    /* input connections */
    GenCode("\n\t /* connection weights  */");
    while (c != (CD_PTR)NULL) {
      int c_size;
      char *c_name;
      
      c_size = _connection_size(c,bbd);
      c_name = c->array_name;
      
      GenCode("\n\t mag_grad += BPsum_squares(b%d_%sac, %d);",
	      bb_number, c_name, 
	      c_size);
      c = c->next;
    }/* end while c */
    
    layer = layer->next;
  } /* end while layer */


      
}/* end ls_mag_grad */

/* ls_dot_grads:  calc old_grad dot new_grad (using cpp trick) */
static void ls_dot_grads(bbd)
     BD_PTR bbd;
{
  LD_PTR layer = bbd->layers;
  int bb_number = bbd->number;


  while(layer != (LD_PTR)NULL) {
    CD_PTR c = layer->inputs_from;
    int l_number = layer->number;
    int l_size = layer->n_nodes;
    
    GenCode("\n\t /* biases */");
    GenCode("\n\t dot_grads += BPvdot(b%d_l%d_ac,",
	    bb_number, l_number);
    GenCode("\n#define network_data network_data_buffer");
    GenCode("\n\t                     b%d_l%d_ac,",
	    bb_number, l_number);
    GenCode("\n#undef network_data");
    GenCode("\n\t                     %d);",
	    l_size);
    
    /* ar? */
    { int counter = layer->layer_order;
      if (counter) {
	GenCode("\n\t /* self feed back connection weights  */");
	do {
	  GenCode("\n\t dot_grads += BPvdot(b%d_l%d_r%dac,",
		  bb_number, l_number, counter);
	  GenCode("\n#define network_data network_data_buffer");
	  GenCode("\n\t                     b%d_l%d_r%dac,",
		  bb_number, l_number, counter);
	  GenCode("\n#undef network_data");
	  GenCode("\n\t                     %d);",
		  l_size);
	} while(--counter);
      }/* end if counter */
    }/* end block */
    
    /* input connections */
    GenCode("\n\t /* connection weights  */");
    while (c != (CD_PTR)NULL) {
      int c_size;
      char *c_name;
      
      c_size = _connection_size(c,bbd);
      c_name = c->array_name;
      
      GenCode("\n\t dot_grads += BPvdot(b%d_%sac,",
	      bb_number, c_name);
      GenCode("\n#define network_data network_data_buffer");
      GenCode("\n\t                     b%d_%sac,",
	      bb_number, c_name);
      GenCode("\n#undef network_data");
      GenCode("\n\t                     %d);",
	      c_size);
      c = c->next;
    }/* end while c */
    
    layer = layer->next;
  } /* end while layer */


      
}/* end ls_dot_grads */

/*  declare_line_search: Code for simple line search */
static void declare_line_search(file)
     char *file;
{
  
  GenCode("\n\tline_search_counter = ( (line_search_counter + 1) %% %d);",
	  network->line_search_update);

  GenCode("\n\tif ( ! line_search_counter ) { /* line search */");
  GenCode("\n\t  float alpha=BPlearning_rate,inertia=BPinertia;");
  GenCode("\n\t  static float network_data_buffer[%d+PADDING]; /* declare state buffer for network */",
	  network->size);


  GenCode("\n\t  static int timeout_counter = 0, accept_counter = 0;");

  /* for conjugate gradient...just an adaptive interia */
  if (network->conjugate_gradient != NO_CONJ_GRAD) {
    GenCode("\n\t  static int conjugate_counter = 0;");

    GenCode("\n\t /* conjugate gradient */");
    GenCode("\n\t if ( ! (conjugate_counter %% %d) ) { /* reset */", network->n_connections);
    GenCode("\n\t    last_mag_grad = 1.0; inertia = 0.0;");
    GenCode("\n\t } else { /* select inertia */");
    GenCode("\n\t  float mag_grad=0.0, dot_grads=0.0;");

    GenCode("\n\n\t  /* calc mag of new grad */");
    {
      BBLPTR list = black_box_list;
      while(list != (BBLPTR)NULL) {
	/* if learning is allowed...*/
	if ( BBCALC_CREDIT(list->bbox->dynamic) ) {
	  
	  GenCode("\n\n\t/* mag grad of %s */", list->bbox->name);
	  ls_mag_grad(list->bbox);
	  
	}/* end if */
	list = list->next;
      }/* end while */
    }


    GenCode("\n\n\t  /* calc old_grad dot new_grad (using cpp trick) */");
    {
      BBLPTR list = black_box_list;
      while(list != (BBLPTR)NULL) {
	/* if learning is allowed...*/
	if ( BBCALC_CREDIT(list->bbox->dynamic) ) {
	  
	  GenCode("\n\n\t/* dot old and new grads of %s */", list->bbox->name);
	  ls_dot_grads(list->bbox);
	  
	}/* end if */
	list = list->next;
      }/* end while */
    }


    GenCode("\n\n\t  inertia = (mag_grad - dot_grads)/last_mag_grad;");
    GenCode("\n\n\t  if ( inertia > 10.0 ) {");
    GenCode("\n\n\t   inertia = 10.0; /* clip */");
    GenCode("\n\n\t  } else if (inertia < -10.0) {");
    GenCode("\n\n\t   inertia = -10.0; /* clip */");
    GenCode("\n\n\t  }");

    GenCode("\n\t } /* end select inertia */");

    GenCode("\n\t conjugate_counter = (conjugate_counter + 1) %% %d;", network->n_connections);

  }/* end if conj_grad */




  GenCode("\n\n\n\t  /* copy current of network to  buffer */");
  GenCode("\n\t  bcopy((char *)network_data, (char *)network_data_buffer, (%d+PADDING)*sizeof(float));",
	  network->size);

  GenCode("\n\n\t/* line search */");
  GenCode("\n\t{\n\t\t int timeout=%d;", network->line_search_timeout);

  GenCode("\n\n\t\t do {");
  GenCode("\n\t\t float step_error=0.0;\n\t\t   int it=%d;", network->line_search_update);

  GenCode("\n\n\t\t  /* modify current state */");
  GenCode("\n\t  bcopy((char *)network_data_buffer, (char *)network_data, (%d+PADDING)*sizeof(float));",
	  network->size);
  {
    BBLPTR list = black_box_list;
    while(list != (BBLPTR)NULL) {
      /* if learning is allowed...*/
      if ( BBCALC_CREDIT(list->bbox->dynamic) ) {

	GenCode("\n\t/* scale grad of %s */", list->bbox->name);
	ls_scale_grad(list->bbox);
	GenCode("\n\n\t/* update %s */", list->bbox->name);
	GenCode("\n\t %s_update_weights( (alpha*inertia)/last_alpha );\n",
		list->bbox->name);

      }/* end if */
      list = list->next;
    }/* end while */
  }

  GenCode("\n\t\t  /* calc error for %d iterations */", network->line_search_update);
  GenCode("\n\t\t  while(it--) {\n\t\t float ls_error;\n");
  GenCode("\n\t\t     %s_ls_playback_io(it);", file);
  GenCode("\n\t\t     %s_network_forward();", file);

  {
    BBLPTR list = black_box_list;

    GenCode("\n\n /* Calc Error on each black box */");

    while(list != (BBLPTR)NULL) {
      /* if learning is allowed...*/
      if ( BBCALC_CREDIT(list->bbox->dynamic) )
	
	
	
	if (list->bbox->efferent) { /* gotta have outputs */
	  
	  if (list->bbox->update_interval != network->line_search_update) {
	    fprintf(stderr, "\n\nWarning: The line search update does not equal the update");
	    fprintf(stderr, "\n         interval for black box %s, using line search update interval\n",
		    list->bbox->name);
	  }/* end if */
	  
	  if (bpthreshold != 0.0) {
	    GenCode("\n\n\t\t /* Only calc on significant errors (heuristic to speed learning) */");
	    GenCode("\n\t\t     ls_error = %s_calc_error();",
		    list->bbox->name);	
	    GenCode("\n\t\t  if ( ls_error > BPthreshold ) step_error += ls_error;");
	  } else {
	    GenCode("\n\t\t     step_error += %s_calc_error();",
		    list->bbox->name);
	  }/* end if */
	  
	}/* end if */
      
      
      list = list->next;
    }/* end while */
  }


  GenCode("\n\t\t  }");


  if (network->line_search == LINE_SEARCH_VERBOSE)
    GenCode("\n\t\t  fprintf(stderr, \"\\nLearning Rate: %%f Inertia: %%f Mean Error: %%f Last Mean Error: %%f\", alpha, inertia, step_error/%f, total_error/%f);",
	    (float)(network->line_search_update),
	    (float)(network->line_search_update) );

  GenCode("\n\t\t  /* test for doneness */");
  GenCode("\n\t\t  if ( step_error < total_error ) break;");

  GenCode("\n\t\t  /* test for badness */");
  GenCode("\n\t\t  if ( timeout == 1 && (step_error / total_error > 1.5) ) { /* reset direction, no change!, change lr */");
  GenCode("\n\t\t    BPlearning_rate *= 0.5;");
  if (network->conjugate_gradient != NO_CONJ_GRAD) 
    GenCode("\n\t\t    conjugate_counter = 0;");
  GenCode("\n\t\t    bcopy((char *)network_data_buffer, (char *)network_data, (%d+PADDING)*sizeof(float));",
	  network->size);
  GenCode("\n\t\t  }");

  GenCode("\n\t\t  alpha *= 0.5;");

  GenCode("\n\t\t } while (--timeout);");

  GenCode("\n\t\t  if ( alpha == 0.0 ) {");
  GenCode("\n\t\t    fprintf(stderr, \"\\nAll done!\\n\");");
  GenCode("\n\t\t    exit(0);");
  GenCode("\n\t\t  }");


  GenCode("\n\t\t  if ( timeout == %d ) accept_counter++; else accept_counter = 0;",
	  network->line_search_timeout);
  GenCode("\n\t\t  if ( accept_counter == 5 ) { /* be more aggressive */");
  GenCode("\n\t\t     accept_counter = 0; BPlearning_rate *= 2.0;");
  GenCode("\n\t\t  }");

  GenCode("\n\t\t  if ( ! timeout ) timeout_counter++; else timeout_counter = 0;");
  GenCode("\n\t\t  if ( timeout_counter == 5 ) { /* reset, back to gradient descent, change lr */");
  GenCode("\n\t\t     timeout_counter = 0; BPlearning_rate *= 0.5;");
  if (network->conjugate_gradient != NO_CONJ_GRAD) 
    GenCode("\n\t\t     conjugate_counter = 0;");
  GenCode("\n\t\t  }");




  if (network->line_search == LINE_SEARCH_VERBOSE)
    GenCode("\n\t fprintf(stderr, \"\\n\");");

  GenCode("\n\t }");

  GenCode("\n\t last_alpha = alpha; total_error = 0.0;");

  GenCode("\n\t} /* end line search */");
}

/* declare_network_learn: Forward then Backprop */
static declare_network_learn(network,file)
     ND_PTR network;
     char *file;
{

  /* support functions */


  /* line search support */
  if (network->line_search != NO_LINE_SEARCH) {
    
    /* storage to record the i/o */
    GenCode("\n");

    /* to keep track of count on line search */
    GenCode("\nstatic unsigned int line_search_counter=0;");

    GenCode("\nstatic float last_mag_grad=1.0, last_alpha=1.0, total_error;");

    {
      BBLPTR list = black_box_list;
      
      while(list != (BBLPTR)NULL) {
	
	if (list->bbox->efferent) /* targets */
	  GenCode("\nstatic float *ls_b%d_target_table[%d];",
		  list->bbox->number, network->line_search_update);
	
	if (list->bbox->n_inputs) /* inputs */
	  GenCode("\nstatic float *ls_b%d_input_table[%d];",
		  list->bbox->number, network->line_search_update);
	
	list = list->next;
      }/* end while */
    }
    
    GenCode("\n\n/* %s_ls_record_io: Remember the inputs/targets in table. */",file);
    GenCode("\nstatic void %s_ls_record_io(index)",file);
    GenCode("\n  unsigned int index; \n{\n");
    
    {
      BBLPTR list = black_box_list;
      
      while(list != (BBLPTR)NULL) {
	
	if (list->bbox->efferent) /* targets */
	  GenCode("\n\tls_b%d_target_table[index] = %s_get_target_output();",
		  list->bbox->number, list->bbox->name);
	
	if (list->bbox->n_inputs) /* inputs */
	  GenCode("\n\tls_b%d_input_table[index] = %s_get_input();",
		  list->bbox->number, list->bbox->name);
	
	list = list->next;
      }/* end while */
    }
    
    GenCode("\n}");


    GenCode("\n\n/* %s_ls_playback_io: Remember the inputs/targets in table. */", file);
    GenCode("\nstatic void %s_ls_playback_io(index)",file);
    GenCode("\n  unsigned int index; \n{\n");
    
    {
      BBLPTR list = black_box_list;
      
      while(list != (BBLPTR)NULL) {
	
	if (list->bbox->efferent) /* targets */
	  GenCode("\n\t%s_set_target_output( ls_b%d_target_table[index] );",
		  list->bbox->name, list->bbox->number);
	
	if (list->bbox->n_inputs) /* inputs */
	  GenCode("\n\t%s_set_input( ls_b%d_input_table[index] );",
		  list->bbox->name, list->bbox->number);
	
	list = list->next;
      }/* end while */
    }
    
    GenCode("\n}");
    
  }/* end line search support */


  /* the learning function */
  
  GenCode("\n\n/* %s_network_learn: Propagate all bb's forward then backward */",file);
  GenCode("\nvoid %s_network_learn(iterations, generator)",file);
  GenCode("\n  int iterations; \n   VFPTR generator; \n{");
  

  GenCode("\n float error=0.0;");

  GenCode("\n\n while(iterations--) {");

  GenCode("\n  /* execute the generator */");
  GenCode("\n  generator();");

  GenCode("\n  /* all black boxes forward! */");
  GenCode("\n  %s_network_forward();", file);
  

  /* error */
  {
    BBLPTR list = black_box_list;

    GenCode("\n\n /* Calc error on each output black box */");
    while(list != (BBLPTR)NULL) {
      
      if ( BBCALC_CREDIT(list->bbox->dynamic) ) { /* if learning is allowed...*/

	if (list->bbox->efferent) 
	  GenCode("\n   error += %s_calc_error();", list->bbox->name);

      }/* end if */
      list = list->next;
    }/* end while */
  }

  
  /* begin heuristic */
  if (bpthreshold != 0.0) {
    GenCode("\n\n /* Only calc on significant errors (heuristic to speed learning) */");
    GenCode("\n  if ( error > BPthreshold )  {");
  }  else
    GenCode("\n  {");

  if ( network->line_search != NO_LINE_SEARCH) {
    GenCode("\n  total_error += error;");
    GenCode("\n  %s_ls_record_io(line_search_counter);", file);
  }


  {
    BBLPTR list = black_box_list;

    GenCode("\n\n /* Calc grad ient and update on each black box */");
    while(list != (BBLPTR)NULL) {
      /* if learning is allowed...*/
      if ( BBCALC_CREDIT(list->bbox->dynamic) ) {

	GenCode("\n    %s_calc_grad();", list->bbox->name);

	if ( network->line_search == NO_LINE_SEARCH) /* line search does this */
	  GenCode("\n    %s_update_weights(BPinertia);", list->bbox->name);

      }/* end if */
      list = list->next;
    }/* end while */
  }

  /* line search update weights */
  switch(network->line_search) {
  case LINE_SEARCH :    
  case LINE_SEARCH_VERBOSE :
    declare_line_search(file);
    break;
  }

  /* end heuristic */
  if (bpthreshold != 0.0) {
    GenCode("\n  }/* end if error */");
  }  else {
    GenCode("\n  }");
  }

  GenCode("\n  error = 0.0;");

  GenCode("\n }/* end while iterations */");
  
  
  GenCode("\n}/* end %s_network_learn */", file);
  
}/* end declare_network_learn */



/* _declare_learning_fncts: Forward and learn.
 */
void _declare_learning_fncts(network, file)
     ND_PTR network;
     char *file;
{
  char file_name[50];
  extern FILE *stream;

  /* create an ordered linked list of black boxes */
  order_black_boxes(network);

  /* create the forward pass through all the black boxes */
  declare_network_forward(file);

  /* create the learning function */
  declare_network_learn(network, file);

}/* end _declare_learning_fncts */



/* create_interface_init: Writes initialization for learning. */
static create_interface_init(network, file)
     ND_PTR network;
     char *file;
{
  extern FILE *stream;
  register BBLPTR list = black_box_list;
  char error_string[128];

  GenCode("\n\n/* network_initialize: Init simulation, optionaly load dump.*/");
  GenCode("\nvoid network_initialize(filename, verbose)\n  int verbose;\n  char *filename;\n{");

  GenCode("\n /* initialize the simulator */");
  GenCode("\n if (%s_init_network()) {",
	  file);
  GenCode("\n   fprintf(stderr, \"%%s\", %s_error_string());",
	  file);
  GenCode("\n   exit(1);");
  GenCode("\n }/* end if init */");

  GenCode("\n /* load dump file? */");
  GenCode("\n if (filename != (char *)NULL) {");
  GenCode("\n  if (verbose) printf(\"\\nLoading Saved Network: %%s\", filename);");
  GenCode("\n  if (%s_load_network(filename)) {",
	  file);
  GenCode("\n    fprintf(stderr, \"%%s\", %s_error_string());",
	  file);
  GenCode("\n    exit(1);");
  GenCode("\n  }/* end if load_network */");
  GenCode("\n }/* end if filename */");


  GenCode("\n\n if (verbose) {");
  GenCode("\n  printf(\"\\n\\nBackpropagation Learning\");");
  GenCode("\n  printf(\"\\n\\nBlack Boxes:\\n\");");
  while(list != (BBLPTR)NULL) {
      GenCode("\n  printf(\"\\n%s (Saved at Iteration %%d)\", %s_get_backprop_counter());",
	      list->bbox->name,
	      list->bbox->name);
    list = list->next;
  }/* end while */
  GenCode("\n }/* end if verbose */");

  GenCode("\n}/* end network_initialize */");

}/* end create_interface_init */

/* create_interface_forward: Forward prop all of the black boxes */
static create_interface_forward(file)
     char *file;
{
  register BBLPTR list = black_box_list;
  register int counter;
  int buffer_size;
  
  GenCode("\n\n/* network_forward: Propagate all bb's forward.*/");
  GenCode("\nvoid network_forward(iterations, generator)");
  GenCode("\n  int iterations; \n  VFPTR generator;\n{");
  
  GenCode("\n  while (iterations--) {");
  GenCode("\n   generator();");
  GenCode("\n   %s_network_forward();", file);
  GenCode("\n  }/* end while */");
    
  GenCode("\n}/* end network_forward */");
  
}/* end create_interface_forward */

/* times_backward: Update counter for each bb, */
static void times_backward(bbd)
     BD_PTR bbd;
{
  GenCode("\n  b%dbcounter += n_patterns * passes;", bbd->number);
}/* end times_backward */

/* create_interface_learn: Forward then Backprop */
static create_interface_learn(file)
     char *file;
{
  int buffer_size, counter;

  GenCode("\n\n/* network_learn: Propagate all bb's forward then backward */");
  GenCode("\nvoid network_learn(iterations, generator)");
  GenCode("\n  int iterations; \n  VFPTR generator; \n{");

  GenCode("\n     %s_network_learn(iterations, generator);", file);

  GenCode("\n}/* end network_learn */");

}/* end create_interface_learn */

/* create_interface_clear_delays: Zero the delay buffers */
static void create_interface_clear_delays(file)
     char *file;
{
  GenCode("\n\nvoid network_clear_delays()\n{");
  GenCode("\n  %s_clear_delays();", file);
  GenCode("\n}/* end network_clear_delays */");
}/* end create_interface_clear_delays */

/* create_interface_io: Create load/dump function for network. */
static create_interface_io(file)
     char *file;
{
  GenCode("\n\nvoid network_dump(filename)\n  char *filename;\n{");
  GenCode("\n if (%s_dump_network(filename)) {", file);
  GenCode("\n   fprintf(stderr, \"%%s\", %s_error_string());",
	  file);
  GenCode("\n}/* end if dump*/");
  GenCode("\n}/* end network_dump */");

  GenCode("\n\nvoid network_load(filename)\n  char *filename;\n{");
  GenCode("\n if (%s_load_network(filename)) {", file);
  GenCode("\n   fprintf(stderr, \"%%s\", %s_error_string());",
	  file);
  GenCode("\n}/* end if load*/");
  GenCode("\n}/* end network_load */");

}/* end create_interface_io */

/* create_interface_ascii_io: Create dump function for network. */
static create_interface_ascii_io(file)
     char *file;
{
  GenCode("\n\nvoid network_ascii_dump(formatted)\n\tint formatted;\n{");
  GenCode("\n %s_ascii_dump_network(formatted);", file);
  GenCode("\n}/* end network_ascii_dump */");

  GenCode("\n\nvoid network_load_ascii()\n{");
  GenCode("\n %s_load_ascii_network();", file);
  GenCode("\n}/* end network_load_ascii */");

}/* end create_interface_ascii_io */

/* create_interface_query: Create dump function for network. */
static create_interface_query(file)
     char *file;
{
  GenCode("\n\nLB_PTR network_query(bbindex, layerindex)");
  GenCode("\n  int bbindex, layerindex;\n{");
  GenCode("\n return(%s_query_network(bbindex, layerindex));", file);
  GenCode("\n}/* end network_query */");
}/* end create_interface_query */



/* create_forward_print: print outputs and targets */
static create_forward_print(file)
     char *file;
{
  register BBLPTR list = black_box_list;
  
  GenCode("\n\n/* network_forward_print: print outputs/target for each bb */");
  GenCode("\nvoid network_forward_print(iterations, generator)");
  GenCode("\n  int iterations; \n  VFPTR generator;\n{");
  GenCode("\n  register int icounter,counter;");
  GenCode("\n  register float *ptr;\n");

  GenCode("\n  for(icounter=0;icounter<iterations;icounter++) {");
  GenCode("\n   generator();");
  GenCode("\n   %s_network_forward();", file);
  
  /* thru the black boxes */
  while(list != (BBLPTR)NULL) {

    if (list->bbox->efferent) { /* if there are connections to outside.. */

      GenCode("\n   printf(\"\\nIteration: %%d\", icounter);");
      GenCode("\n   printf(\"\\nblack box %s:\");", list->bbox->name);
      GenCode("\n   ptr = %s_get_output();", list->bbox->name);
      GenCode("\n   counter = %d;", list->bbox->n_outputs);
      GenCode("\n   printf(\"\\n\\tOutputs:\");");
      GenCode("\n   while(counter--)");
      GenCode("\n     printf(\" %%f\", *ptr++);");
      
      GenCode("\n   ptr = %s_get_target_output();", list->bbox->name);
      GenCode("\n   counter = %d;", list->bbox->n_outputs);
      GenCode("\n   printf(\"\\n\\tTargets:\");");
      GenCode("\n   while(counter--)");
      GenCode("\n     printf(\" %%f\", *ptr++);");
      
    }/* end if */

    list = list->next;
  }/* end while */
  
  GenCode("\n  }/* end while */");

  GenCode("\n}/* end network_forward_print */");
  
}/* end create_forward_print */

/* create_forward_pdpfa: calc pd, pfa and pfn */
static create_forward_pdpfa(file)
     char *file;
{
  register BBLPTR list;
  
  GenCode("\n\n/* network_forward_pdpfa: calc pd, pfa, pfn */");
  GenCode("\nvoid network_forward_pdpfa(iterations, generator, threshold)");
  GenCode("\n  int iterations; \n  VFPTR generator;\n  float threshold;\n{");
  GenCode("\n  register int counter;");
  GenCode("\n  register float *ptr1, *ptr2;");
  GenCode("\n  int total = iterations;");
  /* for acummulating stats */
  list = black_box_list;
  while(list != (BBLPTR)NULL) {
    if (list->bbox->efferent) { /* if there are connections to outside.. */
      GenCode("\n  float %s_sum;", list->bbox->name);
      GenCode("\n  int %s_detections = 0;",	list->bbox->name);
      GenCode("\n  double %sMin[%d];", list->bbox->name, list->bbox->n_outputs);
      GenCode("\n  double %sMax[%d];", list->bbox->name, list->bbox->n_outputs);
      GenCode("\n  float %sMean[%d];", list->bbox->name, list->bbox->n_outputs);
      GenCode("\n  float %sSquare[%d];", list->bbox->name, list->bbox->n_outputs);
    }/* end if */
    list = list->next;
  }/* end while */
  
  GenCode("\n\n  if (!iterations) return;");
  
  GenCode("\n  threshold *= threshold;");
  
  /* init */
  list = black_box_list;
  while(list != (BBLPTR)NULL) {
    if (list->bbox->efferent) { /* if there are connections to outside.. */
      GenCode("\n  for(counter=0;counter<%d;counter++) { /* init */",
	      list->bbox->n_outputs);
      GenCode("\n     %sMin[counter] = AM_HUGE_VAL;", list->bbox->name);
      GenCode("\n     %sMax[counter] = -AM_HUGE_VAL;", list->bbox->name);
      GenCode("\n     %sMean[counter] = 0.0;", list->bbox->name);
      GenCode("\n     %sSquare[counter] = 0.0;", list->bbox->name);
      GenCode("\n  }/* end for */");
    }/* end if */
    list = list->next;
  }/* end while */
  
  GenCode("\n  while (iterations--) {");
  GenCode("\n   generator();");
  GenCode("\n   %s_network_forward();", file);
  
  /* thru the black boxes */
  list = black_box_list;
  while(list != (BBLPTR)NULL) {
    
    if (list->bbox->efferent) { /* if there are connections to outside.. */

      GenCode("\n   %s_sum = 0.0;", list->bbox->name);
      
      GenCode("\n   ptr1 = %s_get_output();", list->bbox->name);
      GenCode("\n   ptr2 = %s_get_target_output();", list->bbox->name);
      
      GenCode("\n   for(counter=0;counter<%d;counter++) { /* calc sum of squared diffs */",
	      list->bbox->n_outputs);
      GenCode("\n     float val, target, diff;\n");
      
      GenCode("\n     val = *ptr1++;");
      GenCode("\n     target = *ptr2++;");
      GenCode("\n     %sMean[counter] += val;", list->bbox->name);
      GenCode("\n     %sSquare[counter] += val * val;", list->bbox->name);
      GenCode("\n     if (val < %sMin[counter]) %sMin[counter] = val;",
	      list->bbox->name,
	      list->bbox->name);
      GenCode("\n     if (val > %sMax[counter]) %sMax[counter] = val;",
	      list->bbox->name,
	      list->bbox->name);
      GenCode("\n     diff = val - target;");
      GenCode("\n     %s_sum += diff * diff;\n   }/* end for */",list->bbox->name);
      
      GenCode("\n   /* count */");
      GenCode("\n   if (%s_sum < threshold) %s_detections++;",
	      list->bbox->name,
	      list->bbox->name);
      
    }/* end if */
    
    list = list->next;
  }/* end while */
  
  GenCode("\n  }/* end while */");
  
  /* print out */
  list = black_box_list;
  while(list != (BBLPTR)NULL) {
    if (list->bbox->efferent) { /* if there are connections to outside.. */
      GenCode("\n\n  printf(\"\\n\\n%s:\");", list->bbox->name);
      
      GenCode("\n  printf(\"\\nMaxima: \");");
      GenCode("\n  for(counter=0;counter<%d;counter++) { /* calc max */",
	      list->bbox->n_outputs);
      GenCode("\n   printf(\"%%f \",%sMax[counter]);", list->bbox->name);
      GenCode("\n  }/* end for */");

      GenCode("\n  printf(\"\\nMinima: \");");
      GenCode("\n  for(counter=0;counter<%d;counter++) { /* calc min */",
	      list->bbox->n_outputs);
      GenCode("\n   printf(\"%%f \",%sMin[counter]);", list->bbox->name);
      GenCode("\n  }/* end for */");

      GenCode("\n  printf(\"\\nMeans: \");");
      GenCode("\n  for(counter=0;counter<%d;counter++) { /* calc mean */",
	      list->bbox->n_outputs);
      GenCode("\n     %sMean[counter] /= (float)total;", list->bbox->name);
      GenCode("\n     printf(\"%%f \",%sMean[counter]);", list->bbox->name);
      GenCode("\n  }/* end for */");
      
      GenCode("\n  printf(\"\\nVariances: \");");
      GenCode("\n  for(counter=0;counter<%d;counter++) { /* calc variance */",
	      list->bbox->n_outputs);
      GenCode("\n     %sSquare[counter] /= (float)total;", list->bbox->name);
      GenCode("\n     printf(\"%%f \",%sSquare[counter] - (%sMean[counter] * %sMean[counter]));",
	      list->bbox->name,
	      list->bbox->name,
	      list->bbox->name);
      GenCode("\n  }/* end for */");
      
      
      GenCode("\n  /* print stats */");
      GenCode("\n  printf(\"\\n\\nPd = %%d/%%d = %%f\",%s_detections, total, (float)%s_detections/total);",
	      list->bbox->name,
	      list->bbox->name);
      GenCode("\n  printf(\"\\nPfa = %%d/%%d = %%f\", total - %s_detections, total, (float)(total - %s_detections)/total);",
	      list->bbox->name,
	      list->bbox->name);
      
    }/* end if */
    list = list->next;
    
  }/* end while */
  
  GenCode("\n}/* end network_forward_pdpfa */");
  
}/* end create_forward_pdpfa */

/* _create_interface_fncts: Create functions that control the learning
                           for the whole network.
 */
void _create_interface_fncts(network, file)
     ND_PTR network;
     char *file;
{
  char file_name[50];
  extern FILE *stream;

  if (strcmp(file,"network") == 0) _mistake("Don't name your file 'network', anything but that!");
  
  /* create the initialization */
  create_interface_init(network,file);

  /* create the forward pass through all the black boxes */
  create_interface_forward(file);

  /* create the learning function */
  create_interface_learn(file);

  /* create function to clear delay buffers */
  create_interface_clear_delays(file);

  /* create dump function */
  create_interface_io(file);

  /* create ascii io */
  create_interface_ascii_io(file);

  /* create query function */
  create_interface_query(file);

  /* create print output function */
  create_forward_print(file);

  /* create stat function */
  create_forward_pdpfa(file);

}/* end _create_interface_fncts */

