open Ast
open Ast_types
open Environment
open Equations_types

type envType = Environment_types.t

(* global flag *)
let type_inference = ref true

(* Gets the return type of a functional type 
EFor example, getReturnType (a -> b -> c -> d) 1 returns b -> c -> d 
getReturnType (a -> b -> c -> d) 2 returns c -> d 
getReturnType (a -> b -> c -> d) 3 returns d *)
let rec getReturnType (t : typ) (n : int) =
  if n <= 0 then t else
  match t with
    | TArrow (_, b) -> getReturnType b (n-1)
    | _ -> t
  
(* alpha-convert a schema *)
(* This function returns the type part of s in which every quantified 
 variable is replaced by a fresh variable. EFor example, if s is 
 "for all a,b , a -> b -> TInt", this function returns 
 "[c;d], c -> d -> TInt", where c and d are fresh variables. *)
let alpha (Schema (vars, typ) : schema) : id list * typ =
  List.fold_right (fun x (freshvars, t) ->
    let y = newId() in
    y :: freshvars, substituteExp (TVar x) (TVar y) t
  ) vars ([], typ)

(* 
 * TODO: What are we suppose to do here?
 * should td be the entire list up the stack, or just the current map?
 *)
let rec type_constructor_helper (c : id) td : id * id list * typ option = 
  match td with
(* Given a type constructor c, finds out the corresponding type id,
   with its alpha-converted type variables, and the argument to 
   constructor c containing those type variables *)
  | [ ] -> raise (Type_error ("Constructor " ^ c ^ " undefined"))
  | (id, (typl, l)) :: tl ->
    try match List.assoc c l with
      None -> (id, List.map (fun x -> newId()) typl, None)
    | Some(argtyps) -> let typl1, typ1 = alpha(Schema(typl, argtyps)) in
                       (id, typl1, Some(typ1))
    with Not_found -> type_constructor_helper c tl

let type_constructor (c : id) (env : envType) : id * id list * typ option =
  let (c,td) = Environment_types.get_typedefs_list env c in
    type_constructor_helper c td

let rec subst_type (t : typ) (a : (idtype * typ) list) = match t with
| TVar i -> (try List.assoc i a with Not_found -> TVar i)
| TInt | TFloat | TString | TBool | TUnit | TVoid
| TDummy | TNull | TSymbol -> t
| TArrow (t1, t2) -> TArrow (subst_type t1 a, subst_type t2 a)
| TTuple (tl) -> TTuple (List.map (fun t1 -> subst_type t1 a) tl)
| TUser (s, tl) ->
    TUser (s, List.map (fun t1 -> subst_type t1 a) tl)

let provided_solver = function
  | EVar "constructor" | EApp(EVar "iterator", _) | EApp(EVar "constriterator", _) 
  | EVar "gaussian" | EApp(EVar "appears", _) -> true
  | _ -> false

let rec gts_pattern (p : pattern) (o : envType)
                   : typ * envType * substitution =
    match p with
    | PVar x -> (* it is binding a new variable x! *)
        let freshEVar = newEVar() in
        let typ = Schema ([], freshEVar) in
        freshEVar, (Environment_types.bind x typ o), []
    (* just slightly adapted for the gts calls *)
    | PInt i -> (TInt, o, [])
    | PFloat f -> (TFloat, o, [])
    | PString s -> (TString, o, [])
    | PBool b -> (TBool, o, [])
    | PUnit -> (TUnit, o, [])
    | PTuple l ->
      let rec iterate oi = function
        | [ ] -> [ ], oi, [ ]
        | ei :: tl ->
          let (ti, oi1, substi) = gts_pattern ei oi in 
          let ri = solve substi in
          let oj = applySubstEnv ri oi1 in
          let (lj, oi2, substj) = iterate oj tl in
          let rj = solve substj in
          (applySubst rj ti) :: lj, oi2, ri @ rj in
      let l2, o2, subst2 = iterate o l in
      TTuple l2, o2, subst2
    | PInj(c, eop) ->
      let (idtyp, typvars, typop) = type_constructor c o in
        (match eop, typop with
         | None, None -> 
           TUser(idtyp, List.map (fun x -> TVar x) typvars), o, [ ]
         | _, None -> 
           raise (Type_error (
	     "Constructor " ^ c ^ " does not expect an argument"))
         | None, _ -> 
           raise (Type_error (
	     "Constructor " ^ c ^ " expects an argument"))
         | Some(e1), Some(typop1) -> (* just check that e1 has type t *)
           let (typ1, o1, subst1) = gts_pattern e1 o in
           let ro = solve subst1 in
           let sigma = solve [(applySubst ro typ1, typop1)] in
           let typ0 = TUser(idtyp, List.map (fun x -> TVar x) typvars) in
           (applySubst sigma typ0, o1, ro @ sigma)
        )
    | PUnknown p1 -> 
      let (typ1, o1, subst1) = gts_pattern p1 o in
      let ro = solve subst1 in
      let sigma = solve [(applySubst ro typ1, TSymbol)] in
      (newEVar(), o1, ro @ sigma) (* the type it matches is not known *)
    | PUnderscore -> newEVar(), o, []

and gts_patterns (pl : (pattern * expr) list) (o : envType) :
    typ * typ * envType * substitution =
  match pl with
  | [ ] -> raise (Type_error "Unexpected empty list of patterns")
  | (pi, ei) :: tl ->
      let (typ1, o1, subst1) = gts_pattern pi o in
      let r1 = solve subst1 in
      let (typ2, o2, subst2) = gts ei (applySubstEnv r1 o1) in
      let r2 = solve subst2 in
      let typ1' = applySubst r2 typ1 in
      let sigma = (
        if tl = [ ] then [ ] else
          let (typ3, typ4, o3, subst3) = gts_patterns tl o in
          solve [ (typ1', typ3); (typ2, typ4) ]) in
      (applySubst sigma typ1', applySubst sigma typ2, o, r1 @ r2 @ sigma)

(* gts = Get Type and Substitution 
This function takes an expression t and a type environment o and 
returns a triple: the expression's type if it has one, 
the new type environment, and the substitution solution of the type 
equations generated by type inference.
Raises Type_error if expression t is not well-typed. *)
and gts (t : expr) (o : envType) : typ * envType * substitution =
    match t with
      | EVar x -> snd (alpha (Environment_types.lookup x o)), o, []
      | EInt i -> (TInt, o, [])
      | EFloat f -> (TFloat, o, [])
      | EString s -> (TString, o, [])
      | EBool b -> (TBool, o, [])

      (* functions *)
      | EFun ([EVar x], t) -> 
          let freshEVar = newEVar() in
          let (typ, _, subst) = 
            gts t (Environment_types.bind x
		      (Schema ([], freshEVar)) o) in
          let ro = solve subst in
          (applySubst ro (TArrow (freshEVar, typ)), o, ro)
      | EFun (EVar x :: tl, t) -> 
          let freshEVar = newEVar() in
          let (typ, _, subst) = 
            gts (EFun (tl, t)) 
                (Environment_types.bind x 
		    (Schema ([], freshEVar)) o) in
          let ro = solve subst in
          (applySubst ro (TArrow (freshEVar, typ)), o, ro)
      | EFun ([EUnit], t) -> 
          let (typ, _, subst) = gts t o in
          let ro = solve subst in
          (applySubst ro (TArrow (TUnit, typ)), o, ro)
      | EFun (EUnit :: tl, t) -> 
          let (typ, _, subst) = gts (EFun (tl, t)) o in
          let ro = solve subst in
          (applySubst ro (TArrow (TUnit, typ)), o, ro)
      | EFun (_, t) -> type_error "EUnknown function arguments"

      (* typed functions *)
      | EFunType ([EVar x], [TNull], t) -> gts (EFun ([EVar x], t)) o 
      | EFunType ([EVar x], [typex], t) -> 
          let (typ, _, subst) = 
            gts t (Environment_types.bind x 
		      (Schema ([], typex)) o) in
          let ro = solve subst in
          (applySubst ro (TArrow (typex, typ)), o, ro)
      | EFunType (EVar x :: tl, TNull :: ttl, t) -> 
          let freshEVar = newEVar()  in
          let (typ, _, subst) = 
            gts (EFunType(tl, ttl,t))
                (Environment_types.bind x
		    (Schema ([], freshEVar)) o) in
          let ro = solve subst in
          (applySubst ro (TArrow (freshEVar, typ)), o, ro)
      | EFunType (EVar x :: tl, typex :: ttl, t) -> 
          let (typ, _, subst) = 
            gts (EFunType(tl, ttl, t))
                (Environment_types.bind x
		    (Schema ([], typex)) o) in
          let ro = solve subst in
          (applySubst ro (TArrow (typex, typ)), o, ro)
      | EFunType ([EUnit], [(TNull | TUnit)], t) -> gts (EFun ([EUnit], t)) o
      | EFunType (EUnit :: tl, (TNull | TUnit) :: ttl, t) ->
          let (typ, _, subst) = gts (EFunType (tl, ttl, t)) o in
          let ro = solve subst in 
          (applySubst ro (TArrow (TUnit, typ)), o, ro)
      | EFunType (EUnit :: _, _, _) -> type_error "type of () must be unit"
      | EFunType _ -> type_error "EUnknown function arguments"
      | EFunCorec _ -> 
	type_error "Direct declaration of corecursive function not allowed"

      | EFunction pl ->  (* cf. EMatch *)
        let (typ2, typ3, _, subst2) = gts_patterns pl o in
        let r2 = solve subst2 in
        (TArrow(applySubst r2 typ2, applySubst r2 typ3), o, r2)	  

      | ELetType (id, _, args, _, e1, e2) |
        ELetrecType (id, _, args, _, e1, e2) |
        ELetcorecType (_, id, _, args, _, e1, e2) ->
        let d =
          (match t with 
            | ELetType (id, typeid, args, typeargs, e1, e2) -> 
              Def (id, typeid, args, typeargs, e1)
            | ELetrecType (id, typeid, args, typeargs, e1, e2) -> 
              Defrec (id, typeid, args, typeargs, e1)
            | ELetcorecType (s, id, typeid, args, typeargs, e1, e2) -> 
              Defcorec (s, id, typeid, args, typeargs, e1)
            | _ -> type_error "impossible case") in
        let (typeid, oP, s) = gtsd d o in
        let ro = solve s in
        let oS = applySubstEnv ro oP in
        let gen = generalize typeid oS in
        let oT = Environment_types.bind id gen oS in
        let (typ, _, subst) = gts e2 oT in
        let roP = solve subst in
        (applySubst roP typ, o, ro @ roP)
	  
      | EApp (t, u) -> 
          let (typ1, _, subst1) = gts u o in
          let ro = solve subst1 in 
          let (typ2, _, subst2) = gts t (applySubstEnv ro o) in
          let roP = solve subst2 in
          let freshEVar = newEVar() in
          let sigma = solve [(typ2, TArrow (applySubst roP typ1, freshEVar))] in 
          (applySubst sigma freshEVar, o, ro @ roP @ sigma)     

      | EIf (t, u, v) ->
          let (typ1, _, subst1) = gts t o in
          let ro = solve subst1 in
          let sigma = solve [(typ1, TBool)] in
          let oP = applySubstEnv (ro @ sigma) o in
          let (typ2, _, subst2) = gts u oP in
          let roP = solve subst2 in
          let (typ3, _, subst3) = gts v (applySubstEnv roP oP) in
          let roS = solve subst3 in
          let sigmaP = solve [(applySubst roS typ2, typ3)] in
          (applySubst sigmaP typ3, o, ro @ sigma @ roP @ roS @ sigmaP)
   
      | EWhile (t, u) ->
          let (typ1, _, subst1) = gts t o in  
          let ro = solve subst1 in
          let sigma = solve [(typ1, TBool)] in 
          let oP = applySubstEnv (ro @ sigma) o in
          let (typ2, _, subst2) = gts u oP in
          let roP = solve subst2 in
          let sigmaP = solve [(typ2, TUnit)] in
          (TUnit, o, ro @ sigma @ roP @ sigmaP)

      | EFor (x, e1, e2, e3) ->
          let (typ1, _, subst1) = gts e1 o in
          let ro = solve subst1 in
          let sigma = solve [(typ1, TInt)] in
          let oP = applySubstEnv (ro @ sigma) o in
          let (typ2, _, subst2) = gts e2 oP in
          let roP = solve subst2 in
          let sigmaP = solve [(typ2, TInt)] in
          let oS = applySubstEnv (roP @ sigmaP) oP in
          let (typ3, _, subst3) = 
            gts e3 (Environment_types.bind x
		       (Schema ([], TInt)) oS) in
          let roS = solve subst3 in
          let sigmaS = solve [(typ3, TUnit)] in
          (TUnit, o, ro @ sigma @ roP @ sigmaP @ roS @ sigmaS)
            
      | EAssign (x, e) ->
          let (typ1, _, subst1) = gts (EVar x) o in
          let ro = solve subst1 in
          let oP = applySubstEnv ro o in
          let (typ2, _, subst2) = gts e oP in
          let roP = solve subst2 in    
          let sigma = solve [(applySubst roP typ1, typ2)] in
          (TUnit, o, ro @ roP @ sigma)
             
      | ESeq (e1, e2) ->
          let (typ1, _, subst1) = gts e1 o in
          let ro = solve subst1 in
          let sigma = solve [(typ1, TUnit)] in
          let oP = applySubstEnv (ro @ sigma) o in
          let (typ2, _, subst2) = gts e2 oP in
          let roP = solve subst2 in
          (applySubst roP typ2, o, ro @ sigma @ roP)

      | EMatch (e1, pl) -> (* cf. EFunction *)
          let (typ1, _, subst1) = gts e1 o in
          let r1 = solve subst1 in
          let (typ2, typ3, _, subst2) = gts_patterns pl o in
          let r2 = solve subst2 in
          let sigma = solve [ (applySubst r1 typ1, applySubst r2 typ2) ] in
          (applySubst sigma typ3, o, r1 @ r2 @ sigma)

      | ETuple l ->
          let rec iterate oi = function
            | [ ] -> [ ], oi, [ ]
            | ei :: tl ->
              let (ti, _, substi) = gts ei oi in 
              let ri = solve substi in
              let oj = applySubstEnv ri oi in
              let (lj, _, substj) = iterate oj tl in
              let rj = solve substj in
              ti :: lj, oi, ri @ rj in
          let l2, o2, subst2 = iterate o l in
          TTuple l2, o2, subst2

      | EInj(c, eop) ->
          let (idtyp, typvars, typop) = type_constructor c o in
          (match eop, typop with
           | None, None -> 
             TUser(idtyp, List.map (fun x -> TVar x) typvars), o, [ ]
           | _, None -> 
             raise (Type_error (
	       "Constructor " ^ c ^ " does not expect an argument"))
           | None, _ -> 
             raise (Type_error (
	       "Constructor " ^ c ^ " expects an argument"))
           | Some(e1), Some(typop1) ->
             let (typ1, _, subst1) = gts e1 o in
             let ro = solve subst1 in
             let sigma = solve [(typ1, typop1)] in
             let typ0 = TUser(idtyp, 
				 List.map (fun x -> TVar x) typvars) in
             (applySubst sigma typ0, o, ro @ sigma)
          )

      | EBinop(b, e1, e2) ->
        let (typ1, _, subst1) = gts e1 o in
        let ro = solve subst1 in
        let oP = applySubstEnv ro o in
        let (typ2, _, subst2) = gts e2 oP in
        let roP = solve subst2 in
	(match b with
	| BEq | BNeq -> 
	  let sigma = solve [(typ2, applySubst roP typ1)] in
	  (match typ2 with
	  | (TInt | TBool | TVar _ | TString | TTuple _ 
		| TFloat | TUser _) -> 
            (TBool, o, ro @ roP @ sigma)
	  | _ -> type_error 
	    "can only test booleans, lists, integers, or strings for equality")
	| BLt | BLe | BGt | BGe ->
          let sigma = solve [(typ2, applySubst roP typ1)] in
          (match typ2 with
            | (TInt | TBool | TFloat | TVar _ | TString) ->
	      (TBool, o, ro @ roP @ sigma)
            | _ -> type_error "can only compare ints to ints and bools to bools")
	| BPlus | BMinus | BMul | BDiv | BMod ->
	  let sigma = solve [typ1, TInt; typ2, TInt] in
          (TInt, o, ro @ roP @ sigma)
	| BPlusF | BMinusF | BMulF | BDivF ->
	  let sigma = solve [typ1, TFloat; typ2, TFloat] in
	  (TFloat, o, ro @ roP @ sigma)
	| BConcat -> let sigma = solve [typ1, TString; typ2, TString] in
		    (TString, o, ro @ roP @ sigma)
	| BAnd | BOr -> let sigma = solve [typ1, TBool; typ2, TBool] in
		      (TBool, o, ro @ roP @ sigma))

      | ENot e1 ->
          let (typ, _, subst) = gts e1 o in
          let ro = solve subst in
          let sigma = solve [(typ, TBool)] in
          (TBool, o, ro @ sigma)

      | EUnit -> (TUnit, o, [])
      | ENative (_, _, t) -> (t, o, [])
      | EDummy -> (TDummy, o, [])
      | EUnknown _ -> raise (Type_error "Unexpected unknown")
      | ESymbol _ -> raise (Type_error "Unexpected symbol")

(* gtsd = Get Type and Substitution in a Declaration 
This function takes a declaration d and a type schema environment o and 
returns a three-element tuple: the declaration's type if it has one,
the new type schema environment and the substitution solution of the type 
equations generated by type inference.
Raises Type_error if d is not well-typed *)
and gtsd (d : declaration) (o : envType) : typ * envType * substitution =
  match d with
  (* TODO: How to get inference types for modules? *)
  | Module(id, dl) -> failwith "Idk what to do here"
  | Def (f, typef, [], typeargs, e) -> 
    let (t, _, subst) = gts e o in
        let ro = solve subst in
        if typef = TNull
        then (t, o, ro)
        else
           let sigma = solve [(typef, t)] in 
           (typef, o, ro @ sigma)
    | Def (f, typef, args, typeargs, e) -> 
        let (t, _, subst) = gts (EFunType (args, typeargs, e)) o in
        let ro = solve subst in
        if typef = TNull
        then (t, o, ro)
        else
          let sigma = solve [(typef, getReturnType t (List.length args))] in
          (applySubst sigma t, o, ro @ sigma)
    | Defrec (f, TNull, [], typeargs, e) -> 
        let freshEVar = newEVar() in
        let oP = Environment_types.bind f (Schema([], freshEVar)) o in
        let (t, _, subst) = gts e oP in
        let ro = solve subst in
        let aux = applySubst ro freshEVar in
        let sigma = solve [(t, aux)] in
        (applySubst sigma t, o, ro @ sigma)
    | Defrec (f, typef, [], typeargs, e) -> 
        let oP = Environment_types.bind f (Schema([], typef)) o in
        let (t, _, subst) = gts e oP in
        let ro = solve subst in
        let sigma = solve [(typef, t)] in  
        (typef, o, ro @ sigma) 
    | Defrec (f, typef, [x], typeargs, e) -> 
        let freshEVar = newEVar() in
        let s = if typef = TNull then newEVar() else typef in
        let oP = Environment_types.bind f (Schema ([], TArrow (freshEVar, s))) o in
        let (t, _, subst) = gts (EFunType ([x], typeargs, e)) oP in
        let ro = solve subst in
        let aux = applySubst ro (TArrow (freshEVar, s)) in
        let sigma = solve [(t, aux)] in
        (applySubst sigma t, o, ro @ sigma)
    | Defrec (f, typef, a :: t, typeargs, e) -> 
        let freshEVar = newEVar () in
        let s = if typef = TNull then newEVar () else typef in
        gtsd (Defrec (f, TArrow (freshEVar, s), [a], [List.hd typeargs],
			  EFunType (t, List.tl typeargs, e))) o

    | Defcorec(EApp(EVar "iterator", init), f, typef, args, typeargs, e)
    | Defcorec(EApp(EVar "constriterator", init), f, typef, args, typeargs, e)
    | Defcorec(EApp(EVar "appears", init), f, typef, args, typeargs, e) ->
        let (t1t2, o1, ro1) =
          gtsd (Defrec(f, typef, args, typeargs, e)) o in
        let (tsolver, _, subst) = gts init o in
        let t1 = newEVar() and t2 = newEVar() in
	let sigma = solve [(applySubst subst t1t2, TArrow(t1, t2));
                           (applySubst subst tsolver, t2)] in
	let typf = applySubst sigma (TArrow(t1, t2)) in
	let o2 = Environment_types.bind f (Schema ([], typf)) o1 in
        (typf, o2, ro1 @ subst @ sigma)
    | Defcorec(solver, f, typef, args, typeargs, e) when provided_solver solver -> 
        gtsd (Defrec(f, typef, args, typeargs, e)) o
    | Defcorec(solver, f, typef, args, typeargs, e) (* user solver *) -> 
        let (t1t2, o1, ro1) =
          gtsd (Defrec(f, typef, args, typeargs, e)) o in
        let (tsolver, _, subst) = gts solver o in
        let t1 = newEVar() and t2 = newEVar() and t3 = newEVar() in
	let sigma = solve (
	  match tsolver with
	  | TTuple([tunk; tfresh; tsolve]) ->
	    (* TODO: we don't check that t2 is of the form var expr *)
	    let tvar = newEVar() in
	    (* tvar == var, t2 == var expr, t3 == t *)
	    [applySubst subst t1t2, TArrow(t1, t2);
	     tunk, TArrow(tvar, t2);
	     tfresh, TArrow(TUnit, tvar);
	     tsolve, TArrow(tvar,
			   TArrow(TUser("list", [TTuple [tvar; t2]]),
				 t3))] 
	  | _ ->
	    [applySubst subst t1t2, TArrow(t1, t2);
	     applySubst ro1 tsolver, 
	     TArrow(TSymbol,
		   TArrow(TUser("list", [TTuple [TSymbol; t2]]),
			 t3))] 
	) in
	let typf = applySubst sigma (TArrow(t1, t3)) in
	  (* PROBLEM: now f has two possibly different types: the type it
	     has before going through the solver and the type after: just weird *)
	  (* Rebinding f *)
	let o2 = Environment_types.bind f (Schema ([], typf)) o1 in
        (typf, o2, ro1 @ subst @ sigma)
    | Typedef _ -> assert false (* never called on this; would not make sense *)

let getType (t : expr) (o: envType) : typ * envType =
  let (t, s, _) = gts t o in (t, s)

let getTypeDecl (d : declaration) (o : envType) : typ * envType =
  let (t, s, _) = gtsd d o in (t, s)
