(module mdp racket
  
  (provide (all-defined-out))
  
  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
  ;; Util
  
  (define (print-hash hash)
    (printf "hash: ~a\n" (hash->list hash)))
  
  (define (summation exp over-list)
    (apply + (map exp over-list)))
  
  (define (cons-across prefixes list-lists)
    (foldl (lambda (ea r) 
             (append (map (lambda (eb) (cons ea eb)) list-lists)
                     r))
           '()
           prefixes))
  
  (define (cross . lst-lst)
    (foldr (lambda (lst lst-of-products-so-far)
             (cons-across lst lst-of-products-so-far))
           '(())
           lst-lst))
  
  
  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
  ;; Representation of MDPs
  ;;
  ;; mdp ::= (list states actions tran-probs rewards gamma)
  ;; states: [list state]
  ;;   a list of states where a state is something that can be hashed (e.g., numbers)
  ;; actions: [list action]
  ;;   a list of actions where an action is something that can be hashed (e.g., numbers)
  ;; trans-probs: [hash state * action * state ->  num]
  ;;   a function from a state, an action, and a state to a number in [0,1]
  ;;   the probablity of the action performed in the first state leading to the next state
  ;; rewards: [hash state * action -> num]
  ;;   a function from states and actions to a number indicating the reward for doing
  ;;   that action in that state
  ;; discount-factor: num
  ;;   the "gamma": a number g s.t. 0 <= g < 1.  (g = 1 might result in non-termination).
  (define (pack-mdp states actions trans-probs rewards gamma)
    (list states actions trans-probs rewards gamma))
  
  (define (states mdp)
    (first mdp))
  
  (define (actions mdp)
    (second mdp))
  
  (define (trans-prob mdp current-state action new-state)
    (hash-ref (third mdp) (list current-state action new-state)))
  
  (define (reward mdp state act)
    (hash-ref (fourth mdp) (list state act)))
  
  (define (discount-factor mdp)
    (fifth mdp))
  
  (define (print-mdp mdp)
    (printf "mdp: ~a ; ~a ;"
            (states mdp)
            (actions mdp))
    (print-hash (third mdp))
    (printf " ; ")
    (print-hash (fourth mdp))
    (printf " ; ~a"
            (discount-factor mdp)))
  
  
  (define (make-trans-hash . quads)
    (let ([hash-tbl (make-hash)])
      (for-each (lambda (quad)
                  (let ([cur-state (first quad)]
                        [act (second quad)]
                        [next-state (third quad)]
                        [prob (fourth quad)])
                    (hash-set! hash-tbl (list cur-state act next-state) prob)))
                quads)
      hash-tbl))
  
  (define (make-reward-hash . trips)
    (let ([hash-tbl (make-hash)])
      (for-each (lambda (trip)
                  (let ([cur-state (first trip)]
                        [act (second trip)]
                        [reward (third trip)])
                    (hash-set! hash-tbl (list cur-state act) reward)))
                trips)
      hash-tbl))
  
  
  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
  ;; helpers
  
  (define (expand-trans-probs-list states actions trans-prob-list)
    (foldl (lambda (state rst1)
             (foldl (lambda (act rst2)
                      (let ([part-matches (filter (lambda (trans-prob)
                                                    (and (symbol=? state (first trans-prob))
                                                         (symbol=? act (second trans-prob))))
                                                  trans-prob-list)])
                        (cond [(empty? part-matches) ; no transistions given for this action: default to self loop
                               (foldl (lambda (next-state rst3)
                                        (if (symbol=? state next-state)
                                            (cons (list state act next-state 1) rst3)
                                            (cons (list state act next-state 0) rst3)))
                                      rst2
                                      states)]
                              [else
                               (foldl (lambda (next-state rst3)
                                        (let ([full-matches (filter (lambda (trans-prob)
                                                                      (symbol=? next-state (third trans-prob)))
                                                                    part-matches)])
                                          (cond [(empty? full-matches) (cons (list state act next-state 0) rst3)]
                                                [(empty? (rest full-matches)) (cons (first full-matches) rst3)]
                                                [else (error "More than one probablity given for the same transisition: " full-matches)])))
                                      rst2
                                      states)])))
                    rst1
                    actions))
           '()
           states))
  
  (define (expand-reward-list states actions reward-list)
    (foldl (lambda (state rst1)
             (foldl (lambda (action rst2)
                      (let ([matches (filter (lambda (reward-line)
                                               (and (symbol=? state (first reward-line))
                                                    (symbol=? action (second reward-line))))
                                             reward-list)])
                        (cond [(empty? matches) (cons (list state action 0) rst2)] ; default to zero reward
                              [(empty? (rest matches)) (cons (first matches) rst2)]
                              [else (error "More than one reward given for same state-action-pair: " matches)])))
                    rst1
                    actions))
           '()
           states))
  
  ;; trans-probs-rlist: list of transistions each given as [list state action state num] 
  ;;  if an action isn't listed at all, it will made a self-loop with reward zero
  ;;  if an action is listed, need only list those transistions with non-zero probablity
  ;; reward: list of rewards each given as [list state acton num]
  ;;  if a state-action pair isn't listed, the reward is treated as zero
  ;; States and actions must be symbols
  ;; It easy to add Nothing actions 'N by simply listing it in the list of actions and not
  ;; adding it anywhere else.
  ;; The efficientency of this algoirthm could be improved by first putting things into a hash table
  ;; and then filling missing values
  (define (make-mdp states actions trans-probs-list reward-list gamma)
    (pack-mdp states
              actions
              (apply make-trans-hash (expand-trans-probs-list states actions trans-probs-list))
              (apply make-reward-hash (expand-reward-list states actions reward-list))
              gamma))
  
  ;; The same as make-mdp but it adds the nothing action 'N for you.
  ;; Make sure it is not included in actions, trans-probs-list, or reward-list 
  (define (make-nmdp states actions trans-probs-list reward-list gamma)
    (make-mdp states (cons 'N actions) trans-probs-list reward-list gamma))
  ;; makes sure that an mdp is actually a valid MDP
  ;; will return #f if probablities do not add up correctly
  ;; will return with an error if its missing transisitons all together
  (define (valid-mdp? mdp)
    (let ([states (states mdp)]
          [actions (actions mdp)])
      (and
       (andmap
        (lambda (s1-a-s2-trip)
          (let* ([s1 (first s1-a-s2-trip)]
                 [a (second s1-a-s2-trip)]
                 [s2 (third s1-a-s2-trip)]
                 [prob (trans-prob mdp s1 a s2)])
            (and (>= prob 0)
                 (<= prob 1))))
        (cross states actions states))
       (andmap
        (lambda (state-act-pair)
          (= (summation (lambda (next-state)
                          (trans-prob mdp (first state-act-pair) (second state-act-pair) next-state))
                        states)
             1))
        (cross states actions)))))
  
  (define (find-invalid-mdp-problm mdp)
    (let ([states (states mdp)]
          [actions (actions mdp)])
      (map
       (lambda (s1-a-s2-trip)
         (let* ([s1 (first s1-a-s2-trip)]
                [a (second s1-a-s2-trip)]
                [s2 (third s1-a-s2-trip)]
                [prob (trans-prob mdp s1 a s2)])
           (if (not (>= prob 0))
               (printf "prob too small: ~a at ~a ~a ~a\n" prob s1 a s2)
               (void))
           (if (not (<= prob 1))
               (printf "prob too big: ~a at ~a ~a ~a\n" prob s1 a s2)
               (void))))
       (cross states actions states))
      (map
       (lambda (state-act-pair)
         (let* ([state (first state-act-pair)]
                [act (second state-act-pair)]
                [sum (summation (lambda (next-state)
                                  (trans-prob mdp state act next-state))
                                states)])
           (if (not (= sum 1))
               (printf "prob not adding to 1: ~a at ~a ~a\n" sum state act)
               (void)))) 
       (cross states actions))))
  
  (define (valid-nmdp? mdp)
    (and (valid-mdp? mdp)
         (andmap (lambda (state)
                   (= (trans-prob mdp state 'N state) 1))
                 (states mdp))))
  
  )