(**************************************************************)
(*
 *  Ensemble, (Version 0.40)
 *  Copyright 1997 Cornell University
 *  All rights reserved.
 *
 *  See ensemble/doc/license.txt for further information.
 *)
(**************************************************************)
(**************************************************************)
(* SIGNED: MD5 signatures, 16-byte md5 connection ids. *)
(* Author: Mark Hayden, 3/97 *)
(**************************************************************)
open Trans
open Util
(**************************************************************)
let name = Trace.source_file "SIGNED"
let failwith s = failwith (Util.failmsg name s)
let log = Trace.log name ""
(**************************************************************)

type secure = bool

let f =
  let pack_of_conn = Conn.hash_of_id in
  let (marshal,unmarshal) = Mbuf.make_marsh name Mbuf.global in

  let const hdlr kind rank = Route.Signed(kind,rank,hdlr) in

  let recv pack key hdlr kind rank = 
    let secure = hdlr kind rank true in
    let insecure = hdlr Conn.Other (-1) false in
    let signature = String.create md5len in
    let pack_s = String.copy pack in

    fun rbuf ofs len ->
      if len < 2 * md5len + 8 then (
	Route.drop (fun () -> sprintf "Signed:size below minimum:len=%d\n" len) ;
      ) else ( 
	(* Extract connection info.
	 *)
	let buf = Refcnt.read name rbuf in
	String.blit buf ofs pack_s 0 md5len ;
	if pack_s <> pack then (
	  Route.drop (fun () -> sprintf "Signed:rest of Conn.id did not match") ;
	) else (
	  (* Extract signature info.
	   *)
	  String.blit buf (ofs+len-md5len) signature 0 md5len ;

	  (* Calculate the signature for the message
	   *)
	  let sign_cpt = Security.sign key buf ofs (len-md5len) in

	  (* Get the integer value and length of
	   * of the marshalled portion of the message.
	   *)
	  let mi    = Hsys.pop_int buf (ofs + md5len) in
	  let molen = Hsys.pop_int buf (ofs + md5len + 4) in

	  (* Advance past previous fields.
           * Also subtract off the signature space at the end.
	   *)
	  let ofs = ofs + md5len + 8 in
	  let len = len - md5len - 8 - md5len in

	  if sign_cpt = signature then (
	    if molen = 0 then (
	      let mv = Iovec.alloc name rbuf (ofs+molen) (len-molen) in
	      secure mi None [| mv |]
	    ) else if len >= molen then (
	      let mo = Some(unmarshal buf ofs molen) in
	      let mv = Iovec.alloc name rbuf (ofs+molen) (len-molen) in
	      secure mi mo [| mv |]
	    ) else (
	      Route.drop (fun () -> sprintf "Signed:short message:len=%d:molen=%d\n" len molen) ;
	    )
	  ) else (
	    if len < molen then (
	      Route.drop (fun () -> sprintf "Signed:(insecure) short message:len=%d:molen=%d\n" len molen) ;
	    ) else (
	      (* Only information passed up is the iovec.
	       *)
	      log (fun () -> sprintf "insecure message") ;
	      let mv = Iovec.alloc name rbuf (ofs+molen) (len-molen) in
	      insecure (-1) None [|mv|]
	    )
	  )
	)
    )
  in

  let merge upcalls =
    let upcalls =
      Array.map (function
      | (_,pack,key,Route.Signed(kind,rank,upcall)) ->
      	  recv pack key upcall kind rank
      | _ -> failwith "sanity"
      ) upcalls
    in
    Route.merge3 upcalls
  in

  let blast _ xmitv key pack conn =
    let ints_s = String.create 8 in
    let prefix = Iovec.of_string name (pack^ints_s) in

    (fun _ mi mo mv ->
      let mo = match mo with
      |	None -> Iovec.empty name
      |	Some(mo) -> marshal mo
      in
      let molen = Iovec.len name mo in

      Iovec.write_hack name prefix (fun buf ofs len ->
	Hsys.push_int buf (ofs+md5len) mi ;
	Hsys.push_int buf (ofs+md5len+4) molen
      ) ;

      (* Calculate the signature.
       *)
      let sign =
      	let iov = Array.append [|prefix;mo|] mv in
      	let iov = Iovecl.flatten name iov in
	Iovec.read name iov (fun buf ofs len ->
	  Security.sign key buf ofs len
	)
      in
      let sign = Iovec.of_string name sign in

      let iov = Array.concat [[|prefix;mo|];mv;[|sign|]] in
      xmitv iov
    )
  in

  Route.make
    name
    ident
    const
    pack_of_conn
    merge
    blast

(**************************************************************)
