structure CSEReplace :> sig

    val replace: Lambda.lexp -> Lambda.lexp

end = struct
    open Lambda
    open LambdaCSE
    open CSEUtils

    val replaced = ref true

    (* are two lambda expressions equivalent and safe for substitute?
       Return false for anything with a side-effect. *)
    fun eq_exp(xNode as CSENode(x, (xDepth, _, _)):lexp_cse,
               yNode as CSENode(y, (yDepth, _, _)):lexp_cse):bool =
      let
        fun eq_unop(u1:lunop, u2:lunop):bool =
          case (u1,u2) of
            (Neg_l, Neg_l) => true
          | (Not_l, Not_l) => true
          | _ => false (* Ptr can be used to change memory *)

        fun eq_binop(b1:lbinop, b2:lbinop):bool =
          case (b1,b2) of
            (Plus_l,Plus_l) => true
          | (Times_l,Times_l) => true
          | (Minus_l,Minus_l) => true
          | (Equal_l,Equal_l) => true
          | (GreaterThan_l,GreaterThan_l) => true
          | (LessThan_l,LessThan_l) => true
          | _ => false (* concat changes memory *)

        fun eql(CSENode(x,(d_x,_,_)):lexp_cse,
                CSENode(y,(d_y,_,_)):lexp_cse):bool =
          case (x,y) of
            (Var_c i, Var_c j) => (* check *)
              let
                val xNotFree = (d_x - xDepth) > i
                val yNotFree = (d_y - yDepth) > j
                val same = (xNotFree = yNotFree)
              in
                same andalso (if xNotFree then i = j else (xDepth-i)=(yDepth-j))
              end
          | (Int_c i,Int_c j) => i = j
          | (Real_c r, Real_c s) => Real.==(r,s)
          | (Fn_c e1, Fn_c e2) => eql(e1,e2)
          | (App_c (e1,e2), App_c (e3,e4)) => eql(e2,e4) andalso eql(e1,e3)
          | (Unop_c(u1,e1),Unop_c(u2,e2)) => eq_unop(u1,u2) andalso eql(e1,e2)
          | (Binop_c(e1,b1,e2), Binop_c(e3,b2,e4)) =>
              eq_binop(b1, b2) andalso eql(e1,e3) andalso eql(e2,e4)
          | (Tuple_c(es1),Tuple_c(es2)) =>
              (length es1=length es2) andalso ListPair.all eql (es1,es2)
          | (Ith_c(e1,e2),Ith_c(e3,e4)) => eql(e1,e3) andalso eq_exp(e2,e4)
          (* SetIth is ONLY created from ref assignment in compile.sml
             (ref 4) != (ref 4), since the memory allocated for each tuple
             cannot be proven equal. More complex analysis (aliasing) can
             take care of this, but for now we disallow entirely. *)
          | (SetIth_c(e1,e2,e3),SetIth_c(e4,e5,e6)) =>
              false
          | (If_c(e1,e2,e3),If_c(e4,e5,e6)) =>
              eql(e1,e4) andalso eql(e2,e5) andalso eql(e3,e6)
          | (Letrec_c(e1,e2),Letrec_c(e3,e4)) => eql(e1,e3) andalso eql(e2,e4)
          | (Error_c(s1), Error_c(s2)) => s1 = s2
          | _ => false
      in
        eql(xNode, yNode)
      end

    (* given a root node and a prefix, finds the parent node corresponding
       to the prefix and returns an integer indicating where the node is
       from the parent (integer is always positive). Returns NONE if the
       prefix was not valid. If the prefix is [], then we want the parent of
       the root node, which doesn't exist, therefore we return
       (root node, ~1) *)
    fun findParent (node as CSENode(rootNode, data):lexp_cse,
                     prefix: int list): (lexp_cse * int) option =
      case prefix of
          nil    =>  SOME (node, ~1)
        (* at the parent node *)
        | x::nil =>  SOME (node, x)
        (* need to go deeper to find parent node *)
        | x::xs =>  (case getNode(rootNode, x) of
                        NONE    => NONE
                      | SOME(n) => findParent(n, xs))

    (* Does the acutal substitution.
       eList: list of things to substitute for, all are eligible and
       structurally equivalent in the context of the best_subtree.
       eList should never be empty (can't substitute nothing).

       root: root node of CSE tree (this is for crawling down stuff)
       best_pfx: best-prefix for the subtree in which all CSE nodes
       in eList can be substituted. *)
    fun runSubst(eList:lexp_cse list,root:lexp_cse,best_pfx:int list):lexp_cse =
      let

        val first = case eList of
                       x::y::xs => x
                     | _ => raise Fail "runSubst: input list too small"
        (* function f is run on all variables, inputs are:
           (deBruijn index * variable depth), output is new
           deBruijn index *)
        fun convertVars (f: int * int -> int) (CSENode(n, (d,h,p))):lexp_cse =
          let val conv = convertVars f
          in case n of
            Int_c(i)             => CSENode(n, (d,h,p))
          | Real_c(r)            => CSENode(n, (d,h,p))
          | Fn_c(e1)             => CSENode(Fn_c(conv e1), (d,h,p))
          | App_c(e1, e2)        => CSENode(App_c(conv e1,conv e2),(d,h,p))
          | Unop_c(u, e1)        => CSENode(Unop_c(u, conv e1),(d,h,p))
          | Binop_c(e1, b, e2)   => CSENode(Binop_c(conv e1,b,conv e2),(d,h,p))
          | Tuple_c(es)          => CSENode(Tuple_c(map conv es), (d,h,p))
          | Ith_c(e1, e2)        => CSENode(Ith_c(conv e1, conv e2), (d,h,p))
          | SetIth_c(e1, e2, e3) =>
              CSENode(SetIth_c(conv e1,conv e2,conv e3),(d,h,p))
          | If_c(e1, e2, e3)     =>
              CSENode(If_c(conv e1,conv e2, conv e3),(d,h,p))
          | Letrec_c(e1, e2)     => CSENode(Letrec_c(conv e1,conv e2),(d,h,p))
          | Error_c(s)           => CSENode(n, (d,h,p))
          | Var_c(i)             => CSENode(Var_c(f(i,d)), (d,h,p))
          end

        (* d = initial depth *)
        fun replacePrefix(d:int)(prefix:int list,
                                 cNode as CSENode(curr,data)):lexp_cse =
          case prefix of
              nil => CSENode(Var_c(#1(data) - d), (#1(data), 0, [0]))
            | x::xs => let
                val new = (changeIthChild (fn e=>replacePrefix d (xs,e)) x curr)
              in CSENode(new,data) end

        fun replChild(CSENode(e,data),x:int,replNode:lexp_cse):lexp_cse =
          CSENode((changeIthChild (fn e => replNode) x e),data)

        (* replacement function
           parentNode is the node before the best sub-tree
           (the one we keep the same. parentNode = NONE iff best prefix is []
           and that means we are doing substitution at root level *)
        fun doSubst(parentNode:lexp_cse option, loc:int):lexp_cse =
          case parentNode of
              NONE =>
                  (* need to do substitution and changing variables
                     because of possible free environment variables *)
                  let
                    val substTerm as CSENode(_,(d,_,_)) = first
                    val convTermVars = (fn(v_i,v_d)=> if (v_d-v_i)<=0 then v_i-d
                                                      else v_i)
                    val convBodyVars = (fn(v_i,v_d)=>(if (v_d-v_i)<=0
                                       then v_i + 1 else v_i))
                    val substTerm = convertVars convTermVars substTerm
                    val newBody = convertVars convBodyVars root
                    val prefixes = map (fn(CSENode(_,(_,_,p))) => p) eList
                    val newBody = foldl (replacePrefix 0) newBody prefixes
                  in
                    CSENode(App_c(CSENode(Fn_c(newBody),(0,0,[0])), substTerm),
                            (0,0,[0]))
                  end
            | SOME(pNode as CSENode(n, (d,h,p))) =>
                  let
                    val bodyNode as CSENode(n,(bodyDepth,_,_)) =
                            (case getNode(n, loc) of
                                NONE => raise Fail "runSubst: invalid prefix"
                              | SOME x => x)
                    val unFixedSubTerm as CSENode(_,(termDepth,_,_)) = first
                    val convTermVars =
                      (fn(v_i,v_d)=>(if (v_d-v_i)<=bodyDepth
                                       then v_i-(termDepth-bodyDepth)
                                     else v_i))
                    val convBodyVars =
                      (fn(v_i,v_d)=>(if (v_d-v_i)<=bodyDepth
                                       then v_i + 1
                                     else v_i))
                    val substTerm = convertVars convTermVars unFixedSubTerm
                    val newBody = convertVars convBodyVars bodyNode

                    val prefixes = map (fn(CSENode(_,(_,_,p))) => p) eList
                    val pLength = List.length(best_pfx)
                    val restOfpath = map (fn(pList)=>List.drop(pList, pLength))
                                         prefixes
                    val repBody = foldl (replacePrefix bodyDepth)
                                        newBody restOfpath
                  in
                    replChild(pNode, loc,
                              CSENode(App_c(CSENode(Fn_c(repBody),(0,0,[0])),
                                            substTerm),
                            (d,h,p)))
                  end

        (* crawling down tree to find place of exact
           place of substitution.  *)
        fun findParentAndSubst(cNode as CSENode(currNode, data):lexp_cse,
                               prefix:int list):lexp_cse =
          case prefix of
              nil => doSubst(NONE, 0)
              (* found parent *)
            | x::nil => doSubst(SOME cNode, x)
            (* finding parent *)
            | x::xs =>
                CSENode((changeIthChild
                         (fn e=>findParentAndSubst(e,xs)) x currNode),
                        data)
      in
        findParentAndSubst(root,best_pfx)
      end

    (* Collect almost every subexpression into a list, and
       get rid of everything whose hash only occurs once. These are suspected
       CSE's. Check using structural equality, and remove anything that's not
       really a duplicate. For the remaining n dups, find the greatest
       common prefix (also known as least common ancestor),
       check for eligiblity, and do the replace *)

    fun replace'(aRoot:lexp_cse):lexp_cse =
      let
        fun sub_collect(e as CSENode(node, dat):lexp_cse):lexp_cse list =
          case node of
            Fn_c(e1)            => e::(sub_collect e1)
          (* don't acutally substitute a function application itself *)
          | App_c(e1,e2)        => sub_collect(e1)@sub_collect(e2)
          | Unop_c(_,e1)        => e::sub_collect(e1)
          | Binop_c(e1,b,e2)    => e::(sub_collect(e1)@sub_collect(e2))
          (* don't acutally substitute a tuple itself *)
          | Tuple_c(es)         =>
              (foldl (fn (e,lst)=>(sub_collect e)@lst) nil es)
          | Ith_c(e1,e2)        => e::(sub_collect(e1)@sub_collect(e2))
          | If_c(e1,e2,e3)      => e::(sub_collect(e1)@sub_collect(e2)@
                                   sub_collect(e3))
          | Letrec_c(e1,e2)     => e::(sub_collect(e1)@sub_collect(e2))
          (* never substitute single ints, reals, variables.
             never ever EVER substitute SetIth's or anything under them *)
          | _                   => nil


        (* all sub-expressions with unique subexpressions *)
        val all_sub_exps = sub_collect aRoot

        (* removes almost all unique subexps based on hash value *)
        fun remove_uniques (l: lexp_cse list):lexp_cse list = let
          fun allEq(CSENode(_,(_,h,_))):lexp_cse list =
            List.filter (fn CSENode(_,(_,h',_)) => h'=h) l
        in List.filter (fn e => length(allEq e)>1) l end

        (* get rid of unique sub expressions *)
        val sub_exps = remove_uniques(all_sub_exps)

        (* given two int lists, return the longest prefix of both *)
        fun longestPrefix(x:int list,y:int list)=
          case (x,y) of
            (x::xs,y::ys) => if (x=y) then x::longestPrefix(xs,ys) else nil
          | _ => nil

        (* sees which elements of list are eligible in the subtree tNode *)
        fun eligibleList'(tNode as CSENode(t, data):lexp_cse,
                          sList:lexp_cse list): lexp_cse list =
            (List.filter (fn x=>eligible(x,tNode)) sList)

        (* the eligible list with the best prefix for that list
           dList should always have at least two elements *)
        fun eligibleList (root:lexp_cse,
                          dList: lexp_cse list):(lexp_cse list * int list) =
          let
            (* prefixes of every lexp in dList *)
            val pfxs = map (fn(CSENode(_,(_,_,p))) => p) dList
            val (first, rest) =
                   case pfxs of
                      x::y::xs => (x, y::xs)
                    | _ => raise Fail "eligibleList: input too small!"
            (* prefix of smallest subtree that contains all of dList *)
            val best_pfx=foldl (fn pair=>longestPrefix pair) first rest
          in
            (* if the parent is the root (code ~1), see who is eligible
             otherwise move down one step (getNode) and check in that tree *)
            case findParent(root, best_pfx) of
                NONE    => ([],[])
              | SOME(CSENode(n,_),i) =>
                  if i = ~1 then (eligibleList'(root, dList), best_pfx)
                  else (case getNode(n,i) of
                           NONE => raise Fail "eligibleList: bad prefix"
                         | SOME x => (eligibleList'(x,dList),best_pfx))
          end

        fun loopSubst(checkNode as CSENode(check, (d,hash,pre)):lexp_cse,
                      toCheck: lexp_cse list, root:lexp_cse):lexp_cse =
          let
            (* making sure of strong equality with checkNode *)
            fun test(sNode as CSENode(s,(_,h,_)):lexp_cse) =
              ((h=hash) andalso eq_exp(sNode,checkNode))
            val dups = List.filter test toCheck
          in case dups of
              nil => (case toCheck of
                        nil => root (* nothing left to check *)
                      | y::ys => loopSubst(y, ys, root))
            | _ => let
                val dups' = checkNode::dups
                val (eList, bestPfx) = eligibleList(root,dups')
              in
                case eList of
                  nil => (case toCheck of
                            nil   => root (*nothing left to check*)
                          | y::ys => loopSubst(y, ys, root))
                | _   => (replaced := true; runSubst(eList, root, bestPfx))
              end
          end
      in
        case sub_exps of
          (* need at least two items to do any meaningful substitutions *)
          x::xs::xss => loopSubst(x,xs::xss,aRoot)
        | _ => aRoot (* no substitutions *)
      end

    fun replace (e:Lambda.lexp):Lambda.lexp =
      let
        fun loop(e:lexp):lexp =
          if  (!replaced) then
            (replaced:= false;
             loop(cseTolexp(replace'(Hash.hash(lexpToCSE(e))))))
          else e
      in
        print ("Replacing...");
        replaced := false;
        let
          val e = loop(cseTolexp(replace'(Hash.hash(lexpToCSE(e)))))
        in
          print ("\n");
          e
        end
      end
  end
