;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;  File: row-ops.lisp
;;;  Author: Chichilnisky
;;;  Description: Functions for row operations on matrices
;;;  Creation Date: March 1992
;;;  ----------------------------------------------------------------
;;;    Object-Based Vision and Image Understanding System (OBVIUS),
;;;      Copyright 1988, Vision Science Group,  Media Laboratory,  
;;;              Massachusetts Institute of Technology.
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(in-package 'obvius)
(export '(row rows displaced-rows col cols
	  normalize-rows normalize-cols
	  add-rows sub-rows mul-rows div-rows add-cols sub-cols
	  paste-rows paste-cols append-rows append-cols shuffle-rows swap-rows sort-rows
	  sum-rows mean-rows covariance-rows sum-cols mean-cols covariance-cols))

;; Functions for row operations on matrices
(defun displaced-rows (matrix &optional (start 0) (end (row-dim matrix)))
  (loop for row from start below end
	collect (displaced-row row matrix)))

(defun row (row matrix)
  (copy (displaced-row row matrix)))

(defun rows (matrix  &optional (start 0) (end (row-dim matrix)))
  (displaced-rows (copy matrix) start end))

(defun col (col matrix)
  (row col (matrix-transpose matrix)))

(defun cols (matrix  &optional (start 0) (end (col-dim matrix)))
  (rows (matrix-transpose matrix) start end))

#|
(defun row (row matrix)
  (make-array (col-dim matrix)
	      :element-type (array-element-type matrix)
	      :displaced-to matrix
	      :displaced-index-offset (* row (col-dim matrix))))

(defun rows (matrix)
  (loop for row from 0 below (row-dim matrix)
	collect (row row matrix)))
|#

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;;; Row and column operations on matrices.

(defmethod check-size-rows ((a array) &rest a-list)
  (cond ((null a-list) a)
	((not (equal (col-dim a) (col-dim (car a-list))))
	 (error "Rows have different dimensions." ))
	(t (apply 'check-size-rows a-list))))

(defmethod check-size-cols ((a array) &rest a-list)
  (cond ((null a-list) a)
	((not (equal (row-dim a) (row-dim (car a-list))))
	 (error "Columns have different dimensions." ))
	(t (apply 'check-size-cols a-list))))

(defmethod sub-rows ((arr array) (vector vector) &key ((:-> result) (similar arr)))
  (check-size arr result)
  (check-size-rows arr vector)
  (checktype-matrices (arr result vector))
  (internal-row-sub arr vector result (row-dim arr) (col-dim arr))
  result)

(defmethod sub-cols ((arr array) (arr2 array) &key ((:-> result) (similar arr)))
  (check-size arr result)
  (with-static-arrays ((transpose (matrix-transpose arr)))
    (sub-rows transpose (vectorize arr2) :-> transpose)
    (matrix-transpose transpose :-> result))
  result)

(defmethod add-cols ((arr array) (arr2 array) &key ((:-> result) (similar arr)))
  (check-size arr result)
  (with-static-arrays ((transpose (matrix-transpose arr)))
    (add-rows transpose (vectorize arr2) :-> transpose)
    (matrix-transpose transpose :-> result))
  result)

(defmethod sub-rows ((vector vector) (array array) &key ((:-> result) (similar array)))
  (negate (sub-rows array vector :-> result) :-> result))

(defmethod sub-rows ((vector1 vector) (vector2 vector) &key ((:-> result) (similar vector1)))
  (sub vector1 vector1 :-> result))

(defmethod add-rows ((arr array) (vector vector) &key ((:-> result) (similar arr)))
  (check-size arr result)
  (check-size-rows arr vector)
  (checktype-matrices (arr result vector))
  (internal-row-add arr vector result (row-dim arr) (col-dim arr))
  result)

(defmethod add-rows ((vector vector) (array array) &key ((:-> result) (similar array)))
  (add-rows array vector :-> result))

(defmethod add-rows ((vector1 vector) (vector2 vector) &key ((:-> result) (similar vector1)))
  (add vector1 vector1 :-> result))


(defmethod div-rows ((arr array) (vector vector) &key ((:-> result) (similar arr)))
  (dotimes (i (row-dim arr))
    (div (displaced-row i arr)  vector :-> (displaced-row i result)))
  result)

(defmethod div-rows ((vector vector) (arr array) &key ((:-> result) (similar arr)))
  (dotimes (i (row-dim arr))
    (div vector (displaced-row i arr) :-> (displaced-row i result)))
  result)

(defmethod div-rows ((vector1 vector) (vector2 vector) &key ((:-> result) (similar vector1)))
  (div vector1 vector2 :-> result))

(defmethod mul-rows ((arr array) (vector vector) &key ((:-> result) (similar arr)))
  (dotimes (i (row-dim arr))
    (mul (displaced-row i arr)  vector :-> (displaced-row i result)))
  result)

(defmethod mul-rows ((vector vector) (arr array) &key ((:-> result) (similar arr)))
  (dotimes (i (row-dim arr))
    (mul vector (displaced-row i arr) :-> (displaced-row i result)))
  result)

(defmethod mul-rows ((vector1 vector) (vector2 vector) &key ((:-> result) (similar vector1)))
  (mul vector1 vector2 :-> result))


(defun sum-rows (matrix &key ((:-> result) (similar matrix :dimensions (col-dim matrix))))
  (checktype-matrices (matrix))
  (check-size-rows matrix result)
  (internal-row-sum matrix result (row-dim matrix) (col-dim matrix))
  result)

(defun mean-rows (matrix &key ((:-> result) (similar matrix :dimensions (col-dim matrix))))
  (sum-rows matrix :-> result)
  (div result (row-dim matrix) :-> result))

(defun sum-cols (matrix &key ((:-> result) (similar matrix :dimensions (list (row-dim matrix) 1))))
  (checktype-matrices (matrix))
  (check-size-cols matrix result)
  (internal-col-sum matrix result (row-dim matrix) (col-dim matrix))
  result)

(defmethod normalize-rows ((arr array) &key ((:-> result) (similar arr)))
  (check-size arr result)
  (dotimes (row (row-dim arr))
    (normalize (displaced-row row arr) :-> (displaced-row row result)))
  result)

(defmethod normalize-cols ((arr array) &key ((:-> result) (similar arr)))
  (check-size arr result)
  (with-static-arrays ((arr-t (matrix-transpose arr))
		       (res-t (matrix-transpose result)))
    (normalize-rows arr-t :-> res-t)
    (matrix-transpose res-t :-> result))
  result)

(defun mean-cols (matrix &key ((:-> result)
			       (similar matrix :dimensions (list (row-dim matrix) 1))))
  (sum-cols matrix :-> result)
  (div result (col-dim matrix) :-> result))

(defun covariance-rows (matrix &key ((:-> result) (similar matrix :dimensions
							  (list (col-dim matrix) (col-dim matrix))))
			       sample)
  (checktype-matrices (matrix))
  (with-static-arrays ((offset (similar matrix :static t))
		       (mean (mean-rows matrix)))
    (sub-rows matrix mean :-> offset)
    (matrix-transpose-mul offset offset :-> result)
    (div result (if sample (max 1 (- (row-dim matrix) 1)) (row-dim matrix))
	 :-> result)))

(defun covariance-cols (matrix &key ((:-> result)
				    (similar matrix :dimensions
					     (list (row-dim matrix) (row-dim matrix))))
			       sample)
  (checktype-matrices (matrix))
  (with-static-arrays ((offset (similar matrix :static t))
		       (mean (mean-cols matrix)))
    (sub-cols matrix mean :-> offset)
    (matrix-mul-transpose offset offset :-> result)
    (div result (if sample (max 1 (- (col-dim matrix) 1)) (col-dim matrix))
	 :-> result)))

(defun shuffle-rows (arr &key ((:-> result) (similar arr)))
  (check-size arr result)
  (unless (eq arr result)
    (copy arr :-> result))
  (if (= (col-dim arr) 1)
      (with-displaced-vectors ((vec arr))
	(shuffle vec :-> vec))
      (loop with rows = (row-dim result)
	    with cols = (col-dim result)
	    for i from 0 below rows
	    for j from rows by -1
	    for rand = (+ i (random j))
	    do (internal-row-swap result i rand cols)
	    ))
  result)

(defun sort-rows (arr predicate &key (key #'vector-length)
		     ((:-> result) (similar arr)))
  (with-static-arrays ((copy (copy arr)))
    (let* ((row-list (displaced-rows copy))
	   (rows (make-array (length row-list) :initial-contents row-list)))
      (sort rows predicate :key key)
      (dotimes (i (length rows))
	(copy (aref rows i) :-> (displaced-row i result)))
      result)))

#|
(dimensions (setf mat (make-array '(30 3) :element-type 'single-float)))
(dimensions (randomize mat 100.0 :-> mat))
(image-from-array (make-matrix (mapcar 'vector-length (displaced-rows mat))))
(dimensions (sort-rows mat '> :-> mat))
(image-from-array (make-matrix (mapcar 'vector-length (displaced-rows (sort-rows mat '>)))))
|#

(defun swap-rows (arr row-1 row-2 &key ((:-> result) (copy arr)))
  (check-size arr result)
  (let ((rows (row-dim arr)))
    (setq row-1 (floor row-1))
    (setq row-2 (floor row-2))
    (unless (and (< -1 row-1 rows) (< -1 row-2 rows))
      (error "Rows specified out of range")))
  (internal-row-swap result row-1 row-2 (col-dim result))
  result)

(defun paste-cols (arr-list &key ((:-> result)))
  (apply 'check-size-cols arr-list)
  (let* ((rows (row-dim (first arr-list)))
	 (cols (sum-of (mapcar 'col-dim arr-list))))
    (unless result (setf result (similar (first arr-list) :dimensions (list rows cols))))
    (loop for arr in arr-list
	  with x-offset = 0
	  do
	  (paste arr result :-> result :x x-offset)
	  (incf x-offset (col-dim arr)))
    result))

(defun paste-rows (arr-list &key ((:-> result)))
  (apply 'check-size-rows arr-list)
  (let* ((rows (sum-of (mapcar 'row-dim arr-list)))
	 (cols (col-dim (first arr-list))))
    (unless result (setf result (similar (first arr-list) :dimensions (list rows cols))))
    (loop for arr in arr-list
	  with y-offset = 0
	  do
	  (paste arr result :-> result :y y-offset)
	  (incf y-offset (row-dim arr)))
    result))

(defun append-cols (&rest arr-list)
  (paste-cols arr-list))

(defun append-rows (&rest arr-list)
  (paste-rows arr-list))


;;; Local Variables:
;;; buffer-read-only: t 
;;; fill-column: 79
;;; End:
