(**********************************************************************)
(* (c) Greg Morrisett, Steve Zdancewic                                *)
(*     June 1998, all rights reserved.                                *)
(**********************************************************************)

(* regrewrite.ml
 *
 * Utility functions for rewriting control flow graphs for register 
 * allocation.
 *)

open Tal
open Cfg
open Regifg
open Talutil

let debug = false
let debugdo s = if debug then 
  print_string (s^"\n")
else ()

exception Done

let rsub_reg colors r =
  try Dict.lookup colors r with Dict.Absent -> r

let rec rsub_rcon colors rc =
  match rc with
    Cvar _ -> rc
  | Clam (id, k, c) -> Clam (id, k, rsub_con colors c)
  | Capp (c, c') -> Capp (rsub_con colors c, rsub_con colors c')
  | Ctuple cl -> Ctuple (List.map (rsub_con colors) cl)
  | Cproj (i, c) ->  Cproj (i, rsub_con colors c)
  | Clab _ -> rc
  | Cprim _ -> rc
  | Crec ikcl -> 
      Crec  (List.map (fun (id, k, c) -> (id, k, rsub_con colors c)) ikcl)
  | Cforall (id, k, c) -> Cforall (id, k, rsub_con colors c)
  | Cexist (id, k, c) -> Cexist (id, k, rsub_con colors c)
  | Ccode rs -> Ccode (rsub_regstate colors rs)
  | Chptr (il, copt) ->
      (match copt with 
	None -> rc
      |	Some c -> Chptr (il, Some (rsub_con colors c)))
  | Cfield (c, v) -> Cfield (rsub_con colors c, v)
  | Cprod cl -> Cprod (List.map (rsub_con colors) cl)
  | Csum cl -> Csum (List.map (rsub_con colors) cl)
  | Carray (c1, c2) -> Carray (rsub_con colors c1, rsub_con colors c2)
  | Csing c -> Csing (rsub_con colors c)
  | Csptr c -> Csptr (rsub_con colors c)
  | Cempty -> rc
  | Ccons (c1, c2) -> Ccons (rsub_con colors c1, rsub_con colors c2)
  | Cappend (c1, c2) -> Cappend (rsub_con colors c1, rsub_con colors c2)

and rsub_regstate colors rs =
  rs_fold_reg
    (fun r c rs' -> rs_set_reg rs' (rsub_reg colors r) (rsub_con colors c))
    rs rs_empty

and rsub_con colors {rcon=rc; isnorm=b; freevars=fv} =
  {rcon=rsub_rcon colors rc;
   isnorm=b;
   freevars=fv}

let rsub_coercion colors c =
  match c with
    Pack (c1, c2) -> Pack (rsub_con colors c1, rsub_con colors c2)
  | Tapp c1 -> Tapp (rsub_con colors c1)
  | Roll c1 -> Roll (rsub_con colors c1)
  | Unroll -> c      
  | Tosum c1 -> Tosum (rsub_con colors c1)
  | Fromsum -> c  
  | RollTosum c1 -> RollTosum (rsub_con colors c1)
  | Toexn -> c     
  | Toarray _-> c  
  | Slot _ -> c
  | Subsume c -> Subsume (rsub_con colors c)

let rsub_colist colors = List.map (rsub_coercion colors)

let rsub_genop colors g =
  match g with
    Immed _ -> g
  | Tag _ -> g
  | Reg r -> Reg (rsub_reg colors r)
  | Addr _ -> g
  | Prjr ((r, cl), i) -> Prjr ((rsub_reg colors r, rsub_colist colors cl), i)
  | Prjl ((id, cl), i) -> Prjl ((id, rsub_colist colors cl), i)

let rsub_conv colors c =
  match c with
    (Cbw | Cdq | Cwd | Cwde) -> c
  | Movsx (r, s, g, sc) -> Movsx (rsub_reg colors r, s, rsub_genop colors g, sc)
  | Movzx (r, s, g, sc) -> Movzx (rsub_reg colors r, s, rsub_genop colors g, sc)

let rec rsub_marg colors m =
  match m with
    Mprod mal -> Mprod (List.map (fun ma -> (rsub_marg colors ma)) mal)
  | Mfield c -> Mfield (rsub_con colors c)
  | Mbytearray _ -> m
  | Mexnname c -> Mexnname (rsub_con colors c)

let rsub_instr colors i =
  let rg = rsub_genop colors in
  let rr = rsub_reg colors in
  let rcl = rsub_colist colors in
  match i with
    ArithBin (arithbin, g, g') -> ArithBin (arithbin, rg g, rg g')
  | ArithUn (arithun, g) -> ArithUn (arithun, rg g)
  | ArithMD (arithmd, g) -> ArithMD (arithmd, rg g) 
  | ArithSR (arithsr, g, io) -> ArithSR (arithsr, rg g, io)
  | Bswap r  -> Bswap (rr r)
  | Call (g, cl) -> Call (rg g, rcl cl)        
  | Clc -> i
  | Cmc -> i                 	
  | Cmovcc (cond, r, (g, cl)) -> Cmovcc (cond, rr r, (rg g, rcl cl))
  | Cmp (g, g') -> Cmp (rg g, rg g')
  | Conv c -> Conv (rsub_conv colors c)
  | Imul3 (r, g, i) -> Imul3 (rr r, rg g, i)
  | Int _ -> i
  | Into -> i             
  | Jcc (cond, (id, cl)) -> Jcc (cond, (id, rcl cl))    
  | Jecxz (id, cl) -> Jecxz (id, rcl cl)
  | Jmp (g, cl) -> Jmp (rg g, rcl cl)      	
  | Lahf -> i
  | Lea (r, g) -> Lea (rr r, rg g)
  | Loopd ((id, cl), bo) -> Loopd ((id, rcl cl), bo)
  | Mov (g, (g', cl)) -> Mov (rg g, (rg g', rcl cl))
  | Nop -> i                	
  | Pop g -> Pop (rg g)
  | Popad -> i                 	
  | Popfd -> i               	
  | Push (g, cl) -> Push (rg g, rcl cl)
  | Pushad -> i               	
  | Pushfd -> i              	
  | Retn _ -> i
  | Sahf -> i              	
  | Setcc (cond, g) -> Setcc (cond, rg g)
  | Shld (g, r, io) -> Shld (rg g, rr r, io)
  | Shrd (g, r, io) -> Shrd (rg g, rr r, io)
  | Stc -> i                	
  | Test (g, g') -> Test (rg g, rg g')
  | Xchg (g, r) -> Xchg (rg g, rr r)   
  | Asub (r1, g1, sc, r2, g2) -> Asub (rr r1, rg g1, sc, rr r2, rg g2)
  | Aupd (g1, sc, r1, r2, g2) -> Aupd (rg g1, sc, rr r1, rr r2, rg g2)
  | Bexn (r, g, (id, cl)) -> Bexn (rr r, rg g, (id, rcl cl))
  | Btagi (r, i, (id, cl), cond) -> Btagi (rr r, i, (id, rcl cl), cond)
  | Btagvar (r, i1, i2, (id, cl), cond) -> 
      Btagvar (rr r, i1, i2, (id, rcl cl), cond)
  | Coerce (r, cl) -> Coerce (rr r, rcl cl)
  | Comment _ -> i
  | Fallthru cl -> Fallthru (List.map (rsub_con colors) cl)
  | Malloc (i, m) -> Malloc (i, rsub_marg colors m)
  | Unpack (id, r, (g, cl)) -> Unpack (id, rr r, (rg g, rcl cl))

let rsub_regset colors rs =
  Set.fold (fun r s -> Set.insert s (rsub_reg colors r)) rs 
    (Set.empty compare_regs)

let rewrite_regs_block colors b () = begin
  debugdo "rewrite_regs_block"; 
  (match b.con with 
     None -> ()
   | Some c -> b.con <- Some (rsub_con colors c));
  b.code <- Array.map (rsub_instr colors) b.code;
  b.use <- rsub_regset colors b.use;
  b.def <- rsub_regset colors b.def;
  b.live_in <- rsub_regset colors b.live_in;
  b.live_out <- rsub_regset colors b.live_out;
end

let rsub_id_reg_dict colors d =
  Dict.fold_dict (fun i rs d' -> Dict.insert d' i (rsub_regset colors rs))
    d (Dict.empty Identifier.id_compare)

(* Performs a substitution on registers for a given control flow graph given *)
(* and mapping from registers to registers. *)
let rewrite_regs cfg colors =
  begin
    print_string "REWRITING REGISTERS\n";
    Format.print_flush();
    Cfg.fold (rewrite_regs_block colors) cfg ();
    cfg.args <- rsub_id_reg_dict colors cfg.args;
    cfg.rets <- rsub_id_reg_dict colors cfg.rets;
    cfg.regs <- rsub_id_reg_dict colors cfg.regs;
    cfg
  end

let collect_rets cfg f d =
  Set.fold (fun r d -> Dict.insert d r Eax) (Cfg.get_rets cfg f) d

(* Given a control flow graph in which every function has at most one return *)
(* value passed in a (virtual) register, returns the control flow graph with *)
(* Eax used to return values instead. *)
let rewrite_returns cfg =
  (* Build a map from the virtual registers mentioned on Return edges to Eax *)
  let colors = Set.fold (collect_rets cfg) cfg.procs (Dict.empty compare_regs) in
  begin
    debugdo "rewrite_returns"; 
    rewrite_regs cfg colors
  end

let cleanup_block b () =
  let xa = Xarray.create (Array.length b.code) Nop in
  let emit i = Xarray.add xa i in
  let len = Array.length b.code in

  let found_callstart = ref false in
  let found_callend = ref false in
  let storing_live_regs = ref false in
  let restoring_live_regs = ref false in
  
  let delete_next_instr = ref false in
  let save_live_counter = ref 0 in
  let stack_slots_to_elim = ref [] in
  let stack_offs = ref 0 in
  let live_offs = ref 0 in

  let cleanup_instr index i = 
    try begin
      (if !found_callstart then
      	(match i with
	  Comment s -> if s = "live alloc done" then
	    found_callstart := false
	  else failwith "rewrite.ml: cleanup bad comment after callstart"
      	| ArithBin (Sub, Reg Esp, Immed slots) -> 
	    let num_deleted_slots = ref 0 in
	  (* Walk over the rest of the code in this block and count the number *)
	  (* of "delete next instr" comments there are.  Subtract 4 * num from *)
	  (* slots.                                                            *)
	    begin
	      for j=index to (len - 1) do
	      	match b.code.(j) with
		  Comment s -> 
		    if s = "delete next instr" then 
		      incr num_deleted_slots
		    else ()
	      	| _ -> ()
	      done;
	      let new_slots = slots - (4*(!num_deleted_slots)) in
	      (if new_slots > 0 then
		emit (ArithBin (Sub, Reg Esp, (Immed new_slots)))
	      else ());
	      raise Done;
	    end
      	| _ -> failwith "rewrite.ml: cleanup bad instr after callstart")
      else if !storing_live_regs then
      	(match i with
	  Mov (Prjr((Esp, cl1), bytes), (Immed imm, cl2)) ->
	    begin
    	  (* A stack slot is being filled from a register, update its offset *)
	      (if !delete_next_instr then
	      	begin
	      	  delete_next_instr := false;
	      	  stack_slots_to_elim := !stack_slots_to_elim@[true];
	      	end
	      else
		begin
		  emit i;
		  stack_slots_to_elim := !stack_slots_to_elim@[false];
		end);
	      raise Done;
	    end
	| Mov (Prjr((Esp, cl1), bytes), (Reg r, cl2)) ->
	    begin
	      (if !delete_next_instr then
	      	begin
		  delete_next_instr := false;
		  stack_slots_to_elim := !stack_slots_to_elim@[true];
	      	end
	      else
	      	begin
	      	  emit (Mov (Prjr((Esp, cl1), (4*(!save_live_counter + 
						  !live_offs))), 
			     (Reg r, cl2)));
		  incr save_live_counter;
	      	  stack_slots_to_elim := !stack_slots_to_elim@[false];
	      	end);
	      raise Done;
	    end
      	| _ -> ())
      else if !restoring_live_regs then
      	(match i with
	  ArithBin (Add, Reg Esp, Immed imm) -> 
	    begin
	      (if !delete_next_instr then
	      	delete_next_instr := false
	      else
	      	emit i);
	      raise Done;
	    end
      	| _ -> ())
      else ());
      (match i with
      	Comment s -> if s = "callstart" then
	  found_callstart := true
	else if s = "live save begin" then
	  begin
	    live_offs := !stack_offs;
	    storing_live_regs := true;
	    save_live_counter := 0;
	  end
	else if s = "callend" then
	  restoring_live_regs := true
	else if s = "live restore done" then
	  restoring_live_regs := false
	else if s = "arg alloc done" then
	  ()
	else if s = "ret alloc done" then
	  ()
	else if s = "live alloc done" then
	  stack_offs := 0
	else if s = "delete next instr" then
	  delete_next_instr := true
	else emit i
      |	Call (op, cl) -> 
	  emit (Call (op, (fix_cl Talctxt.empty_ctxt cl 
			     (elim_stack_slots !stack_slots_to_elim))))
      |	Mov (Reg r1, (Reg r2, cl)) ->
	  if (compare_regs r1 r2) = 0 then 
	    match cl with
	      [] -> ()
	    | _ -> emit (Coerce (r1, cl))
	  else emit i
      |	Push _ -> begin incr stack_offs; emit i end
      |	Pop _ -> begin decr stack_offs; emit i end
      |	ArithBin (Add, Reg Esp, Immed index) ->
	  begin
	    stack_offs := !stack_offs - (index/4);
	    if index > 0 then emit i else ();
	  end
      |	ArithBin (Sub, Reg Esp, Immed index) ->
	  begin
	    stack_offs := !stack_offs + (index/4);
	    if index > 0 then emit i else ();
	  end
      | ArithBin (Sub, Reg _, Immed 0) -> ()
      | ArithBin (Add, Reg _, Immed 0) -> ()
      | Nop -> ()
      | _ -> emit i);
  end with Done -> ()
  in begin
    Array.iteri cleanup_instr b.code;
    b.code <- Xarray.to_array xa;
  end

(* Removes redundant Mov instructions in the program, replacing them with *)
(* coercions if necessary. *)
let cleanup cfg =
  begin
    print_string "CLEANUP\n"; 
    Format.print_flush();
    Cfg.fold cleanup_block cfg ();
    cfg
  end

(* Information about the spills for a function in this iteration of register *)
(* allocation.  Information about previous iterations is found in cfg.rewrite *)
type fun_spill_info = {
    spilled_regs: reg Set.set;     (* The set of simple registers spilled in f *)
    slots: (reg, int) Dict.dict;   (* A map from simple spilled regs's to slots*)
    num_spills: int;               (* The number of new simple spills *)
    spilled_args: reg list;        (* The list of newly spilled args (in order) *)
    num_sp_args: int;              (* List.length spilled_args *)
    spilled_rets: reg list;        (* The list of newly spilled rets (in order) *)
    num_sp_rets: int;              (* List.length spilled_rets *)
    old_spills: int;               (* Number of previous simple spills *)
    old_sp_args: int;              (* Number of previously spilled args *)
    old_sp_rets: int               (* Number of previously spilled rets  *)
  } 

let empty_spill_info = {
    spilled_regs = Set.empty compare_regs;
    slots = Dict.empty compare_regs;
    num_spills = 0;
    spilled_args = [];
    num_sp_args = 0;
    spilled_rets = [];
    num_sp_rets = 0;
    old_spills = 0;
    old_sp_args = 0;
    old_sp_rets = 0
  } 

let form = Format.std_formatter
let ppr r = Talpp.print_reg form Talpp.std_options r

let print_info info = begin
  Format.print_flush ();
  Format.pp_print_string form " : Spill Info";
  Format.pp_print_newline form ();
  Format.pp_open_vbox form 5;
  Format.pp_print_cut form ();
  Format.pp_print_string form "spilled regs: ";
  Format.pp_print_cut form ();
  Set.app ppr info.spilled_regs;
  Format.pp_print_cut form ();
  Format.pp_print_string form "slots: ";
  Format.pp_print_cut form ();
  Dict.fold_dict (fun r i _ -> begin
    ppr r;
    Format.pp_print_string form (" --> "^(string_of_int i));
  Format.pp_print_cut form ();
  end) info.slots ();
  Format.pp_print_string form ("num_spills: "^(string_of_int info.num_spills));
  Format.pp_print_cut form ();
  Format.pp_print_string form "spilled_args:";
  Format.pp_print_cut form ();
  List.iter ppr info.spilled_args;
  Format.pp_print_cut form ();
  Format.pp_print_string form "spilled_rets:";
  Format.pp_print_cut form ();
  List.iter ppr info.spilled_rets;
  Format.pp_print_cut form ();
  Format.pp_print_string form ("old_spills: "^(string_of_int info.old_spills));
  Format.pp_print_cut form ();
  Format.pp_print_string form ("old_sp_args: "^(string_of_int info.old_sp_args));
  Format.pp_print_cut form ();
  Format.pp_print_string form ("old_sp_rets: "^(string_of_int info.old_sp_rets));
  Format.pp_print_cut form ();
  Format.pp_close_box form ();
end

let print_finfo rwi = Dict.fold_dict (fun f info _ -> begin
  Format.pp_print_string form (Identifier.id_to_string f);
  (match info with
    Some inf -> print_info inf;
  | None -> Format.pp_print_string form " : No Spills");
  Format.pp_print_newline form();
  Format.print_flush ();
end) rwi ()

let cc_spilled_reg ifg f r l = 
  if Regifg.reg_cc_spilled ifg r f then r::l else l

(* Gets the information about spills for a given function, inserts it into *)
(* a dictionary. dict : (identifier, fun_spill_info option) Dict.dict *)
let get_spill_info cfg ifg f dict = try
  (* Get the interference graph of the simple spilled regs mentioned in this *)
  (* function.  If there are no such registers, raise NoSpills. *)
  let (sp_ifg, spilled_regs) = Regifg.spill_ifg ifg (Cfg.get_regs cfg f) in begin
    (* Coalesce the live-ranges that don't conflict to save stack space. *)
    while (Regifg.simplifiable sp_ifg || Regifg.coalescable sp_ifg) do begin
(*      Regifg.print_ifg sp_ifg true false true true true; *)
      if Regifg.coalescable sp_ifg then Regifg.coalesce_spills sp_ifg
      else Regifg.simplify_spills sp_ifg
    end
    done;
    (* Get the mapping for simple lr's spilled in this function and the number *)
    (* of slots needed. *)
    let (slots, num_spills) = Regifg.color_spills sp_ifg in 
    let spilled_args = Set.fold (cc_spilled_reg ifg f) (Cfg.get_args cfg f) [] in
    let spilled_rets = Set.fold (cc_spilled_reg ifg f) (Cfg.get_rets cfg f) [] in
    let (old_spills, old_sp_args, old_sp_rets) = 
      match Cfg.get_rewrite_info cfg f with
      	Some rwi -> (rwi.sp_slots, rwi.sp_args, rwi.sp_rets)
      | None -> (0, 0, 0) in
    let fs = {spilled_regs=spilled_regs;
	      slots=slots;
	      num_spills=num_spills;
	      spilled_args=spilled_args;
	      num_sp_args=List.length spilled_args;
	      spilled_rets=spilled_rets;
	      num_sp_rets=List.length spilled_rets;
	      old_spills=old_spills;
	      old_sp_args=old_sp_args;
	      old_sp_rets=old_sp_rets
	    }
    in
    Dict.insert dict f (Some fs)
  end
with
  (* There were no spills in this function, so rewrite is a no-op *)
  Regifg.NoSpills -> begin
    debugdo ("No Spills in: "^(Identifier.id_to_string f)); 
    
    Dict.insert dict f None
  end

let empty_block = 
  let id = Identifier.id_new "EmptyBlock" in
  Cfg.make_block id None id (Array.create 0 Nop)

let get_callee cfg edgeset = Set.fold (fun (_, et, t) block ->
  (match et with
    CallEdge -> Cfg.get_block cfg t
  | _ -> block))
  edgeset empty_block

let get_block_after_call cfg edgeset =
  Set.fold (fun (_, et, t) block ->
    (match et with
      CallSequence _ -> Cfg.get_block cfg t
    | UnknownCall _ -> Cfg.get_block cfg t
    | _ -> failwith "rewrite.ml: get_block_after_call no such block"))
    edgeset empty_block

let get_sequence cfg edgeset = Set.fold (fun (_, et, t) target ->
  (match et with
    Sequence -> Some t
  | _ -> target))
    edgeset None

let get_reg_index r list =
  let rec get_reg_index_h l i =
    match l with
      []       -> failwith "rewrite.ml: get_reg_index reg not in list"
    | r'::rest -> if (compare_regs r r') = 0 then i else
      get_reg_index_h rest (i-1)
  in
  get_reg_index_h list ((List.length list) - 1)

(* rewrite_spills_b : cf_graph -> (identifier, fun_spill_info) Dict.dict ->     *)
(*                        fun_spill_info -> Talctxt.ctxt -> cf_block -> unit    *)

let rec rewrite_spills_b cfg info fs ctxt block =
  let xa = Xarray.create (Array.length block.code) Nop in 
  let current_ctxt = ref (block_type_to_ctxt ctxt block.con) in
  let ppi i = begin
    Talpp.print_instruction Format.std_formatter Talpp.std_options i;
    Format.pp_force_newline Format.std_formatter ();
    Format.pp_print_flush Format.std_formatter ();
  end
  in
  let emit i = begin
    (try 
(*      ppi i; *)
      current_ctxt := Talverify.verify_instr !current_ctxt i
    with Talverify.Terminal_Jmp -> ()
    | Talverify.Fall_Thru _ -> ());
    Xarray.add xa i;
  end in
  let reg_subst = ref (Dict.empty compare_regs) in
  let subst_emit i = emit (rsub_instr !reg_subst i) in
  let get_reg r = Dict.lookup !reg_subst r in
  let set_reg r1 r2 = reg_subst := Dict.insert !reg_subst r1 r2 in
  let new_virt_reg r =
    match r with
      Virt id ->
	let v = Virt (Identifier.id_new (Identifier.id_to_source id)) in
	begin
	  set_reg r v;
	  v
	end
    | _ -> failwith "rewrite.ml: new_virt_reg called on non-virtual register"
  in
  let found_callstart = ref false in
  let found_callend = ref false in
  let allocating_callee_rets = ref false in
  let allocating_callee_args = ref false in

  let storing_live_regs = ref false in
  let restoring_live_regs = ref false in

  let stack_after_live_alloc = ref empty_stack in
  let stack_offs = ref 0 in
  let saved_live_regs = ref 0 in

  let callee = ref empty_block in
  let callee_fs = ref empty_spill_info in
  let unknown_call = ref false in

  let load_simple_spill r =
    let slot = Dict.lookup fs.slots r in
    let v = new_virt_reg r in
    emit (mov_reg_slot v (slot + !stack_offs))
  in
  let load_spilled_arg r =
    let index = get_reg_index r fs.spilled_args in
    let v = new_virt_reg r in
    emit (mov_reg_slot v 
	    (index + 1 + fs.old_spills + fs.num_spills + !stack_offs))
  in
  let store_simple_spill r = 
    let slot = Dict.lookup fs.slots r in
    let v = try get_reg r with Dict.Absent -> 
      failwith "rewrite.ml: store_simple_spill r not in substitution"
    in
    emit (mov_slot_reg (slot + !stack_offs) v)
  in
  let store_spilled_ret r =
    let index = get_reg_index r fs.spilled_rets in
    let v = try get_reg r with Dict.Absent ->
      failwith "rewrite.ml: store_spilled_ret r not in substitution"
    in
    emit (mov_slot_reg (index + fs.old_sp_args + fs.num_sp_args + 1 + 
			fs.old_spills + fs.num_spills + !stack_offs) v)
  in 
  let store_callee_arg r =
    let index = get_reg_index r (!callee_fs).spilled_args in
    let v = try get_reg r with Dict.Absent ->
      failwith "rewrite.ml: store_spilled_arg r not in substitution"
    in
    emit (mov_slot_reg index v)
  in
  let doing_call = ref false in
  let rewrite_stack_index_before_callstart old_index = 
    (* Check to see if it was an old simple spill or the return address.       *)
    if old_index <= fs.old_spills then
      old_index + fs.num_spills
    (* Check to see if it was an old spilled arg.                              *)
    else if old_index < fs.old_spills + 1 + fs.old_sp_args then
      old_index + fs.num_spills + fs.num_sp_args
    (* Check to see if it was an old spilled ret or it was using fixed call    *)
    else 
      old_index + fs.num_spills + fs.num_sp_args + fs.num_sp_rets
  in
  let rewrite_stack_index old_index =
    if !doing_call then 
      if !unknown_call then
	if old_index < !stack_offs + !saved_live_regs then
	  old_index
	else 
	  let index_before_callstart = old_index 
	      - (!stack_offs + !saved_live_regs) 
	  in
	  (rewrite_stack_index_before_callstart index_before_callstart)
	    + !stack_offs + !saved_live_regs
      else
      let cfs = !callee_fs in
      (* In this case, the slots for saved live-registers and callee's old     *)
      (* spilled args and rets have been allocate already.                     *)
      (* The number of callee args and rets for unknown calls is stored in     *)
      (* stack_offs.                                                           *)

      (* Check to see if it was an old callee arg.                             *)
      if old_index < cfs.old_sp_args then
	old_index + cfs.num_sp_args
      (* Check to see if it was an old callee ret.                             *)
      else if old_index < cfs.old_sp_args + cfs.old_sp_rets then
	old_index + cfs.num_sp_args + cfs.num_sp_rets
      (* Check to see if it was a spilled live register.                       *)
      else if old_index < cfs.old_sp_args + cfs.old_sp_rets 
	                  + !saved_live_regs then
	old_index + cfs.num_sp_args + cfs.num_sp_rets
      else 
      	let index_before_callstart = old_index - 
	    (cfs.old_sp_args + cfs.old_sp_rets + !saved_live_regs)
      	in
      	(rewrite_stack_index_before_callstart index_before_callstart)
	  + cfs.old_sp_args + cfs.num_sp_rets
	  + cfs.num_sp_args + cfs.num_sp_rets
	  + (!saved_live_regs)
    else
      rewrite_stack_index_before_callstart old_index
  in

  let rewrite_block_type cl label =
    (* If the label isn't in the control flow graph, it must be external, so   *)
    (* no rewriting is necessary.                                              *)
    try
      let block = Cfg.get_block cfg label in
      match block.con with
	None -> [] (* This block has no type. *)
      |	Some c ->
	  (* First build up a substitution of type variables instantiated by   *)
          (* the variables in cl.  tvar_subst maps id's to type variables.     *)
	  let tvar_subst = make_var_subst c cl in
	  let stack = Talverify.current_stack_con !current_ctxt in
	  let new_stack = Talcon.substs tvar_subst stack in
	  let new_rs = delete_regset_from_rs fs.spilled_regs (get_rs c) in
	  if block.visited then
	    (* This block's type has already been rewritten.  Compare the new  *)
	    (* version to the old, changing incompatible slots to junk.        *)
	    let old_stack = get_stack c in
	    let (junk_slots, new_stack) = 
	      compare_stacks !current_ctxt new_stack old_stack 
	    in 
	    let new_rs = set_stack_con new_rs new_stack in
	    let new_con = set_code c (code_con new_rs) in
	    begin
	      block.con <- Some new_con;
	      current_ctxt := Talctxt.set_val !current_ctxt label 
		   (snd(Talcon.check !current_ctxt new_con));
	      junk_slots
	    end
	  else
	    let new_rs = set_stack_con new_rs new_stack in
	    let new_con = set_code c (code_con new_rs) in
	    begin
	      block.con <- Some new_con;
	      current_ctxt := Talctxt.set_val !current_ctxt label 
		   (snd(Talcon.check !current_ctxt new_con));
	      []
	    end
    with Failure _ -> []
  in
  let rewrite_control_flow_i i cl target =
    let junk_slots = rewrite_block_type cl target in
    let coercion = List.map (fun i -> Slot (i*4,4)) junk_slots in
    begin
      if not ((List.length junk_slots) = 0) then
	emit (Coerce (Esp, coercion))
      else ();
      subst_emit i;
    end
  in
  let rewrite_spills_i i =
    let (def, use) = Cfginstr.def_use i in
    try begin
      (if !found_callstart then
	(match i with
	  Comment s -> if s = "live alloc done" then
	    begin
	      (* Save the current stack so we can compute the number of args   *)
	      (* and rets for the function call.                               *)
	      stack_after_live_alloc := 
		 Talverify.current_stack_con !current_ctxt;
	      stack_offs := 0;
	      found_callstart := false;
	      allocating_callee_rets := true;
	    end
	  else failwith "rewrite.ml: Bad comment after callstart"
	| ArithBin (Sub, Reg Esp, Immed slots) -> 
	    begin
	      subst_emit i;
	      saved_live_regs := slots/4;
	      raise Done;
	    end
	| _ -> failwith "rewrite.ml: Bad instruction after callstart")
      else if !allocating_callee_rets then
	(match i with
	  Comment s -> if s = "ret alloc done" then
	    begin
	      allocating_callee_rets := false;
	      allocating_callee_args := true;
	    end
	  else failwith "rewrite.ml: Bad comment after live alloc done"
	| ArithBin (Sub, Reg Esp, Immed _) ->
	    begin
	      emit (stack_alloc 
		      (!callee_fs.old_sp_rets + !callee_fs.num_sp_rets));
	      raise Done;
	    end
	| _ -> 
	    begin
	      allocating_callee_rets := false;
	      allocating_callee_args := true;
	      emit (stack_alloc 
		      (!callee_fs.old_sp_rets + !callee_fs.num_sp_rets));
	      emit (Comment "ret alloc done");
	    end)
      else if !allocating_callee_args then
	(match i with
	  Comment s -> if s = "arg alloc done" then
	    allocating_callee_args := false
	  else 
	    begin
	      allocating_callee_args := false;
	      emit (stack_alloc 
		      (!callee_fs.old_sp_args + !callee_fs.num_sp_args));
	      emit (Comment "arg alloc done");
	    end
	| ArithBin (Sub, Reg Esp, Immed _) ->
	    begin
	      emit (stack_alloc 
		      (!callee_fs.old_sp_args + !callee_fs.num_sp_args));
	      raise Done;
	    end
	| _ -> 
	    begin
	      allocating_callee_args := false;
	      emit (stack_alloc 
		      (!callee_fs.old_sp_args + !callee_fs.num_sp_args));
	      emit (Comment "arg alloc done");
	    end)
      else if !storing_live_regs then
	(match i with
	  Mov (Prjr((Esp, cl1), bytes), (Reg r, cl2)) ->
	  (* A stack slot is being filled from a register, update its offset *)
	    if Set.member fs.spilled_regs r then
	      (* We don't need to save this live register, so 'comment it out' *)
	      let index = (rewrite_stack_index (bytes/4))*4 in
	      begin
		emit (Comment "delete next instr");
		subst_emit (Mov (Prjr((Esp, cl1), index), (Immed 0, [])));
		raise Done;
	      end
	    else ()
	| _ -> ())
      else if !restoring_live_regs then
	(match i with
	  Pop (Reg r) ->
	    if Set.member fs.spilled_regs r then
	      let junk_reg = Virt (Identifier.id_new "junk") in
	      begin
		emit (Comment "delete next instr");
		emit (stack_free 1);
		raise Done;
	      end
	    else ()
	| _ -> ())
      else ());
      (* If any of the registers used by the instruction have been spilled, we *)
      (* need to load them from their spill locations.                         *)
      (* First load the simple spills                                          *)
      Set.app load_simple_spill (Set.intersect use fs.spilled_regs);
      (* Then load any spilled args                                            *)
      Set.app load_spilled_arg 
	(Set.intersect use (Set.from_list compare_regs fs.spilled_args));
      (* Generate new temporaries for spilled virtual registers defined in i   *)
      Set.app new_virt_reg (Set.diff (Set.intersect fs.spilled_regs def) use); 
      Set.app new_virt_reg (Set.intersect 
			      (Set.from_list compare_regs fs.spilled_rets) def);
      Set.app new_virt_reg (Set.intersect
			      (Set.from_list compare_regs 
				 (!callee_fs).spilled_args)
			      def);
      (match i with
	Comment s -> if s = "callstart" then
	  begin
	    debugdo "found callstart";
	    found_callstart := true;
	    doing_call := true;
	    callee := get_callee cfg block.succ;
	    callee_fs := 
	       (try
		 (match Dict.lookup info (!callee).fun_lab with
		   None -> empty_spill_info
		 | Some fs' -> fs')
	       (* If it's an unknown call then we get empty_spill_info          *)
	       with Dict.Absent -> begin
		 unknown_call := true;
		 empty_spill_info
	       end);
	    subst_emit i;
	  end
	else if s = "live save begin" then
	  begin
	    storing_live_regs := true;
	    subst_emit i;
	  end
	else if s = "callend" then 
	  begin
	    restoring_live_regs := true;
	    subst_emit i;
	  end
	else if s = "live restore done" then
	  begin
	    stack_offs := 0;
	    restoring_live_regs := false;
	    subst_emit i;
	  end
	else subst_emit i
      (* The following instructions modify the stack offset                     *)
      |	Push _ -> begin incr stack_offs; subst_emit i end
      |	Pop _ -> begin decr stack_offs; subst_emit i end
      |	ArithBin (Add, Reg Esp, Immed index) ->
	  begin
	    stack_offs := !stack_offs - (index/4);
	    subst_emit i;
	  end
      |	ArithBin (Sub, Reg Esp, Immed index) ->
	  begin
	    stack_offs := !stack_offs + (index/4);
	    subst_emit i;
	  end
      |	Call (op, cl) -> 
	  (* For calls, we have to instantiate the type application hiding the *)
	  (* spilled live registers and the local spills.  To do this, we get  *)
	  (* the current stack and remove the argument and return slots.  The  *)
	  (* number of argument and return slots is calculated by saving the   *)
	  (* stack after live-register save-slots have been allocated and then *)
	  (* comparing that stack with the current one.                        *)
	  let current_stack = Talverify.current_stack_con !current_ctxt in
(*	  let current_stack_size = stack_size current_stack in
	   let before_callee_stack_size = stack_size !stack_after_live_alloc in
	   let callee_slots = current_stack_size - before_callee_stack_size in *)
	  let callee_slots = !stack_offs in
	  if callee_slots < 0 then
	    failwith "rewrite.ml: rewrite_spills_i callee popped caller slots"
	  else
	    (* Pop the callee's args and rets off the current stack            *)
	    let instantiate_stack = pop_stack current_stack callee_slots in
	    subst_emit (Call (op, fix_cl !current_ctxt cl
				(fun _ -> instantiate_stack)))
      |	Mov (Reg r, (Prjr((Esp, cl1), bytes), cl2)) ->
	  (* A register is being loaded from the stack, we have to update its  *)
	  (* offset. *)
	  let index = (rewrite_stack_index (bytes/4))*4 in
	  subst_emit (Mov (Reg r, (Prjr((Esp, cl1), index), cl2)))
      |	Mov (Prjr((Esp, cl1), bytes), (Reg r, cl2)) ->
	  (* A stack slot is being filled from a register, update its offest *)
	  let index = (rewrite_stack_index (bytes/4))*4 in
	  subst_emit (Mov (Prjr((Esp, cl1), index), (Reg r, cl2)))
      |	Coerce (Esp, cl) -> () (* FIX THIS TOO TOO *)
      |	Loopd _ -> failwith "rewrite.ml: loopd not yet supported"
      |	Retn _ -> 
	  begin
	  (* First emit code to pop the simple spills from this iteration.      *)
	    subst_emit (stack_free fs.num_spills);
	  (* Now rewrite the return to pop the spilled arguments.               *)
	    let args = fs.num_sp_args + fs.old_sp_args in
	    subst_emit (if args = 0 then (Retn None) else (Retn (Some (args*4))))
	  end
      |	Jcc (_, (target, cl)) -> rewrite_control_flow_i i cl target
      |	Jecxz (target, cl) -> rewrite_control_flow_i i cl target
      |	Jmp (Addr target, cl) -> rewrite_control_flow_i i cl target
      |	Bexn (r, _, (target, cl)) -> rewrite_control_flow_i i cl target
      |	Btagi (_, _, (target, cl), _) -> rewrite_control_flow_i i cl target
      |	Btagvar (_, _, _, (target, cl), _) -> rewrite_control_flow_i i cl target
      |	Fallthru cl ->
	  let target = match get_sequence cfg block.succ with
	    Some t -> t
	  | None -> failwith "rewrite.ml: rewrite_spills_i Fallthru no sequence"
	  in
	  let coercions = List.map (fun con -> Tapp con) cl in
	  rewrite_control_flow_i i coercions target
      |	_ -> subst_emit i);

      (* Here we have to store any virtual registers defined by the instruction *)
      Set.app store_simple_spill (Set.intersect def fs.spilled_regs);
      Set.app store_spilled_ret (Set.intersect 
			  (Set.from_list compare_regs fs.spilled_rets) def);
      Set.app store_callee_arg (Set.intersect
			  (Set.from_list compare_regs (!callee_fs).spilled_args)
			  def);
    end with Done -> ()
  in
  let call_rec (_, et, l) =
    let next_block = Cfg.get_block cfg l in
    match et with
      (Return | SelfTailCall | CallEdge | TailCall) -> ()
    | (Jump | Branch | Sequence | UnknownCall _ | CallSequence _) -> 
	rewrite_spills_b cfg info fs !current_ctxt next_block
  in
  begin
    debugdo ("rewriting spills in: "^
		  (Identifier.id_to_string block.lab));
    if block.visited then () else begin
      (* Pop return registers off the stack. *)
      block.visited <- true;
      Array.iter rewrite_spills_i block.code;
      block.code <- Xarray.to_array xa;
      Set.app call_rec block.succ;
    end
  end

let add_header cfg f_entry n = 
  let code = Array.of_list 
      [Comment ("Header block for "^(Identifier.id_to_source f_entry.lab));
       stack_alloc(n);
       Fallthru (get_type_arglist f_entry.con)] 
  in
  let new_block = Cfg.make_block 
      (Identifier.id_new (Identifier.id_to_source f_entry.lab))
        f_entry.con f_entry.lab f_entry.code 
  in
  begin
    (* Add the new block to the graph *)
    Cfg.add_block cfg new_block;
    (* Move the out-edges from f_entry to new_block *)
    Set.app (fun (source, et, target) -> begin 
      Cfg.add_edge cfg (new_block.lab, et, target);
      Cfg.del_edge cfg (source, et, target);
    end) f_entry.succ;
    (* Add the successor edge from the entry to the new block *)
    Cfg.add_edge cfg (f_entry.lab, Sequence, new_block.lab);
    (* Set the entry block's code to be the header code *)
    f_entry.code <- code;
    f_entry.header <- true;
  end

let update_header f_entry n =
  let code = Array.of_list
      [Comment ("Updated header block for "^
		(Identifier.id_to_source f_entry.lab));
       stack_alloc(n);
       Fallthru (get_type_arglist f_entry.con)
     ] 
  in
  if not (f_entry.header) then 
    failwith "rewrite.ml: update_header got non-header block"
  else f_entry.code <- code

let rewrite_spills_f cfg info ctxt f =
  (* Check to see whether this function actually has spills.                    *)
  match Dict.lookup info f with
    Some fs ->
      let block = Cfg.get_block cfg f in
      begin
	debugdo ("Rewriting spills in: "^(Identifier.id_to_string block.lab));
        (* If there are simple spills, then add or update the header block.     *)
      	if fs.num_spills > 0 then
	  if fs.old_spills > 0 then
	    update_header block (fs.num_spills + fs.old_spills)
	  else add_header cfg block fs.num_spills
	else ();
	(* Now rewrite the code in the block.                                   *)
	rewrite_spills_b cfg info fs ctxt block;
      end
  | None -> ()

let add_proc_to_context cfg proc ctxt =
  let block = Cfg.get_block cfg proc in
  match block.con with
    None -> failwith "rewrite.ml: add_proc_to_context proc has no type"
  | Some c -> Talctxt.set_val ctxt proc (snd(Talcon.check ctxt c))

let rewrite_function_type cfg info f =
  match Dict.lookup info f with
    None -> ()
  | Some fs ->
      if (fs.num_sp_args > 0) || (fs.num_sp_rets > 0) then
      	let block = Cfg.get_block cfg f in
      	let old_type = match block.con with
	  Some c -> c
      	| None   -> failwith "rewrite.ml: rewrite_function_type f has no type"
      	in
      	let old_entry_rs = get_rs old_type in
      	let old_return_rs = get_return_rs old_type in
	let old_entry_stack = get_stack_con old_entry_rs in
	let old_return_stack = get_stack_con old_return_rs in
	let new_ret_types = 
	  List.map (fun r -> rs_get_reg old_return_rs r) fs.spilled_rets
	in
	let junk_slots = List.map (fun r -> junk_con) fs.spilled_rets in
	let new_arg_types =
	  List.map (fun r -> rs_get_reg old_entry_rs r) fs.spilled_args
	in
	let new_return_stack = 
	  insert_list_stack old_return_stack 0 new_ret_types
	in
	let new_return_rs = 
	  set_stack_con (delete_regset_from_rs 
			   (Set.from_list compare_regs fs.spilled_rets)
						  old_return_rs)
	    new_return_stack
	in
	let new_return_address = code_con new_return_rs in
	let new_entry_stack = cons_stack new_return_address
	    (insert_list_stack 
	       (insert_list_stack (tl_stack old_entry_stack) 
		  fs.old_sp_args junk_slots)
	       0 new_arg_types)
	in
	let new_entry_rs =
	  set_stack_con (delete_regset_from_rs 
			   (Set.from_list compare_regs fs.spilled_args)
			   old_entry_rs)
	    new_entry_stack
	in 
	let new_type = set_code old_type (code_con new_entry_rs) in
	block.con <- Some new_type
      else
	()
      
let update_spill_info cfg info f =
  match Dict.lookup info f with
    Some fs -> Cfg.set_rewrite_info cfg f {sp_slots=fs.num_spills+fs.old_spills;
					   sp_args=fs.num_sp_args+fs.old_sp_args;
					   sp_rets=fs.num_sp_rets+fs.old_sp_rets}
  | None -> ()

(* Adds spill code to the given control flow graph based on the information     *)
(* in the given interference graph.  If there are no live-ranges in             *)
(* ifg.spilled, this function is the identity.                                  *)
let rewrite_spills cfg ifg = 
  let info = try Set.fold (get_spill_info cfg ifg) cfg.procs  
      (Dict.empty Identifier.id_compare)
  with Dict.Absent -> failwith "rewrite.ml: rewrite_spills raised Absent"
  in
  begin
    print_string "REWRITE SPILLS\n";
    Format.print_flush();
    if debug then print_finfo info else ();
    (* First rewrite the types of functions to reflect spilled cc registers.    *)
    Set.app (rewrite_function_type cfg info) cfg.procs;
    (* Set the visited flags to false.                                          *)
    Cfg.clear_visited_flags cfg;
    (* Add the revised types of the functions to the cfg context.               *)
    let ctxt = Set.fold (add_proc_to_context cfg) cfg.procs cfg.context in 
       (* Rewrite each function.                                                *)
       Set.app (rewrite_spills_f cfg info ctxt) cfg.procs;
    (* Now update the spill information *)
    Set.app (update_spill_info cfg info) cfg.procs;
    cfg
  end

let rec save_live_regs_b cfg block ctxt =
  let xa = Xarray.create (Array.length block.code) Nop in 
  let current_ctxt = ref (block_type_to_ctxt ctxt block.con) in
  let ppi i = begin
    Talpp.print_instruction Format.std_formatter Talpp.std_options i;
    Format.pp_force_newline Format.std_formatter ();
    Format.pp_print_flush Format.std_formatter ();
  end
  in
  let emit i = begin
(*    ppi i; *)
    (try 
      current_ctxt := Talverify.verify_instr !current_ctxt i
    with Talverify.Terminal_Jmp -> ());
    Xarray.add xa i;
  end in
  let instr_index = ref 0 in
  let num_args_on_stack = ref 0 in
  let regs_to_save = ref (Set.empty compare_regs) in
  let stack_before_args = ref empty_stack in
  let set_regs_to_save block =
    let helper (_, et, _) =
      match et with
	CallSequence rets -> 
	  regs_to_save := Set.delete (Set.diff block.live_in rets) Esp
      |	UnknownCall c ->
	  regs_to_save := Set.diff block.live_in (get_ret_registers c)
      |	_ -> ()
    in
    Set.app helper block.pred
  in
  let get_live_register_types () =
    let helper r (rlist, clist) =
      (r::rlist, (Talctxt.get_reg_con !current_ctxt r)::clist)
    in
    Set.fold helper !regs_to_save ([], [])
  in
  let save_live_regs_i i = begin
    (match i with 
      Comment s -> 
	if s = "callend" then begin
	  emit i;
	  (* Restore the saved registers *)
	  set_regs_to_save block;
	  Set.app (fun r -> emit (Pop (Reg r))) !regs_to_save;
	  emit (Comment "live restore done");
	end 
	else if s = "callstart" then
	  let block_after_call = get_block_after_call cfg block.succ in 
	    (* CHANGED from below to make deletion of instructions possible *)
	    (* Cfg.get_call_site block !instr_index in*)
	  begin
	    stack_before_args := Talverify.current_stack_con !current_ctxt;
	    set_regs_to_save block_after_call;
	    emit i;
	    emit (stack_alloc (Set.cardinality !regs_to_save));
	    emit (Comment "live alloc done");
	    num_args_on_stack := 0;
	  end
	else emit i
    | Call (op, cl) ->
	let (regs, cons) = get_live_register_types () in
	let (regs, cons) = (List.rev regs, List.rev cons) in
	let stack_with_live_regs = insert_list_stack !stack_before_args 0 cons in
	begin
	  emit (Comment "live save begin");
	  List.fold_left (fun i r -> begin
	    emit (mov_slot_reg i r);
	    i+1
	  end) !num_args_on_stack regs;
	  emit (Call (op, fix_cl !current_ctxt cl 
		     (fun _ -> stack_with_live_regs)));
	end
    | Push _ -> begin incr num_args_on_stack; emit i end
    | Pop _ -> begin decr num_args_on_stack; emit i end
    | ArithBin (Add, Reg Esp, Immed index) -> 
	begin 
	  num_args_on_stack := !num_args_on_stack - (index/4);
	  emit i 
	end
    | ArithBin (Sub, Reg Esp, Immed index) -> 
	begin 
	  num_args_on_stack := !num_args_on_stack + (index/4);
	  emit i 
	end
    | Fallthru _ -> Xarray.add xa i
    | Retn _ -> Xarray.add xa i
    | _ -> emit i);
    incr instr_index;
  end
  in
  let call_rec (_, et, b) =
    let next_block = Cfg.get_block cfg b in
    match et with
      (Return | SelfTailCall | CallEdge | TailCall) -> ()
    | (Jump | Branch | Sequence | UnknownCall _ | CallSequence _) -> 
	save_live_regs_b cfg next_block !current_ctxt
  in
  begin
    debugdo ("saving live regs in: "^
		  (Identifier.id_to_string block.lab));
    if block.visited then () else begin
      block.visited <- true;
      Array.iter save_live_regs_i block.code;
      block.code <- Xarray.to_array xa;
      Set.app call_rec block.succ;
    end
  end
    
let save_live_regs_f cfg ctxt f =
  let block = Cfg.get_block cfg f in 
  (* Set the visited flags to false, add labels to context. *)
  let add_block_to_context block context =
    match block.con with
      None -> context
    | Some c -> Talctxt.set_val context block.lab 
	  (snd(Talcon.check context c))
  in
  let ctxt = Cfg.fold_intra add_block_to_context cfg f ctxt in 
  begin
    Cfg.clear_visited_flags cfg;
    save_live_regs_b cfg block ctxt
  end

  
let save_live_regs cfg = begin
  print_string "SAVING LIVE REGISTERS\n";
  Format.print_flush();
  let ctxt = Set.fold (add_proc_to_context cfg) cfg.procs cfg.context in 
    Set.app (save_live_regs_f cfg ctxt) cfg.procs;
end


(* EOF: regrewrite.ml *)
