(**********************************************************************)
(* (c) Greg Morrisett, Dan Grossman, Steve Zdancewic                  *)
(*     June 1998, all rights reserved.                                *)
(**********************************************************************)

(* cfgopt.ml *)

open Utilities;;
open Numtypes;;
open Tal;;
open Cfg;;
open Talutil;;

let debug = true
let debugdo s = if debug then 
  print_string (s^"\n")
else ()

(* live_type_b Sets a block's code type to contain only those registers that *)
(* are in the set live_in.                                                   *)
let live_type_b live_in block =
  let new_con = 
    match block.con with
      None -> None
    | Some c -> 
      	let rs = get_rs c in
      	let new_rs = rs_fold_reg
	    (fun r c rs' -> 
		(* Strip out only virtual registers *)
	      match r with
	      	Virt _ ->
	    	  if Set.member live_in r 
		  then rs_set_reg rs' r c 
		  else rs'
	      | _ -> rs_set_reg rs' r c)
	    rs rs_empty
      	in
      	let new_code = code_con new_rs in
      	Some (set_code c new_code)
  in
  block.con <- new_con

let remove_dead_types cfg =
  begin
    print_string "REMOVING DEAD TYPES\n";
    Cfg.fold (fun b _ -> live_type_b b.live_in b) cfg ();
  end

let remove_dead_code cfg = 
  let remove_dead_code_b block _ =
    let xa = Xarray.create (Array.length block.code) Nop in
    let live_out = ref block.live_out in
    let live_in = ref (Set.empty compare_regs) in
    let ppi i = begin
      Talpp.print_instruction Format.std_formatter Talpp.std_options i;
      Format.pp_force_newline Format.std_formatter ();
      Format.pp_print_flush Format.std_formatter ();
    end
    in
    let emit i = begin
(*      ppi i; *)
      Xarray.add xa i;
    end 
    in
    let rev_xa () =
      let code = Xarray.create (Array.length block.code) Nop in
      begin
	for i = (Xarray.length xa) - 1 downto 0 do
	  Xarray.add code (Xarray.get xa i)
	done;
	Xarray.to_array code
      end
    in
    let remove_dead_code_i i =
      let (def, use) = Cfginstr.def_use i in
   (* Check to see if the instruction has "side effects" like setting       *)
   (* condition codes.  If so, that instruction can't be deleted.           *)
      match i with
	(Call _ | Clc | Cmc | Cmp _ | Int _ | Into | Jcc _ | Jecxz _ |
	 Jmp _ | Loopd _ |
	 Mov (Addr _, _) | Mov (Prjr _, _) | Mov (Prjl _, _) |
	 Nop | Pop _ | Popad | Popfd | Push _ | Pushad |
         Pushfd | Stc | Test _ | Aupd _ | Bexn _ | Btagi _ | 
	 Btagvar _ | Comment _ | Fallthru _ | Coerce _) -> 
	   begin
	     emit i;
	     live_in := Set.diff !live_out def;
	     live_out := Set.union !live_in use;
	   end
      |	Retn _ -> 
	  begin
	    emit i;
	    live_in := Set.diff !live_out def;
	    live_out := Cfg.get_rets cfg block.fun_lab;
	  end
 (* Otherwise, see if the registers defined are used in the future.  If   *)
 (* they are, the instruction must be emitted.  If not, the instruction   *)
 (* can be omitted.                                                       *)
      |	_ -> if Set.is_empty (Set.intersect def !live_out) 
	     && (not (Set.member def Esp))
      then ()
      else begin
   	emit i;
	live_in := Set.diff !live_out def;
	live_out := Set.union !live_in use;
      end	  
    in
    begin
(*      debugdo ("remove_dead_code_b :" ^(Identifier.id_to_string block.lab));*)
      for i = (Array.length block.code) - 1 downto 0 do
	remove_dead_code_i block.code.(i);
      done;
      block.code <- rev_xa ();
      live_type_b !live_out block;
    end
  in
  begin
    print_string "REMOVING DEAD CODE\n";
    Format.print_flush();
    Cfg.fold remove_dead_code_b cfg ();
  end 

let dereg r =
  match r with
    Reg  r -> r
  | _      -> failwith "bug: cfg optimize thought a genop would be Reg"

(* CFG block optimize:
 *  Does not remove instructions, but makes some so liveness will!!!
 *  Replace setcc btagi with setcc jcc (let liveness eliminate mov and setcc)
 *)

(* Currently O(nm) where n is number of local variables, and m is number
 * of instructions, but could be made O(logn m) without too much trouble.
 *)

type trackedLoc = S of int32 | R of reg | Void
type fromAbove = { locToVal  : (trackedLoc, genop) Dict.dict } 
let comp = 
  (fun tl1 tl2 -> 
    match (tl1, tl2) with
      (S a, S b) -> compare a b
    | (R a, R b) -> compare_regs a b
    | _ -> 1)
let emptyAbove = { locToVal  = Dict.empty comp }

let genopToLoc g =
  match g with
    Reg r -> R r
  | Prjr((Esp,_), offset) -> S offset
  | _ -> Void

let printAbove above =
  Dict.app_dict
    (fun l g ->
      (match l with
	S i -> print_string (string_of_int32 i)
      |	R Eax -> print_string "EAX"
      |	R r -> print_string "R"
      |	_ -> print_string "V");
      print_string " --> ";
      (match (genopToLoc g) with
	S i -> print_string (string_of_int32 i)
      |	R Eax -> print_string "EAX"
      |	R r -> print_string "R"
      |	_ -> print_string "V");
      print_newline())
    above.locToVal

let newuse filter above g = (* filter : genop -> bool *)
  let rec aux bestsofar current = 
    let newLoc = genopToLoc current in
(*     (match newLoc with
       S i -> print_int i
     | R Eax -> print_string "EAX"
     |	R r -> print_string "R"
     |	_ -> print_string "V");
    print_newline();*)
    match newLoc with
      Void -> bestsofar
    | _    ->
	if Dict.member above.locToVal newLoc
	then 
	  let next = Dict.lookup above.locToVal newLoc in
	  aux (if filter next then next else bestsofar) next
	else
	  bestsofar
  in
  aux g g
let killstack above = 
  { 
  locToVal =
  Dict.fold_dict 
    (fun l g d ->
      match (l,genopToLoc g) with
 	((S _,_) | (R Esp,_) | (_, S _) | (_, R Esp)) -> d
      |	_ -> Dict.delete d l )
    above.locToVal above.locToVal
  }   

let killloc above loc =
  {
  locToVal =
  Dict.fold_dict
    (fun l g d ->
      if (comp l loc) = 0 or (comp (genopToLoc g) loc) = 0
      then Dict.delete d l 
      else d)
    above.locToVal above.locToVal
  } 

let insertloc above loc g = { locToVal = Dict.insert above.locToVal loc g }

let blockOptimize block =  (* Stubs until talk to Steve!!! *)
  let newuseReg = newuse (fun x -> match x with Reg _ -> true | _ -> false) in
  let newuseNoprjr = 
    newuse (fun x -> match x with Prjr _ -> false | _ -> true) in
  let newuseAny = 
    newuse (fun x -> true)                                    in
  let newuseNoImmed = newuseReg in (* STUB!!!*)
    (* newuse (fun x -> match x with Immed _ -> false | _ -> true) in *)
  let kill above g =
    match g with
      Reg Esp           -> killstack above (* Be Smarter Later *)
    | Reg r             -> killloc above (R r)
    | Prjr((Esp,_),off) -> killloc above (S off) (* Wrong if not immed *)
    | _                 -> above in
  let def above g g' = 
    let above = kill above g in
    match g with
      Reg Esp            -> above
    | Prjr((Esp,[]),off) -> insertloc above (S off) g'
    | Prjr _             -> above 
    | Reg r              -> insertloc above (R r) g'
    | _                  -> above in
  let inst_arr = block.code in
  Array.fold_left
    (fun (above, ind) inst ->
      let newAbove = 
  	match inst with
	  Mov(g1, (g2, [])) ->
	    let lookup = 
	      match g1 with
	      	Prjr _ -> newuseNoprjr
	      |	_      -> newuseAny in
	    let newsrc =
	      match g2 with
		Reg r            -> lookup above g2
	      |	Prjr((Esp,[]),i) -> lookup above g2
	      |	Prjr((r,c'),i)   -> Prjr((dereg(newuseReg above (Reg r)),c'),i)
	      |	_                -> g2 in
            let above =
              match g1 with
	      	Reg r            -> def above g1 g2
	      |	Prjr((Esp,[]),i) -> def above g1 g2
	      |	Prjr((r,c'),i)   -> kill above (Reg r) 
	      |	_                -> above in
	    inst_arr.(ind) <- Mov(g1, (newsrc, []));
	    above
	| Cmp(g1, g2)     -> 
	    if compare (Identifier.id_to_string block.lab) "_fact_spill$3836" = 0
	    then
	     ( print_string "========\n";
	      printAbove above);
	    let new1 = newuseNoImmed above g1 in (* First can't be immed *)
(*	    (match new1 with
	      Prjr((Esp,[]), i) -> print_int i; print_newline()
	    | _ -> ());*)
	    let lookup = 
	      match new1 with
	      	Prjr _ -> newuseNoprjr
	      |	_      -> newuseNoprjr (* STUB!!! newuseAny*) in	    
	    let new2 = lookup above g2 in
	    inst_arr.(ind) <- Cmp(new1, new2);
	    above
	| Push((Reg r) as g1, []) ->
	    inst_arr.(ind) <- Push(newuseAny above g1, []);
	    killstack above
	| (Push _ | Pushad | Pushfd) -> killstack above 
	| Pop((Reg r) as g1)         -> killstack (kill above g1)
	| (Pop _ | Popad | Popfd)    -> killstack above
      	| ArithBin(op, g1, g2) -> 
	    let lookup = 
	      match g1 with
	      	Prjr _ -> newuseNoprjr
	      |	_      -> newuseAny in
	    inst_arr.(ind) <- ArithBin(op, g1, lookup above g2);
	    kill above g1
	| ArithSR(op, g1, None) ->
	    let g2 = 
	      match newuseAny above (Reg Ecx) with
	      	Immed i -> Some i 
	      |	_       -> None in
	    inst_arr.(ind) <- ArithSR(op, g1, g2);
	    kill above g1
	| Setcc(cond, ((Reg r1) as g1)) -> 
	    (match inst_arr.(ind+1) with
	      Btagi(r2, i, idcoerce, Eq) when i=$i32_0 ->
              (* Generalize this!!! *)
		if compare_regs r1 r2 = 0
		then 
		  let negated = 
		    match cond with
		      Eq        -> NotEq
		    | NotEq     -> Eq
		    | GreaterEq -> Less
		    | Greater   -> LessEq
		    | LessEq    -> Greater
		    | Less      -> GreaterEq
		    | Below     -> AboveEq
		    | _ -> failwith "optimizer found unexpected condition code"
		  in inst_arr.(ind+1) <- (Jcc (negated, idcoerce))
		else ()
	    | _ -> ());
	    above
	| inst -> 
	    let (defs, _) = Cfginstr.def_use inst in
	    Set.fold (fun r above -> kill above (Reg r)) defs above
      in (newAbove, ind+1))
    
    (emptyAbove, 0) inst_arr; 
  ()

let peephole cfg =
  Cfg.fold (fun block () -> blockOptimize block) cfg ();
  Cfgliveness.liveness cfg;
  remove_dead_code cfg;
  Cfgliveness.liveness cfg

let rec remove_dead_blocks cfg = (* Iterates to fixed point *)
  let to_delete = 
    Dict.fold_dict 
      (fun id block to_delete -> 
	if Set.is_empty block.pred & not (Set.member cfg.roots block.lab)
	then id::to_delete
	else to_delete)
      cfg.blocks [] in
  if debug
  then (print_string "DELETED "; 
  	print_int (List.length to_delete);
  	print_string " BLOCKS\n");
  match to_delete with
    [] -> ()
  | _  -> (List.iter (del_block cfg) to_delete;
	   remove_dead_blocks cfg)

(* Jump threading hoists a single instruction block of the jmp,
fallthru, or return variety up to all predecessors and deletes the
block.  The catch is mapping the quantifier variables in the coercions.

This is intra-procedural, so we do nothing if an edge besides jump, branch, or
sequence comes into the block.

Careful: don't delete blocks while iterating!
*)

(* Much to add -- right now only does Retn and does nothing unless every
 * incoming is Sequence of Jump
 *)
let jumpthread cfg = ()
(* NOT DONE -- WORK ON THIS 
   let toDelete =
   Cfg.fold
   (fun block toDelete ->
   if Array.length block.code = 1
   then 
   match Array.get block.code 0 with
   ((*Fallthru _ | Jmp _ |*) Retn a ) ->
	      if Set.fold 
		  (fun (_,e,_) sofar -> sofar & 
		    (match e with
		      (Sequence | Jump) -> true
		    | _                 -> false))
		  block.pred
		  true
	      then
	      	(Set.app 
	      	   (fun (lab,_,_) ->
		   let pred = get_block cfg lab in
		   Array.set pred.code ((Array.length pred.code)-1) (Retn a);
		   (* Must also re-assign preds of succs *)
		   pred.succ <-
		      Set.fold (fun (_,e,l) s -> Set.insert s (lab, e, l)) 
			block.succ (Set.empty compare_edges))
	      	   block.pred;
		 block.lab::toDelete)
	      else
	      	toDelete
	  | _ -> toDelete
      	else
	  toDelete)
    cfg [] in
  List.iter (del_block cfg) toDelete
*)

