/**
 * <p>Title: Shotgun Project</p>
 * <p>Description: </p>
 * <p>Copyright: </p>
 * <p>Company: </p>
 * @author Alex Ksikes
 * @version 2.1
**/

package shotgun;

import java.io.*;
import java.util.StringTokenizer;

/**
 * A model represents the predictions from one trained model or an ensemble of models.
**/
public class Model implements Comparable
{
  // Instance variables
  private String name;					// name of the model
  private File[] files;					// files which hold predictions
  private Predictions[] predictions;             	// all the predictions for the model
  private int numModels;				// number of models in the ensemble
  private double ensWeight;				// total weight (normalizing factor) of models in the ensemble
  private Threshold threshold; // threshold specific to this model

  // Static variables
  private static FileWriter[] out;		// used to write into files
  private static FileWriter[] outPred;	// used to record the best predictions
  private static int numSets;			// number of sets
  private static Targets[] targets;		// targets associated with each set
  private static String[] testName;		// names of each test set
  private static int trainIndex;		// indicates which set is used for training

  private static int cvFolds=1;			// cross validation folds
  private static int currentFold=0;		// the working fold
  private static Targets cvTargets;		// targets for each of the folds
  private static boolean cvTrain=false;	// are we currently in training mode

  private static boolean dynamicThreshold=false; // set threshold dynamically?

  // Constants
  private static final double SCORE_FUDGE = 1.0e-5;  // fudge factor for avoiding 0's and 1's

  /**
   * Default constructor constructs empty Predictions on each set
   */
  public Model()
  {
    threshold = null;
    if (dynamicThreshold) {
      threshold = new Threshold();
    }

    this.predictions = new Predictions[numSets];
    for (int i=0; i<numSets; i++) {
      this.predictions[i] = new Predictions(targets[i], testName[i]);
      assignThreshold(i);
    }
    numModels=0;
    ensWeight=0.0;

    if (cvFolds!=1) {
      this.predictions[trainIndex] = new Predictions(cvTargets, currentFoldLabel());
      assignThreshold(trainIndex);
    }
  }


  /**
   * Builds a new model from predictions in files.
   *
   * @param predictions The files containing the predictions for each set.
  **/
  public Model(File[] predictions)
  {
    this.files = predictions;
    String fileName=predictions[trainIndex].getName();
    this.name=fileName.substring(0,fileName.lastIndexOf('.'));

    threshold = null;
    if (dynamicThreshold) {
      threshold = new Threshold();
    }

    this.predictions = new Predictions[numSets];
    this.predictions[trainIndex] = readFile(trainIndex);
    assignThreshold(trainIndex);

    numModels=1;
    ensWeight = 0.0;
  }

  public void reload()
  {
    for (int i = 0; i < predictions.length; ++i)
    {
      this.predictions[i] = null;
    }

    threshold = dynamicThreshold ? new Threshold() : null;

    this.predictions[trainIndex]=readFile(trainIndex);

    if (cvFolds!=1)
    {
      double[] predictions = this.predictions[trainIndex].getPredictions();
      int trueSize=predictions.length;
      int cvSize = trueSize/cvFolds;
      int offset=currentFold*cvSize;

      double preds[];
      if (cvTrain)
        preds = new double[trueSize-cvSize];
      else
        preds = new double[cvSize];

      int count=0;
      for (int i=0; i<trueSize; i++)
      {
        if ((i>=offset && i<offset+cvSize)!=cvTrain)
          preds[count++]=predictions[i];
      }

      this.predictions[trainIndex]= new Predictions(preds,cvTargets,currentFoldLabel());
    }

    assignThreshold(trainIndex);
  }

  /**
   * Sets the main static variables for the Model class
   *
   * @param targets the targets for all the sets
   * @param testName the names of the sets
   * @param trainIndex the offset of the train set in the arrays
   * @param datathresh Set threshold based on percent positive in hillclimb set?
   */
  public static void setTargets(File[] targets,
                                String[] testName,
                                int trainIndex,
                                boolean datathresh)
  {
  	Model.numSets=targets.length;
  	Model.targets = new Targets[numSets];
  	for (int i=0; i<numSets; i++)
  		Model.targets[i]= new Targets(targets[i]);
  	Model.trainIndex=trainIndex;
  	Model.testName=testName;

    if (datathresh) {
      double negCount = Model.targets[trainIndex].getTotal_true_0();
      double total = Model.targets[trainIndex].getSize();
      Predictions.setThreshold(negCount / total);
    }
  }

  /**
   * Return the name of this model.
   *
   * @return the name of this model
  **/
  public String getName()
  {
    return name;
  }

  /**
   * Read the current fold
   */
  private Predictions readFile(int index)
  {
    int size=targets[index].getSize();
    double preds[] = new double[size];
    boolean file_read = false;
    int tries = 0;
    while (!file_read){
      try {
        // read probability classes from file
        BufferedReader bf=new BufferedReader(new FileReader(files[index]));
        for (int i=0; i<size; i++) {
          double val = Double.parseDouble(new StringTokenizer(bf.readLine()).nextToken());

          // Predictions that are exactly 0 or 1 are problematic for
          // computing cross entropy.  Intelligent models should never
          // predict these because they can never be certain that they
          // have seen enough cases to generalize.  For our purposes,
          // it would be bad to choose an overly confident model that
          // fit the hillclimbing data perfectly, but then had
          // terrible MXE on a (much larger) test set.
          //
          // So, we'll nudge prediction values away from 0 and 1 by a
          // tiny amount.

          if (val >= 1.0) {
            val -= SCORE_FUDGE;
          }
          else if (val <= 0.0) {
            val = SCORE_FUDGE;
          }

          if (val < 0.0 || val > 1.0) {
            System.err.println("error: " + files[index] + " contains prediction(s) not in range [0,1]");
            System.exit(1);
          }
          preds[i] = val;
        }
        file_read = true;
        bf.close();
      }
      catch (IOException e) {
        if ( tries > 100 ) {
          System.out.println("Error : Problem with file " + name + "." + testName[index]);
          System.exit(-1);
        }
        else {
          try {
            Thread.currentThread().sleep(10000);
          }
          catch (InterruptedException e1)
          {}
          tries++;
          file_read = false;
        }
      }
    }
    return new Predictions(preds, targets[index], testName[index]);
  }

  /**
   * Return the Loss of this model on the train set.
   */
  public double getLoss()
  {
    return this.predictions[trainIndex].getLoss();
  }

  /**
   * Return the performance of this model on the train set.
   */
  public double getPerformance()
  {
  	return this.predictions[trainIndex].getPerformance();
  }

  public void setName(String name)
  {
  	this.name = name;
  }

  /**
   * Compare the error of this model with another model.
   * 	Note that this is error not performance so it's better to be less
   *
   * @param o The model to be compared.
   * @return 1 if this model has greater error, 0 if equal and -1 otherwise.
  **/
  public int compareTo(Object o)
  {
  	return ((Model)o).predictions[trainIndex].compareTo(predictions[trainIndex]);
  }

  /**
   * Get the cross-validation fold set label.
   * @return The cross-validation label.
   */
  private static String currentFoldLabel()
  {
    String label = "fold" + Integer.toString(currentFold)
      + (cvTrain ? "train" : "test");
    return label;
  }

  /**
   * Sets the file writers for output
   * 	If no output is given the test files have names in consecutive order
   * 	and the train file is called perf.train
   *
   * @param output A special name for output files
   * @param writePred Whether predictions will be written in the end
   * @param allSets Write files for all test sets, or just hillclimb set?
   */
  public static void setFileWriters(String output, boolean writePred, boolean allSets)
  {
    out=new FileWriter[numSets];
    outPred=new FileWriter[numSets];

    // Set up file writers.
    try {
      boolean defaultLabel = output.equals("");
      for (int i = 0; i < numSets; ++i) {
        if (allSets || i == trainIndex) {
          String suffix =
            defaultLabel ? (testName[i]+".1") : (output+"."+testName[i]);
          out[i] = new FileWriter(new File("perf."+suffix));
          if (writePred)
            outPred[i] = new FileWriter(new File("preds."+suffix));
        }
      }
    } catch (IOException e) {}
  }

  /**
   * Add a model or an ensemble to the current ensemble
   *
   * @param model the model who's predictions we wish to add
   */
  public void addModel(Model model, boolean allSets, double newWeight)
  {
  	if (allSets && cvFolds==1)
  	{
		for (int i=0; i<numSets; i++)
			predictions[i].addWeighted(model.getPredictions(i),ensWeight,newWeight);
  	}
  	else
  		predictions[trainIndex].addWeighted(model.predictions[trainIndex],ensWeight,newWeight);
    ensWeight+=newWeight;
    numModels++;
  }

  public void addModel(Model model, boolean allSets, int newWeight)
  {
    if (allSets && cvFolds==1)
    {
      for (int i=0; i<numSets; i++)
        predictions[i].addWeighted(model.getPredictions(i),numModels,newWeight);
    }
    else {
      predictions[trainIndex].addWeighted(model.predictions[trainIndex],numModels,newWeight);
    }
    numModels+=newWeight;
    ensWeight+=newWeight;
  }

  public void addModel(Model model, boolean allSets)
  {
    if (!allSets && model.numModels==1)
    {
      predictions[trainIndex].add(model.predictions[trainIndex],numModels);
      numModels++;
      ensWeight++;
    }
    else
      addModel(model,allSets,model.numModels);
  }

  /**
   * Subtract a model or an ensemble from the current ensemble
   *
   * @param model the model who's predictions we wish to subtract
   */
  public void subModel(Model model, boolean allSets, int newWeight)
  {
  	if (allSets && cvFolds==1)
  	{
		for (int i=0; i<numSets; i++)
			predictions[i].subWeighted(model.getPredictions(i),numModels,newWeight);
  	}
  	else
  		predictions[trainIndex].subWeighted(model.predictions[trainIndex],numModels,newWeight);
    numModels--;
  }

  public void subModel(Model model, boolean allSets)
  {
  	subModel(model,allSets,model.numModels);
  }

  /**
   * A new set of predictions on the given set
   */
  private Predictions getPredictions(int index)
  {
    Predictions pred = predictions[index];
    if (pred == null) {
      // Predictions aren't cached yet,
      pred = readFile(index);
      predictions[index] = pred;
      assignThreshold(index);
      predictions[index] = null;
      // TODO: rewrite this (and rest of class) to use SoftReferences
      // This will require some pretty major testing though.
      //predictions[index] = pred;
    }
    return pred;
  }

  /**
   * Make a copy of all the predictions for this model
   *
   * @return a Model with the same predictions as the current model
   */
  public Model copy()
  {
    Model newModel = new Model();
    if (dynamicThreshold) {
      newModel.threshold = new Threshold(this.threshold);
    }

    for (int i=0; i<numSets; i++) {
      newModel.predictions[i] = this.getPredictions(i).copy();
      newModel.assignThreshold(i);
    }

    newModel.numModels = numModels;

    return newModel;
  }

    /**
     * Get the logliklehood of the train set given the model. Used with baesyan averaging.
     **/
    public double getLogLikelihood()
    {
	return predictions[trainIndex].computeLogLikelihood();
    }

  /**
   * Report the performance of this model
   *
   * @param name the name to be outputted at the end of the perf line
   * @param allSets Report on all test sets or just hillclimb set?
   */
  public void report(String name, boolean allSets)
  {
    // Implementation note: always report on the training set first.
    // This ensures that when dynamic thresholding is used, that the
    // threshold will be found and fixed to the same value for all the
    // test sets.  We need to handle this here instead of in
    // Predictions because of a leaky abstraction.  Essentially, the
    // Predictions class computes performance lazily (when asked for a
    // performance) instead of aggressively (whenever predictions
    // change via adding a model, at creation time, etc).  This saves
    // some needless computation, but means that a dynamic threshold
    // is not set until the training set predictions are asked for a
    // performance computation.

    report(name, trainIndex);

    if (allSets) {
      for (int i = 0; i < numSets; ++i) {
        if (i != trainIndex) {
          report(name, i);
        }
      }
    }
  }

  private void report(String name, int i)
  {
    getPredictions(i).report(out[i], testName[i], name, numModels);
  }

  /**
   * Write the predictions for this ensemble to a file
   */
  public void write()
  {
   for (int i=0; i<numSets; i++)
     predictions[i].write(outPred[i]);
  }

  /**
   * Set the current fold in cross validation
   */
  public static void setFolds(int folds, int current, boolean trainMode)
  {
    cvFolds=folds;

    if (folds!=1)
    {
      cvTrain=trainMode;
      currentFold=current;

      int[] targets = Model.targets[trainIndex].getTargets();
      int size= targets.length;
      int cvSize = size/cvFolds;
      int offset=currentFold*cvSize;
      int[] targs;

      if (trainMode)
        targs = new int[size-cvSize];
      else
        targs = new int[cvSize];

      int count=0;
      for (int i=0; i<size; i++)
      {
        if ((i>=offset && i<offset+cvSize)!=cvTrain)
          targs[count++]=targets[i];
      }
      cvTargets = new Targets(targs);
    }
  }

  public static void setFolds(int folds)
  {
    setFolds(folds,currentFold,cvTrain);
  }

  /**
   * Sets whether performance calculations on the training set should
   * determine the pos-neg threshold dynamically.
   * @param dynamic Dynamic threshold?
   */
  public static void setDynamicThreshold(boolean dynamic)
  {
    dynamicThreshold = dynamic;
  }

  /**
   * Create a baseline model that makes predictions based solely on
   * the ratio of positive/negative instances.
   * IMPORTANT: setTargets() should be called first
   * @return A baseline model.
   */
  public static Model getBaseline()
  {
    // Local constants
    final double MAX_MEAN = 0.99999;
    final double MIN_MEAN = 0.00001;

    Model baseline = new Model();
    baseline.name = "baseline";
    baseline.numModels = 1;
    baseline.ensWeight = 1.0; // REVIEW: correct?

    // Figure out what the mean prediction is for the train data set.
    int sum = baseline.targets[trainIndex].getTotal_true_1();
    int count = baseline.targets[trainIndex].getSize();

    double mean = ((double)sum) / ((double)count);

    // Trim extreme mean values.  Otherwise cross-entropy can go to infinity.
    if (mean > MAX_MEAN)
      mean = MAX_MEAN;
    else if (mean < MIN_MEAN)
      mean = MIN_MEAN;

    // Set the prediction for all the data points to the mean.
    for (int i = 0; i < numSets; ++i)
    {
      double [] preds = new double[baseline.targets[i].getSize()];
      java.util.Arrays.fill(preds, mean);
      baseline.predictions[i] = new Predictions(preds,
                                                baseline.targets[i],
                                                baseline.testName[i]);
      baseline.assignThreshold(i);
    }

    // REVIEW: does something special need to be done when crossvalidating?

    return baseline;
  }

  private void assignThreshold(int testIndex)
  {
    Predictions p = predictions[testIndex];
    if (p != null) {
      p.setDynamicThreshold(dynamicThreshold && testIndex == trainIndex,
                            threshold);
    }
  }

  /**
   * Method getNumModels.
   *
   * @return Number of models in ensemble */
  public int getNumModels()
  {
	return numModels;
  }

  public double getEnsWeight()
  {
    return ensWeight;
  }

  public double tryAdd(Model add)
  {
    Predictions temp= predictions[trainIndex].copy();
    Threshold t = dynamicThreshold ? new Threshold(threshold) : null;
    temp.setDynamicThreshold(dynamicThreshold, t);
    temp.addWeighted(add.predictions[trainIndex],this.numModels,add.numModels);
    return temp.getLoss();
  }

  public double trySub(Model sub)
  {
    Predictions temp= predictions[trainIndex].copy();
    Threshold t = dynamicThreshold ? new Threshold(threshold) : null;
    temp.setDynamicThreshold(dynamicThreshold, t);
    temp.subWeighted(sub.predictions[trainIndex],this.numModels,sub.numModels);
    return temp.getLoss();
  }
}
