structure ConstFold :> sig

    val const_fold: Lambda.lexp -> Lambda.lexp

end = struct
    open Lambda

    val changed = ref true

    (* helper function that optimizes unary operators *)
    fun do_unop(u:lunop,c:lexp):lexp =
      case (u,c) of
        (Neg_l,Int_l i) => (changed := true; Int_l(~i))
      | (Neg_l,Real_l r) => (changed := true; Real_l(~r))
      | (Not_l, Int_l i) => (changed := true; Int_l (if i=0 then 1 else 0))
      | _ => Unop_l(u,c)

    (* helper function that optimizes binary operators *)
    fun do_binop(c1:lexp,b:lbinop,c2:lexp):lexp =
      case (c1,b,c2) of
        (Int_l i,Plus_l,Int_l j) => (changed := true;Int_l(i+j))
      | (Real_l i,Plus_l,Real_l j) => (changed := true;Real_l(i+j))
      | (Int_l i,Times_l,Int_l j) => (changed := true;Int_l(i*j))
      | (Real_l i,Times_l,Real_l j) => (changed := true;Real_l(i*j))
      | (Int_l i,Minus_l,Int_l j) => (changed := true;Int_l(i-j))
      | (Real_l i,Minus_l,Real_l j) => (changed := true;Real_l(i-j))
      | (Int_l i,GreaterThan_l,Int_l j) =>
          (changed := true;if i>j then Int_l(1) else Int_l(0))
      | (Real_l i,GreaterThan_l,Real_l j) =>
          (changed := true;if i>j then Int_l(1) else Int_l(0))
      | (Int_l i,LessThan_l,Int_l j) =>
          (changed := true;if i<j then Int_l(1) else Int_l(0))
      | (Real_l i,LessThan_l,Real_l j) =>
          (changed := true;if i<j then Int_l(1) else Int_l(0))
      | (Int_l i,Equal_l,Int_l j) =>
          (changed := true;if i=j then Int_l(1) else Int_l(0))
      | (Tuple_l(strA), Concat_l, Tuple_l(_::_::strB)) =>
          Tuple_l(strA@strB)
      | (_,_,_) => Binop_l(c1,b,c2)

    (* optimizes away everything that can be precalculated
       in the compile stage *)
    fun opt(e:lexp):lexp =
      case e of
	    Int_l _ => e
      | Real_l _ => e
      | Var_l j => e
      | Fn_l(e) => Fn_l(opt(e))
      | App_l(Fn_l(e1),e2) => App_l(Fn_l(opt(e1)),opt(e2))
      | App_l(e1,e2) => App_l(opt(e1),opt(e2))
      | Unop_l(Ptr_l,Var_l _) => e
      | Unop_l(Ptr_l, _) => Int_l(0)
      | Unop_l(u,Int_l i) => do_unop(u,Int_l(i))
      | Unop_l(u,Real_l r) => do_unop(u,Real_l(r))
      | Unop_l(u,e1) =>
          let val x = opt e1 in
            case x of
              Int_l _ => opt(Unop_l(u,x))
            | Real_l _ => opt(Unop_l(u,x))
            | _ => Unop_l(u,x)
          end
      | Binop_l(Real_l r1,b,Real_l r2) =>
          do_binop(Real_l r1,b,Real_l r2)
      | Binop_l(Int_l i1, b, Int_l i2) =>
          do_binop(Int_l i1,b,Int_l i2)
      | Binop_l(x,Concat_l,y) => do_binop(x,Concat_l,y)
      | Binop_l(e1,b,e2) =>
          let val (x,y) = (opt e1, opt e2) in
            case (x,b,y) of
              (Int_l _, b,Int_l _) => opt(Binop_l(x,b,y))
            | (Real_l _,b,Real_l _) => opt(Binop_l(x,b,y))
            (* e * 1 = e, 1 * e = e*)
            | (Int_l(1),Times_l, _) => y
            | (_,Times_l, Int_l(1)) => x
            (* e + 0 = e, 0 + e = e *)
            | (Int_l(0), Plus_l, _) => y
            | (_, Plus_l, Int_l(0)) => x
            (* e - 0 = e, 0 - e = ~e *)
            | (_, Minus_l, Int_l(0)) => x
            | (Int_l(0), Minus_l, _) => Unop_l(Neg_l, y)
            | _ => Binop_l(x,b,y)
          end
      | Tuple_l(es) => Tuple_l(map opt es)
      | Ith_l(e1,e2) => Ith_l(opt(e1), opt(e2))
      | SetIth_l(e1,e2,e3) =>
	  SetIth_l(opt(e1), opt(e2), opt(e3))
      | If_l (Int_l i,e1,e2) => (changed:= true;
                                 if i=1 then opt(e1) else opt(e2))
      | If_l(e1,e2,e3) =>
          let val b = opt e1 in
            case b of
              Int_l _ => opt(If_l(b,e2,e3))
            | _ => If_l(b,opt(e2), opt(e3))
          end
      | Letrec_l(e1,e2) => Letrec_l(opt(e1), opt(e2))
      | Error_l _ => e

    (* the main constant folding loop *)
    fun const_fold(e:lexp):lexp = let
      fun loop(e:lexp):lexp =
        if (!changed) then
          (changed := false;
           loop(opt(e)))
        else e
      in
        (changed:= false; print("Constant-Folding...\n");
        loop(opt e))
      end
end
