(**************************************************************)
(*
 *  Ensemble, (Version 0.40)
 *  Copyright 1997 Cornell University
 *  All rights reserved.
 *
 *  See ensemble/doc/license.txt for further information.
 *)
(**************************************************************)
(**************************************************************)
(* REKEY.ML : switching to a new key *)
(* Authors: Zhen Xiao, Mark Hayden, 4/97 *)
(**************************************************************)
open Util
open Layer
open View
open Event
open Trans
(**************************************************************)
let name = Trace.source_file "REKEY"
(**************************************************************)

type header =
  | Ack
  | Rekey
  | Ticket of rank * Auth.ticket

(**************************************************************)

type state = {
  mutable key_sug : Security.key ;
  rekeying : Once.t ;
  blocking : Once.t ;
  ack : bool array
} 

(**************************************************************)

let dump (ls,vs) s =
  eprintf "REKEY:%s\n" ls.name ;
  eprintf "  ack=%s\n" (string_of_bool_array s.ack) ;
  eprintf "  blocking=%s, rekeying=%s\n"
    (Once.to_string s.blocking) (Once.to_string s.rekeying)

(**************************************************************)

let init () (ls,vs) = { 
  key_sug = vs.key ;
  rekeying = Once.create name ;
  blocking = Once.create "blocking" ;
  ack = array_create name ls.nmembers false
}

(**************************************************************)

let hdlrs s (ls,vs) {up_out=up;upnm_out=upnm;dn_out=dn;dnlm_out=dnlm;dnnm_out=dnnm} =
  let failwith m = dump (ls,vs) s ; failwith (name^":"^m) in
  let ack = make_acker name dnnm in
  let log = Trace.log name ls.name in

  let all_acked () =
    array_for_all ident s.ack
  in

  let try_rekey () =
    if not (Once.isset s.rekeying) 
    && not (Once.isset s.blocking)
    then (
      log (fun () -> sprintf "coord rekeying the group") ;
      Once.set s.rekeying ;

      (* Generate a new shared key.
       *)
      s.key_sug <- Security.create () ;
      s.ack.(ls.rank) <- true ;
      let key = Security.str_of_key s.key_sug in

      (* Encrypt and then send the key to all other members.
       *)
      log (fun () -> sprintf "generating new keys:begin:%s" (Security.string_of_key s.key_sug)) ;
      for rank = 0 to pred ls.nmembers do
	if rank <> ls.rank then (
	  let dst = vs.address.(rank) in
	  match Auth.ticket ls.addr dst key with
	  | Some t -> 
	      (* Should we send to each member or broadcast
	       * to them all?
	       *)
	      dnlm (sendRank name rank) (Ticket(ls.rank,t))
	  | None -> 
	      (* We suspect this member.
	       *)
	      log (fun () -> sprintf "coord didn't authenticate member") ;
	      dnnm (suspectReason name [rank] "REKEY:can create ticket")
	)
      done ;
      log (fun () -> sprintf "generating new keys:end")
    )
  in

  let up_hdlr ev abv () = up ev abv

  and uplm_hdlr ev hdr = match getType ev,hdr with

    (* Got a request to start rekeying the group.
     *)
  | (ECast|ESend), Rekey ->
      if vs.coord = ls.rank then
      	try_rekey () ;
      ack ev ; free name ev
      
    (* Got a ticket from the coordinator.  Try to decrypt
     * it. If succeed, reply [Ack] to the coordinator and
     * remember the key.  
     *)
  | (ECast|ESend), Ticket(orig,t) ->
      if (not ls.am_coord)
      && not (Once.isset s.blocking) 
      then (
	match Auth.check ls.addr t with
	| Some key -> 
	    s.key_sug <- Security.Common(key) ;
	    s.ack.(orig) <- true ;
	    s.ack.(ls.rank) <- true ;
	    dnlm (castEv name) Ack
	| None -> 
	    (* We suspect them if we can't auth ticket.
	     *)
	    log (fun () -> sprintf "member didn't authenticate coord") ;
	    dnnm (suspectReason name [orig] "REKEY:can't auth ticket")
      ) ;
      ack ev ; free name ev 

  | (ECast|ESend), Ack ->
      (* Some members replied that they got the key. Check whether
       * all members have the key. If so, prompt for a view change.
       *)
      let origin = getOrigin ev in 
      s.ack.(origin) <- true ;

      if vs.coord = ls.rank && all_acked () then 
	dnnm (create name EPrompt[]) ;
      ack ev ; free name ev 

  | _ -> failwith "unknown local message"

  and upnm_hdlr ev = match getType ev with

    (* Switch to a new conversation key.  Here we assume the
     * Auth module authenticated a member when it did
     * encryption for that member.  
     *)
  | ERekey ->
      log (fun () -> sprintf "ERekey event") ;
      if ls.rank = vs.coord then (
	try_rekey ()
      ) else (
      	dnlm (sendRank name vs.coord) Rekey
      ) ;
      upnm ev

    (* EBlock: Mark that we are now blocking.
     *)
  | EBlock ->
      Once.set s.blocking ;
      upnm ev

    (* EFail: Mark failed members as acked, since there is
     * no need to wait for failed members.  
     *)
  | EFail ->      
      let failed = getFailures ev in
      List.iter (fun rank ->
	s.ack.(rank) <- true
      ) failed ;

      (* Prompt a new view change if all remaining members are
       * now ready.
       *)
      if vs.coord = ls.rank && all_acked () then 
	dnnm (create name EPrompt[]) ;
      upnm ev

    (* EView: If all members have acknowledged the new key,
     * then switch to using it.
     *)
  | EView ->
      if all_acked () then (
	log (fun () -> sprintf "EView:switching keys:%s" (Security.string_of_key s.key_sug)) ;
	let vs = getViewState ev in
	let vs = View.set vs [Vs_key s.key_sug] in
	let ev = set name ev[ViewState vs] in
      	upnm ev
      ) else (
	upnm ev
      )

  | EDump -> ( dump (ls,vs) s ; upnm ev )
  | _ -> upnm ev

  and dn_hdlr ev abv = dn ev abv ()

  and dnnm_hdlr = dnnm

in {up_in=up_hdlr;uplm_in=uplm_hdlr;upnm_in=upnm_hdlr;dn_in=dn_hdlr;dnnm_in=dnnm_hdlr}

let l args vf = Layer.hdr_state init hdlrs None NoOpt args vf

let _ = Layer.install name (Layer.init l)

(**************************************************************)
