;;; -*- Mode: LISP; Package: NP; Syntax: Common-lisp;                    -*-
;;;
;;; ************************************************************************
;;;
;;; PORTABLE AI LAB - UNI ZH
;;;
;;; ************************************************************************
;;;
;;; Filename:   np-bp.cl
;;; Short Desc: Backprop Nets
;;; Version:    1.0
;;; Status:     Experimental (July 1990)
;;; Last Mod:   27.1.92 - TW
;;; Author:     Mathias Gutknecht & Thomas Wehrle
;;;
;;;
;;;

;;;
;;; Copyright (c) 1992 Istituto Dalle Molle (IDSIA), University of
;;; Zurich, Swiss Federal Institute of Technology Lausanne.
;;;
;;; Permission is granted to any individual or institution to use, copy,
;;; modify, and distribute this software, provided that this complete
;;; copyright and permission notice is maintained, intact, in all
;;; copies and supporting documentation.
;;;
;;; IDSIA provides this software "as is" without express or implied
;;; warranty.  
;;;





(in-package :np)

(export '())


(defclass bp-net-class (feed-forward-net-class)
  ((l-rate       :initform 0.5
	         :initarg :l-rate
	         :accessor l-rate
	         :type real)
   (momentum     :initform 0.9
	         :initarg :momentum
	         :accessor momentum
	         :type real))
  (:documentation "Foundation of all BackPropagation nets"))


(defclass bp-node-class (node-class)
  ((node-error :initform 0
	       :initarg :node-error
	       :accessor node-error
	       :type real)
   (node-delta :initform 0
	       :initarg :node-delta
	       :accessor node-delta
	       :type real))
  (:documentation "Foundation of all BackPropagation nodes"))


(defclass bp-connection-class (connection-class)
  ((weight-error-derivative :initform 0
			    :initarg :weight-error-derivative
			    :accessor weight-error-derivative
			    :type real)
   (weight-delta            :initform 0
			    :initarg :weight-delta
			    :accessor weight-delta
			    :type real))
  (:documentation "Foundation of all BackPropagation Connections"))


; a layer definition is a list whose car is a name and whose cadr is a number
(defun bp-make-net (net-name output-to-input-layer-definition-list)
  (let ((net (make-net net-name 'bp-net-class (mapcar (function (lambda (def)
						    (car def)))
					output-to-input-layer-definition-list))))

    ; Output Layer
    (make-node-seq net 
		   'bp-node-class 
		   (caar output-to-input-layer-definition-list) 
		   (cadar output-to-input-layer-definition-list) 
		   :activation 1)

    ; Input Layer
    (make-node-seq net 
		   'bp-node-class
		   (caar (last output-to-input-layer-definition-list))
		   (1+ (cadar (last output-to-input-layer-definition-list))) 
		   :activation 1 :output 1)

    ; Hidden Layer(s)
    (dolist (def (cdr (butlast output-to-input-layer-definition-list)))
      (make-node-seq net 
		     'bp-node-class
		     (first def)
		     (1+ (second def))
		     :activation 1 :output 1))

    ; Make connections between layers
    (connect-nodes-to-nodes 'bp-connection-class 
			      (get-node-seq net (second (get-node-seq-names net)))
			      (get-node-seq net (first (get-node-seq-names net))))

    (do* ((p-layer-name-list (cdr (get-node-seq-names net)) (cdr p-layer-name-list))
	  (p-layer-name (first p-layer-name-list) (first p-layer-name-list))
	  (c-layer-name (second p-layer-name-list) (second p-layer-name-list)))
	((null c-layer-name) t)
    
      (connect-nodes-to-nodes 'bp-connection-class 
			      (get-node-seq net c-layer-name)
			      (cdr (get-node-seq net p-layer-name))))))
					; bias node has no input


; Forward propagation of activation
(defmethod update-net ((net bp-net-class) &optional 
			  (output-to-input-node-seq-name-list 
			   (get-node-seq-names net)))
  (dolist (hidden-layer-name 
	      (cdr (butlast output-to-input-node-seq-name-list)))  ;without input and output
    (synch-update (cdr (get-node-seq net hidden-layer-name))))
  (synch-update (get-node-seq net (first output-to-input-node-seq-name-list))))
  

(defun bp-feed-in (activation-list node-seq)
  (do ((current-a-list activation-list (cdr current-a-list))
       (current-n-list (cdr node-seq) (cdr current-n-list)))
      ((or (null current-a-list) (null current-n-list)))
    (setf (activation (first current-n-list)) (first current-a-list))
    (setf (output (first current-n-list)) (compute-output (first current-n-list)))))


(defmethod compute-net-error ((net bp-net-class) target-list
			     &optional 
			     (output-to-input-node-seq-name-list 
			      (get-node-seq-names net))
			     &key (error-tolerance 0))
  (let* ((output-to-hidden-node-seq-name-list (butlast output-to-input-node-seq-name-list))
         (output-node-seq (get-node-seq net (first output-to-hidden-node-seq-name-list)))
	(errors-tolerable t))

    ; initialize error in all nodes to zero
    (dolist (node-seq-name output-to-hidden-node-seq-name-list)
      (dolist (node (get-node-seq net node-seq-name))
	(setf (node-error node) 0)))
    
    ; compute error for output nodes
    (do*
	((t-list target-list (cdr t-list))
	 (t-value (first t-list) (first t-list))
	 (o-node-list output-node-seq (cdr o-node-list))
	 (o-node (first o-node-list) (first o-node-list)))
	((null t-list) t)

      (when (> (abs (setf (node-error o-node) 
                          (- t-value (activation o-node))))
	       error-tolerance)
	(setf errors-tolerable nil)))

    ; compute error for the rest
    (dolist (current-layer-name output-to-hidden-node-seq-name-list)
      (let ((current-layer (get-node-seq net current-layer-name)))
        
        (if (equal current-layer-name (last output-to-hidden-node-seq-name-list))
          (dolist (c-node current-layer)
            (setf (node-delta c-node) (* (node-error c-node) 
                                       (compute-activation-derivative c-node))))
          
          (dolist (c-node current-layer)
            (setf (node-delta c-node) (* (node-error c-node) 
                                       (compute-activation-derivative c-node)))
            (dolist (c-conn (in-connections c-node))
              (incf (node-error (from-node c-conn)) (* (node-delta c-node)
                                                     (weight c-conn))))))))
    errors-tolerable))
	

; Computation of the weight error derivative
(defmethod compute-weight-error-deriv ((net bp-net-class)
                                      &optional 
                                      (output-to-input-node-seq-name-list 
                                       (get-node-seq-names net)))
  (let ((output-to-hidden-node-seq-name-list (butlast output-to-input-node-seq-name-list)))
    (dolist (current-layer-name output-to-hidden-node-seq-name-list)
      (let ((current-layer (get-node-seq net current-layer-name)))
        
        (dolist (c-node current-layer)
          (dolist (c-conn (in-connections c-node))
            
            (incf (weight-error-derivative c-conn) 
                  (* (node-delta c-node) (activation (from-node c-conn))))))))))


; Change weights of connections
(defmethod change-weights ((net bp-net-class) 
                           &optional
                           (output-to-input-node-seq-name-list 
                            (get-node-seq-names net)))
  (let ((output-to-hidden-node-seq-name-list (butlast output-to-input-node-seq-name-list)))
    (dolist (current-layer-name output-to-hidden-node-seq-name-list)
      (let ((current-layer (get-node-seq net current-layer-name)))
        
        (dolist (c-node current-layer)
          (dolist (c-conn (in-connections c-node))
            
            (setf (weight-delta c-conn) 
                  (+ (* (l-rate net) (weight-error-derivative c-conn))
                     (* (momentum net) (weight-delta c-conn))))
            (incf (weight c-conn) (weight-delta c-conn))
            (setf (weight-error-derivative c-conn) 0)))))))
	
 
(defmethod compute-net-output ((net bp-net-class) input-pattern
			  &optional 
			  (output-to-input-node-seq-name-list 
			   (get-node-seq-names net)))
  (let ((input-layer (get-node-seq net 
			      (car (last output-to-input-node-seq-name-list))))
	(output-layer (get-node-seq net
				    (first output-to-input-node-seq-name-list))))
    (bp-feed-in input-pattern input-layer)
    (update-net net output-to-input-node-seq-name-list)
    (feed-out output-layer)))


(defmethod reset-net ((net bp-net-class))
  (let ((output-to-input-node-seq-name-list (get-node-seq-names net)))
    (dolist (current-layer-name output-to-input-node-seq-name-list)
      (let ((current-layer (get-node-seq net current-layer-name)))
        (dolist (c-node current-layer)
          (setf (node-error c-node) 0
                (node-delta c-node) 0)
          (dolist (c-conn (in-connections c-node))
            (setf (weight-error-derivative c-conn) 0
                  (weight-delta c-conn) 0)))))))


(defmethod reinit-net ((net bp-net-class))
  (let ((output-to-input-node-seq-name-list (get-node-seq-names net)))
    (dolist (current-layer-name output-to-input-node-seq-name-list)
      (let ((current-layer (get-node-seq net current-layer-name)))
        (dolist (c-node current-layer)
          (setf (node-error c-node) 0
                (node-delta c-node) 0)
          (dolist (c-conn (in-connections c-node))
            (set-random-weight c-conn -1 1)
	    (setf (weight-error-derivative c-conn) 0
		  (weight-delta c-conn) 0)))))))

  
(defmethod learn-patterns ((net bp-net-class) input-target-pattern-list 
                           &optional
                           (output-to-input-node-seq-name-list 
                            (get-node-seq-names net))
                           &key (error-tolerance 0))
  
  (let ((input-layer (get-node-seq net 
                                   (car 
                                    (last output-to-input-node-seq-name-list))))
        (errors-tolerable t))
    (dolist (input-target-pattern input-target-pattern-list)
      (let ((current-input-pattern (first input-target-pattern))
            (current-target-pattern (second input-target-pattern)))
        (bp-feed-in current-input-pattern input-layer)
        (update-net net output-to-input-node-seq-name-list)
        (unless
          (compute-net-error net current-target-pattern 
                                output-to-input-node-seq-name-list
                                :error-tolerance error-tolerance)
          (setf errors-tolerable nil))
        (compute-weight-error-deriv net output-to-input-node-seq-name-list)
        ))
    (change-weights net output-to-input-node-seq-name-list)
    errors-tolerable))


(defmethod learn-till-tolerable ((net bp-net-class) input-target-pattern-list 
                                 &optional
                                 (output-to-input-node-seq-name-list 
                                  (get-node-seq-names net))
                                 &key 
                                 (error-tolerance 0.1)
                                 (max-count 500))
  (reset-net net)
  (format t "~%")
  (dotimes (counter max-count)
    (format t "~a " counter)
    (when (learn-patterns net input-target-pattern-list
			  output-to-input-node-seq-name-list
			  :error-tolerance error-tolerance)
      (return (1+ counter)))))


(defmethod print-net ((net bp-net-class) 
		      &key 
                      (number-of-nodes-per-line 2)
		      (with-connections nil)
		      (number-of-connections-per-line 1))
  (declare (ignore number-of-connections-per-line))
  (labels ((printlist (l &aux (col 0))
	     (cond (with-connections 
		    (dolist (n l)
		      (format t "    * ~A <~A>~%      <~A> <~A>~%      <~A> <~A>" 
                              n (activation n) 
                              (input n) (output n)
                              (node-error n) (node-delta n))
		      (dolist (m (in-connections n))
			(format t "~%           (~A ~A~%            ~A ~A) " 
                                (weight m) (from-node m)
                                (weight-error-derivative m) (weight-delta m)))
		      (format t "~%~%")))
	
		   (t (format t "   (")
		      (dolist (n l)
			(setf col (+ col 1))
			(when (> col number-of-nodes-per-line) 
			  (format t "~%    ") 
			  (setf col 1))
			(format t "~A " n))
		      (format t ")~%~%"))))

 	   (printassoc (pl)
	     (when pl 
	       (format t "~A~%" (car pl))
	       (if (null (cdr pl))
		   (format t "<no entry>")
		 (printlist (cadr pl)))
	       (printassoc (cddr pl)))))

    (printassoc (node-seqs net))))

; ******************************************************************************
