/**
 * Visitor pattern
 * 
 * The Visitor performs different actions, depending on the type
 * of node it is visiting. It packages this functionality in a
 * collection of callback methods that it provides to the Nodes.
 * When a Node is visited, it calls the Visitor's callback to
 * perform the operation.
 * 
 * The advantage of the visitor pattern is that the code to deal
 * with the different types of Nodes is kept together in the same
 * place in the Visitor, not spread out among the Nodes.
 * The disadvantage is that we cannot use dynamic dispatch on the
 * runtime type of the Node, thus the trivial accept(Visitor<T>)
 * method to call the callback must be duplicated in every Node.
 */
abstract interface Visitor<T> {
	public abstract T visit(MultiplyNode node);
	public abstract T visit(AddNode node);
	public abstract T visit(FunctionApplicationNode node);
	public abstract T visit(ParameterNode node);
	public abstract T visit(ConstantNode node);
}

interface Node {
	public <T> T accept(Visitor<T> visitor);
}

class MultiplyNode implements Node {
	Node left, right;

	public MultiplyNode(Node left, Node right) {
		this.left = left;
		this.right = right;
	}
	
   public <T> T accept(Visitor<T> visitor) {
      return visitor.visit(this);
   }
}

class AddNode implements Node {
   Node left, right;

   public AddNode(Node left, Node right) {
		this.left = left;
		this.right = right;
	}

	public <T> T accept(Visitor<T> visitor) {
		return visitor.visit(this);
	}
}

class FunctionApplicationNode implements Node {
   String name;
   Node argument;

	public FunctionApplicationNode(String name, Node argument) {
		this.name = name;
		this.argument = argument;
	}

   public <T> T accept(Visitor<T> visitor) {
      return visitor.visit(this);
   }
}

class ParameterNode implements Node {
	public <T> T accept(Visitor<T> visitor) {
		return visitor.visit(this);
	}
}

class ConstantNode implements Node {
   int constant;

   public ConstantNode(int constant) {
      this.constant = constant;
   }

   public <T> T accept(Visitor<T> visitor) {
		return visitor.visit(this);
	}
}

/**
 * Example: Pretty printing using the visitor pattern.
 */
class PrettyPrinter implements Visitor<String> {
	public String visit(MultiplyNode node) {
		return "(" + node.left.accept(this) + ")*(" + node.right.accept(this) + ")";
	}

	public String visit(AddNode node) {
		return "(" + node.left.accept(this) + ")+(" + node.right.accept(this) + ")";

	}

	public String visit(FunctionApplicationNode node) {
		return node.name + "(" + node.argument.accept(this) + ")";
	}

	public String visit(ParameterNode node) {
		return "x";
	}

	public String visit(ConstantNode node) {
		return String.valueOf(node.constant);
	}
}

/**
 * Example: Differentiation using the visitor pattern.
 */
class Differentiator implements Visitor<Node> {
	// Sum rule for derivatives
	public Node visit(AddNode node) {
		return new AddNode(node.left.accept(this), node.right.accept(this));
	}

	// Product rule for derivatives
	public Node visit(MultiplyNode node) {
		Node left_deriv = node.left.accept(this);
		Node right_deriv = node.right.accept(this);
		return new AddNode(new MultiplyNode(left_deriv, node.right),
				new MultiplyNode(node.left, right_deriv));
	}

	// Chain rule for derivatives
	public Node visit(FunctionApplicationNode node) {
		return new MultiplyNode(new FunctionApplicationNode(node.name + "'", node.argument),
				node.argument.accept(this));
	}

	// Derivative of x is 1
	public Node visit(ParameterNode node) {
		return new ConstantNode(1);
	}

	// Derivative of a constant is 0
	public Node visit(ConstantNode node) {
		return new ConstantNode(0);
	}
}

public class VisitorExample {
	public static void main(String[] args) {
		// f(x^2) + x + 17
		Node example_function = new AddNode(
				new FunctionApplicationNode("f", new MultiplyNode(new ParameterNode(), new ParameterNode())),
				new AddNode(new ParameterNode(), new ConstantNode(17)));

		PrettyPrinter printer = new PrettyPrinter();
		System.out.println(example_function.accept(printer));

		// Hopefully f'(x^2)*2x + 1
		Differentiator diff = new Differentiator();
		Node derivative_function = example_function.accept(diff);
		System.out.println(derivative_function.accept(printer));
	}
}
