#include "net.h"
#include "bp.h"


 // Rescorla-Wagner model of classical conditioning
 // Able to model exctinction and blocking
 // based on: Rescorla, R.A. and Wagner, A.R.: A theory of Psvlovian conditioning:
 // variations in the effectiveness of reinforcement and non-reinforcement.
 // in: Black, A.H. and Prokasy W.F.(Eds.) Classical conditioning (Vol 2). New York:
 // Appleton-Century-Crofts


#define NoCS	5
#define NoUCS	1
#define NoCR	1


float a=0.3, b=0.3;

 /*********************  Set definitions *********************/

SetPtr UCS, CS, CR;

   /*************** Input pattern definitions *****************/

int P0[5]=
   {1, 0, 1, 1, 0};

int P1[5]=
   {0, 0, 0, 1, 0};

int P2[5]=
   {1, 0, 0, 1, 1};

int P3[5]=
   {0, 1, 1, 0, 0};


void SetPattern (int p[])
 {
 int i;
 for (i=1; i< CS->Length; i++)
   SetMember (CS, i)->Output = p[i];
 SetMember (UCS, 0)->Output = p[0];
 }

void SetInputPattern (int u)
  {
  switch (u)
    {
    case 0: SetPattern (P0); break;
    case 1: SetPattern (P1); break;
    case 2: SetPattern (P2); break;
    case 3: SetPattern (P3); break;
    }
  }

   /*****************  'Traverse' functions ****************/

void ComputeOutputCR(UnitPtr unit)
  {
  unit->Netinput=WeightedSum(unit);
  unit->Free2=SetMember (UCS,0)->Output;
  unit->Output=unit->Free2+unit->Netinput;
  }

void R_W (LinkPtr link)
  {
  link->Weight +=
     a*b*(link->To->Free2 - link->To->Netinput)*
	  link->From->Output;
  }

/****************** Commands ******************/

void reset ()
  {
    TraverseSet (UCS, ZeroOutput);
    TraverseSet (CS, ZeroOutput);
    SetMember (CR,0)->Output=0.0;
    SetMember (CR,0)->Netinput=0.0;
    TraverseSetLinks (CR, ZeroWeight);
   }


/****************** 'Build' & 'Run' ******************/

void Build ()
  {
  Squaresize=60;
  UCS = MakeSet ("UCS", 1, NoUCS);
  CS = MakeSet ("CS", 1, NoCS);
  CR = MakeSet ("CR", 1, NoCR);

  InstallCommand ("reset", reset);

  ConnectSets (CS, CR, 0.2, -0.2);
  SetLinkParam (SetMember (UCS,0),SetMember (CR,0), FREE3, 0.0, 0.0);
  }

void Run ()
  {
  TraverseSet (CR, ComputeOutputCR);
  TraverseSetLinks (CR, R_W);
  }



