(* Brittany Nkounkou *)
(* August 2020 *)
(* States *)

Require Export Basics.
Require FMaps FSets.

Set Implicit Arguments.

(* build a DecidableType.DecidableType from a DecidableType *)
Module DTDT (Import DT : DecidableType) <: DecidableType.DecidableType.

(* type *)
Definition t : Type :=
  t.
(* equality *)
Definition eq : t -> t -> Prop :=
  eq.

(* equality reflexivity *)
Definition eq_refl : forall x, eq x x :=
  @eq_refl t.
(* equality symmetry *)
Definition eq_sym : forall x y, eq x y -> eq y x :=
  @eq_sym t.
(* equality transitivity *)
Definition eq_trans : forall x y z, eq x y -> eq y z -> eq x z :=
  @eq_trans t.

(* decidable equality *)
Definition eq_dec : forall x y, { eq x y } + { ~ eq x y } :=
  dec.

End DTDT.

(* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *)

(* map with added lemmas *)
Module map' (K : DecidableType).
Module K' := DTDT K.
Module Import map := FMapWeakList.Make K'.
Module Import facts := FMapFacts.Facts map.

(* for any unequal keys, add order does not matter *)
Lemma add_add V k1 k2 (v1 v2 : V) m :
  k1 <> k2 -> Equal (add k1 v1 (add k2 v2 m)) (add k2 v2 (add k1 v1 m)).
Proof.
  intro. apply Equal_mapsto_iff. etransitivity. apply add_mapsto_iff. symmetry.
  etransitivity. apply add_mapsto_iff. intuition.
    right. split. intro. apply H. etransitivity; eauto.
      apply add_mapsto_iff. auto.
    edestruct add_mapsto_iff. destruct H1. eauto. auto.
      right. destruct H1. split. auto. apply add_mapsto_iff. right. auto.
    right. split. intro. apply H. etransitivity; eauto.
      apply add_mapsto_iff. auto.
    edestruct add_mapsto_iff. destruct H1. eauto. auto.
      right. destruct H1. split. auto. apply add_mapsto_iff. right. auto.
Qed.

(* for any two keys, remove order does not matter *)
Lemma remove_remove V k1 k2 (m : t V) :
  Equal (remove k1 (remove k2 m)) (remove k2 (remove k1 m)).
Proof.
  apply Equal_mapsto_iff. etransitivity. apply remove_mapsto_iff. symmetry.
  etransitivity. apply remove_mapsto_iff. intuition.
    eapply remove_mapsto_iff; eauto.
    apply remove_mapsto_iff. split. auto. eapply remove_mapsto_iff. eauto.
    eapply remove_mapsto_iff; eauto.
    apply remove_mapsto_iff. split. auto. eapply remove_mapsto_iff. eauto.
Qed.

(* for any unequal keys, add/remove order does not matter *)
Lemma add_remove V k1 k2 (v1 : V) m :
  k1 <> k2 -> Equal (add k1 v1 (remove k2 m)) (remove k2 (add k1 v1 m)).
Proof.
  intro. apply Equal_mapsto_iff. etransitivity. apply add_mapsto_iff. symmetry.
  etransitivity. apply remove_mapsto_iff. intuition.
    edestruct add_mapsto_iff. destruct H0. eauto. auto.
      right. destruct H0. split. auto. apply remove_mapsto_iff. auto.
    apply H. etransitivity; eauto.
    apply add_mapsto_iff. auto.
    eapply remove_mapsto_iff; eauto.
    apply add_mapsto_iff. right. split. auto. eapply remove_mapsto_iff. eauto.
Qed.

(* for any unequal keys, find/add order does not matter *)
Lemma find_add V k1 k2 (v1 : V) m :
  k1 <> k2 -> find k2 (add k1 v1 m) = find k2 m.
Proof.
  intro. case_eq (find k2 m); intros.
    apply find_mapsto_iff. apply add_mapsto_iff. right. split. auto.
      apply find_mapsto_iff. auto.
    apply not_find_in_iff. intro. eapply not_find_in_iff. eauto.
      eapply add_neq_in_iff; eauto.
Qed.

(* for any unequal keys, find/remove order does not matter *)
Lemma find_remove V k1 k2 (m : t V) :
  k1 <> k2 -> find k2 (remove k1 m) = find k2 m.
Proof.
  intro. case_eq (find k2 m); intros.
    apply find_mapsto_iff. apply remove_mapsto_iff. split. auto.
      apply find_mapsto_iff. auto.
    apply not_find_in_iff. intro. eapply not_find_in_iff. eauto.
      eapply remove_neq_in_iff; eauto.
Qed.

End map'.

(* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *)

(* set with added lemmas *)
Module set' (T : DecidableType).
Module T' := DTDT T.
Module Import set := FSetWeakList.Make T'.
Module Import facts := FSetFacts.Facts set.

(* for any two elements, add order does not matter *)
Lemma add_add e1 e2 s :
  Equal (add e1 (add e2 s)) (add e2 (add e1 s)).
Proof.
  intro. etransitivity. apply add_iff. symmetry. etransitivity. apply add_iff.
  intuition.
    right. apply add_iff. auto.
    edestruct add_iff. destruct H. eauto. auto. right. apply add_iff. auto.
    right. apply add_iff. auto.
    edestruct add_iff. destruct H. eauto. auto. right. apply add_iff. auto.
Qed.

(* for any two elements, remove order does not matter *)
Lemma remove_remove e1 e2 s :
  Equal (remove e1 (remove e2 s)) (remove e2 (remove e1 s)).
Proof.
  intro. etransitivity. apply remove_iff. symmetry. etransitivity.
  apply remove_iff. intuition.
    apply remove_iff. split; auto. eapply remove_iff. eauto.
    eapply remove_iff; eauto.
    apply remove_iff. split; auto. eapply remove_iff. eauto.
    eapply remove_iff; eauto.
Qed.

(* for any unequal elements, add/remove order does not matter *)
Lemma add_remove e1 e2 s :
  e1 <> e2 -> Equal (add e1 (remove e2 s)) (remove e2 (add e1 s)).
Proof.
  do 2 intro. etransitivity. apply add_iff. symmetry. etransitivity.
  apply remove_iff. intuition.
    edestruct add_iff. destruct H0. eauto. auto. right. apply remove_iff. auto.
    apply add_iff. auto.
    apply H. etransitivity; eauto.
    apply add_iff. right. eapply remove_iff. eauto.
    eapply remove_iff; eauto.
Qed.

End set'.

(* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *)

Module MStates (env : Environment).
Module Export M := MBasics env.

(* map with data variables as keys *)
Module dvarmap' := map' dvar.
Module dvarmap := dvarmap'.map.

(* map with channels as keys *)
Module chanmap' := map' chan.
Module chanmap := chanmap'.map.

(* set of channels *)
Module chanset' := set' chan.
Module chanset := chanset'.set.

(* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *)

Module state.

(* state *)
Record t : Type :=
  make {
    dvars : dvarmap.t val.t;
    dprbs : chanmap.t val.t;
    sprbs : chanset.t;
    rprbs : chanset.t;
  }.

(* state equality *)
Record eq s s' : Prop :=
  makeq {
    eq_dvars : dvarmap.Equal (dvars s) (dvars s');
    eq_dprbs : chanmap.Equal (dprbs s) (dprbs s');
    eq_sprbs : chanset.Equal (sprbs s) (sprbs s');
    eq_rprbs : chanset.Equal (rprbs s) (rprbs s');
  }.

(* the zero state *)
Definition zero : t :=
  make (dvarmap.empty _) (chanmap.empty _) chanset.empty chanset.empty.

(* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *)

(* set dvar option value *)
Definition set_dvar sig x (ok : option val.t) : t :=
  make
    match ok with
    | Some k => dvarmap.add x k (dvars sig)
    | None => dvarmap.remove x (dvars sig)
    end
    (dprbs sig)
    (sprbs sig)
    (rprbs sig).

(* set dprb option value *)
Definition set_dprb sig A (ok : option val.t) : t :=
  make
    (dvars sig)
    match ok with
    | Some k => chanmap.add A k (dprbs sig)
    | None => chanmap.remove A (dprbs sig)
    end
    (sprbs sig)
    (rprbs sig).

(* set sprb boolean *)
Definition set_sprb sig A (b : bool) : t :=
  make
    (dvars sig)
    (dprbs sig)
    (if b then chanset.add A (sprbs sig) else chanset.remove A (sprbs sig))
    (rprbs sig).

(* set rprb boolean *)
Definition set_rprb sig A (b : bool) : t :=
  make
    (dvars sig)
    (dprbs sig)
    (sprbs sig)
    (if b then chanset.add A (rprbs sig) else chanset.remove A (rprbs sig)).

(* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *)

(* get dvar option value *)
Definition get_dvar sig x : option val.t :=
  dvarmap.find x (dvars sig).

(* get dprb option value *)
Definition get_dprb sig A : option val.t :=
  chanmap.find A (dprbs sig).

(* get sprb boolean *)
Definition get_sprb sig A : bool :=
  chanset.mem A (sprbs sig).

(* get rprb boolean *)
Definition get_rprb sig A : bool :=
  chanset.mem A (rprbs sig).

(* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *)

(* expression evaluation *)
Definition eval_expr sig e : option val.t :=
  expr_eval (get_dvar sig) e.

(* guard evaluation *)
Fixpoint eval_guard sig G : option bool :=
  match G with
  | guard.bool b => Some b
  | guard.sprb A => Some (get_sprb sig A)
  | guard.rprb A => Some (get_rprb sig A)
  | guard.dprb A e =>
    match get_dprb sig A, eval_expr sig e with
    | Some k1, Some k2 => Some (if val.dec k1 k2 then true else false)
    | None, Some _ => Some false
    | _, None => None
    end
  | guard.expr e1 e2 =>
    match eval_expr sig e1, eval_expr sig e2 with
    | Some k1, Some k2 => Some (if val.dec k1 k2 then true else false)
    | _, _ => None
    end
  | guard.neg G =>
    match eval_guard sig G with
    | Some b => Some (negb b)
    | _ => None
    end
  | guard.and G1 G2 =>
    match eval_guard sig G1, eval_guard sig G2 with
    | Some b1, Some b2 => Some (andb b1 b2)
    | _, _ => None
    end
  | guard.or G1 G2 =>
    match eval_guard sig G1, eval_guard sig G2 with
    | Some b1, Some b2 => Some (orb b1 b2)
    | _, _ => None
    end
  end.

(* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *)

(* equality reflexivity *)
Lemma eq_refl s : eq s s.
Proof.
  split; reflexivity.
Qed.

(* equality symmetry *)
Lemma eq_sym s s' : eq s s' -> eq s' s.
Proof.
  intro. destruct H. split; symmetry; auto.
Qed.

(* equality transitivity *)
Lemma eq_trans s s' s'': eq s s' -> eq s' s'' -> eq s s''.
Proof.
  intros. destruct H, H0. split; etransitivity; eauto.
Qed.

(* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *)

(* equal states produce equal set_dvar states *)
Lemma eq_set_dvar_eq sig sig' x ok :
  state.eq sig sig' -> state.eq (set_dvar sig x ok) (set_dvar sig' x ok).
Proof.
  intro. destruct H. split; auto. simpl.
  destruct ok; rewrite eq_dvars0; reflexivity.
Qed.

(* equal states produce equal set_dprb states *)
Lemma eq_set_dprb_eq sig sig' A ok :
  state.eq sig sig' -> state.eq (set_dprb sig A ok) (set_dprb sig' A ok).
Proof.
  intro. destruct H. split; auto. simpl.
  destruct ok; rewrite eq_dprbs0; reflexivity.
Qed.

(* equal states produce equal set_sprb states *)
Lemma eq_set_sprb_eq sig sig' A b :
  state.eq sig sig' -> state.eq (set_sprb sig A b) (set_sprb sig' A b).
Proof.
  intro. destruct H. split; auto. simpl.
  destruct b; rewrite eq_sprbs0; reflexivity.
Qed.

(* equal states produce equal set_rprb states *)
Lemma eq_set_rprb_eq sig sig' A b :
  state.eq sig sig' -> state.eq (set_rprb sig A b) (set_rprb sig' A b).
Proof.
  intro. destruct H. split; auto. simpl.
  destruct b; rewrite eq_rprbs0; reflexivity.
Qed.

(* equal states produce equal get_dprbs *)
Lemma eq_get_dprb_eq sig sig' A :
  state.eq sig sig' -> get_dprb sig A = get_dprb sig' A.
Proof.
  intro. destruct H. unfold get_dprb. rewrite eq_dprbs0. auto.
Qed.

(* equal states produce equal eval_exprs *)
Lemma eq_eval_expr_eq sig sig' e :
  state.eq sig sig' -> eval_expr sig e = eval_expr sig' e.
Proof.
  intro. destruct H. unfold eval_expr. apply expr_eval_eq. unfold get_dvar.
  intro. rewrite eq_dvars0. auto.
Qed.

(* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *)

(* for any unequal data variables, set_dvar order does not matter *)
Lemma set_dvar_set_dvar sig x x' ok ok' :
  x <> x' -> eq (set_dvar (set_dvar sig x ok) x' ok')
    (set_dvar (set_dvar sig x' ok') x ok).
Proof.
  intro. split; simpl; try intro; try reflexivity. destruct ok, ok'.
    symmetry. apply dvarmap'.add_add. auto.
    symmetry. apply dvarmap'.add_remove. auto.
    apply dvarmap'.add_remove. auto.
    apply dvarmap'.remove_remove.
Qed.

(* for any data variable and channel, set_dvar/set_dprb order does not matter *)
Lemma set_dvar_set_dprb sig A x ok ok' :
  eq (set_dvar (set_dprb sig A ok) x ok')
    (set_dprb (set_dvar sig x ok') A ok).
Proof.
  split; simpl; try intro; try reflexivity.
Qed.

(* for any data variable and channel, set_dvar/set_sprb order does not matter *)
Lemma set_dvar_set_sprb sig A x b ok :
  eq (set_dvar (set_sprb sig A b) x ok)
    (set_sprb (set_dvar sig x ok) A b).
Proof.
  split; simpl; try intro; try reflexivity.
Qed.

(* for any data variable and channel, set_dvar/set_rprb order does not matter *)
Lemma set_dvar_set_rprb sig A x b ok :
  eq (set_dvar (set_rprb sig A b) x ok)
    (set_rprb (set_dvar sig x ok) A b).
Proof.
  split; simpl; try intro; try reflexivity.
Qed.

(* for any unequal channels, set_dprb order does not matter *)
Lemma set_dprb_set_dprb sig A A' ok ok' :
  A <> A' -> eq (set_dprb (set_dprb sig A ok) A' ok')
    (set_dprb (set_dprb sig A' ok') A ok).
Proof.
  intro. split; simpl; try intro; try reflexivity. destruct ok, ok'.
    symmetry. apply chanmap'.add_add. auto.
    symmetry. apply chanmap'.add_remove. auto.
    apply chanmap'.add_remove. auto.
    apply chanmap'.remove_remove.
Qed.

(* for any two channels, set_sprb/set_dprb order does not matter *)
Lemma set_sprb_set_dprb sig A A' ok b :
  eq (set_sprb (set_dprb sig A ok) A' b)
    (set_dprb (set_sprb sig A' b) A ok).
Proof.
  split; simpl; try intro; try reflexivity.
Qed.

(* for any unequal channels, set_sprb order does not matter *)
Lemma set_sprb_set_sprb sig A A' b b' :
  A <> A' ->
  eq (set_sprb (set_sprb sig A b) A' b') (set_sprb (set_sprb sig A' b') A b).
Proof.
  intro. split; simpl; try intro; try reflexivity. destruct b, b'.
    apply chanset'.add_add.
    symmetry. apply chanset'.add_remove. auto.
    apply chanset'.add_remove. intro. apply H. auto.
    apply chanset'.remove_remove.
Qed.

(* for any two channels, set_rprb/set_dprb order does not matter *)
Lemma set_rprb_set_dprb sig A A' ok b :
  eq (set_rprb (set_dprb sig A ok) A' b)
    (set_dprb (set_rprb sig A' b) A ok).
Proof.
  split; simpl; try intro; try reflexivity.
Qed.

(* for any two channels, set_rprb/set_sprb order does not matter *)
Lemma set_rprb_set_sprb sig A A' b b' :
  eq (set_rprb (set_sprb sig A b) A' b')
    (set_sprb (set_rprb sig A' b') A b).
Proof.
  split; simpl; try intro; try reflexivity.
Qed.

(* for any unequal channels, set_rprb order does not matter *)
Lemma set_rprb_set_rprb sig A A' b b' :
  A <> A' ->
  eq (set_rprb (set_rprb sig A b) A' b') (set_rprb (set_rprb sig A' b') A b).
Proof.
  intro. split; simpl; try intro; try reflexivity. destruct b, b'.
    apply chanset'.add_add.
    symmetry. apply chanset'.add_remove. auto.
    apply chanset'.add_remove. intro. apply H. auto.
    apply chanset'.remove_remove.
Qed.

(* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *)

(* for any unequal data variables x and x', setting x does not change x' *)
Lemma get_dvar_set_dvar sig x x' ok :
  x <> x' -> get_dvar (set_dvar sig x ok) x' = get_dvar sig x'.
Proof.
  unfold get_dvar, set_dvar; simpl. destruct ok.
  apply dvarmap'.find_add. apply dvarmap'.find_remove.
Qed.

(* for any channel A and data variable x, setting x does not change A# *)
Lemma get_dprb_set_dvar sig x A ok :
  get_dprb (set_dvar sig x ok) A = get_dprb sig A.
Proof.
  auto.
Qed.

(* for any unequal channels A and A', setting A# does not change A'# *)
Lemma get_dprb_set_dprb sig A A' ok :
  A <> A' -> get_dprb (set_dprb sig A ok) A' = get_dprb sig A'.
Proof.
  unfold get_dprb, set_dprb; simpl. destruct ok.
  apply chanmap'.find_add. apply chanmap'.find_remove.
Qed.

(* for any two channels A and A', setting A- does not change A'# *)
Lemma get_dprb_set_sprb sig A A' b :
  get_dprb (set_sprb sig A b) A' = get_dprb sig A'.
Proof.
  auto.
Qed.

(* for any two channels A and A', setting A^ does not change A'# *)
Lemma get_dprb_set_rprb sig A A' b :
  get_dprb (set_rprb sig A b) A' = get_dprb sig A'.
Proof.
  auto.
Qed.

(* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *)

(* if data variable x is not in expression e, sig[x |-> ok](e) = sig(e) *)
Lemma eval_expr_set_dvar sig x ok e :
  ~ expr_In x e -> eval_expr (set_dvar sig x ok) e = eval_expr sig e.
Proof.
  unfold eval_expr. intro. apply expr_eval_eq. intros. apply get_dvar_set_dvar.
  intro. apply H. rewrite H1. auto.
Qed.

(* for any channel A and expression e, sig[A# |-> ok](e) = sig(e) *)
Lemma eval_expr_set_dprb sig A ok e :
  eval_expr (set_dprb sig A ok) e = eval_expr sig e.
Proof.
  auto.
Qed.

(* for any channel A and expression e, sig[A- |-> b](e) = sig(e) *)
Lemma eval_expr_set_sprb sig A b e :
  eval_expr (set_sprb sig A b) e = eval_expr sig e.
Proof.
  auto.
Qed.

(* for any channel A and expression e, sig[A^ |-> b](e) = sig(e) *)
Lemma eval_expr_set_rprb sig A b e :
  eval_expr (set_rprb sig A b) e = eval_expr sig e.
Proof.
  auto.
Qed.

End state.

End MStates.

(* (c) 2020 Brittany Ro Nkounkou *)
