 (*
  Part of interpreter that actually performs computations.
*)


structure Evaluator :> sig
  val evaluate:
  AbstractSyntax.exp * Environment.env -> Environment.value * AbstractSyntax.typ
  val evaluateDeclare:
  AbstractSyntax.decl list * Environment.env -> Environment.env
  val forceValue:
  Environment.value*AbstractSyntax.typ -> Environment.value*AbstractSyntax.typ
end
= struct

  open AbstractSyntax
  open Environment

  val  err = Error.runtime

  (* Variables that control the evaluator. *)
  val debug:  bool = false
  datatype scoping_style = STATIC | DYNAMIC
  val scoping: scoping_style = STATIC
  datatype evaluation_style = LAZY | EAGER
  val evaluation: evaluation_style = EAGER


  open PrintDebug


  fun predefined (name: string, (arg, argt): value * typ): value * typ =
  (
    if debug then
      print("(\n predefined: " ^ name  ^ "\n"
          ^ " argument = \n"
          ^ printValue((arg, argt), 4) ^ "\n"
          ^ ")\n\n")
    else ();

    case (name, arg, argt) of

      ("hd",  List_v (v::vs), List_t(base))  => (v, base)
    | ("hd",    List_v _,     _           )  => err "list has no head"
    | ("hd",    _,            _           )  => err "'hd' needs a list"

    | ("tl",  List_v (v::vs), List_t _    )  => (List_v vs, argt)
    | ("tl",    List_v _,     _           )  => err "list has no tail"
    | ("tl",    _,            _           )  => err "'tl' needs a list"

    | ("null",  List_v [],    List_t _    )  => (Bool_v true,  Bool_t)
    | ("null",  List_v(v::_), List_t _    )  => (Bool_v false, Bool_t)
    | ("null",  _,            _           )  => err("'null' needs a list")

    | ("print", String_v s ,  String_t    )  => (print s;
                                                 (Tuple_v [], Tuple_t []))
    | ("print", _,            _           )  => err "'print' needs a string"

    | ("fail",  String_v s ,  String_t    )  => Error.fail s
    | ("fail",  _,            _           )  => err "'fail' requires a string"

    | ("explode", String_v s, String_t    )  => (
        List_v (map (fn c =>Char_v c) (String.explode s)),
        List_t  Char_t)
    | ("explode", _,          _           )  => err "'explode' needs a string"

    | ("implode", List_v vs,  List_t Char_t) => (
        let
          val clist = map (fn v =>(case v of
                                     Char_v c => c
                                   | _        => err ("'implode' needs " ^
                                                      "a char list"))) vs
        in
          String_v (String.implode clist)
        end,
        String_t)
    | ("implode", _,          _)             => err "implode needs a char list"

    | ("ncat",  Tuple_v sl,   Tuple_t tl  )  =>
        if List.all (fn t => case t of String_t => true | _ => false) tl
        then
          (String_v (foldl (fn (sv, cs) => case sv of
                                             String_v s => cs ^ s
                                           | _ => err "internal error [10]")
                           ""
                           sl),
           String_t)
        else
          err "'ncat' needs only strings"
      (*
         If a function has a single argument, that is transmitted as such,
         not as a tuple. Zero or more than one argument results in tuples.
      *)
    | ("ncat", String_v _,    String_t   )  => (arg, argt)
    | ("ncat", _,             _          )  => err "'ncat' needs only strings"

      (*
         Stmtlst evaluates all its arguments in order, from left to right,
         and it returns the value of the last expression. Only meaningful
         when statements before the last are interesting only because of
         their side effects.
      *)
    | ("stmtlst", Tuple_v [], Tuple_t [])   => (Tuple_v [],  Tuple_t [])
    | ("stmtlst", Tuple_v vl, Tuple_t tl)   => (List.last vl, List.last tl)
    | ("stmtlst", _,          _         )   => (arg, argt)

    | _                                     => err "internal error [01]"
  )

  fun specialForm (name: string, expr: exp, en: env): value * typ =
  (
    if debug then
      print("(\n special form: " ^ name ^ "\n"
          ^ " unevaluated argument = \n"
          ^ printExp(expr, 4)
          ^ " environment = \n"
          ^ printEnv(en, 4)
          ^ ")\n\n")
    else ();

    case name of

      (*
         if3(cond, thenB, elseB): if eval(cond) = true then eval(thenB)
                                                       else eval(elseB)

         Note that in the dynamic typechecking setting that we employ, it is
         not possible to determine whether the 'then' and 'else' branches
         return values of the same type without evaluating both. Thus in this
         version of Mini-SML we can legally write statements like this:

        >> if3(2 = 2, "true", 1/0)
        "true": string
      *)
      "if3" =>
         (case expr of
            Tuple_e([cond, thenE, elseE]) =>
             (case evaluate(cond, en) of
                (Bool_v true,  Bool_t) => (* evaluate 'then' branch *)
                                          evaluate(thenE, en)
              | (Bool_v false, Bool_t) => (* evaluate 'else' branch *)
                                          evaluate(elseE, en)
              | _ => err "first argument of if3 must be boolean")
          | _ => err "incorrect argument number for if3; should be 3")

      (*
        lazylet(var, e1, e2) is e2 with var replaced by e1

        >> lazylet("x", z + 3, let z:int = 3 in x end)
        6: int

        The first argument must be an expression that evaluates to a
        string (a variable name). Thus expressions like

        lazylet("x" ^ "y", z + 3, let z:int = 3 in xy end)

        are acceptable, but

        lazylet(x, z + 3, let z:int = 3 in xy end)

        is not, unless x evaluates to string "xy."

      *)
    | "lazylet" =>
         (case expr of
            Tuple_e([var, e1, e2]) =>
              (case evaluate(var, en) of
                 (String_v name, String_t) =>
                    evaluate(e2, insertBinding(
                                   name,
                                   (Dyn_v e1, Undef_t),
                                   en))
               | _ => err "first argument of 'lazylet' evaluate to a name")
            | _ => err "incorrect argument number for letsubst; should be 3")

      (*
        This is a version of lazylet that does not evaluate its first
        argument, which must be a variable name. Thus

        lazylet2(x, z + 3, let z:int = 3 in xy end)

        is legal.
      *)

    | "lazylet2" =>
         (case expr of
            Tuple_e([var, e1, e2]) =>
              (case var of
                 Id_e name => evaluate(e2, insertBinding(
                                             name,
                                             (Dyn_v e1, Undef_t),
                                             en))
               | _ => err "first argument of 'lazylet' must be a variable name")
            | _ => err "incorrect argument number for letsubst; should be 3")

      (*
        lazylet2(var, e1, e2) is e2 with var replaced by e1

        >> lazylet3("temp", 3, temp + 4)
        7: int
        >> lazylet3("x", 1, lazylet("y", 2, x + y))
        3: int

        This special form is not truly 'lazy' (see lazylet above). This form's
        laziness consists only in delayed evaluation of its arguments. In fact,
        this is simple 'let' implemented as a special form.
      *)
    | "lazylet3" =>
         (case expr of
            Tuple_e([var, e1, e2]) =>
              (case evaluate(var, en) of
                 (String_v name, String_t) =>
                    evaluate(e2, insertBinding(name, evaluate(e1, en), en))
               | _ => err "first argument of 'lazylet' must be a name")
            | _ => err "incorrect argument number for letsubst; should be 3")

       (*
         ncat2: takes 0, 1, or more string arguments, and concatenates them

         Contrast the implementation of this special form with that of the
         predefined function ncat. Which is easier to implement? Why?
       *)
     | "ncat2" =>
         (String_v (foldl (fn (e, cs) =>
                             case evaluate(e, en) of
                               (String_v s, String_t) => cs ^ s
                             | _ => err "'ncat2' takes only string arguments")
                    ""
                    (case expr of
                       Tuple_e elst => elst
                     | _            => [expr])),
          String_t)
      (*
         Lookup gets the value bound to the given identifier in the
         current environment.
       *)
     | "lookup" =>
         (case evaluate (expr, en) of
            (String_v s, String_t) => (case lookupBinding (s, en) of
                                         SOME v => v
                                       | NONE   => err ("unbound id "^s^
                                                        " in lookup"))
          | _ => err "argument of 'lookup' should be a string")

     | _     => err "internal error [09]"
  )

  (* Computes a value from expressions that might contain thunks. *)
  and forceValue (v1: value, t1: typ): value * typ =
    case v1 of
      Thunk_v(exp1, env1) =>
        let
          val (v2, t2) = forceValue(evaluate(exp1, env1))
        in
          (v2, unifyTypes(t1, t2))
          handle TypeUnification => err "actual type does not match declaration"
        end
    | _ => (v1, t1)

  and evaluate (ex: exp, en: env): value * typ =
  (
    if debug then
      print("(\n evaluate\n"
          ^ " expression = \n"
          ^ printExp(ex, 4)
          ^ " environment = \n"
          ^ printEnv(en, 4)
          ^ ")\n\n")
    else ();

    case ex of
      Int_c i               => (Int_v i,    Int_t)
    | Real_c r              => (Real_v r,   Real_t)
    | Bool_c b              => (Bool_v b,   Bool_t)
    | Char_c c              => (Char_v c,   Char_t)
    | String_c s            => (String_v s, String_t)
    | Id_e id               => (case lookupBinding (id, en) of
                                 NONE     => err ("unbound variable " ^ id)
                               | SOME((Dyn_v(ex), _))
                                          => evaluate(ex, en)
                               | SOME v   => v)
    | If_e (test, e1, e2)   =>
         (case forceValue (evaluate(test, en)) of
            (Bool_v b, Bool_t)=> evaluate(if b then e1 else e2, en)
          | _                 => err ("'if' condition must be boolean"))
    | Let_e   (dlist, ex)   => evaluate     (ex, evaluateDeclare (dlist, en))
    | Apply_e (e1, e2)      => evaluateApply(e1, e2, en)
    | Unop_e  (uop, ex)     => evaluateUnop (uop, forceValue(evaluate(ex, en)))
    | Binop_e (e1, bop, e2) => evaluateBinop(bop, forceValue(evaluate(e1, en)),
                                                  forceValue(evaluate(e2, en)))
    | Tuple_e elist         =>
         let
           val (v, t) = foldr (fn ((v, t), (vl, tl)) => (v::vl, t::tl))
                              ([],[])
                              (map (fn(e) => evaluate (e, en)) elist)
         in
          (Tuple_v v, Tuple_t t)
         end

    | Ith_e (i, ex)         =>
         (case evaluate (ex, en) of
           (Tuple_v vs, Tuple_t tlst)  =>
              let
                val len = List.length(vs)
              in
                if i<1 orelse i>len then err "projection operator out of range"
                else (List.nth (vs, i - 1), List.nth(tlst, i - 1))
              end
         | _                           => err "projection from non-tuple")
    | List_e elist          =>
         let
           val (v, t) = foldr (fn ((v, t), (vl, tl)) => (v::vl, t::tl))
                              ([],[])
                              (map (fn(e) => evaluate (e, en)) elist)
         in
           (*
             The empty list has a dummy type that can later be instantiated
             using context information.
           *)
           (List_v v,
            List_t (foldl (fn (ta, tb) => unifyTypes (ta, tb)) Undef_t t))
           handle TypeUnification => err "typewise inhomogenous list"
         end
    | Fn_e (args, t, body)  =>
           let
             val (al, atl) = ListPair.unzip args
           in
             (Fn_v(al, en, body, NONE),
              Fn_t(Tuple_t atl, t))
           end
  )

  and evaluateApply (e1: exp, e2: exp, encrt: env): value * typ =
  (
    if debug then
      print("(\n evaluate-apply\n"
          ^ " expression 1 (function)=\n"
          ^ printExp(e1, 4)
          ^ " expression 2 (argument)=\n"
          ^ printExp(e2, 4)
          ^ " environment =\n"
          ^ printEnv(encrt, 4)
          ^ ")\n\n")
    else ();

    let
      (* fc = function, fa = formal arg, a = actual arg, t = type, l = list *)
      val (fc, fct)  = forceValue (evaluate(e1, encrt))
    in
      case fc of
        SpecForm_v name             => specialForm(name, e2, encrt)
      | Predef_v   name             => predefined(
                                         name,
                                         forceValue(evaluate(e2, encrt)))
      | Fn_v(fal, env, body, name)  =>
         let
          (* Are we using static or dynamic scoping? *)
          val en = case scoping of
                     STATIC  => env
                   | DYNAMIC => encrt
          (* "Evaluate" (maybe now, maybe later) the function's arguments. *)
          val (a,  at)   =
            (case evaluation of
               EAGER => evaluate (e2, encrt)
             | LAZY  =>
                (case e2 of
                  Tuple_e a2  => (Tuple_v(map (fn a3 => Thunk_v(a3, encrt)) a2),
                                  Tuple_t(map (fn a3 => Undef_t) a2))
                  | _         => (Thunk_v(e2, encrt), Undef_t)))
          (* Transfer values & types of actual parameters into lists. *)
          val (al, atl) = case (a, at) of
                (Tuple_v al, Tuple_t atl) => (al,  atl)
              | (_,           _         ) => ([a], [at])
          (* First, retrieve the types of the formal arguments.    *)
          val (fatl, frt)  = case fct of
                              Fn_t(Tuple_t fatl, frt) => (fatl, frt)
                            | _ => err "internal error, bad function type [06]"
          (*
             Are there too many, too few, or just enough args?
             Length(atl) = Length(al) & Length(fal) = Length(falt)
             1 = set of matched args
             2 = supplementary or missing args
          *)
          val (fal1, fatl1, al1, atl1, fal2, fatl2, al2, atl2) =
            let
              val len = Int.min(List.length atl, List.length fal)
            in
              (List.take(fal, len), List.take(fatl, len),
               List.take(al,  len), List.take(atl,  len),
               List.drop(fal, len), List.drop(fatl, len),
               List.drop(al,  len), List.drop(atl,  len))
            end

          (*
            Are there too many arguments? That would be bad...
          *)
          val _ = if List.length al2 > 0
                  then err "too many arguments provided in function call"
                  else ()
          (*
            Do types match for the available actual arguments?
            u = unified
          *)
          val utl = ListPair.map (fn (f, a) => unifyTypes(f, a)) (fatl1, atl1)
                    handle TypeUnification =>
                      err "argument types don't match in function call"
         in
          (*
             Are there too few arguments?
          *)
          if List.length fal2 > 0
            then (* this is  a curried function => return closure *)
             ( Fn_v(fal2,
                    ListPair.foldl (fn ((fa, a), ut, en') =>
                                        insertBinding(fa, (a, ut), en'))
                                    en
                                    (ListPair.zip(fal1, al1), utl),
                    body,
                    NONE),
               Fn_t(Tuple_t fatl2, frt))
            else
              (case evaluation of
                EAGER =>
                  (* evaluate function, check returned type (r = returned) *)
                  let
                    val (rv, rt) = evaluate(
                      body,
                      ListPair.foldl (fn ((fa, a), ut, en') =>
                                            insertBinding(fa, (a, ut), en'))
                                     (case name of
                                        NONE    => en
                                      | SOME s  => insertBinding(s,(fc,fct),en))
                                     (ListPair.zip(fal1, al1), utl))
                    val urt      = unifyTypes(frt, rt)
                                   handle TypeUnification =>
                                     err ("actual return type does not match " ^
                                          "declared type")
                  in
                    (rv, urt)
                  end
              | LAZY  =>
                 (* create thunk for function body and unevaluated arguments *)
                 (Thunk_v(
                    body,
                    ListPair.foldl (fn ((fa, a), ut, en') =>
                                         insertBinding(fa, (a, ut), en'))
                                   (case name of
                                      NONE    => en
                                    | SOME s  => insertBinding(s, (fc,fct), en))
                                   (ListPair.zip(fal1, al1), utl)),
                  frt))
         end

      | _  => err "attempt to evaluate non-function"
    end
  )

  and evaluateUnop (uop: unop, (v, t): value * typ): value * typ =
  (
    if debug then
      print("(\n unary operator = "
           ^ (case uop of
                Neg   => "neg (~)"
              | Not   => "not"
              | Ref   => "ref"
              | Deref => "deref (!)")
           ^ "\n argument = \n"
           ^ printValue((v, t), 4)
           ^ ")\n\n")
    else ();

    case (uop, v) of
      (Neg, Int_v a)                      => (Int_v  (~a),   Int_t)
    | (Neg, Real_v a)                     => (Real_v (~a),   Real_t)
    | (Neg, _)                            => err "type error (~)"
    | (Not, Bool_v a)                     => (Bool_v (not a), Bool_t)
    | (Not, _)                            => err "type error (not)"
    | (Ref, _)                            => err "ref not implemented"
    | (Deref, _)                          => err "deref not implemented"
  )

  and evaluateBinop (bop:binop, (v1,t1):value*typ,(v2,t2):value*typ):value*typ =
  (
    if debug then
      print("(\n binary operator = "
         ^ (case bop of
              Plus      => "plus (+)"
            | Minus     => "minus (-)"
            | Times     => "times (*)"
            | Mod       => "mod"
            | Div       => "div"
            | Slash     => "slash (/)"
            | Equal     => "equal (=)"
            | Less      => "less than (<)"
            | LessEq    => "less than or equal to (<=)"
            | Greater   => "greater than (>)"
            | GreaterEq => "greater than or equal to (>=)"
            | Cons      => "cons (::)"
            | Append    => "append (@)"
            | AndAlso   => "andalso"
            | OrElse    => "orelse"
            | Concat    => "concatenate (^)"
            | Assign    => "assign (:=)")
         ^ "\n argument1 =\n"
         ^ printValue((v1, t1), 4)
         ^ " argument2 =\n"
         ^ printValue((v2, t2), 4)
         ^ ")\n\n")
    else ();

    case (bop, v1, v2) of
      (Plus,  Int_v  a, Int_v  b)         => (Int_v  (a+b),     Int_t )
    | (Plus,  Real_v a, Real_v b)         => (Real_v (a+b),     Real_t)
    | (Plus,  _,        _       )         => err "type error (+)"
    | (Times, Int_v  a, Int_v b )         => (Int_v  (a*b),     Int_t)
    | (Times, Real_v a, Real_v b)         => (Real_v (a*b),     Real_t)
    | (Times, _,        _       )         => err "type error (*)"
    | (Minus, Int_v  a, Int_v b )         => (Int_v  (a-b),     Int_t)
    | (Minus, Real_v a, Real_v b)         => (Real_v (a-b),     Real_t)
    | (Minus, _,        _       )         => err "type error (-)"
    | (Mod,   Int_v  a, Int_v b )         => (Int_v  (a mod b), Int_t)
    | (Mod,   _,        _       )         => err "type error (mod)"
    | (Div,   Int_v  a, Int_v b )         => (Int_v  (a div b), Int_t)
    | (Div,   _,        _       )         => err "type error (div)"
    | (Slash, Real_v a, Real_v b)         => (Real_v (a/b),     Real_t)
    | (Slash, _,        _       )         => err "type error (/)"

    (* 'real' is not an equality type! *)
    | (Equal, Int_v  a, Int_v b)          => (Bool_v (a=b),     Bool_t)
    | (Equal, Bool_v a, Bool_v b)         => (Bool_v (a=b),     Bool_t)
    | (Equal, Char_v a, Char_v b)         => (Bool_v (a=b),     Bool_t)
    | (Equal, String_v a, String_v b)     => (Bool_v (a=b),     Bool_t)

    | (Equal, Tuple_v aa, Tuple_v bb)     => (Bool_v (
        (length aa = length bb)
        andalso
        (case t1 of
           Tuple_t tlst  =>
             ListPair.all (fn((a,b), t) =>
                            case evaluateBinop (Equal, (a, t), (b, t)) of
                              (Bool_v true, Bool_t) => true
                            | _                     => false)
                          (ListPair.zip(aa, bb), tlst)
         | _           => err "Type error or internal error [03]")),
                                              Bool_t)

    | (Equal, List_v aa, List_v bb)       => (Bool_v (
        (length aa = length bb)
        andalso
        (case t1 of
           List_t lt  =>
             ListPair.all (fn(a, b) =>
                            case evaluateBinop (Equal, (a, lt), (b, lt)) of
                              (Bool_v true, Bool_t) => true
                            | _                     => false)
                          (aa, bb)
         | _           => err "Type error or internal error [04]")),
                                              Bool_t)

    | (Equal, _, _)                       => err "type error (=)"

    | (Greater, Int_v  a, Int_v b)        => (Bool_v (a>b), Bool_t)
    | (Greater, Real_v a, Real_v b)       => (Bool_v (a>b), Bool_t)
    | (Greater, Char_v a, Char_v b)       => (Bool_v (a>b), Bool_t)
    | (Greater, String_v a, String_v b)   => (Bool_v (a>b), Bool_t)
    | (Greater, _, _)                     => err "type error (>)"
    | (Less, Int_v  a, Int_v b)           => (Bool_v (a<b), Bool_t)
    | (Less, Real_v a, Real_v b)          => (Bool_v (a<b), Bool_t)
    | (Less, Char_v a, Char_v b)          => (Bool_v (a<b), Bool_t)
    | (Less, String_v a, String_v b)      => (Bool_v (a<b), Bool_t)
    | (Less, _, _)                        => err "type error (<)"
    | (GreaterEq, Int_v  a, Int_v b)      => (Bool_v (a>=b), Bool_t)
    | (GreaterEq, Real_v a, Real_v b)     => (Bool_v (a>=b), Bool_t)
    | (GreaterEq, Char_v a, Char_v b)     => (Bool_v (a>=b), Bool_t)
    | (GreaterEq, String_v a, String_v b) => (Bool_v (a>=b), Bool_t)
    | (GreaterEq, _, _)                   => err "type error (>=)"
    | (LessEq, Int_v  a, Int_v b)         => (Bool_v (a<=b), Bool_t)
    | (LessEq, Real_v a, Real_v b)        => (Bool_v (a<=b), Bool_t)
    | (LessEq, Char_v a, Char_v b)        => (Bool_v (a<=b), Bool_t)
    | (LessEq, String_v a, String_v b)    => (Bool_v (a<=b), Bool_t)
    | (LessEq, _, _)                      =>  err "type error (<=)"

    | (Cons, _, List_v vs)                => (List_v(v1::vs),
                                              unifyTypes(List_t t1, t2))
    | (Cons, _, _)                        => err "type error (::)"
    | (Append, List_v aa, List_v bb)      => (List_v(aa@bb), unifyTypes(t1, t2))
    | (Append, _, _)                      => err "type error (@)"

    | (AndAlso, Bool_v a, Bool_v b)       => (Bool_v (a andalso b), Bool_t)
    | (AndAlso, _, _)                     => err "type error (andalso)"
    | (OrElse, Bool_v a, Bool_v b)        => (Bool_v (a orelse b),  Bool_t)
    | (OrElse, _, _)                      => err "type error (orelse)"

    | (Concat, String_v a, String_v b)    => (String_v (a^b),     String_t)
    | (Concat, _, _)                      => err "type error (^)"

    | (Assign, _, _)                      => err "assign not implemented"
  )

  and evaluateDeclare (dlist: decl list, en: env): env =
    let
      fun addOneDecl (d: decl, en: env): env =
        case d of
          Val_d (id, typ, ex)             =>
            let
              val (v, vt) = evaluate(ex, en)
            in
              insertBinding(id, (v, unifyTypes(vt, typ)), en)
              handle TypeUnification =>
                err "type of computed value does not match declared type"
            end
        | Fun_d ({name, args, ret_typ}, ex) =>
            let
              val (args_name, args_typ) =
                foldr (fn ((n, t), (nl, tl)) => (n::nl, t::tl)) ([], []) args
              val fcn_typ = Fn_t(Tuple_t(args_typ), ret_typ)
            in
              insertBinding (name,(
                             Fn_v(args_name, en, ex, SOME(name)), fcn_typ), en)
            end
    in
      foldl addOneDecl en dlist
    end

end
