package ray2.solution;

import javax.swing.*;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.util.*;
import java.io.File;

/**
 * A graphical outer loop for a ray tracer that shows you the progress in a window in a
 * useful way.  It renders an undersampled image first so you can tell roughly what's in the
 * frame, and then it renders the whole image starting from the center, since that's usually
 * where the interesting part is.  When the image is finished it writes it to a PNG file.
 *
 * In order for this to work you need classes called RayTracer, Scene, Color, and Ray2Reader,
 * with the following properties:
 *   * Ray2Reader.readScene(String filename) should return a Scene
 *   * RayTracer.renderPixel(Scene s, int x, int y) should render a single pixel and return a Color
 *   * Color has floating-point fields r, g, and b that run from 0 to 1
 *   * Color.toInt() should return an integer suitable for BufferedImage.setRGB
 * If you are using the classes from the Ray I assignment, you are likely to be set up this way
 * already, except for RayTracer.renderPixel.
 *
 * @author srm, 26 Nov 2004
 */

public class RayViewer extends JFrame {

    static final int TILE_SIZE = 8;

    BufferedImage image;
    Thread worker;
    Scene scene;
    String outputFilename;

    RayViewer(Scene scene, String outputFilename) {
        super(outputFilename + ": Rendering");

        this.scene = scene;
        this.outputFilename = outputFilename;

        setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);

        int nx = scene.camera.xPixelCount;
        int ny = scene.camera.yPixelCount;

        JLabel imageLabel = new JLabel();
        imageLabel.setMinimumSize(new Dimension(nx, ny));
        imageLabel.setMaximumSize(new Dimension(nx, ny));
        
        getContentPane().add(imageLabel, BorderLayout.CENTER);

        image = new BufferedImage(nx, ny, BufferedImage.TYPE_INT_RGB);
        imageLabel.setIcon(new ImageIcon(image));

        pack();
    }


    public void startRendering() {

        worker = new Thread() {
            public void run() {
                int ny = image.getHeight();
                int nx = image.getWidth();
                int ntx = (nx + TILE_SIZE - 1) / TILE_SIZE;
                int nty = (ny + TILE_SIZE - 1) / TILE_SIZE;

                class Tile implements Comparator {
                    public int itx, ity;
                    Tile(int itx, int ity) { this.itx = itx; this.ity = ity; }
                    int distance(Tile t) {
                        int dx = t.itx - itx, dy = t.ity - ity;
                        return dx*dx + dy*dy;
                    }
                    public int compare(Object a, Object b) {
                        // compare by distance from this tile
                        return distance((Tile)a) - distance((Tile)b);
                    }
                };
                ArrayList tiles = new ArrayList();

                // make array of tiles and sort by distance from center
                for (int ity = 0; ity < nty; ity++)
                    for (int itx = 0; itx < ntx; itx++)
                    tiles.add(new Tile(itx, ity));
                Tile centerTile = new Tile(ntx/2, nty/2);
                Collections.sort(tiles, centerTile);

                // first pass: one pixel per tile for rough image
                long t0 = System.currentTimeMillis();
                for (int k = 0; k < tiles.size(); k++) {
                    Tile t = (Tile)tiles.get(k);
                    int x0 = t.itx*TILE_SIZE, y0 = t.ity*TILE_SIZE;
                    Color c = RayTracer.renderPixel(scene, x0 + TILE_SIZE/2, y0 + TILE_SIZE/2);
                    c.r = (1+c.r)/2; c.g = (1+c.g)/2; c.b = (1+c.b)/2;
                    for (int iy = y0; iy < ny && iy < y0+TILE_SIZE; iy++)
                        for (int ix = x0; ix < nx && ix < x0+TILE_SIZE; ix++) {
                            image.setRGB(ix, iy, c.toInt());
                        }
                    long t1 = System.currentTimeMillis();
                    if (t1 - t0 > 1000) {
                        repaint();
                        t0 = t1;
                    }
                }

                // second pass: render all pixels one tile at a time
                for (int k = 0; k < tiles.size(); k++) {
                    Tile t = (Tile)tiles.get(k);
                    int x0 = t.itx*TILE_SIZE, y0 = t.ity*TILE_SIZE;
                    for (int iy = y0; iy < ny && iy < y0+TILE_SIZE; iy++)
                        for (int ix = x0; ix < nx && ix < x0+TILE_SIZE; ix++) {
                            Color c = RayTracer.renderPixel(scene, ix, iy);
                            image.setRGB(ix, iy, c.toInt());
                        }
                    repaint();
                }

                // write the image to a .png file
                boolean wroteOK = true;
                try {
                    ImageIO.write(image, "PNG", new File(outputFilename));
                } catch (Exception e) {
                    System.err.println("While writing output file:");
                    System.err.println("    " + e);
                    wroteOK = false;
                }

                if (wroteOK)
                    setTitle(outputFilename + ": Done");
                else
                    setTitle(outputFilename + ": Error");
            }
        };

        worker.start();
    }


    public static void main(String args[]) {

        String inputFilename = args[0];
        String outputFilename = inputFilename + ".png";
        Scene scene = null;

        try {
            scene = Ray2Reader.readScene(inputFilename);
        } catch (Exception e) {
            System.err.println("While reading input file '" + inputFilename + "':");
            System.err.println("    " + e);
            e.printStackTrace();
            return;
        }

        RayViewer rv = new RayViewer(scene, outputFilename);
        rv.startRendering();
        rv.show();
    }

}