(**************************************************************)
(*
 *  Ensemble, (Version 0.40)
 *  Copyright 1997 Cornell University
 *  All rights reserved.
 *
 *  See ensemble/doc/license.txt for further information.
 *)
(**************************************************************)
(**************************************************************)
(* TCP.ML *)
(* Author: Mark Hayden, 7/96 *)
(**************************************************************)
open Util
open Trans
(**************************************************************)
let name = Trace.source_file "TCP"
let failwith s = failwith (Util.failmsg name s)
let log = Trace.log name ""
(**************************************************************)
(* Notes:

 * BUG: need to collect unused sockets.

 *)
(**************************************************************)
(* Set up a TCP socket.
 *)

let init host port =
  (* Create a TCP stream socket.
   *)
  let sock = Hsys.socket_stream () in

  (*set_close_on_exec sock ;*)
  (*set_nonblock sock ;*)

  let rec loop port =
    try
      Arge.check_port name port ;
      Hsys.bind sock host port ;
      Hsys.listen sock 5 ;
      port
    with e -> (
      log (fun () -> Hsys.error e) ;
      loop (succ port)
    )
  in

  let port = loop port in

  (sock,port)

(**************************************************************)

type conn = {
  mutable sock : Hsys.socket option ;
  mutable peer : (Hsys.inet * port) option ;
  mutable send : (buf -> ofs -> len -> int) option ;
  mutable access : Time.t
}

(**************************************************************)

let domain alarm =
  let conns = 
    let table = Hashtbl.create 10 in
    Trace.install_root (fun () ->
      [sprintf "TCP:#conns=%d\n" (hashtbl_size table)]
    ) ;
    table
  in
      
  let nconns = ref 0 in
  let host = Hsys.inet_of_string (Hsys.gethostname ()) in
  let port = Arge.check name Arge.port in
  let (server, port) = init host port in
  Unique.install_port port ;

  let break c =
    log (fun () -> "breaking connection") ;
    if_some c.sock (fun sock -> 
      Alarm.rmv_sock alarm sock ;
      Hsys.close sock ;
      c.sock <- None ;
    ) ;
    c.send <- None ;
  in

  let disable c =
    break c ;
    if_some c.peer (fun peer ->
      Hashtbl.remove conns peer
    )
  in

  let find addr =
    let conn =
      match addr with
      |	None -> None
      |	Some addr ->
	  try Some (Hashtbl.find conns addr) 
	  with Not_found -> None
    in
    
    match conn with 
    | Some c -> c
    | None ->
      	let conn = {
	  peer = addr ;
      	  sock = None ;
      	  send = None ;
	  access = Alarm.gettime alarm
      	} in
	if_some addr (fun addr ->
      	  Hashtbl.add conns addr conn ;
	) ;
	conn
  in

  let deliver = Route.deliver in
  let install c =
    match c.sock with
    | None -> failwith "install_conn:sanity"
    | Some sock ->
	let (send,recv) = Hsyssupp.tcp sock in
	let recv () =
	  match recv () with
	  | Some msgs ->
	      List.iter (fun s ->
		let iov = Iovec.of_string name s in
		let dint = Iovec.read_int name iov in
	    	Iovec.break name iov (deliver dint)
	      ) msgs
	  | None ->
	      log (fun () ->
	      	sprintf "lost connection to %s"
	        (string_of_option (fun (inet,port) ->
		  sprintf "%s:%d" (Hsys.string_of_inet inet) port) c.peer)) ;
	      break c
	in
	Alarm.add_sock alarm sock (Hsys.Handler0 recv) ;
	if_some c.peer (fun peer ->
	  c.send <- Some send ;
	  Hashtbl.add conns peer c
	)
  in

  let connect c =
    match c.sock,c.peer with
    | None,Some(inet,port) ->
	log (fun () -> sprintf "connecting to {%s:%d}" (Hsys.string_of_inet inet) port) ;
      	let sock = Hsys.socket_stream () in
      	begin try
	  Util.disable_sigpipe () ;
	  Hsys.connect sock inet port ;
	  c.sock <- Some sock ;
	  install c ;
	with e ->
	  log (fun () -> sprintf "error:%s" (Hsys.error e)) ;
	  Hsys.close sock
	end
    | _,_ -> failwith "connect:sanity" ;
  in

  (* Accept a connection on the server socket.
   *)
  let accept _ =
    try
      log (fun () -> sprintf "accepting connection") ;
      Util.disable_sigpipe () ;
      let conn = find None in
      let (sock,_,_) = Hsys.accept server in
      conn.sock <- Some sock ;
      install conn
    with e ->
      log (fun () -> sprintf "error:%s" (Hsys.error e)) ;
      failwith "error occurred while accepting"
  in

  let addr = Addr.TcpA(host,port) in
  let addr _ = addr in

  let xmit _ dest =
    (* Do some preprocessing.
     *)
    let conns =
      match dest with
      | Domain.Mcast _ -> failwith "mcast not supported"
      | Domain.Gossip _ -> failwith "gossip not supported"
      | Domain.Pt2pt dests ->
	  Array.map (fun dest ->
	    match Addr.project dest Addr.Tcp with
	    | Addr.TcpA(inet,port) -> find (Some(inet,port))
	    | _ -> failwith "xmit:sanity"
	  ) dests
    in
    log (fun () -> sprintf "TCP:selected") ;

    let xmit buf ofs len =
      Array.iter (function
	| {send=Some send} as c ->
	    let ret = send buf ofs len in
	    if ret = 0 then (
	      log (fun () -> sprintf "send error:len=%d" len) ;
	      break c
	    )
	| {sock=None;peer=Some _} as c ->
	    (* If we have a name, but no socket, then
	     * try connecting.
	     *)
	    connect c ;
	    if_some c.send (fun send ->
	      let ret = send buf ofs len in
	      if ret = 0 then (
		log (fun () -> sprintf "send error:len=%d" len) ;
		break c
	      )
	    )
	| _ -> ()
      ) conns
    in
    let xmitv iov =
      let iov = Iovecl.flatten name iov in(*PERF*)
      Iovec.read name iov xmit
    in
    Some(xmit,xmitv)
  in

  let enable  _ _ _ _ = Alarm.add_sock alarm server (Hsys.Handler0 accept) in
  let disable _ _ _ _ = Alarm.rmv_sock alarm server in 

  Domain.create name addr enable disable xmit

(**************************************************************)

let _ =
  Domain.install Addr.Tcp domain

(**************************************************************)
