package il.ac.idc.storage;

import java.security.*;
import java.io.*;
import java.util.*;

/** A bogus cryptography class, but good enough for the purposes of this
    class. */
public class CryptKey {
    private long key;
    private Random uniqGen;
    private MessageDigest md;

    public CryptKey(long key) throws NoSuchAlgorithmException {
	this.key = key;
	uniqGen = new Random();
	md = MessageDigest.getInstance("MD5");
    }

    /** Takes a byte array, and returns an encrypted byte array which is
	32 bytes longer than the original. */
    public byte[] enc(byte[] msg) throws IOException {
	long uniq = uniqGen.nextLong();
	byte[] salt = longToByteArray(uniqGen.nextLong());
	Random maskGen = new Random(key ^ uniq);
	byte[] mask = new byte[msg.length + 8];
	maskGen.nextBytes(mask);

	ByteArrayOutputStream baos = 
	    new ByteArrayOutputStream(msg.length + 32);
	DataOutputStream dos = new DataOutputStream(baos);

	for (int i = 0; i < 8; i++)
	    mask[i] = (byte)(mask[i] ^ salt[i]);
	for(int i = 0; i < msg.length; i++)
	    mask[i+8] = (byte)(mask[i+8] ^ msg[i]);

	// compute the hash
	md.reset();
	md.update(salt);
	md.update(msg);
	byte[] hash = md.digest();

	dos.writeLong(uniq);  // 8 bytes
	dos.write(hash, 0, hash.length);  // 16 bytes
	dos.write(mask,0,mask.length);  // message (with 8 bytes of salt)
	return baos.toByteArray();
	
    }

    /** Takes a byte array, and returns a decrypted byte array which is 32
	bytes shorter than the original. Make sure that the input is at least
	32 bytes long.  This returns null if the message failed to decrypt. */
    public byte[] dec(byte[] msg) throws IOException {
	if (msg.length < 24)
	    return null;

	DataInputStream dis = 
	    new DataInputStream(new ByteArrayInputStream(msg));
	byte msgHash[] = new byte[16];
	
	long uniq = dis.readLong(); // read in the long at the head
	dis.read(msgHash,0,16);  // then read in the 16 bit hash

	byte[] mask = new byte[msg.length - 24];
	Random maskGen = new Random(key ^ uniq);
	maskGen.nextBytes(mask);
	for (int i=0; i < msg.length-24; i++)
	    mask[i] = (byte)(mask[i] ^ msg[i+24]);

	if (equal(md.digest(mask),msgHash)) {
	    byte rval[] = new byte[mask.length - 8];
	    ByteArrayInputStream tmp = new ByteArrayInputStream(mask);
	    tmp.skip(8);
	    tmp.read(rval,0,mask.length-8);
	    return rval;
	} else {
	    return null;
	}
    }

    private boolean equal(byte[] x, byte[] y) {
	if (x.length != y.length)
	    return false;
	for(int i=0; i<x.length;i++)
	    if (x[i] != y[i])
		return false;
	return true;
    }

    byte[] longToByteArray(long input){
	byte[] output = new byte[8];
	for (int i=0;i<8;i++){
	    output[i] = (byte)(input & 0xff);
	    input = input >> 8;
	}
	return output;
    }
}
