(**************************************************************)
(*
 *  Ensemble, (Version 0.40)
 *  Copyright 1997 Cornell University
 *  All rights reserved.
 *
 *  See ensemble/doc/license.txt for further information.
 *)
(**************************************************************)
(**************************************************************)
(* GCAST.ML *)
(* Author: Mark Hayden, 12/96 *)
(**************************************************************)
open Trans
open Layer
open Event
open Util
open View
(**************************************************************)
let name = Trace.source_file "GCAST"
let failwith s = failwith (Util.failmsg name s)
(**************************************************************)

type header = NoHdr 
  | Debug
  | Cast of rank

type state = {
    fanout : int ;
    failed : bool array ;
    mutable forward : rank list array
}

(**************************************************************)

(* Fix O'caml mod function to return non-negative values.
 *)
let (mod) a b = 
  let m = a mod b in
  if m >=0 then m else m + b

(**************************************************************)

(* Calculate vector of children for forwarding messages.
 *)
let children nmembers fanout rank failed =
  let is_failed shift rank =
    failed.((rank - shift) mod nmembers)
  in
  let fan = sequence fanout in
  let rec children shift rank =
    let c = (rank - shift) mod nmembers in
    let c = succ (c * fanout) in
    let c = List.map ((+) c) fan in
    let c = list_filter (fun i -> i < nmembers) c in
    let c = List.map (fun i -> (i + shift) mod nmembers) c in
    let c = 
      List.map (fun i ->
      	if failed.(i) then
	  children shift i
      	else [i]
      ) c
    in
    let c = List.flatten c in
    c
  in

  array_createf nmembers (fun i -> children i rank)

(* DEBUGGING VERSION
let children nmembers fanout root failed =
  if failed.(root) then failwith "root failed" ;
  let mark = Array.copy failed in
  let rec dft r = 
    if mark.(r) then failwith "marked twice" ;
    if failed.(r) then failwith "send to failed member" ;
    mark.(r) <- true ;
    let ch = children nmembers fanout r failed in
    let ch = ch.(root) in
    List.iter dft ch
  in
  dft root ;

  let nfailed = ref 0 in
  Array.iter (fun b -> if b then incr nfailed) failed ;
  eprintf "GCAST:checking tree: nmembers=%d, root=%d, nfailed=%d\n"
    nmembers root !nfailed ;
  if array_exists (fun _ b -> not b) mark then
    failwith "not all marked" ;
  
  children nmembers fanout root failed
*)
    
(**************************************************************)

let init () (ls,vs) =
  let fanout = Param.int vs.params "gcast_fanout" in
  let failed = array_create name ls.nmembers false in
  { fanout = fanout ;
    failed = failed ;
    forward = children ls.nmembers fanout ls.rank failed
  }

(**************************************************************)

let hdlrs s (ls,vs) {up_out=up;upnm_out=upnm;dn_out=dn;dnlm_out=dnlm;dnnm_out=dnnm} =
  let log = Trace.log name ls.name in
  let up_hdlr ev abv hdr = match getType ev, hdr with
  | ESend, Cast(origin) ->
      let dests = s.forward.(origin) in
      log (fun () -> sprintf "%d ... -> %d -> %d -> %s" 
	origin (getOrigin ev) ls.rank (string_of_int_list dests)) ;
      if dests <> [] then (
      	dn (sendRanksIov name dests (getIov ev)) abv (Cast origin) ;
      ) ;
      up (castOriginIov name origin (getIov ev)) abv ;

  | _, NoHdr -> up ev abv
  | _, _     -> failwith "bad up event"

  and uplm_hdlr ev hdr = failwith "unknown local message"
  and upnm_hdlr ev = match getType ev with

  | EFail -> 
      (* Mark failed members.
       *)
      let failed = getFailures ev in
      List.iter (fun rank -> s.failed.(rank) <- true) failed ;

      (* Recalculate forwarding info.
       *)
      s.forward <- children ls.nmembers s.fanout ls.rank s.failed ;
      upnm ev

  | _ -> upnm ev
  
  and dn_hdlr ev abv = match getType ev with
  | ECast ->
      let origin = ls.rank in
      let dests = s.forward.(origin) in
      log (fun () -> sprintf "root:%d -> %s" 
	ls.rank (string_of_int_list dests)) ;
      if dests <> [] then
      	dn (sendRanksIov name dests (getIov ev)) abv (Cast origin) ;
      free name ev

  | _ -> dn ev abv NoHdr

  and dnnm_hdlr = dnnm

in {up_in=up_hdlr;uplm_in=uplm_hdlr;upnm_in=upnm_hdlr;dn_in=dn_hdlr;dnnm_in=dnnm_hdlr}

(**************************************************************)

let l args vf = Layer.hdr init hdlrs None NoOpt args vf

let _ = 
  Param.default "gcast_fanout" (Param.Int 2) ;
  Layer.install name (Layer.init l)
    
(**************************************************************)
