exception RuntimeError of string

open Ast
open Translate

(* Substitute e for v in e'.  This assumes e' is alpha-varied. *)
let rec subst e v e' = match e' with
  Var v' -> if eqvar v v' then e else e'
| Op(op, e1, e2) -> Op(op, subst e v e1, subst e v e2)
| Zerop expr -> Zerop (subst e v expr)
| If (e0,e1,e2) -> If (subst e v e0, subst e v e1, subst e v e2)
| Left expr -> Left (subst e v expr)
| Right expr -> Right (subst e v expr)
| Pair (e1,e2) -> Pair (subst e v e1, subst e v e2)
| Fn (v', body) -> Fn(v', subst e v body)
| App (e1,e2) -> App(subst e v e1, subst e v e2)
| Let ((v0,e0), body) -> let e0' = subst e v e0 in
			 Let((v0,e0'), subst e v body)
| Ref expr -> Ref (subst e v expr)
| Deref expr -> Deref (subst e v expr)
| Assign (e1,e2) -> Assign (subst e v e1, subst e v e2)
| Seq (e1,e2) -> Seq (subst e v e1, subst e v e2)
| e -> e

(* Interpret as call-by-value *)
let interpret_by_value e =
  let empty: int -> expr = (fun loc -> raise (RuntimeError ("cannot find variable " ^ (string_of_int loc) ^ " in store"))) in
  let store = ref empty in
  let loc = ref 0 in
  let newloc () = (loc := !loc + 1; !loc) in

  let prim op e1 e2 =
    match (e1,e2) with
      (Num n1, Num n2) -> (match op with
			      Plus -> Num(n1 + n2)
			    | Minus -> Num(n1 - n2)
			    | Times -> Num(n1 * n2)
			    | Eq -> if (n1 = n2) then True else False
			    | Lt -> if (n1 < n2) then True else False)
    | _ -> raise (RuntimeError "prim applied to non-number") 
  in

  let rec interp e =
    (*
    let _ = print_string "interpreting:\n" in
    let _ = Pprint.pp_expr e in
    let _ = print_newline() in
    *)
    match e with
      Op(op, e1, e2) -> let e1' = interp e1 in
                        let e2' = interp e2 in
			(try
			  prim op e1' e2'
			 with RuntimeError s -> (Pprint.pp_expr (Op(op,e1',e2'));
						 print_newline();
						 raise (RuntimeError s)))
    | Zerop expr -> let expr' = interp expr in
                    (match expr' with Num 0 -> True | _ -> False)
    | If (e0,e1,e2) -> let e0' = interp e0 in
                       (match e0' with
		           True -> interp e1
			 | False -> interp e2
			 | _ -> (Pprint.pp_expr (If(e0',e1,e2));
				 print_newline();
				 raise (RuntimeError "cond of if must be a bool")))
    | Left expr -> let expr' = interp expr in
                   (match expr' with
		     Pair(x,y) -> x
		   | _ -> (Pprint.pp_expr (Left expr');
			   print_newline();
			   raise (RuntimeError "left applied to non-pair")))
    | Right expr -> let expr' = interp expr in
                    (match expr' with
		      Pair(x,y) -> y
		    | _ -> (Pprint.pp_expr (Right expr');
			    print_newline();
			    raise (RuntimeError "right applied to non-pair")))
    | Pair (e1,e2) -> let e1' = interp e1 in
                      let e2' = interp e2 in
		      Pair(e1', e2')
    | App (e1, e2) ->
      let e1' = interp e1 in
      let e2' = interp e2 in
      (match e1' with Fn(v,body) ->
		    let body' = alpha_vary body in
		    interp (subst e2' v body')
		  | _ -> (Pprint.pp_expr (App(e1', e2'));
			  print_newline();
		          raise (RuntimeError "application to non-function")))
    | Let ((v,e), body) -> interp (App(Fn(v,body), e))
    | Ref e -> let e' = interp e in
	       let l = newloc() in
               let s = !store in
	       let _ = (store := (fun loc -> if loc = l then e' else (s loc))) in
	       (Loc l)
    | Deref e -> (let e' = interp e in
		  match e' with (Loc l) -> (!store l)
			      | _ -> (Pprint.pp_expr (Deref e');
				      print_newline();
				      raise (RuntimeError "deref of non-loc")))
    | Assign (e1, e2) -> (let e1' = interp e1 in
			  let e2' = interp e2 in
			  match e1' with
			      Loc l -> let s = !store in
                                       let _ = (store := (fun loc -> if loc = l then e2' else (s loc))) in e2'
			    | _ -> (Pprint.pp_expr (Assign (e1',e2'));
				    print_newline();
				    raise (RuntimeError "assign to non-loc")))
    | Seq (e1, e2) -> let e1' = interp e1 in interp e2
    | Num n -> Num n
    | True -> True
    | False -> False
    | Loc l -> Loc l
    | Fn(v,body) -> Fn(v,body)
    | Var v -> raise (RuntimeError "free variable")
  in

  try
    interp e
  with (RuntimeError s) -> (print_string (s ^ "\n"); ast_error)

let interpret_by_need e =
  let empty: int -> expr = (fun loc -> raise (RuntimeError ("cannot find variable " ^ (string_of_int loc) ^ " in store"))) in
  let store = ref empty in
  let loc = ref 0 in
  let newloc () = (loc := !loc + 1; !loc) in

  let prim op e1 e2 =
    match (e1,e2) with
      (Num n1, Num n2) -> (match op with
			      Plus -> Num(n1 + n2)
			    | Minus -> Num(n1 - n2)
			    | Times -> Num(n1 * n2)
			    | Eq -> if (n1 = n2) then True else False
			    | Lt -> if (n1 < n2) then True else False)
    | _ -> raise (RuntimeError "prim applied to non-number") 
  in

  (* Interpret as call-by-need *)
  let rec interp e env to_value =
    (*
    let _ = print_string "interpreting:\n" in
    let _ = Pprint.pp_expr e in
    let _ = print_newline() in
    *)
    let (result,renv) = match e with
      Var (s,i) -> (try
		      let e' = List.assoc (s,i) env in
                      (*
                      let _ = print_string "maps to:\n" in
                      let _ = Pprint.pp_expr e' in
                      let _ = print_newline() in
                      *)
		      let (z,env') = interp e' env true in
                      (*
                      let _ = print_string "evals to:\n" in
                      let _ = Pprint.pp_expr z in
                      let _ = print_newline() in
                      *)
		      let env'' = ((s,i),z)::env' in
		      (alpha_vary z, env'')
		    with
		      Not_found -> raise (RuntimeError (s ^ (string_of_int i) ^ " not found")))
    | Op(op, e1, e2) -> let (e1',env') = interp e1 env false in
                        let (e2',env'') = interp e2 env' false in
			(try
			  (prim op e1' e2',env'')
			 with RuntimeError s -> (Pprint.pp_expr e;
						 print_newline();
						 raise (RuntimeError s)))
    | Zerop expr -> let (expr',env') = interp expr env false in
                    (match expr' with Num 0 -> (True,env') | _ -> (False,env'))
    | If (e0,e1,e2) -> let (e0',env') = interp e0 env false in
                       (match e0' with
		           True -> interp e1 env' to_value
			 | False -> interp e2 env' to_value
			 | _ -> (Pprint.pp_expr e;
				 print_newline();
				 raise (RuntimeError "cond of if must be a bool")))
    | Left expr -> let (expr',env') = interp expr env false in
                   (match expr' with
		     Pair(x,y) -> interp x env' to_value
		   | _ -> (Pprint.pp_expr e;
			   print_newline();
			   raise (RuntimeError "left applied to non-pair")))
    | Right expr -> let (expr',env') = interp expr env false in
                    (match expr' with
		      Pair(x,y) -> interp y env' to_value
		    | _ -> (Pprint.pp_expr e;
			    print_newline();
			    raise (RuntimeError "right applied to non-pair")))
    | Pair (e1, e2) -> if to_value then
                         let (e1',env') = interp e1 env true in
                         let (e2',env'') = interp e2 env' true in
                         (Pair (e1', e2'), env'')
                       else
                         (e, env)
    | App (e1, e2) ->
      (match e2 with
	(Var _) | (Num _) | True | False ->
	  let (e1',env') = interp e1 env false in
	  (match e1' with
	    Fn(v, body) -> let body' = alpha_vary body in
			   interp (subst e2 v body') env' to_value
	  | _ -> (Pprint.pp_expr e;
		  print_newline();
	          raise (RuntimeError "application to non-function")))
	| _ -> (Pprint.pp_expr e;
		print_newline();
		raise (RuntimeError "application not normalized")))
    | Let ((v,expr), body) -> let env' = (v,expr)::env in
                              interp body env' to_value
    | Ref expr -> let (e',env') = interp expr env true in
	       let l = newloc() in
               let s = !store in
	       let _ = (store := (fun loc -> if loc = l then e' else (s loc))) in
	       (Loc l, env')
    | Deref expr -> (let (e',env') = interp expr env false in
                     match e' with
                         Loc l -> (!store l, env')
                       | _ -> (Pprint.pp_expr e;
                               print_newline();
                               raise (RuntimeError "deref of non-loc")))
    | Assign (e1, e2) -> (let (e1',env') = interp e1 env false in
			  let (e2',env'') = interp e2 env' true in
                          (*
                          let _ = print_string "assign:\n" in
                          let _ = Pprint.pp_expr e1' in
                          let _ = print_string " := " in
                          let _ = Pprint.pp_expr e2' in
                          let _ = print_newline() in
                          *)
			  match e1' with
                              Loc l -> let s = !store in
                                       let _ = (store := (fun loc -> if loc = l then e2' else (s loc))) in
				       (e2', env'')
			    | _ -> (Pprint.pp_expr e;
				    print_newline();
				    raise (RuntimeError "assign to non-loc")))
    | Seq (e1, e2) -> let (_,env') = interp e1 env false in
                                     interp e2 env' to_value
    | e -> (e, env)
  in
    (*
    let _ = print_string "result:\n" in
    let _ = Pprint.pp_expr result in
    let _ = print_newline() in
    *)
    (result, renv)
  in

  try
    let e' = Translate.normalize (alpha_vary e) in
    let (v,_) = interp e' [] false in v
  with (RuntimeError s) -> (print_string (s ^ "\n"); ast_error)
