(**************************************************************)
(*
 *  Ensemble, (Version 0.40)
 *  Copyright 1997 Cornell University
 *  All rights reserved.
 *
 *  See ensemble/doc/license.txt for further information.
 *)
(**************************************************************)
(**************************************************************)
(* PT2PT.ML : FIFO, reliable pt2pt sends *)
(* Author: Mark Hayden, 3/96 *)
(* Based on code by: Robbert vanRenesse *)
(**************************************************************)
open Layer
open View
open Event
open Util
open Trans
open Compresse
(**************************************************************)
let name = Trace.source_file "PT2PT"
(**************************************************************)
(* Data(seqno): sent with message number 'seqno'

 * Ack(seqno): acks first seqno pt2pt messages.

 * Nak(lo,hi): request for retransmission of messages number
 * 'lo' to 'hi'.
 * BUG: comment should say, inclusive/exclusive
 *)
type header = NoHdr
  | Data of seqno * (seqno option)
  | Ack of seqno
  | Nak of seqno * seqno

(**************************************************************)
(* These are for optimizations in the Layer module.
 *)
let detector = function 
  | Data(seqno,None) -> Some seqno
  | _ -> None

let constructor seqno = Data(seqno,None)

(**************************************************************)

type ('abv,'cps) state = {
  sweep         : Time.t ;
  sync          : bool ;
  ack_rate      : seqno ;
  mutable next_sweep : Time.t ;
  mutable block_ok : bool ;

  sends		: ('abv,'cps) Compresse.t Iq.t array ;
  recvs		: ('abv,'cps) Compresse.t Iq.t array ;
  naked         : seqno array ;
  acked		: seqno array ;
  failed	: bool array ;

  acct_rsize    : int array ;
  acct_ssize    : int array
}

(**************************************************************)

let iq_size i =
  let i = Iq.list_of_iq i in
  let s = List.fold_left (fun s (_,i) -> s + (size i)) 0 i in
  s

(**************************************************************)

let dump (ls,vs) s =
  eprintf "PT2PT:dump:%s\n" ls.name ;
  eprintf "  rank=%d, nmembers=%d, sync=%b\n" ls.rank ls.nmembers s.sync ;
  eprintf "  failed =%s\n" (string_of_bool_array s.failed) ;
  eprintf "  send_lo=%s\n" (string_of_int_array (Array.map Iq.head s.sends)) ;
  eprintf "  send_hi=%s\n" (string_of_int_array (Array.map Iq.tail s.sends)) ;
  eprintf "  recv_lo=%s\n" (string_of_int_array (Array.map Iq.head s.recvs)) ;
  eprintf "  recv_hi=%s\n" (string_of_int_array (Array.map Iq.tail s.recvs)) ;
  eprintf "  recv_size=%s\n" (string_of_int_array (Array.map iq_size s.recvs)) ;
  eprintf "  send_size=%s\n" (string_of_int_array (Array.map iq_size s.sends))

(**************************************************************)

let init () (ls,vs) = {
  sweep		= Param.time vs.params "pt2pt_sweep" ;
  sync          = Param.bool vs.params "pt2pt_sync" ;
  ack_rate      = Param.int vs.params "pt2pt_ack_rate" ;
  block_ok	= false ;
  next_sweep	= Time.zero ;
  sends		= array_createf ls.nmembers (fun _ -> Iq.empty Unset Reset) ;
  recvs		= array_createf ls.nmembers (fun _ -> Iq.empty Unset Reset) ;
  acked		= array_create name ls.nmembers 0 ;
  naked		= array_create name ls.nmembers 0 ;
  failed	= array_create name ls.nmembers false ;
  acct_rsize    = array_create name ls.nmembers 0 ;
  acct_ssize    = array_create name ls.nmembers 0
}

(**************************************************************)

let hdlrs s (ls,vs) {up_out=up;upnm_out=upnm;dn_out=dn;dnlm_out=dnlm;dnnm_out=dnnm} =
  let failwith m = dump (ls,vs) s ; failwith (name^":"^m) in
  let log = Trace.log name ls.name in
  let logb = Trace.log "BUFFER" (name^":"^ls.name) in

  let send dest ev abv =
    if dest =| ls.rank then (
      eprintf "PT2PT:%s\nPT2PT:%s\n" (Event.to_string ev) (View.string_of_full (ls,vs)) ;
      failwith "send to myself" ;
    ) ;
    let sends = s.sends.(dest) in
    let seqno = Iq.tail sends in
    Iq.add sends (Full(abv,(getIov ev))) ;
    array_add s.acct_ssize dest (Iovecl.len name (getIov ev)) ;
    
    let ack = 
(*
      let ackno = Iq.head s.recvs.(dest) in
(*
      if ackno >| s.acked.(dest) + s.ack_rate then (
	s.acked.(dest) <- ackno ;
*)
	Some ackno
(*
      ) else None
*)
*)None
    in
	    
    dn ev abv (Data(seqno,ack))
  in

  let up_hdlr ev abv hdr = match getType ev, hdr with

    (* ESend:Data: Got a data message from other
     * member.  Check for fast path.
     *)
  | ESend, Data(seqno,ackno) ->
      let origin = getOrigin ev in
      let recvs = s.recvs.(origin) in

      (* Unbuffer acknowledged messages.
       *)
      begin match ackno with
      |	None -> ()
      |	Some ackno ->
	  Iq.advance_head s.sends.(origin) ackno
      end ;

      (* Check for fast-path.
       *)
      if Iq.opt_check_update recvs seqno then (
	log (fun () -> sprintf "fast path") ;
	up ev abv
      ) else (
	log (fun () -> sprintf "slow path origin=%d seqno=%d" (getOrigin ev) seqno) ;
	let iov = getIov ev in
	if Iq.assign recvs seqno (Full(abv,iov)) then (
	  Iovecl.ref name iov ;
	  array_add s.acct_rsize origin (Iovecl.len name iov) ;
	  
	  Iq.get_prefix recvs (fun seqno cps ->
	    let abv,iov = uncompress cps in
	    up (create name ESend[
	      Iov iov ;
	      Origin origin
	    ]) abv
	  )
	) ;

	(* If seqno seems old, then maybe a previous ack did
	 * not get through.  Shift ack seqno back.  This will
	 * cause an ack to be sent on next timeout.
	 *)
	if seqno <| s.acked.(origin) then (
	  log (fun () -> sprintf "setting acked=%d back to %d"
	    s.acked.(origin) seqno) ;
	  s.acked.(origin) <- seqno ;
	) ;

	(* If there are holes out of order
         * then send a nak.
	 *)
	if seqno >| s.naked.(origin) then (
	  Iq.advance_tail recvs seqno ; (*BUG?*)
	  if_some (Iq.hole recvs) (fun (lo,hi) ->
	    if seqno >=| lo then (
	      (* Keep track of highest msg # we've naked.
	       *)
	      s.naked.(origin) <- max s.naked.(origin) hi ;
	      dnlm (sendRank name origin) (Nak(lo,hi)) ;
	    )
	  )
	) ;

        free name ev
      )

  | _, NoHdr -> up ev abv
  | _        -> failwith "bad header"

  and uplm_hdlr ev hdr = match getType ev,hdr with

    (* Nak: got a request for retransmission.  Send any
     * messages I have in the requested interval, lo..hi.
     *)
  | ESend, Nak(lo,hi) ->
      let o = (getOrigin ev) in

      (* Ack: Unbuffer any acked messages.
       *)
      Iq.advance_head s.sends.(o) lo ;

      (* Retransmit any of the messages asked for that I have.
       *)
      let list = Iq.list_of_iq_interval s.sends.(o) (lo,hi) in
      List.iter (fun (seqno,cps) ->
        let (abv,iov) = uncompress cps in
	Iovecl.ref name iov ;
        dn (sendRanksIov name [o] iov) abv (Data(seqno,None))
      ) list ;

      free name ev

    (* Ack: Unbuffer any acked messages.
     *)
  | ESend, Ack(seqno) ->
      let o = (getOrigin ev) in
      log (fun () -> sprintf "got ack for %d from %d" seqno o) ;
      Iq.advance_head_gc s.sends.(o) seqno (fun c ->
	let iov = snd (uncompress c) in
	Iovecl.free name iov
      ) ;

      free name ev

  | _ -> failwith "unknown local message"

  and upnm_hdlr ev = match getType ev with

    (* EFail: Mark failed members, and pass on up.
     *)
  | EFail ->
      let failed = getFailures ev in
      List.iter (fun rank ->
      	s.failed.(rank) <- true
      ) failed ;

      upnm ev

    (* EBlockOk: buffer the event until my messages are
     * stable.
     *)
  | EBlockOk ->
      if s.block_ok then
        failwith "2nd EBlockOk" ;
      if ls.nmembers = 1 || (not s.sync) then (
        upnm ev
      ) else (
        log (fun () -> sprintf "EBlockOk:capture") ;
        s.block_ok <- true ;
        free name ev ;
      )

    (* ETimer: check for any messages that need to be
     * retransmitted.  Acknowledge any unacknowledged
     * messages.  
     *)
  | ETimer ->
      let time = getTime ev in
      if time >= s.next_sweep then (
        s.next_sweep <- Time.add time s.sweep ;
	dnnm (timerAlarm name s.next_sweep) ;

	(* Used to check if all my messages are stable.
	 *)
	let stable = ref true in

	(* Go through all live members other than myself.
	 *)
      	for i = 0 to pred ls.nmembers do
	  if i <>| ls.rank && not s.failed.(i) then (
	    (* Send out acknowledgements.
	     *)
	    let head = Iq.head s.recvs.(i) in
	    if head >| s.acked.(i) then (
	      log (fun () -> sprintf "sending ack to %d for %d, acked was %d" i head s.acked.(i)) ;
	      dnlm (sendRank name i) (Ack(head)) ;
	      s.acked.(i) <- head
	    ) ;

	    (* Retransmit any unacknowledged messages I've sent.
	     *)
	    let list = Iq.list_of_iq (*_interval (lo,hi)*) s.sends.(i) in
	    if list <> [] then
	      stable := false ;
	    List.iter (fun (seqno,cps) ->
	      let abv,iov = uncompress cps in
	      dn (sendRanksIov name [i] iov) abv (Data(seqno,None))
	    ) list ;
	  ) ;
	done ;

(*
	if s.block_ok && not !stable then (
	  for i = 0 to pred ls.nmembers do
	    log (fun () -> sprintf "my rank=%d" ls.rank) ;
	    let list = Iq.list_of_iq (*_interval (lo,hi)*) s.sends.(i) in
	    let list = List.map fst list in
	    log (fun () -> sprintf "rank=%d, head=%d, seqno=%s" i (Iq.head s.sends.(i)) (string_of_int_list list)) ;
	  done
	) ;
*)
	(* If an EBlockOk has been buffered and I'm stable
	 * then send up the event.
	 *)
	if s.block_ok && !stable then (
	  log (fun () -> sprintf "EBlockOk:release") ;
	  s.block_ok <- false ;
	  upnm (create name EBlockOk[]) ;
	) ;
      ) ;

      upnm ev

  | EAccount ->
      (* Dump buffering information if requested.
       *)
      logb (fun () -> sprintf "recv(bytes)=%s" (string_of_int_array (Array.map iq_size s.recvs))) ;
      logb (fun () -> sprintf "send(bytes)=%s" (string_of_int_array (Array.map iq_size s.sends))) ;
      upnm ev

  | EDump -> ( dump (ls,vs) s ; upnm ev )
  | _ -> upnm ev

  and dn_hdlr ev abv = match getType ev with

    (* ESend: for each destination, buffer a copy and send it on.
     *)
  | ESend ->
      if getUnreliable ev then (
        dn ev abv NoHdr
      ) else (
	match getRanks ev with
	| [] -> ()
	| [dest] -> send dest ev abv
	| dests ->
	    let iov = getIov ev in
	    List.iter (fun dest ->
	      let ev = copy name ev in
	      let ev = set name ev [Ranks [dest]] in
	      send dest ev abv
	    ) (getRanks ev) ;
	    free name ev
      )

  | _ -> dn ev abv NoHdr

  and dnnm_hdlr = dnnm

in {up_in=up_hdlr;uplm_in=uplm_hdlr;upnm_in=upnm_hdlr;dn_in=dn_hdlr;dnnm_in=dnnm_hdlr}

let l args vs = Layer.hdr_state init hdlrs None (LocalSeqno(NoHdr,ESend,detector,constructor)) args vs

let _ = 
  Param.default "pt2pt_sweep" ((Param.Time (Time.of_float 1.0))) ;
  Param.default "pt2pt_ack_rate" (Param.Int 10) ;
  Param.default "pt2pt_sync" (Param.Bool true) ;
  Layer.install name (Layer.init l)

(**************************************************************)
