(sys:compile-load-if "wimberly:mooney.ml;id3-all")
(setf *chi-square-cutoff* nil)
(setf *tie-breaker* 'ordered)

(defparameter *feature-limit* 350)
(defparameter *trace-fringe* nil)

(defvar *new-features* nil)

(defun train-fringe (training-examples)
  "Function to be used by UNIVERSAL-TESTER.  Handles feature ordered and alist example formats"
  (fringe (if (alist-example-p (first training-examples))
	   (mapcar #'(lambda (ex) (convert-to-ordered-example ex nil))
		   training-examples)
	   (copy-tree training-examples))))

(defun fringe (examples)
  (let ((*domains* (copy-list *domains*))
	(*feature-names* (copy-list *feature-names*))
	(iteration 0)
	new-features)
    (setf *new-features* nil)
    (loop (incf iteration)
	  (trace-print *trace-fringe* "~%~%Iteration ~D" iteration)
	  (id3 examples)
	  (trace-print *trace-fringe* "~%New Decision tree:~%~A" *decision-tree*)
	  (when (linear-tree *decision-tree*)
	    (trace-print *trace-fringe* "~%~%Linear tree so stopping.")
	    (return (fringe-output *decision-tree* iteration "Linear tree")))
	  (setq new-features (define-new-features examples *decision-tree*))
	  (when (and *trace-fringe* new-features)
	      (format t "~%~%New Features:")
	      (dolist (new-feature new-features)
		(format t "~%~A: ~A~%  DNF: ~A" new-feature
			(get new-feature 'def) (get new-feature 'dnf))))
	  (when (null new-features)
		(trace-print *trace-fringe* "~%~%No new features so stopping.")
		(return (fringe-output *decision-tree* iteration "No new features")))
	  (when (> (length *feature-names*) *feature-limit*)
		(trace-print *trace-fringe* "~%~%Max features exceeded so stopping.")
		(return (fringe-output *decision-tree* iteration "Max features exceeded"))))))


(defun linear-tree (decision-tree)
  (cond ((eq decision-tree '-) T)
	((symbolp decision-tree) nil)
	(t (let ((true-branch (find 'true (decision-tree-branches decision-tree) :key #'branch-value))
		 (false-branch (find 'false (decision-tree-branches decision-tree) :key #'branch-value)))
	     (and true-branch false-branch
		  (eq (branch-subtree true-branch) '+)
		  (linear-tree (branch-subtree false-branch)))))))

;;;; ==========================================================================================
;;;; Making new features
;;;; ==========================================================================================

(defun define-new-features (examples decision-tree)
  (let* ((new-features (find-new-features decision-tree))
	 (new-feature-domains (make-list (length new-features) :initial-element '(true false))))
    (when new-features
	(mapcar #'(lambda (example)
		    (add-features example new-features))
		examples)
	(nconc *feature-names* new-features)
	(nconc *domains* new-feature-domains))
     new-features))

(defun find-new-features (decision-tree)
   (mapcan #'(lambda (feature-def)
	       (let ((dnf (compute-dnf-form feature-def)))
		 (unless (dolist (feature *new-features* nil)
			   (when (dnf-equal dnf (get feature 'dnf))
			     (return t)))
		   (let ((feature-name (gensym "F")))
		     (setf (get feature-name 'def) feature-def)
		     (setf (get feature-name 'dnf) dnf)
		     (push feature-name *new-features*)
		     (list feature-name)))))
       (remove-duplicates (find-features decision-tree)
		     :test #'(lambda (a b) (set-equal a b :test #'equal)))))


(defun find-features  (decision-tree &optional last-two-features)
  (cond  ((and (symbolp decision-tree)
	       (not (member decision-tree '(- negative)))
	       (rest last-two-features))
	  (list last-two-features))
	 ((decision-tree-p decision-tree)
	  (mapcan #'(lambda (branch)
		      (find-features (branch-subtree branch)
				   (cons (feature-description decision-tree branch)
					 (if last-two-features
					     (list (first last-two-features))))))
		  (decision-tree-branches decision-tree)))))

(defun feature-description (decision-tree branch)
  (cons (feature-name (decision-tree-feature decision-tree))
	(if (null (decision-tree-threshold decision-tree))
	    (list (branch-value branch))
	    (list (branch-value branch) (decision-tree-threshold decision-tree)))))


(defun add-features (example new-features)
  (nconc (second example)
	 (mapcar #'(lambda (new-feature)
		     (new-feature-value example new-feature))
		 new-features))
  example)

(defun new-feature-value (example new-feature)
  (if (every #'(lambda (component)
		 (cond ((= (length component) 2)  ; not linear threshold
			(eq (feature-name-value example (first component))
			    (second component)))
		       (t (funcall (second component)
				   (feature-name-value example (first component))
				   (third component)))))
	     (get new-feature 'def))
      'true 'false))

(defun feature-name-value (example feature-name)
  (do ((rest-example (second example) (rest rest-example))
       (rest-features *feature-names* (rest rest-features)))
      ((eq (first rest-features) feature-name) (first rest-example))))

;;;; ==========================================================================================
;;;; DNF stuff (for checking uniqueness of created features)
;;;; ==========================================================================================

(defun dnf-equal (dnf1 dnf2)
  (set-equal dnf1 dnf2 :test #'(lambda (term1 term2) (set-equal term1 term2 :test #'equal))))

(defun compute-dnf-form (feature-def)
  (simplify-dnf (multiply-dnf (get-dnf-form (first feature-def))
			      (get-dnf-form (second feature-def)))))

(defun get-dnf-form (feature-pair)
  (let ((dnf (get (first feature-pair) 'dnf)))
    (cond ((not dnf) (list (list feature-pair)))
	  ((eq (second feature-pair) 'false)
	    (negate-dnf dnf))
	  (t dnf))))

(defun multiply-dnf (dnf1 dnf2)
  (mapcan #'(lambda (term1)
	      (mapcan #'(lambda (term2)
			  (unless (some #'(lambda (lit1)
					(member lit1 term2 :test #'negatives-p))
				    term1)
			    (list (union term1 term2 :test #'equal))))
		      dnf2))
	  dnf1))

(defun negatives-p  (lit1 lit2)
  (or (and (eq (first lit1) 'not)
	   (equal (second lit1) lit2))
      (and (eq (first lit2) 'not)
	   (equal (second lit2) lit1))
      (and (eq (first lit1) (first lit2))
	   (or (and (eq (second lit1) 'true)
		    (eq (second lit2) 'false))
	       (and (eq (second lit1) 'false)
		    (eq (second lit2) 'true))))))

(defun negate-dnf (dnf)
  (reduce #'multiply-dnf
	  (mapcar #'negate-term
		  dnf)))

(defun negate-term (term)
  (mapcar #'(lambda (lit) (list (negate-literal lit)))
	  term))

(defun negate-literal (lit)
  (cond
    ((eq (first lit) 'not)
     (second lit))
    ((binary-feature-p (first lit))
     (if (eq (second lit) 'true)
	 (list (first lit) 'false)
	 (list (first lit) 'true)))
    (t (list 'not lit))))

(defun simplify-dnf (dnf)
  (do ((rest-dnf dnf (rest rest-dnf))
       (result nil))
      ((null rest-dnf) result)
    (dolist (term (rest rest-dnf) (push (first rest-dnf) result))
      (cond ((set-equal (first rest-dnf) term :test #'equal) (return nil))
	    ((subsetp (first rest-dnf) term :test #'equal)
	     (setf rest-dnf (delete term rest-dnf)))
	    ((subsetp term (first rest-dnf) :test #'equal)
	     (return nil))))))
  
(defun set-equal (set1 set2 &key (test #'eql))
  "Return T if two sets are equal"
  ;;; Iterative version
  ;; Check that each element of set1 is a member of set2
  ;; each time removing that element from set2.  If set2
  ;; is empty at end then it contains no extra elements.
  (and (= (length set1) (length set2))
       (every #'(lambda (elt1)
		  (member elt1 set2 :test test))
	      set1)))


;;;; ==========================================================================================
;;;; Assembling output
;;;; ==========================================================================================

(defun fringe-output (decision-tree iterations termination-reason)
  (let* ((num-old-features (- (length *feature-names*) (length *new-features*)))
	 (used-new-features (used-new-features decision-tree num-old-features))
	 (new-feature-num (1- num-old-features))
	 (renumber-alist (mapcar #'(lambda (old-num)
				     (cons old-num (incf new-feature-num)))
				 used-new-features)))
    (transform-tree-features decision-tree renumber-alist)
    (list decision-tree
	  (mapcar #'feature-name used-new-features) iterations termination-reason)))

(defun transform-tree-features (decision-tree renumber-alist)
  (unless (symbolp decision-tree)
    (let ((renumber (assoc (decision-tree-feature decision-tree) renumber-alist)))
      (if renumber
	  (setf (decision-tree-feature decision-tree) (cdr renumber)))
      (dolist (branch (decision-tree-branches decision-tree))
	(transform-tree-features (branch-subtree branch) renumber-alist)))))

(defun used-new-features (decision-tree num-old-features)
  (if (symbolp decision-tree)
      nil
      (append (when (>= (decision-tree-feature decision-tree)
		     num-old-features)
		  (list (decision-tree-feature decision-tree)))
	      (mapcan #'(lambda (branch) (used-new-features (branch-subtree branch) num-old-features))
		      (decision-tree-branches decision-tree)))))

(defun train-fringe-output (training-result training-examples)
  (declare (ignore training-examples))
  (format t "~%Stopped after iteration ~D because ~A" (third training-result) (fourth training-result))
  (if (second training-result)
      (format t "~%New Features used:")
      (format t "~%No new features used."))
  (dolist (new-feature (second training-result))
    (format t "~%~A: ~A" new-feature (get new-feature 'dnf))))


;;;; ==========================================================================================
;;;; Testing 
;;;; ==========================================================================================

(defun test-fringe (example fringe-result)
  (setf example (if (alist-example-p example)
		    (convert-to-ordered-example example nil)
		    (copy-tree example)))
  (test-id3 (add-features-dnf example (second fringe-result))
	    (first fringe-result)))

(defun add-features-dnf (example new-features)
  (nconc (second example)
	 (mapcar #'(lambda (new-feature)
		     (new-feature-dnf-value example new-feature))
		 new-features))
  example)

(defun new-feature-dnf-value (example new-feature)
  (if (some #'(lambda (term)
		(every #'(lambda (lit)
			   (let* ((negated (when (eq (first lit) 'not)
					     (setf lit (second lit))))
				  (value (feature-name-value example (first lit)))
				  (truth-value (cond ((= (length lit) 2)  ; not linear threshold
						      (eq value (second lit)))
						     (t (funcall (second lit) value
								 (third lit))))))
			     (if negated
				 (not truth-value)
				 truth-value)))
		       term))
	    (get new-feature 'dnf))
      'true 'false))

