package JavaGroups.JavaStack.Protocols;

import java.io.*;
import java.util.*;
import java.security.*;
import JavaGroups.*;
import JavaGroups.JavaStack.*;

import java.math.BigInteger;
import java.security.spec.*;
import java.security.interfaces.*;
import javax.crypto.*;
import javax.crypto.spec.*;
import javax.crypto.interfaces.*;
import au.net.aba.crypto.provider.ABAProvider;
import au.net.aba.crypto.provider.*;

class EncryptHeader implements Serializable {
	int type;
	static final int ENCRYPT         = 0;
	static final int KEY_REQUEST     = 1;
    static final int SERVER_PUBKEY   = 2;
    static final int SECRETKEY       = 3;
	static final int SECRETKEY_READY = 4;
    public EncryptHeader(int type) {
		this.type = type;
	}
    public String toString() {return "[ENCTYPT: <variables> ]";}
}


/**
 * ENCRYPT layer. Encrypt and decrypt the group communication in JavaGroups
 */
public class ENCRYPT extends Protocol {
	Address  local_addr=null;
	Address  keyServerAddr = null;
    boolean  keyServer=false;
	String   asymAlgorithm="RSA";
	String   symAlgorithm="DES/ECB/PKCS5Padding";
	int      asymInit=512;					// initial public/private key length
	int		 symInit=56;					// initial shared key length
	// for public/private Key
	KeyPair		Kpair;			// to store own's public/private Key
	SecretKey	desKey=null;
	PublicKey   pubKey = null;               // for server to store the temporary client public key
	PublicKey   serverPubKey = null;         // for client to store server's public Key
	Cipher		cipher;
	Cipher		rsa;
	Vector		members=new Vector();
	Vector      notReady = new Vector();

	public ENCRYPT(){
		Security.addProvider(new ABAProvider());
	}


	public String  GetName() {return "ENCRYPT";}


	/*
	 * GetAlgorithm: Get the algorithm name from "algorithm/mode/padding"
	 */
	private String GetAlgorithm(String s)
	{
		int index = s.indexOf("/");
		if (index==-1)
			return s;

		return s.substring(0, index);
	}


    public boolean SetProperties(Properties props) {
		String     str;
		this.props=props;

		// asymmetric key length
		str=props.getProperty("asymInit");
		if (str != null) {
		    asymInit = new Integer(str).intValue();
		    props.remove("asymInit");
			System.out.println("asymInit = "+asymInit);
		}

		// symmetric key length
		str=props.getProperty("symInit");
		if (str != null) {
		    symInit = new Integer(str).intValue();
		    props.remove("symInit");
			System.out.println("symInit = "+symInit);
		}

		// asymmetric algorithm name
		str=props.getProperty("asymAlgorithm");
		if (str != null) {
		    asymAlgorithm = new String(str).toString();
		    props.remove("asymAlgorithm");
		}

		// symmetric algorithm name
		str=props.getProperty("symAlgorithm");
		if (str != null) {
		    symAlgorithm = new String(str).toString();
		    props.remove("symAlgorithm");
		}
		if (props.size() > 0) {
		    System.err.println("ENCRYPT.SetProperties(): these properties are not recognized:");
		    props.list(System.out);
		    return false;
		}

		// generate keys according to the specified algorithms
		try{
			// generate publicKey and Private Key using RSA
			KeyPairGenerator KpairGen = KeyPairGenerator.getInstance(GetAlgorithm(asymAlgorithm));
			KpairGen.initialize(asymInit, new SecureRandom());
			Kpair = KpairGen.generateKeyPair();

			// generate secret key
			KeyGenerator keyGen = KeyGenerator.getInstance(GetAlgorithm(symAlgorithm));
			keyGen.init(symInit);
			desKey = keyGen.generateKey();

			// initialize for rsa, cipher encryption/decryption
			rsa = Cipher.getInstance(asymAlgorithm);
			cipher = Cipher.getInstance(symAlgorithm);
		}
		catch(Exception e){
			System.out.println(e+"at SetProperties");
		}
		return true;
    }


    /** Just remove if you don't need to reset any state */
    public void Reset() {}

	public void Up(Event evt) {
		Message msg;
		Message newMsg;
		EncryptHeader hdr;
		switch(evt.GetType()) {

			case Event.SET_LOCAL_ADDRESS:
				local_addr=(Address)evt.GetArg();
				break;
			case Event.FIND_INITIAL_MBRS_OK:
				Vector member=(Vector)evt.GetArg();
				keyServer = member.size() > 0 ? false : true;

				if (member != null && member.size() > 0)
					keyServerAddr = (Address)((PingRsp)member.firstElement()).coord_addr;
				else keyServerAddr = local_addr;

				System.out.println("keyServer = " + keyServer + "  keyServerAddr : "+keyServerAddr.toString());
				if (!keyServer)
				{
					desKey = null;
					// client send clien's public key to server and request server's public key
					newMsg = new Message(keyServerAddr,local_addr,Kpair.getPublic().getEncoded());
					newMsg.AddHeader(new EncryptHeader(EncryptHeader.KEY_REQUEST));
					PassDown(new Event(Event.MSG,newMsg));
				}

				PassUp(evt);
				return;

			case Event.MSG:
			    msg=(Message)evt.GetArg();

				Object obj=msg.PeekHeader();

				// if not encrypted message, pass up
	    	    if (obj == null || !(obj instanceof EncryptHeader)) {
					PassUp(evt);
					return;
				}
				hdr = (EncryptHeader)msg.RemoveHeader();

				switch(hdr.type){

					// key request from client and send server's public key to client
					case EncryptHeader.KEY_REQUEST:
						try{
						// store the this client to notReady list using client's address
						notReady.addElement(msg.GetSrc());
						// store the client's public key for temporary
						PublicKey pubKey = generatePubKey(msg.GetBuffer());

						// send server's publicKey
						newMsg = new Message(msg.GetSrc(), local_addr, Kpair.getPublic().getEncoded());
						newMsg.AddHeader(new EncryptHeader(EncryptHeader.SERVER_PUBKEY));
						PassDown(new Event(Event.MSG, newMsg));

						// send shared DesKey to client
						//   1. Decrypt desKey with server's own private Key
						//   2. Encrypt decrypted desKey with client's own public Key
						// encrypt encoded desKey using server's private key
						rsa.init(Cipher.ENCRYPT_MODE, Kpair.getPrivate());
						byte [] decryptedKey = rsa.doFinal(desKey.getEncoded());

						// encrypt decrypted key using client's public key
						rsa.init(Cipher.ENCRYPT_MODE, pubKey);
						byte [] encryptedKey = rsa.doFinal(decryptedKey);

						//send encrypted deskey to client
						newMsg = new Message(msg.GetSrc(), local_addr, encryptedKey);
						newMsg.AddHeader(new EncryptHeader(EncryptHeader.SECRETKEY));
						PassDown(new Event(Event.MSG, newMsg));
						}
						catch(Exception e){
							System.out.println(e+"0");
						}
						return;

					case EncryptHeader.SECRETKEY_READY:
						//server get client's public key and generate the secret key
						notReady.removeElement(msg.GetSrc());
						return;
					case EncryptHeader.SERVER_PUBKEY:
						serverPubKey = generatePubKey(msg.GetBuffer());
						return;

					case EncryptHeader.SECRETKEY:
						try{
							// decrypt using client's private Key
							rsa.init(Cipher.DECRYPT_MODE,Kpair.getPrivate());
							byte[] decryptedKey = rsa.doFinal(msg.GetBuffer());

							// decrypt using server's public Key
							rsa.init(Cipher.DECRYPT_MODE,serverPubKey);
							byte[] encodedKey = rsa.doFinal(decryptedKey);

							// decode secretKey
							desKey = decodedKey(encodedKey);
							System.out.println("Client generate shared secret key");
							// send ready message
							newMsg = new Message(msg.GetSrc(), local_addr, null);
							newMsg.AddHeader(new EncryptHeader(EncryptHeader.SECRETKEY_READY));
							PassDown(new Event(Event.MSG, newMsg));
						}
						catch(Exception e){
							System.out.println(e+"5");
						}
						return;

					default: break;
				}

				if (hdr.type != 0) System.out.println("This is ERROR");

				// not have shared key yet
				// this encrypted message is of no use, drop it
				if (desKey == null) return;

				// if both the shared key and incoming message are not null
				// decrypt the message
				if (msg.GetBuffer()!=null)
				{
					try{
						cipher.init(Cipher.DECRYPT_MODE, desKey);
						msg.SetBuffer(cipher.doFinal(msg.GetBuffer()));
					}
					catch(Exception e){
						System.out.println(e+"6");
					}
				}

				break;
		}
		PassUp(evt);            // Pass up to the layer above us
    }

    public void Down(Event evt) {
		Message msg;
		Message newMsg;
		SecretKey key;
		boolean leave = false;

		switch(evt.GetType()) {
			case Event.VIEW_CHANGE:
				Vector new_members=(Vector)((View)evt.GetArg()).GetMembers();

				// member size decreases: member leaves, need a new key
				if (members.size() > new_members.size()) leave = true;

				// copy member list
				synchronized(members) {
					members.removeAllElements();
					if (new_members != null && new_members.size() > 0)
						for (int i=0; i < new_members.size(); i++)
							members.addElement(new_members.elementAt(i));
				}

				// redistribute/regain the new key because old member leaves
				if (leave){
					// get coordinator address
					Object obj = members.firstElement();

					// if I'm the coordinator/key-server
					if (obj.equals(local_addr)){
						//create the new shared key and distribute
						keyServer = true;
						keyServerAddr = local_addr;

						// reset shared key
						desKey=null;

						try {
							//generate new shared key
							KeyGenerator keyGen = KeyGenerator.getInstance(GetAlgorithm(symAlgorithm));
							keyGen.init(symInit);
							desKey = keyGen.generateKey();
						}
						catch (Exception e) {
							System.out.println(e+"7");
						}

					}//end of local_addr == obj
					// if I'm not the coordinator/key-server
					else {
						keyServer = false;
						keyServerAddr = (Address)obj;
						
						// reset shared key
						desKey = null;

						// client send clien's public key to server and request server's public key
						newMsg = new Message(keyServerAddr, local_addr, Kpair.getPublic().getEncoded());
						newMsg.AddHeader(new EncryptHeader(EncryptHeader.KEY_REQUEST));
						PassDown(new Event(Event.MSG, newMsg));
						System.out.println("Request new key");
					}
				}
				break;

			case Event.MSG:
				msg=(Message)evt.GetArg();
				int i;

				// For Server:
				// if some members don't have the shared key yet
				if (!notReady.isEmpty()){
					System.out.println("not Ready list  :"+ notReady.toString());
					if (msg.GetDest() == null){
						for (i = 0; i < notReady.size();i++){
							newMsg = new Message(notReady.elementAt(i), local_addr, msg.GetBuffer());
							PassDown(new Event(Event.MSG, newMsg));
						}
						break;
					}
					else{
						for (i = 0; i < notReady.size();i++){
							if (msg.GetDest() == notReady.elementAt(i)){
								PassDown(evt);
								return;
							}
						}
					}
				}

				// I already know the shared key
				if (desKey != null) {
					try {
						// if the message is not empty, encrypt it
						if (msg.GetBuffer() != null)
						{
							cipher.init(Cipher.ENCRYPT_MODE, desKey);
							msg.SetBuffer(cipher.doFinal(msg.GetBuffer()));
							msg.AddHeader(new EncryptHeader(0));
						}
					}
					catch (Exception e) {
						System.out.println(e+"8");
					}
				}
			break;
		}
		//System.out.println("Pass Down: "+evt.toString());
		PassDown(evt);          // Pass on to the layer below us
    }

	private SecretKey decodedKey(byte[] encodedKey){
		SecretKey key = null;
		try{
			SecretKeyFactory KeyFac = SecretKeyFactory.getInstance(GetAlgorithm(symAlgorithm));
			SecretKeySpec desKeySpec = new SecretKeySpec(encodedKey, GetAlgorithm(symAlgorithm));
			key = KeyFac.generateSecret((KeySpec)desKeySpec);
		}
		catch(Exception e){
			System.err.println(e);
		}
		return key;
	}

	private PublicKey generatePubKey(byte [] encodedKey){
		PublicKey pubKey = null;
		try{
			KeyFactory KeyFac = KeyFactory.getInstance(GetAlgorithm(asymAlgorithm));
			X509EncodedKeySpec x509KeySpec = new X509EncodedKeySpec(encodedKey);
			pubKey = KeyFac.generatePublic((KeySpec)x509KeySpec);
		}
		catch(Exception e){
			System.err.println(e);
		}
		return pubKey;
	}
}