/* ======================================================================
   A simple implementation of a Hopfield Net.
   Essentialy simplest-hop.c plus some error checking.
   ====================================================================== */
#include <stdio.h>
#include <math.h>
#include "sim-basics.h"

#define U_th 	0		/* Unit's Threshold */ 
#define SCOM 	argv[0]

#define MAX_UNITS 500		/* maximum no of units */
int 	Nunits=0;		/* current number of units in network */
int	output[MAX_UNITS];	/* output of units */
int	temp_output[MAX_UNITS];	/* temporary output array */
int	flag[MAX_UNITS];	/* flag used by Flip */
int	T[MAX_UNITS][MAX_UNITS];/* connection strenghts */

#define MAX_PATS 50		/* maximum no of patterns  */
int	Npats=0;		/* current number of stored pats */
int	pat[MAX_PATS][MAX_UNITS]; /* stored pats,for comparisons */

/*======================================================================*/
#define HStepHop "step the network synchronously"
/*======================================================================*/
/*   calculate the activation for each unit, determine its future
     state and place it in temp_output array.
     Finally, copy the future states into the current states.
*/     
StepHop()
{
  int unit, from, act;
  
  for( unit=0; unit < Nunits; ++unit ) {
    for( from=act=0; from < Nunits; ++from ) 
      act += T[unit][from] * output[from];
    temp_output[unit] = (act > U_th ? 1 : act < U_th ? 0 : output[unit]);
  }

  for( unit=0; unit < Nunits; ++unit )
    output[unit] = temp_output[unit];
  
  return OK;
}

/*===========================================================================*/
#define HStorePats "MFILE\n\
  read the patterns in MFILE and set the T weights according to Hopfield's\n\
  storage algorithm"
/*===========================================================================*/
/* The pattern file is expected to be of the following form:
       Nunits
       b b b b b b ...b b b b b 
       b b b .........b b b b b
          ..............
       b b b .........b b b b b
                              ^Num zeros and ones
   Where Nunits is the dimensionality of the patterns and 
   each b is either zero or one.
*/
#define MFILE argv[1]
StorePats( argc, argv )
     int argc; char *argv[];
{
  int kf, i, j, unit, npat, Tij;
  FILE *file;

  /* try opening the indicated file */
  if( NULL == (file=fopen( MFILE, "r" )))
     nerror2( return OK, "%s: Can't find file '%s'\n", SCOM, MFILE );
     
  /* Read the Dimensionality of the patterns and check 0 <= Nunits < MAX_UNITS */
  if( 1 != fscanf( file," %d", &unit) 
      || unit < 0 
      || unit >= MAX_UNITS ){
    fclose(file);
    nerror3( return OK, "%s: expected int < %d. corrupt pattern file '%s'?\n", 
	     SCOM, MAX_UNITS, MFILE );
  }
  Nunits = unit;

  printf( "reading patterns\n" );
  npat = unit = 0;
  while( EOF != ( kf = fscanf( file, " %d ", &(pat[npat][unit])) )
	 && kf==1 				/* check if read was succesful */
	){
    printf( "%d", pat[npat][unit] );
    ++ unit;
    if( unit % Nunits == 0 ) { 			/* end of this pattern */
      printf( " : pat %d\n", npat );		/* say so */
      ++npat;					/* go for next pattern */
      unit = 0; 				
      if( npat >= MAX_PATS ) 			/* check max pattern limit */
	nerror1( break, "%s: Too many patterns. Rest ignored\n", SCOM );
    }
  }
  fclose( file );
  if( unit % Nunits != 0 || kf != EOF ) { 	/* Error while reading file */
    Nunits = Npats = 0;
    nerror2( return OK, "%s: corrupt pattern file '%s'?\n", SCOM, MFILE );
  }
  Npats = npat;
  
  /* Calculate the T weights using Hopfield's algorithm */
  for( i=0; i < Nunits; ++i ) {
    T[i][i] = 0;
    for( j=0; j < i; ++j ) {
      for( npat=Tij=0; npat < Npats; ++npat )
	Tij += (2 * pat[npat][i] - 1 ) * (2 * pat[npat][j] - 1 );
      T[i][j] = T[j][i] = Tij;
    }
  }
  return OK;
}

/*===========================================================================*/
#define HSetPat "PAT_NO\n\
    set the state of the network to that of pattern PAT_NO"
/*===========================================================================*/
#define  PAT_NO 1
SetPat( argc, argv )
     int argc; char *argv[];
{
  int PatNo, unit;

  SCAN_POSP_INT_ARGV(return OK, PAT_NO, PatNo );
  if( PatNo >= Npats )
    nerror2( return OK, "%s: '%d' invalid memory no\n", SCOM, PatNo );
  for( unit=0; unit < Nunits; ++unit )
    output[unit] = pat[PatNo][unit];

  return OK;
}

/*===========================================================================*/
#define HFlip "FLIP_NO [SEED]\n\
    select randomly FLIP_NO units, and flip their state. SEED is optional\n\
    and used to be able to reproduce random changes."
/*===========================================================================*/
#define FLIP_NO 1
#define SEED 2
Flip( argc, argv )
     int argc; char *argv[];
{ 
  int unit, FlipNo, fliped, Seed;

  SCAN_POSP_INT_ARGV( return OK, FLIP_NO, FlipNo );
  if( FlipNo > Nunits )
    nerror2( return OK,"%s: '%d', Can't flip so many units\n", SCOM, FlipNo );

  if( argc > SEED ) {
    SCAN_INT_ARGV( return OK, SEED, Seed );
    set_seed( Seed );
  }    

  for( unit=0; unit < Nunits; ++unit )
    flag[unit] = TRUE;
    
  fliped = 0;
  while( fliped < FlipNo ) {
    unit = int_uniform_dev( 0, Nunits-1 );
    if( flag[unit] ) {
      ++fliped;
      output[unit] = (1 - output[unit]);
      flag[unit] = FALSE;
    }
  }
  return OK;
}


/*===========================================================================*/
#define HShowState "print the state of the units in the network"
/*===========================================================================*/
ShowState()
{ 
  int unit;

  for( unit=0; unit < Nunits; ++unit ) 
    printf( "%d", output[unit] );
  printf( "\n" );

  return OK;
}

/*===========================================================================*/
#define HComparePat "PAT_NO\n\
    Compare the state of the network to that of pattern PAT_NO.\n\
    places where they disagree are shown as a '*'"
/*===========================================================================*/
#define PAT_NO 1
ComparePat( argc, argv )
     int argc; char *argv[];
{
  int PatNo, unit, count;

  SCAN_POSP_INT_ARGV(return OK, PAT_NO, PatNo );
  if( PatNo >= Npats )
    nerror2( return OK, "%s: '%d' invalid memory no\n", SCOM, PatNo );

  for( unit=count=0; unit < Nunits; ++unit ) 
    if( output[unit] == pat[PatNo][unit] ) {
      ++count;
      printf( "%d",output[unit] );
    }
    else printf( "*" );
  printf( " : %3d\n", count );

  return OK;
}

/*===========================================================================*/
#define HComparePats "Compare the state of the network to that of every\n\
    stored pattern."
/*===========================================================================*/
ComparePats()
{
  static char *iargv[] ={ "ComparePat", "number" };
  int 	pat;

  for( pat=0; pat < Npats; ++pat ) {
    sprintf( iargv[1],"%d", pat );
    ComparePat( 2, iargv );
  }
  return OK;
}

/* ======================================================================
   install the above defined commands with their help descriptions 
   (and their abreviations) and call the interpreter loop
   ======================================================================*/
main( )
{
  ba_install_cmd( "StepHop", 	StepHop,	HStepHop );
  ba_install_cmd( "hop", 	StepHop,	"abrev for 'StepHop'" );
  ba_install_cmd( "StorePats", 	StorePats,	HStorePats );
  ba_install_cmd( "store",	StorePats,	"abrev for 'StorePats'" );
  ba_install_cmd( "SetPat", 	SetPat,		HSetPat );
  ba_install_cmd( "set", 	SetPat,		"abrev for 'SetPat'" );
  ba_install_cmd( "Flip", 	Flip,		HFlip );
  ba_install_cmd( "f", 		Flip,		"abrev for 'Flip'" );
  ba_install_cmd( "ShowState", 	ShowState,	HShowState );
  ba_install_cmd( "sh", 	ShowState,	"abrev for 'ShowState'" );
  ba_install_cmd( "ComparePat", ComparePat,	HComparePat );
  ba_install_cmd( "cp", 	ComparePat,	"abrev for 'ComparePat'" );
  ba_install_cmd( "ComparePats",ComparePats,	HComparePats );
  ba_install_cmd( "cps", 	ComparePats,	"abrev for 'ComparePats'" );

  ba_start_msg( "Simple Hopfield Net Simulator." );
  ba_prompt( "SHoP> " );
  ba_interpreter_loop( stdin );
}
