
open Ast
open Ast_types
open Environment
open Equations_types

type envType = Environment_types.t

(* global flag *)
let type_inference = ref true 

(* Gets the return type of a functional type 
For example, getReturnType (a -> b -> c -> d) 1 returns b -> c -> d 
getReturnType (a -> b -> c -> d) 2 returns c -> d 
getReturnType (a -> b -> c -> d) 3 returns d *)
let rec getReturnType (t : typeNC) (n : int) =
  if n <= 0 then t else
  match t with
    | Arrow (_, b) -> getReturnType b (n-1)
    | _ -> t
  
(* alpha-convert a schema *)
(* This function returns the type part of s in which every quantified 
 variable is replaced by a fresh variable. For example, if s is 
 "for all a,b , a -> b -> Integer", this function returns 
 "c -> d -> Integer", where c and d are fresh variables. *)
let alpha (Schema (vars, typ) : schema) : typeNC =
  List.fold_left (fun t x -> substituteExp (VarType x) (newVar()) t) typ vars
                                                                      
(* gts = Get Type and Substitution 
This function takes an expression t and a type environment o and 
returns a triple: the expression's type if it has one, 
the new type environment, and the substitution solution of the type 
equations generated by type inference.
Raises Type_error if expression t is not well-typed. *)
let rec gts (t : expr) (o : envType) : typeNC * envType * substitution =
    match t with
      | Var x -> (alpha (Environment_types.lookup x o), o, [])           
      | Int i -> (Integer, o, [])
      | Float f -> (FloatType, o, [])
      | Strg s -> (String, o, [])
      | Bool b -> (Boolean, o, [])

      (* functions *)
      | Fun ([Var x], t) -> 
          let freshVar = newVar() in
          let (typ, _, subst) = gts t (Environment_types.bind x (Schema ([], freshVar)) o) in
          let ro = solve subst in
          (Arrow (applySubst ro freshVar, typ), o, ro)
      | Fun (Var x :: tl, t) -> 
          let freshVar = newVar() in
          let (typ, _, subst) = gts (Fun (tl, t)) (Environment_types.bind x (Schema ([], freshVar)) o) in
          let ro = solve subst in
          (Arrow (applySubst ro freshVar, typ), o, ro)
      | Fun ([Unit], t) -> 
          let (typ, _, subst) = gts t o in
          let ro = solve subst in
          (Arrow (UnitType, typ), o, ro)
      | Fun (Unit :: tl, t) -> 
          let (typ, _, subst) = gts (Fun (tl, t)) o in
          let ro = solve subst in
          (Arrow (UnitType, typ), o, ro)
      | Fun (_, t) -> type_error "Unknown function arguments"

      (* typed functions *)
      | FunType ([Var x], [Null], t) -> gts (Fun ([Var x], t)) o 
      | FunType ([Var x], [typex], t) -> 
          let (typ, _, subst) = gts t (Environment_types.bind x (Schema ([], typex)) o) in
          let ro = solve subst in
          (Arrow (typex, typ), o, ro)
      | FunType (Var x :: tl, Null :: ttl, t) -> 
          let freshVar = newVar()  in
          let (typ, _, subst) = gts (FunType(tl, ttl,t)) (Environment_types.bind x (Schema ([], freshVar)) o) in
          let ro = solve subst in
          (Arrow (applySubst ro freshVar, typ), o, ro)
      | FunType (Var x :: tl, typex :: ttl, t) -> 
          let (typ, _, subst) = gts (FunType(tl, ttl, t)) (Environment_types.bind x (Schema ([], typex)) o) in
          let ro = solve subst in
          (Arrow (typex, typ), o, ro)
      | FunType ([Unit], [(Null | UnitType)], t) -> gts (Fun ([Unit], t)) o
      | FunType (Unit :: tl, (Null | UnitType) :: ttl, t) ->
          let (typ, _, subst) = gts (FunType (tl, ttl, t)) o in
          let ro = solve subst in
          (Arrow (UnitType, typ), o, ro)
      | FunType (Unit :: _, _, _) -> type_error "type of () must be unit"
      | FunType _ -> type_error "Unknown function arguments"

      | (Let (id, args, e1, e2) | LetType (id, _, args, _, e1, e2) |
        Letrec (id, args, e1, e2) | LetrecType (id, _, args, _, e1, e2)) -> 
          let d =
            (match t with 
              | Let _ -> Def (id, args, e1)
              | LetType (id, typeid, args, typeargs, e1, e2) -> DefType (id, typeid, args, typeargs, e1)
              | Letrec _ -> Defrec (id, args, e1)
              | LetrecType (id, typeid, args, typeargs, e1, e2) -> DefrecType (id, typeid, args, typeargs, e1)
              | _ -> type_error "impossible case") in
          let (typeid, oP, s) = gtsd d o in
          let ro = solve s in
          let oS = applySubstEnv ro oP in
          let gen = generalize typeid oS in
          let oT = Environment_types.bind id gen oS in
          let (typ, _, subst) = gts e2 oT in
          let roP = solve subst in
          (typ, o, ro @ roP)

      | App (t, u) -> 
          let (typ1, _, subst1) = gts u o in
          let ro = solve subst1 in 
          let (typ2, _, subst2) = gts t (applySubstEnv ro o) in
          let roP = solve subst2 in
          let freshVar = newVar() in
          let sigma = solve [(typ2, Arrow (applySubst roP typ1, freshVar))] in 
          (applySubst sigma freshVar, o, ro @ roP @ sigma)     

      | If (t, u, v) ->
          let (typ1, _, subst1) = gts t o in
          let ro = solve subst1 in
          let sigma = solve [(typ1, Boolean)] in
          let oP = applySubstEnv (ro @ sigma) o in
          let (typ2, _, subst2) = gts u oP in
          let roP = solve subst2 in
          let (typ3, _, subst3) = gts v (applySubstEnv roP oP) in
          let roS = solve subst3 in
          let sigmaP = solve [(applySubst roS typ2, typ3)] in
          (applySubst sigmaP typ3, o, ro @ sigma @ roP @ roS @ sigmaP)
   
      | While (t, u) ->
          let (typ1, _, subst1) = gts t o in  
          let ro = solve subst1 in
          let sigma = solve [(typ1, Boolean)] in 
          let oP = applySubstEnv (ro @ sigma) o in    
          let (typ2, _, subst2) = gts u oP in
          let roP = solve subst2 in
          let sigmaP = solve [(typ2, UnitType)] in
          (UnitType, o, ro @ sigma @ roP @ sigmaP)

      | For (x, e1, e2, e3) ->
          let (typ1, _, subst1) = gts e1 o in
          let ro = solve subst1 in
          let sigma = solve [(typ1, Integer)] in
          let oP = applySubstEnv (ro @ sigma) o in  
          let (typ2, _, subst2) = gts e2 oP in
          let roP = solve subst2 in
          let sigmaP = solve [(typ2, Integer)] in
          let oS = applySubstEnv (roP @ sigmaP) oP in
          let (typ3, _, subst3) = gts e3 (Environment_types.bind x (Schema ([], Integer)) oS) in
          let roS = solve subst3 in
          let sigmaS = solve [(typ3, UnitType)] in
          (UnitType, o, ro @ sigma @ roP @ sigmaP @ roS @ sigmaS)
            
      | Assg (x, e) ->
          let (typ1, _, subst1) = gts (Var x) o in
          let ro = solve subst1 in
          let oP = applySubstEnv ro o in
          let (typ2, _, subst2) = gts e oP in
          let roP = solve subst2 in    
          let sigma = solve [(applySubst roP typ1, typ2)] in
          (UnitType, o, ro @ roP @ sigma)
             
      | Seq (e1, e2) ->
          let (typ1, _, subst1) = gts e1 o in
          let ro = solve subst1 in
          let sigma = solve [(typ1, UnitType)] in
          let oP = applySubstEnv (ro @ sigma) o in  
          let (typ2, _, subst2) = gts e2 oP in
          let roP = solve subst2 in
          (typ2, o, ro @ sigma @ roP)

      (* Lists *)
      | ListMake (Seq (e2, e3)) -> 
         let (typ1, _, subst1) = gts e3 o in
         let ro = solve subst1 in         
         let oP = applySubstEnv ro o in
           (match e2 with
             | Seq (e4, e5) ->
                let (typ2, _, subst2) = gts e5 oP in
                let roP = solve subst2 in
                let sigma = solve [(applySubst roP typ1, typ2)] in
                let (typ3, _, subst3) = gts (ListMake e2) oP in
                let roS = solve subst3 in
                (typ3, o, ro @ roP @ sigma @ roS)           
             | _ ->
                let (typ2, _, subst2) = gts e2 oP in
                let roP = solve subst2 in         
                let sigma = solve [(applySubst roP typ1, typ2)] in
                (ListType (applySubst sigma typ2), o, ro @ roP @ sigma))
      | ListMake e1 ->
          let (typ, _, subst) = gts e1 o in
          let ro = solve subst in
          (ListType typ, o, ro)

      | Cons (e1, e2) ->
         let (typ1, _, subst1) = gts e1 o in
         let ro = solve subst1 in
         let oP = applySubstEnv ro o in
         let (typ2, _, subst2) = gts e2 oP in
         let roP = solve subst2 in
         let sigma = solve [(ListType (applySubst roP typ1), typ2)] in 
         (applySubst sigma typ2, o, ro @ roP @ sigma)     
      | List [] -> (ListType (newVar()), o, [])
      | List (a :: t) -> gts (Cons (a, List t)) o

      (* Comparison operators *)
      | (Eq (e1, e2) | Neq (e1, e2)) -> 
          let (typ1, _, subst1) = gts e1 o in
          let ro = solve subst1 in
          let oP = applySubstEnv ro o in
          let (typ2, _, subst2) = gts e2 oP in
          let roP = solve subst2 in
          let sigma = solve [(typ2, applySubst roP typ1)] in
          (match typ2 with
            | (Integer | Boolean | ListType _ | VarType _ | String) -> 
                (Boolean, o, ro @ roP @ sigma)
            | _ -> type_error "can only test booleans, lists, integers, or strings for equality")
      | (Lt (e1,e2) | Le (e1, e2) | Gt (e1, e2) |  Ge (e1, e2)) ->  
          let (typ1, _, subst1) = gts e1 o in
          let ro = solve subst1 in
          let oP = applySubstEnv ro o in
          let (typ2, _, subst2) = gts e2 oP in
          let roP = solve subst2 in
          let sigma = solve [(typ2, applySubst roP typ1)] in
          (match typ2 with
            | (Integer | Boolean | VarType _) -> (Boolean, o, ro @ roP @ sigma)
            | _ -> type_error "can only compare ints to ints and bools to bools")
  
      (* Arithmetic operators on integers *)
      | (Plus (e1, e2) | Minus (e1, e2) | Mul (e1, e2) | Div (e1, e2)) -> 
          let (typ1, _, subst1) = gts e2 o in
          let ro = solve subst1 in
          let sigma = solve [(typ1, Integer)] in
          let (typ2, _, subst2) = gts e1 (applySubstEnv (ro @ sigma) o) in
          let roP = solve subst2 in
          let sigmaP = solve [(typ2, Integer)] in
          (Integer, o, ro @ sigma @ roP @ sigmaP) 

      (* Arithmetic operators on floats *)
      | (PlusF (e1, e2) | MinusF (e1, e2) | MulF (e1, e2) | DivF (e1, e2)) -> 
          let (typ1, _, subst1) = gts e2 o in
          let ro = solve subst1 in
          let sigma = solve [(typ1, FloatType)] in
          let (typ2, _, subst2) = gts e1 (applySubstEnv (ro @ sigma) o) in
          let roP = solve subst2 in
          let sigmaP = solve [(typ2, FloatType)] in
          (FloatType, o, ro @ sigma @ roP @ sigmaP) 

      (* String operators *)
      | Concat (e1, e2) ->
          let (typ1, _, subst1) = gts e1 o in
          let ro = solve subst1 in
          let sigma = solve [(typ1, String)] in
          let oP = applySubstEnv (ro @ sigma) o in
          let (typ2, _, subst2) = gts e2 oP in
          let roP = solve subst2 in
          let sigmaP = solve [(typ2, String)] in 
          (String, o, ro @ sigma @ roP @ sigmaP)

      (* Boolean operators *)
      | Not e1 ->
          let (typ, _, subst) = gts e1 o in
          let ro = solve subst in
          let sigma = solve [(typ, Boolean)] in
          (Boolean, o, ro @ sigma)
      | (And (e1, e2) | Or (e1, e2)) ->
          let (typ1, _, subst1) = gts e2 o in
          let ro = solve subst1 in
          let sigma = solve [(typ1, Boolean)] in
          let (typ2, _, subst2) = gts e1 (applySubstEnv (ro @ sigma) o) in
          let roP = solve subst2 in
          let sigmaP = solve [(typ2, Boolean)] in
          (Boolean, o, ro @ sigma @ roP @ sigmaP)
      
      | Unit -> (UnitType, o, [])
      | Native (_, _, t) -> (t, o, [])
      | Dummy -> (DummyType, o, [])

(* gtsd = Get Type and Substitution in a Declaration 
This function takes a declaration d and a type schema environment o and 
returns a three-element tuple: the declaration's type if it has one,
the new type schema environment and the substitution solution of the type 
equations generated by type inference.
Raises Type_error if d is not well-typed *)
and gtsd (d : declaration) (o : envType) =
  match d with
    | Def (f, args, e) ->
        gtsd (DefType (f, Null, args, List.map (fun _ -> Null) args, e)) o  
    | Defrec (f, args, e) ->  
        gtsd (DefrecType (f, Null, args, List.map (fun _ -> Null) args, e)) o  
    | DefType (f, typef, [], typeargs, e) -> 
        let (t, _, subst) = gts e o in
        let ro = solve subst in
        if typef = Null
        then (t, o, ro)
        else     
           let sigma = solve [(typef, t)] in 
           (typef, o, ro @ sigma)
    | DefType (f, typef, args, typeargs, e) -> 
        let (t, _, subst) = gts (FunType (args, typeargs, e)) o in
        let ro = solve subst in
        if typef = Null
        then (t, o, ro)
        else
          let sigma = solve [(typef, getReturnType t (List.length args))] in
          (applySubst sigma t, o, ro @ sigma)
    | DefrecType (f, Null, [], typeargs, e) -> 
        let freshVar = newVar() in
        let oP = Environment_types.bind f (Schema([], freshVar)) o in
        let (t, _, subst) = gts e oP in
        let ro = solve subst in
        let aux = applySubst ro freshVar in
        let sigma = solve [(t, aux)] in
        (applySubst sigma t, o, ro @ sigma)
    | DefrecType (f, typef, [], typeargs, e) -> 
        let oP = Environment_types.bind f (Schema([], typef)) o  in
        let (t, _, subst) = gts e oP in
        let ro = solve subst in
        let sigma = solve [(typef, t)] in  
        (typef, o, ro @ sigma)       
    | DefrecType (f, typef, [x], typeargs, e) -> 
        let freshVar = newVar() in
        let s = if typef = Null then newVar() else typef in
        let oP = Environment_types.bind f (Schema ([], Arrow (freshVar, s))) o in
        let (t, _, subst) = gts (FunType ([x], typeargs, e)) oP in
        let ro = solve subst in
        let aux = applySubst ro (Arrow (freshVar, s)) in
        let sigma = solve [(t, aux)] in
        (applySubst sigma t, o, ro @ sigma)
    | DefrecType (f, typef, a :: t, typeargs, e) -> 
        let freshVar = newVar () in
        let s = if typef = Null then newVar () else typef in
        gtsd (DefrecType (f, Arrow (freshVar, s), [a], [List.hd typeargs], FunType (t, List.tl typeargs, e))) o

let getType (t : expr) (o: envType) : typeNC * envType =
  let (t, s, _) = gts t o in (t, s)

let getTypeDecl (d : declaration) (o : envType) : typeNC * envType =
  let (t, s, _) = gtsd d o in (t, s)