(**************************************************************)
(*
 *  Ensemble, (Version 0.40)
 *  Copyright 1997 Cornell University
 *  All rights reserved.
 *
 *  See ensemble/doc/license.txt for further information.
 *)
(**************************************************************)
(**************************************************************)
(* PT2PTW.ML : Point-to-point window-based flow control. *)
(* Author: Mark Hayden, 3/96 *)
(* Note that the window-cost of a message here is 1 more
 * than the number of bytes in the iov field of the message. *)
(**************************************************************)
open Layer
open View
open Event
open Util
open Trans
(**************************************************************)
let name = Trace.source_file "PT2PTW"
(**************************************************************)

type header = NoHdr
  | Ack of int
  | Unrel

(**************************************************************)

type 'abv state = {
  window : int ;
  ack_thresh : int ;
  send_buf : (Event.dn * 'abv) Queue.t array ;
  send_credit : int array ;
  recv_credit : int array ;
  failed : bool array
}  

(**************************************************************)

let msg_len ev = Iovecl.len name (getIov ev) + 100

let string_of_queue_len q = string_of_int (Queue.length q)
let string_of_queue_bytes q = 
  let len = ref 0 in
  Queue.iter (fun (ev,_) -> len := !len + msg_len ev) q ;
  string_of_int !len

let dump (ls,vs) s =
  eprintf "PT2PTW:dump:%s\n" ls.name ;
  eprintf "  send_buf=%s\n" (string_of_array string_of_queue_len s.send_buf) ;
  eprintf "  send_credit=%s\n" (string_of_int_array s.send_credit) ;
  eprintf "  recv_credit=%s\n" (string_of_int_array s.recv_credit)

(**************************************************************)

let init () (ls,vs) = 
  let window = Param.int vs.params "pt2ptw_window" in
  { window = window ;
    ack_thresh = Param.int vs.params "pt2ptw_ack_thresh" ;
    send_buf = array_createf ls.nmembers (fun _ -> Queue.create ()) ;
    send_credit = array_create name ls.nmembers window ;
    recv_credit = array_create name ls.nmembers 0 ;
    failed = array_create name ls.nmembers false
  }

(**************************************************************)

let hdlrs s (ls,vs) {up_out=up;upnm_out=upnm;dn_out=dn;dnlm_out=dnlm;dnnm_out=dnnm} =
  let ack = make_acker name dnnm in
  let failwith m = dump (ls,vs) s ; failwith (name^":"^m) in
  let log = Trace.log name ls.name in
  let logb = Trace.log "BUFFER" (name^":"^ls.name) in

  let up_hdlr ev abv hdr = match getType ev, hdr with

    (* Increase amount of credit to pass back to sender.
     * If the amount of credit is beyond the threshhold,
     * send an acknowledgement.  Finally, deliver the
     * message.  
     *)
  | ESend, NoHdr ->
      let origin = (getOrigin ev) in
      array_add s.recv_credit origin (msg_len ev) ;
      if s.recv_credit.(origin) > s.ack_thresh then (
	dnlm (sendRank name origin) (Ack(s.recv_credit.(origin))) ;
	s.recv_credit.(origin) <- 0 ;
      ) ;
      up ev abv

  | ESend, Unrel -> up ev abv
  | _, NoHdr -> up ev abv
  | _        -> failwith "bad header"

  and uplm_hdlr ev hdr = match getType ev,hdr with
    (* Some credit were sent back, send more data if its waiting.
     *)
  | ESend, Ack(credit) ->
      let origin = (getOrigin ev) in 
      array_add s.send_credit origin credit ;

      let len = ref (Queue.length s.send_buf.(origin)) in
      while s.send_credit.(origin) > 0 && !len > 0 do
	(* Queue.take will not fail here.
	 *)
	let ev,abv = Queue.take s.send_buf.(origin) in
	decr len ;
	dn ev abv NoHdr ;
      	array_sub s.send_credit origin (msg_len ev)
      done ;

      ack ev ; free name ev 
  | _ -> failwith "unknown local message"

  and upnm_hdlr ev = match getType ev with

    (* EFail: Mark failed members, and pass on up.
     *)
  | EFail ->
      let failed = getFailures ev in
      List.iter (fun rank ->
	queue_clean (fun (ev,_) ->
	  free name ev
	) s.send_buf.(rank) ;
	s.failed.(rank) <- true ;
	s.send_credit.(rank) <- 0 ;
	s.recv_credit.(rank) <- 0
      ) failed ;
      upnm ev

  | EAccount ->
      logb (fun () -> sprintf "blocked(msgs):%s" 
	(string_of_array string_of_queue_len s.send_buf)) ;
      logb (fun () -> sprintf "blocked(byte):%s" 
	(string_of_array string_of_queue_bytes s.send_buf)) ;
      upnm ev

  | EDump -> ( dump (ls,vs) s ; upnm ev )
  | _ -> upnm ev

  and dn_hdlr ev abv = match getType ev with

    (* Send a message to each destination.  If we don't have
     * any credit, then buffer it.  
     *)
  | ESend ->
      if getUnreliable ev then (
        dn ev abv Unrel
      ) else (
	let ndests = List.length (getRanks ev) in
      	List.iter (fun dest ->
	  (* Optimize to use the unmodified event
	   * when there is only one destination.
	   *)
	  let ev = if ndests = 1 then ev else (set name ev [Ranks [dest]]) in
	      
	  if s.failed.(dest) then (
	    log (fun () -> sprintf "send to failed member") ;
	    free name ev		(*BUG*)
	  ) else if s.send_credit.(dest) > 0 then (
	    array_sub s.send_credit dest (msg_len ev) ;
            dn ev abv NoHdr
	  ) else (
	    Queue.add (ev,abv) s.send_buf.(dest)
	  )
	) (getRanks ev)
      )

  | _ -> dn ev abv NoHdr

  and dnnm_hdlr = dnnm

in {up_in=up_hdlr;uplm_in=uplm_hdlr;upnm_in=upnm_hdlr;dn_in=dn_hdlr;dnnm_in=dnnm_hdlr}

let l args vs = Layer.hdr_state init hdlrs None (FullNoHdr NoHdr) args vs

let _ = 
  Param.default "pt2ptw_window" (Param.Int 50000) ;
  Param.default "pt2ptw_ack_thresh" (Param.Int 25000) ;
  Layer.install name (Layer.init l)

(**************************************************************)
