;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;  File: LLSE-filter-design.lisp
;;;  Author: Simoncelli
;;;  Description: Code to design symmetric/anti-symmetric FIR filters from their
;;;               DFT specification.  Error criterion is simple weighted least
;;;               squares in the frequency domain.
;;;  Creation Date: 3/94
;;;  ----------------------------------------------------------------
;;;    Object-Based Vision and Image Understanding System (OBVIUS),
;;;      Copyright 1988, Vision Science Group,  Media Laboratory,  
;;;              Massachusetts Institute of Technology.
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(in-package 'obvius)
(export '())

(obv-require :matrix)

;;; Weighted least squares (in the frequency domain) 2D filter design,
;;; for REAL (ANTI-)SYMMETRIC filters.  dfft is the desired DFT of the
;;; filter, and should be an even- or odd-symmetric image, with
;;; power-of-two dimensions.  Weight can be an image or a number which
;;; is used as an exponent for an exponentially decaying weight
;;; function.

;;; *** Even kernels not working yet.
(defun design-filter (dfft fdims &key weight dc-weight step-vector)
  (when (some #'evenp fdims) (error "Sorry: not written for even size kernels"))
  (with-local-viewables
      ((ctr (mapcar #'(lambda (x) (floor x 2))  (dimensions dfft)))
       (inv (let ((im (flip-x dfft)))
	      (flip-y im :-> im)
	      (circular-shift im :x 1 :y 1 :-> im)))
       (symmetric (let ((pwr (variance dfft)))
		    (cond ((< (/ (mean-square-error dfft inv) pwr) 0.01) ;*** magic
			   t)
			  ((< (/ (mean-square-error (-. inv :-> inv) dfft) pwr) 0.01)
			   nil)
			  (t (error "desired fft must be even- or odd- symmetric.")))))
       (fourier (make-fourier-basis (dimensions dfft)
				    :symmetric symmetric
				    :fdims fdims))
       (wt (when (numberp weight)
	     (let ((im (make-1-over-r (dimensions dfft) :exponent weight :origin ctr)))
	       (setf (iref im (car ctr) (cadr ctr)) (or dc-weight 1.0))
	       ))))
    (multiple-value-bind (p-vect err)
	(linear-least-squares fourier dfft :return-mse t :weight (or wt weight))
      (values (params-to-filter p-vect fdims symmetric :step-vector step-vector) err))))

;;; Make sine or cosine fourier basis functions of given size (dims),
;;; covering given frequency range.
(defun make-fourier-basis (dims &key symmetric fdims)
  (make-image-sequence
   (loop with y-ctr = (/ (- (car fdims) 1) 2)
	 with x-ctr = (/ (- (cadr fdims) 1) 2)
	 for y-ind from 0 below (car fdims)
	 for wy = (- y-ind y-ctr) 
	 nconc
	 (loop for wx from  (cond ((< wy 0) 0)
				  ((> wy 0) 1)
				  (t (if symmetric 0 1)))
	       to x-ctr
	       collect
	       (make-synthetic-image dims
				     (cond ((= wx wy 0)
					    #'(lambda (y x) 1.0))
					   ((not symmetric)
					    #'(lambda (y x)
						(* 2.0 (sin (+ (* y wy) (* x wx))))))
					   (symmetric
					    #'(lambda (y x)
						(* 2.0 (cos (+ (* y wy) (* x wx)))))))
				     :x-range (pi-range (cadr dims))
				     :y-range (pi-range (car dims)))))))

;;; Must be matched to ordering in make-fourier-basis
(defun params-to-filter (p-vect fdims symmetric &key step-vector)
  (let ((kernel (make-array fdims :element-type 'single-float))
	(y-ctr (/ (- (car fdims) 1) 2))
	(x-ctr (/ (- (cadr fdims) 1) 2)))
    (loop with p-ind = 0
	  for y from 0 below (car fdims)
	  do
	  (loop for x from (ceiling (cond ((< y y-ctr) x-ctr)
					  ((> y y-ctr) (1+ x-ctr))
					  (t (if symmetric x-ctr (1+ x-ctr)))))
		below (cadr fdims)
		do
		(setf (aref kernel y x) (aref p-vect p-ind))
		(setf (aref kernel (- (* 2 y-ctr) y) (- (* 2 x-ctr) x))
		      (if symmetric (aref p-vect p-ind) (- (aref p-vect p-ind))))
		(incf p-ind)))
    (make-filter kernel :step-vector step-vector)))

;;; Returns vector of the values a_i minimizing [w (y - \sum_i a_i
;;; x_i)]^2 where x_i are the basis images, y is the desired model
;;; image, and w is a weighting image (can be nil).
(defun linear-least-squares (x-seq y &key (rank (length. x-seq))
				   weight return-mse)
  (let* ((dim (length. x-seq))
	 (mat (make-array (list dim dim) :element-type 'single-float))
	 (vect (make-array dim  :element-type 'single-float))
	 inv est)
    (with-local-viewables ((temp (similar y))
			   (frame-j (similar temp))
			   (wt2 (when weight (square weight))))
      (loop for j from 0 below dim
	    do
	    (status-message "Computing row: ~A of ~A" j dim)
	    (if wt2
		(mul (frame j x-seq) wt2 :-> frame-j)
		(copy (frame j x-seq) :-> frame-j))
	    (mul y frame-j :-> temp)
	    (setf (aref vect j) (mean temp))
	    (loop for i from j below dim
		  do
		  (mul (frame i x-seq) frame-j  :-> temp)
		  (setf (aref mat j i) (setf (aref mat i j) (mean temp))))))
    (setq inv (matrix-inverse mat :dimension-limit rank))
    (setq est (matrix-mul vect inv))
    (if return-mse
	(with-local-viewables ((err (dot-product x-seq est)))
	  (sub err y :-> err)
	  (square err :-> err)
	  (values est (mean err)))
	est)))

;;; FIgure out range of values in radians for an DFT of size dim:
(defun pi-range (dim)
  (list (- pi)
	(* pi (/ (- dim 2) dim))))

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;; Useful radial profile functions

;;; Bandwidth is in octaves.
(defun log-raised-cos-fn (&key (ctr-freq (/ pi 2)) (bandwidth 1))
  #'(lambda (r)
      (let* ((r-arg (* (/ pi bandwidth)
                       (log-0 (/ r ctr-freq) 2 bandwidth))))
        (if (< (- pi) r-arg pi)
            (sqrt (* 0.5 (+ (cos r-arg) 1.0)))
            0.0))))

;;; Assymmetric: right side is sqrt'ed twice!
(defun log-raised-cos-fn2 (&key (ctr-freq (/ pi 2)) (bandwidth 1))
  #'(lambda (r)
      (let* ((r-arg (* (/ pi bandwidth)
                       (log-0 (/ r ctr-freq) 2 bandwidth))))
        (cond ((< (- pi) r-arg 0)
               (sqrt (* 0.5 (+ (cos r-arg) 1.0))))
              ((<= 0 r-arg pi)
               (sqrt (sqrt (* 0.5 (+ (cos r-arg) 1.0)))))
              (t 0.0)))))
 
;;; Left side sqr rooted twice
(defun log-raised-cos-fn3 (&key (ctr-freq (/ pi 2)) (bandwidth 1))
  #'(lambda (r)
      (let* ((r-arg (* (/ pi bandwidth)
                       (log-0 (/ r ctr-freq) 2 bandwidth))))
        (cond ((< (- pi) r-arg 0)
               (sqrt (sqrt (* 0.5 (+ (cos r-arg) 1.0)))))
              ((<= 0 r-arg pi)
               (sqrt (* 0.5 (+ (cos r-arg) 1.0))))
              (t 0.0)))))
 
(defun raised-cos-fn (&key (ctr-freq (/ pi 2)) (bandwidth 1))
  #'(lambda (r)
      (let* ((r-arg (* (- r ctr-freq) (/ pi (* ctr-freq bandwidth)))))
        (if (< (- pi) r-arg pi)
            (sqrt (* 0.5 (+ (cos r-arg) 1.0)))
            0.0))))
 
(defun x^n (x n)
  (loop for i from 0 to n
        for res = 1 then (* res x)
        finally (return res)))

