
open Ast
open Ast_types
open Environment

(* Generalizes a schema. This means that it takes a type "a", and quantifies 
in "a" every variable that is in "a" but not in the type environment "o". 
For example, in the empty environment, the type "t1 -> t2" is generalized by 
the schema "For all t1 and t2, t1 -> t2".
In the environment containing (x, t1), the same type is generalized in 
"For all t2, t1 -> t2" *)
let generalize (a : typeNC) (o : Environment_types.t) : schema =
  let h1 = Util.HashSet.make() in
  free_vars h1 (Schema ([], a));
  let h2 = Util.HashSet.make() in
  Environment_types.iter (fun _ -> free_vars h2) o;
  Util.HashSet.remove_all h1 h2;
  Schema (Util.HashSet.to_list h1, a)


(* EQUATIONS, SUBSTITUTIONS, SOLVER *)

(* The equation a = b is represented by a pair of type expressions (a,b) *)

(* An equation system is a list of equations *)

(* A substitution is a solved equation system, ie a list of 
equations [(X1, A1) ; ... ; (Xn, An)] where Xi are distinct variables, 
Ai are type expressions and no Xi appears in any Aj *)

(* Substitution [(X1, A1); ...; (Xn, An)] means "subsitute A1 for X1 and
 subsitute A2 for X2 and ..." *)

exception Type_error of string

type substitution = (typeNC * typeNC) list
type equation = typeNC * typeNC
type equations = equation list

let type_error (s : string) = raise (Type_error s)

let isVar (x : typeNC) : bool =
  match x with
   | VarType _ -> true
   | _ -> false

(* Tests if x is a subexpression of type expression a *)
let rec occurs_in (a : typeNC) (x : typeNC) : bool =
  (* sanity check - should only be called with variables x *)
  if not (isVar x) then raise (Util.Fatal "improper occurs check");
  match a with
   | Arrow (t1, t2) -> occurs_in t1 x || occurs_in t2 x
   | ListType t -> occurs_in t x
   | VarType _ -> x = a
   | _ -> false
 
(* Tests if equation system systEqu is solved, ie if it is a list of
pairs [(X1, A1) ; ... ; (Xn, An)] where Xi are distinct variables and 
do not appear in any of Ai *)
let isSolved : equations -> bool =
  let rec dupl (sE : equations) : bool =
    match sE with
      | [] -> false
      | (y, _) :: t -> List.exists (fun (x, _) -> x = y) t || dupl t in
  fun (systEqu : equations) ->
    (* Test that all Xi are variables *)
    List.for_all (fun (x, _) -> isVar x) systEqu
    (* Tests that the list of Xi contains no duplicates *)
    && not (dupl systEqu)
    (* Tests that no Xi occurs in any of the Ai *)	
    && List.for_all (fun (x, _) -> not (List.exists (fun (_, ai) -> occurs_in ai x) systEqu)) systEqu
    (* Tests that systEqu contains no Null types *)
    && not (List.exists (fun (a, b) -> a = Null || b = Null) systEqu)

(* Substitutes type expression a for type variable x in type expression exp *)
let rec substituteExp (x : typeNC) (a : typeNC) (exp : typeNC) : typeNC =
  (* sanity check - should only be called with variables x *)
  if not (isVar x) then raise (Util.Fatal "improper substitution");
  match exp with
   | (Integer | FloatType | String | Boolean | UnitType | DummyType | Null) -> exp
   | Arrow (b, c) -> Arrow (substituteExp x a b, substituteExp x a c)
   | VarType y -> if x = exp then a else exp
   | ListType t -> ListType (substituteExp x a t)
   | TupleType l -> TupleType (List.map (substituteExp x a) l)
   | UserType _ -> exp

(* Substitutes type expression a for type variable x in equation e *)
let substituteEquation (x : typeNC) (a : typeNC) ((e1, e2) : equation) : equation =
  let f = substituteExp x a in (f e1, f e2)
  
(* Substitutes type expression a for type variable x in equation system sE *)
let substituteSystem (x : typeNC) (a : typeNC) : equations -> equations =
  List.map (substituteEquation x a)
  
(* Applies substitution s to expression exp *)
let applySubst (s : substitution) (exp : typeNC) : typeNC =
  List.fold_left (fun e (x, a) -> substituteExp x a e) exp s
  
(* Applies a substitution to a schema *)
(* There is no check for capture avoidance here - should there be? *)
(* In fact, this is not what it is doing at all. It is ignoring the *)
(* bound variables completely. *)
let applySubstSch : substitution -> schema -> schema =
  let get_id (x : typeNC) : id =
    match x with
      | VarType a -> a
      | _ -> raise (Util.Fatal "not a substitution") in
  fun (subst : substitution) (Schema (l1, t) : schema) ->
    let (vars, _) = List.split subst in
    let l2 = List.map get_id vars in
    let newL = Util.list_difference l1 l2 in
    let newT = applySubst subst t in
    Schema (newL, newT)

(* Applies substitution s to type environment e *)
let applySubstEnv (s : substitution) : Environment_types.t -> Environment_types.t =
  Environment_types.map (applySubstSch s)
 
(* Solve an equation system E1, E2, E3, ..., En *)

(* Takes equation E1 (equ) and equation system E2, E3, ..., En (systEqu)
and returns the new equation system obtained after analyzing E1 *)
let analyze (equ : equation) (systEqu : equations) : equations =
  match equ with
    | ((Null, _) | (_, Null)) -> systEqu
    | ((Integer, Integer) | (FloatType, FloatType) | (String, String) |
      (Boolean, Boolean) | (UnitType, UnitType) | (DummyType, DummyType)) -> systEqu
    | (VarType x, VarType y) when x = y -> systEqu
    | (Arrow (a,b), Arrow (c,d)) -> (a,c) :: (b,d) :: systEqu
    | (ListType t1, ListType t2) -> (t1, t2) :: systEqu
    | ((VarType x, a) | (a, VarType x)) ->
        if (occurs_in a (VarType x))
        then type_error (Printf.sprintf "conflict involving %s and %s" x (to_string a))
        else (substituteSystem (VarType x) a systEqu) @ [(VarType x, a)]
    | TupleType l1, TupleType l2 when List.length l1 = List.length l2 ->
        (List.combine l1 l2) @ systEqu
    | UserType s1, UserType s2 when s1 = s2 -> systEqu
    | (t1, t2) ->
        type_error (Printf.sprintf "conflict involving %s and %s" (to_string t1) (to_string t2))

(* Returns a solution of the given equation system.
Raises a type error if no solution *)
let rec solve (systEqu : equations) : substitution =
  if isSolved systEqu then systEqu else
  match systEqu with
    | [] -> systEqu
    | equ :: t -> solve (analyze equ t)

(* Fresh type variable generator *)
let newVar = 
  let ls = Util.LexStream.make() in
  fun () -> VarType ("'" ^ (Util.LexStream.next ls))