;;
;;; Equality solver for Zeno -- use incremental Gaussian elimination
;;

(in-package "ZENO")
(proclaim '(optimize (speed 3) (safety 2) (space 0) (compilation-speed 0)))

(defclass EQN-SOLVER ()
  ((forms :initform NIL :accessor eqn-forms)
   (consistent? :initform T :accessor eqn-consistent?)
   (implicits :initform nil :accessor eqn-implicits)
   (csolver :accessor eqn-csolver :initarg :csolver)
   ))

(defmethod spawn ((self eqn-solver) old-eqn-solver)
  ;; Makes this instance a duplicate of OLD-EQN-SOLVER, minus
  ;; the link to the CSOLVER.
  (with-slots (forms consistent? implicits) self
    (cond ((eqn-consistent? old-eqn-solver)
	   (setf forms (mapcar #'copy-row (eqn-forms old-eqn-solver)))
	   (setf implicits (eqn-implicits old-eqn-solver))
	   (setf consistent? t))
	  (t
	    (setf consistent? nil)))
    (values)))
      
(defmethod EQN-SIMPLEX ((self eqn-solver))
  (with-slots (csolver) self
    (cs-ineq csolver)))

(defmethod EQN-NONLIN ((self eqn-solver))
  (with-slots (csolver) self
    (cs-nlin csolver)))

(defstruct (ROW				;a row in parametric solved form,
            (:type vector))
  constant
  eps					;epsilon (only used for simplex)
  vars)

(defun COPY-ROW (row)
  (make-row :constant (row-constant row)
	    :eps (row-eps row)
	    :vars (copy-alist (row-vars row))))

(defun PRINT-ROW (self stream depth)
  (declare (ignore depth))
  (princ "{" stream)
  (print-row-internal self stream '=)
  (princ "}" stream))

(defun princ-var-name (thing stream)
  (cond ((numberp thing)
	 (cond ((< thing 0)
		(princ "s" stream)
		(princ (abs thing) stream))
	       (t
		(princ "v" stream)
		(princ thing stream))))
	(t		
	 (princ thing stream))))

(defun PRINT-ROW-INTERNAL (self stream inequality)
  (flet ((princ-var (v stream)
	   (cond ((== 1 (car v))
		  (princ-var-name (cdr v) stream))
		 (t
		  (format stream "~f" (/ (round (* (car v) 1000)) 1000.0))
		  (princ " " stream)
		  (princ-var-name (cdr v) stream)))))
    (let ((vars  (row-vars self))
	  (const (row-constant self)))
      (cond ((null (cdr vars))
	     (princ-var (car vars) stream))
	    ((do* ((vs vars (cdr vs))
		   (v (car vs) (car vs))
		   (more? (cdr vs) (cdr vs)))
		 ((null vs))
	       (cond (more?
		      (princ-var v stream)
		      (princ " + " stream))
		     (t
		      (princ-var v stream))))))
      (format stream " ~a ~f" inequality
	      (/ (round (* const 1000)) 1000.0))
      (when (> (abs (row-eps self)) *tolerance*)
	(format stream " + ~a EPS"
		(/ (round (* (row-eps self) 1000)) 1000.0)))
      )))

(defun VAR=ZERO? (var)
  (== 0 (car var)))

(defun CONST=0? (row)
  (if (== 0 (row-constant row))
      (== 0 (row-eps row))))

(defun CONST<0? (row)
  (cond ((= (row-constant row) 0)
	 (< (row-eps row) 0))
	(t
	 (< (row-constant row) 0))))

(defun SCALE-ROW (row scale)
  (let ((newv (if (row-vars row)
                  (mapcar #'(lambda (v) (scale-sym v scale))
		      (row-vars row)))))
    (setf (row-vars row) newv)
    (setf (row-constant row)
      (* (row-constant row) scale))
    (setf (row-eps row)
      (* (row-eps row) scale))
    (values row)))

(defmethod RESET ((self eqn-solver))
  (setf (eqn-consistent? self) t)
  (setf (eqn-forms self) nil)
  (setf (eqn-implicits self) nil)
  (values))

(defun canonicalize-vars (vars)
  ;; vars is a list of (coeff . var) representing a sum.
  ;; gather all common vars and add their coefficients.
  ;;
  (declare (optimize (speed 3) (safety 1)))
  (let ((temp nil)
	(result nil)
        (entry nil))
    (dolist (v vars)
      (setf entry (vassoc (cdr v) result))
      (if entry
          (incf (car entry) (car v))
	(push (cons (car v) (cdr v)) result)))
    (dolist (v result)
      (unless (== 0 (car v)) (push v temp)))
    (values temp)))

(defmethod ADD-RAW-EQUATION ((self eqn-solver) form)
  ;; FORM is :equation <number> &rest ((coeff . var)*)
;  (show-constraints self)
  (cond ((eqn-consistent? self)
	 (process-new-row
	  self
	  (make-row :constant (second form)
		    :eps 0
		    :vars (canonicalize-vars (cddr form)))))
	(t
	 :inconsistent)))

(defmethod PROCESS-NEW-ROW ((self eqn-solver) row)
  (substitute-out-subjects self row)
;  (condense-row row)
  (cond ((null (row-vars row))
	 (cond ((not (const=0? row))
		(setf (eqn-consistent? self) nil)
		(values :inconsistent))
	       (t
		(values :redundant))))
	(t
	 (add-new-row self row))))

(defmethod ADD-NEW-ROW ((self eqn-solver) row)
  (let ((result (add-new-row-1 self row)))
    (when (eq result :implicit)
      (do  ((simplex (eqn-simplex self))
	    (nonlin (eqn-nonlin self))
	    (blank (list :inequality '= 0 (cons 1 nil)))
	    (val-pair (pop (eqn-implicits self))
		      (pop (eqn-implicits self))))
	  ((null val-pair))
	(count-stat .gauss->nonlin.)
	(cond ((have-seen-var-p (eqn-simplex self) (car val-pair))
	       (setf (cdr (elt blank 3)) (car val-pair))
	       (setf (elt blank 2) (cdr val-pair))
	       (count-stat .gauss->simplex.)
	       (if (or (eq :inconsistent 
		              (handle-equality 
				      simplex (equation->row simplex blank) t))
		       (eq :inconsistent
			   (wakeup nonlin (car val-pair) (cdr val-pair))))
		   (setf result :inconsistent)))
	      (t
	       (if (eq :inconsistent
		       (wakeup nonlin (car val-pair) (cdr val-pair)))
		   (setf result :inconsistent))))))
    (values result)))

(defmethod ADD-NEW-ROW-1 ((self eqn-solver) row)
  (choose-new-subject self row)
  (push row (eqn-forms self))
  (let ((vars (row-vars row)))
    (cond ((null (cdr vars))
	   ;; Implicit equalities!
	   (push (cons (cdar vars)
		       (/ (double-float (row-constant row))
			  (caar vars)))
		 (eqn-implicits self))
	   (if (eq :inconsistent (gaussian-pivot self row))
	       :inconsistent
	     :implicit))
	  (t
	   (values (gaussian-pivot self row))))))

(defmethod CHOOSE-NEW-SUBJECT ((self eqn-solver) row)
  (let ((nonbasic nil)
	(basic nil)
	(chosen nil)
	(tableau (eqn-simplex self)))
    (dolist (v (row-vars row))
      (cond ((basic-var-p tableau (cdr v))
	     (setf basic v))
	    ((not (nonbasic-var-p tableau (cdr v)))
	     (setf chosen v)
	     (return))
	    (t
	     (setf nonbasic v))))
    (setf chosen (or chosen basic nonbasic))
    (setf (row-vars row) (move-to-front chosen (row-vars row)))
    (scale-row row (/ 1.0d0 (car (row-subject row))))
    chosen))

(defun ROW-SUBSTITUTE-OUT (from to)
  (declare (optimize (speed 3) (safety 1)))
  (count-stat .ero.)
  (let* ((subj (row-subject from))
	 (vars (row-vars to))
	 (dest (vassoc (cdr subj) vars))
	 (coeff (car dest))
	 (newvars nil)
	 )
    (unless (or (null dest) (== 0 coeff)) ;; don't do anything
      (count-stat .valid-ero.)
      (if (plusp (car subj)) (setf coeff (* -1.0d0 coeff)))
      (dolist (fromv (row-vars from))
	(push 
	 (cons (* coeff (car fromv)) (cdr fromv)) newvars))
      (setf (row-vars to) (canonicalize-vars (nconc (row-vars to) newvars)))
      (incf (row-constant to) (* coeff (row-constant from)))
      (incf (row-eps to) (* coeff (row-eps from)))
      (values t)
      )))

(defmethod SUBSTITUTE-OUT-SUBJECTS ((self eqn-solver) row)
  (let ((vars (row-vars row)))
    (dolist (from (eqn-forms self))
      (unless (eq from row)
	(when (vassoc (cdr (row-subject from)) vars)
	  (row-substitute-out from row))))
  (values row)))

(defmethod GAUSSIAN-PIVOT ((self eqn-solver) row)
  (let ((implicits? nil)
	(vars nil)
	(const nil)
	(pivot (cdr (row-subject row)))
	(garbage nil))
    (count-stat .gpivot.)
    (dolist (to (eqn-forms self))
      (unless (or (eq to row)
		  (not (vassoc pivot (row-vars to)))
		  (not (row-substitute-out row to)))
	(setf vars (row-vars to))
	(setf const (row-constant to))
	(cond ((null (row-vars to))
	       (push to garbage))
	      ((null (cdr vars))	;a new implicit equality!
	       (setf implicits? t)
	       (push (cons (cdar vars) (/ (double-float const) (caar vars)))
		     (eqn-implicits self))))))
    (if garbage
	(setf (eqn-forms self) (set-difference (eqn-forms self) garbage)))
    (when (every #'(lambda (v)
		     (have-seen-var-p (eqn-simplex self) (cdr v)))
		 (row-vars row))
      (count-stat .gauss->simplex.)
	 (handle-equality (eqn-simplex self) (copy-row row) t))
    (if (and (ineq-consistent? (eqn-simplex self))
	     (eqn-consistent? self))
	(if implicits? :implict :ok)
      (progn
	(setf (eqn-consistent? self) nil)
	:inconsistent))))

(defmethod GATHER-INTERNAL-VARS ((self eqn-solver))
  (let ((vars nil))
    (dolist (row (eqn-forms self))
      (dolist (v (row-vars row))
	(when (numberp (cdr v))  ;;;;**** assumes variables are numbers
	  (pushnew (cdr v) vars))))
    (values vars)))

(defmethod FIND-ROW-WITH-VAR ((self eqn-solver) name)
  (dolist (row (eqn-forms self))
    (when (vassoc name (row-vars row))
      (return row))))

(defmethod ELIMINATE-INTERNALS ((self eqn-solver))
  (let ((ivars (gather-internal-vars self))
	(row nil) 
	(simplex (eqn-simplex self)))
    (dolist (iv ivars)
      (setf row (find-row-with-var self iv))
      (basicfy-var-in-row row iv)
      (setf (eqn-forms self) (delete row (eqn-forms self)))
      (dolist (other-linear (eqn-forms self))
	(row-substitute-out row other-linear))
      (dolist (other-ineq (ineq-tableau simplex))
	(row-substitute-out row other-ineq)))
    (restore-tableau simplex)))

(defmethod SHOW-CONSTRAINTS ((self eqn-solver))
  (dolist (row (eqn-forms self))
    (when (row-vars row)
	(print-row-internal row *standard-output* '=)
	(terpri))))

