(**************************************************************)
(*
 *  Ensemble, (Version 0.40)
 *  Copyright 1997 Cornell University
 *  All rights reserved.
 *
 *  See ensemble/doc/license.txt for further information.
 *)
(**************************************************************)
(**************************************************************)
(* SESSVR.ML: session server *)
(* The front end of the remote execution server.
 * If window from which sessvr is started dies, then sessvr dies also
 * when output is attempted.  To avoid this, redirect output to e.g.,
 * /dev/null.
 *)
(* TODO: handle execsvr or client failures *)
(* Author: Takako M. Hickey, 4/97 *)
(* Thanks to Mark Hayden and Robbert van Renesse for many useful
 * discussions.
 *)
(**************************************************************)
open Ensemble
open Util 
open View
open Appl_intf 
open Rpc
open Session
open Db
open Dbinput
open Env
open Dutil
open Clientreq
(**************************************************************)
let name = Trace.source_file "SESSVR" 
let failwith s = failwith (name^":"^s) 
(**************************************************************)
let default_sessvr_port	= 8123
let heartbeat_rate	= ref 1.0
let shutdown		= ref false	
let uptime		= ref false	

let my_state = {
  sttype = SvrSess ;
  stname = Unix.gethostname () ;
  stendpt = None ;
  staticentry = [] ;
  stload = 0.0 ;
  stuptime = 0.0 ;
  nrunning = 0 ;
  nwaiting = 0 ;
  ncompleted = 0 ;
  gossipid = 0 ;
  strestrictions = [] 
}

(**************************************************************)
let intf (ls,vs) =
  let vs		= ref vs
  and ls		= ref ls
  and curtime		= ref 0.0	(* current logical time *)
  and msgq 		= Queue.create ()
  and events_to_submit	= ref []
  and ntowait		= ref 0
  and uptowait		= ref 0
  in

  let rank_of_endpt endpt = array_index endpt !vs.view (* may raise Not_found *)
  and endpt_of_rank rank = !vs.view.(rank) (* may raise Invalid_argument *)
  and async = Appl.async (!vs.group, !ls.endpt) in
  let gettime () = Time.to_float (Alarm.gettime (Alarm.get())) in
  let starttime	= (gettime ())			(* start time *)
  and clear_events () =
    events_to_submit := []
  and add_event event =
    events_to_submit := !events_to_submit @ [event] ;
  in

(*
  let error s =
    Util.printf "ERROR: %s\n" s ;
  let step1 q =
    printf "STEP1\n"; flush stdout;
    Rqueue.enqueue q "hello" (step2 q) error
*)

  let slist = sesslist_create () in


  (* Remove information of machines that no longer are members.
   *)
  let remove_dead_resource () =
    let rec loop = function
      | [] -> []
      | h::t ->
          let endpt = getdbval "endpt" h in
          (match endpt with
           | Endpt e ->
               (try
                 let i = rank_of_endpt e in
                 h::(loop t)
               with Not_found -> (
                 Util.printf "SESSVR: removing information for dead procsvr: %s\n" (Endpt.string_of_id e);
                 loop t
               ))
           | _ -> loop t
          )
    in
      db := loop !db 
  in

(**************************************************************)

  let select_machines programs machspec procspec flag =
    let total = ref 0 in
    let ndisjunct = ref (Array.length machspec) in
    let machines = array_create name !ndisjunct [] in
    let nindiv = array_create name !ndisjunct 0 in
    let n = ref (Array.length programs) in
    if (!ndisjunct = 1 & machspec.(0) = []) or (!ndisjunct = 0) then (
      (* If no machine specification, pick a machine at random *)
      machines.(0) <- db_restrict procspec !db ;
      nindiv.(0) <- List.length machines.(0) ;
      total := !total + nindiv.(0) ;
      ndisjunct := 1 ;
    )
    else (
      let i = ref 0 in
      while !total < !n & !i < !ndisjunct do
        machines.(!i) <- db_select_with_restrict machspec.(!i) procspec ;
        nindiv.(!i) <- List.length machines.(!i) ;
        total := !total + nindiv.(!i) ;
        i := !i + 1 ;
      done
    ) ;
    let ok = ref true in
    (match flag with
    | AllOrNothing() ->
        if !total < !n then
          ok := false ;
    | MultLimit(k) ->
        if (!total * k) < !n then
          n := !total * k ;
    | Unlimited() -> ()
    ) ;
    Util.printf "%s: selected %d machines for %d commands\n" my_state.stname !total !n ;
    if !total = 0 then
      ok := false 
    else (
      if !ndisjunct = 1 & !total > !n then (
        machines.(0) <- db_random !n machines.(0) true [] ;
        nindiv.(0) <- !n
      )
    ) ;
    (!ok, !ndisjunct, machines, nindiv, !n)
  in
  let handle_request req = (match req.crreq with
    | SessCreate(namehint) ->
        let s = sess_create slist namehint in
(*
  Util.printf "%s: Sesssion created %s\n" my_state.stname s.sessname ;
*)
        Sockio.response req.crrid (SessCreateSuccess(namehint, s.sessname))
    | SessDestroy(name) ->
        (try
           sess_destroy slist name ;
           Sockio.response req.crrid (SessDestroySuccess(name))
         with Failure(reason) ->
           Sockio.response req.crrid (SessOpFailure(name, SeDestroy, reason))
        )
    | SessWait(name) ->
        (try
           let s  = sess_req_add slist name SeWait req.crrid (-1) in
           let send_wait p =
             let dest = rank_of_endpt p.seprocsvr in
             add_event (Send ([dest], (DSessWait(name, s.req_ticket, p.seprocname))))
           in
             List.map send_wait s.processes ; ()
        with Not_found ->
          Sockio.response req.crrid (SessOpFailure(name, SeWait, "Session not found"))
        )
    | SessSig(name, sesssig) ->
       (try
          let s = sess_req_add slist name SeSig req.crrid (-1) in
           let send_sig p =
             let dest = rank_of_endpt p.seprocsvr in
             add_event (Send ([dest], (DSessSig(name, s.req_ticket, p.seprocname, sesssig))))
           in
             List.map send_sig s.processes ; ()
        with Not_found ->
          Sockio.response req.crrid (SessOpFailure(name, SeSig, "Session not found"))
       )
  
    | ProcCreate(sessname, programs, env, machspec, procspec, flag) ->
  Util.printf "%s: Got ProcCreate for %s\n" my_state.stname sessname ;
        let nprog = Array.length programs in
        if nprog = 0 then
          Sockio.response req.crrid (ProcOpFailure(sessname, PrCreate, "no program specified"))
        else (try
          let s = sess_lookup_incticket slist sessname nprog in
          let (ok, ndisjunct, machines, nindiv, n) =
            select_machines programs machspec procspec flag
          in
          if ok then (
            let s = sess_req_add slist sessname PrCreate req.crrid n in
            let total = ref 0 in
            let i = ref 0 in
            let t = ref (s.proc_ticket - n) in
            while !total < n do
              for j = 0 to nindiv.(!i) - 1 do
                let e = getdbval "endpt" (Array.of_list machines.(!i)).(j) in
                (match e with
                 | Endpt e ->
                     (* TODO: do more optimal assignment *)
                     (match programs.(!total).(!i) with
                     | Cmd cmd ->
                         let dest = rank_of_endpt e in
                         add_event (Send ([dest], (DProcCreate(sessname, s.req_ticket, cmd, env, !t)))) 
                     | _ -> ())
                 | _ -> ()
                ) ;
                t := !t + 1 ;
                total := !total + 1 ;
              done ;
              i := !i + 1 ;
              if !i >= ndisjunct then
                i := 0
            done
          )
          else
            Sockio.response req.crrid (ProcOpFailure(sessname, PrCreate, "not enough appropriate machines"))
        with exn ->
          (try Printexc.print (function () -> raise exn) () with _ ->
          Sockio.response req.crrid (ProcOpFailure(sessname, PrCreate, ("session not found: " ^ sessname))))
      )
    | ProcWait(sessname, procname) ->
  Util.printf "%s: Got ProcWait for %s %s\n" my_state.stname sessname procname ;
        (try
          let p = sess_proc_lookup slist sessname procname in
          let s = sess_req_add slist sessname PrWait req.crrid 1 in
          let dest = rank_of_endpt p.seprocsvr in
(*
  Util.printf "%s: Sending DProcWait for %s %s (%d)\n" my_state.stname sessname procname reqid ;
*)
          add_event (Send ([dest], (DProcWait(sessname, s.req_ticket, procname)))) 
        with Not_found ->
          Sockio.response req.crrid (ProcOpFailure(name, PrWait, "Process not found"))
       )
    | ProcSig(sessname, procname, sesssig) ->
        (try
           let p = sess_proc_lookup slist sessname procname in
           let s = sess_req_add slist sessname PrSig req.crrid 1 in
           let dest = rank_of_endpt p.seprocsvr in
           add_event (Send ([dest], (DProcSig(sessname, s.req_ticket, procname, sesssig, req.crrid)))) 
        with Not_found ->
           Sockio.response req.crrid (ProcOpFailure(name, PrSig, "Process not found"))
       )
    | DBGetEntry(key) ->
        let mlist = Array.of_list (db_select [(DBeq, key)]) in
        if Array.length mlist > 0 then (
          Sockio.response req.crrid (DBGetEntryReply(key, mlist.(0)))
        )
        else (
          Sockio.response req.crrid (DBGetEntryReply(key, [])) ;
        )
    | DBDeleteEntry(key) ->
        (* TODO: wait to send response back until op really done *)
        Sockio.response req.crrid (DBDeleteEntryOk(key)) ;
        add_event (Cast (DDBDeleteEntry(key)))
    | DBChangeAttributes(key, changes) ->
        (* TODO: wait to send response back until op really done *)
        Sockio.response req.crrid (DBChangeAttributesOk(key)) ;
        add_event (Cast (DDBChangeAttributes(key, changes)))
    | Shutdown() ->
        (* TODO: wait to send response back until op really done *)
        Sockio.response req.crrid (ShutdownOk()) ;
        add_event (Cast (DShutdown()))
    | _ -> Sockio.response req.crrid (OpFailure("unknown op"))
  )
  in

  let recv_msg endpt msg = (match msg with
    | DSessWaitSuccess(sessname, reqid, procname) ->
        (try
          let r = sess_req_lookup_decntowait slist sessname reqid in
          if r.ntowait <= 0 then (
             Sockio.response r.clientrid (SessWaitSuccess(sessname)) ;
             sess_req_remove slist sessname reqid
          )
          with Not_found -> ()
        )
    | DSessSigSuccess(sessname, reqid, procname, sesssig) ->
        (try
          let r = sess_req_lookup_decntowait slist sessname reqid in
          if r.ntowait <= 0 then (
             Sockio.response r.clientrid (SessSigSuccess(sessname, sesssig)) ;
             sess_req_remove slist sessname reqid ;
             if sesssig = SessSigKill then
               sess_processes_clear slist sessname
          )
          with Not_found -> ()
        )
    | DProcCreateSuccess(sessname, reqid, proc) ->
        let s = sess_proc_add slist sessname proc in
(*
Util.printf "%s: Got DProcCreateSuccess for: (%s %s)\n" my_state.stname sessname proc.seprocname;
*)
        (try
           let r = sess_req_lookup_decntowait_addproc slist sessname reqid proc in
           if r.ntowait <= 0 then (
(*
Util.printf "%s: Sending ProcCreateSuccess for: (%s %s)\n" my_state.stname sessname proc.seprocname ;
*)
              Sockio.response r.clientrid (ProcCreateSuccess(sessname, Array.of_list r.proctoreturn)) ;
              sess_req_remove slist sessname reqid ;
           )
         with Not_found -> ()
        )
    | DProcWaitSuccess(sessname, reqid, procname) ->
  Util.printf "%s: Got DProcWaitSuccess for %s %s (%d)\n" my_state.stname sessname procname reqid ;
        (try
           let r = sess_req_lookup slist sessname reqid in
(*
Util.printf "%s: Sending ProcWaitSuccess for: (%s %s)\n" my_state.stname sessname procname;
*)
           Sockio.response r.clientrid (ProcWaitSuccess(sessname, procname)) ;
           (* TODO: remove requests more efficiently *)
           sess_proc_remove slist sessname procname ;
           sess_req_remove slist sessname reqid ;
           my_state.ncompleted <- my_state.ncompleted + 1 ;
         with Not_found -> ()
        )
    | DProcSigSuccess(sessname, reqid, procname, sesssig, rid) ->
        (try
           let r = sess_req_lookup slist sessname reqid in
           Sockio.response r.clientrid (ProcSigSuccess(sessname, procname, sesssig)) ;
           if sesssig = SessSigKill then
             sess_proc_remove slist sessname procname ;
           (* TODO: remove requests more efficiently *)
           sess_req_remove slist sessname reqid ;
         with Not_found -> ()
        )

    | DDBDeleteEntry(key) ->
        db_delete (DBeq, key)
    | DDBChangeAttributes(key, changes) ->
        db_change key changes
    | DShutdown() ->
        Util.printf "SESSVR: got shutdown request\n" ;
        add_event Leave
    | DUpQuery() ->
        db_print () ;
        Util.printf "\n" ;
        let time = (gettime ()) -. starttime in
        let load = get_load () in
        let dest = rank_of_endpt endpt in
        add_event (Send ([dest], (DUpAnswer(my_state.sttype, my_state.stname, time, load, my_state.ncompleted))))
    | DUpAnswer(svr, machine, time, load, nissued) ->
        (match my_state.stendpt with
        | Some e ->
            if not (e = endpt) then
              print_uptime svr machine time load nissued ;
        | None -> ()) ;
        uptowait := !uptowait - 1 ;
        if !uptowait <= 0 then
          add_event Leave
    | DGossip(fromhost, time, changes) ->
        db_change ("endpt", Endpt endpt) changes ;
    | _ -> ()

  ) ;
  let events = !events_to_submit in
  clear_events () ;
  events
  in

  let recv_cast from msg =
    try
      let endpt = endpt_of_rank from in
      recv_msg endpt msg 
    with Invalid_argument _ -> (
      printf "SESSVR: msg from unknown member %d\n" from ;
      []
    )
  and recv_send from msg =
    try
      let endpt = endpt_of_rank from in
      recv_msg endpt msg 
    with Invalid_argument _ -> (
      printf "SESSVR: msg from unknown member %d\n" from ;
      []
    )

  and block () = []

  (* Periodic routine *)
  and heartbeat tick = (
    curtime := (Time.to_float tick) ;
    let empty = ref false in
    while (not !empty) do
      try
        let (endpt, msg) = Queue.take msgq in
        recv_msg endpt msg ;
      with Queue.Empty -> empty := true
    done ;

    (try
      let req = crq_assign_next () in
      handle_request req ;
     with Queue.Empty -> ()
    ) ;

    let events = !events_to_submit in
    clear_events () ;
    events
  )

  (* Save msg for later processing.  Must save endpt instead of rank
   * since rank may change via view change.
   *)
  and block_recv_cast from msg =
    (try
      Queue.add ((endpt_of_rank from), msg) msgq ;
      print_msghdr msg
    with Invalid_argument _ ->
      print_msghdr msg)
  and block_recv_send from msg =
    (try
      Queue.add ((endpt_of_rank from), msg) msgq ;
      print_msghdr msg
    with Invalid_argument _ ->
      print_msghdr msg)

  and block_view (ls,vs) = [ls.rank, [my_state]]
  and block_install_view (ls,vs) s =
    List.fold_left Lset.union [] s 
  and unblock_view (ls',vs') s =
    db := [] ;
    restrictions := [] ;
    let f st =
      if st.sttype = SvrProc then (
        db_add (compose_dbentry st) ;
        db_restrictions_add st.stname st.strestrictions ;
      )
    in
    List.map f s ;

    vs := vs' ;
    ls := ls' ;
    remove_dead_resource () ;
    let i = truncate (gettime()) in
    Random.init i ;

    Util.printf "SESSVR: got view, nmembers=%d " !ls.nmembers ;
    Util.printf "\n";
    db_print () ;

    if !shutdown then (
      if !ls.nmembers > 1 then (
        (* TODO: add in security check *)
        Util.printf "SESSVR: Requesting shutdown\n" ;
        [Cast (DShutdown()); Leave]
      )
      else
        []
    )
    else if !uptime then (
      if !ls.nmembers > 1 then (
        Util.printf "SESSVR: Requesting uptime\n" ;
(*
        upowait := !ls.nmembers - 1 ;
*)
        uptowait := !ls.nmembers ;
        [Cast (DUpQuery())]
      )
      else
        []
    )
    else (
      events_to_submit := [] ;
      !events_to_submit
    )

  and exit () =
    exit 0

  in full (*(Appl_intf.debug*) {
    recv_cast           = recv_cast ;
    recv_send           = recv_send ;
    heartbeat           = heartbeat ;
    heartbeat_rate	= Time.of_float !heartbeat_rate ;
    block               = block ;
    block_recv_cast     = block_recv_cast ;
    block_recv_send     = block_recv_send ;
    block_view 		= block_view ;
    block_install_view  = block_install_view ;
    unblock_view        = unblock_view ;
    exit                = exit
  }(*)*) 

(**************************************************************)

let run () =
  Sys.signal Sys.sigpipe Sys.Signal_ignore ;

  let props = Property.Total :: Property.Scale :: Property.vsync in
  let props = List.map Property.string_of_id props in
  let props = String.concat ":" props in
  Arge.set_default Arge.properties props ;

  (*
   * Parse command line arguments.
   *)
  let undoc = "undocumented" in
  Arge.parse [
    ("-uptime",         Arg.Unit(fun () -> uptime := true),undoc) ;
    ("-shutdown",       Arg.Unit(fun () -> shutdown := true),undoc) 
  ] (Arge.badarg name) "sessvr" ;

  (*
   * Get default transport and alarm info.
   *)
  let (ls,vs) = Appl.default_info "execsvr" in
  let alarm = Alarm.get () in

  let addr = (Unix.gethostbyname my_state.stname).Unix.h_addr_list.(0) in
  my_state.stendpt <- Some ls.endpt ;

  (*
   * Initialize RPC
   *)
  let dp =
    (try
       int_of_string (Sys.getenv "ENS_EXECSVR_PORT")
     with Not_found ->
       default_sessvr_port
    )
  in
  Sockio.ensemble_register (Alarm.add_sock alarm) (Alarm.rmv_sock alarm) ;
  Sockio.start dp crq_unassigned_add_new ;

  (*
   * Initialize the application interface.
   *)
  let interface = intf (ls,vs) in

  (*
   * Initialize the Horus protocol stack, using the
   * interface, transports, and group endpt chosen above.  
   *)
  Appl.config interface (ls,vs) ;

  (*
   * Enter a main loop.  The argument is the number of non-blocking
   * select before doing blocking select.
   *)
  Appl.main_loop ()
 (* end of run function *)


(* Run the application, with exception handlers to catch any
 * problems that might occur.  
 *)
let _ =
  Appl.exec ["sessvr"] run

(**************************************************************)
