/*
 * Decompiled with CFR 0.152.
 */
package polyglot.visit;

import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import polyglot.ast.ArrayInit;
import polyglot.ast.Assign;
import polyglot.ast.Block;
import polyglot.ast.ConstructorCall;
import polyglot.ast.Do;
import polyglot.ast.Eval;
import polyglot.ast.Expr;
import polyglot.ast.FieldDecl;
import polyglot.ast.For;
import polyglot.ast.If;
import polyglot.ast.Lit;
import polyglot.ast.Local;
import polyglot.ast.LocalDecl;
import polyglot.ast.Node;
import polyglot.ast.NodeFactory;
import polyglot.ast.Special;
import polyglot.ast.Stmt;
import polyglot.ast.Switch;
import polyglot.ast.Term;
import polyglot.ast.TypeNode;
import polyglot.ast.Unary;
import polyglot.ast.While;
import polyglot.types.Flags;
import polyglot.types.TypeSystem;
import polyglot.util.Position;
import polyglot.visit.NodeVisitor;

public class FlattenVisitor
extends NodeVisitor {
    protected TypeSystem ts;
    protected NodeFactory nf;
    protected LinkedList<List<Stmt>> stack;
    protected static int count = 0;
    protected Set<Term> noFlatten = new HashSet<Term>();
    protected Set<Term> neverFlatten = new HashSet<Term>();

    public FlattenVisitor(TypeSystem ts, NodeFactory nf) {
        super(nf.lang());
        this.ts = ts;
        this.nf = nf;
        this.stack = new LinkedList();
    }

    @Override
    public Node override(Node parent, Node n) {
        if (n instanceof If) {
            If s = (If)n;
            Stmt s1 = s.consequent();
            Stmt s2 = s.alternative();
            if (!(s1 instanceof Block)) {
                s = s.consequent(this.nf.Block(s1.position(), s1));
            }
            if (s2 != null && !(s2 instanceof Block)) {
                s = s.alternative(this.nf.Block(s2.position(), s2));
            }
            return this.visitEdgeNoOverride(parent, s);
        }
        if (n instanceof Do) {
            Do s = (Do)n;
            Stmt s1 = s.body();
            if (!(s1 instanceof Block)) {
                s = s.body(this.nf.Block(s1.position(), s1));
            }
            return this.visitEdgeNoOverride(parent, s);
        }
        if (n instanceof While) {
            While s = (While)n;
            Stmt s1 = s.body();
            if (!(s1 instanceof Block)) {
                s = s.body(this.nf.Block(s1.position(), s1));
            }
            return this.visitEdgeNoOverride(parent, s);
        }
        if (n instanceof For) {
            For s = (For)n;
            Stmt s1 = s.body();
            if (!(s1 instanceof Block)) {
                s = s.body(this.nf.Block(s1.position(), s1));
            }
            return this.visitEdgeNoOverride(parent, s);
        }
        if (n instanceof FieldDecl || n instanceof ConstructorCall) {
            if (!this.stack.isEmpty()) {
                List<Stmt> l = this.stack.getFirst();
                l.add((Stmt)n);
            }
            return n;
        }
        if (n instanceof Switch) {
            return n;
        }
        if (this.neverFlatten.contains(n)) {
            return n;
        }
        if (n instanceof ArrayInit) {
            return n;
        }
        return null;
    }

    protected static String newID() {
        return "flat$$$" + count++;
    }

    @Override
    public NodeVisitor enter(Node parent, Node n) {
        Term s;
        if (n instanceof Block) {
            this.stack.addFirst(new LinkedList());
        }
        if (n instanceof Eval) {
            s = (Eval)n;
            this.noFlatten.add(s.expr());
        }
        if (n instanceof LocalDecl) {
            s = (LocalDecl)n;
            this.noFlatten.add(s.init());
        }
        if (n instanceof For) {
            s = (For)n;
            this.noFlatten.addAll(s.inits());
            this.neverFlatten.addAll(s.iters());
            this.neverFlatten.add(s.cond());
        }
        if (n instanceof While) {
            s = (While)n;
            this.neverFlatten.add(s.cond());
        }
        if (n instanceof Do) {
            s = (Do)n;
            this.neverFlatten.add(s.cond());
        }
        if (n instanceof Assign) {
            s = (Assign)n;
            this.noFlatten.add(s.left());
            this.noFlatten.add(s.right());
        }
        if (n instanceof Unary) {
            Unary u = (Unary)n;
            this.noFlatten.add(u.expr());
        }
        return this;
    }

    @Override
    public Node leave(Node parent, Node old, Node n, NodeVisitor v) {
        if (this.noFlatten.contains(old)) {
            this.noFlatten.remove(old);
            return n;
        }
        if (n instanceof Block) {
            List<Stmt> l = this.stack.removeFirst();
            Block block = ((Block)n).statements(l);
            if (parent instanceof Block && !this.stack.isEmpty()) {
                l = this.stack.getFirst();
                l.add(block);
            }
            return block;
        }
        if (n instanceof Stmt) {
            List<Stmt> l = this.stack.getFirst();
            l.add((Stmt)n);
            return n;
        }
        if (n instanceof Expr && !(n instanceof Lit) && !(n instanceof Special) && !(n instanceof Local)) {
            Expr e = (Expr)n;
            if (e instanceof Assign) {
                return n;
            }
            String name = FlattenVisitor.newID();
            LocalDecl def = this.nf.LocalDecl(e.position(), Flags.FINAL, (TypeNode)this.nf.CanonicalTypeNode(e.position(), e.type()), this.nf.Id(Position.compilerGenerated(), name), e);
            def = def.localInstance(this.ts.localInstance(e.position(), Flags.FINAL, e.type(), name));
            List<Stmt> l = this.stack.getFirst();
            l.add(def);
            Local use = this.nf.Local(e.position(), this.nf.Id(Position.compilerGenerated(), name));
            use = (Local)use.type(e.type());
            use = use.localInstance(this.ts.localInstance(e.position(), Flags.FINAL, e.type(), name));
            return use;
        }
        return n;
    }
}

