   
(**********************************************************************)
(* (c) Greg Morrisett, Dan Grossman, Steve Zdancewic                  *)
(*     June 1998, all rights reserved.                                *)
(**********************************************************************)

(* cfgopt.ml *)

open Tal
open Cfg
open Utilities
open Talutil

let debug = false
let debugdo s = if debug then 
  print_string (s^"\n")
else ()


(* live_type_b Sets a block's code type to contain only those registers that *)
(* are in the set live_in.                                                   *)
let live_type_b live_in block =
  let new_con = 
    match block.con with
      None -> None
    | Some c -> 
      	let rs = get_rs c in
      	let new_rs = rs_fold_reg
	    (fun r c rs' -> 
		(* Strip out only virtual registers *)
	      match r with
	      	Virt _ ->
	    	  if Set.member live_in r 
		  then rs_set_reg rs' r c 
		  else rs'
	      | _ -> rs_set_reg rs' r c)
	    rs rs_empty
      	in
      	let new_code = code_con new_rs in
      	Some (set_code c new_code)
  in
  block.con <- new_con

let remove_dead_types cfg =
  begin
    print_string "REMOVING DEAD TYPES\n";
    Cfg.fold (fun b _ -> live_type_b b.live_in b) cfg ();
  end

let remove_dead_code cfg =
  let remove_dead_code_b block _ =
    let xa = Xarray.create (Array.length block.code) Nop in
    let live_out = ref block.live_out in
    let live_in = ref (Set.empty compare_regs) in
    let ppi i = begin
      Talpp.print_instruction Format.std_formatter Talpp.std_options i;
      Format.pp_force_newline Format.std_formatter ();
      Format.pp_print_flush Format.std_formatter ();
    end
    in
    let emit i = begin
(*      ppi i; *)
      Xarray.add xa i;
    end 
    in
    let rev_xa () =
      let code = Xarray.create (Array.length block.code) Nop in
      begin
	for i = (Xarray.length xa) - 1 downto 0 do
	  Xarray.add code (Xarray.get xa i)
	done;
	Xarray.to_array code
      end
    in
    let remove_dead_code_i i =
      let (def, use) = Cfginstr.def_use i in
      (* Check to see if the instruction has "side effects" like setting       *)
      (* condition codes.  If so, that instruction can't be deleted.           *)
      match i with
	(Call _ | Clc | Cmc | Cmp _ | Int _ | Into | Jcc _ | Jecxz _ |
	 Jmp _ | Loopd _ |
	 Mov (Addr _, _) | Mov (Prjr _, _) | Mov (Prjl _, _) |
	 Nop | Pop _ | Popad | Popfd | Push _ | Pushad |
         Pushfd | Stc | Test _ | Aupd _ | Bexn _ | Btagi _ | 
	 Btagvar _ | Comment _ | Fallthru _ | Coerce _) -> 
	   begin
	     emit i;
	     live_in := Set.diff !live_out def;
	     live_out := Set.union !live_in use;
	   end
      |	Retn _ -> 
	  begin
	    emit i;
	    live_in := Set.diff !live_out def;
	    live_out := Cfg.get_rets cfg block.fun_lab;
	  end
      (* Otherwise, see if the registers defined are used in the future.  If   *)
      (* they are, the instruction must be emitted.  If not, the instruction   *)
      (* can be omitted.                                                       *)
      |	_ -> if Set.is_empty (Set.intersect def !live_out) 
	     && (not (Set.member def Esp))
      then
	  begin
(*	    print_string "Deleting: ";
	    ppi i;  *)
	  (* Don't emit the instruction *) ()
	  end
      else begin
   	emit i;
	live_in := Set.diff !live_out def;
	live_out := Set.union !live_in use;
      end	  
    in
    begin
      debugdo ("remove_dead_code_b :" ^(Identifier.id_to_string block.lab));
      for i = (Array.length block.code) - 1 downto 0 do
	remove_dead_code_i block.code.(i);
      done;
      block.code <- rev_xa ();
      live_type_b !live_out block;
    end
  in 
  begin
    print_string "REMOVING DEAD CODE\n";
    Format.print_flush();
    Cfg.fold remove_dead_code_b cfg ();
  end 

let dereg r =
  match r with
    Reg  r -> r
  | _      -> failwith "bug: cfg optimize thought a genop would be Reg"

(* CFG block optimize:
 *  Assumes Esp is never used to define another register -- might be okay
 *  Does not remove instructions, but makes some so liveness will!!!
 *  Replace setcc btagi with setcc jcc (let liveness eliminate mov and setcc)
 *)

type fromAbove = { 
    regToVal  : (genop, genop)      Dict.dict;
    regToUses : (genop, genop list) Dict.dict
  }

let blockOptimize block =
  let emptyAbove = 
    let comp = (fun g1 g2 -> compare_regs (dereg g1) (dereg g2)) in
    { regToVal = Dict.empty comp; regToUses = Dict.empty comp} in
  let rec newuse above g =
    match g with
      Reg _ -> 
	if   Dict.member above.regToVal g
      	then 
	  let g' = Dict.lookup above.regToVal g in
	  match g' with
	    Reg _ -> newuse above g'
	  | _ -> g'
	else g
    | _ -> g in
  let rec newuseNoImmed above g =
    match g with
      Reg _ -> 
	if   Dict.member above.regToVal g
      	then 
	  let g' = Dict.lookup above.regToVal g in
	  match g' with
	    Reg _   -> newuseNoImmed above g'
	  | Immed _ -> g
	  | _ -> g'
	else g
    | _ -> g in
  let kill above g =
    let uses      = try Dict.lookup above.regToUses g with Dict.Absent -> [] in
    let regToUses = Dict.delete above.regToUses g in
    let regToVal  = Dict.delete above.regToVal  g in
    let regToVal  = List.fold_left Dict.delete regToVal uses in
    { regToVal  = regToVal; 
      regToUses = regToUses } in
  let def above ((Reg r) as g) g' =
    let above = kill above g in
    match g' with
      Reg Esp -> above (* Because Esp gets killed a lot and we don't track it *)
    | Prjr _  -> above 
    | _ -> 
    	{ regToVal  = Dict.insert above.regToVal g g';
	  regToUses = 
	  match g' with
	    Reg _ -> 
	      Dict.insert above.regToUses g'
	    	(g::
		 (if Dict.member  above.regToUses g'
		 then Dict.lookup above.regToUses g'
		 else []))
	  | _ -> above.regToUses
    	} in
  
  let inst_arr = block.code in
  Array.fold_left
    (fun (above, ind) inst ->
      let newAbove = 
  	match inst with
	  Mov(g1, (g2, [])) ->
            let (above, newdest) =
              match g1 with
	      	Reg r          -> (def above g1 g2, g1)
	      |	Prjr((r,c'),i) -> (kill above (Reg r), (* damn types! *)
				   g1)
			  (* Prjr((dereg (newuse above (Reg r)),c'),i)*)
	      |	_              -> (above, g1) in
	    let newsrc =
	      match g2 with
		Reg r          -> newuse above g2
	      |	Prjr((r,c'),i) -> Prjr((dereg(newuse above (Reg r)),c'),i)
	      |	_              -> g2 in
	    inst_arr.(ind) <- Mov(newdest, (newsrc, []));
	    above
	| Cmp(g1, g2)     -> 
           (* First can't be immed *)
            inst_arr.(ind) <- Cmp(newuseNoImmed above g1, newuse above g2); 
	    above
	| Push((Reg r) as g1, []) ->
	    inst_arr.(ind) <- Push(newuse above g1, []);
	    above
      	| ArithBin(op, g1, g2) -> 
	    inst_arr.(ind) <- ArithBin(op, g1, newuse above g2);
	    kill above g1
	| ArithSR(op, g1, None) ->
	    let g2 = 
	      match newuse above (Reg Ecx) with
	      	Immed i -> Some i 
	      |	_       -> None in
	    inst_arr.(ind) <- ArithSR(op, g1, g2);
	    kill above g1
	| Setcc(cond, ((Reg r1) as g1)) -> 
	    (match inst_arr.(ind+1) with
	      Btagi(r2, 0, idcoerce, Eq) -> (* Generalize this!!! *)
		if compare_regs r1 r2 = 0
		then 
		  let negated = 
		    match cond with
		      Eq        -> NotEq
		    | NotEq     -> Eq
		    | GreaterEq -> Less
		    | Greater   -> LessEq
		    | LessEq    -> Greater
		    | Less      -> GreaterEq
		    | Below     -> AboveEq
		    | _ -> failwith "optimizer found unexpected condition code"
		  in inst_arr.(ind+1) <- (Jcc (negated, idcoerce))
		else ()
	    | _ -> ());
	    above
	| inst -> 
	    let (defs, _) = Cfginstr.def_use inst in
	    Set.fold (fun r above -> kill above (Reg r)) defs above
      in (newAbove, ind+1))
    
    (emptyAbove, 0) inst_arr; 
  ()

let peephole cfg =
  Cfg.fold (fun block () -> blockOptimize block) cfg ();
  Cfgliveness.liveness cfg;
  remove_dead_code cfg;
  Cfgliveness.liveness cfg

(* Jump threading hoists a single instruction block of the jmp,
fallthru, or return variety up to all predecessors and deletes the
block.  The catch is mapping the quantifier variables in the coercions.

This is intra-procedural, so we do nothing if an edge besides jump, branch, or
sequence comes into the block.

Careful: don't delete blocks while iterating!
*)

(* Much to add -- right now only does Retn and does nothing unless every
 * incoming is Sequence of Jump
 *)
let jumpthread cfg = ()
(* NOT DONE -- WORK ON THIS 
   let toDelete =
   Cfg.fold
   (fun block toDelete ->
   if Array.length block.code = 1
   then 
   match Array.get block.code 0 with
   ((*Fallthru _ | Jmp _ |*) Retn a ) ->
	      if Set.fold 
		  (fun (_,e,_) sofar -> sofar & 
		    (match e with
		      (Sequence | Jump) -> true
		    | _                 -> false))
		  block.pred
		  true
	      then
	      	(Set.app 
	      	   (fun (lab,_,_) ->
		   let pred = get_block cfg lab in
		   Array.set pred.code ((Array.length pred.code)-1) (Retn a);
		   (* Must also re-assign preds of succs *)
		   pred.succ <-
		      Set.fold (fun (_,e,l) s -> Set.insert s (lab, e, l)) 
			block.succ (Set.empty compare_edges))
	      	   block.pred;
		 block.lab::toDelete)
	      else
	      	toDelete
	  | _ -> toDelete
      	else
	  toDelete)
    cfg [] in
  List.iter (del_block cfg) toDelete
*)

