(**************************************************************)
(*
 *  Ensemble, (Version 0.40)
 *  Copyright 1997 Cornell University
 *  All rights reserved.
 *
 *  See ensemble/doc/license.txt for further information.
 *)
(**************************************************************)
(**************************************************************)
(* TRANSPORT.ML *)
(* Author: Mark Hayden, 7/95 *)
(**************************************************************)
open Util
open Route
open Trans
open Domain
(**************************************************************)
let name = Trace.source_file "TRANSPORT"
let failwith s = failwith (Util.failmsg name s)
let log_route = Trace.logl "TRANSPORTR" ""
(**************************************************************)

type dest =
  | XSend of rank
  | XCast
  | XGossip
  | XMerge of View.id option * Endpt.full

let string_of_dest = function
  | XSend rank -> sprintf "Send(%d)" rank
  | XCast -> "Cast"
  | XGossip -> "Gossip"
  | XMerge _ -> "Merge(...)"

(**************************************************************)

type 'msg t = {
  disable : unit -> unit ;
  xmit	  : dest -> 'msg
}

(**************************************************************)

let disable m 	= m.disable ()
let send m rank = m.xmit (XSend(rank))
let cast m 	= m.xmit XCast
let gossip m 	= m.xmit XGossip
let merge m v e = m.xmit (XMerge(v,e))

(**************************************************************)

let route conns endpt addr group view addrs dest info =
  let info () = (sprintf "endpt=%s,group=%s,dest.0=%s"
    (Endpt.string_of_id endpt)
    (Group.string_of_id group)
    (string_of_dest dest)) :: info ()
  in

  (* Find modes to use.
   *)
  let modes = match dest with
  | XSend(dest) ->
      let dest = addrs.(dest) in
      let modes = Addr.modes_of_view [|addr;dest|] in
      let modes = array_filter Addr.has_pt2pt modes in
      if modes = [||] then
        failwith "route:Send:no pt2pt modes in endpt" ;
      [|Addr.prefer modes|]
  | XCast ->
      let modes = Addr.modes_of_view addrs in
      let modes = array_filter Addr.has_pt2pt modes in
      if modes = [||] then
        failwith "route:Cast:no pt2pt modes in endpt" ;
      [|Addr.prefer modes|]
  | XGossip ->                (* all available cast modes *)
      let modes = Addr.modes_of_view [|addr|] in
      let modes = array_filter Addr.has_mcast modes in
      if modes = [||] then
        failwith "route:Gossip:no mcast modes in endpt" ;
      modes
  | XMerge(_,dest) ->
      let modes = Addr.modes_of_view [|addr; snd dest|] in
      let modes = array_filter Addr.has_pt2pt modes in
      if modes = [||] then
        failwith "route:Merge:no pt2pt modes in endpt" ;
      [|Addr.prefer modes|]
  in

  let info () = (sprintf "modes=%s" (string_of_array Addr.string_of_id modes)) :: info () in

  let dests = match dest with
  | XSend(dest) ->
      let dest = addrs.(dest) in
      Pt2pt([|dest|])
  | XCast ->
      (* Send to all the addresses but mine.
       *)
      if Array.length addrs <> Array.length view then
	failwith "mismatched view and address" ;
      let addrs = array_combine view addrs in
      let addrs =
	Array.map (fun (e,a) ->
	  if e = endpt then None else Some a
        ) addrs
      in
      let addrs = array_filter_nones addrs in
      Pt2pt(addrs)
  | XGossip -> Gossip(group)
  | XMerge(_,dest) -> Pt2pt([|snd dest|])
  in

  let info () = (sprintf "dests.1=%s" (Domain.string_of_dest dests)) :: info () in

  (* For point-to-point:
   * (I've already been stripped)
   * 1. Strip local destinations
   * 2. Strip duplicate destinations
   * 3. Determine if any destinations are local.
   *)
  let local,dests = match dests with
  | Pt2pt(dests) ->
      let local,dests = Addr.compress addr dests in
      local, Pt2pt(dests)
  | _ -> (false,dests)
  in

  let info () = (sprintf "local=%b, dests.2=%s" local (Domain.string_of_dest dests)) :: info () in

  (* Use Mcast if we have it and there is more than 1
   * effective destination (but not if the mode is Udp).
   *)
  let local,dests = match dests with
  | Pt2pt(dests) when (Array.length dests > 1)
    && modes <> [|Addr.Udp|]
    && array_for_all Addr.has_mcast modes ->
      let loopback = true in
      false(*BUG?*), Mcast(group,loopback)
  | _ -> local,dests
  in

  let info () = (sprintf "local=%b, dests.3=%s" local (Domain.string_of_dest dests)) :: info () in

  (* Collect all live xmits.
   *)
  let xmits =
    array_fold_left (fun l mode ->
      let xmit = Domain.xmit (Domain.of_mode mode) mode dests in
      match xmit with
      | None -> l
      | Some(xs) -> xs :: l
    ) [] modes
  in

  let info () = (sprintf "nxmits=%d" (List.length xmits)) :: info () in

  let xmits =
    if not local then xmits else (
      let flatten = Mbuf.flatten name Mbuf.global in
      let x buf ofs len =
	let dint = Hsys.pop_int buf ofs in
	deliver dint (Refcnt.void name buf) ofs len
      and xv iov =
	let iov = flatten iov in
	Iovec.break name iov (deliver (Iovec.read_int name iov)) ;
	Iovec.free name iov
      in
      xmits @ [x,xv]
    )
  in

  (* Debugging information.
   *)
  log_route (fun () ->
    let info = info () in
    let info = List.rev info in
    let info = ["begin"] @ info @ ["end"] in
    info
  );

  (* Depending on number of xmits, do different things.
   *)
  let x,xv =
    match xmits with
    | [] -> ((fun _ _ _ -> ()),(fun _ -> ()))
    | [xmit] -> xmit
    | xmits -> (
        let x b o l = List.iter (fun (x,_) -> x b o l) xmits in
        let xv v = List.iter (fun (_,xv) -> xv v) xmits in
        (x,xv)
      )
  in

  (x,xv,info)

(**************************************************************)
(* Security stuff.
 *)

let secure = ref false
let insecure_warned = ref false

let set_secure () =
  eprintf "TRANSPORT:security enabled\n" ;
  secure := true

let security_check router =
  (* Check for security problems.
   *)
  if !secure && not router.secure then (
    eprintf "TRANSPORT:enabling transport with insecure router (%s)\n" router.name ;
    eprintf "  -secure flag is set, exiting\n" ;
    exit 1
  )
(* ;
  if router.secure & (not !secure) & (not !insecure_warned) then (
    insecure_warned := true ;
    eprintf "TRANSPORT:warning:enabling secure transport but -secure flag not set\n" ;
    eprintf "  (an insecure transport may be enabled)\n"
  )
*)

(**************************************************************)

let enabled = 
  let table = Hashtbl.create 10 in
  Trace.install_root (fun () ->
    [sprintf "TRANSPORT:#enabled=%d" (hashtbl_size table)]
  ) ;
  table
(*
    Hashtbl.iter (fun id () ->
      eprintf "  %s\n" (Conn.string_of_id_field id)
    ) enabled
*)

let f version group endpt stack_id proto_id key view addrs view_id is_gossip router msg_hdlr =
  let rank = array_index endpt view in
  let addr = addrs.(rank) in

  (* Create record of connections.
   *)
  let conns = Conn.create version endpt group view_id stack_id proto_id view is_gossip in

  begin
    (* Check for security problems.
     *)
    security_check router ;

    (* Enable all potential domains.
     *)
    let modes = Addr.ids_of_set addr in
    Array.iter (fun mode ->
      if Addr.has_mcast mode || Addr.has_pt2pt mode then (
      	let domain = Domain.of_mode mode in
      	Domain.enable domain mode group addr view
      )
    ) modes ;
    router.install conns key msg_hdlr ;
    Hashtbl.add enabled conns.Conn.id () ;
  end ;

  (* Transmit a message.
   *)
  let xmit dest =
    let info () = [] in

    let conn = match dest with
    | XCast -> conns.Conn.multi_send
    | XSend(dest) -> conns.Conn.pt2pt_send.(dest)
    | XGossip -> conns.Conn.gossip
    | XMerge(view_id,dest) ->
	conns.Conn.merge_send view_id (fst dest)
    in

    let info () = (sprintf "conn=%s" (Conn.string_of_id conn)) :: info () in

    (* Determine routing information.
     *)
    let (x,xv,info) = route conns endpt addr group view addrs dest info in

    (* In the end, pass everything off to router's blast function.
     *)
    router.blast x xv key conn
  in

  (* Disable this transport instance.
   *)
  let disable () =
    let modes = Addr.ids_of_set addr in
    Hashtbl.remove enabled conns.Conn.id ;
    Array.iter (fun mode ->
      if Addr.has_mcast mode || Addr.has_pt2pt mode then (
      	let domain = Domain.of_mode mode in
      	Domain.disable domain mode group addr view
      )
    ) modes ;
    router.remove conns
  in 
  { xmit 	= xmit ;
    disable	= disable }

(**************************************************************)
