structure Term : TERM  = struct

  open Util

(***********************************************
 * syntax
 ***********************************************)

  datatype term =
    TST_VAR of string
  | ACT_VAR of string
  | TST_CON of string
  | ACT_CON of string
  | PLUS of term list
  | TIMES of term list
  | NOT of term
  | STAR of term
  | ZERO
  | ONE

  type id = string
  type substitution = (id * term) list
  datatype eqn = EQ of term * term | LE of term * term
  type cond_eqn = eqn list * eqn

  fun args (EQ x) = x | args (LE x) = x
  fun equation (EQ _) = EQ | equation (LE _) = LE
  fun headsymbol (PLUS _) = PLUS | headsymbol _ = TIMES
  fun uheadsymbol (NOT _) = NOT | uheadsymbol _ = STAR

(***********************************************
 * output
 ***********************************************)

  val fixle : string -> string = String.map (fn c => if c = #"<" then #"@" else c)

  fun idToXML (x:id) : string = "<id>" ^ fixle x ^ "</id>"

  (* higher precedence binds tighter *)
  fun outPrecedence(t:term) : int =
    case t of
      PLUS _ => 0
    | TIMES _ => 1
    | NOT _ => 2
    | STAR _ => 3
    | _ => 4 (* variables and constants *)

  fun assocToString (opr:string) (id:string) (s:string list) : string =
    case s of
      [] => id
    | [x] => x
    | x::t => x^opr^(assocToString opr id t)

  fun toString (t:term) : string = let
    (* parenthesize as dictated by surrounding precedence *)
    fun protect (x:term) : string = let
      val s = toString x
    in
      if outPrecedence t <= outPrecedence x then s else "("^s^")"
    end
  in
    case t of
      (TST_VAR x | ACT_VAR x) => x
    | (TST_CON x | ACT_CON x) => x
    | PLUS x => assocToString " + " "0" (map protect x)
    | TIMES x => assocToString ";" "1" (map protect x)
    | NOT x => "~" ^ (protect x)
    | STAR x => (protect x) ^ "*"
    | ZERO => "0"
    | ONE => "1"
  end

  fun termStructure (t:term) : string =
    case t of
      TST_VAR x => "TV("^x^")"
    | ACT_VAR x => "AV("^x^")"
    | TST_CON x => "TC("^x^")"
    | ACT_CON x => "AC("^x^")"
    | PLUS x => "PLUS[" ^ concat(mapAllButLast (fn y => y^",") (map termStructure x)) ^ "]"
    | TIMES x => "TIMES[" ^ concat(mapAllButLast (fn y => y^",") (map termStructure x)) ^ "]"
    | NOT x => "NOT(" ^ termStructure x ^ ")"
    | STAR x => "STAR(" ^ termStructure x ^ ")"
    | ZERO => "0"
    | ONE => "1"

  fun termToXML (t:term) : string = "<term>" ^ fixle(toString t) ^ "</term>"

  fun substToXML (s:substitution) : string =
    let fun toXML (x:id,t:term) = "<sub>" ^ idToXML x ^ termToXML t ^ "</sub>"
    in concat(map toXML s) end

  fun opToString (EQ _) = "=" | opToString (LE _) = "<"

  fun eqnToString(e:eqn) =
    let val (s,t) = args e
    in toString s^" "^opToString e^" "^toString t
    end

  fun eqnStructure(e:eqn) =
    let val (s,t) = args e
    in termStructure s ^ " " ^opToString e ^" " ^termStructure t
    end

  fun eqnToXML (e:eqn) : string =
    "<eqn>" ^ fixle (eqnToString e) ^ "</eqn>"

  fun condEqnToString((ce,e):cond_eqn) =
    concat (map (fn d => eqnToString d ^ " -> ") ce) ^ eqnToString e

  fun condEqnStructure((ce,e):cond_eqn) =
    concat (map (fn d => eqnStructure d ^ " -> ") ce) ^ eqnStructure e

  fun condEqnToXML (ce:cond_eqn) : string =
    "<condeqn>" ^ fixle (condEqnToString ce) ^ "</condeqn>"



  fun toLaTeX (t:term) : string = let
    (* parenthesize as dictated by surrounding precedence *)
    fun protect (x:term) : string = let
      val s = toLaTeX x
    in
      if outPrecedence t <= outPrecedence x then s else "("^s^")"
    end
  in
    case t of
      (TST_VAR x | ACT_VAR x) => x
    | (TST_CON x | ACT_CON x) => x
    | PLUS x => assocToString " + " "0" (map protect x)
    | TIMES x => assocToString " \\cdot " "1" (map protect x)
    | NOT x => "\\overline{" ^ (protect x) ^ "}"
    | STAR x => (protect x) ^ "\\star"
    | ZERO => "0"
    | ONE => "1"
  end

  fun opToLaTeX (EQ _) = "=" | opToLaTeX (LE _) = "\\leq"

  fun eqnToLaTeX(e:eqn) =
    let val (s,t) = args e
        val maybeOut = "\\mathsf{"^toLaTeX s^"} & "^opToLaTeX e^"& \\mathsf{"^toLaTeX t^"}"
    in 
      if (String.size(eqnToString(EQ(s,t))) < 60) then maybeOut
      else "&&\\mathsf{"^toLaTeX s^"}\\nonumber \n\\\\&&"^opToLaTeX e^" \\mathsf{"^toLaTeX t^"}"
    end

  fun eqnToLaTeXNA(e:eqn) =
    let val (s,t) = args e
        val maybeOut ="\\mathsf{"^toLaTeX s^" "^opToLaTeX e^" "^toLaTeX t^"}"
    in 
      if (String.size(eqnToString(EQ(s,t))) < 60) then maybeOut
      else "\\begin{array}{l}\n\\mathsf{"^toLaTeX s^"}\\\\\n"^opToLaTeX e^" \\mathsf{"^toLaTeX t^"}\n\\end{array}"
    end

(***********************************************
 * utilities
 ***********************************************)

  fun isTest (t:term) : bool =
    case t of
      (TST_VAR _ | TST_CON _) => true
    | (ACT_VAR _ | ACT_CON _) => false
    | PLUS x => List.all isTest x
    | TIMES x => List.all isTest x
    | NOT x => isTest x orelse raise Fail "not a test"
    | STAR x => isTest x
    | (ZERO | ONE) => true

  fun makeConstant (t:term) : term =
    case t of
      TST_VAR x => TST_CON x
    | ACT_VAR x => ACT_CON x
    | PLUS x => PLUS (map makeConstant x)
    | TIMES x => TIMES (map makeConstant x)
    | NOT x => NOT (makeConstant x)
    | STAR x => STAR (makeConstant x)
    | _ => t

  fun makeConstantEqn (e:eqn) : eqn =
    case e of
      EQ(x,y) => EQ(makeConstant x,makeConstant y)
    | LE(x,y) => LE(makeConstant x,makeConstant y)

  fun makeConstantCondEqn ((pre,con):cond_eqn) : cond_eqn =
    (map makeConstantEqn pre,makeConstantEqn con)

  fun makeVariable (t:term) : term =
    case t of
      TST_CON x => TST_VAR x
    | ACT_CON x => ACT_VAR x
    | PLUS x => PLUS (map makeVariable x)
    | TIMES x => TIMES (map makeVariable x)
    | NOT x => NOT (makeVariable x)
    | STAR x => STAR (makeVariable x)
    | _ => t

  fun makeVariableEqn (e:eqn) : eqn =
    case e of
      EQ(x,y) => EQ(makeVariable x,makeVariable y)
    | LE(x,y) => LE(makeVariable x,makeVariable y)

  fun makeVariableCondEqn ((pre,con):cond_eqn) : cond_eqn =
    (map makeVariableEqn pre,makeVariableEqn con)

  fun makeVariableSubst (s:substitution) : substitution =
    map (fn (x,y) => (x,makeVariable y)) s

  fun variables (t:term) : id list =
    case t of
      (TST_VAR x | ACT_VAR x) => [x]
    | (PLUS x | TIMES x) => removeDuplicates(List.concat (map variables x))
    | (NOT x | STAR x) => variables x
    | _ => []

  fun substVariables (s:substitution) : id list = map #1 s

  fun eqnVariables (e:eqn) : id list =
    let val (s,t) = args e
    in removeDuplicates ((variables s) @ (variables t))
    end

  fun condEqnVariables ((pre,con):cond_eqn) : id list =
    removeDuplicates(List.concat(map eqnVariables (con::pre)))

  (* flatten terms *)
  fun flatten (t:term) : term =
    case t of
      PLUS x => let
        val y = map flatten x
        val z = List.concat(map (fn u => case u of PLUS v => v | _ => [u]) y)
      in
        case z of [] => ZERO | [x] => x | _ => PLUS z
      end
    | TIMES x => let
        val y = map flatten x
        val z = List.concat(map (fn u => case u of TIMES v => v | _ => [u]) y)
      in
        case z of [] => ONE | [x] => x | _ => TIMES z
      end
    | NOT x => NOT (flatten x)
    | STAR x => STAR (flatten x)
    | _ => t

(***********************************************
 * simplify
 ***********************************************)

  (* convert empty sums and products to 0 and 1, resp,
   * and sums and products of one element to that element,
   * combine adjacent sums and products *)
  fun simplifyLite (t:term) : term =
    case t of
      PLUS [] => ZERO
    | PLUS [x] => simplifyLite x
    | PLUS y => let
        val y' = map simplifyLite y
        fun f(PLUS z,zz) = z@zz | f(x,zz) = x::zz
      in
        PLUS(foldr f [] y')
      end
    | TIMES [] => ONE
    | TIMES [t] => simplifyLite t
    | TIMES y => let
        val y' = map simplifyLite y
        fun f(TIMES z,zz) = z@zz | f(x,zz) = x::zz
      in
        TIMES (foldr f [] y')
      end
    | NOT x => NOT (simplifyLite x)
    | STAR x => STAR (simplifyLite x)
    | _ => t

  fun simplify (t:term) : term =
    case t of
      PLUS x =>
        let
          val x1 = map simplify x
          val x2 = List.filter (fn ZERO => false | _ => true) x1
          val x3 = foldr (fn(h,s as h'::t) => if h = h' then s else h::s
                          | (h,s) => h::s) [] x2
        in simplifyLite(PLUS x3)
        end
    | TIMES x =>
        let
          val x1 = map simplify x
          val x2 = List.filter (fn ONE => false | _ => true) x1
          val x3 = if member ZERO x2 then [] else x2
        in simplifyLite(PLUS x3)
        end
    | NOT(NOT x) => simplify x
    | NOT x => NOT(simplify x)
    | STAR(STAR x) => simplify (STAR x)
    | STAR x =>
        let val x' = simplify x
        in if isTest x' then ONE else STAR x'
        end
    | _ => t



(***********************************************
 * parser
 ***********************************************)

  datatype ops = Plus | Times | Star | Not | Lparen | Rparen

  datatype parsand = OP of ops | TERM of term

  fun parseId(s:string) =
    if Char.isUpper(String.sub(s,0)) then TST_VAR s else 
    ACT_VAR(if Char.contains "\"" (String.sub(s,0)) 
       then (String.substring(s,1,String.size(s)-2)) else s)

  fun tokenize (s:string) : parsand list = let
    val inquotes = ref false
    fun sep (c:char) =
      ((if Char.contains "\"" c then inquotes := not(!inquotes) else ());
      if Char.contains "()*;+~01" c andalso not(!inquotes) then " "^str c^" " else str c)
    val st = String.translate sep s
    val tokens = String.tokens (eq #" ") st
    fun xlate(s:string):parsand =
      case s of
        "(" => OP Lparen
      | ")" => OP Rparen
      | "+" => OP Plus
      | "*" => OP Star
      | ";" => OP Times
      | "~" => OP Not
      | "0" => TERM ZERO
      | "1" => TERM ONE
      | _ => TERM (parseId s)
  in
    map xlate tokens
  end

  fun reduce (p:parsand list) : parsand list =
    case p of
      TERM x::OP Plus::TERM y::t => TERM(PLUS [y,x])::t
    | TERM x::OP Times::TERM y::t => TERM(TIMES [y,x])::t
    | TERM x::OP Not::t => TERM(NOT x)::t
    | _ => raise Fail "parse failed (reduce)"

  (* higher precedence binds tighter *)
  fun inPrecedence (p:ops) : int =
    case p of
      Star => 4
    | Not => 3
    | Times => 2
    | Plus => 1
    | Lparen => 0
    | Rparen => 0

  fun parseTerm(s:string) : term option = let
    fun parse'(p:parsand list, q:parsand list) : term =
      case (p,q) of
        ([TERM t], []) => t
      | (_,[]) => parse'(reduce p,q)
      | ([],x::t) => parse'([x],t)
      | (TERM _::_, TERM _::_) => parse'(p,OP Times::q)
      | (TERM _::_, OP Lparen::_) => parse'(p,OP Times::q)
      | (TERM _::_, OP Not::_) => parse'(p,OP Times::q)
      | (TERM t::xs, OP Star::ys) => parse'(TERM (STAR t)::xs,ys)
      | ([TERM _], OP y::ys) => parse'(OP y::p,ys)
      | (TERM x::OP Lparen::xs, OP Rparen::ys) => parse'(TERM x::xs,ys)
      | (TERM t::OP x::xs, OP y::ys) =>
          if inPrecedence x <= inPrecedence y
            then parse'(OP y::p,ys)
            else parse'(reduce p,q)
      | (OP _::_, y::ys) => parse'(y::p,ys)
      | _ => raise Fail "parse failed"
  in
    SOME (flatten(simplifyLite(parse'([],tokenize s))))
  end
  handle Fail x => (println x; NONE)

  fun parseEqnTokenized(e:string list) : eqn option = let
    fun f(s:string) =
      case s of
        "<=" => "/</"
      | "<" => "/</"
      | "=" => "/=/"
      | _ => s^" "
    val t2 = map f e
    val t3 = String.tokens (eq #"/") (String.concat t2)
    val (e1,oper,e2) = case t3 of
      [es1,oper,es2] => (parseTerm es1,oper,parseTerm es2)
    | _ => raise Fail "invalid equation"
  in
    case (e1,oper,e2) of
      (SOME t1,"=",SOME t2) => SOME(EQ(t1,t2))
    | (SOME t1,"<",SOME t2) => SOME(LE(t1,t2))
    | _ => raise Fail "invalid equation"
  end
  handle Fail x => (println x; NONE)

  fun parseEqn(e:string) : eqn option =
    parseEqnTokenized (String.tokens Char.isSpace e)

  fun parseCondEqn(ce:string) : cond_eqn option = let
    val t1 = String.tokens Char.isSpace ce
    fun f(s:string) = case s of "->" => "/" | _ => s^" "
    val t2 = map f t1
    val t3 = String.tokens (eq #"/") (String.concat t2)
    val eqns = map (valOf o parseEqn) t3
      handle Option => raise Fail "parse failed"
  in
    case rev eqns of [] => NONE
    | x::t => SOME (rev t,x)
  end
  handle Fail x => (println x; NONE)


  (* CPS function to prompt user for terms that need filled in
   * Function is for CLI interface, update reference for other
   * interfaces.  Constructs list of all user-inputted terms.
   * When all holes filled in, pass to a continuation, which
   * will be the actual citation of a theorem                  *)
  val getMissing = ref (let
    fun getMiss(ses:id list) (missedterms:term list) (cont:term list -> unit) =
    (case ses of
      [] => cont (missedterms)
    | (s::rest) => let
        val inputLine = ((!print) (s ^ "=? "); TextIO.inputLine TextIO.stdIn)
        val tokenizedInput = String.tokens Char.isSpace inputLine
        val _ = if tokenizedInput = [] then raise Fail "operation canceled" else ()
        val term = parseTerm (concat (mapAllButLast (fn x => x^" ") tokenizedInput))
       in
        case term of
          SOME t => (getMiss rest ((makeConstant t)::missedterms) cont)
        | NONE => (println "invalid term, try again"; getMiss ses missedterms cont)
       end)
   in
     getMiss
   end)

(***********************************************
 * focus
 ***********************************************)

(* The focus subterm is specified by an int list, which gives
 * a path in the expression tree.  [] means no current focus,
 * [0,...] gives a subterm of the lhs, [1,...] a subterm of the
 * rhs.  All elements of the focus except the last represent
 * a subterm of the previous term; this must be 0 for NOT and
 * STAR and an index into x for PLUS x and TIMES x.
 *
 * The last element of the focus specifies a range and is always
 * nonzero.  For a subterm of the form PLUS x or TIMES x, if the 
 * elements of the focus are i,m, this represents the subterm consisting
 * of the subarray of x of length m beginning at i.  For the (unique)
 * subterm x of NOT x or STAR x, the last two elements are always 0,1. *)

    fun focusSubterm (a:term,s:int list) = let
      fun split i j s = let
        val u = List.take(s,i)
        val v = List.drop(s,i)
        val w = List.take(v,j)
        val z = List.drop(v,j)
      in (u,w,z) end
      handle Subscript => raise Fail "system error: no such subterm"

      fun protect (t:int list) (x:term) : int * int * int =
        let val (j,k,m) = focusSubterm (x,t)
        in if outPrecedence a <= outPrecedence x then (j,k,m) else (j+1,k,m+1)
        end

      val fp = (fn(i,j,k) => i+j+k) o (protect [])
      val sum = foldr (op +)
      val opWidth = case a of PLUS _ => 3 | _ => 1
    in
      case (a,s) of
        (_,[]) => (0,size(toString a),0)
      | ((PLUS x,[i,n]) | (TIMES x,[i,n])) =>
          let
            val (u,v,w) = split i n x
            val (j,k,m) = protect [] ((headsymbol a) v)
            val j' = i*opWidth + sum j (map fp u)
            val m' = (length x - i + n)*opWidth + sum m (map fp w)
          in (j',k,m')
          end
      | ((PLUS x,i::j::t) | (TIMES x,i::j::t)) =>
          let
            val (u,[v],w) = split i 1 x
            val (j,k,m) = protect (j::t) v
            val j' = i*opWidth + sum j (map fp u)
            val m' = (length x - i + 1)*opWidth + sum m (map fp w)
          in (j',k,m')
          end
      | (NOT x,0::t) =>
          let val t = if t = [1] then [] else t
              val (j,k,m) = protect t x
          in (j+1,k,m)
          end
      | (STAR x,0::t) =>
          let val t = if t = [1] then [] else t
              val (j,k,m) = protect t x
          in (j,k,m+1)
          end
      | _ => raise Fail "no such subterm"
    end

  fun focusToString (a:term,t:int list) : string * string = let

    val (i,j,_) = focusSubterm (a,t)
    val u = String.implode (List.tabulate (i,fn i => #" "))
    val v = String.implode (List.tabulate (j,fn i => #"-"))
    val s = toString a
  in
    (s,u^v)
  end

  fun focusInEqnToString (e:eqn,cf:int list) = let
    val (lhs,rhs) = args e
    val s = eqnToString e
    val ind =
      case cf of
        [] => String.implode (List.tabulate (size s,fn _ => #"-"))
      | 0::t => #2(focusToString(lhs,if t = [1] then [] else t))
      | 1::t => let
          val t = if t = [1] then [] else t
          val rhsPos = size(toString lhs) + size(opToString e) + 2
          val spaces = String.implode (List.tabulate (rhsPos,fn _ => #" "))
        in
          spaces ^ #2(focusToString(rhs,t))
        end
      | _ => raise Fail "system error: focus does not exist"
  in
    s^"\n"^ind
  end

  (* get the focused subterm with context *)
  (* if it is a subrange of PLUS x or TIMES x, include the length of x *)
  fun getFocusWithContext(a:term,s:int list) : term * int =
    case (a,s) of
      (_,[]) => (a,0)
    | ((NOT x,[0,1]) | (STAR x,[0,1])) => (x,1)
    | ((NOT x,0::t) | (STAR x,0::t)) => getFocusWithContext(x,t)
    | ((PLUS x,[i,1]) | (TIMES x,[i,1])) => (List.nth(x,i),length x)
    | ((PLUS x,[i,m]) | (TIMES x,[i,m])) => (flatten((headsymbol a) (List.take(List.drop(x,i),m))),length x)
    | ((PLUS x,i::j::t) | (TIMES x,i::j::t)) => getFocusWithContext(List.nth(x,i),j::t)
    | _ => raise Fail "system error: focus does not exist"
  handle Subscript => raise Fail "system error: focus does not exist"

  (* get the focused subterm *)
  fun getFocus(a:term,s:int list) : term = #1(getFocusWithContext(a,s))

  (* get the focused subterm in an equation with context *)
  fun getFocusInEqnWithContext(a:eqn,s:int list) : term * int =
    case (args a,s) of
      (((x,_),[0,1]) | ((_,x),[1,1])) => (x,2)
    | (((x,_),0::t) | ((_,x),1::t)) => getFocusWithContext(x,t)
    | _ => raise Fail ("system error: focus does not exist"^(foldl (fn (x,y) => ((Int.toString x)^y)) "" s))

  (* get the focused subterm in an equation *)
  fun getFocusInEqn(a:eqn,s:int list) : term = #1(getFocusInEqnWithContext(a,s))

  (* substitute b in a at focus position s -- return new term and focus *)
  (* result is flattened if b and a were *)
  fun focusedSubst (a:term,s:int list,b:term) : term * int list =
    case (a,s) of
      (_,[]) => (b,[])
    | ((NOT x,[0,1]) | (STAR x,[0,1])) => ((uheadsymbol a) b,[0,1])
    | ((NOT x,0::t) | (STAR x,0::t)) =>
        let val (u,f) = focusedSubst(x,t,b)
        in ((uheadsymbol a) u,0::f)
        end
    | (PLUS x,[i,m]) =>
        let val u = List.take(x,i)
            val v = List.drop(x,i+m)
            val y = case b of PLUS z => z | _ => [b]
        in (PLUS(u@y@v),[i,length y])
        end
    | (TIMES x,[i,m]) =>
        let val u = List.take(x,i)
            val v = List.drop(x,i+m)
            val y = case b of TIMES z => z | _ => [b]
        in (TIMES(u@y@v),[i,length y])
        end
    | ((PLUS x,i::t)| (TIMES x,i::t)) =>
        let val (y,f) = focusedSubst(List.nth(x,i),t,b)
            val u = List.take(x,i)
            val v = List.drop(x,i+1)
        in ((headsymbol a) (u@(y::v)),i::f)
        end
    | _ => raise Fail "system error: focus does not exist"

  (* substitute b in a at focus position s -- return eqn and new focus *)
  fun focusedSubstInEqn (a:eqn,s:int list,b:term) : eqn * int list =
    case (args a,s) of
      (_,[]) => (a,s)
    | ((x,y),[0,1]) => ((equation a) (b,y),[0,1])
    | ((x,y),[1,1]) => ((equation a) (x,b),[1,1])
    | ((x,y),0::t) =>
        let val (x,f) = focusedSubst(x,t,b)
        in ((equation a) (x,y),0::f)
        end
    | ((x,y),1::t) =>
        let val (y,f) = focusedSubst(y,t,b)
        in ((equation a) (x,y),1::f)
        end
    | _ => raise Fail "system error: focus does not exist"

   (* Get steps (as commands) necessary to get to focus *)
   fun moveToFocus f eq = let
     fun moveToFocus' f f' = let
       fun moveToRight n = if n = 0 then "" else " r"^(moveToRight (n-1))
       fun moveDown n = if n = 0 then "" else " d"^(moveDown (n-1))
       fun moveToDown n focus = let 
val mfocus = case focus of [x] => focus@[1] | _ => focus
val term = getFocusInEqn (eq,mfocus)

                            in
                               case term of
                                  (PLUS x | TIMES x) => moveDown (List.length(x) - n)
                                | _ => " d"
                            end
     in
       case f of
         nil => ""
       | [x] => ""
       | [x,y] =>  (moveToDown y (f'@[1]))^(moveToRight x)
       | (x::xs) => (moveToDown 1 (f'@[1]))^
                    (moveToRight x)^(moveToFocus' xs (f'@[x]))
     end
  in
     case f of 
       (0::xs) => " d"^(moveToFocus' xs [0])
     | (1::xs) => " d r"^(moveToFocus' xs [1])
     | _ => "" 
  end
  
 
end

