package carray.ast;

import polyglot.ast.ArrayInit;
import polyglot.ast.Expr;
import polyglot.ast.NewArray;
import polyglot.ast.Node;
import polyglot.ast.NodeFactory;
import polyglot.ast.TypeNode;
import polyglot.ext.jl5.ast.AmbTypeInstantiation;
import polyglot.translate.ExtensionRewriter;
import polyglot.types.SemanticException;
import polyglot.types.Type;
import polyglot.util.SerialVersionUID;
import carray.translate.CArrayRewriter;
import carray.visit.ArrayInitRemover;

public class CArrayNewArrayExt extends CArrayExt {
    private static final long serialVersionUID = SerialVersionUID.generate();

    @Override
    public ArrayInitRemover removeArrayInitEnter(ArrayInitRemover v) {
        // Update the visitor to record that we are entering a
        // NewArray node.
        return v.inNewArray(true);
    }

    @Override
    public Node extRewrite(ExtensionRewriter rw) throws SemanticException {
        CArrayRewriter crw = (CArrayRewriter) rw;
        if (crw.translcateCArray()) {
            NewArray n = (NewArray) node();
            Type type = n.type();
            Type baseType = type.toArray().base();
            NodeFactory nf = rw.nodeFactory();

            n = (NewArray) super.extRewrite(rw);
            ArrayInit init = n.init();

            TypeNode basetn = rw.typeToJava(baseType, baseType.position());
            Expr initArg;
            if (basetn instanceof AmbTypeInstantiation) {
                initArg =
                        rw.qq()
                          .parseExpr("(%T) %E",
                                     nf.ArrayTypeNode(type.position(), basetn),
                                     nf.NewArray(n.position(),
                                                 ((AmbTypeInstantiation) basetn).base(),
                                                 1,
                                                 init));
            }
            else initArg = nf.NewArray(n.position(), basetn, 1, init);

            // new X[]{ ... } --> Array.init(new X[]{ ... })
            if (baseType.isPrimitive()) {
                return rw.qq().parseExpr("Array_" + baseType + ".init(%E)",
                                         initArg);
            }
            else {
                return rw.qq().parseExpr("Array.init(%E)", initArg);
            }
        }
        return super.extRewrite(rw);
    }
}
