(in-package 'obvius)

(export '(make-diagonal-matrix make-identity-matrix))

;; added optional arguments for convenience
(defun weibull (x &optional (alpha 1.0) (beta 1.0) (gamma 0.0))
  (- 1.0 (* (- 1.0 gamma) (exp (- (expt (/ x alpha) beta))))))
(defun inverse-weibull (w &optional (alpha 1.0) (beta 1.0) (gamma 0.0))
  (* alpha (expt (- (log (/ (- 1.0 w) (- 1.0 gamma)))) (/ 1.0 beta))))
(defun detection-weibull (x &optional (alpha 1.0) (beta 1.0))
  (- 1.0 (exp (- (expt (/ x alpha) beta)))))
(defun inverse-detection-weibull (w &optional (alpha 1.0) (beta 1.0))
  (* alpha (expt (- (log (- 1.0 w))) (/ 1.0 beta))))
(defun 2afc-weibull (x &optional  (alpha 1.0) (beta 1.0))
  (- 1.0 (* 0.5 (exp (- (expt (/ x alpha) beta))))))
(defun inverse-2afc-weibull (w &optional  (alpha 1.0) (beta 1.0))
  (* alpha (expt (- (log (/ (- 1.0 w) 0.5))) (/ 1.0 beta))))
(defun detection-log-normal (x &optional (half 1.0) (slope 1.0))
  (cumulative-normal (* slope (log (/ x half) 2.0))))
(defun inverse-detection-log-normal (x &optional (half 1.0) (slope 1.0))
  (* half (expt 2.0 (/ (inverse-cumulative-normal x) slope))))


;; allow keyword to control allocation of copy.
(defmethod copy ((arr array) &key (static (allocated-array-p arr))
		 ((:-> result) (similar arr :static static)))
  (check-size arr result)
  (cond ((lucid-float-arrays-p arr result)
	 (internal-copy-array arr result (total-size arr)))
	((lucid-8bit-arrays-p arr result)
	 (memcpy result arr (total-size arr)))
	((lucid-fixnum-arrays-p arr result)
	 (internal32-copy-array arr result (total-size arr)))
	((lucid-1bit-arrays-p arr result)
	 (bit-not arr result)
	 (bit-not result result))	;faster than doing bit-and and bit-ior
	(t
	 (with-displaced-vectors ((displaced-arr arr)
				  (displaced-result result))
	   (loop for i from 0 below (total-size arr)
		 with elt-type = (array-element-type result) do
		 (setf (aref displaced-result i) (coerce (aref displaced-arr i) elt-type))))))
  result)


(defun make-identity-matrix (size &key ((:-> res)
				   (make-array (list size size) 
					       :element-type 'single-float)))
  (declare-matrices (res) ())
  (checktype-matrices (res))
  (zero! res)
  (dotimes (i (min (row-dim res) (col-dim res)))
    (declare (fixnum i))
    (setf (aref res i i) 1.0))
  res)

;; made these methods consistent with new names
(defmethod make-diagonal-matrix ((diagonals list)
			&key ((:-> result) (make-array (list (length diagonals)
							     (length diagonals))
						       :element-type 'single-float)))
  (fill! result 0.0)
  (dotimes (index (length diagonals))
    (setf (aref result index index) (float (elt diagonals index))))
  result)

(defmethod make-diagonal-matrix ((diagonals array)
			    &key ((:-> result) (make-array (list (total-size diagonals)
								 (total-size diagonals))
							   :element-type 'single-float)))
  (fill! result 0.0)
  (let ((vec (vectorize diagonals)))
    (dotimes (index (length vec))
      (setf (aref result index index) (float (aref vec index)))))
  result)

;; uses make-diagonal-matrix
(defmethod matrix-inverse ((matrix array)
			   &key dimension-limit condition-number-limit singular-value-limit
			   suppress-warning
			   ((:-> result) (make-array (reverse (dimensions matrix))
						     :element-type (array-element-type matrix))))
  (declare-matrices (matrix result) nil)
  (unless (>= (row-dim matrix) (col-dim matrix))
    (error "Left pseudo-inverse doesn't make sense for short, fat matrices"))
  (with-static-arrays (s u v)
    (multiple-value-setq (s u v) (svd matrix))
    (with-static-arrays
	((diagonal (similar matrix :dimensions (list (col-dim v) (col-dim u)) :initial-element 0.0))
	 (tmp-arr (similar matrix :dimensions (list (col-dim v) (row-dim u)))))
    
      ;; throw away smallest singular values
      (when dimension-limit
	(decf dimension-limit)		; Decrement since the user specifies dimensions from 1...n
	(dotimes (index (length s))
	  (when (> index dimension-limit)
	    (setf (aref s index) 0.0))))
      (when condition-number-limit
	(dotimes (index (length s))
	  (when (> (sqrt (/ (aref s 0) (aref s index))) condition-number-limit)
	    (setf (aref s index) 0.0))))
      (when singular-value-limit
	(dotimes (index (length s))
	  (when (< (aref s index) singular-value-limit)
	    (setf (aref s index) 0.0))))

      ;; Compute the inverse
      (div 1.0 s :zero-value 0.0 :suppress-warning suppress-warning :-> s)
      (make-diagonal-matrix s :-> diagonal)
      (matrix-mul v (matrix-mul-transpose diagonal u :-> tmp-arr) :-> result))))

;; uses make-diagonal-matrix
(defmethod quadratic-decomposition ((matrix array) &key ((:-> result) (similar matrix)))
  (declare-matrices (matrix) nil)
  (multiple-value-bind (s u v) (svd matrix)
    (declare-matrices (u v) (s))
    (unless (almost-equal u v)
      (error "Quadratic decomposition is impossible"))
    (square-root s :-> s)
    (matrix-mul v (make-diagonal-matrix s) :-> result)))

;; uses make-diagonal-matrix
(defun principal-components (matrix &key (dimension (min (row-dim matrix) (col-dim matrix)))
					scale)
  (with-svd (s u v) matrix
    (when (> dimension (length s))
      (error "Too many principal components requested"))

    (let ((principal-components (similar matrix :dimensions (list dimension (col-dim matrix))))
	  (principal-values (similar matrix :dimensions (list dimension (row-dim matrix))))
	  (singular-values (similar matrix :dimensions dimension))
	  (diagonal (when scale (make-diagonal-matrix (vectorize s :size dimension)))))

      ;; Calculate the principal components
      (crop (matrix-transpose v) :y-dim dimension :-> principal-components)
      (when scale (with-static-arrays ((product (matrix-mul diagonal principal-components)))
		    (copy product :-> principal-components)))

      ;; Calculate the principal values
      (crop u :y-dim dimension :-> principal-values)
      (when scale
	(with-static-arrays ((product (matrix-mul principal-values diagonal)))
	  (copy product :-> principal-values)))

      ;; Calculate the shortened singular values
      (crop s :x-dim dimension :-> singular-values)
      
      (values principal-components principal-values singular-values))))

;; uses make-diagonal-matrix
(defmethod diagonal-p ((matrix array) &key (tolerance *tolerance*))
  (with-static-arrays ((diagonal (diagonal matrix))
		       (diagonal-matrix (make-diagonal-matrix diagonal)))
    (when (almost-equal diagonal-matrix matrix :tolerance tolerance)
      matrix)))

;; uses make-diagonal-matrix
(defmethod regress ((observed array) (predictor array) &key diagonal affine transpose transform)
  (cond ((and affine diagonal) (error "Diagonal affine regression not implemented"))
	(transform (with-static-arrays ((t-observed (matrix-mul observed transform))
					(t-predictor (matrix-mul predictor transform))
					(inverse (matrix-inverse transform)))
		     (matrix-mul transform
				 (matrix-mul (regress t-observed t-predictor :diagonal diagonal
						      :affine affine) inverse))))
	(transpose (with-static-arrays ((t-observed (matrix-transpose observed))
					(t-predictor (matrix-transpose predictor)))
		     (regress t-observed t-predictor :diagonal diagonal :affine affine)))
		     
	(affine (let* ((dimensions (list (row-dim predictor) (1+ (col-dim predictor))))
		       (affine-predictor (similar predictor :dimensions dimensions)))
		  (fill! affine-predictor 1)
		  (paste predictor affine-predictor :-> affine-predictor)
		  (regress observed affine-predictor)))
	(diagonal (make-diagonal-matrix (loop for observed-col in (cols observed)
					      for predictor-col in (cols predictor)
					      collect (regress observed-col predictor-col))))
	(t (matrix-mul (matrix-inverse predictor) observed))))

;; replaced in array-ops to make compatible with standard args.
(defun displaced-matrix (arr &key (y-offset 0) (x-offset 0)
			     (y-size (y-dim arr)) (x-size (x-dim arr))
			     (y y-offset) (x x-offset) (y-dim y-size) (x-dim x-size))
  (unless (< (rank arr) 3)
    (error "Displaced-matrix implemented only for vectors or two-dimensional arrays."))
  (setq y (max (round y) 0)
	x (max (round x) 0)
	x-dim (min (round x-dim) (- (x-dim arr) x)) ;clip to image boundary
	y-dim (min (round y-dim) (- (y-dim arr) y)))
  (make-array (list y-dim x-dim) :displaced-to arr :element-type (array-element-type arr)
	      :displaced-index-offset (+ x (* y (x-dim arr)))))

;; Add -> argument, maintain name argument for back-compatibility
;; Also, add methods on arrays and lists. EJC 8.29.92
(defmethod make-histogram ((im image) 
			   &key
			   (range (list (minimum im) (maximum im)))
			   (binsize nil binsize-specified-p)
			   (bincenter (mean im))
			   (size (get-default (find-class 'discrete-function) :size)
				 size-specified-p)
			   (name (format nil "Histogram of ~A" (name im)))
			   (-> name))
  (let ((interval (- (apply '- range)))
	data origin)
    (cond ((and binsize-specified-p size-specified-p)
	   (error "Can't specify both binsize and size of histogram"))
	  (binsize-specified-p
	   (setq size (+ (/-0 interval binsize) 2)))
	  (t (setq binsize (if (zerop interval) 1.0 (/ interval (- size 2))))))
    ;; Origin is bincenter minus a multiple of binsize such that
    ;; origin is <= minimum of image.
    (setq origin (- bincenter (* binsize (round (- bincenter (car range)) binsize))))
    (setq data (compute-histogram im origin binsize (floor size)))
    (make-instance 'histogram 
		   :data data
		   :size (total-size data)
		   :origin origin 
		   :increment binsize
		   :image im
		   :name ->)))

(defmethod make-histogram ((list list) &rest args)
  (with-local-viewables ((im (make-image (make-matrix list))))
    (apply 'make-histogram im args)))

(defmethod make-histogram ((arr array) &rest args)
  (with-local-viewables ((im (make-image arr)))
    (apply 'make-histogram im args)))
