(********************************************)
(* To interpret:                            *)
(*   ocaml ocaml_mockup.ml                  *)
(* Or:                                      *)
(*   ocaml -init ocaml_mockup.ml            *)
(********************************************)

module Symb = struct
  type t = string
  let fresh =
    let c = ref 0 in
    fun () -> c := !c + 1; "a" ^ string_of_int(!c)
  let eq =
    (=)
end
(********************************************)
(* From here: copied from Corec/gaussian.ml *)
(********************************************)

module type Solver = sig
  type t
  type 'a expr
  val eval: t expr -> t
  val solve: ((Symb.t * Symb.t expr) list) * Symb.t -> t
end

(* needs to be outside to be usable outside the module
   [another solution is to not declare Gaussian as 
   implementing signature Solver, and then the constructors
   will be accessible from outside] *)
type 'a expression = Val of float 
		     | Plus of 'a expression * 'a expression
		     | Mul  of 'a expression * 'a expression
		     | EUnknown of 'a

module Gaussian (*: Solver*) = struct
  type t = float
  type 'a expr = 'a expression
  let rec eval e = 
    match e with
      | Val f -> f
      | Plus(a,b) -> eval(a) +. eval(b)
      | Mul(a,b) -> eval(a) *. eval(b)
      | EUnknown x -> x

  (********************************************)
  (* From here: copied from Corec/gaussian.ml *)
  (********************************************)
  type eq = (Symb.t * t) list * t
  let rec sum_aux c1 eq1 c2 eq2 =
    (* calculates c2 * eq2 - c1 * eq1 without the constants *)
    match eq1, eq2 with
      | [], [] -> []
      | (v1, f1) :: t1, (v2, f2) :: t2 when v1 = v2 ->
	let res2 = sum_aux c1 t1 c2 t2 in 
	let total = c2 *. f2 -. c1 *. f1 in
	if total = 0.0 then res2 
        else (v1, total) :: res2
      | (v1, f1) :: t1, [] ->
        let res2 = sum_aux c1 t1 c2 eq2 in
        (v1, -. c1 *. f1) :: res2
      | (v1, f1) :: t1, (v2, _) :: _ when v1 < v2 ->
        let res2 = sum_aux c1 t1 c2 eq2 in
        (v1, -. c1 *. f1) :: res2
      | [], (v2, f2) :: t2 | _ :: _, (v2, f2) :: t2 ->
        let res2 = sum_aux c1 eq1 c2 t2 in
        (v2, c2 *. f2) :: res2

  let sum c1 eq1 c2 eq2 (* calculates c2 * eq2 - c1 * eq1 *) =
    let (e1, f1) = eq1 and (e2, f2) = eq2 in
    let res2 = sum_aux c1 e1 c2 e2 in
    (res2, c2 *. f2 -. c1 *. f1)

  let eliminate v f (eq1, f1) (eq2, f2) =
    (* eliminates v in eq2 using eq1, where (v,f) was in front of eq1 
       returns the changed eq2 *)
    try
      let c2 = List.assoc v eq2 in
      let res = sum f (eq2, f2) c2 ((v,f) :: eq1, f1) in
      try 
        let (eq, f) = res in
        let epsilon = List.assoc v eq in (* for floating point errors *)
        print_string ("epsilon = " ^ v ^ " " ^ (string_of_float epsilon) ^ "\n");
        (List.remove_assoc v eq, f)
      with Not_found -> res
    with
	Not_found -> (eq2, f2)

  let rec solve_aux (x:Symb.t) (eqs:eq list) (isol: eq list) =
    match eqs with
      | [] -> isol
      | eq :: tl -> match eq with
          | [], 0.0 -> failwith "Infinite number of solutions" (* check *)
          | [], _   -> failwith "No solution"
          | (v, f) :: t, i -> 
            if (f = 0.0) then failwith "This should not be zero"
            else solve_aux x (List.map (eliminate v f (t, i)) tl) (eq :: isol)

  let rec compute_solutions (sols : (Symb.t * t) list) 
      (isol : eq list) = match isol with
	| [] -> sols
	| ((v,f) :: eq, c) :: tl -> 
	  let val_v =
	    (-. 1. /. f) *.
	      (List.fold_left
		 (fun acc (v, coeff) -> acc +. coeff *. (List.assoc v sols))
		 c eq) in
	  compute_solutions ((v, val_v) :: sols) tl
	| _ -> failwith "Impossible"

  let rec eliminate_doubles (* in a sorted list of pairs *) = function
    | [] -> []
    | [ p ] -> [ p ]
    | (v1, f1) :: (v2, f2) :: t when v1 = v2 ->
      let total = f1 +. f2 in
      if total = 0.0 
      then eliminate_doubles t
      else eliminate_doubles ((v1,total) :: t)
    | (v1, f1) :: (v2, f2) :: t ->
      (v1, f1) :: (eliminate_doubles ((v2, f2) :: t))
  (********************************************)
  (* To here **********************************)
  (********************************************)

  let rec convert_eq_aux (e:Symb.t expr) : eq =
    match e with
      | Plus(e1, e2) -> let l1, f1 = (convert_eq_aux e1) in
			let l2, f2 = (convert_eq_aux e2) in
			l1 @ l2, f1 +. f2
      | Mul(EUnknown(x), Val i) | Mul(Val i, EUnknown(x)) -> 
	[ (x, i) ], 0.
      | Val f -> [ ], f
      | EUnknown x -> [(x,1.)], 0.
      | _ -> failwith "Expression not suitable for gaussian elimination"
	
  let convert_eq (i, e: Symb.t * Symb.t expr) : eq =
    let l, f = convert_eq_aux e in
    (i, -1.) :: l, f

  (********************************************)
  (* And from here  ***************************)
  (********************************************)
  let solve (lst,x) =
    (* each equation is supposed sorted *)
    let eqs0 = List.map convert_eq lst in
    let eqs1 = List.map 
      (fun eq ->
        List.sort (fun (v1, _) (v2, _) -> EString compare v1 v2) (fst eq), snd eq)
      eqs0 in
    (* and with no doubles and no zeroes *)
    let eqs = List.map (fun (eq, f) -> eliminate_doubles eq, f) eqs1 in
    let isol = solve_aux x eqs [] in
    let sols = compute_solutions [] isol in
    List.assoc x sols
  (********************************************)
  (* To here **********************************)
  (********************************************)
end

type coin = H | T | Flip of float * coin * coin

let q = 0.7
let p = 0.3

let t = Flip(q,T,H)
let s = Flip(p,H,t)

(* Ideally *)
let (*co*)rec(*[Gaussian]*) pr_heads = function
  | H -> 1.
  | T -> 0.
  | Flip(p, v, w) -> p *. pr_heads v +. (1. -. p) *. pr_heads w

let r1 = pr_heads H

(* But in CoCaml it would be *)
let (*co*)rec(*[Gaussian]*) pr_heads = function
  | H -> Val 1.
  | T -> Val 0.
  | Flip(p, v, w) -> Plus(
      Mul(Val p, pr_heads v),
      Mul(Val (1. -. p), pr_heads w))

let r2 = pr_heads H
(* r3 is thunked, but raises a stack overflow on evaluation *)
let r3 = fun () -> let rec x = Flip(0.5, H, Flip(0.5, T, x)) in
		   pr_heads x




      
