(*******************************)
(* Code common to all solution *)
(*******************************)
      
(********************************************)
(* From here: copied from Corec/gaussian.ml *)
(* (and slightly adapted)                   *)
(********************************************)
let rec sum_aux c1 eq1 c2 eq2 =
  (* calculates c2 * eq2 - c1 * eq1 without the constants *)
  match eq1, eq2 with
    | [], [] -> []
    | (v1, f1) :: t1, [] ->
      let res2 = sum_aux c1 t1 c2 eq2 in
      (v1, -. c1 *. f1) :: res2
    | (v1, f1) :: t1, (v2, f2) :: t2 -> 
      if v1 = v2 then
	let res2 = sum_aux c1 t1 c2 t2 in 
	let total = c2 *. f2 -. c1 *. f1 in
	if total = 0.0 then res2 
        else (v1, total) :: res2
      else if v1 < v2 then
        let res2 = sum_aux c1 t1 c2 eq2 in
        (v1, -. c1 *. f1) :: res2
      else 
        let res2 = sum_aux c1 eq1 c2 t2 in
        (v2, c2 *. f2) :: res2
    | [], (v2, f2) :: t2 ->
      let res2 = sum_aux c1 eq1 c2 t2 in
      (v2, c2 *. f2) :: res2
	
let sum c1 eq1 c2 eq2 (* calculates c2 * eq2 - c1 * eq1 *) =
  match eq1, eq2 with (e1, f1), (e2, f2) ->
    let res2 = sum_aux c1 e1 c2 e2 in
    (res2, c2 *. f2 -. c1 *. f1)

let eliminate v f eq1f1 eq2f2 = match eq1f1, eq2f2 with
    (eq1, f1), (eq2, f2) ->
	(* eliminates v in eq2 using eq1, where (v,f) was in front of eq1 
	   returns the changed eq2 *)
      match assoc_option v eq2 with
	| Some(c2) -> (
          match sum f (eq2, f2) c2 ((v,f) :: eq1, f1) with (eq, f) ->
            match assoc_option v eq with (* for floating point errors *)
	      | Some epsilon -> remove_assoc v eq, f
	      | None -> (eq, f)
	)
	| None -> (eq2, f2)

let rec solve_aux (x:'var) (eqs:'a list) (isol:'a list) =
  match eqs with
    | [] -> isol
    | eq :: tl -> match eq with
        | [], 0.0 -> failwith "Infinite number of solutions" (* check *)
        | [], ig   -> failwith "No solution"
        | (v, f) :: t, i -> 
          if (f = 0.0) then failwith "This should not be zero"
          else solve_aux x (map (eliminate v f (t, i)) tl) (eq :: isol)

let rec compute_solutions (sols : ('var * float) list) 
    (isol : 'eq list) = match isol with
      | [] -> sols
      | ((v,f) :: eq, c) :: tl -> 
	let val_v =
	  (-. 1. /. f) *.
	    (fold_left
	       (fun acc vcoeff -> 
		 acc +. (snd vcoeff) *. (assoc (fst vcoeff) sols))
	       c eq) in
	compute_solutions ((v, val_v) :: sols) tl
      | ig -> failwith "Impossible"

let rec eliminate_doubles (* in a sorted list of pairs *) a = match a with
  | [] -> []
  | [ p ] -> [ p ]
  | (v1, f1) :: (v2, f2) :: t -> if v1 = v2 then
      let total = f1 +. f2 in
      if total = 0.0 
      then eliminate_doubles t
      else eliminate_doubles ((v1,total) :: t)
    else
      (v1, f1) :: (eliminate_doubles ((v2, f2) :: t))

let udgaussian convert_eq_aux (x:'var) (lst:('var * 'expr) list) =
  (* 'var is symbol in JB's solution, it is string in Dexter's *)
  (* each equation is supposed sorted *)
  let convert_eq ie = match ie with (i, e) ->
    match convert_eq_aux e with l, f ->
      (i, -.1.) :: l, f in
  let eqs0 = map convert_eq lst in
  let eqs1 = map 
    (fun eq ->
      sort 
	(fun v1ig1 v2ig2 -> match v1ig1, v2ig2 with (v1, ig1), (v2, ig2) -> 
	  if v1 > v2 then 1 else if v1 = v2 then 0 else -1)
	(fst eq), snd eq)
    eqs0 in
  (* and with no doubles and no zeroes *)
  let eqs = map (fun eqf -> match eqf with (eq, f) ->
    eliminate_doubles eq, f) eqs1 in
  let isol = solve_aux x eqs [] in
  let sols = compute_solutions [] isol in
  assoc x sols  
(********************************************)
(* To here **********************************)
(********************************************)

(* Also look at probability.ml for use of provided gaussian solver *)

type tree = Heads | Tails | Flip of float * tree * tree

let rec coin0 = Flip(0.4, Heads, Flip(0.3, Tails, coin0))
let rec coin1 = Flip(0.5, Heads, Flip(0.5, Tails, coin1))

(***************************************)
(* Using Dexter's user defined solvers *)
(***************************************)

type 'a dexpr =
  DVal of float
| DPlus of 'a dexpr * 'a dexpr
| DMul of 'a dexpr * 'a dexpr
| Unk of 'a

(* Main difference with JB: the Unk in 'a dexpr below,
   and the five functions below: t, var, unk, fresh, eq *)
(* type t = float *)
(* type var = string *)
let unk (s:string) : string dexpr = Unk s
let fresh : unit -> string = let c = 0 in 
			     (fun (x:unit) -> c := c+1; string_of_int c)

let rec dconvert_eq_aux e = match e with
  | DPlus(e1, e2) -> (match dconvert_eq_aux e1, dconvert_eq_aux e2
    with (l1, f1), (l2, f2) -> (append l1 l2), f1 +. f2)
  | DMul(Unk(x), DVal i) -> [ x, i ], 0.
  | DMul(DVal i, Unk(x)) -> [ x, i ], 0.
  | DVal f -> [ ], f
  | Unk x -> [(x,1.)], 0.
  | e -> failwith "Expression not suitable for gaussian elimination"

let dgaussian = udgaussian dconvert_eq_aux

(**** Probability of Heads ****)
let corec[unk, fresh, dgaussian] probability t = match t with
  | Heads -> DVal 1.
  | Tails -> DVal 0.
  | Flip(p, v, w) -> DPlus(
      DMul(DVal p, probability v),
      DMul(DVal (1. -. p), probability w))

let r13 = probability Heads
let r14 = probability Tails
let r15 = probability (Flip(0.3, Heads, Tails))
let r16 = probability (Flip(0.3, Tails, Heads))
let r17 = probability coin0
let r18 = probability coin1

(**** Expected number of flips ****)
let corec[unk, fresh, dgaussian] flips t = match t with
    Heads -> DVal 0. | Tails -> DVal 0.
  | Flip(p, t1, t2) -> DPlus(
    DVal 1., 
    DPlus(DMul(DVal p, flips t1),
	  DMul(DVal (1. -. p), flips t2)))

let r19 = flips Heads
let r20 = flips Tails
let r21 = flips (Flip(0.3, Heads, Tails))
let r22 = flips (Flip(0.3, Tails, Heads))

let r23 = flips coin0
let r24 = flips coin1


