(module audit racket
  
  (provide (all-defined-out))
  
  (require "mdp.rkt")

  
  
  
  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
  ;; Helper Functions
  
  ;; utilities1: [hash state num]
  ;; utilities2: [hash state num]
  ;; returns the absolute value of the largest difference between the two 
  (define (max-change mdp utilities1 utilities2)
    (apply max (map (lambda (state) 
                      (abs (- (hash-ref utilities2 state) (hash-ref utilities1 state))))
                    (states mdp))))
  
  ;; Returns Q(s, a)
  ;; That is, it computes the value of doing action act from state state
  ;; under the assumption that all states have the utilities given in utilities
  (define (value-of-action mdp utilities state act)
    (+ (reward mdp state act)
       (* (discount-factor mdp)
          (summation (lambda (next-state)  
                       (* (trans-prob mdp state act next-state)
                          (hash-ref utilities next-state)))
                     (states mdp)))))
  
  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
  ;; Basic MDP algorithms
  
  
  ;; mdp: mdp
  ;; policy: [hash state action]
  ;; utilities: [hash state num]
  ;;  a guess of the values of the a policy
  ;; epsilon: real
  ;;  a parameter to many algorithms below that determines 
  ;;  how approximate they are.  the smaller the number, 
  ;;  the more exact the algoirthms are.
  ;;  the returned values will be no more than 2*epsilon*gamma/(1-gamma)
  ;;  away from the true values
  ;;
  ;; returns [hash state num]
  ;; Determines the value of a policy for each state
  (define (value-determination epsilon mdp policy utilities)
    ;(print-hash utilities)
    (let ([new-utilities (make-hash)])
      (for-each 
       (lambda (state) 
         (hash-set! new-utilities state (value-of-action mdp utilities state (hash-ref policy state))))
       (states mdp))
      (if (< (max-change mdp utilities new-utilities) epsilon)
          new-utilities
          (value-determination epsilon mdp policy new-utilities))))
  
  ;; computes the optimal policy and it utilities
  ;; mdp: mdp
  ;; utilities: [hash state num]
  ;;  a guess of the values of the a policy
  ;; returns [hash state action] * [hash state num]
  ;; Determines the optimal of a policy for a mdp
  (define (value-iteration epsilon mdp utilities)
    ;(print-hash utilities)
    (let ([actions (actions mdp)]
          [new-utilities (make-hash)]
          [new-policy (make-hash)])
      (for-each
       (lambda (state)
         (let* ([best-act-value-pair 
                 (foldl (lambda (cur-act cur-best-act-value-pair)
                          (let ([cur-value (value-of-action mdp utilities state cur-act)])
                            (if (> cur-value (second cur-best-act-value-pair))
                                (list cur-act cur-value)
                                cur-best-act-value-pair)))
                        (list (first actions) (value-of-action mdp utilities state (first actions)))
                        (rest actions))]
                [best-act (first best-act-value-pair)]
                [best-value (second best-act-value-pair)])
           (hash-set! new-policy state best-act)
           (hash-set! new-utilities state best-value)))
       (states mdp))
      ;(printf ">> new: ~a ~a\n" new-policy new-utilities)
      (if (< (max-change mdp utilities new-utilities) epsilon)
          (list new-policy new-utilities)
          (value-iteration epsilon mdp new-utilities))))
  
  ;; uses value iteration to find an optimal policy for an MDP 
  (define (solve-mdp epsilon mdp)
    (let ([util-guess (make-hash)])
      (for-each (lambda (state)
                  (hash-set! util-guess state 0))
                (states mdp))
      (value-iteration epsilon mdp util-guess)))
  
   ;; num * mdp -> (* (state -> num)(state -> num)) 
  ;; returns a function that returns for each state
  ;; returns a list of a lower bound and upper bound 
  ;; on the value of a state in that order
  (define (solve-mdp-bounds epsilon mdp)
    (let ([opt-utilities-approx (second (solve-mdp epsilon mdp))]
          [gamma (discount-factor mdp)])
      (list (lambda (state)
              (let ([approx (hash-ref opt-utilities-approx state)]
                    [max-error (/ (* 2 epsilon gamma) (- 1 gamma))])
                (- approx max-error)))
            (lambda (state)
              (let ([approx (hash-ref opt-utilities-approx state)]
                    [max-error (/ (* 2 epsilon gamma) (- 1 gamma))])
                (+ approx max-error))))))
  
  )