open Pretty
open Cil
open Str
open List
open String
module E = Errormsg
module H = Hashtbl

(* Create a prototype for the logging function, but don't put it in the 
 * file *)

let getFuncDecl fname =
  let fdec = emptyFunction fname in 
    fdec.svar.vtype <- TFun(voidType, 
                            Some [ ("ptr", voidPtrType, []); 
			           ("field", charConstPtrType, [])  ;
                                   ("isClientPtr", intType, []) ],
			    false,
			    []);
    fdec

let getRecFuncWriteDecl = 
  let fdec = emptyFunction "RecordFuncWrite" in 
    fdec.svar.vtype <- TFun(voidType, 
                            Some [ ("ptr", voidPtrType, []); 
			           ("field", charConstPtrType, [])  ;
                                   ("funcname", charConstPtrType, []) ],
			    false,
			    []);
    fdec


let rec getCompleteFieldName prefix offset = begin
  match offset with 
      NoOffset -> 
        prefix
    | Field(f,o) -> 
        let s = getCompleteFieldName f.fname o in 
          if s = "" then prefix else prefix ^ "." ^ s
    | Index(e,o) -> 
        (* this is weird *)
        let s1 = Pretty.sprint 80 (Pretty.dprintf "%a" d_exp e) in 
        let s2 = getCompleteFieldName ("[" ^ s1 ^ "]") o in 
          prefix (* ^ s2 *) (* not sure what to do here. in case we use the index, we need the 
                               run-time value of the expression rather than the syntactic form *)
end

let getRegexp tl = 
  let f x = "\\(" ^ x ^ "\\)" in 
  let rs = concat "\\|" (List.map f tl) in 
  regexp (".*\\(" ^ rs ^ "\\)") 
					       
(* cant believe ocaml does not contain a straight-forward 
   'search for a substring' function :(  *)
let shouldRecordType (t: typ) : bool = 
  let s = Pretty.sprint 80 (Pretty.dprintf "%a" d_type t) in 
  let tl = [ "entityState_" ; "gentity_" ; "gclient_" ; "playerState_" ;
             "entityShared_" ] in 
  let reg = getRegexp tl in
    string_match reg s 0 
      
let shouldRecordExp e = 
  shouldRecordType (typeOf e) 
    
(*
  let makeRecordStmt (s : string) : stmt =     
  let str = "recording a read to " ^ s in
  let i = Call((None), 
  (Lval(Var(printfFun.svar),NoOffset)),
	[ mkString str ],
  locUnknown) in 
  mkStmt (Instr [i])
*)

(* what do we want to record, really? 
   we want to know which fields are written to for what types of entities
   (classnames), and we want to know which entity's which think function 
   wrote it *)
    
let isClientPointer exp = 
    let t = typeOf exp in 
    let s = Pretty.sprint 80 (Pretty.dprintf "%a" d_type t) in 
    let reg = getRegexp [ "gclient_" ; "playerState_" ] in
    if string_match reg s 0 then
        1
    else
        0

(*
let getFieldName e f o = 
    let ts = Pretty.sprint 80 (Pretty.dprintf "%a" d_type (typeOf e)) in 
    ts
*)

let getFuncLval x = Lval(Var(x.svar),NoOffset)

let makeRecordInstr func e f o = 
  let isClient = isClientPointer e in
  Call((None), (getFuncLval func), 
       [ CastE(voidPtrType,e) ; 
       (mkString (getCompleteFieldName f.fname o));
       (Cil.integer isClient) ], locUnknown)
    
let wrRecInstr = makeRecordInstr (getFuncDecl "RecordWrite")
let rdRecInstr = makeRecordInstr (getFuncDecl "RecordRead")
let fnRecInstr_1 = makeRecordInstr (getFuncDecl "RecordFuncallPre")
let fnRecInstr_2 = makeRecordInstr (getFuncDecl "RecordFuncallPost")
		     
let getFieldFromLval lv = 
    match lv with (lh, off) ->
        match (lh, off) with 
          (Mem(exp), Field(f,o)) -> getCompleteFieldName f.fname o
        | _ -> "dont_care"

let getExpFromLval lv = 
    match lv with (lh, off) ->
        match (lh, off) with (Mem(exp),_) -> Some exp
        | _ -> None
        
(* record writes happenning to function pointers *)
let recordFuncWrite lv rhs = 
    match rhs with AddrOf(rlval) -> begin 
        let name = getFieldFromLval lv in     
        match lv with (lh,off) -> begin
            match (lh,off) with (Mem(e),Field(f,o)) -> begin
                let tl = [ "think"; "reached"; "blocked"; "touch"; "use"; "pain"; "die" ] in
                let reg = getRegexp tl in
                if string_match reg name 0 then 
                    let i = Call((None), (getFuncLval (getRecFuncWriteDecl)),
                    [ CastE(voidPtrType, e); (mkString (getCompleteFieldName f.fname o));
                    (mkString (Pretty.sprint 200 (Pretty.dprintf "%a" d_lval rlval))) ], locUnknown) in 
                    [i]
                else
                    []        
            end
            | _ -> []
        end
    end
    | _ -> []
        
let recordLval recFunc lv = 
     match lv with (lh,off) -> 
       match (lh,off) with 
           (Mem(exp),Field(f,o)) -> 
             let name = getCompleteFieldName f.fname o in 
	     if shouldRecordExp exp && name <> "inuse" then
	       let i = recFunc exp f o in [i]
	     else 
	       []
         | (Var(vinfo),_) -> []   (* dont want to record writes here *)
         | _ -> []
	     
let recordWrites = recordLval wrRecInstr
		     
let rec recordReads (e : exp) : instr list = 
  match e with 
     Lval(lv) | AddrOf(lv) | StartOf(lv) -> recordLval rdRecInstr lv
    | SizeOfE(e1) | AlignOfE(e1) | UnOp(_,e1,_) | CastE(_,e1) -> recordReads e1 
    | BinOp(_,e1,e2,_) -> 
        let l1 = recordReads e1 in 
	let l2 = recordReads e2 in 
	  l1 @ l2
    | _ -> []
	
let tp x = ignore (Pretty.printf "%s\n" x)
	     
let recordFunCall (f: exp) : instr list = 
  match f with 
      Lval(Mem(Lval(lv)),NoOffset) -> (recordLval fnRecInstr_1 lv) @ (recordLval fnRecInstr_2 lv)
    | _ -> []
	
class visitor = object
  inherit nopCilVisitor
  method vinst (i: instr) : instr list visitAction = begin
    match i with 
	Set(lv,rhs_exp,loc) -> begin
    	    (*  ignore (Pretty.printf "%a ======= %a\n" d_lval lv d_exp rhs_exp); *)
	  let wr_list = recordWrites lv in 
          let fwr_list = recordFuncWrite lv rhs_exp in
	  let rd_list = recordReads rhs_exp in 
	    ChangeTo(fwr_list @ wr_list @ rd_list @ [ i ])
	end
      | Call(lvopt,func,args,loc) -> begin 
	  (* ignore (Pretty.printf "instr-> %a\n" dn_instr i); *)
          let wr_list = 
	     match lvopt with Some lv -> recordWrites lv 
	       | _ -> []
	  in 
	  let fn_list = recordFunCall func in 
	  let f l arg = l @ (recordReads arg) in 
	  let rd_list = List.fold_left f [] args in 
	    if (List.length fn_list > 0) then
	      ChangeTo(wr_list @ rd_list @ [ (List.nth fn_list 0) ; i ; (List.nth fn_list 1)])
	    else
	      ChangeTo(wr_list @ rd_list @ [ i ])
	end
	  
      (*
	match lv with (lh,off) -> begin
	begin match lh with 
	Var(vinfo) -> ignore(Pretty.printf "lh=variable\n")
	| Mem(exp) -> 
	ignore(Pretty.printf "lh=mem-exp=%a type=%a record=%b\n" 
	d_exp exp d_type (typeOf exp) (shouldRecordType
	(typeOf exp)))
	end;
	match off with 
	NoOffset -> ignore (Pretty.printf "off=nooffset\n")
	| Field(finfo,o) -> ignore (Pretty.printf "\tfield: fullname=%s\n" (getCompleteFieldName finfo.fname o));
	| Index(_,_) -> ignore (Pretty.printf "off=index\n");
	end;
      *)	    
      | _ -> DoChildren
  end
						       
  method vstmt (s : stmt) : stmt visitAction = begin 
    (* instrument the expressions in the control flow statements. 
       instructions within	basic-blocks will be changed by vinst *)
    
    match s.skind with 
	Return(Some exp,_) 
      | If(exp,_,_,_) 
      | Switch(exp,_,_,_) -> 
          let rd_list = recordReads exp in 
    	  let rec_st = mkStmt (Instr (rd_list)) in 
	    ChangeDoChildrenPost(s, fun sn -> mkStmt (Block (mkBlock [rec_st ; sn])))
      | _ -> DoChildren
  end 
						 
end
  
let feature : featureDescr = 
  { fd_name = "instrument";
    fd_enabled = ref false;
    fd_description = "random instrumentation";
    fd_extraopt = [];
    fd_doit = 
      (function (f: file) -> 
	 let lwVisitor = new visitor in
	   visitCilFileSameGlobals lwVisitor f);
    fd_post_check = true;
  } 
    
