structure CSEUtils :> sig

    val eligible: LambdaCSE.lexp_cse * LambdaCSE.lexp_cse -> bool

    val getNode:  LambdaCSE.lexp_cse_node * int -> LambdaCSE.lexp_cse option

    val changeIthChild:  (LambdaCSE.lexp_cse -> LambdaCSE.lexp_cse) -> int ->
                         LambdaCSE.lexp_cse_node -> LambdaCSE.lexp_cse_node

end = struct
    open Lambda
    open LambdaCSE

    (* sNode is eligible in tNode if all bound vars in sNode are
       bound WITHIN sNode, and all free variables in sNode
       are also free in tNode. *)
    fun eligible(sNode as CSENode(s,(s_depth,_,_)):lexp_cse,
                 tNode as CSENode(t,(t_depth,_,_)):lexp_cse):bool =
      let
        fun collect(e as CSENode(exp,_):lexp_cse,
                    vars: lexp_cse list):lexp_cse list =
          case exp of
            Int_c(i)          => vars
          | Real_c(r)         => vars
          | Fn_c(e)           => collect(e, vars)
          | App_c(e1,e2)      => collect(e1, collect(e2, vars))
          | Unop_c(_,e)       => collect(e,vars)
          | Binop_c(e1,_,e2)  => collect(e1, collect(e2, vars))
          | Tuple_c(es)       => foldl (fn (e,l) => collect(e,l)) vars es
          | Ith_c(e1,e2)      => collect(e1, collect(e2, vars))
          | SetIth_c(e1,e2,e3)=> raise Fail "eligible: shouldn't see setith"
          | If_c(e1,e2,e3)    => collect(e1, collect(e2, collect(e3,vars)))
          | Letrec_c(e1,e2)   => collect(e1, collect(e2, vars))
          | Var_c(n)          => e::vars
          | Error_c(s)        => vars

        fun test(CSENode(Var_c(n),(d,_,_)):lexp_cse):bool =
               ((d-n)>s_depth) orelse ((d-n)<=t_depth)
          | test(_) = raise Fail("test:impossible")
      in
        List.all test (collect(sNode,[]))
      end


    (* given a node and an int encoding of which child to get (prefix number),
       gets that child. Returns NONE if the child does not exist. *)
    fun getNode (e: lexp_cse_node,
                 x: int):lexp_cse option =
      case e of
          Fn_c(e1) => if x = 1 then SOME e1 else NONE
        | App_c(e1, e2) => (case x of 1 => SOME e1 | 2 => SOME e2 | _ => NONE)
        | Unop_c(u, e1) => if x = 1 then SOME e1 else NONE
        | Binop_c(e1, b, e2) => (case x of 1 => SOME e1 |
                                           2 => SOME e2 | _ => NONE)
        | Tuple_c(es) => (SOME(List.nth(es, x - 1)) handle Subscript => NONE)
        | Ith_c(e1, e2) => (case x of 1 => SOME e1 | 2 => SOME e2 | _ => NONE)
        | SetIth_c(e1, e2, e3) => (case x of 1 => SOME e1 | 2 => SOME e2 |
                                             3 => SOME e3 | _ => NONE)
        | If_c(e1, e2, e3) => (case x of 1 => SOME e1 | 2 => SOME e2 |
                                         3 => SOME e3 | _ => NONE)
        | Letrec_c(e1, e2) => (case x of 1 => SOME e1 |
                                         2 => SOME e2 | _ => NONE)
        | _ => NONE (* can't go deeper on Int, Real, Var, Error *)


    (* given a node, an int encoding of which child to change,
       and a function describing how to change that child,
       return the node with the child changed
       to the output of the function *)
    fun changeIthChild(f:lexp_cse->lexp_cse)(x:int)(e:lexp_cse_node) =
      case e of
          Fn_c(e1) => if x = 1 then Fn_c(f e1) else raise Fail("WTF")
        | App_c(e1, e2) => (case x of
                              1 => App_c(f e1,e2)
                            | 2 => App_c(e1,f e2)
                            | _ => raise Fail("WTF"))
        | Unop_c(u, e1) => if x = 1 then Unop_c(u, f e1) else raise Fail("WTF")
        | Binop_c(e1,b,e2) => (case x of
                                 1 => Binop_c(f e1,b,e2)
                               | 2 => Binop_c(e1,b,f e2)
                               | _=> raise Fail("WTF"))
        | Tuple_c(es) => let
            val len = List.length es
            val in_bounds = x <= len
            val setnth =(fn (e,(i,es'))=>(i-1,(if i=x then (f e) else e)::es'))
          in
            if in_bounds then Tuple_c(#2 (foldr setnth (len,nil) es))
            else raise Fail("WTF")
          end
        | Ith_c(e1, e2) => (case x of
                              1 => Ith_c(f e1, e2)
                            | 2 => Ith_c(e1, f e2)
                            | _ =>raise Fail("WTF"))
        | SetIth_c(e1, e2, e3) => (case x of
                                     1 => SetIth_c(f e1, e2, e3)
                                   | 2 => SetIth_c(e1, f e2, e3)
                                   | 3 => SetIth_c(e1,e2, f e3)
                                   | _ => raise Fail("WTF"))
        | If_c(e1, e2, e3) =>(case x of
                                1 => If_c(f e1, e2, e3)
                              | 2 => If_c(e1, f e2, e3)
                              | 3 => If_c(e1, e2, f e3)
                              | _ => raise Fail("WTF"))
        | Letrec_c(e1, e2) => (case x of
                                 1 => Letrec_c(f e1, e2)
                               | 2 => Letrec_c(e1, f e2)
                               | _ => raise Fail("WTF"))
        | _ => raise Fail("WTF") (* can't go deeper on Int, Real, Var, Error *)

  end
