;; for making multi step models
(module model racket
  (provide (all-defined-out))
  
  (require (prefix-in mdp: "../audit/mdp.rkt"))
    
  
  ;; A model has two parameters 
  ;;   m: steps until forgetting
  ;;   n: number of patient in the RHIO database
  ;; States encode a history.  They have the form:
  ;; q.E.E.E ... .E:P where the number of Es is equal to m
  ;; E ranges over [0,n], s, and t.
  ;; A number in [0,n] denotes reading the record patient i
  ;; s denotes studying
  ;; t denotes treating a patient
  ;; P ranges over i in [1,n], O, and N.
  ;; a number i in [1,n] means that patient i is seeking treatment
  ;; o means that someone not in the datebase is seeking treatment
  ;; n means that no one is seeking treatment.
  ;; a state q.E1.E2 ... .Em:P transistions as follows:
  ;;  -- read i --> q.E2 ... .Em.i:P'
  ;;  -- study  --> q.E2 ... .Em.s:P'
  ;;  -- treat  --> q.E2 ... .Em.t:P'
  ;; where P' is a random draw over its range.
  ;; Only the treat action produces a non-zero reward and only P is not n.
  ;; The reward for P=o is equal to base*S where S is the number of s's in E1 ... Em
  ;; For P=i, it is base*S + bonus_i if i is in E1 ... Em.  It's base*S otherwise.
  
  ;; makes all the events of form E for n patients
  (define (make-events n)
    (if (<= n 0)
        '(s t)
        (cons n (make-events (sub1 n)))))
  
  ;; makes all the possible conditions of form P for n patients
  (define (make-conditions n)
    (if (<= n 0)
        '(o n)
        (cons n (make-conditions (sub1 n)))))
  
  
  ;; makes all the possible histories for n patients over m steps
  ;; a hist has the form (list E1 ... Em)
  (define (make-hists events m)
    (if (<= m 0)
        '(())
        (let ([pre-hists (make-hists events (sub1 m))])
          (foldl (lambda (pre-hist rst1)
                   (foldl (lambda (e rst2)
                            (cons (cons e pre-hist) rst2))
                          rst1
                          events))
                 '()
                 pre-hists))))
  
  ;; Makes list states that are translated in symbolic ones later
  ;; q.E1 ... Em:P is (list P E1 ... Em)
  (define (make-states m n)
    (let ([conditions (make-conditions n)])
      (foldl (lambda (hist rst1) 
               (foldl (lambda (condition rst2)
                        (cons (cons condition hist)
                              rst2))
                      rst1
                      conditions))
             '()
             (make-hists (make-events n) m))))
  
  ;(make-states 3 2)
  
  ;; converst state of form (P Em Em-1 ... E2 E1) where E1 is the oldest to
  ;; (P' E' Em Em-1 ... E2) where P' = outcome and E' = act.
  (define (state-update state act outcome)
    (cons outcome
          (cons act
                (rest (drop-right state 1)))))
  
  ;(state-update (list 'o 4 3 2 1) 't 'n)
  
  (define (condition-probablity probs c)
    (cond [(eq? c 'n) (list-ref probs 0)]
          [(eq? c 'o) (list-ref probs 1)]
          [else (list-ref probs (+ c 1))]))
  
  ;; a state q.E1.E2 ... .Em:P transistions as follows:
  ;;  -- read i --> q.E2 ... .Em.i:P'
  ;;  -- study  --> q.E2 ... .Em.s:P'
  ;;  -- treat  --> q.E2 ... .Em.t:P'
  ;; where P' is a random draw over its range.
  ;; probs a list providing the probablity of each transistion
  ;; it has the form (p_n p_o p_1 p_2 p_3...)
  (define (make-trans m n probs)
    (let ([states (make-states m n)])
      (foldl (lambda (state rst1)
               (foldl (lambda (act rst2)
                        (foldl (lambda (outcome rst3)
                                 (cons (list (list-state->state-symbol state)
                                             (event->action-symbol act) 
                                             (list-state->state-symbol (state-update state act outcome))
                                             (condition-probablity probs outcome))
                                       rst3))
                               rst2
                               (make-conditions n)))
                      rst1
                      (make-events n)))
             '()
             states)))
  
  ;(make-trans 2 3 '(0.1 0.2 0.3 0.4 0.5))
  
  
  ;; Only the treat action produces a non-zero reward and only P is not n.
  ;; The reward for P=o is equal to base*S where S is the number of s's in E1 ... Em
  ;; For P=i, it is base*S + bonus_i if i is in E1 ... Em.  It's base*S otherwise.
  
  ;; normal-rewards: r_o r_1 ... r_n
  ;; read-reward-bounces: r'_1 ... r'_n
  ;; studied-reward-multi: num
  ;; reward = r_c + R*r'_c + S*studied-reward-multi
  ;; where S is number of times studied, and R is true if c's record is read
  (define (make-rewards m n normal-rewards read-reward-bounces studied-reward-multi)
    (foldl (lambda (state rst)
             (let ([condition (first state)]
                   [history (rest state)])
               (cons (list (list-state->state-symbol state)
                           (event->action-symbol 't)
                           (cond [(equal? condition 'n) 
                                  0]
                                 [(equal? condition 'o) 
                                  (+ (list-ref normal-rewards 0)
                                     (* studied-reward-multi (count (lambda (e) (equal? e 's)) (rest state))))]
                                 [(and (< 0 condition) (<= condition n))
                                  (+ (list-ref normal-rewards condition) 
                                     (* studied-reward-multi (count (lambda (e) (equal? e 's)) (rest state)))
                                     (if (member condition history)
                                         (list-ref read-reward-bounces (sub1 condition))
                                         0))]
                                 [else (error "unknown condition" condition)]))
                     rst)))
           '()
           (make-states m n)))
  
  ;(make-rewards 2 2 '(100 200 300) '(40 50) 6)
  
  (define (list-state->state-symbol list-state)
    (string->symbol 
     (string-append "q"
                    (foldl (lambda (e rst)
                             (string-append (format ".~a" e) rst))
                           (format ":~a" (first list-state))
                           (rest list-state)))))
  
  (define (event->action-symbol e)
    (string->symbol (format "a.~a" e)))
  
  
  ;; m number of steps
  ;; n number of patients
  ;; probs: p_none p_other p_1 p_2 p_3... p_n
  ;; normal-rewards: r_other r_1 ... r_n
  ;; read-rewards: r'_1 ... r'_n
  ;; studied-reward-multi: num
  ;; reward = r_c + R*r'_c + S*studied-reward-multi
  ;; where S is number of times studied, and R is true if c's record is read
  ;; q.E.E.E ... .E:P where the number of Es is equal to m
  (define (make-multi-step-nmdp m n probs normal-rewards read-rewards studied-reward-multi gamma)
    (let ([model (mdp:make-nmdp (map list-state->state-symbol (make-states m n))
                                (map event->action-symbol (make-events n))
                                (make-trans m n probs)
                                (make-rewards m n normal-rewards read-rewards studied-reward-multi)
                                gamma)])
      (if (mdp:valid-nmdp? model) 
          ; this test is assertation for debugging
          model
          (error "Invalid NDMP: " model))))
  
  )
