package edu.cornell.cs.cs4120.eth.etac.tester;

import edu.cornell.cs.cs4120.util.CodeWriterSExpPrinter;
import edu.cornell.cs.cs4120.util.SExpPrinter;
import edu.cornell.cs.cs4120.xic.ir.IRCompUnit;
import edu.cornell.cs.cs4120.xic.ir.IRNode;
import edu.cornell.cs.cs4120.xic.ir.IRNodeFactory_c;
import edu.cornell.cs.cs4120.xic.ir.interpret.IRSimulator;
import edu.cornell.cs.cs4120.xic.ir.interpret.IRSimulator.OutOfBoundTrap;
import edu.cornell.cs.cs4120.xic.ir.parse.IRLexer;
import edu.cornell.cs.cs4120.xic.ir.parse.IRParser;
import edu.cornell.cs.cs4120.xic.ir.visit.CheckCanonicalIRVisitor;
import edu.cornell.cs.cs4120.xic.ir.visit.CheckConstFoldedIRVisitor;
import edu.cornell.cs.cs4120.eth.FormattedOutput;
import edu.cornell.cs.cs4120.eth.SourceFileTest;

import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

public class EtacIRGenTester extends AbstractEtacTester {

    protected static final int TIMEOUT = 30;
    protected static final TimeUnit TIMEUNIT = TimeUnit.SECONDS;

    /** Whether the generated IR should be optimized */
    protected boolean optimized;

    protected int numTests = 0;
    protected int numCanon = 0;
    protected int numFolded = 0;
    protected int numCorrect = 0;

    public EtacIRGenTester(boolean optimized) {
        this.optimized = optimized;
    }

    @Override
    public boolean copyReferenceFiles(SourceFileTest t, File testDir, File destDir) {
        boolean okay = true;
        String testDirname = appendDirSep(testDir.getPath());
        String destDirname = appendDirSep(destDir.getPath());
        for (List<String> compilationUnit : t.getSourceFileNames())
            for (String filename : compilationUnit) {
                String irsolFilename = normalizedFilename(irsolFilename(filename));
                okay =
                        okay
                                && copyFile(
                                        testDirname + irsolFilename,
                                        destDirname + irsolFilename,
                                        t);
            }
        return okay;
    }

    @Override
    public boolean renameReferenceFiles(SourceFileTest t, File destDir) {
        boolean okay = true;
        // File path/to/file.eta results in destDir/path/to/file.ir.
        // Rename destDir/path/to/file.ir to destDir/path/to/file.irsol.
        // Run IR code in destDir/path/to/file.irsol
        // and store result in destDir/path/to/file.irsol.nml,
        String destDirname = appendDirSep(destDir.getPath());
        for (List<String> compilationUnit : t.getSourceFileNames())
            for (String filename : compilationUnit) {
                String irFilename = irFilename(filename);
                String irsolFilename = irsolFilename(filename);
                okay = okay && renameFile(destDirname + irFilename, destDirname + irsolFilename, t);
                File irsolFile = new File(destDirname + irsolFilename);
                if (!ensureFileExists(irsolFile, t)) {
                    okay = false;
                    continue;
                }
                okay = normalize(irsolFile, t, false);
            }
        return okay;
    }

    @Override
    public boolean normalizeReferenceFiles(SourceFileTest t, File destDir) {
        // Already normalized.
        // Interpretation was done as part of renaming.
        return true;
    }

    @Override
    public boolean normalizeGeneratedFiles(SourceFileTest t, File destDir) {
        boolean okay = true;
        String destDirname = appendDirSep(destDir.getPath());
        for (List<String> list : t.getSourceFileNames()) {
            for (String filename : list) {
                String irFilename = irFilename(filename);
                // Check existence of files.
                File irFile = new File(destDirname + irFilename);
                if (!ensureFileExists(irFile, t)) {
                    okay = false;
                    continue;
                }
                okay = okay && normalize(irFile, t, true);
            }
        }
        return okay;
    }

    protected boolean normalize(File irFile, SourceFileTest t, boolean adjustCounters) {
        File nmltypedFile = new File(normalizedFilename(irFile.getPath()));
        try (FileReader irfr = new FileReader(irFile);
                BufferedReader br = new BufferedReader(irfr);
                PrintStream ps = new PrintStream(nmltypedFile); ) {
            IRParser parser = new IRParser(new IRLexer(br), new IRNodeFactory_c());
            IRCompUnit compUnit = null;
            try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
                    PrintStream outps = new PrintStream(baos); ) {
                PrintStream out = System.out;
                PrintStream err = System.err;

                // Redirect standard output and standard error used by CUP
                // to designated PrintStream.
                System.setOut(outps);
                System.setErr(outps);
                try {
                    compUnit = parser.parse().<IRCompUnit>value();
                } catch (RuntimeException e) {
                    throw e;
                } catch (Exception e) {
                    // Used by CUP to indicate an unrecoverable error.
                    String msg = e.getMessage();
                    if (baos.size() > 0) t.appendFailureMessage(baos.toString());
                    if (msg != null) t.appendFailureMessage("IR syntax error: " + msg);
                    return false;
                } finally {
                    System.setOut(out);
                    System.setErr(err);
                }
            }

            if (compUnit != null) {
                boolean okay = true;
                if (adjustCounters) {
                    numTests++;

                    // Check for canonicity.
                    {
                        CheckCanonicalIRVisitor cv = new CheckCanonicalIRVisitor();
                        if (cv.visit(compUnit)) numCanon++;
                        else {
                            t.appendFailureMessage("Noncanonical IR:");
                            t.appendFailureMessage(prettyPrintIR(cv.noncanonical()));
                            okay = false;
                        }
                    }

                    // Check for optimized IR.
                    {
                        if (optimized) {
                            CheckConstFoldedIRVisitor cv = new CheckConstFoldedIRVisitor();
                            if (cv.visit(compUnit)) numFolded++;
                            else {
                                t.appendFailureMessage("IR not constant-folded:");
                                t.appendFailureMessage(prettyPrintIR(cv.unfolded()));
                                okay = false;
                            }
                        }
                    }
                }

                // Redirect standard output to output file.
                PrintStream out = System.out;
                System.setOut(ps);

                // SJS: hack so we can pass compUnit to thread
                final IRCompUnit compUnit_ = compUnit;

                // Interpret
                Callable<Long> callable =
                        new Callable<Long>() {
                            @Override
                            public Long call() {
                                // SJS: creating the IRSimulator may throw exception, so do it here.
                                IRSimulator sim = new IRSimulator(compUnit_, 20480);
                                return sim.call("_Imain_paai", 0);
                            }
                        };
                ExecutorService executor = Executors.newSingleThreadExecutor();
                Future<Long> task = executor.submit(callable);
                try {
                    task.get(TIMEOUT, TIMEUNIT);
                } catch (TimeoutException e) {
                    task.cancel(true);
                    t.appendFailureMessage(
                            "IR simulation timeout after " + TIMEOUT + " " + TIMEUNIT + ".");
                    return false;
                } catch (InterruptedException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                    assert false; // SJS: should never happen
                } catch (ExecutionException e) {
                    Throwable cause = e.getCause();
                    if (cause instanceof OutOfBoundTrap) {
                        ps.println("_eta_out_of_bounds called");
                        return true;
                    }
                    t.appendFailureMessage(cause.toString());
                    for (StackTraceElement ste : cause.getStackTrace()) {
                        if (ste.getClassName().startsWith("edu.cornell.cs.cs4120.xth")) break;
                        t.appendFailureMessage("\tat " + ste);
                    }
                    return false;
                } finally {
                    executor.shutdownNow();
                    // Restore standard output.
                    System.setOut(out);
                }
                return okay;
            } else return false;
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
            return false;
        }
    }

    protected String prettyPrintIR(IRNode n) {
        StringWriter sw = new StringWriter();
        try (PrintWriter pw = new PrintWriter(sw);
                SExpPrinter sp = new CodeWriterSExpPrinter(pw)) {
            n.printSExp(sp);
        }
        return sw.toString();
    }

    protected String normalizedFilename(String filename) {
        return filename + ".nml";
    }

    @Override
    public boolean checkResult(SourceFileTest t, File destDir) {
        boolean okay = true;
        String destDirname = appendDirSep(destDir.getPath());
        for (List<String> list : t.getSourceFileNames()) {
            for (String filename : list) {
                String typeddFilename = normalizedFilename(irFilename(filename));
                String typedsolFilename = normalizedFilename(irsolFilename(filename));
                File typedFile = new File(destDirname + typeddFilename);
                File typedsolFile = new File(destDirname + typedsolFilename);
                if (compareFiles(typedsolFile, typedFile, t)) numCorrect++;
                else okay = false;
            }
        }
        return okay;
    }

    @Override
    public boolean cleanupReferenceFiles(SourceFileTest t, File destDir, File saveDir) {
        boolean okay = true;
        String destDirname = appendDirSep(destDir.getPath());
        for (List<String> compilationUnit : t.getSourceFileNames())
            for (String filename : compilationUnit) {
                // .irsol.nml
                String irsolFilename = irsolFilename(filename);
                String irsolnmlFilename = normalizedFilename(irsolFilename);
                okay = okay && moveFileIfExists(destDirname + irsolnmlFilename, saveDir, t);
            }
        return okay;
    }

    @Override
    public boolean cleanupGeneratedFiles(SourceFileTest t, File destDir, File saveDir) {
        boolean okay = true;
        String destDirname = appendDirSep(destDir.getPath());
        for (List<String> compilationUnit : t.getSourceFileNames())
            for (String filename : compilationUnit) {
                // .ir
                String irFilename = irFilename(filename);
                okay = okay && moveFileIfExists(destDirname + irFilename, saveDir, t);
                // .irsol
                String irsolFilename = irsolFilename(filename);
                okay = okay && moveFileIfExists(destDirname + irsolFilename, saveDir, t);
                // .ir.nml
                String irnmlFilename = normalizedFilename(irFilename);
                okay = okay && moveFileIfExists(destDirname + irnmlFilename, saveDir, t);
            }
        return okay;
    }

    @Override
    public void printTestResult(SourceFileTest t, File destDir, FormattedOutput pr) {
        String destDirname = appendDirSep(destDir.getPath());
        for (List<String> compilationUnit : t.getSourceFileNames())
            for (String filename : compilationUnit) {
                // .ir
                String irFilename = irFilename(filename);
                // .ir.nml
                String irnmlFilename = normalizedFilename(irFilename);
                pr.printHeader("Generated result for --irrun:");
                pr.printCode(new File(destDirname + irnmlFilename));
                // .irsol.nml
                String irsolnmlFilename = normalizedFilename(irsolFilename(filename));
                pr.printHeader("Expected result for --irrun:");
                pr.printCode(new File(destDirname + irsolnmlFilename));

                pr.printHeader("Generated result for --irgen:");
                pr.printCode(new File(destDirname + irFilename));
            }
    }

    @Override
    public void getSummary(StringBuffer sb) {
        sb.append("\nNumber of IRs: " + numTests);
        sb.append("\nNumber of canonical IRs: " + numCanon);
        if (optimized) sb.append("\nNumber of constant-folded IRs: " + numFolded);
        sb.append("\nNumber of correct IRs: " + numCorrect);
    }

    protected String irFilename(String filename) {
        return filenameNoExt(filename) + ".ir";
    }

    protected String irsolFilename(String filename) {
        return filenameNoExt(filename) + ".irsol";
    }
}
