
signature TYPE_CHECK = sig
  val tcheck : TypeContext.context * AbSyn.exp -> AbSyn.typ
  val declcheck : TypeContext.context * AbSyn.decl list
                    -> TypeContext.context
end


structure TypeCheck :> TYPE_CHECK =
  struct

    open AbSyn
    open DataTypes
    open TypeContext

    fun vars (p) =
      (case p of
         Wild_p => []
       | Id_p (id) => [id]
       | Const_p (_) => []
       | DataCon_p (id,NONE) => []
       | DataCon_p (id,SOME (p)) => vars (p)
       | Tuple_p (pl) => foldl (fn (e,s) => s@(vars e)) [] pl
       | Record_p (fl) => foldl (fn ((_,e),s) => s@(vars e)) [] fl)

    val error = Error.static
    fun mismatch (s) = error ("Type mismatch ("^s^")")
    fun unexpected (s) = error ("Unexpected type ("^s^")")


    (* inefficient, but works *)
    fun distinct_labels (flds:(string * 'a) list) =
      case flds of
        [] => true
      | ((id,_)::rest) => not (List.exists (fn (x,_) => id=x) rest)
                          andalso distinct_labels (rest)

    (* similarly for this one... *)
    fun distinct_vars (vs) =
      case vs of
        [] => true
      | (v::rest) => not (List.exists (fn x => x=v) rest)
                     andalso distinct_vars (rest)

    structure S = ListMergeSort

    fun sort_fields (l:(string * 'a) list) =
      let fun gt ((id1,_),(id2,_)) = id1 > id2
      in S.sort gt l end

    fun unopArgTyp (Neg,Real_t) = Real_t
      | unopArgTyp (Neg,_) = Int_t   (* default *)
      | unopArgTyp (Not,_) = DataTyp_t ("bool")

    fun unopResTyp (Neg,Real_t) = Real_t
      | unopResTyp (Neg,_) = Int_t   (* default *)
      | unopResTyp (Not,_) = DataTyp_t ("bool")

    fun binopArgTyp (Plus,Real_t) = Tuple_t [Real_t,Real_t]
      | binopArgTyp (Plus,_) = Tuple_t [Int_t,Int_t]
      | binopArgTyp (Times,Real_t) = Tuple_t [Real_t,Real_t]
      | binopArgTyp (Times,_) = Tuple_t [Int_t,Int_t]
      | binopArgTyp (Minus,Real_t) = Tuple_t [Real_t,Real_t]
      | binopArgTyp (Minus,_) = Tuple_t [Int_t,Int_t]

      | binopArgTyp (Equal,String_t) = Tuple_t [String_t,String_t]
      | binopArgTyp (Equal,Ref_t(t)) = Tuple_t [Ref_t(t),Ref_t(t)]
      | binopArgTyp (Equal,Char_t) = Tuple_t [Char_t,Char_t]

      | binopArgTyp (Equal,_) = Tuple_t [Int_t,Int_t]
      | binopArgTyp (GreaterThan,Real_t) = Tuple_t [Real_t,Real_t]
      | binopArgTyp (GreaterThan,_) = Tuple_t [Int_t,Int_t]
      | binopArgTyp (GreaterThanEq,Real_t) = Tuple_t [Real_t,Real_t]
      | binopArgTyp (GreaterThanEq,_) = Tuple_t [Int_t,Int_t]
      | binopArgTyp (LessThan,Real_t) = Tuple_t [Real_t,Real_t]
      | binopArgTyp (LessThan,_) = Tuple_t [Int_t,Int_t]
      | binopArgTyp (LessThanEq,Real_t) = Tuple_t [Real_t,Real_t]
      | binopArgTyp (LessThanEq,_) = Tuple_t [Int_t,Int_t]
      | binopArgTyp (Concat,_) = Tuple_t [String_t,String_t]

    fun binopResTyp (Plus,Real_t) = Real_t
      | binopResTyp (Plus,_) = Int_t
      | binopResTyp (Times,Int_t) = Real_t
      | binopResTyp (Times,_) = Int_t
      | binopResTyp (Minus,Real_t) = Real_t
      | binopResTyp (Minus,_) = Int_t
      | binopResTyp (Equal,_) = DataTyp_t ("bool")
      | binopResTyp (GreaterThan,_) = DataTyp_t ("bool")
      | binopResTyp (LessThan,_) = DataTyp_t ("bool")
      | binopResTyp (GreaterThanEq,_) = DataTyp_t ("bool")
      | binopResTyp (LessThanEq,_) = DataTyp_t ("bool")
      | binopResTyp (Concat,_) = String_t

    (* Parser generates Id_t when it sees a type identifier. lookup_types
       replaces all Id_t's with their proper definitions *)
    fun lookup_types env t =
      case t of
        Id_t name =>
          (case lookup_type(env,name) of
             SOME t => t
           | NONE => unexpected(name))
      | Tuple_t tl => Tuple_t (map (lookup_types env) tl)
      | Record_t itl => Record_t (map (fn (i,t) => (i, lookup_types env t)) itl)
      | Fn_t(t1,t2) => Fn_t (lookup_types env t1, lookup_types env t2)
      | Ref_t(t) => Ref_t(lookup_types env t)
      | ( Int_t | Real_t | String_t | Char_t ) => t
      | _ => raise Fail "lookup_types found a type not generated by the parser"

    fun compare_rec_types (f,t1,t2) =
      (length (t1) = length (t2)) andalso
      let
        val t1' = sort_fields (t1)
        val t2' = sort_fields (t2)
      in
        ListPair.all (fn ((id,t),(id',t')) => (id=id) andalso f (t,t'))
                     (t1',t2')
      end

    fun same_types (t1,t2) =
      case (t1,t2) of
        (Int_t,Int_t) => true
      | (Real_t,Real_t) => true
      | (String_t,String_t) => true
      | (Char_t,Char_t) => true
      | (Tuple_t (t1),Tuple_t (t2)) =>
          (length (t1) = length (t2)) andalso
          ListPair.all same_types (t1,t2)
      | (Record_t (t1),Record_t (t2)) =>
          compare_rec_types (same_types,t1,t2)
      | (Fn_t (t1,t2),Fn_t (t1',t2')) =>
          same_types (t1,t1') andalso same_types (t2,t2')
      | (DataTyp_t (id),DataTyp_t (id')) => (id=id')
      | (Ref_t (t1),Ref_t (t2)) => same_types (t1,t2)
      | ((Id_t _, _) | (_, Id_t _)) =>
          raise Fail "Bug: same_types should not see Id_t"
      | _ => false

    fun tcheck (env,exp) =
      case exp of
        Const_e (Int_c _) => Int_t
      | Const_e (Real_c _) => Real_t
      | Const_e (String_c _) => String_t
      | Const_e (Char_c _) => Char_t
      | Id_e (id) => (case (lookup_var (env,id)) of
                          NONE => error ("Unbound identifier "^id)
                        | SOME (v) => v)
      | DataCon_e (id,NONE) => consResTyp (id)
      | DataCon_e (id,SOME (e)) =>
        let val t = tcheck (env,e)
        in
          if same_types (t,consArgTyp (id)) then consResTyp (id)
          else mismatch "data constructor"
        end
      | Let_e (d,e) =>
        let val env' = declcheck (env,d)
        in  tcheck (env',e) end 

      (* You can assume the code matching  Case_e, Assign_e, Deref_e, and Ref_e
         expressions to be written correctly *)
      | Case_e (e,(p1,e1)::rest) => let
          val t = tcheck (env,e)
          val env1 = patcheck (env,t,p1)
          val t1 = tcheck (env1,e1)
          fun loop (t1,env,[]) = t1
            | loop (t1,env,(p,e')::rest) = let
                val env' = patcheck (env,t,p)
                val t' = tcheck (env',e')
              in
		if same_types (t1,t') then loop (t1,env,rest)
                else mismatch "case"
              end
        in
          loop (t1,env,rest)
        end
      | Case_e (e,[]) => error "Case expression without rules..."
      | Deref_e (e) => (case tcheck (env,e) of
                            Ref_t (t) => t
                          | _ => mismatch "!")
      | Assign_e (e1,e2) => (case (tcheck (env,e1),tcheck (env,e2))
                               of (Ref_t (t1),t2) =>
                                     if (same_types (t2,t1))
                                      then Tuple_t([])
                                    else mismatch ":="
                                | _ => unexpected ":=")
      | Ref_e (e) => Ref_t (tcheck (env,e))

      (* You must remove the following line and match all unmatched cases *)
      | _ => error "Expression unrecognized by type checker"


    and declcheck (env,decl) =
      case decl of
        [] => env
      | Val_d (p,e)::decl2 => let
          val t = tcheck (env,e)
          val env2 = patcheck (env,t,p)
        in
          declcheck(env2,decl2)
        end
      | Fun_d ({name,arg,arg_typ,ret_typ},e)::decl2 => let
          val arg_typ = lookup_types env arg_typ
          val ret_typ = lookup_types env ret_typ
          val env' = add_var (env,name,Fn_t (ret_typ,arg_typ))
          val env'' = add_var (env',arg,arg_typ)
          val t3 = tcheck (env'',e)
        in
          if same_types (ret_typ,t3)
            then declcheck(env',decl2)
          else mismatch "fun decl"
        end

    (* You can assume this function to be correctly written *)
    and patcheck (env,t,p) = if distinct_vars (vars (p))
                               then union_env (env,patenv (t,p))
                             else error "Duplicate identifiers in pattern"

    (* You can assume this function to be correctly written *)
    and patenv (t,p) =
      case p of
        Wild_p => empty_env
      | Id_p (id) => add_var (empty_env,id,t)
      | Const_p (c) => let
          val t2 = tcheck (empty_env,Const_e (Int_c (c)))
        in
          if same_types (t2,t)
            then empty_env
          else mismatch "constant pattern"
        end
      | DataCon_p (id,NONE) => let
          val t2 = consResTyp (id)
        in
          if same_types (t2,t)
            then empty_env
          else mismatch "data constructor pattern"
        end
      | DataCon_p (id,SOME (p)) => let
          val t2 = consResTyp (id)
          val t3 = consArgTyp (id)
        in
          if same_types (t,t2)
            then patenv (t3,p)
          else mismatch "data constructor pattern"
        end
      | Tuple_p (ps) =>
          (case t of
             Tuple_t (ts) => if length (ts) = length (ps)
                                 then union_envs (ListPair.map patenv
                                                  (ts,ps))
                               else mismatch "tuple pattern"
           | _ => mismatch "tuple pattern")
      | Record_p (idps) => (* note... type is sorted at this point *)
          (case t of
             Record_t (idts) =>
               if length (idts) = length (idps)
                 then let
                   val idps' = sort_fields (idps)
                   val idts' = sort_fields (idts)
                 in
                   union_envs
                     (ListPair.map (fn ((id,t),(id',p)) =>
                                    if (id=id') then patenv (t,p)
                                    else mismatch "record pattern")
                     (idts',idps'))
                 end
               else mismatch "record pattern"
           | _ => mismatch "record pattern")
  end









