package cs2110.assignment1.test;

import static org.testng.AssertJUnit.*;

import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Random;
import javax.tools.JavaCompiler;
import javax.tools.ToolProvider;

import org.testng.annotations.Test;

public class TestImplementation {

	private static TestListener listener;
	public static TestListener testListener;
	
	// Used by subtests
	public static Class<?> implementationClass = null; 
	public static byte[] sourceCode = null;
	public static Class<?> recompiledClass = null;
	
	public static void configure(TestListener listener, Class<?> implementationClass) {
			TestImplementation.listener = listener;
			TestImplementation.implementationClass = implementationClass;
	}

	@Test(description="Test implementation correctness")
	public void testInterface() {
		testListener = listener.newSubListener();
		testListener.runTest(TestImplementations.testClass);
		assertTrue(testListener.passedAllTests());
	}
	
	@Test(description="Source code included in Jar", dependsOnMethods={"testInterface"})
	public void hasSource() throws Exception {
		sourceCode = JarTest.jar.getSource(TestImplementation.implementationClass);
		if(sourceCode == null) {
			throw new Exception("No source code found for " + TestImplementation.implementationClass.getName());
		}
	}
	
	//@Test(description="Source compiles", dependsOnMethods={"hasSource"})
	public void compileSource() throws Exception {
		Class<?> k = TestImplementation.implementationClass;		
		recompiledClass = compileSource(k.getName(), JarTest.jar.getSourceName(k), sourceCode);
		if(recompiledClass == null) 
			throw new Exception("Source does not compile");
	} 
	
	//@Test(description="Test correctness of compiled source", dependsOnMethods={"compileSource"})
	public void recompiledIsCorrect() {
		// TODO
	}
	
	
	public Class<?> compileSource(String classname, String filename, byte[] code) throws IOException {
		JavaCompiler javac = ToolProvider.getSystemJavaCompiler();
		File tempDir = createTempDir();
		String sourceFile = tempDir.getPath() + File.separator + filename;
		FileWriter fw = new FileWriter(sourceFile);
		fw.write(new String(code));
		fw.close();
		int result = javac.run(null,null,null,sourceFile);
		
		File sf = new File(sourceFile);
		if(sf.exists())
			sf.delete();
		
		// Non-zero indicates a compiler error.
		assertTrue(result == 0); 
		
		String classFile = sourceFile.substring(0,sourceFile.lastIndexOf('.')) + ".class";
		
		File cf = new File(classFile);
		assertTrue(cf.exists());
		
		FileInputStream fis = new FileInputStream(classFile);
		DataInputStream dis = new DataInputStream(fis);
		int len = (int) cf.length();
		int offset = 0;
		byte[] data = new byte[len];
		while(len > 0) {
			int bytesRead = dis.read(data,offset,len);
			len -= bytesRead;
			offset += bytesRead;
		}
		dis.close();
		
		if(cf.exists())
			cf.delete();
		
		return (new JarInspector.ByteClassLoader()).classFromBytes(classname,data);
	}
	
	public static File createTempDir() {
		// This method taken from a comment on stackoverflow.com
		// http://stackoverflow.com/questions/375910/creating-a-temp-dir-in-java
		final String baseTempPath = System.getProperty("java.io.tmpdir");
		Random rand = new Random();
		int randomInt = 1 + rand.nextInt();
		File tempDir = new File(baseTempPath + File.separator + "tempDir" + randomInt);
		if (tempDir.exists() == false) { 
			tempDir.mkdir();
		}
		tempDir.deleteOnExit();
		return tempDir;
	}

	public static void singleClass(String classname, Class<?> testclass) {
		listener = new TestListener();
		testListener = listener.newSubListener();
		TestImplementations.testClass = testclass;
		
		try {
			TestImplementation.implementationClass = Class.forName(classname);
		} catch (ClassNotFoundException e) {
			// TODO Auto-generated catch block
			System.err.println("Class not found: " + classname);
			System.exit(1);
		}
		testListener.runTest(TestImplementations.testClass);
	
		testListener.dump(listener.out);
		
		
		if(testListener.passedAllTests()) 
			listener.println("SUCCESS");
		else
			listener.println("FAILURE");
	}
	
	
}
