(**********************************************************************)
(* (c) Greg Morrisett, Neal Glew, David Walker,                       *)
(*     June 1998, all rights reserved.                                *)
(**********************************************************************)

(* talcon.ml
 * TAL kind & type constructor verifier and manipulation utilities
 *
 * Kinds: subkind & equal kinds.
 * Type constructors: well formed of kind, nomalise, alpha equality
 * Utilities: unroll, size, stack size, seperate function type
 *)
 
open Utilities;;
open Identifier;;
open Tal;;
open Talctxt;;

(*** Statistics/Profiling Infrastructure ***)

let stats : (int ref * string) list ref = ref []
let click r = (fun () -> r := (!r) + 1);;
let new_stat s = 
  let r = ref 0 in begin stats := (r,s)::(!stats); click r end;;
let print_stats () = 
  List.iter 
    (fun (r,s) -> print_string s; print_string ": "; print_int (!r); 
      print_string "\n") (!stats)

let rename_count = new_stat "rename count";;
let fast_rename_count = new_stat "fast rename count";;
let sizeof_count = new_stat "sizeof"
let kindleq_count = new_stat "kindleq";;
let kindeq_count = new_stat "kindeq";;
let subst_one_count = new_stat "substitutions of one constructor";;
let subst_one_cutoff = new_stat "cutoff of substitutions of one constructor";;
let subst_count = new_stat "substitutions";;
let subst_cutoff = new_stat "substitution cutoffs";;
let check_count = new_stat "kind check";;
let whnorm_count = new_stat "weak-head normalization";; 
let fast_whead_count = new_stat "fast weak-head normalizations";; 
let normalize_count = new_stat "full normalizations";;
let norm_count = new_stat "norm calls";;
let check_whnorm_count = new_stat "kind check & weak-head normalize";;
let alphaeqcon_count = new_stat "alphaeqcon";;
let fast_alphaeq_count = new_stat "fast alphaeq";;
let alphaeqcon_nodes = new_stat "alphaeqcon nodes";;
let eqcon_count = new_stat "eqcon";;
let fast_eqcon_count = new_stat "fast eqcon";;
let unroll_count = new_stat "unroll count";;

(*************************************************************************)
(* k1 <= k2, k1 = k2, and kind meet                                      *)
(*************************************************************************)

(* kindleq k1 k2:
 *   if     k1 <= k2 then ()
 *   if not k1 <= k2 then generate error Kindleq (k1,k2)
 *)

let rec kindleq ctxt k1 k2 =
  kindleq_count(); 
  if k1 == k2 then () else
  match k1,k2 with
    (K4byte,K4byte) -> ()
  | (K4byte,Ktype) -> ()
  | (Ktype,Ktype) -> ()
  | (Kstack,Kstack) -> ()
  | (Karrow(k1a,k1b),Karrow(k2a,k2b)) -> 
      (kindleq ctxt k2a k1a; kindleq ctxt k1b k2b)
  | (Kprod k1s,Kprod k2s) -> kindsleq ctxt (Kindleq (k1,k2)) k1s k2s
  | (_,_) -> generate_error ctxt (Kindleq (k1,k2))
and kindsleq ctxt ve k1s k2s =
  match k1s,k2s with
    ([],[]) -> ()
  | (k1::k1s,k2::k2s) -> (kindleq ctxt k1 k2; kindsleq ctxt ve k1s k2s)
  | (_,_) -> generate_error ctxt ve
;;

(* kindeq k1 k2:
 *   if k1  = k2 then ()
 *   if k1 != k2 then generate error Kindeq (k1,k2)
 *)

let rec kindeq ctxt k1 k2 =
  kindeq_count(); 
  if k1 == k2 then () else
  match k1,k2 with
    (K4byte,K4byte) -> ()
  | (Ktype,Ktype) -> ()
  | (Kstack,Kstack) -> ()
  | (Karrow(k1a,k1b),Karrow(k2a,k2b)) ->
      (kindeq ctxt k1a k2a; kindeq ctxt k2a k2b)
  | (Kprod k1s,Kprod k2s) -> kindseq ctxt (Kindeq (k1,k2)) k1s k2s
  | (_,_) -> generate_error ctxt (Kindeq (k1,k2))
and kindseq ctxt ve k1s k2s =
  match k1s,k2s with
    ([],[]) -> ()
  | (k1::k1s,k2::k2s) -> (kindeq ctxt k1 k2; kindseq ctxt ve k1s k2s)
  | (_,_) -> generate_error ctxt ve
;;

(* kindmeet k1 k2:
 *   if the meet of k1 & k2 exists return it
 *   otherwise generate error Kindmeet (k1,k2)
 * kindjoin k1 k2:
 *   if the join of k1 & k2 exists return it
 *   otherwise generate error Kindjoin (k1,k2)
 *)

let rec kindmeet ctxt k1 k2 =
  if k1==k2 then k1
  else match k1,k2 with
    K4byte,(K4byte|Ktype) | (K4byte|Ktype),K4byte -> K4byte
  | Ktype,Ktype -> Ktype
  | Kstack,Kstack -> Kstack
  | Karrow (k11,k12),Karrow (k21,k22) ->
      Karrow (kindjoin ctxt k11 k21,kindmeet ctxt k12 k22)
  | Kprod k1s,Kprod k2s -> Kprod (kindsmeet ctxt (Kindmeet (k1,k2)) k1s k2s)
  | _,_ -> generate_error ctxt (Kindmeet (k1,k2)); raise Talfail
and kindsmeet ctxt ve k1s k2s =
  match k1s,k2s with
    [],[] -> []
  | k1::k1s,k2::k2s -> (kindmeet ctxt k1 k2)::(kindsmeet ctxt ve k1s k2s)
  | _,_ -> generate_error ctxt ve; raise Talfail
and kindjoin ctxt k1 k2 =
  if k1==k2 then k1
  else match k1,k2 with
    Ktype,(K4byte|Ktype) | (K4byte|Ktype),Ktype -> Ktype
  | K4byte,K4byte -> K4byte
  | Kstack,Kstack -> Kstack
  | Karrow (k11,k12),Karrow (k21,k22) ->
      Karrow (kindmeet ctxt k11 k21,kindjoin ctxt k12 k22)
  | Kprod k1s,Kprod k2s -> Kprod (kindsjoin ctxt (Kindjoin (k1,k2)) k1s k2s)
  | _,_ -> generate_error ctxt (Kindjoin (k1,k2)); raise Talfail
and kindsjoin ctxt ve k1s k2s =
  match k1s,k2s with
    [],[] -> []
  | k1::k1s,k2::k2s -> (kindjoin ctxt k1 k2)::(kindsjoin ctxt ve k1s k2s)
  | _,_ -> generate_error ctxt ve; raise Talfail
;;

(*************************************************************************)
(* D |- c : k  where D is a kind assignment                              *)
(*************************************************************************)

let error ctxt c s = generate_error ctxt (Conwf (c,s));;

(* primcon_kind returns the least kind of a primitive constructor *)
let k2 = Karrow(K4byte,K4byte)
let ks = Karrow(Kstack,K4byte)
let primcon_kind ctxt pc = 
  match pc with
    PCbytes sc -> (match sc with Byte4 -> K4byte | _ -> Ktype)
  | PCjunk -> K4byte
    (* tags are limited to bytes -- this assumes that pointers in the range
     * of 0..255 are not valid addresses.
     *)
  | PCtag i ->
      if i >= 0 && i <= 255 then K4byte
      else begin error ctxt (defcon(Cprim pc)) "tag out of range"; K4byte end
  | PCreal -> Ktype
  | PCexn -> K4byte
  | PCexnname -> k2
  | PCarray -> k2
  | PCvector -> k2
  | PCbytearray _ -> K4byte
  | PCbytevector _ -> K4byte
  | PCstackptr -> ks
;;

(* assumes the constructor has been checked already *)
let rec con_kind ctxt con =
  match con.rcon with
    Cvar a -> get_variable_kind ctxt a 
  | Clam(a,k1,c) -> 
      Karrow(k1,con_kind (add_var ctxt a k1) c)
  | Capp(c1,c2) ->
      (match con_kind ctxt c1 with
	Karrow(_,k2) -> k2
      |	_ -> error ctxt con "Capp: not a Karrow"; raise Talfail)
  | Ctuple cs -> Kprod(List.map (con_kind ctxt) cs)
  | Cproj(i,c') ->
      (match con_kind ctxt c' with
	Kprod ks ->
	  (try List.nth ks i with Failure _ -> 
	    error ctxt con "Cproj: index out of range"; raise Talfail)
      |	_ -> error ctxt con "Cproj: not a Kprod"; raise Talfail)
  | Clab l -> get_label_kind ctxt l
  | Cprim pc -> primcon_kind ctxt pc
  | Crec fs -> Kprod (List.map (fun (_,k,_) -> k) fs)
  | Cempty -> Kstack
  | Ccons(_,_) -> Kstack
  | Cappend(_,_) -> Kstack
  | _ -> K4byte
;;

let singleton = Set.singleton id_compare
let empty_set = Set.empty id_compare

let rec rc_freevars (c : rcon) : identifier Set.set =
  match c with
    Cvar a -> singleton a
  | Clam(v,_,c) -> Set.delete (freevars c) v
  | Capp(c1,c2) -> Set.union (freevars c1) (freevars c2)
  | Ctuple cs -> unions (List.map freevars cs)
  | Cproj(_,c) -> freevars c
  | Clab _ -> empty_set
  | Cprim _ -> empty_set
  | Crec fs -> 
      let s =
 	List.fold_left (fun s (_,_,c) -> Set.union s (freevars c))
	  empty_set fs in
      List.fold_left (fun s (x,_,_) -> Set.delete s x) s fs
  | Cforall(v,k,c) -> Set.delete (freevars c) v
  | Cexist(v,k,c) -> Set.delete (freevars c) v
  | Cprod fs -> field_freevars fs
  | Csum {sum_tags=_;sum_vt=Some(Tuple fs)} -> field_freevars fs
  | Csum {sum_tags=_;sum_vt=Some(Variants vs)} ->
      unions (List.map (fun (i,vs) -> field_freevars vs) vs)
  | Csum {sum_tags=_;sum_vt=None} -> empty_set
  | Ccode rs -> 
      Dict.fold_dict (fun r c s -> Set.union s (freevars c)) rs empty_set
  | Cempty -> empty_set
  | Ccons(c1,c2) -> Set.union (freevars c1) (freevars c2)
  | Cappend(c1,c2) -> Set.union (freevars c1) (freevars c2)
and field_freevars fs = 
  List.fold_left (fun s (c,_,_) -> Set.union s (freevars c)) empty_set fs
and unions (l : (identifier Set.set) list) : identifier Set.set = 
  List.fold_left Set.union empty_set l
and freevars (c : con) : identifier Set.set = 
  match c.freevars with
    None -> let s = rc_freevars c.rcon in c.freevars <- (Some s); s
  | Some s -> s
;;

exception Check_Unique
let check_unique l = 
  let rec ck l = 
          (match l with
	    [] -> ()
	  | [_] -> ()
	  | (x1::((x2::rest) as tl)) -> 
	      if x1 = x2 then raise Check_Unique else ck tl)
  in ck l
;;

let not_fun c = 
  (match c.rcon with
    Clam(_,_,_) -> false
  | _ -> true) ;;
let not_tuple c =
  (match c.rcon with
    Ctuple _ -> false
  | _ -> true) ;;

(* checks the kind of a constructor, performs the substitutions entered in
 * the context, and returns the kind and new constructor.
 *)
let check ctxt c =
  check_count(); 
  let defcon b rc = {rcon=rc; isnorm=b; freevars=Some(rc_freevars rc)} in
  let rec ck (ctxt : ctxt) (con : con) =
    let c = con.rcon in
    match c with
      Cvar a -> 
	(get_variable_kind ctxt a,defcon true c)
    | Clam(a,k1,c) -> 
	let (k2,c') = ck (add_var ctxt a k1) c in
	let c = defcon true (Clam(a,k1,c')) in
	(Karrow(k1,k2), c)
    | Capp(c1,c2) ->
	let (k1,c1) = ck ctxt c1 in
	let (k2,c2) = ck ctxt c2 in
	let isnorm = c1.isnorm & c2.isnorm & (not_fun c1) in
	begin
	  match k1 with
	    Karrow(ka,kb) -> 
	      kindleq ctxt k2 ka; (kb,defcon isnorm (Capp(c1,c2)))
	  | _ -> error ctxt con "Capp: not a Karrow"; raise Talfail
	end
    | Ctuple cs -> 
	let (ks,cs,isnorm) = 
	  List.fold_left 
	    (fun (ks,cs,isnorm) c ->
	      let (k,c) = ck ctxt c in (k::ks,c::cs,c.isnorm & isnorm))
	    ([],[],true) cs in
	(Kprod ks,defcon isnorm (Ctuple cs))
    | Cproj(i,c') ->
	let (k,c') = ck ctxt c' in
	let isnorm = c'.isnorm & (not_tuple c') in
	begin
	  match k with
	    Kprod ks ->
	      (try (List.nth ks i,defcon isnorm (Cproj(i,c'))) with
		Failure _ -> error ctxt con "Cproj: index out of range"; 
		  raise Talfail)
	  | _ -> error ctxt con "Cproj: not a Kprod"; raise Talfail
	end
    | Clab l ->
	let k = get_label_kind ctxt l in
	(k,defcon true c)
    | Cprim pc -> (primcon_kind ctxt pc,defcon true c)
    | Crec fs ->
	let g ctxt (a,k,_) = add_var ctxt a k in
	let ctxt' = List.fold_left g ctxt fs in
	let isnorm = ref true in
	let check_f (a,k,c) = 
	  let (k',c') = ck ctxt' c in
	  kindleq ctxt' k' k;
	  isnorm := (!isnorm) & c'.isnorm;
	  (a,k,c') in
	let fs = List.map check_f fs in
	let k = Kprod (List.map (fun (_,k,_) -> k) fs) in
	(k,defcon (!isnorm) (Crec fs))
    | Cforall(a,k,c) ->
	let (k',c') = ck (add_var ctxt a k) c in
	kindleq ctxt k' K4byte;
	(K4byte,defcon c'.isnorm (Cforall(a,k,c')))
    | Cexist(a,k,c) ->
	let (k',c') = ck (add_var ctxt a k) c in
	kindleq ctxt k' K4byte;
	(K4byte,defcon c'.isnorm (Cexist(a,k,c')))
    | Cprod cs ->
	let isnorm = ref true in
	let ck_fields (c,x,y) =
	  let (k,c) = ck ctxt c in
	  kindleq ctxt k Ktype;
	  isnorm := (!isnorm) & c.isnorm;
	  (c,x,y) in
	let fs = List.map ck_fields cs in
	(K4byte,defcon (!isnorm) (Cprod fs))
    | Csum {sum_tags=tags; sum_vt=svt} ->
	let tags = Sort.list (fun i j -> i <= j) tags in
	(try check_unique tags with Check_Unique ->
	  error ctxt con "Csum: duplicated singleton tags");
	let (isnorm,svt) = 
	  (match svt with
	    None -> (true,None)
	  | Some(Tuple cs) -> 
	      let (_,cx) = ck ctxt (defcon false (Cprod cs)) in
	      (match cx.rcon with
		Cprod cs -> (cx.isnorm,Some(Tuple cs))
	      |	_ -> failwith "Talcon.check - Csum: internal error 1")
	  | Some(Variants vs) ->
	      let isnorm = ref true in
	      let vs = Sort.list (fun (i,fs1) (j,fs2) -> i <= j) vs in
	      (try check_unique (List.map fst vs) with Check_Unique ->
		error ctxt con "Csum: duplicated variant tags");
	      let aux (i,fs) =
		let (_,cx) = ck ctxt (defcon false (Cprod fs)) in
		match cx.rcon with
		  Cprod fs -> (isnorm := (!isnorm) & cx.isnorm; (i,fs))
		| _ -> failwith "Talcon.check - Csum: internal error 2" in
	      let vs = List.map aux vs in
	      (!isnorm,Some(Variants vs))) in
 	(K4byte,defcon isnorm (Csum{sum_tags=tags;sum_vt=svt}))
    | Ccode rs ->
	let isnorm = ref true in
	let ck_dict c = 
	  let (k,c) = ck ctxt c in
	  begin
	    if c.isnorm then () else isnorm := false;
	    kindleq ctxt k K4byte; c
	  end in
	let rs = Dict.map_dict ck_dict rs in
 	(K4byte,defcon (!isnorm) (Ccode rs))
    | Cempty -> (Kstack,defcon true c)
    | Ccons(c1,c2) ->
	let (k1,c1) = ck ctxt c1 in
	let (k2,c2) = ck ctxt c2 in
	kindleq ctxt k1 K4byte; kindleq ctxt k2 Kstack; 
	(Kstack,defcon (c1.isnorm & c2.isnorm) (Ccons(c1,c2)))
    | Cappend(c1,c2) ->
	let (k1,c1) = ck ctxt c1 in
	let (k2,c2) = ck ctxt c2 in	
	kindleq ctxt k1 Kstack; kindleq ctxt k2 Kstack;
	(Kstack,defcon false (Cappend(c1,c2)))
  in ck ctxt c
;;

(*************************************************************************)
(* [c'/a]c : capture-avoiding substitution for constructors              *)
(*************************************************************************)

let defvarcon x =
  { rcon=Cvar x;
    isnorm=true;
    freevars=Some(Set.singleton id_compare x)
  }

let rename ((x,k,c) as t) ((d,s) as z) =
  rename_count();
  if Dict.member d x then
    (fast_rename_count(); (Dict.delete d x,s),t)
  else if Set.member s x then
    let x' = id_unique x in
    ((Dict.insert d x (defvarcon x'),s),(x',k,c))
  else (fast_rename_count(); (z,t))

let rename_then f d t =
  let (d',(x,k,c)) = rename t d
  in
      (x,k,f d' c)

(* applies the substitution d to c, alpha-converting c as necessary
 * in order to avoid capture. *)
let rec rcsubsts d con = 
  let c = con.rcon in
  subst_count(); 
  match c with
    Cvar a ->
      (* variable a may be free, return it if so *)
      (try Dict.lookup (fst d) a with Dict.Absent -> con)
  | Clam(x,k,c) -> 
     let (x',k',c') = rename_then substs d (x,k,c) in defcon(Clam(x',k',c'))
  | Capp(c1,c2) -> defcon(Capp(substs d c1, substs d c2))
  | Ctuple cs -> defcon(Ctuple(List.map (substs d) cs))
  | Cproj(i,c) -> defcon(Cproj(i,substs d c))
  | Clab _ -> con
  | Cprim _ -> con
  | Crec fs ->
      let g f (d,fs) = let (d',f') = rename f d in (d',f'::fs) in
      let (d',fs') = List.fold_right g fs (d,[]) in
      defcon(Crec (List.map (fun (x',k,c) -> (x',k,substs d' c)) fs'))
  | Cforall (x,k,c) -> 
      let (x',k',c') = rename_then substs d (x,k,c) 
      in defcon(Cforall(x',k',c'))
  | Cexist (x,k,c) -> 
      let (x',k',c') = rename_then substs d (x,k,c) 
      in defcon(Cexist(x',k',c'))
  | Cprod fs -> defcon(Cprod (substs_fields d fs))
  | Csum {sum_tags=_;sum_vt=None} -> con
  | Csum {sum_tags=st;sum_vt=Some(Tuple fs)} -> 
      defcon(Csum{sum_tags=st;sum_vt=Some(Tuple(substs_fields d fs))})
  | Csum {sum_tags=st;sum_vt=Some(Variants vs)} ->
      defcon(Csum{ sum_tags=st;
		   sum_vt=
		   Some(Variants(
			List.map (fun (i,fs) -> (i,substs_fields d fs)) 
			  vs))})
  | Ccode rs -> defcon(Ccode(Dict.map_dict (substs d) rs))
  | Cempty -> con
  | Ccons (c1,c2) -> defcon(Ccons(substs d c1,substs d c2))
  | Cappend (c1,c2) -> defcon(Cappend(substs d c1, substs d c2))
and substs_field d (c,cap,init) = (substs d c,cap,init)
and substs_fields d = List.map (substs_field d)
and substs ((d,s) as p) c = 
  match c.freevars with
    None -> rcsubsts p c
  | Some cfreevars ->
      let (d,s) = 
	Dict.fold_dict 
	  (fun x c ((d,s) as z) -> 
	    if Set.member cfreevars x then 
	      (Dict.insert d x c,Set.union s (freevars c))
	    else z)
	  d (Dict.empty id_compare,Set.empty id_compare)
      in if Dict.is_empty d then (subst_cutoff(); c) else 
      let c = rcsubsts (d,s) c in (* freevars c;*) c
      (* JGM: calculating the freevars here seems to actually slow things
       * down...*)
;;

(* substitute c1 for free occurrences of x in c2 *)
let subst c1 x c2 = substs (Dict.singleton id_compare x c1,freevars c1) c2;;

let subst_one c1 x fv c2 = substs (Dict.singleton id_compare x c1,fv) c2;;

let substs_list l c =
  let aux (d,s) (x,c) = (Dict.insert d x c,Set.union s (freevars c)) in
  substs (List.fold_left aux (Dict.empty id_compare,Set.empty id_compare) l) c

let rec rcsubst ca a (fvs : identifier Set.set) con = 
  let c = con.rcon in
  subst_one_count(); 
  match c with
    Cvar x ->
      (* variable a may be free, return it if so *)
      if x = a then ca else con
  | Clam(x,k,c) -> 
      if x = a then (subst_one_cutoff(); con) else
      if Set.member fvs x then subst_one ca a fvs con 
      else defcon(Clam(x,k,subst ca a fvs c))
  | Capp(c1,c2) -> defcon(Capp(subst ca a fvs c1, subst ca a fvs c2))
  | Ctuple cs -> defcon(Ctuple(List.map (subst ca a fvs) cs))
  | Cproj(i,c) -> defcon(Cproj(i,subst ca a fvs c))
  | Clab _ -> con
  | Cprim _ -> con
  | Crec fs ->
      if List.exists (fun (x,_,_) -> x = a) fs then (subst_one_cutoff(); con)
      else if List.exists (fun (x,_,_) -> Set.member fvs x) fs then
	subst_one ca a fvs con
      else 
	defcon(Crec(List.map (fun (x,k,c) -> (x,k,subst ca a fvs c)) fs))
  | Cforall (x,k,c) -> 
      if x = a then (subst_one_cutoff(); con)
      else if Set.member fvs x then subst_one ca a fvs con
      else defcon(Cforall(x,k,subst ca a fvs c))
  | Cexist (x,k,c) -> 
      if x = a then (subst_one_cutoff(); con)
      else if Set.member fvs x then subst_one ca a fvs con
      else defcon(Cexist(x,k,subst ca a fvs c))
  | Cprod fs -> defcon(Cprod (substs_fields ca a fvs fs))
  | Csum {sum_tags=_;sum_vt=None} -> con
  | Csum {sum_tags=st;sum_vt=Some(Tuple fs)} -> 
      defcon(Csum{sum_tags=st;sum_vt=Some(Tuple(substs_fields ca a fvs fs))})
  | Csum {sum_tags=st;sum_vt=Some(Variants vs)} ->
      defcon(Csum{sum_tags=st;sum_vt=
		  Some(Variants(
		       List.map(fun (i,fs)->(i,substs_fields ca a fvs fs)) 
			 vs))})
  | Ccode rs -> defcon(Ccode(Dict.map_dict (subst ca a fvs) rs))
  | Cempty -> con
  | Ccons (c1,c2) -> defcon(Ccons(subst ca a fvs c1,subst ca a fvs c2))
  | Cappend (c1,c2) -> defcon(Cappend(subst ca a fvs c1, subst ca a fvs c2))
and substs_field ca a fvs (c,cap,init) = (subst ca a fvs c,cap,init)
and substs_fields ca a fvs = List.map (substs_field ca a fvs)
and subst ca a fvs c =
  match c.freevars with
    None -> rcsubst ca a fvs c
  | Some cfreevars ->
      if Set.member cfreevars a then
	rcsubst ca a fvs c
      else (subst_one_cutoff(); c)
;;
let subst ca a c = subst ca a (freevars ca) c

let substs d c =
  let s =
    Dict.fold_dict (fun x c s -> Set.union s (freevars c)) d
      (Set.empty id_compare) in
  substs (d,s) c
;;

(*************************************************************************)
(* c ->* c' : normalization of constructors                              *)
(*************************************************************************)

(* weak-head normalization *)
let whnorm ctxt c =
  let rec wh (c : con) : con =
    whnorm_count(); 
    if c.isnorm then (fast_whead_count(); c) else 
    match c.rcon with
      Capp(c1,c2) ->
	begin
	  wh c1;
	  match c1.rcon with
	    Clam(x,k,c3) -> 
	      let c' = wh(subst c2 x c3) in
	      c.rcon <- c'.rcon; 
	      (*c.iswhead <- true;*)
	      c.freevars <- c'.freevars;
	      c'
	  | _ -> (*c.iswhead <- true;*) c
	end
    | Cproj(i,c1) ->
	begin
	  try 
	    wh c1;
	    match c1.rcon with
	      Ctuple cs -> 
		let c' = wh (List.nth cs i) in
		c.rcon <- c'.rcon; 
		(*c.iswhead <- true;*)
		c.freevars <- c'.freevars;
		c'
	    | _ -> (*c.iswhead <- true;*) c
	  with Failure _ -> failwith "Talcon.whnorm Cproj"
	end
    | Cappend(c1,c2) ->
	begin
	  wh c1;
	  match c1.rcon with
	    Cempty -> 
	      let c' = wh c2 in
	      c.rcon <- c'.rcon; 
	      (*c.iswhead <- true;*)
	      c.freevars <- c'.freevars;
	      c'
	  | Ccons(f,c1) -> 
	      c.rcon <- Ccons(f,defcon(Cappend(c1,c2)));
	      (*c.iswhead <- true;*)
	      c
	  | Cappend(ca,cb) -> 
	      c.rcon <- Cappend(ca,defcon(Cappend(cb,c2)));
	      wh c
	  | c1 -> (*c.iswhead <- true;*) c
	end
    | _ -> (*c.iswhead <- true;*) c
  in
  wh c
;;

(* normalization: assumes c is well-formed *)
let normalize ctxt c =
  normalize_count(); 
  let rec norm (c : con) : unit =
    if c.isnorm then () else
    begin
    norm_count(); 
    let rec aux (c:rcon) : unit =
      match c with
      	Cvar _ -> ()
      | Clam(x,k,c) -> norm c
      | Capp(c1,c2) -> norm c1; norm c2
      | Ctuple cs -> List.iter norm cs
      | Cproj(i,c) -> norm c
      | Clab _ -> ()
      | Cprim _ -> ()
      | Crec fs -> List.iter (fun (x,k,c) -> norm c) fs
      | Cforall(x,k,c) -> norm c
      | Cexist(x,k,c) -> norm c
      | Cprod fs -> norm_fields fs
      | Csum {sum_tags=st;sum_vt=vt} ->
      	  (match vt with
            None -> ()
      	  | Some(Tuple fs) -> norm_fields fs
      	  | Some(Variants vs) -> 
              List.iter (fun (i,fs) -> norm_fields fs) vs)
      | Ccode rs -> (Dict.app_dict (fun r c -> norm c) rs)
      | Cempty -> ()
      | Ccons (c1,c2) -> norm c1; norm c2
      | Cappend(c1,c2) -> norm c1; norm c2
    and norm_field (c,cap,init) = norm c
    and norm_fields fs = List.iter norm_field fs 
    in aux (whnorm ctxt c).rcon; c.isnorm <- true
    end
  in    
  norm c; c
;;

let check_whnorm ctxt c =
  check_whnorm_count(); 
  let (k,c) = check ctxt c in
  (k,whnorm ctxt c)

(* check a register state *)
let verify_gamma ctxt gamma =
  let ctxt = set_verify_ctxt ctxt "checking register state" in
  let f c = let (k,c') = check ctxt c in kindleq ctxt k K4byte; c' in
  Dict.map_dict f gamma
;;

(*************************************************************************)
(* c ==alpha c'                                                          *)
(*************************************************************************)

type alphactxt =
    ((identifier,identifier) Dict.dict) * ((identifier,kind) Dict.dict)
;;

let empty_ctxt = (Dict.empty id_compare,Dict.empty id_compare);;
let extend c x1 x2 k : alphactxt = 
  (Dict.insert (fst c) x2 x1,Dict.insert (snd c) x1 k)
;;

exception NotEq
let eqerror () = raise NotEq

(* compare two constructors up to alpha-equivalence *)
let alphaeqcon error ctxt c1 c2 =
  alphaeqcon_count(); 
  let rec aeq ctx c1 c2 = 
    if c1 == c2 then fast_alphaeq_count() else raeq ctx c1.rcon c2.rcon
  and raeq ((varmap,kenv) as ctx) c1 c2 =
    alphaeqcon_nodes(); 
    match c1,c2 with
      (Cvar x,Cvar y) -> 
	(try if x = (Dict.lookup varmap y) then () else error ()
	  with Dict.Absent -> if x = y then () else error ())
    | (Clam(x1,k1,c1),Clam(x2,k2,c2)) ->
	(kindeq ctxt k1 k2; aeq (extend ctx x1 x2 k1) c1 c2)
    | (Capp(c1a,c1b),Capp(c2a,c2b)) -> (aeq ctx c1a c2a; aeq ctx c1b c2b)
    | (Ctuple cs1, Ctuple cs2) -> aeqs ctx cs1 cs2
    | (Ctuple [],Cvar y) -> 
	(try (kindeq ctxt (Kprod[]) (Dict.lookup kenv (Dict.lookup varmap y)))
	  with Dict.Absent -> error ())
    | (Cvar x,Ctuple []) ->
	(try (kindeq ctxt (Kprod[]) (Dict.lookup kenv x))
	    with Dict.Absent -> error ())
    | (Cproj(i1,c1),Cproj(i2,c2)) ->
	if i1 = i2 then aeq ctx c1 c2 else error ()
    | (Clab l1,Clab l2) -> if l1<>l2 then error ()
    | (Cprim pc1,Cprim pc2) -> if pc1 = pc2 then () else error ()
    | (Crec fs1,Crec fs2) -> 
	let ctx2 = 
	  List.fold_right2 
	    (fun (x1,k1,_) (x2,k2,_) ctx ->
	      (kindeq ctxt k1 k2; extend ctx x1 x2 k1))
	    fs1 fs2 ctx in
 	List.iter2 (fun (_,_,c1) (_,_,c2) -> aeq ctx2 c1 c2) fs1 fs2
    | (Cforall(x1,k1,c1),Cforall(x2,k2,c2)) ->
	(kindeq ctxt k1 k2; aeq (extend ctx x1 x2 k1) c1 c2)
    | (Cexist(x1,k1,c1),Cexist(x2,k2,c2)) ->
	(kindeq ctxt k1 k2; aeq (extend ctx x1 x2 k1) c1 c2)
    | (Cprod fs1,Cprod fs2) -> aeq_fields ctx fs1 fs2
    | (Csum{sum_tags=st1;sum_vt=vt1},Csum{sum_tags=st2;sum_vt=vt2}) ->
	if st1 = st2 then
	  (match (vt1,vt2) with
	    (None,None) -> ()
	  | (Some(Tuple fs1),Some(Tuple fs2)) -> aeq_fields ctx fs1 fs2
	  | (Some(Variants vs1),Some(Variants vs2)) ->
	      let aux (i1,fs1) (i2,fs2) =
		if i1 = i2 then aeq_fields ctx fs1 fs2 else error() in
	      List.iter2 aux vs1 vs2
	  | (_,_) -> error ())
	else error ()
    | (Ccode rs1,Ccode rs2) -> 
	begin
	  try
	    Dict.app_dict (fun r c1 -> aeq ctx c1 (Dict.lookup rs2 r)) rs1;
	    Dict.app_dict (fun r _ -> Dict.lookup rs1 r) rs2
	  with Dict.Absent -> error()
	end
    | (Cempty,Cempty) -> ()
    | (Ccons(hd1,c1),Ccons(hd2,c2)) -> (aeq ctx hd1 hd2; aeq ctx c1 c2)
    | (Cappend(c1a,c1b),Cappend(c2a,c2b)) -> (aeq ctx c1a c2a; aeq ctx c1b c2b)
    | (_,_) -> error ()
  and aeqs ctx cs1 cs2 = List.iter2 (aeq ctx) cs1 cs2
  and aeq_field ctx (c1,i1,r1) (c2,i2,r2) =
    if (i1 = i2) & (r1 = r2) then aeq ctx c1 c2 else error ()
  and aeq_fields ctx fs1 fs2 = List.iter2 (aeq_field ctx) fs1 fs2
in
  (aeq empty_ctxt c1 c2)  (* could set c1 == c2 and force more sharing *)
;;

let dieerror ctxt c1 c2 () = generate_error ctxt (Neqcon (c1,c2))

let eqcon ctxt c1 c2 =
  eqcon_count(); 
  if (c1 == c2) then fast_eqcon_count() else
  try alphaeqcon eqerror ctxt c1 c2 with NotEq -> 
    let c1 = normalize ctxt c1 in
    let c2 = normalize ctxt c2 in
    alphaeqcon (dieerror ctxt c1 c2) ctxt c1 c2;;

let alphaeqcon ctxt c1 c2 = alphaeqcon (dieerror ctxt c1 c2) ctxt c1 c2

(* ctxt |- gamma1 <= gamma2 *)
let reg_state_leq ctxt gamma1 gamma2 =
  let ctxt = set_verify_ctxt ctxt "register state leq" in
  let error () = generate_error ctxt (Rsnleq (gamma1,gamma2)) in
  let check1 r c2 =
    try eqcon ctxt (Dict.lookup gamma1 r) c2
    with Dict.Absent -> error () in
  Dict.app_dict check1 gamma2
;;

(*************************************************************************)
(* Utilities                                                             *)
(*************************************************************************)

(* unroll a recursive type *)
let rec unroll_rec ctxt c =
  unroll_count(); 
  match c.rcon with
    Cproj (i,c1) ->
      (let c1 = whnorm ctxt c1 in
      match c1.rcon with
	Crec [(v,k,c2)] -> subst c v c2
      |	Crec fs ->
	  let aux (d,n) (v,_,_) =
	    let uc = if n=i then c else defcon (Cproj (n,c1)) in
	    (Dict.insert d v uc,n+1) in
	  let (d,_) = List.fold_left aux (Dict.empty id_compare,0) fs in
	  let (_,_,c2) =
	    try List.nth fs i
	    with Failure _ ->
	      generate_error ctxt (BadUnroll c); raise Talfail in
	  substs d c2
      |	_ ->
	  whnorm ctxt (defcon (Cproj (i,unroll_rec ctxt c))))
  | Capp (c1,c2) ->
      (let c1 = whnorm ctxt c1 in
      whnorm ctxt (defcon (Capp (unroll_rec ctxt c1,c2))))
  | Clab l ->
      (match get_label_def ctxt l with
	AbsCon -> generate_error ctxt (BadUnroll c); raise Talfail
      |	ConcCon c -> c)
  | _ -> generate_error ctxt (BadUnroll c); raise Talfail
;;

(* calculates the size (in bytes) of values who have type c *)	
(* assumes c is normalized *)
let rec sizeof ctxt c =
  sizeof_count();
  match c.rcon with
    Cprim(PCbytes scale) -> 
      (match scale with
	Byte1 -> 1
      |	Byte2 -> 2
      |	Byte4 -> 4
      |	Byte8 -> 8)
  | Cprim PCreal -> 8
  | _ ->
      begin
      	match con_kind ctxt c with
          K4byte -> 4
      	| _ ->
	    let c = whnorm ctxt c in
	    begin
	      match c.rcon with
		Cprim (PCbytes _) -> sizeof ctxt c
	      |	Cprim PCreal -> 8
	      |	_ -> generate_error ctxt (Unknown_size c); 4
	    end
      end
;;

(* Calclulate the size of a stack type *)
let rec sizeof_stack ctxt c =
  match (whnorm ctxt c).rcon with
    Cempty -> 0
  | Ccons (c1,c2) -> (sizeof ctxt c1) + (sizeof_stack ctxt c2)
  | Cappend (c1, c2) -> (sizeof_stack ctxt c1) + (sizeof_stack ctxt c2)
  | _ -> generate_error ctxt (Unknown_size c); 0
;;

(* From a tal function type c, separate abstracted type variables and value
   variables *)
let rec separate_fun_type ctxt c =
  match c.rcon with
    Cforall (v,k,c) -> 
      let (vks, regstate) = separate_fun_type ctxt c in
      ((v,k) :: vks, regstate)
  | Ccode regstate -> ([], regstate)
  | _ -> generate_error ctxt (Conwf (c,"not a function type")); raise Talfail
;;

(*************************************************************************)
(* Field/Stack Slot Utilities                                            *)
(*************************************************************************)

(* -n means n bytes before a field/slot
 * +n means n bytes after last valid field/slot
 *)

let rec get_field_offset ctxt i fs =
  match i,fs with
    0,f::_ -> f
  | i,(c,_,_ as f)::rest ->
      let i' = i - (sizeof ctxt c) in
      if i'>=0 then get_field_offset ctxt i' rest
      else begin generate_error ctxt (Bad_offset i'); f end
  | (i,[]) -> generate_error ctxt (Bad_offset i); raise Talfail
;;

let rec init_field_offset ctxt i fs =
  match i,fs with
    0,(c,cap,Uninit)::rest -> (c,cap,Init)::rest
  | 0,((_,ReadWrite,Init)::_ as fs) -> fs
  | 0,((_,Read,Init)::_ as fs) -> generate_error ctxt Readonly; fs
  | i,(c,_,_ as f)::rest ->
      let i' = i - (sizeof ctxt c) in
      if i'>=0 then f::(init_field_offset ctxt i' rest)
      else begin generate_error ctxt (Bad_offset i'); f::rest end
  | (i,[]) -> generate_error ctxt (Bad_offset i); []
;;

let rec get_stack_offset ctxt i con =
  match i,(whnorm ctxt con).rcon with
    0,Ccons (c,_) -> c
  | i,Ccons (c,con) ->
      let i' = i - (sizeof ctxt c) in
      if i'>=0 then get_stack_offset ctxt i' con
      else begin generate_error ctxt (Bad_offset i'); c end
  | _,_ -> generate_error ctxt (Bad_offset i); raise Talfail
;;

let rec init_stack_offset ctxt i con carg =
  match i,(whnorm ctxt con).rcon with
    0,Ccons (c,con) ->
      if (sizeof ctxt c)<>(sizeof ctxt carg) then generate_error ctxt Readonly;
      defcon(Ccons (carg,con))
  | i,Ccons (c,con) ->
      let i' = i - (sizeof ctxt c) in
      if i'>=0 then defcon(Ccons (c,init_stack_offset ctxt i' con carg))
      else begin generate_error ctxt (Bad_offset i'); raise Talfail end
  | _,_ -> generate_error ctxt (Bad_offset i); raise Talfail
;;

let rec get_stack_tail ctxt i con =
  match i,(whnorm ctxt con).rcon with
    0,_ -> con
  | i,Ccons (c,con) ->
      let i' = i - (sizeof ctxt c) in
      if i'>=0 then get_stack_tail ctxt i' con
      else begin generate_error ctxt (Bad_offset i'); con end
  | _,_ -> generate_error ctxt (Bad_offset i); raise Talfail
;;

(* verify that the stack constructor c1 is a tail of stack constructor c2.
 * assumes c1,c2 normalized and returns a function, which when given a 
 * mutated version of c1, generates the corresponding mutated version of c2.
 * (see assign_simple_path below).  That is, we verify that there exists
 * a c3 such that Cappend(c1,c3) = c2 and return a function which maps
 * c1' to Cappend(c1',c3).
 *
 * JGM: this is a stupid algorithm for doing this, but will probably
 * work well in practice.
 *)
exception Tryrest
let verify_stack_tail ctxt c1 c2 =
  let ctxt' = error_handler ctxt (fun _ _ -> raise Tryrest) in
  let rec aux c2 =
    try
      eqcon ctxt' c1 c2;
      fun c -> c
    with
      Tryrest -> 
      	(match (whnorm ctxt c2).rcon with
	  Ccons(ca,cb) -> 
	    let f = aux cb in fun c -> defcon(Ccons(ca,f c))
	| Cappend(ca,cb) ->
	    let f = aux cb in fun c -> defcon(Cappend(ca,f c))
	| _ -> generate_error ctxt (Not_tail (c1,c2)); fun c -> c)
  in
  aux c2
;;

let writeable_field ctxt (c1,cap,init) c2 =
  (match cap,init with
    ReadWrite,_ -> ()
  | Read,Uninit -> ()
  | Read,Init -> generate_error ctxt Readonly);
  eqcon ctxt c1 c2
;;

(* For now only allow something to be written over a slot of the same size
   or an 8 byte object to be written over two 4s *)
let rec writeable_stack_offset ctxt i c1 c2 =
  match i,(whnorm ctxt c1).rcon with
    0,Ccons (c,c1) ->
      let s1 = sizeof ctxt c and s2 = sizeof ctxt c2 in
      if s1<>s2 then
	if s2<>8 or s1<>4 or
	  (match (whnorm ctxt c1).rcon with 
	    Ccons (c,_) -> (sizeof ctxt c)<>4 | _ -> true) then
	  generate_error ctxt Stack_write_alignment
  | i,Ccons (c,con) ->
      let i' = i - (sizeof ctxt c) in
      if i'>=0 then writeable_stack_offset ctxt i' con c2
      else generate_error ctxt (Bad_offset i')
  | _,_ -> generate_error ctxt (Bad_offset i)
;;

(* EOF: talcon.ml *)
