#open "misc";;
#open "constants";;
#open "globals";;
#open "syntax";;
#open "locations";;
#open "k2";;

(*  match.ml : expansion du filtrage dans k2

    Bibliographie : Peyton-Jones, The Implementation of F.P.L, chap. 5    *)

(* Representation d'un pattern matching : une disjonction de conjonctions.

      pat & pat & ... & pat  ->  action
    | pat & pat & ... & pat  ->  action
    | ...
    | pat & pat & ... & pat  ->  action

      exp , exp , ... , exp 

  Un motif "pat" s'applique a (i.e. doit filtrer) l'expression qui est
  en-dessous de lui, dont la valeur est stockee dans la variable 
  d'environnement correspondante.
*)

type pattern_matching = 
  Matching of (pattern list * k2exp) list (* patterns et actions *)
              * k2exp list (* exp_1 ... exp_n *)
;;

(* Quelques manipulations triviales de matchings *)

let make_path_without_tag n (path::pathl) =
  let rec make i =
    if i >= n then pathl else Kprim(Pfield i, [path]) :: make (i+1)
  in
    make 0
;;

(*let make_path_with_tag n (path::pathl) =
  let rec make i =
    if i >= n then pathl else Kprim(Pfield (i+1), [path]) :: make (i+1)
  in
    make 0
;;*)

let add_to_match (Matching(casel,pathl)) cas =
  Matching(cas :: casel, pathl)

and make_constant_match = fun
    (path :: pathl) cas -> Matching([cas], pathl)
  | _ _ -> fatal_error "make_constant_match"

and make_tuple_match arity pathl =
  Matching([], make_path_without_tag arity pathl)

and make_construct_match cstr ((path :: pathl) as pathl0) cas =
  match cstr.info.cs_kind with
    Constr_constant _ -> Matching([cas], pathl) (* pourquoi ?? *)
  | Constr_superfluous -> Matching([cas], pathl0)
  | Constr_tagless(_,arity) ->
      if arity = 1 then Matching([cas], Kprim(Pfield 0,[path]) :: pathl)
      else Matching([cas], pathl0)
  | Constr_regular(_,_,arity) ->
      if arity = 1 then Matching([cas], Kprim(Pfield 1,[path]) :: pathl)
      else Matching([cas],Kprim(Pshift_tag,[path]) :: pathl)
;;

(* Auxiliaries for factoring common tests *)

let add_to_division make_match divlist kinds cas =
  try
    let matchref = assoc kinds divlist in
      matchref := add_to_match !matchref cas; divlist
    with Not_found -> (kinds, ref (make_match cas)) :: divlist
;;

(* To skip type constraints and aliases, and flatten "or" patterns. *)

let rec simpl_casel = function
    (Pat(Zaliaspat(pat,v),_) :: patl, action) :: rest ->
      simpl_casel ((pat::patl, action) :: rest)
  | (Pat(Zconstraintpat(pat,ty),_) :: patl, action) :: rest ->
      simpl_casel ((pat::patl, action) :: rest)
  | (Pat(Zorpat(pat1, pat2),_) :: patl, action) :: rest ->
      simpl_casel ((pat1::patl, action) :: (pat2::patl, action) :: rest)
  | casel ->
      casel
;;

(* Factoring pattern-matchings. *)

(* Factorisation d'un matching dans le cas ou il "commence" (en haut a gauche)
   par un pattern constant *)

let divide_constant_matching (Matching(casel, pathl)) =
  divide_rec casel where rec divide_rec casel =
    match simpl_casel casel with
      (Pat(Zconstantpat(cst),_) :: patl, action) :: rest ->
        let (constants, others) = divide_rec rest in
          add_to_division
            (make_constant_match pathl) constants [cst]
            (patl, action),
          others
    | (Pat(Zdotpat(cstl),_) :: patl, action) :: rest ->
        let (constants, others) = divide_rec rest in
          add_to_division
            (make_constant_match pathl) constants cstl
            (patl, action),
          others
    | casel ->
        [], Matching(casel, pathl)
;;

(* Factorisation d'un matching dans le cas ou il "commence" (en haut a gauche)
   par un tuple *)

let divide_tuple_matching arity (Matching(casel, pathl)) =
(*  print_int arity; print_newline (); *)
  divide_rec casel where rec divide_rec casel =
    match simpl_casel casel with
      (Pat(Ztuplepat(args), _) :: patl, action) :: rest ->
(*        print_string "Ztuple\n"; *)
        add_to_match (divide_rec rest) (args @ patl, action)
    | (Pat((Zwildpat | Zvarpat _), _) :: patl, action) :: rest ->
        let rec make_pats i =
          if i >= arity
          then []
          else Pat(Zwildpat, no_location) :: make_pats (i+1) in
        add_to_match (divide_rec rest) (make_pats 0 @ patl, action)
    | [] ->     (*   print_string "Znil\n"; *)
        make_tuple_match arity pathl
    | _ ->
        fatal_error "divide_tuple_matching"
;;

(* Factorisation d'un matching dans le cas ou il "commence" (en haut a gauche)
   par un pattern construit *)

let divide_construct_matching (Matching(casel, pathl)) =
  divide_rec casel where rec divide_rec casel =
    match simpl_casel casel with
      (Pat(Zconstruct0pat(c), _) :: patl, action) :: rest ->
        let (constrs, others) =
          divide_rec rest in
        add_to_division
          (make_construct_match c pathl) constrs [c.info.cs_kind] 
          (patl, action),
        others
    | (Pat(Zconstruct1pat(c,arg),_) :: patl, action) :: rest ->
        let patl' =
          match c.info.cs_kind with
            Constr_constant _ -> patl
          |          _        -> arg :: patl in
        let (constrs, others) =
          divide_rec rest in
        add_to_division
          (make_construct_match c pathl) constrs [c.info.cs_kind]
          (patl', action),
        others
    | casel ->
        [], Matching(casel, pathl)
;;

(* Factorisation d'un matching dans le cas ou il "commence" (en haut a gauche)
   par une variable ou par _                                                *)

let divide_var_matching (Matching(casel, (_ :: endpathl as pathl))) =
  divide_rec casel where rec divide_rec casel =
    match simpl_casel casel with
      (Pat((Zwildpat | Zvarpat _),_) :: patl, action) :: rest ->
        let vars, others = divide_rec rest in
          add_to_match vars (patl, action),
          others
    | casel ->
        Matching([], endpathl), Matching(casel, pathl)
;;

(* Factorisation d'un matching dans le cas ou il "commence" (en haut a gauche)
   par un record pattern : pas de tag *)

let divide_record_matching (Matching(casel, pathl)) =
  let max_pos = ref 0 in
  let rec max_size = function
      Pat(Zaliaspat(pat,v),_) -> max_size pat
    | Pat(Zconstraintpat(pat,ty),_) -> max_size pat
    | Pat(Zorpat(pat1,pat2),_) -> max_size pat1; max_size pat2
    | Pat(Zrecordpat pat_expr_list,_) ->
        do_list
          (fun (lbl,p) ->
            if lbl.info.lbl_pos > !max_pos then
              (max_pos := lbl.info.lbl_pos; ()))
          pat_expr_list
    | _ -> () in
  do_list
    (function (pat::patl, act) -> max_size pat
            | ([],_) -> fatal_error "divide_record_matching_do_list")
    casel;
  let rec divide_rec = function
      (Pat(Zaliaspat(pat,v),_) :: patl, action) :: rest ->
        divide_rec ((pat::patl, action) :: rest)
    | (Pat(Zconstraintpat(pat,ty),_) :: patl, action) :: rest ->
        divide_rec ((pat::patl, action) :: rest)
    | (Pat(Zorpat(pat1, pat2),_) :: patl, action) :: rest ->
        divide_rec ((pat1::patl, action) :: (pat2::patl, action) :: rest)
    | (Pat(Zrecordpat pat_expr_list,_) :: patl, action) :: rest ->
        divide_rec_cont pat_expr_list patl action rest
    | (Pat((Zwildpat | Zvarpat _),_) :: patl, action) :: rest ->
        divide_rec_cont [] patl action rest
    | [] ->
        Matching([], make_path_without_tag (succ !max_pos) pathl)
    | _ ->
        fatal_error "divide_record_matching"
  and divide_rec_cont pat_expr_list patl action rest =
    let v = make_vect (succ !max_pos) (Pat(Zwildpat, no_location)) in
      do_list (fun (lbl, pat) -> v.(lbl.info.lbl_pos) <- pat) pat_expr_list;
      add_to_match (divide_rec rest) (list_of_vect v @ patl, action)
  in
    divide_rec casel
;;

(* Utilitaires quelconques. *)

let length_of_matching (Matching(casel,_)) = list_length casel
;;

let upper_left_pattern =
  let rec strip = function
      Pat(Zaliaspat(pat,_),_) -> strip pat
    | Pat(Zconstraintpat(pat,_),_) -> strip pat
    | Pat(Zorpat(pat1,pat2),_) -> strip pat1
    | pat -> pat in
  function Matching((pat::_, _) :: _, _) -> strip pat
      |                _                 -> fatal_error "upper_left_pattern"
;;

let get_span_of_cs_kind kind = 
  match kind with
    Constr_constant(ConstrRegular(_,span)) -> span
  | Constr_regular(ConstrRegular(_,span),_,_) -> span
  | Constr_constant(ConstrExtensible _) -> -1
  | Constr_regular(ConstrExtensible _,_,_) -> -1
  | Constr_superfluous -> 1
  | Constr_tagless(span,_) -> span
;;

let get_span_of_constr cstr = 
  get_span_of_cs_kind cstr.info.cs_kind
;;

let get_span_of_matching matching =
  match upper_left_pattern matching with
      Pat(Zconstruct0pat(c), _)   -> get_span_of_constr c
    | Pat(Zconstruct1pat(c,_), _) -> get_span_of_constr c
    | _ -> fatal_error "get_span_of_matching"
;;

(* La logique trois etats. *)

let tristate_or = function
    (True, _)     -> True
  | (_, True)     -> True
  | (False,False) -> False
  |      _        -> Maybe
;;

(* The main compilation function.
   Input: a pattern-matching,
   Output: a k2 expression, a "partial" flag, a list of used cases. *)

let rec conquer_matching =
  let rec conquer_divided_matching = function
    [] ->
      [], False, []
  | (kinds, matchref) :: rest ->
      let k2exp1, partial1, used1 = conquer_matching !matchref
      and list2,   partial2, used2 = conquer_divided_matching rest in
        (kinds,k2exp1) :: list2,
        tristate_or(partial1,partial2),
        used1 @ used2
  in function
    Matching([], _) ->
      Kreturn("*static2*",Kvoid), True, []
  | Matching(([], action) :: rest, _) ->
      action, False, [action]
  | Matching(_, (path :: _)) as matching ->
     (match upper_left_pattern matching with
        Pat((Zwildpat | Zvarpat _), _) ->
          let vars, rest = divide_var_matching matching in
          let k2exp1, partial1, used1 = conquer_matching vars
          and k2exp2, partial2, used2 = conquer_matching rest in
            if partial1 == False then
              k2exp1, False, used1
            else
	      (if k2exp2 = Kreturn("*static2*",Kvoid)
               then k2exp1
               else
                 Kblock("*static1*",
                   [Kblock("*static2*",[Kreturn("*static1*",k2exp1)]);
                    k2exp2])),
              (if partial2 == False then False else Maybe),
              used1 @ used2
      | Pat(Ztuplepat patl, _) ->
          conquer_matching(divide_tuple_matching(list_length patl) matching)
      | Pat((Zconstruct0pat(_) | Zconstruct1pat(_,_)),_) ->
          let span = get_span_of_matching matching in
          (* span=-1 pour un constructeur extensible *)
          let constrs, vars = divide_construct_matching matching in
          let condlist, partial1, used1 = conquer_divided_matching constrs
          and k2exp, partial2, used2 = conquer_matching vars
          and num_cstr = list_length constrs in
(*            print_int num_cstr; print_string " "; 
            print_int span; print_newline(); *)
            if num_cstr == span & partial1 == False then
              match condlist with
                [] -> fatal_error "condlist=[] in conquer_matching"
              | [(_,unique)] -> unique, False, used1
              | _ -> Kcase(path,condlist,Kvoid), False, used1
            else 
              (let taglessp = let rec teste = function
                                    [] -> false
                                  | ([Constr_tagless _],_)::_ -> true
                                  | _::list -> teste list
                              in teste condlist
               and superflup = let rec teste = function
                                    [] -> false
                                  | ([Constr_superfluous],_)::_ -> true
                                  | _::list -> teste list
                              in teste condlist
               in let condlist2 = if taglessp
                    then (* completer condlist avec tag 0 s'il n'y est pas *)
                      let taglist = 
                       let l = ref []
                       in do_list
                           (fun ((Constr_constant _ :: _) as constrl,_) ->
                                  l := (map (function Constr_constant t -> 
                                          int_of_constr_tag t) 
                                            constrl) @ !l
                              | ((Constr_regular _ :: _) as constrl,_) ->
                                  l := (map (function Constr_regular(t,_,_) -> 
                                          int_of_constr_tag t) 
                                            constrl) @ !l
                              | _ -> () )
                           condlist; !l
                      and constr_list = ref([] : constr_kind list)
                      and backt = Kreturn("*static2*",Kvoid)
                                  (* la valeur retournee est inutile *)
                      in if not(memq 0 taglist) then 
                             ([Constr_constant(ConstrRegular(0,span))],backt)
                             :: condlist
                         else condlist
                    else condlist (* toujours le cas si extensible *)
                in (*print_int (list_length condlist2); print_newline();*)
                   if k2exp = Kreturn("*static2*",Kvoid)
                     then Kcase(path,condlist2,if superflup then Kvoid
                                               else Kreturn("*static2*",Kvoid))
                     else Kblock("*static1*",
                            [Kblock("*static2*",
                               [Kreturn("*static1*",Kcase(path,condlist2,
                                    if superflup then Kvoid (* inutile *)
                                    else Kreturn("*static2*",Kvoid)))]);
                             k2exp])),   
              (if partial2 == False then False
               else if num_cstr < span & partial2 == True then True
               else Maybe),
              used1 @ used2
      | Pat(Zconstantpat _,_) | Pat(Zdotpat _,_) ->
          let constants, vars = divide_constant_matching matching in
            let switchlist, _, used1 = conquer_divided_matching constants
            and k2exp2, partial2, used2 = conquer_matching vars
            in (if k2exp2 = Kreturn("*static2*",Kvoid)
                then Kswitch(path,switchlist,Kreturn("*static2*",Kvoid))
                else Kblock("*static1*",
                       [Kblock("*static2*",
                         [Kreturn("*static1*",
                            Kswitch(path,switchlist,
                               Kreturn("*static2*",Kvoid)))]);
                               (* la valeur retournee est inutile *)
                         k2exp2])),      
                partial2,
                used1 @ used2
      | Pat(Zrecordpat _,_) ->
          conquer_matching (divide_record_matching matching)
      | _ ->
          fatal_error "conquer_matching 2")
  | _ -> fatal_error "conquer_matching 1"
;;

let make_initial_matching vars = function
    [] ->
      fatal_error "make_initial_matching: empty"
  | (patl, _) :: _ as casel ->
      Matching(casel, map (fun id -> Kvar id) vars)
;;

(* The point d'entree *)

let translate_matching failure_code loc casel vars =
  let (k2exp, partial, used) =
    conquer_matching (make_initial_matching vars casel) in
  if not for_all (fun (_, act) -> memq act used) casel then begin
    prerr_location loc;
    prerr_begline " Warning: some cases are unused in this matching.";
    prerr_endline2 ""
  end;
  match partial with
      False -> k2exp
      | _ -> Kblock("*static1*",
                        [Kblock("*static2*",[Kreturn("*static1*",k2exp)]);
                    failure_code partial])
;;
