#lang racket

(require "mdp.rkt" "opt.rkt" "audit.rkt")

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; Tests

(define epsilon 1/10000)

(define (test1)
  (let ([mdp2 (make-mdp '(q0 q1)
                        '(act0 act1)
                        '([q0 act0 q0 1.0]
                          [q0 act0 q1 0.0]
                          [q0 act1 q0 0.5]
                          [q0 act1 q1 0.5]
                          [q1 act0 q0 0.0]
                          [q1 act0 q1 1.0]
                          [q1 act1 q0 0.5]
                          [q1 act1 q1 0.5])
                        '([q0 act0 1.0]
                          [q0 act1 0.0]
                          [q1 act0 5.0]
                          [q1 act1 0.0])
                        0.5)])
    (printf "test1.1: valid: ~a; q0 . act1, q1 . act0 = ~a ; ~a\n"
            (valid-mdp? mdp2)
            (value-iteration mdp2 (make-hash '((q0 . 0) (q1 . 0))))
            (solve-mdp mdp2))))

(define (test2)
  (let ([nmdp1 (make-nmdp '(q0) '(a1 a2) '() '([q0 a1 1] '[q0 a2 -1]) 0.5)])
    (printf "test2.1: valid: ~a; q0 . a1 = ~a\n"
            (valid-nmdp? nmdp1)
            (solve-mdp nmdp1))
    (printf "audit2.1: no violation = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q0 a1]) ))
    (printf "audit2.2: violation = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q0 a2]) ))
    (printf "audit2.3: violation = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q0 a2]) ))
    ))

(define (test3)
  (let ([nmdp1 (make-nmdp '(q0) '(a1 a2) '() '([q0 a1 1] '[q0 a2 0]) 0.5)])
    (printf "test2.1: valid: ~a; q0 . a1 = ~a\n"
            (valid-nmdp? nmdp1)
            (solve-mdp nmdp1))
    (printf "audit3.1: n v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q0 a1]) ))
    (printf "audit3.2: v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q0 a2]) ))
    (printf "audit3.3: v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q0 a2]) ))
    (printf "audit3.4: v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q0 a2] [q0 a1]) ))
    ; what does the last one even mean since it is not a static stratigy?
    ))


(define (test4)
  (let ([nmdp1 (make-nmdp '(q1 q2) '(a1 a2) '([q1 a2 q2 1] [q2 a1 q1 1]) '([q1 a2 3] '[q2 a1 -0.5]) 0.5)])
    (printf "test4.1: valid: ~a; q0 . a1 = ~a\n"
            (valid-nmdp? nmdp1)
            (solve-mdp nmdp1))
    (printf "audit4.1: v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q1 a1]) ))
    (printf "audit4.2: n v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q1 a2]) ))
    (printf "audit4.3: v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q2 a2]) ))
    (printf "audit4.4: n v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q2 a1]) ))
    (printf "audit4.5: v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q2 N]) ))
    ))

(define (test5)
  (let ([nmdp1 (make-nmdp '(q1 q2 q3) 
                          '(a1 a2) 
                          '([q1 a2 q2 0.5]
                            [q1 a2 q3 0.5]
                            [q2 a1 q1 1]) 
                          '([q1 a2 3] 
                            [q2 a1 -0.5]) 
                          0.5)])
    (printf "test5.1: valid: ~a; q1 . a2 ; q2 . a1 ; q3 . N = ~a\n"
            (valid-nmdp? nmdp1)
            (solve-mdp nmdp1))
    (printf "audit5.1: v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q1 a1]) ))
    (printf "audit5.2: n v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q1 a2]) ))
    (printf "audit5.3: v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q2 a2]) ))
    (printf "audit5.4: n v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q2 a1]) ))
    (printf "audit5.5: v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q2 N]) ))
    (printf "audit5.6: n v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q3 N]) ))
    (printf "audit5.7: v = ~a\n" (audit-policy epsilon 'exclusivity nmdp1 '([q3 a1]) ))
    ))


(define (test tests-name nmdp tests)
  (printf "Tests ~a\n" tests-name) 
  (if (valid-nmdp? nmdp)
      (for-each (lambda (test-row)
                  ;(printf ">> test-row: ~a; ~a \n" test-row (second test-row))
                  (let* ([name (first test-row)]
                         [log (second test-row)]
                         [expected-result (third test-row)]
                         [actual-result (auditMDPapprox-valueIter epsilon nmdp log)])
                    (if (equal? expected-result actual-result)
                        #t
                        (begin (printf "Failed test ~a on log ~a.  Expected ~a; got ~a.\n" name log expected-result actual-result)
                               #f))))
                tests)
      (printf "Invalid NMDP\n")))


(define (testA)
  (test "A"
        (make-nmdp '(q0) '(a1 a2) '() '([q0 a1 1] '[q0 a2 -1]) 0.5)
        '(["1" ([q0 a1]) #f]
          ["2" ([q0 a2]) #t]
          ["3" ([q0 a2]) #t])))

(define (testB)
  (test "B"
        (make-nmdp '(q1 q2) 
                   '(a1 a2) 
                   '([q1 a2 q2 1/2]
                     [q1 a2 q1 1/2]
                     [q2 a1 q1 1]) 
                   '([q1 a2 3] 
                     [q2 a1 -1/2]) 
                   1/2)
        '(["1" ([q1 a1]) #t] ; #t = violation of exclusivity rule
          ["2" ([q1 a2]) #f]
          ["3" ([q2 a2]) #t]
          ["4" ([q2 a1]) #f] ; doing a1 in q2 is optimal, thus, we do no expect a violation
          ["5" ([q2 N]) #t])    ))

(define (testC)
  (test "C"
        (make-nmdp '(q1 q2 q3) 
                   '(a1 a2) 
                   '([q1 a2 q2 0.5]
                     [q1 a2 q3 0.5]
                     [q2 a1 q1 1]) 
                   '([q1 a2 3] 
                     [q2 a1 -0.5]) 
                   0.5)
        '(["1" ([q1 a1]) #t] ; #t = violation of exclusivity rule
          ["2" ([q1 a2]) #f]
          ["3" ([q2 a2]) #t]
          ["4" ([q2 a1]) #f]
          ["5" ([q2 N]) #t]
          ["6" ([q3 N]) #f]
          ["7" ([q3 a1]) #t] ; redundant action; the implementation can't catch these and will fail this test
          ["8" ([q1 a2] [q2 a2]) #t]
          )))

(define (testD)
  (test "D"
        (make-nmdp '(q1 q2 q3) 
                   '(a1 a2) 
                   '([q1 a2 q1 0.4]
                     [q1 a2 q2 0.3]
                     [q1 a2 q3 0.3]
                     [q1 a1 q1 0.5]
                     [q1 a1 q3 0.5]
                     [q2 a1 q1 1]) 
                   '([q1 a2 3] 
                     [q2 a1 -0.5]) 
                   0.5)
        '(["1" ([q1 a1]) #t] ; #t = violation of exclusivity rule
          ["2" ([q1 a2]) #f]
          ["3" ([q2 a2]) #t]
          ["4" ([q2 a1]) #f]
          ["5" ([q2 N]) #t]
          ["6" ([q3 N]) #f]
          ["7" ([q3 a1]) #t] ; redundant action; the implementation can't catch these and will fail this test
          ["8" ([q1 a2] [q2 a2]) #t]
          ["9" ([q1 a1] [q3 N]) #t]
          ["10" ([q1 a2] [q1 a2] [q3 N]) #f]
          ["11" ([q1 a2] [q1 a2] [q3 a1]) #t] ; a redundant action; the implementation can't catch these and will fail this test
          ["12" ([q1 a2] [q1 a2] [q2 a1]) #f]
          ["13" ([q1 a2] [q1 a2] [q2 a2]) #t]
          ) ))

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; Pick what to run
;; Will have print outs saying what if anything is wrong.

(testA)
(testB)
(testC)
(testD)
