(**************************************************************)
(*
 *  Ensemble, (Version 0.40)
 *  Copyright 1997 Cornell University
 *  All rights reserved.
 *
 *  See ensemble/doc/license.txt for further information.
 *)
(**************************************************************)
(**************************************************************)
(* PROXY.ML *)
(* Author: Mark Hayden, 11/96 *)
(* Designed with Roy Friedman *)
(**************************************************************)
open Util
open Trans
(**************************************************************)
let name = Trace.source_file "PROXY"
let failwith s = failwith (name^":"^s)
let log = Trace.log name ""
let logc = Trace.log (name^"C") ""
(**************************************************************)

type id = string

(* ENDPT: the type of endpoints.
 *)
type endpt = id

(* GROUP: the type of groups.
 *)
type group = id

(**************************************************************)

type member_msg =
  | Join
  | Synced
  | Fail of endpt list
   
(**************************************************************)

type coord_msg =
  | View of ltime * (endpt list)
  | Sync
  | Failed of endpt list

(**************************************************************)

type tcp_message = 
  | Coord of coord_msg
  | Member of member_msg
      
let string_of_msg = function
  | Coord msg -> 
      let msg = match msg with
      | View(ltime,view) -> sprintf "View(%d,%d)" ltime (List.length view)
      | Failed(_) -> "Failed"
      | Sync -> "Sync"
      in sprintf "Coord(%s)" msg
  | Member msg ->
      let msg = match msg with
      |	Join -> "Join"
      |	Synced -> "Synced"
      |	Fail(_) -> "Fail"
      in sprintf "Member(%s)" msg

(**************************************************************)

type t = {
  sock : Hsys.socket ;
  send : group -> endpt -> member_msg -> unit ;
  members : (group * endpt, coord_msg -> unit) Hashtbl.t
} 
    
(**************************************************************)

let make_marsh_id name =
  let marsh,unmarsh = Util.make_marsh ("PROXY:noshare:"^name) false in
  let marsh o =
    let o = Deepcopy.f o in
    marsh o
  and unmarsh s =
    unmarsh s 0 (String.length s)
  in (marsh,unmarsh)

let ma_endpt, um_endpt = make_marsh_id "endpt"
let ma_group, um_group = make_marsh_id "group"

(**************************************************************)

let string_of_endpt endpt =
  try
    let endpt = um_endpt endpt in
    Endpt.string_of_full endpt
  with Failure _ ->
    (* If unmarsh doesn't work, then just print Id as hex.
     *)
    sprintf "Endpt_nm{%s}" 
      (hex_of_string endpt)

(**************************************************************)
(*
let marsh,unmarsh = Util.make_marsh name
*)
(**************************************************************)

open Marsh
let write_endpt = write_string
let write_group = write_string
let read_endpt = read_string
let read_group = read_string

(*
let write_endpt m e = 
  let e = Obj.marshal (Obj.repr e) in
  write_string m e
let write_group m e = 
  let e = Obj.marshal (Obj.repr e) in
  write_string m e

let read_endpt m = 
  let e = read_string m in
  let e,_ = Obj.unmarshal e 0 in
  Obj.magic e
let read_group m = 
  let e = read_string m in
  let e,_ = Obj.unmarshal e 0 in
  Obj.magic e
*)
let marsh,unmarsh =
  let coord_view = 0
  and coord_sync = 1
  and coord_failed = 2
  and member_join = 3
  and member_sync = 4
  and member_fail = 5
  in

  let marsh msg =
    let m = init () in
    let (group,endpt,msg) = msg in
    let common () =
      write_group m group ;
      write_endpt m endpt
    in
    begin
      match msg with
      | Coord(View(ltime,view)) ->
	  write_int m coord_view ;
	  common () ;
	  write_int m ltime ;
	  write_list m (fun endpt ->
	    write_endpt m endpt
	  ) view
      | Coord(Sync) ->
	  write_int m coord_sync ;
	  common ()
      | Coord(Failed(failed)) ->
	  write_int m coord_failed ;
	  common () ;
	  write_list m (fun endpt ->
	    write_endpt m endpt
	  ) failed ;
      | Member(Join) ->
	  write_int m member_join ;
	  common ()
      | Member(Synced) ->
	  write_int m member_sync ;
	  common ()
      | Member(Fail(failed)) ->
	  write_int m member_fail ;
	  common () ;
	  write_list m (fun endpt ->
	    write_endpt m endpt
	  ) failed
    end ;
    marsh m
  in

  let map = [
    coord_view,(fun m ->
      let ltime = read_int m in
      let view = read_list m (fun () -> read_endpt m) in
      Coord(View(ltime,view))
    ) ;
    coord_sync,(fun m -> Coord(Sync)) ;
    coord_failed,(fun m -> 
      let failed = read_list m (fun () -> read_endpt m) in
      Coord(Failed(failed))
    ) ;
    member_join,(fun m -> Member(Join)) ;
    member_sync,(fun m -> Member(Synced)) ;
    member_fail,(fun m -> 
      let failed = read_list m (fun () -> read_endpt m) in
      Member(Fail(failed))
    )
  ] in
  
  let unmarsh buf ofs len =
    try
      let msg = String.sub buf ofs len in
      let m = Marsh.unmarsh msg in
      let tag = read_int m in
      let group = read_group m in
      let endpt = read_endpt m in
      let f = List.assoc tag map in
      let msg = f m in
      (group,endpt,msg)
    with Marsh.Error(s) ->
      eprintf "PROXY:unmarshalling error:bad format, exiting\n" ;
      exit 0
  in
  
  (marsh,unmarsh)

(**************************************************************)

let create sock =
  let alarm = Alarm.get () in
  let verbose = ref true in
  let connected = ref true in
  
  let send,recv = Hsyssupp.tcp sock in
  let send group endpt msg =
    log (fun () -> string_of_msg (Member msg)) ;
    let msg = group,endpt,(Member msg) in
    let msg = marsh msg in
    send msg 0 (String.length msg) ;
    ()
  in
  
  let s = {
    sock = sock ;
    send = send ;
    members = Hashtbl.create 100
  } in 
  
  let recv () =
    match recv () with
    | Some msgs ->
	List.iter (fun msg ->
	  let msg = unmarsh msg 0 (String.length msg) in
	  let (group,endpt,msg) = msg in
	  log (fun () -> string_of_msg msg) ;
	  let msg = match msg with Coord msg -> msg | _ -> failwith "sanity" in
	  let to_client =      
	    try Hashtbl.find s.members (group,endpt)
	    with Not_found ->
	      if !verbose then
		eprintf "PROXY:recv'd message, but no member\n" ;
	      fun _ -> ()
	  in
	  to_client msg
	) msgs ;
    | None ->
	connected := false ;
	Alarm.rmv_sock alarm sock ;
	Hsys.close sock ;
	eprintf "PROXY:lost connection to domain server, exiting\n" ;
	exit 1 ;
  in
  Alarm.add_sock alarm sock (Hsys.Handler0 recv) ;
  s
  
(**************************************************************)

let join s group endpt to_client =
  Hashtbl.add s.members (group,endpt) to_client ;
  let to_server msg =
    s.send group endpt msg
  in
  to_server
  
(**************************************************************)

let conns = ref 0

let server port join =
  let host = Hsys.inet_of_string (Hsys.gethostname ()) in
  let sock = Hsys.socket_stream () in
  Hsys.setsockopt sock Hsys.Reuse ;

  if not (Arge.get Arge.quiet) then
    eprintf "PROXY:server binding to port %d\n" port ;
  begin
    try Hsys.bind sock host port with e ->
      if not (Arge.get Arge.quiet) then
	eprintf "PROXY:error:%s, exiting \n" (Hsys.error e) ;
      exit 1
  end ;
  Hsys.listen sock 5 ;
  if not (Arge.get Arge.quiet) then
    eprintf "PROXY:server installed\n" ;
  let alarm = Alarm.get () in
  
  let client_init sock =
    let clients = Hashtbl.create 100 in
    let connected = ref true in
    
    let send,recv = Hsyssupp.tcp sock in
    let send group endpt msg =
      if !connected then (
	begin match msg with
	| Failed(failed) when List.mem endpt failed ->
	    Hashtbl.remove clients (group,endpt)
	| _ -> ()
	end ;
	log (fun () -> string_of_msg (Coord msg)) ;
	let msg = group,endpt,(Coord msg) in
	let msg = marsh msg in
	send msg 0 (String.length msg) ;
	()
      )
    in
    let recv () =
      match recv () with
      |	Some msgs ->
	  List.iter (fun msg ->
	    let msg = unmarsh msg 0 (String.length msg) in
	    let (group,endpt,msg) = msg in
	    log (fun () -> string_of_msg msg) ;
	    let msg = match msg with Member m -> m | _ -> failwith "sanity" in

	    (* For joins, we first have to add a connection record
	     * to our table.
	     *)
	    begin match msg with
	    | Join ->
		let to_client = send group endpt in
		let to_server = join group endpt to_client in
		Hashtbl.add clients (group,endpt) to_server
	    | _ -> ()
	    end ;

	    (* If record of the client is not here, then we assume
	     * the connection was previously broken and do nothing.
	     *)
	    begin 
	      try 
	      	let to_server = Hashtbl.find clients (group,endpt)  in
	      	to_server msg
	      with Not_found -> ()
	    end
	 ) msgs ;
      |	None ->
	  decr conns ;
	  logc (fun () -> sprintf "#connections=%d" !conns) ;
	  connected := false ;
	  Alarm.rmv_sock alarm sock ;
	  Hsys.close sock ;
	  Hashtbl.iter (fun (_,endpt) to_server ->
	    to_server (Fail([endpt]))	(*PERF*)
	  ) clients
    in
    recv
  in
  
  let svr_handler () =
    let client,_,_ = Hsys.accept sock in
    incr conns ;
    logc (fun () -> sprintf "#connections=%d" !conns) ;
    Alarm.add_sock alarm client (Hsys.Handler0 (client_init client))
  in
  Alarm.add_sock alarm sock (Hsys.Handler0 svr_handler)

(**************************************************************)
