(**************************************************************)
(*
 *  Ensemble, (Version 0.40)
 *  Copyright 1997 Cornell University
 *  All rights reserved.
 *
 *  See ensemble/doc/license.txt for further information.
 *)
(**************************************************************)
(**************************************************************)
(* ENCRYPT.ML *)
(* Author: Mark Hayden, 4/97 *)
(**************************************************************)
open Trans
open Layer
open View
open Event
open Util
(**************************************************************)
let name = Trace.source_file "ENCRYPT"
let failwith s = failwith (Util.failmsg name s)
(**************************************************************)

type chan = {
  context : Shared.context ;
  mutable seqno : seqno
}

type header = NoHdr 

(* Data(seqno,length,checksum):
 * + seqno: for sanity checks
 * + length: gives actual payload size
 * + checksum: for further sanity checks
 *
 * Note that we could use the checksum for authenticating
 * the sender.
 *)
  | Data of seqno * len * Digest.t

type state = {
  shared        : Shared.t ;
  x_cast	: chan ;
  x_send        : chan array ;
  r_cast	: chan array ;
  r_send	: chan array
}

let dump (ls,vs) s =
  eprintf "ENCRYPT:dump:%s\n" ls.name

let init () (ls,vs) = 
  let shared = Shared.of_key vs.key in
  let chan e _ = 
    { context = Shared.init shared vs.key e ;
    seqno = 0 } 
  in 
  { shared = shared ;
    x_cast = chan true () ;
    r_cast = array_createf ls.nmembers (chan false) ;
    x_send = array_createf ls.nmembers (chan true) ;
    r_send = array_createf ls.nmembers (chan false) }

let hdlrs s (ls,vs) {up_out=up;upnm_out=upnm;dn_out=dn;dnlm_out=dnlm;dnnm_out=dnnm} =
  let log = Trace.log name ls.name in

  (* Encrypt the iovec.  Note that DES requires that the
   * payload size be a multiple of 8.
   *)
  let encrypt chan iov =
    let len = Iovecl.len name iov in
    let ceil = ((len + 7) / 8) * 8 in
    let fill = ceil - len in
    let iov =
      if fill = 0 then
        iov
      else 
	(* Add some random filler data to the end.
	 *)
        Array.append iov [| Iovec.create name fill |]
    in
    let iov = Iovecl.flatten name iov in
    let iov,checksum = 
      Iovec.read name iov (fun buf ofs len ->
	(* Note that the checksum is for more than application data.
	 *)
	let checksum = Digest.substring buf ofs len in
	let dst = String.create ceil in
        Shared.update s.shared chan.context buf ofs dst 0 len ;
	let dst = Iovec.heap name dst in
        (Iovec.alloc name dst 0 len), checksum
      )
    in
    let iov = [| iov |] in
    (iov,len,checksum)
  in

  (* Decrypt the iovector and chop of excess iovec data.
   *)
  let decrypt chan iov actual checksum =
    let iov = Iovecl.flatten name iov in
    let len = Iovec.len name iov in
    if (len mod 8) <> 0 then 
      failwith "iovec not multiple of 8 bytes" ;
    let iov =
      Iovec.read name iov (fun buf ofs len ->
	let dst = String.create len in
	Shared.update s.shared chan.context buf ofs dst 0 len ;
	let checksum_cpt = Digest.substring dst 0 len in
	if checksum_cpt <> checksum then
	  failwith "mismatched checksum, exiting" ;
	let dst = Iovec.heap name dst in
	Iovec.alloc name dst 0 actual
      )
    in
    let iov = [| iov |] in
    iov
  in

  let up_hdlr ev abv hdr = match getType ev, hdr with
  | (ECast|ESend), Data(seqno,actual,checksum) ->
      let origin = getOrigin ev in
      let chan =
 	if getType ev = ECast then 
	  s.r_cast.(origin)
	else
	  s.r_send.(origin)
      in
      if chan.seqno <> seqno then
	failwith "out-of-order packet" ;
      chan.seqno <- succ chan.seqno ;
      let iov = getIov ev in
      let iov = decrypt chan iov actual checksum in
      let ev = set name ev [Iov iov] in
      up ev abv
      
  | _, NoHdr -> up ev abv
  | _, _     -> failwith "bad up event"

  and uplm_hdlr ev () = failwith "got uplm event"
  and upnm_hdlr = upnm
  
  and dn_hdlr ev abv = match getType ev with
  | ECast | ESend ->
      let iov = getIov ev in
      if getUnreliable ev 
      || Iovecl.len name iov = 0 then ( 
	dn ev abv NoHdr 
      ) else (
	let origin = getOrigin ev in
	let chan =
	  if getType ev = ECast then 
	    s.x_cast
	  else (
	    let dest = 
	      match getRanks ev with
	      | [dest] ->dest
	      | _ -> failwith "ndests <> 1"
	    in
	    s.x_send.(dest)
	  )
	in
	let iov,actual,checksum = encrypt chan iov in
      	let ev = set name ev [Iov iov] in
	let seqno = chan.seqno in
	chan.seqno <- succ chan.seqno ;
	dn ev abv (Data(seqno,actual,checksum))
      )

  | _ -> dn ev abv NoHdr

  and dnnm_hdlr = dnnm

in {up_in=up_hdlr;uplm_in=uplm_hdlr;upnm_in=upnm_hdlr;dn_in=dn_hdlr;dnnm_in=dnnm_hdlr}

let l args vs = Layer.hdr init hdlrs None (FullNoHdr NoHdr) args vs

let _ = 
  Layer.install name (Layer.init l)
    
(**************************************************************)
