/**
 * <p>Title: Shotgun Project</p>
 * <p>Description: Flexible multi metric implementation.</p>
 * @author M. Arthur Munson
 * @version 1.0
 * History:
 *  2005/02/23  Art Munson created.
 */

package shotgun.metrics;
import shotgun.Predictions;

/**
 * Wrapper metric that does a weighted average of multiple other
 * metrics.  The average is reported with as a loss, with low numbers
 * good and high numbers bad.  This allows combining metrics like ACC
 * and RMS that normally treat low and high numbers differently.
 *
 * Example usage with shotgun to optimize multi-metric combining ACC,
 * ROC, and RMS that gives RMS as much weight as the other two
 * combined:
 *  java <shotgun> -sfr 200 -custom my_multi \
 *     -addmetrics shotgun.metrics.MultiMetric \
 *     "my_multi acc:1 roc:1 rms:2" ...
 *
 * Weights can be any double number, although only positive numbers
 * have been tested.  Omitted weights default to 1; for example, the following is identical to above:
 *     "my_multi acc roc rms:2"
 *
 * The weighted average is scaled by the total weight so that the
 * number always falls in the range [0,1].
 */
public class MultiMetric implements MetricBundle
{
  private static final double EPS = 1.0e-99;
  private int[] components = null;
  private double[] weights = null;
  private double totalWeight = 1.0;
  private String label = null;

  // Cache of previously computed values.
  private PerfCache cache;

  public MultiMetric()
  {
    cache = new PerfCache();
  }

  ///////////////////////////////////////////////////////////
  // MetricBundle implementation.
  ///////////////////////////////////////////////////////////

  public int count()
  {
    return 1;
  }

  /**
   * Initialize a multi metric.
   * @param arg - Parameters for the multi-metric.  Should have the
   * format:
   *
   *   "label metric1:weight1 metric2:weight2 ... metricN:weightN"
   *
   * Note that the quotes are required.  The label is how shotgun will
   * identify the multi-metric.  The metric names used should be
   * shotgun labels, e.g. acc, rms, roc, ...  Weights can be any
   * double number, preferably positive.  Omitted weights default to
   * 1. All of the component metrics should be added before this one.
   * @return True iff initialization succeeded.
   */
  public boolean init(String arg)
  {
    label = null;
    components = null;
    weights = null;

    String[] parms = arg.split(" ");
    if (parms.length < 3) {
      System.err.println("missing label or component metric specifications for MultiMetric");
      return false;
    }

    // First string is the label name for the multi-metric.
    label = parms[0].toUpperCase();

    // Allocate space for storing the components and their weights.
    int numMetrics = parms.length - 1;
    components = new int[numMetrics];
    weights = new double[numMetrics];
    totalWeight = 0.0;

    // Find the component metrics.
    for (int i = 0; i < numMetrics; ++i) {
      String[] pair = parms[i+1].split(":");

      // Ensure the metric name is capitalized.
      String name = pair[0].toUpperCase();

      int index = Predictions.lookupMode(name);
      if (index == -1) {
        System.err.println("metric "+name+" was not loaded before being included in a MultiMetric");
        return false;
      }

      double weight = 1.0;
      if (pair.length >= 2) {
        weight = Double.parseDouble(pair[1]);
      }

      components[i] = index;
      weights[i] = weight;
      totalWeight += weight;
    }

    return true;
  }

  public void invalidateCache(Predictions pred)
  {
    cache.remove(pred);
  }

  public String name(int i)
  {
    return label;
  }

  public double performance(int i, Predictions pred)
  {
    if (totalWeight < EPS) {
      return 0.0;
    }

    // Some of the components may be expensive to compute and not
    // individually cached.  First check the cache to see if we can
    // skip the computation.
    double [] scores = cache.get(pred);

    if (scores == null) {
      double perf = 0.0;
      for (int c = 0; c < components.length; ++c) {
        int metric = components[c];
        perf += weights[c] * pred.getLoss(metric, pred.compute(metric));
      }

      scores = new double[1];
      scores[0] = perf / totalWeight;
      cache.put(pred, scores);
    }

    // Read performance from cache.
    return scores[0];
  }

  public double loss(int i, Predictions pred, double perf)
  {
    return perf;
  }

  public boolean smallerIsBetter(int i)
  {
    return true;
  }

  public boolean requiresStrictOrder(int i)
  {
    // If any component metric requires strict order, then the
    // multi-metric must require strict ordering.
    boolean reorderOkay = true;
    for (int c = 0; reorderOkay && c < components.length; ++c) {
      reorderOkay = Predictions.reorderingAllowed(components[c]);
    }
    return !reorderOkay;
  }

  public boolean thresholdSensitive(int i)
  {
    // If any component metric is sensitive, then the multi-metric is.
    boolean sensitive = false;
    for (int c = 0; !sensitive && c < components.length; ++c) {
      sensitive = Predictions.thresholdSensitive(components[c]);
    }
    return sensitive;
  }
}

