;;; ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
;;;
;;;     			     ANSIL
;;; 		      Advanced Network Simulator In Lisp
;;; 		      ----------------------------------
;;;
;;;				  Written By
;;;			       Peter J. Angeline
;;;			      Gregory M. Saunders
;;;
;;;			   The Ohio State University
;;;		       Laboratory for AI Research (LAIR)
;;;			     Columbus, Ohio, 43210
;;;
;;;			      Copyright (c) 1991
;;;
;;; ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


;;; ---------------------------------------------------------------------- 
;;; File: raam1.lisp -  Setting up & training raams.
;;;
;;; License: This code may be distributed free of charge.  This code may not be
;;; incorporated into any production software of any type or any code which is
;;; sold without the express written consent of the authors.  Modifications and
;;; inclusions for other types of software are permitted, provided that this
;;; header remains unchanged and included in all subsequent versions of the
;;; code.
;;;
;;; Disclaimer: The authors of this code make no claims as to its
;;; validity, and are in no way responsible for any errors contained within.
;;;
;;; Creation date: Aug 5, 1990
;;; Version: 0.95
;;; ----------------------------------------------------------------------
;;;
;;;
;;;                         Change Log
;;;                         ----------
;;;
;;; 10/23/90 - (pja) changed set-raam to do the keys correctly.
;;;
;;; 10/26/90 - (pja) added John's suggestion to initialize the raam to have fft
;;; related coefficients for the encoder and decoder.  Added function
;;; raam-fft-init.  
;;;
;;; 11/21/90 - (pja) added a terminal test to RAAM.  The terminal test involves
;;; the addition of a bit for each distinct symbol which is decoded from the
;;; hidden rep.  Thus if valence is 3 then 3 extra bits will be added to the
;;; output of the raam.  The job of the decoder now becomes to not only
;;; reconstruct the n children of this non-terminal but to identify which are
;;; terminals and which are not.  All representations being stored in the tree
;;; will be without this bit, it will be a phenomena of the execution only.
;;;
;;; 11/23/90 - (pja) Made a reconstruction option for termination in the
;;; raam-train routine.  If the key RECONSTRUCT is non-null and non-zero then
;;; the termination condition of the raam will be that all of the trees can be
;;; reconstructed accurately.  Since we only are worried about reconstructing
;;; into  the correct tree of terminals, the reconstruction will be complete
;;; when each tree is of the correct shape and each terminal reconstructed is
;;; within the raam terminal tolerance of the terminal desired.
;;;
;;; 12/18/90 - (gms) Removed default representation width and default valence.
;;;
;;; 02/06/91 - (gms) Expanded code to include stack raams.  New function for
;;; user is "create-stack", which is very similar to "create-raam".  There was
;;; a small annoyance in generalizing the code to stacks, namely, that before
;;; we had a single representation width (rep-width).  For stacks, we need an
;;; additional slot in the raam structure (terminal-width) to allow terminals
;;; and nonterminals to have different widths.  This change propagated itself
;;; throughout the code, especially in raam-train-pattern, raam-count-errors,
;;; and raam-count-errors-for-subtree.  The decode function was updated in a
;;; similar manner.  See further comments in those functions.
;;;
;;; 02/18/91 - (pja) Fixed raam-train so when a tree is within tolerance but
;;; not reconstructed it keeps training until reconstructed.  Also, altered
;;; termination test to work with stack code.  Also, changed 'empty to '*empty*
;;; to avoid possible conflict with users.
;;;
;;; 02/26/91 - (gms) Added code to reconstruct trees from their reduced
;;; representations.  New functions:  reconstruct-tree, reconstruct-trees,
;;; reconstruct-stack-tree, reconstruct-stack-trees.  Also added function to
;;; write out the representations to a file (write-raam-reps), and updated
;;; symbol-name-by-value. 
;;;
;;; 02/28/91 - (gms) Moved show-vector, format-vector, and format-vectors from
;;; this file to matrixbp.lisp.  See comments there for reasons.  Also added
;;; encode function to encode a symbolic tree from scratch.  We really should
;;; split up such junk into another file, but I don't feel like updating the
;;; manual.
;;;
;;; 02/28/91 - (gms) Split raam into two files (main code & analysis).  Also
;;; changed nil to '*empty* in raam creation routines.  Then replaced '*empty*
;;; with the variable *nilvar*, w/ value '*empty*.
;;;
;;; 03/26/91 - (pja) Added the routine REVERT-RAAM-NETWORK which resets the
;;; raam to the state it was in prior to training.  Both the initial conditions
;;; of the weights of the network and the number of epochs is reverted.
;;;
;;; 04/02/91 - (pja) Altered the structure and code so the choice of whether to
;;; weight each tree and subtree in the training set the same or by the number
;;; of appearances in the full training set (i.e. before removing duplicates).
;;; Added the slots raam-pattern-copies and raam-pattern-repeat to the raam.
;;; Added :pattern-repeat to create-raam and set-raam as well.
;;;
;;; 04/09/91 - (pja) Added some logic so that the raam-compute-errors is only
;;; done when we want to write something out to the terminal.  This saves a
;;; little time. 
;;;
;;; 04/16/91 - (pja) Allowed the key NIL-VALUE to be nil in create-raam which
;;; prevents *empty* from being added to the symbol table.  
;;;
;;; 04/18/91 - (pja) Added ability to designate a specific value for a
;;; non-terminal of a raam.  raam-fix-tree-rep is a function which takes a
;;; raam, a tree and a list of values and stores that value for the rep of the
;;; tree in the symbol table.  The raam structure will has a new slot called
;;; static-table which stores the static rep of each tree for which the rep is
;;; to be fixed.  Training of a fixed tree is in two parts.  First, the static
;;; rep of the tree is placed at the middle layer and forward proped to the
;;; output layer.  The just the weights in the decoder are updated.  Next, the
;;; weights in the encoder ar update but forward proping the actual inputs to
;;; the middle layer and the fixed rep used as a target value.  
;;; ----------------------------------------------------------------------


;;; ----------------------------------------------------------------------
;;;
;;; Package info, exported functions, etc.
;;;
;;; ----------------------------------------------------------------------

(proclaim '(optimize speed))

(provide 'mraam)
(in-package 'mraam)

(export '(create-raam create-stack raam-training-set raam-network
	  raam-train raam-train-pattern raam-symbol-table 
	  symbol-name-by-index symbol-value-by-index 
	  symbol-index-by-name symbol-value-by-name
	  symbol-name-by-value list-to-vector
	  raam-add-tree raam-remove-tree set-rates
	  randomize-terminals set-terminal set-terminals 
	  set-raam raam-rep-width raam-fft-init coerce-vector
	  raam-symbol-table raam-total-epochs raam-name
	  randomize-raam-network revert-raam-network raam copy-raam copy-object
	  raam-network *empty* verify raam-pattern-copies raam-pattern-repeat
	  raam-fix-tree-rep raam-train-fixed-pattern raam-static-table))


(require 'loop)
(require 'linear)
(require 'mbp)

(use-package 'loop)
(use-package 'linear)
(use-package 'mbp)


;;; ----------------------------------------------------------------------
;;;
;;; Constants
;;;
;;; ----------------------------------------------------------------------

(eval-when (compile load)
	   (defconstant *element-type* mbp::*element-type*)
	   (defconstant zero (coerce 0 *element-type*))
	   (defconstant half (coerce 0.5 *element-type*))
	   (defconstant  one (coerce 1 *element-type*))
	   (defconstant *nilvar* '*empty*))

(defconstant *new-rep-level* 1)             ; Raam level from which new reps
					    ; will be read 
(defconstant *dummy-value*
  (coerce -1 *element-type*))              ; Dummy value for initial
					   ; representations 

;;; ----------------------------------------------------------------------
;;;
;;; Default raam parameters
;;;
;;; ----------------------------------------------------------------------

(defvar *default-name* "GENERIC-RAAM")	     ; default name for raam

(defvar *default-nil-value*		     ; value to use in vector for nil
  (coerce .5 *element-type*))		

(defvar *default-terminal-tolerance*	     ; terminal tolerance
  (coerce .2 *element-type*))		

(defvar *default-nonterminal-tolerance*	     ; nonterminal tolerance
  (coerce .05 *element-type*))
					


;;; ----------------------------------------------------------------------
;;;
;;; Raam structure
;;;
;;; ----------------------------------------------------------------------

(eval-when (compile load)
 (eval
  `(defstruct raam
     "The \"then some\" part..."
     (valence        0     :type integer) ; trees in the training set require a
					  ; fixed valence  
     (rep-width      0     :type integer) ; the number of nodes on the hidden
					  ; layer  
     (terminal-width 0     :type integer) ; terminal rep width for stacks
     (raam-width     0     :type integer) ; set automatically to (* valence
					  ; rep-width)  
     (network        nil   :type network) ; the network associated with this
					  ; raam 
     (init-terminals nil   :type list)	  ; the initial representations for the
					  ; terminals  
     (term-tol       ,zero :type ,*element-type*) ; tolerance for terminals
     (nonterm-tol    ,zero :type ,*element-type*) ; tolerance for non-terminals
     (name           ""    :type string) ; name for this raam
     (creation-date  ""    :type string) ; time of creation
     (total-epochs   0     :type integer) ; total number of passes through the
					  ; training set  
     (org-trees      nil   :type list)	; OT = original trees
     (fixed-trees    nil   :type list)	; FT = OT converted to fixed-valence
     (all-trees      nil   :type list)	; AT = all subtrees of FT
     (symbol-table   nil   :type list)	; Sy = assoc list of all
					; terminals/trees 
     (training-set   nil   :type list)	; Training set for the raam, w/ list of
					; indices into symbol table 
     (pattern-copies nil   :type list)  ; # of times a pattern appears in
					; training set
     (pattern-repeat nil   :type list)	; if t then multiply error for a
					; pattern by the cooreponding value in
					; pattern-copies 
     (static-table   nil   :type list)  ; indexes of static tree reps
     (history        nil   :type list))) ; History of changes made to raam

 (eval
  `(defstruct history-record
     (epoch     0     :type integer)
     (attribute nil   :type atom)
     (value     ,zero :type ,*element-type*))))

;;; ----------------------------------------------------------------------
;;;
;;; Support functions
;;;
;;; ----------------------------------------------------------------------

(defun extend (x v)
  (let* ((r (- v (length x)))
	 (nil-list (make-list r :initial-element *nilvar*)))
    (append x nil-list)))

(defun round-off (x &optional (p 3))
  (let ((y (expt 10 p)))
    (coerce (/ (round (* x y)) y)
	    (type-of x))))

(defun copy-object (x)
  (let ((tempx 
	 (with-open-stream (in (make-string-input-stream (write-to-string x)))
			   (read in))))
    (if (typep tempx 'raam) (fix-raam-arrays tempx))
    (if (typep tempx 'network) (mbp::fix-network-arrays tempx)) ;???
    tempx))

(defun coerce-vector (r &optional (element-type *element-type*))
  (coerce
   (loop for x in (coerce r 'list) collect
	 (coerce x element-type))
   `(simple-array ,element-type)))

(defun date (&aux (d (multiple-value-list (get-decoded-time))))
  (let ((second (nth 0 d))
	(minute (nth 1 d))
	(hour   (nth 2 d))
	(date   (nth 3 d))
	(month  (1- (nth 4 d)))
	(year   (nth 5 d))
	(day    (nth 6 d)))
    (format nil "~a ~a ~a ~2,'0d:~2,'0d:~2,'0d ~a"
	    (nth day  '("Mon" "Tues" "Wed" "Thurs" "Fri" "Sat" "Sun"))
	    (nth month '("Jan" "Feb" "Mar" "Apr" "May" "Jun" "Jul" "Aug" "Sep"
			 "Oct" "Nov" "Dec")) 
	    date hour minute second year)))

(defun vector-or-list? (x)
  (and x
       (or (vectorp x)
	   (listp x))))

;;; ----------------------------------------------------------------------
;;;
;;; Vector support
;;;
;;; ----------------------------------------------------------------------

(defun random-vector (n &optional (x one))
  (make-array (list n) 
	      :element-type (type-of x)
	      :initial-contents
	      (loop for i from 1 to n collect (round-off (random x)))))

(defun constant-vector (n val &optional (elt-type *element-type*))
  (make-array (list n) 
	      :element-type elt-type
	      :initial-element (coerce val elt-type)))

(defun zero-vector (n)
  (constant-vector n 0))

(defun dummy-vector (n)
  (constant-vector n *dummy-value*))

(defun dummy-vector? (x)
  (equalp x (dummy-vector (length x))))

;;; ----------------------------------------------------------------------
;;;
;;; Functions dealing with terminals
;;;
;;; ----------------------------------------------------------------------

(defun terminals (x)
  (remove nil (remove-duplicates (terminals2 x))))

(defun terminals2 (x)
  (cond ((equal x *nilvar*) ())
	((terminal? x) (list x))
	(t (append (terminals (car x))
		   (terminals (cdr x))))))

(defun terminal? (x)
  (atom x))

(defun non-terminal? (x)
  (not (terminal? x)))

(defun nonterm-or-empty? (x)
  (or (not (terminal? x))
      (equal x *nilvar*)))

(defun terminals-of-raam (raam &aux (symtab (raam-symbol-table raam)))
  (loop for i from 0 to (1- (length symtab))
	for x = (symbol-name-by-index symtab i)
	when (terminal? x)
	collect x))

(defun undefined-terminals (raam)
  (let ((terms (terminals-of-raam raam))
	(symtab (raam-symbol-table raam)))
    (loop for x in terms
	  for v = (symbol-value-by-name symtab x)
	  when (dummy-vector? v)
	  collect x)))

(defun verify-all-terminals-defined (raam &aux (undefined-terms
						(undefined-terminals raam))) 
  (if undefined-terms
      (error (format nil "The following terminals are undefined:  ~a"
		     undefined-terms)))) 

;;; ----------------------------------------------------------------------
;;;
;;; Functions for picking apart a set of trees to obtain all the unique
;;; subtrees.  This is necessary because training a raam on an arbitrary
;;; tree entails having it learn all of the subtrees.
;;;
;;; ----------------------------------------------------------------------

;; All subtrees of a SINGLE tree (i.e., list)
(defun subtrees (x) 
  (remove-duplicates (subtrees-with-duplicates x) :test #'equal :from-end t))

(defun subtrees-with-duplicates (x)
  (if (atom x) ()
    (append
     (apply #'append (loop for y in x 
			   when (listp y)
			   collect (subtrees-with-duplicates y)))
     (list x))))

;; Convert tree to fixed valence.  This is needed because a raam combines k
;; n-dimensional vectors into a single n-dimensional vectors, where k is some
;; CONSTANT.  Hence the trees must have constant valence (branching factor).

(defun fix-valence (x v)
  (if (atom x) x
    (let ((res (loop for y in x collect (fix-valence y v)))
	  (i -1))
      (cond ((<= (length res) v) (extend res v))
	    (t (fix-valence (loop while (> (length res) v) do
				(incf i)
				(if (= i (length res)) (setf i 0))
				(setf res (append (subseq res 0 i)
						  (list (subseq res i (+ i v)))
						  (subseq res (+ i v)
							  (length x)))) 
				finally (return res))
			  v))))))


;;; ----------------------------------------------------------------------
;;;
;;; Creating a raam
;;;
;;; This set of functions creates a raam.  All parameters to the main 
;;; function should be passed as keywords.  Defaults are given to ones
;;; not supplied.  This comment is worthless.
;;;
;;; ----------------------------------------------------------------------

(defun create-terminal-symbol-table (terminal-codes rep-width terminal-width
				     &aux codes terminals) 
  (setf terminal-codes (reverse terminal-codes))
  (setf codes (loop for x in terminal-codes collect (coerce (cadr x) 'list)))
  (setf terminals (loop for x in terminal-codes collect (car x)))
  (loop for i from 0 to (1- (length codes)) do                      ;gms 02/06
	(unless (or (and (nonterm-or-empty? (car (nth i terminal-codes)))
			 (= rep-width (length (nth i codes))))
		    (and (terminal? (car (nth i terminal-codes)))
			 (= terminal-width (length (nth i codes)))))
		(error "All terminal codes must have length (~a)" rep-width)))
  (pairlis terminals (loop for x in codes collect (coerce-vector x))))

(defun create-raam (&key (trees          nil)
			 (valence        nil) ;*default-valence* gms12/18
			 (net            nil)
			 (terminal-codes nil)
			 (rep-width      (if terminal-codes
					     (length (cadar terminal-codes))
					   nil))  ;*default-rep-width* gms12/18
			 (terminal-width nil) ;gms 02/06
			 (term-tol       *default-terminal-tolerance*)
			 (nonterm-tol    *default-nonterminal-tolerance*)
			 (name           *default-name*)
			 (random-weights   t)
			 (random-terminals nil)
			 (rate (coerce 1.0 *element-type*))
			 (momentum (coerce 0.5 *element-type*))
			 (sp-add (coerce 0.1 *element-type*))
			 (adjust? nil)
			 (repeat? t)
			 (always-bp? t)
			 (nil-value        *default-nil-value*))
  (unless valence  ; gms12/18
	  (error "Desired valence for trees must be supplied explicitly"))
  (unless rep-width ; gms12/18
	  (error "Representation width not specified"))
  (unless (null nil-value)
	  (setf nil-value (coerce nil-value *element-type*))) ;pja 04/02&04/16
  (setf terminal-width rep-width) ;gms 02/06
  (let ((x (make-raam))
	(nil-vector (unless (null nil-value)
			    (make-array rep-width :initial-element nil-value)))
	(r (* valence rep-width)))
    (if (loop for x in trees when (atom x) collect x)
	(error "Trees must all be lists, but an atom was detected."))
    (unless net
	    ; the output layer is longer due to the terminal test bits
	    (setf net (create-network r rep-width (+ r valence))) 
	    (if random-weights (randomize-network net)))
    (if (and (not terminal-codes) random-terminals)
	(setf terminal-codes
	      (loop for x in (terminals trees) 
		    collect (list x (random-vector rep-width))))) 
    (setf (raam-valence x) valence)
    (setf (raam-rep-width x) rep-width)
    (setf (raam-terminal-width x) terminal-width)
    (setf (raam-raam-width x) r)
    (setf (raam-network x) net)
    (setf (raam-name x) name)
    (setf (raam-creation-date x) (date))
    (unless (or (null nil-value)
		(member *nilvar* terminal-codes :key #'car) ;???empty
	    (setf terminal-codes
		  (cons (list *nilvar* nil-vector) terminal-codes))))
    (setf (raam-symbol-table x)
	  (create-terminal-symbol-table terminal-codes
					rep-width terminal-width)) ;gms 02/06
    (setf (raam-init-terminals x) (copy-object (raam-symbol-table x)))
    (setf (raam-term-tol x) (coerce term-tol *element-type*))
    (setf (raam-nonterm-tol x) (coerce nonterm-tol *element-type*))
    (setf (network-rate (raam-network x)) (coerce rate *element-type*))
    (setf (network-momentum (raam-network x)) (coerce momentum *element-type*))
    (setf (network-sp-add (raam-network x)) (coerce sp-add *element-type*))
    (setf (network-adjust? (raam-network x)) adjust?)
    (setf (network-always-bp? (raam-network x)) always-bp?)
    (setf (raam-pattern-repeat x) repeat?)
    (loop for y in trees do (raam-add-tree x y))
  x))


;; store a static representation for a tree in the raam.  This is done by
;; creating a symbol table of static values.  
(defun raam-fix-tree-rep (raam tree rep)
  (unless (= (length rep) (raam-rep-width raam))
     (error "RAAM-FIX-TREE-REP:: Length of rep provided is incorrect."))
  (push (cons tree (coerce-vector rep)) (raam-static-table raam)))

;;; ----------------------------------------------------------------------
;;;
;;; Initialization
;;;
;;; ----------------------------------------------------------------------
;; randomizes the weight matrices in the raam's network.  
(defun randomize-raam-network (raam)
  (randomize-network (raam-network raam))
  (add-to-history raam 'randomize-weights t)
  (setf (raam-total-epochs raam) 0))

(defun revert-raam-network (raam)
  (revert-network (raam-network raam))
  (add-to-history raam 'randomize-weights t)
  (setf (raam-total-epochs raam) 0))

(defun make-matrix-fft (mat range &optional inverse)
  (loop for i from 1 to (array-dimension mat 0)
	for j from 1 to (array-dimension mat 1) do
	(setf (aref mat (1- i) (1- j))
	      (coerce (* range
			 (cos (* pi i j (/ 1.0 (array-dimension mat 0))
				 (if inverse -2.0 2.0))))
		      *element-type*))))

(defun raam-fft-init (raam &key (encode-range 20.0) (decode-range 3.0))
  (make-matrix-fft (cadr (network-weights (raam-network raam))) decode-range t)
  (make-matrix-fft (car (network-weights (raam-network raam))) encode-range))

;;; ----------------------------------------------------------------------
;;;
;;; Stacks - code similar to raam code, but for stacks
;;;
;;; rep-width      = RAAM hidden layer width
;;; terminal-width = width for symbols
;;;   e.g. 4-3-4 stackraam = :rep-width 3 :terminal-width 1
;;;
;;; ----------------------------------------------------------------------


(defun create-stack (&key (trees          nil)
			  (net            nil)
			  (terminal-codes nil)
			  (terminal-width (if terminal-codes
					      (length (cadar terminal-codes))
					    nil))
			  (rep-width      nil)
			  (term-tol       *default-terminal-tolerance*)
			  (nonterm-tol    *default-nonterminal-tolerance*)
			  (name           *default-name*)
			  (random-weights   t)
			  (random-terminals nil)
			  (rate (coerce 1.0 *element-type*))
			  (momentum (coerce 0.5 *element-type*))
			  (sp-add (coerce 0.1 *element-type*))
			  (repeat? t)
			  (adjust? nil)
			  (always-bp? t)
			  (nil-value        *default-nil-value*))
  (unless rep-width			; gms12/18
	  (error "Representation width not specified"))
  (unless terminal-width 
	  (error "Atom width not specified"))
  (let ((x (make-raam))
	(nil-vector (make-array rep-width :initial-element nil-value))
	(r (+ rep-width terminal-width))) ;(* valence rep-width)
    (if (loop for x in trees 
	      when (atom x)
	      collect x)
	(error "Trees must all be lists, but an atom was detected."))
    (unless net	     ; the output layer is longer due to the terminal test bits
	    (setf net (create-network r rep-width (+ r 2))) ;valence=2 assumed!
	    (if random-weights (randomize-network net)))
    (if (and (not terminal-codes) random-terminals)
	(setf terminal-codes
	      (loop for x in (terminals trees) 
		    collect (list x (random-vector terminal-width)))))
    (setf (raam-valence x) 2)		;valence=2 is assumed for stacks!!!
    (setf (raam-rep-width x) rep-width)
    (setf (raam-terminal-width x) terminal-width)
    (setf (raam-raam-width x) r)
    (setf (raam-network x) net)
    (setf (raam-name x) name)
    (setf (raam-creation-date x) (date))
    (unless (member *nilvar* terminal-codes :key #'car)
	    (setf terminal-codes (cons (list *nilvar* nil-vector)
				       terminal-codes))) 
    (setf (raam-symbol-table x)
	  (create-terminal-symbol-table terminal-codes rep-width terminal-width))
    (setf (raam-init-terminals x) (copy-object (raam-symbol-table x)))
    (setf (raam-term-tol x) (coerce term-tol *element-type*))
    (setf (raam-nonterm-tol x) (coerce nonterm-tol *element-type*))
    (setf (network-rate (raam-network x)) (coerce rate *element-type*))
    (setf (network-momentum (raam-network x)) (coerce momentum *element-type*))
    (setf (network-sp-add (raam-network x)) (coerce sp-add *element-type*))
    (setf (raam-pattern-repeat x) repeat?)
    (setf (network-adjust? (raam-network x)) adjust?)
    (setf (network-always-bp? (raam-network x)) always-bp?)
    (loop for y in trees do (raam-add-stack x y))
    x))


;; Add a stack to the training set

(defun raam-add-stack (raam stack &key (branching 'left)
			    &aux (nested-stack *nilvar*))
  (if (equal branching 'right)
      (loop for x in (reverse stack) do 
	    (setf nested-stack (list x nested-stack)))
    (loop for x in stack do 
	  (setf nested-stack (cons nested-stack (list x)))))
  (raam-add-tree raam nested-stack :width (raam-terminal-width raam)))



;;; ----------------------------------------------------------------------
;;;
;;; Symbol table access
;;;
;;; The defsetf methods below set a value in the symbol table by COPYING
;;; a given vector.  Copying is necessary since the new vector is usually
;;; just a pointer to a set of nodes in a network.  For comparision, see
;;; the function (pointer-to-activation-of-level).  
;;; ----------------------------------------------------------------------

(defun symbol-name-by-index  (symtab i) (car (nth i symtab))) 
(defun symbol-value-by-index (symtab i) (cdr (nth i symtab)))

(defun symbol-index-by-name (symtab sym)
  (position sym symtab :test #'equal :key #'car))

(defun symbol-value-by-name (symtab sym)
  (cdr (assoc sym symtab :test #'equal)))

(defsetf symbol-value-by-name (symtab sym) (value)
  `(let ((ptr (cdr (assoc ,sym ,symtab :test #'equal))))
     (loop for j from 0 to (1- (length ptr)) do
	   (setf (aref ptr j) (aref ,value j)))))

(defsetf symbol-value-by-index (symtab i) (value)
  `(let ((ptr (cdr (nth ,i ,symtab))))
     (loop for j from 0 to (1- (length ptr)) do
	   (setf (aref ptr j) (aref ,value j)))))

(defun symbol-name-by-value (symtab value &key (terminals-only nil))
  (let* ((r (length value))
	 (symtab2 (loop for tree-index in symtab             ;only check
			when (= r (length (cdr tree-index))) ;vectors of
			collect tree-index))                 ;length |value|
	 (symtab3 
	  (if terminals-only (loop for tree-index in symtab2
				   when (terminal? (car tree-index))
				   collect tree-index)
	    symtab2))
	 (distances (loop for i from 0 to (1- (length symtab3)) collect
			  (vector-distance (symbol-value-by-index symtab3 i)
					   value))) 
	 (min (apply #'min distances))
	 (index (position min distances)))
    (symbol-name-by-index symtab3 index)))



;;; ----------------------------------------------------------------------
;;;
;;; Training
;;;   Because the training set is ordered, we optimize the raam so that
;;; it just cycles through the training set.
;;;
;;; ----------------------------------------------------------------------


;;; TRAIN

(defun raam-train (raam &key (num 10000) (display 1) (reconstruct 0)
			     (terminal-test t))
  (unless (typep raam 'raam)
	  (error "First argument must be of type \"raam\""))
  (let ((errors 1)) 
    (verify-all-terminals-defined raam)      ; This causes an error if
					     ; terminals aren't defined. 
    (let* ((net    (raam-network raam))
	   (t-set  (raam-training-set raam))
	   (numpats (length t-set)))
      (when (and (not (zerop reconstruct)) (not terminal-test))
         (error "Terminal Test must be selected when Reconstruct is non-zero"))
      (unless t-set (error "No trees in training set!"))
      (zero-prev-deltas net)
      (loop for i from 1 to num
	    for output = (and (not (zerop display))
			      (zerop (mod i display)))
	    until (zerop errors) do
	    (zero-deltas net)
	    (setf (network-total-error net) zero (network-max-error net) zero
		  errors 0)
	    (loop for tree-index in t-set do
	      (incf errors
		    (if (symbol-value-by-name (raam-static-table raam)
					      (car tree-index))
			(raam-train-fixed-pattern
			 raam tree-index numpats output terminal-test
			 reconstruct)
		      (raam-train-pattern raam tree-index numpats output
					  terminal-test reconstruct))))
	    (incf (raam-total-epochs raam))
	    (update-weights net)
	    (when (zerop (mod i display))
		  (format t "~%Epoch ~d: Err=~d Tot=~,6f Max=~,6f~%"
			  (raam-total-epochs raam) errors
			  (network-total-error net) (network-max-error net))) 
	    (update-prev-deltas net)
	    (when (and (zerop errors) (not (zerop reconstruct)))
		  (setf errors 1))
	    (when (and (not (zerop reconstruct))
		       (or (zerop errors) (zerop (mod i reconstruct))))
	       (if (verify raam) (setf errors 0) (setf errors 1)))
	    finally (format t "~%Epoch ~d: Err=~d Tot=~,6f Max=~,6f~%"
			    (raam-total-epochs raam) errors 
			    (network-total-error net)
			    (network-max-error net))))
    (zerop errors)))

(defun pointer-to-activation-of-level (i net)
  (cdr-vector (nth i (network-activations net))))

;; typical training for a tree in raam.
(defun raam-train-pattern (raam tree-index numpats output termtest reconstruct)
  (let* ((tree (car tree-index))
	 (index (cdr tree-index))
	 (net    (raam-network raam))
	 (valence (raam-valence raam))
	 (symtab (raam-symbol-table raam))
	 (terminal-bits (mapcar #'(lambda (x) (if (terminal? x) one zero))
				tree)) 
	 (terminal-test (coerce-vector terminal-bits))
	 ;; This is the place which can do with some speedup.  If we make
	 ;; training-set have more information, like point directly to the
	 ;; vectors in the symbol table so we don't need to search for the
	 ;; symbols.
	 (codes (loop for x in tree collect (symbol-value-by-name symtab x)))
	 (incode (apply #'append-vector codes))
	 (outcode (append-vector terminal-test incode)))
    (set-nodes net incode)
    (forward-pass net)
    (if (network-always-bp? net)
	(compute-output-errors net outcode zero)
      (loop for subtree in (cons nil tree)
	    for len = (length (symbol-value-by-name symtab subtree))
	    for nodestart = 0 then (1+ nodeend)
	    for nodeend = (1- valence) then (1- (+ nodestart len))
	    do
	    (compute-output-errors net outcode
				   (if (terminal? subtree)
				       (raam-term-tol raam)
				     (raam-nonterm-tol raam))
				   :start nodestart
				   :end nodeend)))
    (compute-hidden-errors net)
    (compute-deltas net
		    :divisor (if (network-adjust? net) numpats one)
                    :repetitions (if (raam-pattern-repeat raam)
				     (nth index (raam-pattern-copies raam))
				   one))
    (setf (symbol-value-by-index symtab index)
	  (pointer-to-activation-of-level *new-rep-level* net))
    (if (or output (zerop reconstruct))
	(raam-count-errors raam outcode tree output termtest)
      1)))

;; train a tree which has a fixed representation.  When the key DECODER is T
;; the decoder is also trained.  When it is NIL only the encoder is trained.
(defun raam-train-fixed-pattern (raam tree-index numpats output termtest 
				      reconstruct
				 &key (decoder t))
  (let* ((tree (car tree-index))
	 (index (cdr tree-index))
	 (net    (raam-network raam))
	 (valence (raam-valence raam))
	 (symtab (raam-symbol-table raam))
	 (terminal-bits (mapcar #'(lambda (x) (if (terminal? x) one zero))
				tree)) 
	 (terminal-test (coerce-vector terminal-bits))
	 ;; This is the place which can do with some speedup.  If we make
	 ;; training-set have more information, like point directly to the
	 ;; vectors in the symbol table so we don't need to search for the
	 ;; symbols.
	 (fixedrep (symbol-value-by-name (raam-static-table raam) tree))
	 (codes (loop for x in tree collect (symbol-value-by-name symtab x)))
	 (incode (apply #'append-vector codes))
	 (outcode (append-vector terminal-test incode)))
    ;; first place the fixed rep at the center and forward prop through the
    ;; decoder and compute the deltas for only the decoder.
    (when decoder
	  (set-nodes net fixedrep *new-rep-level*)
	  (forward-pass net :min-level *new-rep-level*)
	  (if (network-always-bp? net) 
	      (compute-output-errors net outcode zero)
	    (loop for subtree in (cons nil tree)
		  for len = (length (symbol-value-by-name symtab subtree))
		  for nodestart = 0 then (1+ nodeend)
		  for nodeend = (1- valence) then (1- (+ nodestart len))
		  do
		  (compute-output-errors net outcode
					 (if (terminal? subtree)
					     (raam-term-tol raam)
					   (raam-nonterm-tol raam))
					 :start nodestart
					 :end nodeend)))
	  (compute-deltas net :min-level *new-rep-level*
			  :divisor (if (network-adjust? net) numpats one)
			  :repetitions 
			    (if (raam-pattern-repeat raam)
				(nth index (raam-pattern-copies raam))
			      one)))
    ;; Next, place the actual inputs on the input layer and foward prop to the
    ;; middle layer.  Use the fixed rep for an target value and compute the
    ;; deltas for the encoder only.  We forward prop the whole network so we
    ;; can count errors correctly
    (forward-pass net)
    (if (network-always-bp? net) 
	(compute-output-errors net fixedrep zero :max-level *new-rep-level*)
      (loop for subtree in (cons nil tree)
	    for len = (length (symbol-value-by-name symtab subtree))
	    for nodestart = 0 then (1+ nodeend)
	    for nodeend = (1- valence) then (1- (+ nodestart len))
	    do
	    (compute-output-errors net fixedrep
				   (if (terminal? subtree)
				       (raam-term-tol raam)
				     (raam-nonterm-tol raam))
				   :start nodestart :max-level *new-rep-level*
				   :end nodeend)))
    (compute-deltas net :max-level *new-rep-level*
		    :divisor (if (network-adjust? net) numpats one)
                    :repetitions (if (raam-pattern-repeat raam)
				     (nth index (raam-pattern-copies raam))
				   one))
    ;; update the symbol table and compute errors as normal.
    (setf (symbol-value-by-index symtab index)
	  (pointer-to-activation-of-level *new-rep-level* net))
    (if (or output (zerop reconstruct))
	(raam-count-errors raam outcode tree output termtest)
      1)))

;;; ----------------------------------------------------------------------
;;;
;;; Compute output errors for raam
;;;
;;; ----------------------------------------------------------------------

;; counts the errors for the various subtrees.
(defun raam-count-errors (raam pattern tree &optional display termtest
			       &aux 
			       (symtab (raam-symbol-table raam))
			       (valence (raam-valence raam))
			       (term-tol (raam-term-tol raam))
			       (nonterm-tol (raam-nonterm-tol raam)))
  (loop for subtree in tree 
	for i = 0 then (1+ i)
	for len = (length (symbol-value-by-name symtab subtree))
	for nodestart = valence then (1+ nodeend)
	for nodeend = (1- (+ nodestart len)) 
	sum (count-errors-for-subtree
	     raam pattern i
	     (if (terminal? subtree) term-tol nonterm-tol)
	     nodestart nodeend
	     display termtest)))

(defun correct (pat act)
  (or (and (= pat one)  (> act half))
      (and (= pat zero) (< act half))))

;; steps through the positions coorsponding to the nodes which should output
;; the given subtree position (pos) and computes the error for the nodes
;; updating the network structure
(defun count-errors-for-subtree (raam pattern pos tol
				 nodestart nodeend display termtest)
  (let* ((net (raam-network raam))
	 (activation (pointer-to-activation-of-level
		      (1- (length (network-activations net))) net)))
	(loop for i from nodestart to nodeend 
	      for actv = (aref activation i)
	      for target = (aref pattern i)
	      for err = (- target actv)
	      sum (* err err) into total-error 
	      maximize (abs err) into max-error 
	      count (> (abs err) tol) into num
	      finally
	      (when (and termtest
			 (not (correct (aref pattern pos)
				       (aref activation pos))))
		      (incf num))
	      (when display (write-char (if (zerop num) #\. #\*)))
	      (incf (network-total-error net) total-error)
	      (setf (network-max-error net)
		    (max max-error (network-max-error net)))
	      (return num))))


;;; ----------------------------------------------------------------------
;;;
;;; Dynamically add/remove trees to the training set
;;;
;;; ----------------------------------------------------------------------

;; Add a tree to the training set
(defun raam-add-tree (raam tree &key (width (raam-rep-width raam))
			   &aux all-trees3)
  "Add a tree to the current training-set, e.g., (raam-add-tree raam '((a b) (c d)))" 
  (loop for x in (terminals tree)
	do (set-terminal raam x (dummy-vector width)))
  (let* ((fixed-tree (fix-valence tree (raam-valence raam)))
	 (all-trees1 (subtrees fixed-tree))
	 (all-trees2 (loop for x in all-trees1 collect
			   (cons x (dummy-vector (raam-rep-width raam))))))
    (setf (raam-symbol-table raam)
	  (remove-duplicates (append (raam-symbol-table raam) all-trees2)
			     :key #'car :test #'equal :from-end t))
    (setf all-trees3 (loop for x in all-trees1 collect 
			   (cons x (symbol-index-by-name
				    (raam-symbol-table raam) x))))
    (setf (raam-org-trees raam)
	  (remove-duplicates (append (raam-org-trees raam) (list tree))
			     :test #'equal :from-end t))
    (setf (raam-fixed-trees raam) 
	  (remove-duplicates (append (raam-fixed-trees raam) (list fixed-tree))
			     :test #'equal :from-end t))
    (setf (raam-all-trees raam)
	  (remove-duplicates (append (raam-all-trees raam) all-trees1)
			     :test #'equal :from-end t))
    (setf (raam-training-set raam)
	  (remove-duplicates (append (raam-training-set raam) all-trees3)
			     :key #'car :test #'equal :from-end t))
    (setf (raam-pattern-copies raam)
	  (append (raam-pattern-copies raam)
		  (make-list (- (length (raam-symbol-table raam))
				(length (raam-pattern-copies raam)))
			     :initial-element zero)))
    (loop for subtree in all-trees1
	  do (incf (nth (symbol-index-by-name (raam-symbol-table raam) subtree)
			(raam-pattern-copies raam))))
    (format nil "New tree = ~a~%Valence ~a = ~a~%Subtrees = ~a~%" 
	    tree (raam-valence raam) fixed-tree all-trees1)))
 


;; Remove a tree from the training set
(defun raam-remove-tree (raam tree &aux 
			      (t-set (raam-training-set raam))
			      (x (find tree t-set :key #'car :test #'equal)))
  (if x (setf (raam-training-set raam) (remove x t-set))
    (error "Tree not found")))


;;; ----------------------------------------------------------------------
;;;
;;; Functions for setting terminals
;;;
;;; ----------------------------------------------------------------------

;; Randomize terminals
(defun randomize-terminals (raam &key (all nil))
  "Randomizes all undefined terminals.  With \":all t\", randomizes ALL
terminals" 
  (let* ((symtab (raam-symbol-table raam))
	 (repwidth (raam-rep-width raam))
	 (terms (if all (remove nil (terminals-of-raam raam))
		  (undefined-terminals raam))))
    (loop for x in terms do
	  (setf (symbol-value-by-name symtab x) (random-vector repwidth)))))

;; verify new representation is of correct form
(defun rep->vector (raam rep &aux (w (raam-terminal-width raam)))
  (unless (vector-or-list? rep)
	  (error "Not a valid representation"))
  (unless (= (length rep) w)
	  (error "All terminal codes must have length \"rep-width\" (~a)" w))
  (coerce-vector rep))

;; add a new terminal to raam, called by set-terminal
(defun add-terminal-to-raam (raam term rep)
  (let ((symtab (raam-symbol-table raam)))
    (if (and rep
	     (not (symbol-index-by-name symtab term)))
	(setf (raam-symbol-table raam)
	      (append (raam-symbol-table raam)
		      (list (cons term rep)))))))

;; set existing terminal, or add terminal if needed
(defun set-terminal (raam term rep)
  (let ((w (raam-terminal-width raam))
	(symtab (raam-symbol-table raam)))
    (if rep 
	(progn
	  (setf rep (rep->vector raam rep))
	  (if (and (symbol-index-by-name symtab term)
		   (not (dummy-vector? rep)))
	      (setf (symbol-value-by-name symtab term) rep)
	    (add-terminal-to-raam raam term rep)))
      (add-terminal-to-raam raam term (dummy-vector w)))))

;; Set several terminals
(defun set-terminals (raam term-reps)
  "(set-terminals raam term-reps), e.g. (set-terminals raam '((a (.3 .4)) (b
(.1 .9))))" 
  (loop for x in term-reps do
	(set-terminal raam (car x) (cadr x))))

;;; ----------------------------------------------------------------------
;;;
;;; Functions for setting raam parameters
;;;
;;; ----------------------------------------------------------------------

(defun set-raam (raam &key (rate (network-rate (raam-network raam)))
		      (momentum (network-momentum (raam-network raam)))
		      (term-tol (raam-term-tol raam))
		      (nonterm-tol (raam-nonterm-tol raam))
		      (name (raam-name raam))
		      (always-bp? (network-always-bp? (raam-network raam)))
		      (repeat? (raam-pattern-repeat raam))
		      (adjust? (network-adjust? (raam-network raam)))
		      (sp-add (network-sp-add (raam-network raam))))
  (let ((net  (raam-network raam)))
    (setf rate        (coerce rate        *element-type*))
    (setf momentum    (coerce momentum    *element-type*))
    (setf term-tol    (coerce term-tol    *element-type*))
    (setf nonterm-tol (coerce nonterm-tol *element-type*))
    (setf (network-rate     net)  (coerce rate *element-type*))
    (setf (network-momentum net)  (coerce momentum *element-type*))
    (setf (raam-term-tol raam)    (coerce term-tol *element-type*))
    (setf (raam-nonterm-tol raam) (coerce nonterm-tol *element-type*))
    (setf (raam-name raam)        name)
    (setf (network-sp-add net) (coerce sp-add *element-type*))
    (setf (network-adjust? net) adjust?)
    (setf (raam-pattern-repeat raam) repeat?)
    (setf (network-always-bp? net) always-bp?)
    (format nil "Current values are~%~{  ~12a= ~a~%~}~%"
	    `("Name"        ,(raam-name raam)
	      "Rate"        ,(network-rate net)
	      "Momentum"    ,(network-momentum net)
	      "Term-tol"    ,(raam-term-tol raam)
	      "Nonterm-tol" ,(raam-nonterm-tol raam)
	      "Repeat? (repeat?)" ,(raam-pattern-repeat raam)
	      "Adjust Rate? (adjust?)" ,(network-adjust? net)
	      "Always BP? (always-bp?)" ,(network-always-bp? net)
	      "Sigmoid Prime Add (sp-add)" ,(network-sp-add net)))))

(defun add-to-history (raam attribute value)
  (setf (raam-history raam)
	(append (raam-history raam)
		(list (make-history-record :attribute attribute
					   :value value
					   :epoch (raam-total-epochs raam))))))


;;; ----------------------------------------------------------------------
;;;
;;; This code does a verification of the trees in the training set.  A
;;; verification reconstructs each of the trees in the training set from the
;;; stored root represntation.  A tree's encoding is determined to be verified
;;; when it decodes into the same shape tree and each of the terminals are
;;; within the terminal tolerance of the intended terminal.
;;;
;;; ----------------------------------------------------------------------

;; This function verifies that each tree in the list originally given is
;; reconstructed from it's compressed version.  It bails out when it discovers
;; a tree which does not decode corectly.
(defun verify (raam)
  (let ((symtab (raam-symbol-table raam)))
    (loop for tree in (raam-fixed-trees raam)
	  unless (verify-subtree raam tree (symbol-value-by-name symtab tree))
	  return nil
	  finally (return t))))

;; Checks that the code given decodes into the desired shape with terminals
;; within tolerance
(defun verify-subtree (raam subtree code)
  (let ((net (raam-network raam))
	(actv nil)
	(table (raam-symbol-table raam))
	(ok t))
    (set-nodes net code *new-rep-level*)
    (forward-pass net :min-level *new-rep-level*)
    (setf actv
	  (copy-seq
	   (cdr-vector
	    (car (last (network-activations (raam-network raam)))))))
    (loop for part in subtree
	  for len = (length (symbol-value-by-name table part))
	  for start = (raam-valence raam) then (1+ end)
	  for end = (1- (+ start len))
	  for i = 0 then (+ i 1)
	  while ok do
       (setf ok
	     (or (and (non-terminal? part)
		      (> half (extract-test actv i))
		      (verify-subtree raam part
				      (extract-symbol actv start end)))
		 (and (terminal? part)
		      (< half (extract-test actv i))
		      (verify-terminal raam part
				       (extract-symbol actv start end))))))
    ok))

;; verify the terminal is the correct one
(defun verify-terminal (raam term code)
  (let ((termrep (symbol-value-by-name (raam-symbol-table raam) term))
	(tol (raam-term-tol raam)))
    (loop for i from 0 to (1- (length termrep))
	  while (< (abs (- (aref termrep i) (aref code i))) tol)
	  finally (return (= i (length termrep))))))

;; extract the raam terminal test bit for the ith symbol. returns a number
(defun extract-test (actv pos)
  (aref actv pos))

;; extract the ith symbol from the activation.
(defun extract-symbol (actv start end)
  (make-array (list (1+ (- end start)))
	      :element-type *element-type*
	      :initial-contents (coerce (subseq actv start (1+ end)) 'list)))

;;; Two possible additions for raams:
;;;
;;; 1) Allow extra hidden units which are not apart of the representation
;;; level. Thus we could have a 10-7-10 raam where only 5 units of the middle
;;; layer are used.  Reconstruction would be a problem.  Could simply place 0.5
;;; in each unused position.  Or could save whole representation and only use
;;; rep-width units for encoding but all for decoding.
;;;
;;; 2) (greg and pete's) Allow the ability to specify non-terminal reps for the
;;; non terminals.  Training when a nonterminal rep is specified is done as two
;;; separate networks, the decoder and encoder are trained dictinctly.
;;;
