(****************************************************************************)
(*              The Calculus of Inductive Constructions                     *)
(*                                                                          *)
(*                            Projet Coq                                    *)
(*                                                                          *)
(*                  INRIA                        ENS                        *)
(*           Rocquencourt                        Lyon                       *)
(*                                                                          *)
(*                              Coq V5.8                                    *)
(*                            Jan 1st 1993                                  *)
(****************************************************************************)
(*                            optimise.ml                                   *)
(****************************************************************************)

(*           performs partial evaluation on the extracted code              *)

#open "std";;
#open "fmlterm";;
#open "compile";;
#open "fmlenv";;

let assoc3 name = ass
where rec ass = function
  [] -> failwith "find"
| (x,y,z)::l -> if name = x then (y,z) else ass l
;;

let rec free_var = function
      (Fmvar s) -> [s]
    | (Fmapp(t1,t2)) -> union (free_var t1) (free_var t2)
    | (Fmconstruct (n,s,l)) -> list_it union (map free_var l) []
    | (Fmlambda(s,t)) ->  (subtract (free_var t) [s])
    | (Fmlocal(s,FmRec t1,t2)) -> union (free_var t1) (free_var t2)
    | (Fmlocal(s,t1,t2))->union(free_var t1) (subtract (free_var t2)[s])
    | (Fmmatch(t,s,pl)) ->
                      (list_it union
                           (map (fun (n,l,ti)->
                                      (subtract (free_var ti) l))
                                pl)
                           (free_var t))
    | Fmexcept(t1,t2) -> union (free_var t1) (free_var t2)
    | FmRec(t) -> free_var t
    | _ -> [];;

let  occur sc =
  let rec oc = function
      (Fmvar s) -> if s=sc then 1 else 0
    | (Fmapp(t1,t2)) -> (oc t1)+(oc t2)
    | (Fmexcept(t1,t2)) -> (oc t1)+(oc t2)
    | (Fmconstruct (n,s,l)) -> list_it (fun t n->(oc t)+n) l 0
    | (Fmlambda(s,t)) -> if s = sc then 0 else oc t
    | (Fmlocal(s,FmRec t1,t2)) -> if s=sc then 0
                                          else (oc t1)+(oc t2)
    | (Fmlocal(s,t1,t2))-> if s=sc then oc t1
                                   else (oc t1)+(oc t2)
    | (Fmmatch(t,s,pl)) ->
             list_it
                (fun (_,vl,t) n -> if mem sc vl then n
                                                else n+(oc t))
                pl 
               (oc t)
    | (FmRec t) -> oc t
    | _ -> 0
in oc;;


let new_name l =
    let rec nw s = let s'=s^"'"in if mem s' l then nw s' else s'
           in nw;;

let rec fmsubstvar s t =
 let V = free_var t in
 let rec subs1 = function
   Fmvar(v) as w -> if s=v then t else w
 | Fmconst(_) as c -> c
 | Fmapp(t1,t2) -> Fmapp(subs1 t1,subs1 t2)
 | Fmexcept(t1,t2) -> Fmexcept(subs1 t1,subs1 t2)
 | Fmlambda(v,t1) as t2 ->
    if s = v then t2
             else 
            if (mem v V) then let nv = new_name (union V (free_var t1)) v in
                               subs1(Fmlambda(nv,fmsubstvar v (Fmvar nv) t1))
                         else   Fmlambda(v,subs1 t1)
 | Fmconstruct(n,T,l) -> Fmconstruct(n,T,map subs1 l)
 | Fmext(_) as t1 -> t1
 | Fmlocal(v,FmRec(t1),t2) as t3 ->
      if s = v
        then t3
      else if mem v V 
        then let nv = new_name
                        (union V (union (free_var t1)(free_var t2)))
                        v
           in
               subs1(Fmlocal(nv,FmRec(fmsubstvar v (Fmvar nv) t1),
                              fmsubstvar v (Fmvar nv) t2))
        else Fmlocal(v,FmRec(subs1 t1),subs1 t2)
 | Fmlocal(v,t1,t2) ->
  let t3 = subs1 t1 in
    if s = v
      then Fmlocal(v,t3,t2)
      else if mem v V
             then let nv = new_name (union (free_var t2) V) v in
                    subs1(Fmlocal(nv,t1,fmsubstvar v (Fmvar nv) t2))
             else Fmlocal(v,t3,subs1 t2)
 | FmRec(t1) -> FmRec(subs1 t1)
 | Fmmatch(t1,T,pl) -> 
      Fmmatch(subs1 t1,T,map
                            (fun (n,l,t2) ->
                                 if mem s l then (n,l,t2)
                                  else
                                   (fun(l0,t2')->(n,l0,subs1 t2'))
                                       (ren V t2 l))
                         pl
              )
 | t -> t
in subs1

and ren V t l =
  let rec ren1 l t = function
                    [] -> (rev l,t)
                  | a::l1 -> if mem a V
                                 then let nv =
                                       new_name
                                         (union l1
                                         (union V
                                         (union l (free_var t)))) a
                                      in
                                    ren1 (nv::l)
                                        (fmsubstvar a (Fmvar nv) t)
                                          l1
                                  else ren1 (a::l) t l1
in ren1 [] t l;;


let fmsubstconst s t =
let rec subs1 = function
   Fmconst(v) as w -> if s=v then t else w
 | Fmapp(t1,t2) -> Fmapp(subs1 t1,subs1 t2)
 | Fmexcept(t1,t2) -> Fmexcept(subs1 t1,subs1 t2)
 | Fmlambda(v,t1) -> Fmlambda(v,subs1 t1)
 | Fmconstruct(n,T,l) -> Fmconstruct(n,T,map subs1 l)
 | Fmlocal(v,t1,t2) -> Fmlocal(v,subs1 t1,subs1 t2)
 | FmRec(t1) -> FmRec(subs1 t1)
 | Fmmatch(t1,T,pl) ->
      Fmmatch(subs1 t1,T,map (fun (n,l,t2) -> (n,l,subs1 t2)) pl)
 | t -> t
in subs1;;


let double_list_it f = let rec a1 = fun [] -> (fun x -> I)
                                      | (a::l) -> (fun (b::m) x ->
                                            a1 l m (f a b x)
                                                     | _ _ ->
                                            failwith "double_list")
in a1;;



let rec fm_beta_norm = function
    Fmapp(t1,t2) -> let t3 = fm_beta_norm t1 in
          (match t3 with
              Fmlambda(s,t4) -> fm_beta_norm(fmsubstvar s t2 t4)
            | _ -> Fmapp(t3,fm_beta_norm t2))
 | Fmlambda(s,t) -> Fmlambda(s,fm_beta_norm t)
 | Fmexcept(t1,t2) -> Fmexcept(fm_beta_norm t1,fm_beta_norm t2)
 | Fmconstruct(n,T,tl) -> Fmconstruct(n,T,map fm_beta_norm tl)
 | Fmmatch(t1,T,pl) ->Fmmatch(fm_beta_norm t1,T,
                              map (fun(n,sl,t)->(n,sl,fm_beta_norm t))pl)
 | Fmlocal(s,FmRec(t1),t2) -> let t3 = fm_beta_norm t1 in 
        if mem s (free_var t3) 
            then Fmlocal(s,FmRec t3,fm_beta_norm t2)
            else fm_beta_norm(Fmlocal(s,t3,t2))
 | Fmlocal(s,t1,t2) -> let t3 = fm_beta_norm t2 in 
        if occur s t3 < 2
            then fm_beta_norm(fmsubstvar s t1 t3)
            else Fmlocal(s,fm_beta_norm t1,t3)
 | t -> t;;

let superfl s = let T=assoc s !Fmindenv in
                   length T=1 & length(snd(hd T))=1;;

let rec fm_norm = function
   Fmlocal(s,FmRec(t1),t2) -> let t3 = fm_norm t1
                              and t4 = fm_norm t2 in
        if mem s (free_var t4) then 
          (if mem s (free_var t3)
            then Fmlocal(s,FmRec t3,t4)
            else fm_norm(Fmlocal(s,t3,t4)))
         else
           t4
 | Fmlocal(s,t1,t2) -> let t3 = fm_norm t2 in 
        if occur s t3 < 2
            then fm_norm(fmsubstvar s t1 t3)
            else Fmlocal(s,fm_norm t1,t3)
 | Fmapp(t1,t2) -> let t3 = fm_norm t1 in
          (match t3 with
              Fmlambda(s,t4) -> fm_norm(fmsubstvar s t2 t4)
            | Fmmatch(t4,T,pl) -> let t5=fm_norm t2 in
                 let V = free_var t5 in
                   (fm_norm(Fmmatch(fm_norm t4,T,
                               map (fun(n,l,t)->
                                 let (l1,t')=ren V t l in
                                (n,l1,fm_norm(Fmapp(t',t5))))
                             pl)))
            | _ -> Fmapp(t3,fm_norm t2))
 | Fmlambda(s,t) -> Fmlambda(s,fm_norm t)
 | Fmconstruct(n,T,tl) -> 
        if superfl T
            then fm_norm (hd tl)
            else Fmconstruct(n,T,map fm_norm tl)
 | (Fmmatch(t,T,pl) as X) ->
   let VX = free_var X in
             if superfl T
               then (match pl with
                   [(_,[s],t1)]->fm_norm(fmsubstvar s (fm_norm t) t1)
                 | _ -> failwith "fm_norm Fmmatch")
               else (
                let t1 = t in
                let t2 = fm_norm t1 in
      (match t2 with
         Fmconstruct(n,T,tl) -> let (sl,t)=assoc3 n pl in
          fm_norm
            (double_list_it fmsubstvar sl tl t)

       | Fmmatch(t,T',pl') ->
            fm_norm
              (Fmmatch(t,T',
                        map
                          (fun (n,sl,t0)->
                               let (sl',t0') = ren VX t0 sl in
                                   (n,sl',Fmmatch(t0',T,pl)))
                        pl'))
       | _ -> Fmmatch(t2,T,map (fun(n,sl,t)->(n,sl,fm_norm t)) pl)))
 | Fmexcept(t1,t2) -> Fmexcept(fm_norm t1,fm_norm t2)
 | FmRec(t) -> FmRec(fm_norm t)
 | t -> t;;

let rec hd_var = function
   Fmapp(t1,t2) -> hd_var t1
 | Fmvar(_) -> true
 | _ -> false ;;

let rec strict_vars = function
   Fmconstruct(_,_,l)->list_it union (map strict_vars l) []
 | Fmlambda(s,t) -> strict_vars t
 | Fmvar(s) -> [s]
 | Fmconst(_) -> []
 | Fmmatch(t,T,pl) -> union (strict_vars t)
                            (list_it intersect
                                     (map (fun(_,l,t)->subtract(strict_vars t)
                                                               l)
                                          pl)
                                     [])
 | FmRec(t) -> strict_vars t
 | Fmlocal(s,FmRec(t1),t2) -> subtract(union(strict_vars t1)
                                            (strict_vars
                                              (fmsubstvar s (Fmconst s) t2)))
                                      [s]
 | Fmlocal(s,t1,t2) -> union (strict_vars t1)
                             (subtract
                                (strict_vars (fmsubstvar s (Fmconst s) t2))
                                 [s])
 | Fmapp(Fmconst(s),t) ->  (strict_vars t)
 | Fmapp(Fmvar(s),t) -> [s]
 | Fmapp(t1,t2) -> if hd_var t1 then strict_vars t1
                                else union(strict_vars t2)(strict_vars t1)
 | Fmexcept(t1,t2) -> union(strict_vars t1)(strict_vars t2)
 | _ -> [];;

let rec fm_size = function
   Fmapp(t1,t2) -> 1 + (fm_size t1) + (fm_size t2)
 | FmRec(t) -> fm_size t
 | Fmconstruct(_,_,l) -> list_it (fun t y -> (fm_size t)+y) l 0
 | Fmlocal(_,t1,t2) -> 1+ (fm_size t1) + (fm_size t2)
 | Fmmatch(t,_,pl) -> 1 + (fm_size t)
                        + (list_it (fun (_,_,t) x -> x + (fm_size t)) pl 0)
 | Fmlambda(s,t) -> 1 + fm_size t
 | _ -> 0;;

let is_strict t =(fm_size t>4)&
                 ([] = subtract (fml_abs_var t)(strict_vars t))
                    or (match t with FmRec(_)->true
                                         | _ -> false);;


let optimal () =
  let rec op = function
    [] -> []
 |  (s,t)::l -> let t0 = (fm_norm t) in
                  let l' = 
                     if ((((s<>"well_founded_recursion")
                              &(not (is_constr t))
                              &(is_strict t0))
                       or l=[])&(not (t0 = Fmerror)))
                    then (op l)
                    else (op (map(fun(s1,t)-> (s1,(fmsubstconst s
                                                               t0
                                                               t)))
                               l))
               in if t0 = Fmerror then l'
                                  else (s,t0)::l'
    in rev(op (rev !Fmenv));;
compile__optimal.v <- optimal;;
