/* 

  ****************   NO WARRANTY  *****************

Since the Aspirin/MIGRAINES system is licensed free of charge,
the MITRE Corporation provides absolutley no warranty. Should
the Aspirin/MIGRAINES system prove defective, you must assume
the cost of all necessary servicing, repair or correction.
In no way will the MITRE Corporation be liable to you for
damages, including any lost profits, lost monies, or other
special, incidental or consequential damages arising out of
the use or inability to use the Aspirin/MIGRAINES system.

  *****************   COPYRIGHT  *******************

This software is the copyright of The MITRE Corporation. 
It may be freely used and modified for research and development
purposes. We require a brief acknowledgement in any research
paper or other publication where this software has made a significant
contribution. If you wish to use it for commercial gain you must contact 
The MITRE Corporation for conditions of use. The MITRE Corporation 
provides absolutely NO WARRANTY for this software.

   January, 1992 
   Russell Leighton
   The MITRE Corporation
   7525 Colshire Dr.
   McLean, Va. 22102-3481

*/

#include "bp_generator.h"

extern int _connection_size();

/* declare_bb_sizes:   Declare the sizes of all bbs in the header */
static declare_bb_sizes(bbd)
     BD_PTR bbd;
{
  extern FILE *stream;
  char *name = bbd->name;
  char *pname; /* print name of layer */
  LD_PTR from_layer, layer = bbd->layers;
  CD_PTR connection;
  int size = 0; /* size of the black box dump */
  
  /* black box name */
  size += strlen(name) + 1;
  /* iteration counter */
  size += sizeof(int);
  /* number of layers */
  size += sizeof(int);
  /* number of connection arrays */
  size += sizeof(int);
  /* add up layer sizes (the thresholds are dumped) plus the length
     of the layer names */
  while(layer != (LD_PTR)NULL)  {
    pname = sym_name(layer->name);
    size += strlen(pname) + 1; /* size of layer name */
    size += sizeof(int); /* size of length of data to follow */
    size += layer->n_nodes * sizeof(float); /* size of thresholds in bytes */
    size += layer->layer_order * layer->n_nodes * sizeof(float); /* reflection weights in bytes */
    connection = layer->inputs_from;
    while (connection != (CD_PTR)NULL)
      {
	size += strlen(connection->array_name) + 1;
	pname = sym_name(connection->to);
	size += strlen(pname) + 1;
	size += sizeof(int);
	/* weight array */
	size += _connection_size(connection, bbd) * sizeof(float);
	/* next */
	connection = connection->next;
      }/* end while */
    /* next */
    layer = layer->next;
  }/* end while */
  
  /* write to the header */
  GenCode( "\n number = %d; /* data size of black box */",
	  size);
  GenCode( "\n write(fd, &number, sizeof(int));");
  
}/* end declare_bb_sizes */

/* _declare_write_header: Write header on dump file. */
void _declare_write_header(network)
     ND_PTR network;
{

  GenCode(
	  "\n\nstatic write_header(fd)\n  int fd;\n{");
  GenCode( "\n int number;");
  GenCode( "\n int n_extra_bytes = 256;");
  GenCode( "\n char extra_bytes[256];\n");

  GenCode( "\n /* write header */");
  GenCode( "\n number = %d; /* compiler identifier */",
	  ASPIRIN_MAGIC);
  GenCode( "\n write(fd, &number, sizeof(int));");
  GenCode( "\n number = %d; /* major version */",
	  BP_MAJOR_VERSION);
  GenCode( "\n write(fd, &number, sizeof(int));");
  GenCode( "\n number = %d; /* minor version */",
	  BP_MINOR_VERSION);
  GenCode( "\n write(fd, &number, sizeof(int));");

  /*** this is the difference between 2.0 and 3.0 ***/
  GenCode( "\n /* extra bytes for future use */");
  GenCode( "\n write(fd, extra_bytes, n_extra_bytes);");

  GenCode( "\n}/* end write_header */\n");

  
}/* end _declare_write_header */

/* declare_dump_bb:  Declare code to dump this black box */
static declare_dump_bb(bbd)
     BD_PTR bbd;
{
  extern FILE *stream;
  LD_PTR layer = bbd->layers;
  CD_PTR connection;
  int size;
  char *pname;



  GenCode( "\n /* dump %s */", bbd->name);
  GenCode( "\n write(fd, \"%s\", %d); /* name */",
	  bbd->name,
	  (strlen(bbd->name) + 1));
  GenCode( "\n number = b%dbcounter; /* number of iterations */",
	  bbd->number);
  GenCode( "\n write(fd, &number, sizeof(int));");
  GenCode( "\n number = %d; /* number of layers */",
	  bbd->n_layers);
  GenCode( "\n write(fd, &number, sizeof(int));");
  while (layer != (LD_PTR)NULL)
    {
      /* dump thresholds */
      pname = sym_name(layer->name);
      GenCode("\n write(fd, \"%s\", %d); /* layer name */",
	      pname,
	      (strlen(pname) + 1));
      GenCode("\n number = %d; /* size of threshold array */",
	      layer->n_nodes);
      GenCode("\n write(fd, &number, sizeof(int));");
      GenCode("\n write(fd, b%d_l%d_t, %d * sizeof(float)); /* thresholds */",
	      bbd->number,
	      layer->number,
	      layer->n_nodes);

      { /* dump reflection weights */
	int counter = layer->layer_order;
	while(counter--)
	  GenCode("\n write(fd, b%d_l%d_r%d, %d * sizeof(float)); /* reflection weights */",
		  bbd->number, layer->number,
		  counter + 1,
		  layer->n_nodes);
      }

      /* next */
      layer = layer->next;
    }/* end while */

  /* dump connections into file */
  GenCode( "\n number = %d; /* number of connection arrays */",
	  bbd->n_cds);
  GenCode( "\n write(fd, &number, sizeof(int));");
  layer = bbd->layers;
  while (layer != (LD_PTR)NULL)
    {
      connection = layer->inputs_from;
      while(connection != (CD_PTR)NULL)
	{
	  pname = sym_name(connection->to);
	  GenCode( "\n write(fd, \"%s\", %d); /* to layer name */",
		  pname,
		  (strlen(pname) + 1));
	  GenCode( "\n write(fd, \"%s\", %d); /* matrix name */",
		  connection->array_name,
		  (strlen(connection->array_name) + 1));
	  size = _connection_size(connection, bbd);
	  GenCode( "\n number = %d; /* size of weight array */",
		  size);
	  GenCode( "\n write(fd, &number, sizeof(int));");
	  GenCode(
		  "\n write(fd, b%d_%s, %d * sizeof(float)); /* connection */",
		  bbd->number,
		  connection->array_name,
		  size);
	  /* next */
	  connection = connection->next;
	}/* end while */
     /* next */
      layer = layer->next;
    }/* end while */
}/* declare_dump_bb */

/* _declare_dump:   Declare the function that dumps the weights to disk. */
void _declare_dump(network, file)
     ND_PTR network;
     char *file;
{
  extern FILE *stream;

  GenCode(
	  "\n\nint %s_dump_network(filename)\n  char *filename;\n{",
	  file);
  GenCode( "\n int number;");
  GenCode( "\n int n_extra_bytes = 256;");
  GenCode( "\n int extra_bytes[256];");
  GenCode( "\n int fd;\n");
  GenCode( "\n /* open the file */");
  GenCode( "\n fd = creat(filename, 0644);");
  GenCode( "\n if (fd == -1) {");
  GenCode(
	  "\n  sprintf(error_string, \"\\nUnable to open %%s.\\n\", filename);");
  GenCode( "\n  return(FERROR);");
  GenCode( "\n }/* end if */");

  GenCode( "\n /* write header */");
  GenCode( "\n write_header(fd);");


  GenCode( "\n number = %d; /* number of black boxes */",
	  network->n_bbds);
  GenCode( "\n write(fd, &number, sizeof(int));");
  /* now calc the size in bytes of the dump for each black box
     and record in the header...
   */
  map_dictionary(declare_bb_sizes, network->lookup);

  /* now declare the code to dump the black boxes */
  map_dictionary(declare_dump_bb, network->lookup);

  GenCode( "\n close(fd);");
  GenCode( "\n return(0);");
  GenCode( "\n}/* end %s_dump_network */", file);

}/* end declare_dump */

/* _declare_read_header: Read dump file header. */
void _declare_read_header(network)
     ND_PTR network;
{
  extern FILE *stream;
  int extra_bytes = 256;
  int header_size;
  int n_bbds = network->n_bbds;
  int counter;

  header_size = (4 * sizeof(int)) + extra_bytes; /* bytes in header */

  GenCode(
	  "\n\nstatic TOC_PTR read_header(fd)\n  int fd;\n{");
  GenCode( "\n int header_size = %d;", header_size);
  GenCode( "\n int number;");
  GenCode( "\n char extra_bytes[256];");
  GenCode( "\n int counter;");
  GenCode( "\n int n_bbs; /* number of black boxes */");
  GenCode( "\n TOC_PTR table_of_contents;");
  GenCode( "\n off_t *address_table; /* used to index bb in a file */");

  GenCode( "\n\n /* read header */");

/* word 1 */
  /* was this dumped from a backprop simulator ? */
  GenCode( "\n read(fd, &number, sizeof(int)); /* compiler type */");
  GenCode( "\n if (number != %d) {", ASPIRIN_MAGIC);
  GenCode(
  "\n  sprintf(error_string, \"\\nWarning: This dump file is from another compiler.\\n\");");
  GenCode( "\n }/* end if */");

/* word 2 */
  /* was the version of the compiler compatible? */
  GenCode( "\n read(fd, &number, sizeof(int)); /* major version */");
  GenCode( "\n if (number != %d) {", BP_MAJOR_VERSION);
  GenCode(
  "\n  sprintf(error_string, \"\\nWarning: Dump file created with another version (v.%%d) of compiler.\\n\", number);");
/*  GenCode( "\n  return((TOC_PTR)NULL);"); */
  GenCode( "\n }/* end if */");

/* word 3 */
  GenCode( "\n read(fd, &number, sizeof(int)); /* minor version */");
  GenCode( "\n if (number != %d) {", BP_MINOR_VERSION);
  GenCode(
  "\n  sprintf(error_string, \"\\nWarning: Dump file created with another version of compiler.\\n\");");
  GenCode( "\n }/* end if */");

/* 256 bytes */
  /*** this is the difference between 2.0 and 3.0 ***/
  GenCode( "\n /* extra bytes for future use */");
  GenCode( "\n read(fd, extra_bytes, %d);", extra_bytes);

/* word 4 */
  /* read the number of black boxes */
  GenCode( "\n read(fd, &n_bbs, sizeof(int));");
  GenCode(
	  "\n /* create a table of black boxes => location in file */");
  GenCode( "\n address_table = (off_t *)am_alloc_mem(n_bbs * sizeof(int));");
  /* fill the address table with the byte address of the black box */
  GenCode(
    "\n header_size += n_bbs * sizeof(int); /* add the address_table size */");
  GenCode( "\n for (counter = 0; counter<n_bbs; counter++) {");
  GenCode( "\n  address_table[counter] =  header_size;");
  GenCode( "\n  read(fd, &number, sizeof(int));");
  GenCode( "\n  header_size += number;");
  GenCode( "\n }/* end for */");

  GenCode( "\n table_of_contents = (TOC_PTR)am_alloc_mem(sizeof(TOC_STRUCT));");
  GenCode( "\n table_of_contents->size = n_bbs;");
  GenCode( "\n table_of_contents->address_table = address_table;");
  GenCode( "\n return(table_of_contents);");
  GenCode( "\n\n}/* end read_header */");
  
}/* end _declare_read_header */


/* declare_load_bb: generate an if test for each bb name */
static declare_load_bb(bbd)
     BD_PTR bbd;
{
  extern FILE *stream;
  char *name = bbd->name;
  int number = bbd->number;
  LD_PTR layer;
  CD_PTR connection;

  GenCode( "\n if (strcmp(\"%s\", name) == 0) {",
	  name);

  /* read the iteration counter */
  GenCode( "\n  read(fd, &b%dbcounter, sizeof(int)); /* iteration counter */",
  	  number);

  /* read in the thresholds for the layers */
  GenCode( "\n  read(fd, &n_layers, sizeof(int)); /* n layers of thresholds */");
  GenCode( "\n  if (n_layers != %d) {", bbd->n_layers);
  GenCode( "\n   sprintf(error_string, \"\\nError in reading %s\\n\");",
	  name);
  GenCode( "\n   return(DFERROR);"); 
  GenCode( "\n  }/* end  if */");

  GenCode( "\n  /* load all thresholds and reflection weights */");
  GenCode( "\n  for(counter = 0; counter<n_layers; counter++) {");
  GenCode( "\n   /* read the layer name and data size */");
  GenCode( "\n   BPread_string(fd, name_string);");
  GenCode( "\n   read(fd, &size, sizeof(int));");
  GenCode( "\n   not_found = 1; /* reset flag (0 if read the thresholds) */");

  /* find the layer name, check the layer size and load */
  layer = bbd->layers;
  while(layer !=(LD_PTR)NULL)
    {
      /* read the thresholds */
      GenCode( "\n   if (error_code = BPread_thresholds(fd,name_string,\"%s\",size,%d,b%d_l%d_t,&not_found))",
	      sym_name(layer->name),
	      layer->n_nodes,
	      number,
	      layer->number);
      GenCode( "\n      return(error_code);");

      { /* read the reflection weights */
	int counter = layer->layer_order;

	while(counter--) {
	  GenCode( "\n   if (error_code = BPread_reflection_weights(fd,name_string,\"%s\",size,%d,b%d_l%d_r%d,&not_found))",
		  sym_name(layer->name),
		  layer->n_nodes,
		  number, layer->number,
		  counter + 1);
	  GenCode( "\n      return(error_code);");
	}/* end while */
      } /* end block */

      /* next */
      layer = layer->next;
    }/* end while */
  GenCode( "\n   if(not_found) {");
  GenCode( "\n     sprintf(error_string, \"\\nUnknown layer name in file: %%s\\n\", name_string);");
  GenCode( "\n     return(DFERROR);");
  GenCode( "\n   }/* end if not found */");
  GenCode( "\n  }/* end for */");

  /* now fill the connection arrays */
  GenCode( "\n  read(fd, &n_connections, sizeof(int));");
  GenCode( "\n  for(counter = 0; counter<n_connections; counter++) {");
  /* read the to and from layer name, and data size */
  GenCode( "\n   /* to layer */");
  GenCode( "\n   BPread_string(fd, name_string);");
  GenCode( "\n   /* from layer */");
  GenCode( "\n   BPread_string(fd, name_string2);");
  GenCode( "\n   read(fd, &size, sizeof(int));");
  GenCode( "\n   not_found = 1; /* reset flag (0 if read the weights) */");
  /* find the layer name, check the layer size and load */
  layer = bbd->layers;
  while (layer != (LD_PTR)NULL)
    {
      connection = layer->inputs_from;
      while (connection != (CD_PTR)NULL)
	{
	  GenCode( "\n   if (error_code = BPread_weights(fd,name_string,\"%s\",name_string2,\"%s\",size,%d,b%d_%s,&not_found))",
		  sym_name(connection->to),
		  connection->array_name,
		  _connection_size(connection,bbd),
		  number, connection->array_name);
	  GenCode( "\n      return(error_code);");
	  /* next */
	  connection = connection->next;
	}/* end while */
      /* NEXT */
      layer = layer->next;
    }/* end while */
  GenCode( "\n   if(not_found) {");
  GenCode( "\n     sprintf(error_string, \"\\nUnknown connection array in file\\n\");");
  GenCode( "\n     return(DFERROR);");
  GenCode( "\n   }/* end if */");
  GenCode( "\n  }/* end for */");
  
  GenCode( "\n return(0);");
  GenCode( "\n }/* end if */");

}/* end delcare_load_bb */

/* _declare_load_black_box: Declare code to load a black box
                           from a file (random access) */
void _declare_load_black_box(network)
     ND_PTR network;
{
  extern FILE *stream;

  GenCode(
	  "\n\nstatic load_black_box(fd, name, key, toc)");
  GenCode(
	  "\n int fd;\n char *name, *key;\n TOC_PTR toc;\n{");
  GenCode( "\n int table_size = toc->size;");
  GenCode( "\n off_t *address_table = toc->address_table;");
  GenCode( "\n off_t *end_of_table;");
  GenCode( "\n int n_layers, n_connections, size, counter, error_code, not_found;");
  GenCode( "\n char name_string[100], name_string2[100]; /* for reading names */\n");
  GenCode( "\n /* record end of address table */");
  GenCode( "\n end_of_table = address_table + table_size;");

  GenCode( "\n  /* find key name in the file */");
  GenCode( "\n  do {");
  GenCode( "\n   if (address_table == end_of_table) {");
  GenCode( "\n    sprintf(error_string, \"\\nUnable to find %%s as %%s.\\n\", name, key);");
  GenCode( "\n    return(DFERROR);");
  GenCode( "\n   }/* end if */");
  GenCode( "\n   lseek(fd, *address_table++, 0);");
  GenCode( "\n   BPread_string(fd, name_string);");
  GenCode( "\n  }while (strcmp(key, name_string) != 0);\n");

  map_dictionary(declare_load_bb, network->lookup);

  GenCode( "\n}/* end load_black_box */");

}/* end _declare_load_black_box  */

/* declare_call_load_bb: Calls to random access finds of black box info. */
static declare_call_load_bb(bbd)
     BD_PTR bbd;
{
  extern FILE *stream;

  GenCode( "\n if (error_code = load_black_box(fd, \"%s\", \"%s\", toc))",
	  bbd->name, bbd->name);
  GenCode( "\n\treturn(error_code);");
  
}/* end declare_call_load_bb */

/* _declare_load: Declare the function that loads the weights from disk. */
void _declare_load(network, file)
     ND_PTR network;
     char *file;
{
  extern FILE *stream;
  int header_size;
  int n_bbds = network->n_bbds;
  int counter;

  GenCode(
	  "\n\nint %s_load_network(filename)\n  char *filename;\n{",
	  file);
  GenCode( "\n int fd, error_code;");
  GenCode( "\n TOC_PTR toc; /* table of contents */\n");

  GenCode( "\n\n /* open the file */");
  GenCode( "\n fd = open(filename, 0);");
  GenCode( "\n if (fd == -1) {");
  GenCode(
	  "\n  sprintf(error_string, \"\\nUnable to open %%s.\\n\", filename);");
  GenCode( "\n  return(FERROR);");
  GenCode( "\n }/* end if */");

  GenCode( "\n /* read header */");
  GenCode( "\n toc = read_header(fd);");
  GenCode( "\n if (toc == (TOC_PTR)NULL) return(DFERROR);"); 

  /* read in each black box */
  map_dictionary(declare_call_load_bb, network->lookup);

  GenCode( "\n close(fd);");
  GenCode( "\n return(0);");
  GenCode( "\n}/* end %s_load_network */", file);

}/* end _declare_load */

/* _declare_load_data: Function is like load_network but just loads
                       the black box passed.
 */		      
void _declare_load_data(network)
     ND_PTR network;
{
  extern FILE *stream;
  int header_size;
  int n_bbds = network->n_bbds;
  int counter;

  header_size = 4 * sizeof(int); /* bytes in header */

  GenCode(
  "\n\nstatic %s\n  char *name, *key, *filename;\n{",
	  "int load_black_box_data(name, key, filename)");

  GenCode( "\n int fd;");
  GenCode( "\n TOC_PTR toc;");

  GenCode( "\n\n /* open the file */");
  GenCode( "\n fd = open(filename, 0);");
  GenCode( "\n if (fd == -1) {");
  GenCode(
	  "\n  sprintf(error_string, \"\\nUnable to open %%s.\\n\", filename);");
  GenCode( "\n  return(FERROR);");
  GenCode( "\n }/* end if */");

  GenCode( "\n /* read header */");
  GenCode( "\n toc = read_header(fd);");
  GenCode( "\n if (toc == (TOC_PTR)NULL) return(DFERROR);"); 

  GenCode( "\n return(load_black_box(fd, name, key, toc));");

  GenCode( "\n\n}/* end load_black_box_data */");

}/* end _declare_load_data */

