(**********************************************************************)
(* (c) Greg Morrisett, Steve Zdancewic                                *)
(*     June 1998, all rights reserved.                                *)
(**********************************************************************)

(* cfg.ml
 *
 * The datatype for TAL-level control flow graphs.
 *)

open Utilities
open Identifier
open Tal
open Set

exception CfgError of string

(*******************************************************************************)
(* Control flow edges *)

type cf_edge_type = 
    CallEdge        (* A procedure call edge *)
  | TailCall        (* Call to a different procedure's entry block *)
  | SelfTailCall    (* Recursive (loop) tail call--bypass any header block *)
  | Return          (* Return from a procedure call *)
  | UnknownCall of con  (* Cross module or unknown function call *)
 	                (* code type with no virtual registers *)              
  | Jump            (* Unconditional jump, no live registers will be saved *)
  | Branch          (* Conditional jump, no live registers will be saved *)
  | Sequence        (* Link between blocks that must be sequentially layed out*)
                    (* e.g. branches and Fallthru's *)
  | CallSequence of reg set  (* Links the call block to the return block, with *)
                             (* the return registers in the set *)

type cf_edge = identifier * cf_edge_type * identifier

let compare_edges = compare

(*******************************************************************************)
(* Control flow (basic) blocks *)

type cf_block = {
    mutable lab: identifier;   (* The unique TAL identifier for the block *)
    mutable con: con option;    (* The type constructor for the entry point *)
    mutable fun_lab: identifier; (* The function to which the block belongs *)
    mutable code: instruction vector;   (* The actual TAL code *)
    mutable call_sites: (int, cf_block) Dict.dict; (* map of callstart sites to *)
                                                (* the block following the call *)

    mutable pred: cf_edge set;  (* The set of incoming edges *)
    mutable succ: cf_edge set;  (* The set of outgoing edges *)

    mutable use: reg set;   (* Registers used before being defined *)
    mutable def: reg set;   (* Registers defined in the block *)
    mutable live_in: reg set; 
    mutable live_out: reg set;

    mutable header: bool;
    mutable visited: bool
  } 
    
(* Those blocks which end in unknown jumps, i.e. for exception handlers, *)
(* should have empty successor sets *)

let compare_blocks b1 b2 = id_compare b1.lab b2.lab

let make_block l c f code = {
  lab = l;
  con = c;
  fun_lab = f;
  code = code;
  call_sites = Dict.empty compare;

  pred = Set.empty compare_edges;
  succ = Set.empty compare_edges;
  
  use = Set.empty compare_regs;
  def = Set.empty compare_regs;
  live_in = Set.empty compare_regs;
  live_out = Set.empty compare_regs;

  header = false;
  visited = false
} 

let set_con b c =
  b.con <- c

let set_code b c =
  b.code <- c

let add_pred b ((_, _, l2) as e) =
  if not (b.lab = l2) then raise (CfgError "cfg:add_pred: labels don't match")
  else b.pred <- Set.insert b.pred e

let add_succ b ((l1, _, _) as e) =
  if not (b.lab = l1) then raise (CfgError "cfg:add_succ: labels don't match")
  else b.succ <- Set.insert b.succ e

let del_pred b ((_, _, l2) as e) =
  if not (b.lab = l2) then raise (CfgError "cfg:del_pred: labels don't match")
  else b.pred <- Set.delete b.pred e

let del_succ b ((l1, _, _) as e) =
  if not (b.lab = l1) then raise (CfgError "cfg:del_succ: labels don't match")
  else b.succ <- Set.delete b.succ e

let add_call_site b index block =
  b.call_sites <- Dict.insert b.call_sites index block

let get_call_site b index =
  try Dict.lookup b.call_sites index 
  with Dict.Absent -> failwith ("cfg.ml: get_call_site bad index: "^
				(string_of_int index))

(*******************************************************************************)
(* Individual instructions *)

type cf_instr = identifier * int
let compare_instrs = compare

(*******************************************************************************)
(* Control flow graphs *)

type rewrite_info = {
    mutable sp_slots: int;  (* The number of spill slots used in this function *)
    mutable sp_args: int; (* The number of spilled arguments *)
    mutable sp_rets: int  (* The list of spilled returns *)
  } 

type cf_graph = {
    mutable context: Talctxt.ctxt;   (* The context of the cfg - contains *)
                                     (* imported types *)
    mutable blocks: (identifier, cf_block) Dict.dict;  (* Blocks in the graph *)
    mutable roots: identifier set;   (* The entry blocks of the module *)
    mutable procs: identifier set;   (* The set of procedure labels *)
                                     (* defined in the control flow graph *)
    (* A map from function labels to their argument registers *)
    mutable args: (identifier, reg set) Dict.dict;
    (* A map from function labels to their return registers *)
    mutable rets: (identifier, reg set) Dict.dict;
    (* A map from function labels to the set of regs mentioned in their code *)
    mutable regs: (identifier, reg set) Dict.dict;
    (* A map from function labels to rewrite information about that function *)
    mutable rewrite: (identifier, rewrite_info) Dict.dict;

    mutable df_order: identifier list option; (* Depth-first order from root *)
    mutable rev_order: identifier list option (* Reverse depth-first ordering *)
  } 

let empty_cfg () = {
  context = Talctxt.empty_ctxt;
  blocks = Dict.empty id_compare;
  roots = Set.empty id_compare;
  procs = Set.empty id_compare;
  args = Dict.empty id_compare;
  rets = Dict.empty id_compare;
  regs = Dict.empty id_compare;
  rewrite = Dict.empty id_compare;

  df_order = None;
  rev_order = None
} 

let set_context cfg ctxt = cfg.context <- ctxt
let get_context cfg = cfg.context

let add_block cfg block = begin 
  cfg.blocks <- Dict.insert_new cfg.blocks block.lab block;
  cfg.df_order <- None;
  cfg.rev_order <- None;
end

let del_edge cfg ((l1, _, l2) as e) =
  let b1 = Dict.lookup cfg.blocks l1 in
  let b2 = Dict.lookup cfg.blocks l2 in begin
    del_succ b1 e;
    del_pred b2 e;
    cfg.df_order <- None;
    cfg.rev_order <- None;
  end

let del_block cfg id = try 
  let block = Dict.lookup cfg.blocks id in
  begin
    Set.app (del_edge cfg) block.pred;
    Set.app (del_edge cfg) block.succ;
    cfg.roots <- Set.delete cfg.roots id;
    cfg.blocks <- Dict.delete cfg.blocks id;
    cfg.df_order <- None;
    cfg.rev_order <- None;
  end
with Dict.Absent ->
  raise (CfgError "cfg: del_block - block not found")
    
let add_edge_bb cfg b1 et b2 =
  let e = (b1.lab, et, b2.lab) in begin
    add_succ b1 e;
    add_pred b2 e;
    cfg.df_order <- None;
    cfg.rev_order <- None;
  end

let add_edge cfg ((l1, _, l2) as e) = try
  let b1 = Dict.lookup cfg.blocks l1 in
  let b2 = Dict.lookup cfg.blocks l2 in begin
    add_succ b1 e;
    add_pred b2 e;
    cfg.df_order <- None;
    cfg.rev_order <- None;
  end
with Dict.Absent -> failwith "cfg.ml: add_edge source or target not in cfg"  


let add_root cfg l = begin
  cfg.roots <- Set.insert cfg.roots l;
  cfg.df_order <- None;
  cfg.rev_order <- None;
end

let del_root cfg l = begin
  cfg.roots <- Set.delete cfg.roots l;
  cfg.df_order <- None;
  cfg.rev_order <- None;
end

let add_proc cfg l = cfg.procs <- Set.insert cfg.procs l

let del_proc cfg l = cfg.procs <- Set.delete cfg.procs l

let add_reg map f r = try 
  let s = Dict.lookup map f in
  Dict.insert map f (insert s r)
with Dict.Absent ->
  Dict.insert map f (Set.singleton compare_regs r)

let del_reg map f r = try 
  let s = Dict.lookup map f in
  Dict.insert map f (Set.delete s r)
with Dict.Absent ->
  map

let add_regset map f s = try 
  let s' = Dict.lookup map f in
  Dict.insert map f (union s s')
with Dict.Absent ->
  Dict.insert map f s

let del_regset map f s = try 
  let s' = Dict.lookup map f in
  Dict.insert map f (diff s' s)
with Dict.Absent ->
  map

let add_arg cfg f r = cfg.args <- (add_reg cfg.args f r)
let del_arg cfg f r = cfg.args <- (del_reg cfg.args f r)
let add_ret cfg f r = cfg.rets <- (add_reg cfg.rets f r)
let del_ret cfg f r = cfg.rets <- (del_reg cfg.rets f r)
let add_regs cfg f s = cfg.regs <- (add_regset cfg.regs f s)
let del_regs cfg f r = cfg.regs <- (del_regset cfg.regs f r)

let get_regset map f = try
  Dict.lookup map f 
with Dict.Absent ->
  Set.empty compare_regs

let get_args cfg = get_regset cfg.args
let get_rets cfg = get_regset cfg.rets
let get_regs cfg = get_regset cfg.regs

let set_rewrite_info cfg f rwi =
  cfg.rewrite <- Dict.insert cfg.rewrite f rwi
  
let get_rewrite_info cfg f = try
  let rwi = Dict.lookup cfg.rewrite f in
  Some rwi
with Dict.Absent ->
  None

let get_block cfg l = try 
  Dict.lookup cfg.blocks l
with Dict.Absent -> failwith "cfg.ml: get_block block not found"

let get_instr cfg (l, i) = 
  let b = get_block cfg l in
  b.code.(i)

let set_df_order cfg l = cfg.df_order <- Some l
let set_rev_order cfg l = cfg.rev_order <- Some l

let clear_visited_flags cfg =
  Dict.app_dict (fun _ b -> b.visited <- false) cfg.blocks

let fold f cfg b =
  match cfg.df_order with
    Some l -> List.fold_left (fun acc lab -> f (get_block cfg lab) acc) b l
  | None ->
      let order = ref [] in
      let rec df block f acc =
	if block.visited then acc else begin
	  block.visited <- true;        (* Mark this block as seen *)
	  order := (block.lab)::(!order);  (* Record the order for the future *)
	  (* Now fold over the successors of this block *)
	  Set.fold (fun (_, _, l) acc' -> df (get_block cfg l) f acc')
	    block.succ (f block acc)
	end in
      begin
	(* Mark all of the blocks as unseen *)
	clear_visited_flags cfg;
	(* Fold the depth-first search over the roots *)
	let a = Set.fold (fun l acc' -> df (get_block cfg l) f acc') cfg.roots b
 	in begin
	  (* Save the order in which we traversed the nodes *)
	  cfg.df_order <- Some (List.rev (!order));
	  cfg.rev_order <- Some !order;
	  a
	end
      end

let fold_intra f cfg id b =
  let rec df block f acc =
    if block.visited then acc else begin
      block.visited <- true;        (* Mark this block as seen *)
	(* Now fold over the successors of this block *)
      Set.fold (fun (_, et, l) acc' -> 
	match et with
	  (CallEdge | TailCall | Return | SelfTailCall) -> acc'
	| (TailCall | Jump | Branch | Sequence | CallSequence _ | UnknownCall _)
	  ->
	    df (get_block cfg l) f acc')
	block.succ (f block acc)
    end in
  begin
    (* Mark all of the blocks as unseen *)
    clear_visited_flags cfg;
    (* Fold the depth-first search over the root block *)
    df (get_block cfg id) f b
  end

let app f cfg =
  match cfg.df_order with
    Some l -> List.iter (fun lab -> f (get_block cfg lab)) l
  | None ->
      let order = ref [] in
      let rec df block f =
	if block.visited then () else begin
	  block.visited <- true;        (* Mark this block as seen *)
	  order := (block.lab)::(!order);  (* Record the order for the future *)
	  f block; (* Apply the function to this block *)
	  (* Now iterate over the successors of this block *)
	  Set.app (fun (_, _, l) -> df (get_block cfg l) f) block.succ
      	end in
      begin
	(* Mark all of the blocks as unseen *)
	clear_visited_flags cfg;
	(* Iterate the depth-first search over the roots *)
	Set.app (fun l -> df (get_block cfg l) f) cfg.roots;
	(* Save the order in which we traversed the nodes *)
	cfg.df_order <- Some (List.rev (!order));
	cfg.rev_order <- Some !order;
      end

	      
let rev_fold f cfg b =
  let l = match cfg.rev_order with
    (* We need to calculate the reverse depth-first order *)
    None -> begin 
      app (fun b -> ()) cfg;
      match cfg.rev_order with
	Some l -> l
      |	None -> failwith "cfg:rev_fold shouldn't get here"
    end
  | Some l -> l
  in
  List.fold_left (fun acc lab -> f (get_block cfg lab) acc) b l
      
let rev_app f cfg =
  let l = match cfg.rev_order with
    (* We need to calculate the reverse depth-first order *)
    None -> begin 
      app (fun b -> ()) cfg;
      match cfg.rev_order with
	Some l -> l
      |	None -> failwith "cfg:rev_fold shouldn't get here"
    end
  | Some l -> l
  in
  List.iter (fun lab -> f (get_block cfg lab)) l

let get_succ_ids eset = 
  let seqflag = ref false in
  let filter (source, et, target) r =
    match et with
      (CallEdge | TailCall | SelfTailCall | Return) -> r
    | (Jump | Branch) -> target::r
    | (Sequence | CallSequence _ | UnknownCall _) -> 
	if !seqflag then failwith "rewrite.ml: two sequence edges" else
	begin
	  seqflag := true;
	  target::r
	end
  in
  Set.fold filter eset []

let get_seq eset =
  Set.fold (fun (_, et, id) a -> match et with 
    (Sequence | CallSequence _ | UnknownCall _) -> Some id | _ -> a)
    eset None

let rec get_seq_chain cfg block =
  match get_seq block.succ with
    None -> [block]
  | Some id -> block :: (get_seq_chain cfg (get_block cfg id))
			  
let cfg_to_code_blocks cfg =
  let xa = Xarray.create 10 (Identifier.id_new "bogus", 
			     None,
			     ref (Xarray.create 1 Nop)) in
  let i = ref 0 in
  let rec trav id =
    let block = get_block cfg id in
    if block.visited then [] else begin
      block.visited <- true;
      let rest = List.fold_left (fun list id -> (trav id)@list) [] 
	  (get_succ_ids block.succ) in
      match get_seq block.pred with
      	None -> (get_seq_chain cfg block)@rest
      | Some _ -> rest
    end
  in 
  let emit_block b =
    match b.con with
      None -> (* Append the code in this block to the code for the last *)
	let (_, _, x) = Xarray.get xa (!i-1) in
	x := Xarray.append !x (Xarray.from_array b.code);
    | Some con -> (* This is the start of a new TAL block *)
	begin
	  Xarray.add xa (b.lab, Some con, ref (Xarray.from_array b.code));
	  incr i;
	end
  in
  let fix_xblock (id, c, x) = (id, c, Xarray.to_array !x) in
  begin
    (* Mark all of the blocks as unseen *)
    clear_visited_flags cfg;
    List.iter emit_block
      (Set.fold (fun id idlist -> (trav id)@idlist) cfg.procs []);
    Xarray.to_array (Xarray.map fix_xblock xa)
  end
  

(* EOF: cfg.mli *)


