;;; -*- Mode:Common-Lisp; Package:USER; Base:10 -*-


;;============================================================================
;; EXTENTRON ALGORITHM
;;
;; Copyright (c) 1992 by Paul Baffes and John Zelle. This program may be 
;; freely copied, used, or modified provided that this copyright notice is
;; included in each copy of this code and parts thereof.
;;============================================================================


;;----------------------------------------------------------------------------
;; TRACING PARAMETERS AND ROUTINES
;;----------------------------------------------------------------------------
;; for tracing

(proclaim '(optimize (speed 3) (compilation-speed 0)))

(defvar *trace-extentron* nil)
(defvar *trace-weights* nil)
(defvar *trace-percep* nil)

(defun trace-extentron (&optional (trace-percep nil)  (trace-weights nil))
  (setf *trace-extentron* t)
  (setf *trace-percep* trace-percep)
  (setf *trace-weights* trace-weights))

(defun untrace-extentron ()
  (setf *trace-extentron* nil)
  (setf *trace-percep* nil)
  (setf *trace-weights* nil))

(defmacro trace-format (test-var &rest format-form)
  ;; When test-var (usually a *-trace variable) is set then use formated print
  `(if ,test-var
       (format t ,@format-form)))


;;----------------------------------------------------------------------------
;; STATISTICS PARAMETERS
;;----------------------------------------------------------------------------
(defvar *num-epochs* 0) 


;;----------------------------------------------------------------------------
;; UTILITIES
;;----------------------------------------------------------------------------

(defun square (x) (* x x))

(defun mix-up (list)
  ;;--------------------------------------------------------------------------
  ;; A simple randomizing routine for mixing up a list in place. The idea is
  ;; to use a "bubble sort" kind of approach. That is, we move down the list,
  ;; randomly swapping the two current elements.
  ;;--------------------------------------------------------------------------
    (do ((elem list (cdr elem))
	 (next (cdr list) (cdr next))
	 temp)
	((null next) list) 
      (when (>= (random 1.0) 0.75)
	(setf temp (car elem))
	(setf (car elem) (car next))
	(setf (car next) temp)))
  )


;;----------------------------------------------------------------------------
;; ACCESS MACROS
;;----------------------------------------------------------------------------

;;----------------------------------------------------------------------------
;; NODES
;;
;; Each node is an integer LTU with a thrshold and list of weights
;; A node is a LIST OF NUMBERS: first the threshold then the weights.
;;----------------------------------------------------------------------------
(defmacro node-threshold (node)
  `(car ,node)) 

(defmacro node-weights (node)
  `(cdr ,node))


;;----------------------------------------------------------------------------
;; IO-PAIRS
;;
;; an io-pairs is a list of lists, each is a list of input and output (that
;; is, all the inputs are in one list, all the outputs in another.
;; Specifically, the first list is the list of outputs, the second is the list
;; of inputs).  
;;----------------------------------------------------------------------------
(defmacro example-input (example)
  `(cadr ,example))                  ;; `(car ,example) changed for Univ. Test

(defmacro example-output (example)
  `(car ,example))	             ;; `(cdr ,example) changed for Univ. Test.


;;----------------------------------------------------------------------------
;; LAYERS
;;
;; Now that we only pass through particular node values when we propagate,
;; each layer has become more complicated (ptb 11/26). Now, a layer is a list
;; of two lists (rather than just one list as before): first is a list of
;; which nodes should pass through to the next layer, and second is a list of
;; the nodes for the layer.
;;
;; NOTE: by convention we use "layer" as a variable representing an entire
;; layer, "nodes" for just the nodes list, and "passes" for the pass-through
;; list.
;;----------------------------------------------------------------------------
(defmacro layer-pass-through (layer)
  `(car ,layer))

(defmacro layer-nodes (layer)
  `(cadr ,layer))


;;----------------------------------------------------------------------------
;; GLOBAL VARIABLES FOR EXTENTRON
;;---------------------------------------------------------------------------- 

(defvar *node-accuracy-counts* nil) ;; keep track of # correct at each node

(defvar *node-purity-counts* nil) ;; total "1's" for node in io-pairs

(defvar *temp-output* nil) ;; A list of output allocated only once. See the
			   ;; routine "test-layer"

(defvar *num-pairs* 0)     ;; keeps track of total number of io-pairs

(defvar *magic-number* 30) ;; number of cycles to tune with no improvement.
			   ;; See the routine "train-layer-perceptron"


;;----------------------------------------------------------------------------
;; EXTENTRON Code starts here
;;----------------------------------------------------------------------------

(defun extentron (io-pairs bit-threshold pair-threshold)
  ;;--------------------------------------------------------------------------
  ;; Takes io-pairs to be learned, and two "stopping" criterion: the accuracy
  ;; for bits and the accuracy for pairs. To get "perfect" training, you set
  ;; the bit-threshold to 1 (ie, 100%) and the pair-threshold to 1.
  ;;
  ;; Modified 11/24/91 to set up a single global output for testing layers
  ;;--------------------------------------------------------------------------
  (let* ((pairs (copy-tree io-pairs))
	 (layer-count 1)
	 (num-outputs (if pairs (length (example-output (car pairs))) 1))
	 network layer bit-results pair-results )

    ;; initialize global variables
    (init-global-variables io-pairs num-outputs)

    ;; create an initial layer and train it. Then test the accuracy
    (setf layer (create-zero-layer (car pairs)))
    (train-layer-perceptron layer pairs)
    (setf network (nconc network (list layer)))
    (multiple-value-setq (bit-results pair-results) (test-layer layer pairs))

    ;; repeat until we get the desired accuracy
    (do ()
	((and (>= bit-results bit-threshold) (>= pair-results pair-threshold)))
      (incf layer-count)
      (trace-format *trace-extentron* "~%======================")
      (trace-format *trace-extentron* 
		    "~%bit accuracy:     ~F~%pattern accuracy:~F" 
		    bit-results pair-results)
      (trace-format *trace-extentron* "~%node accuracies ~A~%" 
		    *node-accuracy-counts*)
      (trace-format *trace-extentron* "~%making layer ~3D" layer-count)

      ;; extend the io-pairs based upon the last layer
      (dolist (p pairs)
	(nconc (example-input p) 
	       (propagate-and-filter layer (example-input p))))

      ;; create a new layer from the last layer, train it, test it.
      (setf layer (create-next-layer layer))
      (train-layer-perceptron layer pairs)
      (nconc network (list layer))
      (multiple-value-setq (bit-results pair-results) 
	(test-layer layer pairs)))

    ;; print results and return network.
    (trace-format *trace-extentron* 
		  "~%~%~%bit accuracy:  ~F~%pattern accuracy: ~F"
		  bit-results pair-results)
    (trace-format *trace-extentron* "~%Total Layers ~3D" layer-count)
    (trace-format *trace-extentron* "~%Total Epochs: ~6,2F"  
		  (/ *num-epochs* num-outputs))
    network))
      

(defun init-global-variables (io-pairs num-outputs)
  (setf *num-epochs* 0)
  (setf *num-pairs* (length io-pairs))
  (setf *temp-output* 
	(make-sequence 'list num-outputs :initial-element 0))
  (setf *node-accuracy-counts*
	(make-sequence 'list num-outputs :initial-element 0))
  (setf *node-purity-counts* 
	(make-sequence 'list num-outputs :initial-element 0))
  (dolist (p io-pairs)
    (do ((bits (example-output p) (cdr bits))
	 (counts *node-purity-counts* (cdr counts)))
	((null counts) t)
      (if (= (car bits) 1) (incf (car counts))))))


(defun propagate-and-filter (layer input-vector &optional (use-threshold t))
  ;;--------------------------------------------------------------------------
  ;; A special propagate routine. This code does a normal propagate and then
  ;; consults the pass-through list of the layer to see which of the actual
  ;; outputs should be produced. Note that this propagate is ONLY used for
  ;; extending the io-pairs in the "extentron" routine (see above).
  ;; Returns a new list of the propagated outputs. (ptb 11/26). Note that the
  ;; activations are pushed onto a list, thus reversing their order (thus the
  ;; nreverse for the return of the do*).
  ;;--------------------------------------------------------------------------
  (do* ((nodes (layer-nodes layer) (cdr nodes))
	(node (car nodes) (car nodes))
	(passes (layer-pass-through layer) (cdr passes))
	(result nil))
       ((null nodes) (nreverse result))
    (when (car passes)
      (let ((activation (compute-sum node input-vector)))
	(push 
	  (if use-threshold
	      (if (>= activation (node-threshold node)) 1 0)
	      activation)
	  result)))))



(defun propagate-into (layer input-vector output-vector 
		       &optional (use-threshold t))
  ;;--------------------------------------------------------------------------
  ;; Destructively insert the output of propagation into output-vector
  ;; modified 11/27 (ptb): changed to have "use-threshold" so we can use it in 
  ;; "test-extentron"
  ;;--------------------------------------------------------------------------
  (do* ((nodes (layer-nodes layer) (cdr nodes))
	(node (car nodes) (car nodes))
	(outputs output-vector (cdr outputs)))
       ((null nodes) output-vector)
    (if use-threshold
	(if (>= (compute-sum node input-vector) (node-threshold node))
	    (setf (car outputs) 1)
	    (setf (car outputs) 0))
	(setf (car outputs) (compute-sum node input-vector)))))


(defun compute-sum (node input-vector) 
  (let ((sum 0))
    (mapc #'(lambda (x y) 
		(setf sum (+ sum (* x y))))
	    (node-weights node) 
	    input-vector)
    sum))


(defun create-zero-layer (io-pair) 
  ;;--------------------------------------------------------------------------
  ;; This routine was modified (ptb 11/26) to account for the new layer
  ;; structure that keeps "pass through" information (see layer macros above).
  ;; Each layer now has a list of two things: the pass-through vector for that
  ;; layer (the first element) and the list of  actual nodes (the second
  ;; element). Here, since we are creating the first layer (layer 0. You're
  ;; welcome, Dijykstra) we know that all the outputs will pass through, and
  ;; we can just create a list of "t".
  ;;--------------------------------------------------------------------------
  (let ((input-vector (example-input io-pair)))
    (list
      (make-sequence 'list (length (example-output io-pair)) 
		     :initial-element t)
      (mapcar #'(lambda (dummy)
		  (cons 0 (mapcar #'(lambda (dummy) 0)
				  input-vector)))
	      (example-output io-pair)))))


(defun create-next-layer (prev-layer)
  ;;--------------------------------------------------------------------------
  ;; Modified 11/26/91 (ptb): now this routine can no longer assume that all
  ;; outputs are going to be passed through to this layer. In fact, the
  ;; io-pairs may have been extended with only a few of the outputs from the
  ;; previous layer (see "propagate-and-filter" above). So, this routine must
  ;; compute the number of weights to extend each node by. This is done by
  ;; looping through the pass-through section of the previous layer, counting
  ;; the number of outputs that were really passed through.
  ;;
  ;; In addition, this routine must update the pass-through section of the new
  ;; layer to reflect whether or not a particular output should be passed
  ;; through in the future.
  ;;--------------------------------------------------------------------------
  (let ((layer (copy-tree prev-layer))
	(num-new-weights 0)
	new-weights)

    ;; count the number of outputs passed from previous layer
    (dolist (pass (layer-pass-through prev-layer))
      (if pass (incf num-new-weights)))

    ;; extend the nodes in the new layer by the number of outputs passed
    ;; through from the previous layer
    (dolist (node (layer-nodes layer))
      (setf new-weights nil)
      (dotimes (n num-new-weights)
	(push 0 new-weights))
      (nconc (node-weights node) new-weights))

    ;; update the pass-through section of the new layer to reflect the latest
    ;; accuracy counts.
    (do ((counts *node-accuracy-counts* (cdr counts))
	 (passes (layer-pass-through layer) (cdr passes)))
	((null counts) t)
      (if (= (car counts) *num-pairs*)
	  (setf (car passes) nil)))

    ;; return the new layer
    layer))


(defun train-layer-perceptron (layer io-pairs)
  ;;--------------------------------------------------------------------------
  ;; Modified (ptb 11/26). Changed the initialization of "temp-node" to use
  ;; make-sequence (see init-global-variables). Also, note that the new
  ;; "pass-through" section of the layers DOES NOT effect this routine. 
  ;;--------------------------------------------------------------------------
  
  ;; first, create some locals. The "temp-node" serves as a temporary place
  ;; holder for node states as cycling occurs in the "train-node" routine.
  (let ((node-index 0)
	(temp-node 
	  (make-sequence 'list 
			 (1+ (length (example-input (car io-pairs))))
			 :initial-element 0)))

    ;; train each node in turn
    (do ((nodes (layer-nodes layer) (cdr nodes))
	 (counts *node-accuracy-counts* (cdr counts)))
	((null nodes) t)
      (trace-format *trace-percep* "~%Training node: ~D~%" node-index)

      ;; if the node is not already fully trained, then train it. Note that
      ;; the value for "(car counts)" is passed through from the layer out of
      ;; which the current layer in "layer" was formed. So, each time a layer
      ;; is tained, the number of counts is passed along. Since the
      ;; "train-node" routine won't quit until it finds a better count, the
      ;; algorithm is guaranteed to eventually converge.
      (unless (= (car counts) *num-pairs*)
	(train-node (car nodes) counts io-pairs node-index temp-node))
      (incf node-index))))

;(defvar *show-all* nil)

(defun train-node (node ptr-to-count io-pairs output-index temp-node)
  ;;--------------------------------------------------------------------------
  ;; Passing a global temp-node so as not to garbagify. ptr-to-count points to
  ;; the accuracy we must beat on this layer. Note that initially (ie the first
  ;; time this routine is called) the node has a count of 0, meaning that it
  ;; has not classified anything correctly. As this routine is repeatedly
  ;; called, the results of calling "test-node" will return a measure of how
  ;; many 0s and 1s for this node were accurately generated. Because we get a
  ;; pointer to the count here, this routine destructively modifies that count
  ;; each time this routine is called. Since the count is PASSED ON from one
  ;; layer to the next, we force subsequent layers to be more accurate for the
  ;; node (because we only set "better-split" to true here when we get a value
  ;; for "num-correct" that beats "max-correct").
  ;;
  ;; 11/25/91: revise to use apportionment. First, we set max-correct to be
  ;; the initial value of the last layer's value for the node (since we know
  ;; we must beat it to converge). 
  ;;--------------------------------------------------------------------------
  (let ((max-correct (car ptr-to-count)) ;; init max-correct to best so far for
					 ;; this node.
	(cycle-count 0)
	(node-purity (nth output-index *node-purity-counts*))
	better-split split-p num-correct)

    ;; loop until get all correct, or get at least one split (ie, better-split
    ;; is set) and you've gone past the magic number of tries.
    (loop (if (or (= max-correct *num-pairs*) 
		  (and better-split (> cycle-count *magic-number*)))
	      (return))
	  (if *trace-percep*
	      (if better-split
		  (format t "~D," cycle-count) (format t ".")))
	  
;	  (trace-format *show-all* 
;			"~%split-p=~A  num-correct=~D  max-correct=~D"
;			split-p num-correct max-correct)

	  ;; train and test the node for one epoch (that is, one look at all
	  ;; the IO pairs). We randomize here for EACH epoch.
	  (train-node-epoch node (mix-up io-pairs) output-index)
	  (incf *num-epochs*)

	  ;; test the node, setting "split-p" if a split is returned that is
	  ;; better than naive guessing.
	  (multiple-value-setq (split-p num-correct)
	    (test-node node io-pairs output-index node-purity))
	  
	  ;; if we get a "good" split, and we get all correct or at least more
	  ;; than what was in "max-correct" then update max correct and the
	  ;; temp node to save the weights. Note that we reset "cycle-count" so
	  ;; we get at least "magic-number" more cycles to find a better split
	  ;; than the currect max.
	  (if (or (= num-correct *num-pairs*) 
		  (and split-p (> num-correct max-correct)))
	      (progn
		(setf max-correct num-correct)
		(setf better-split t)
		(do ((n node (cdr n))
		     (tn temp-node (cdr tn)))
		    ((null n) t)
		  (setf (car tn) (car n)))
		(setf cycle-count 0))
	      (incf cycle-count))) 

  ;; update the counts value for this node
  (setf (car ptr-to-count) max-correct)

  ;; Finally, set the node to the best weights we found.
  (do ((n node (cdr n))
       (tn temp-node (cdr tn)))
      ((null n) t)
    (setf (car n) (car tn)))))


(defun train-node-epoch (node io-pairs output-index)
  (let (input-vector target-value)
    (dolist (pair io-pairs)
      (setf input-vector (example-input pair))
      (setf target-value (nth output-index (example-output pair)))
      (unless (= target-value (propagate-node node input-vector)) 
	(if (zerop target-value)
	    (progn
	      (incf (node-threshold node))
	      (do ((ws (node-weights node) (cdr ws))
		   (ins input-vector (cdr ins)))
		((null ws) t)
		(unless (zerop (car ins)) (decf (car ws)))))
	    (progn
	      (decf (node-threshold node))
	      (do ((ws (node-weights node) (cdr ws))
		   (ins input-vector (cdr ins)))
 		  ((null ws) t)
		(unless (zerop (car ins)) (incf (car ws))))))))))

(defun propagate-node (node input-vector)
  ;; props for only one node
  (let ((sum 0))
    (mapc #'(lambda (wt input) (incf sum (* wt input)))
	  (node-weights node) input-vector)
    (if (>= sum (node-threshold node)) 1 0)))


(defun test-node (node io-pairs output-index output-purity)
  ;;---------------------------------------------------------------------------
  ;; modified 11/26 to add a notion of purity (see *node-purity-counts*) global
  ;; variable. Passed in output-purity to save an "nth", that is, the purity
  ;; value for this particular node is passed in to this routine by the caller,
  ;; to save doing "nth" for the same node for each epoch.
  ;;
  ;; The most difficult thing to understand about this routine is the first of
  ;; the two returned values (see "values" statement as return of the dolist).
  ;; The idea is to return "t" as the first value only if two conditions are
  ;; met: (1) at least one correct prediction is made and (2) the predictive
  ;; accuracy of the node is better than random guessing. Condition 1 is easy:
  ;; we simply wait until the number of correct "1" values generated for the
  ;; node (we assume at least one "1" is desired) is generated. Testing
  ;; condition (2) is more complex.
  ;;
  ;; To compute predictive accuracy, we need two measures: (a) the number of
  ;; predictions and (b) how many times the predictions were correct. We
  ;; compute this here by simply looking at the number of "1" outputs that are
  ;; accurately predicted (we don't need to check both "1s" and "0s" since they
  ;; are the only two values--ie, they mirror eachother). So, this loop
  ;; computes two values: the number of "1s" that were predicted and the number
  ;; of times those predictions were correct. That gives us the following
  ;; ratio:
  ;;
  ;;         #-correctly-generated-1s
  ;;         ------------------------
  ;;         total-#-of-predicted-1s
  ;;
  ;; which is the predictive accuracy. Now, we also know what the a priori
  ;; expected number of "1s" is for this node; it gets passed in as
  ;; "output-purity". At initialization time (see "init-global-variables") we
  ;; compute the number of expected "1s" for this node. Thus we have another
  ;; ratio we can generate:
  ;;
  ;;         #-of-expected-1s-for-node
  ;;         -------------------------
  ;;            total-#-of-io-pairs
  ;;
  ;; which is the expected ratio of "1s" which we can expect for the node,
  ;; simply by looking at the io pairs. Note that with NO TRAINING AT ALL we
  ;; can generate a predictive accuracy that matches this latter ratio: we
  ;; simply always guess "1". Thus, our number of predicted "1s" would be the
  ;; same as the number of io pairs, and we'd be guaranteed to generate all the
  ;; expected "1s". Thus we can trivially get the latter ratio by creating a
  ;; node that always returns "1".
  ;;
  ;; To do better, we require that the split found be MORE ACCURATE than the
  ;; naive "alwasys-generate-a-1" case. Specifically, we want the first ratio
  ;; to be closer to 1 (ie more accurate) than the second ratio (the naive
  ;; approach). Thus we want the following:
  ;; 
  ;;         #-correctly-generated-1s       #-of-expected-1s-for-node
  ;;         ------------------------   >   -------------------------
  ;;         total-#-of-predicted-1s           total-#-of-io-pairs
  ;;
  ;; which is identical to
  ;;
  ;; (#-correctly-generated-1s * total-#-of-io-pairs) >
  ;;      (total-#-of-predicted-1s * #-of-expected-1s-for-node)
  ;;---------------------------------------------------------------------------
  (let ((num-correct 0) 
	(correct-1s 0) 
	(predicted-1s 0)
	out-value input-vector target-value)
    (dolist (pair io-pairs 
		  (values (>  (* correct-1s *num-pairs*) 
			      (* predicted-1s output-purity))
			  num-correct))
      (setf input-vector (example-input pair))
      (setf target-value (nth output-index (example-output pair)))
      (setf out-value (propagate-node node input-vector))
      (when (= out-value 1)
	(incf predicted-1s)
	(if (= target-value 1)
	    (incf correct-1s)))
      (if (= target-value out-value)
	  (incf num-correct)))))


(defun test-layer (layer pairs)
  ;;---------------------------------------------------------------------------
  ;; Takes a layer and pairs, and returns two percentages:
  ;;    the percentage of bits that are correctly output and the percentage of
  ;;    pairs that are correctly output.
  ;; Modified 11/18/91: had to change this to return 100 percent accuracy in
  ;;    the event that there are no pairs passed in for testing.
  ;; Modified 11/24/91: All calculated outputs are saved in the global list
  ;;    *temp-output* which is assumed to be the appropriate length.  This
  ;;    is a cons saving hack.
  ;; Checked 11/26/91: note that the addition of the "pass-through" information
  ;;    does not effect this routine (ptb).
  ;;---------------------------------------------------------------------------
  (if (null pairs) (values 1.0 1.0)
      (let ((bit-count 0)
	    (pattern-count 0)
	    (output *temp-output*))
	(dolist (p pairs)
	  (propagate-into layer (example-input p) output)
	  (if (equal output (example-output p)) (incf pattern-count))
	  (mapc #'(lambda (output-bit desired-bit) 
		    (if (= output-bit desired-bit) (incf bit-count))) 
		output
		(example-output p)))
	(values (/ bit-count 
		   (* (length (example-output (car pairs))) *num-pairs*))
		(/ pattern-count *num-pairs*)))))



(setq and-pairs '(((0 0) (0 0)) ((1 1) (1 1)) ((0 1) (1 0)) ((0 1) (0 1))))

(setq xor-pairs '(((0) (0 0 )) ((0) (1 1 )) ((1) (0 1 )) ((1) (1 0 ))))

(setq equiv-pairs '(((1) (0 0)) ((1) (1 1)) ((0) (1 0)) ((0) (0 1))))

(setq devil-pairs '(((0 1) (0 0)) ((0 1) (1 1)) 
		    ((1 0) (0 1)) ((1 0) (1 0))))
