(**************************************************************)
(*
 *  Ensemble, 1.10
 *  Copyright 2001 Cornell University, Hebrew University
 *  All rights reserved.
 *
 *  See ensemble/doc/license.txt for further information.
 *)
(**************************************************************)
(**************************************************************)
(* SOCKIO.ML *)
(* Author: Robbert vanRenesse *)
(**************************************************************)
open Ensemble
open Printf
open Util

exception Eof

(* TCP address.
 *)
type location = string * int

(* These are the white pages servers.
 *
 * TODO.  This should be configurable, and depend on the service that
 *	  we are trying to look up.
 *)
let wpsvrs = [
  ("tumeric.cs.cornell.edu", 2222) ;
  ("snotra.cs.cornell.edu", 2222) 
]

(* The white pages service supports these two requests.
 *)
type wp_req = 
  | WP_UPDATE of string * (location list)
  | WP_LOOKUP of string

(* Requests are uniquely identified by a client id and by a sequence
 * number.  Currently client_ids are chosen by taking the millisecond
 * at which the client started.  This may not be good enough...
 *)
type client_id = int
let verylarge = 536870911.0 (* 2 ^ 29 - 1 *)
let time_in_ms = 
  let time = Time.gettimeofday() in
  let time = Time.to_float time in
  time *. 1000.0
let chopped = float(truncate(time_in_ms /. verylarge))
let client_id = truncate(time_in_ms -. (chopped *. verylarge))
let seq_no = ref 0

(* Each message (or message pair in the case of RPC) is uniquely
 * identified by the client identifier, and a sequence number
 * generated by the client.
 *)
type message_id = client_id * int
type message = message_id * Obj.t
type request = message

type response
  = RESPONSE of message
  | EXCEPTION of message
  (* | REDIRECT of (message_id * address) *)

type request_id = Hsys.socket * message_id

type rpc_descr = (Obj.t * (Obj.t -> unit) * (Obj.t -> unit))
type conn_state
  = Down
  | Connect_First
  | Connect_Again
  | Up of Hsys.socket

(* This is a handle on a service.  It contains the name of the service,
 * a function to be invoked when a connection is made, whether or not
 * the service may be looked up in the White Pages, a list of servers
 * (host, port), and the state of connectivity.
 *)
type handle = {
  name : string;
  newconn : (Obj.t -> unit) -> (unit -> unit) -> ((Obj.t -> unit) * (unit -> bool));
  mutable wp : bool;   (* can be looked up in White Pages *)
  mutable servers : location list;
  mutable state : conn_state
}

(* This is a list of service handles that are currently disconnected
 * and wish to attempt to connect to one of their servers.
 *)
let conn_list = ref []

type rpc_queue = (message_id * rpc_descr) list
type rpc_handle = handle * (rpc_queue ref)

(* This is a list of local service names and their ports.
 *)
let services = ref []

(* The white pages are refreshed every minute.  This variable
 * keeps track of when the last time was that this happened.
 *)
let last_time = ref Time.zero

(* For debugging purposes.
 *)
let dump buf offset len =
  let total = String.length buf in
  let min = if offset + len > total then total - offset else len in
  if min <= 0 then
    raise (Failure "dump: bad length")
  else
    for i = offset to offset + min - 1 do
      printf "<%d>" (Char.code (String.get buf i))
    done

(* Registration functions for socket management.
 *)
let register_ref = ref (fun _ _ _ -> failwith 
  "SOCKIO:register:sanity:register-register not called")
let unregister_ref = ref (fun _ -> failwith 
  "SOCKIO:unregister:sanity:unregister-register not called")
let register_register reg unreg =
  register_ref := reg ; 
  unregister_ref := unreg

(* Register and unregister both look in the ref's above.
 *)
let register a b c = !register_ref a b c
let unregister a = !unregister_ref a

(* This is an association list that maps a socket descriptor to a function
 * that takes input on that descriptor, and a function that is invoked this
 * function fails (presumably when there's no more input on the descriptor.
 *)
let socks = ref ([]: (Hsys.socket *
  ((Hsys.socket -> unit) * (Hsys.socket -> unit))) list)

(* Register functions for local sockio management.
 *)
let sockio_register () = 
  (* Add the given inputavail function and fail function to the socket
   * association list.
   *)
  let register fd inputavail fail =
    socks := Xlist.insert (fd, (inputavail, fail)) !socks
  in

  (* Unregister the given socket and its associations.
   *)
  let unregister fd =
    socks := Xlist.except fd !socks
  in

  register_register register unregister

(* Register functions for Ensemble input.
 *)
let ensemble_register add rmv = 
  let failure fd fail =
    unregister fd;
    fail fd;
    try
      Hsys.close fd; ()
    with _ ->
      ()
  in

  let register fd input fail =
    add fd (Hsys.Handler0 (fun () ->
	try input fd with _ -> failure fd fail))
  in

  let unregister = rmv in
  register_register register unregister

(* Stick a length field in front of the given string so that we
 * can frame the streamed TCP connections.
 *)
let make_frame buf =
  let len = String.length buf in
  let n = ref (len + 4) in
  let frame = String.create !n in
  for i = 0 to 3 do
    String.set frame i (Char.chr (!n mod 256));
    n := !n / 256
  done;
  String.blit buf 0 frame 4 len;
  frame

(* Send the given object of any type over the given TCP socket.
 *)
let send skt o =
  let data = Obj.marshal (Obj.repr o) in
  let frame = make_frame data in
  ignore (Hsys.send skt frame 0 (String.length frame))

(* There's input on this socket.  Read frames and deliver them until
 * there are no more left.
 *)
let message_len = 4096
let reassembly_list = ref []
let gotinput upcall fd =
  (* This function deals with a received buffer of the given total length.
   * If there's more in that buffer, it recursively deals with that as well.
   *)
  let rec handle_chunk buf total rest =
    (* This function returns the length of the frame.
     *)
    let get_length hdr =
      let n = ref 0 in
      for i = 3 downto 1 do
        n := !n + Char.code (String.get hdr i);
        n := !n * 256
      done;
      !n + Char.code (String.get hdr 0)
    in

    (* This function unmarshalls a frame and delivers it to the application.
     *)
    let handle_frame() =
      Printexc.catch (fun _ ->
        let (o,_) = Obj.unmarshal buf 4 in
        upcall fd (Obj.magic o)
      ) ()
    in

    (* If there are fewer than 4 bytes, we cannot figure out yet how large the
     * message is supposed to be.
     *)
    if total < 4 then
      reassembly_list := (fd, (buf, total)) :: rest
    else
      let size = get_length buf in
      (* If we don't have the whole message yet, stick it on the reassembly list.
       * Grow the buffer if it's full.
       *)
      if size > total then
	let buflen = String.length buf in
        if buflen < size then
	  reassembly_list := (fd, ((buf ^ (String.create (size - buflen))), total)) :: rest
        else
          reassembly_list := (fd, (buf, total)) :: rest
      else
        begin
          handle_frame();
	  (* If there's more, copy it to the beginning of the buffer and
	   * deal with that next.
	   *)
	  if size < total then
	    begin
	      String.blit buf size buf 0 (total - size);
	      handle_chunk buf (total - size) rest
	    end

	  (* There nothing left in this buffer.  We could throw the buffer away,
	   * but we're keeping it because the garbage collector deals badly with
	   * freeing and reallocating these large buffers.  Perhaps this is more
	   * efficient too...
	   *)
	  else
	    reassembly_list := (fd, (buf, 0)) :: rest
	end
  in

  (* See if we have a buffer in the reassembly list.  If not, create one.
   *)
  let ((buf, total), rest) = 
    try
      Xlist.take fd !reassembly_list
    with Not_found ->
      ((String.create message_len, 0), !reassembly_list)
  in

  (* We guarantee that there is some space left in the buffer, so we can
   * read at the end of the buffer now.
   *)
  let nbytes = Hsys.recv fd buf total ((String.length buf) - total) in
  if nbytes = 0 then
    begin
      reassembly_list := rest;
      raise Eof
    end
  else
    handle_chunk buf (total + nbytes) rest

(* Create a socket and bind it to the given port.  Return the socket.
 *)
let bind port =
  let hostname = Hsys.gethostname() in
  let addr = Hsys.inet_of_string hostname in

  (* Create stream socket for accepting connections on.
   *)
  let sock = Hsys.socket_stream () in

  (* Bind it to the port.
   *)
  (Hsys.bind sock addr port) ;
  printf "Server %s:%d starting\n" hostname port; flush stdout;

  (* Set up to accept connections on.
   *)
  Hsys.listen sock 5(* max value *) ;
  sock

(* Get the message_id.
 *)
let mid_of_rid (socket, mid) = mid

(* Make a string out of the message id.
 *)
let string_of_mid (client, seqno) = sprintf "%d!%d" client seqno

(* Create a client handle for the name service.  Newconn is invoked when a connection
 * has been made successfully.  It takes two function arguments:  the first is used
 * for sending a message, while the second closes the connection.  It's supposed to
 * return two functions as well:  one to deliver messages, and one to report failure.
 * The failure routine should return true if the connection should remain broken
 * (until the next time a connect is attempted), and false if the connection should be
 * remade asap.
 *)
let client name newconn =
  let nc send close =			(* Do a little type munching ... *)
    let send msg = send (Obj.repr msg) in
    let (deliv, fail) = newconn send close in
    let deliv msg = deliv (Obj.magic msg) in
    (deliv, fail)
  in {
    name = name;
    newconn = nc;
    wp = true;
    servers = [];
    state = Down
  }

(* Create a handle for the named RPC service.
 *)
let bind_to_service name =
  let queue = ref [] in
  let newconn xmit close =
    let deliv =
      let take mid =
        let (request, rest) = Xlist.take mid !queue in
        queue := rest;
        request
    in function
      | RESPONSE (mid, value) ->
	  let (_,complete,_) = take mid in complete value
      | EXCEPTION (mid, descr) ->
	  let (_,_,abort) = take mid in abort descr
    and fail() = (!queue = []) in
    List.iter (function (mid, (cmd,_,_)) -> xmit (mid, cmd)) !queue;
    (deliv, fail)
  in
  (client name newconn), queue

(* Add a server to the list of servers.  After this, the servers
 * may no longer be looked up in the White Pages.
 *)
let add_to_service (svc, _) host port =
  if svc.wp then
    begin
      svc.servers <- [host, port];
      svc.wp <- false
    end
  else
    svc.servers <- Xlist.add (host, port) svc.servers;
  Eval.add_server host port

(**********************************************************)

type wpstate
  = WP_UNINIT
  | WP_CONN of rpc_handle

let wpstate = ref WP_UNINIT

(* Get a handle on the WP service.
 *)
let wp_service() =
  match !wpstate with
    | WP_UNINIT ->
        let wp = bind_to_service "WhitePages" in
	let add (h, p) = add_to_service wp h p in
        List.iter add wpsvrs ;
	wpstate := WP_CONN wp;
	wp
    | WP_CONN wp ->
	wp

(* Find the given service in the White Pages.
 *)
let rec wp_lookup svc =
  let complete locs =
    let add (host, port) =
      svc.servers <- Xlist.add (host, port) svc.servers;
      Eval.add_server host port
    in
    ignore (List.map add (Obj.magic locs));
    conn_list := svc :: !conn_list
  in
  let abort descr =
    printf "White Pages for %s: %s\n" svc.name (Obj.magic descr);
    flush stdout;
    conn_list := svc :: !conn_list
  in
  let wp = wp_service() in
  rpc wp (Obj.magic (WP_LOOKUP svc.name)) complete abort

(* Do an RPC with the given service.  Invoke "complete result" if
 * successful, or "abort description" if not.  We queue the rpc command,
 * complete, and abort functions in the server's handle.  We assume that
 * responses are received in the same order as the requests were sent,
 * so that we can dequeue the handler as the response comes in.  If the
 * connection fails, we can use the queue to replay the commands.
 *)
and rpc (server, queue) cmd complete abort =
  let mid = (client_id, !seq_no) in
  incr seq_no;
  let data = Obj.repr cmd in
  queue := !queue @ [mid, (data, complete, abort)];
  match server.state with
    |  Down ->
	server.state <- Connect_First;
	if server.wp then
	  wp_lookup server
	else
          conn_list := server :: !conn_list
    | Connect_First | Connect_Again ->
	()
    | Up skt ->
	send skt (mid, data)

(* Connect to the given service.
 *)
and connect server =
  match server.state with
    | Down ->
	server.state <- Connect_First;
	if server.wp then
	  wp_lookup server
	else
          conn_list := server :: !conn_list
    | _ ->
	failwith "connection not broken"
 
(* Send a response to the client.  If this goes wrong, don't pass an
 * exception to the server, for two reasons.  First, it's none of the
 * server's business that the reply didn't get to the client.  But
 * perhaps more importantly, the server may well incorrectly try to
 * send an error response to the client, leading to infinite recursion.
 *)
let response (skt, mid) v =
  let o = RESPONSE (mid, Obj.repr v) in
  try
    send skt o
  with _ ->
    printf "Send failed -- response lost\n";
    flush stdout

(* Send an exception to the client.  See comments for response().
 *)
let except (skt, mid) v =
  let o = EXCEPTION (mid, Obj.repr v) in
  try
    send skt o
  with _ ->
    printf "Send failed -- exception lost\n";
    flush stdout

(**********************************************************)

(* Refresh the white pages.
 *)
let refresh_white_pages() =
  let wp = wp_service() in
  let host = Hsys.gethostname() in
  let update (name, port) =
    let complete _ = () in
    let abort descr =
      printf "White Pages (%s): %s\n" name (Obj.magic descr)
    in
      rpc wp (WP_UPDATE(name, [(host, port)])) complete abort
  in
  List.iter update !services;
  last_time := Time.gettimeofday()

(* Try to connect to the given service.
 *)
let try_connect svc =
  let (host, port) = Eval.pick_server svc.servers in
  let skt = Hsys.socket_stream() in
  let addr = Hsys.inet_of_string host in
  begin
    try
      Hsys.connect skt addr port
    with e ->
      printf "Connect to %s server %s:%d failed\n" svc.name host port;
      flush stdout;
      Eval.bring_down (host, port);
      raise e
  end;

  (* Same stuff as in server().
   *)
  let xmit msg = send skt msg in
  let close() = Hsys.close skt in
  let (deliv, fail) = svc.newconn xmit close in
  let upcall fd msg = deliv msg in

  (* This function is called when the connection fails.  Stick
   * the server handle on the connection list if the user's
   * fail routine returns false.
   *)
  let broken _ =
    Eval.bring_down (host, port);
    if fail() then
      begin
        printf "%s server %s:%d went down\n" svc.name host port;
	flush stdout;
        svc.state <- Down
      end
    else
      begin
        printf "%s server %s:%d went down; reconnecting\n"
	  svc.name host port;
	flush stdout;
        svc.state <- Connect_Again;
        conn_list := svc :: !conn_list
      end
  in

  let _ = register skt (gotinput upcall) broken in
  if svc.state = Connect_Again then
  begin
    printf "Connected to %s server %s:%d\n" svc.name host port;
    flush stdout
  end;
  Eval.bring_up (host, port);
  svc.state <- Up skt

(* Try to connect to the list of services in the conn_list.
 * Return a list of services that we want to try to connect to again.
 *)
let attempt_connections _ =
  let rec go_thru_list = function
      hd :: tl ->
	begin
	  try
	    try_connect hd;
	    go_thru_list tl
	  with _ ->
	    if hd.wp then
	      begin
		wp_lookup hd;
	        go_thru_list tl
	      end
	    else
	      begin
	    	if hd.state = Connect_First then
	      	  hd.state <- Connect_Again;
	    	hd :: (go_thru_list tl)
	      end
	end
    | [] -> []
  in
  Eval.recalculate();
  conn_list := go_thru_list !conn_list

let debug() =
  printf "services: %d\n" (List.length !services);
  printf "socks: %d\n" (List.length !socks);
  printf "reassembly_list: ";
  List.iter (fun (_, (b,t)) -> printf "<%d,%d>" (String.length b) t) !reassembly_list;
  printf "\n";
  (* Eval.debug(); *)
  Gc.print_stat stdout;
  flush stdout

(* This function should be called on a regular basis, say once
 * a second.
 *)
let sweep() =
  if !conn_list <> [] then
    attempt_connections();
  let now = Time.gettimeofday() in
  if (Time.sub now !last_time) > Time.of_int 60 then (
    refresh_white_pages()
    (* ; debug() *)
  )

(* Do a select on the current set of registered file descriptors.
 *)
let select timeout =
  (* On socket failure, remove from list of sockets and invoke the corresponding
   * fail function.
   *)
  let failure fd fail =
    unregister fd;
    fail fd;
    try
      Hsys.close fd; ()
    with _ ->
      ()
  in
  sweep();
  let (skts, _) = List.split !socks in
  if skts = [] then ( printf "Sockio.select: no sockets\n"; flush stdout ) ;
  let skts = Array.of_list skts in 
  let read = Array.create (Array.length skts) false in
  let select_info = Hsys.select_info (Some(skts,read)) None None in
  ignore (Time.select select_info timeout) ;
   
  for i = 0 to pred (Array.length skts) do
    if read.(i) then (
      let fd = skts.(i) in
      let (inputavail, fail) = List.assoc fd !socks in
      try
	inputavail fd
      with _ ->
       failure fd fail
    )
  done

(**********************************************************************************)

(* Create a socket to receive incoming connections on.  The routine newconn
 * is invoked for each new connection with a function argument that can be
 * used to send a message on the connection, and another to close the connection.
 * Newconn should return a tuple with the message delivery and failure notification
 * routines.
 *)
let server name port newconn =
  (* This function is invoked when there is a new connection.  It invokes
   * newconn in turn to get the delivery and failure handlers.
   *)
  let getconn skt =
    let (news,_,_) = Hsys.accept skt in
    let xmit msg = send news msg in
    let close() = Hsys.close news in
    let (deliv, fail) = newconn xmit close in
    let upcall fd msg = deliv msg in
    let broken fd = fail() in
    register news (gotinput upcall) broken
  in

  (* Create and register the socket.
   *)
  let sock = bind port in
  register sock getconn (fun _ -> ());
  services := (name, port) :: !services;
  refresh_white_pages()

(* Create a socket to receive incoming RPC connections on.
 *)
let start port eval =
  let getconn sock =
    let (news,_,_) = Hsys.accept sock in
    let upcall fd (mid, cmd) = eval (fd, mid) cmd in
    register news (gotinput upcall) (fun _ -> ())
  in
  let sock = bind port in
  register sock getconn (fun _ -> ())

(* Start a server, and install its location at the White Pages service.
 *)
let start_server name port eval =
  start port eval;
  services := (name, port) :: !services;
  refresh_white_pages()
