/*********************************************************************
  RSAsol.java         CS100m-2005fa         Project No. 4

  -- A simplified RSA encryption implementation.
  
 *************
  THIS SOLUTION FILE IS RSAsol.java, SO EXECUTE THE PROGRAM USING THE
  NAME RSAsol, not just RSA.  For example, to generate keys, type
      java RSAsol genkeys
  *************

  About efficiency:
      Since this is an introductory level project and all inputs are small,
  maximum running time efficiency is not required from your algorithm.
  However, it should not be too inefficient, e.g. running for more
  than 2 seconds on any single genKey/encrypt/decrypt/crack operation.

  @ Yunpeng Li, Oct 19, 2005

***********************************************************************/


/* The public class - should always have the same name as the .java file.
 All methods are static so that they can be called without using objects. */
public class RSAsol {

    /****** Constants ******/

    // Approximate bounds for randomly generated primes.
        /* In this exercise, we are only dealing with very small primes
           (not more than '7.5' bits = 2^7 + 2^6) since we want to guarantee
           that any intermediate values to fit into
           32-bit signed int, i.e. no larger than pow(2,31)-1. */
    public static final int PRIME_LOW_BOUND = 61; // 'big' enough
    public static final int PRIME_HIGH_BOUND = 192; /* not too big to cause
                                                         potential overflow */

    // Approximate bound for encryption/decryption keys
    public static final int KEY_LOW_BOUND = 31; // so key not 'too' small

    // The max allowed value of the message text (integer) to be encrypted.
    public static final int MAX_INPUT_MESSAGE =
            PRIME_LOW_BOUND * PRIME_LOW_BOUND - 1;


    /****** Methods ******/

    /* Testing whether a number is prime
     Parameters:
       n - number to be tested
     Returns:
       true|false - whether the number is a prime
     This method is called by 'genPrime' which generates prime numbers.

     Hint: This can be implemented by iteratively testing every potential
     exact divisor of n.

      (Note that it is more efficient (O(log(n)) vs. O(sqrt(n)) time) to
      implement prime test using Fermat's Little Theorem.
      Although passing the test doesn't guarantee the
      number to be prime in general, it works many cases,
      e.g. for positive numbers smaller than 32768, with testing bases
      2,3,5,7,11, and 13, since there are no pseudoprimes for all those
      bases in that range.
          However, implementation needs array in general. Hence
      We don't use it here and you don't need to implement it.)
     */
    public static boolean testPrime(int n) {
        //### Implement this method ###
        boolean result = true; // You'll need to update its value in your code.

        //--------- YOUR CODE BELOW ---------

        if(n < 2)
            result = false; // Smaller prime is 2
        else if(n == 2)
            result = true; // 2 is a prime
        else if(n % 2 == 0)
            result = false; // No other even numbers except 2 is prime
        else { // result still has its initial value = true at this point
            int ubound = (int)Math.sqrt(n);
            // if n has any factor, it has one below ubound. (symmetry)
            for(int k=3; k<=ubound && result; k+=2) {
                if (n % k == 0)
                    result = false; // found divisor of n, so n is not prime
            }
        }

        //--------- YOUR CODE ABOVE ---------

        return result;
    }

    /* Randomly generate a prime number roughly in the range of
     [PRIME_LOW_BOUND, RPIME_HIGH_BOUND).
     */
    public static int genPrime() {
        int n = (int)(Math.random() * (PRIME_HIGH_BOUND - PRIME_LOW_BOUND))
                + PRIME_LOW_BOUND; // generate a random number int the range
        if(n % 2 == 0)  // if n is even, it can't be prime
            n--;
        while(!testPrime(n)) // repeat until a prime is found
            n -= 2;
        return n;
    }

    /* Find the greatest common divisor (gcd) of two positive numbers,
     n1 and n2.
     */
    public static int gcd(int n1, int n2) {
        //### Implement this method ###
        int divisor = 1;  // You'll need to update its value in your code.

        //--------- YOUR CODE BELOW ---------

        while(n2 > 0) {
            int remainder = n1 % n2;
            n1 = n2;
            n2 = remainder;
        }
        divisor = n1;

        //--------- YOUR CODE ABOVE ---------

        return divisor;
    }

    /* Randomly generate a positive integer that is coprime to the given
     positive integer n and approximately in the range of [KEY_LOW_BOUND, n). */
    public static int genCoPrime(int n) {
        int copn = (int)(Math.random() * (n - KEY_LOW_BOUND)) + KEY_LOW_BOUND;
        if(n % 2 == 0) { // n is even, and is hence only coprime to odd numbers
            if(copn % 2 == 0)
                copn--;
            while(gcd(n, copn) != 1)
                copn -= 2;
        }
        else { // n is odd, and can be coprime to both odd and even numbers
            while(gcd(n, copn) != 1)
                copn--;
        }
        return copn;
    }

    /* Generate a public-private key pair and the corresponding modulus cipher.
     Returns a the key pair and the cipher.

     Hint: methods you might want to call inside this method
     - genPrime() returns a prime number.
     - genCoPrime(n) returns a number coprime to n and smaller than n.
     */
    public static RSAKeys genKeys() {
        //### Implement this method ###
        // You'll need to assign/update the values of e, d, p, q in your code.
        int e, d;  // public key e for encryption, private key d for decryption
        int p, q;  // the two primes used to generate the keys

        //--------- YOUR CODE BELOW ---------

        // e = d = p = q = 1;  // Delete this dummy line as you complete your code

        // Generate two primes
        p = genPrime();
        q = genPrime();
        while(q == p)
            q = genPrime();  // p, q mustn't be equal
        // find (p-1)*(q-1) and generate the public key e
        int totient = (p - 1) * (q - 1);
        e = genCoPrime(totient);

        // find d such that (d*e-1) is divisible by (p-1)*(q-1)
        // i.e. d is the multiplicative inverse of e mod (p-1)*(q-1)
        int x = 1;
        while((x * totient + 1) % e != 0)
            x++;
        d = (x * totient + 1) / e;
        /* Note that the above approach is inefficient with running time O(e)
           and hence exponential to the key length -- essentially the
           same amount of work as cracking RSA. The efficient way to do it
           is using the Extended Euclidean Algorithm, but we do not use it
           here as it would require arrays.
           */

        //--------- YOUR CODE ABOVE ---------

        return new RSAKeys(e, d, p * q);
    }

    /* Find the remainder of pow(m, k) (i.e. m to the power of k) divided by c.
     m - the message text, a number in this project (> 0)
     k - the encryption/decryption key (> 0)
     c - the cipher p*q (> 0)
     This method is the actual encryption/decryption process. It is called by
     both methods 'encrypt' and 'decrypt'.

     Note: One cannot simply raise m to the power of k and then compute
     the modulo, since the first operation will easily overflow the 32-bit int
     with large k.

     Hint: Remember that (a*b) mod c = ((a mod c)(b mod c)) mod c.
     */
    public static int powMod(int m, int k, int c) {
        //### Implement this method ###
        int result = 1;  // You'll need to update its value in your code.

        //--------- YOUR CODE BELOW ---------

        for(int i=0; i<k; i++)
            result = (result * m) % c;

        /*    This approach is a simple, intuitive but inefficient
          implementation of power-mod. It only works for small
          value of k (e.g. k = one million), which suffices for
          this this assignemnt.
              For large value of k (e.g. k = 2^1024), however, it becomes
          impratical. As its running time is linear to k and hence exponential
          to the length (i.e. # of bits) of key k.  More efficient approaches
          would require arrays, so we do not use them here.          
         */

        //--------- YOUR CODE ABOVE ---------

        return result;
    }


    /* Encryption function
     Parameters:
     m - the original message, or plaintext (int, <= MAX_INPUT_MESSAGE)
     e - the public encryption key
     c - the modulus cipher p*q
     Returns:
     the encrypted message (i.e. ciphertext, which is an int)
     */
    public static int encrypt(int m, int e, int c) {
        // Catch some invalid inputs
        if(m < 0) {
            System.out.println("STOP: plaintext is a negative number!");
            System.exit(1);
        }
        if(m > MAX_INPUT_MESSAGE) {
            System.out.println("STOP: plaintext is too big! (max = " +
                               MAX_INPUT_MESSAGE + ")");
            return -1; // Just to make it more "Dr. Java friendly" -
                       // -- Should really be System.exit(1);
        }
        // Compute the ciphertext from original message using power-mod
        return powMod(m, e, c);
    }

    /* Decryption function
     Parameters:
     s - the encrypted message (i.e. ciphertext)
     d - the private decryption key
     c - the modulus cipher p*q
     Returns:
     the decrypted message, i.e. the original message (or plaintext, as an int)
     */
    public static int decrypt(int s, int d, int c) {
        // Catch some invalid inputs
        if(s < 0) {
            System.out.println("STOP: ciphertext is a negative number!" +
                             " -- Shouldn't happen.");
            System.exit(1);
        }
        if(s >= c) {
            System.out.println("STOP: ciphertext is not smaller than the cipher!" +
                               "\n-- Something is wrong.");
            return -1; // Just to make it more "Dr. Java friendly" -
                       // -- Should really be System.exit(1);
        }
        // decypting ciphertext to original message using power-mod
        return powMod(s, d, c);
    }

    /* Cracking function -- Recover the private decryption key d from the
     public encryption key e and modulus cipher c
     Parameters:
     e - the public encryption key
     c - the modulus cipher
     Returns:
     the private decryption key d
     */
    public static int crack(int e, int c) {
        //### CHALLENGE QUESTION ONLY--OPTIONAL ###
        //### Implement this method--OPTIONAL ###
        /* Hint: The only known method to crack RSA encryption is to factor the
         modulus cipher to obtain the two primes p and q that are used to
         generate the keys. Once p and q are known, one can find the
         private decryption key d from the public encryption key e in
         the same fashion as in genKeys. */
        int result = 1;  // You'll need to update its value in your code

        //--------- YOUR CODE BELOW ---------

        // Factor the modulus cipher c to obtain the original two primes p, q
        int p = 1; // init value
        if(c % 2 == 0) {
            /* cipher is even - unlikely and not possible in this project,
             but is a trivial possibility in general.
             No penalty if student assumes c is odd in his/her solution. */
            p = 2;
        }
        else {
            /* Any factorization method that works is acceptable. */
            int ubound = (int)Math.sqrt(c);
            for(int i=3; i<=ubound && p==1; i++) {
                    // p==1 means haven't found a factor
                if(c % i == 0)
                    p = i;
            }
        }
        int q = c / p;

        // Recover d from e, p, q
        // find d such that (d*e-1) is divisible by (p-1)*(q-1)
        // i.e. d is the multiplicative inverse of e mod (p-1)*(q-1)
        int totient = (p - 1) * (q - 1);
        int x = 1;
        while((x * totient + 1) % e != 0)
            x++;
        result = (x * totient + 1) / e;
        /* Same comments (re. efficiency) as in genKeys apply */

        //--------- YOUR CODE ABOVE ---------

        return result;
    }


    /* The main method -- Takes commandline arguments as input, calls some
     appropriate static method, and print out the output.

     It is UNNECISSARY to modify this method.
     */
    public static void main(String[] args) {
        final String usageMsg =
                "Usage: java RSA <mode> [args ...]\n" +
                "mode = <genkeys | enc | dec | crack>\n" +
                "args:\n" +
                "  genkeys - (no args)\n" +
                "  enc - m (plaintext), e (encryption key), c (modulus cipher)\n" +
                "  dec - s (ciphertext), d (decryption key), c (modulus cipher)\n" +
                "  crack - e (encryption key), c (modulus cipher)\n" +
                "addition modes to help debugging: <testprime | testgcd>\n" +
                "  testprime - n  (call your testPrime to test whether n is prime\n" +
                "  testgcd - a, b (call your gcd to find the gcd of a and b)\n" +
                "All arguments must be integers.\n" +
                "Examples:\n" +
                "  java RSA genkeys\n" +
                "  java RSA enc 123 17 3233\n";
        if(args.length < 1) {
            System.out.println(usageMsg);
            return; // Just to make it more "Dr. Java friendly"
                    // -- Should really be System.exit(1);
        }
        // Find the mode
        if(args[0].equalsIgnoreCase("genkeys")) {
            System.out.println("Generating keys...");
            RSAKeys keys = genKeys();
            System.out.println("--> Public encryption key e: " + keys.encKey +
                               "\n--> Private decryption key d: " + keys.decKey +
                               "\n--> Modulus cipher c: " + keys.cipher);
        }
        else if(args[0].equalsIgnoreCase("enc") ||
                args[0].equalsIgnoreCase("dec") ||
                args[0].equalsIgnoreCase("crack") ||
                args[0].equalsIgnoreCase("testprime") ||
                args[0].equalsIgnoreCase("testgcd")) {
            int[] input = new int[args.length];
            // Parsing the input arguments
            for (int i = 1; i < args.length; i++) {
                try { // try-catch required by method parseInt
                    input[i] = Integer.parseInt(args[i]);
                }
                catch (NumberFormatException e) {
                    System.out.println("***Argument " + i + ": " + args[i] +
                                       " is not an integer!");
                    return; // Just to make it more "Dr. Java friendly"
                            // -- Should really be System.exit(1);
                }
            }
            if(args[0].equalsIgnoreCase("enc")) {
                if(input.length < 4) { // Not enough arguments
                    System.out.println("***Too few arguments!\n" + usageMsg);
                    return; // Just to make it more "Dr. Java friendly"
                            // -- Should really be System.exit(1);
                }
                System.out.println("Plaintext m: " + input[1] +
                                   "\nEncryption key e: " + input[2] +
                                   "\nModulus cipher c: " + input[3] +
                                   "\nEncrypting...");
                int s = encrypt(input[1], input[2], input[3]);
                if(s != -1)  // -1 is the "error code"
                    System.out.println("--> Ciphertext s: " + s);
            }
            else if(args[0].equalsIgnoreCase("dec")) {
                if(input.length < 4) { // Not enough arguments
                    System.out.println("***Too few arguments!\n" + usageMsg);
                    return; // Just to make it more "Dr. Java friendly"
                            // -- Should really be System.exit(1);
                }
                System.out.println("Ciphertext s: " + input[1] +
                                   "\nDecryption key d: " + input[2] +
                                   "\nModulus cipher c: " + input[3] +
                                   "\nDecrypting...");
                int m = decrypt(input[1], input[2], input[3]);
                if(m != -1)  // -1 is the "error code"
                    System.out.println("--> Plaintext m: " + m);
            }
            else if(args[0].equalsIgnoreCase("crack")) {
                if(input.length < 3) { // Not enough arguments
                    System.out.println("***Too few arguments!\n" + usageMsg);
                    return; // Just to make it more "Dr. Java friendly"
                            // -- Should really be System.exit(1);
                }
                System.out.println("Encryption key e: " + input[1] +
                                   "\nModulus cipher c: " + input[2]);
                System.out.println("Cracking...");
                int d = crack(input[1], input[2]);
                System.out.println("--> Recovered decryption key d: " + d);
            }
            else if(args[0].equalsIgnoreCase("testprime")) {
                if(input.length < 2) { // Not enough arguments
                    System.out.println("***Too few arguments!\n" + usageMsg);
                    return; // Just to make it more "Dr. Java friendly"
                            // -- Should really be System.exit(1);
                }
                System.out.println("testPrime(" + input[1] + ") = " +
                                   testPrime(input[1]));
            }
            else if(args[0].equalsIgnoreCase("testgcd")) {
                if(input.length < 3) { // Not enough arguments
                    System.out.println("***Too few arguments!\n" + usageMsg);
                    return; // Just to make it more "Dr. Java friendly"
                            // -- Should really be System.exit(1);
                }
                System.out.println("gcd(" + input[1] + ", " + input[2] + ") = "
                                   + gcd(input[1], input[2]));
            }
        }
        else {
            System.out.println("***Unknown mode: " + args[0]);
            System.out.println(usageMsg);
            return; // Just to make it more "Dr. Java friendly"
                    // -- Should really be System.exit(1);
        }
    }

}//class RSA


/*******************************
 A simple data structure storing a triplet: the public encryption key e,
 the private decryption key d, and the cipher c = p*q.

 NOTE: You DON'T have to use the following part in the code you write, nor
 do you need to understand class/objects at the moment. This topic will be
 covered in later lectures.
 ***/
class RSAKeys {
   int encKey, decKey, cipher;

   public RSAKeys(int e, int d, int c) { // Constructor
       encKey = e;
       decKey = d;
       cipher = c;
   }
}
