package com.engisis.sysphs.translation.modelica;

import java.util.ArrayList;
import java.util.Hashtable;
import java.util.List;

import com.engisis.sysphs.generation.modelica.ModelicaBaseListener;
import com.engisis.sysphs.generation.modelica.ModelicaParser;
import com.engisis.sysphs.generation.modelica.ModelicaParser.EquationContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Add_opContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Arithmetic_expressionContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Array_subscriptsContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Component_referenceContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.ExpressionContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Expression_listContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.FactorContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.For_indexContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.For_indicesContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.For_statementContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Function_argumentContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Function_argumentsContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Function_call_argsContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.If_equationContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.If_statementContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Logical_expressionContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Logical_factorContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Logical_termContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Mul_opContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Name_pathContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Named_argumentContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Named_argumentsContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Output_expression_listContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.PrimaryContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Rel_opContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.RelationContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.Simple_expressionContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.StatementContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.SubscriptContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.TermContext;
import com.engisis.sysphs.generation.modelica.ModelicaParser.While_statementContext;

import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.tree.ParseTreeProperty;
import org.antlr.v4.runtime.tree.TerminalNode;
import org.apache.log4j.Logger;

/**
 * Class used to convert Modelica expressions into the SysPhS expression
 * language
 * 
 * @author barbau
 *
 */
public class ModelicaExpressionExtractor extends ModelicaBaseListener
{
    private static final Logger log = Logger.getLogger(ModelicaExpressionExtractor.class);
    
    /**
     * Translated text
     */
    private ParseTreeProperty<String> ptp = new ParseTreeProperty<String>();
    /**
     * Map with a property paths as keys, and constraint names as values
     */
    private Hashtable<List<String>, String> parameters;
    /**
     * Values
     */
    private ParseTreeProperty<Object> values = new ParseTreeProperty<Object>();
    /**
     * LHS bound element
     */
    private List<String> lbind;
    /**
     * RHS bound element
     */
    private List<String> rbind;
    /**
     * Substitution table
     */
    private Hashtable<String, String> substitutions;
    /**
     * Actions associated with each active state
     */
    private Hashtable<String, String> stateactions;
    
    /**
     * Returns the left-hand side of the equation if that equation is a binding
     * 
     * @return path to the LHS binding
     */
    public List<String> getLBinding()
    {
        return lbind;
    }
    
    /**
     * Returns the right-hand side of the equation if that equation is a binding
     * 
     * @return path to the RHS binding
     */
    public List<String> getRBinding()
    {
        return rbind;
    }
    
    /**
     * Get constraint parameter table (path -> constraint)
     * 
     * @return parameters table of parameters, with paths as keys, and parameter names as values
     */
    public Hashtable<List<String>, String> getParameters()
    {
        return parameters;
    }
    
    /**
     * Sets the map with state names as keys, and actions as values
     * 
     * @param stateactions map with state names as keys, and actions as values
     */
    public void setStateActions(Hashtable<String, String> stateactions)
    {
        this.stateactions = stateactions;
    }
    
    /**
     * Sets the substitution table
     * 
     * @param substitutions substitution table
     */
    public void setSubstitutions(Hashtable<String, String> substitutions)
    {
        this.substitutions = substitutions;
    }
    
    /**
     * Retrieves the translated text for a given syntax tree node
     * 
     * @param pt syntax tree node
     * @return translated text
     */
    public String getValue(ParseTree pt)
    {
        String v = ptp.get(pt);
        if (v == null)
            log.error("The value of " + pt.getClass() + " is null");
        return v;
    }
    
    /**
     * Sets the translated text for the given syntax tree node
     * 
     * @param pt syntax tree node
     * @param str translated text
     */
    public void setValue(ParseTree pt, String str)
    {
        ptp.put(pt, str);
    }
    
    public void prepareForConstraint()
    {
        parameters = new Hashtable<List<String>, String>();
    }
    
    @Override
    public void enterEquation(EquationContext ctx)
    {
        super.enterEquation(ctx);
    }
    
    @SuppressWarnings("unchecked")
    @Override
    public void exitEquation(EquationContext ctx)
    {
        processStraight(ctx);
        
        // to make sure only top-level equations are bound
        lbind = null;
        rbind = null;
        
        if (ctx.simple_expression() == null || ctx.expression() == null)
            return;
        Object lb = values.get(ctx.simple_expression());
        Object rb = values.get(ctx.expression());
        if (lb instanceof List && rb instanceof List)
        {
            lbind = (List<String>) lb;
            rbind = (List<String>) rb;
        }
    }
    
    @Override
    public void exitStatement(StatementContext ctx)
    {
        processStraight(ctx);
    }
    
    @Override
    public void exitIf_equation(If_equationContext ctx)
    {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < ctx.getChildCount(); i++)
        {
            ParseTree pt = ctx.getChild(i);
            if (pt instanceof TerminalNode)
            {
                int t = ((TerminalNode) pt).getSymbol().getType();
                if (t == ModelicaParser.IF)
                    sb.append("if");
                else if (t == ModelicaParser.THEN)
                    sb.append("then\n");
                else if (t == ModelicaParser.ELSEIF)
                    sb.append("elseif");
                else if (t == ModelicaParser.ELSE)
                    sb.append("else\n");
                else if (t == ModelicaParser.END)
                    sb.append("end ");
            }
            else if (pt instanceof ExpressionContext)
                sb.append(" " + getValue(pt) + " ");
            else if (pt instanceof EquationContext)
                sb.append(getValue(pt) + ";\n");
        }
        setValue(ctx, sb.toString());
    }
    
    @Override
    public void exitIf_statement(If_statementContext ctx)
    {
        StringBuilder sb = new StringBuilder();
        
        String currentstate = null;
        StringBuilder sb2 = new StringBuilder();
        
        for (int i = 0; i < ctx.getChildCount(); i++)
        {
            ParseTree pt = ctx.getChild(i);
            if (pt instanceof TerminalNode)
            {
                int t = ((TerminalNode) pt).getSymbol().getType();
                if (t == ModelicaParser.IF)
                    sb.append("if");
                else if (t == ModelicaParser.THEN)
                    sb.append("then\n");
                else if (t == ModelicaParser.ELSEIF)
                {
                    if (currentstate != null)
                    {
                        
                        stateactions.put(currentstate, sb2.toString());
                        sb2.setLength(0);
                    }
                    currentstate = null;
                    sb.append("elseif");
                }
                else if (t == ModelicaParser.ELSE)
                {
                    if (currentstate != null)
                    {
                        stateactions.put(currentstate, sb2.toString());
                        sb2.setLength(0);
                    }
                    currentstate = null;
                    sb.append("else\n");
                }
                else if (t == ModelicaParser.END)
                {
                    if (currentstate != null)
                    {
                        stateactions.put(currentstate, sb2.toString());
                        sb2.setLength(0);
                    }
                    currentstate = null;
                    sb.append("end ");
                }
            }
            else if (pt instanceof ExpressionContext)
            {
                String v = getValue(pt);
                String act = ".active";
                if (v.endsWith(act))
                {
                    currentstate = v.substring(0, v.length() - act.length());
                }
                sb.append(" " + v + " ");
            }
            else if (pt instanceof StatementContext)
            {
                if (currentstate != null)
                    sb2.append(getValue(pt) + ";");
                sb.append(getValue(pt) + ";\n");
            }
        }
        setValue(ctx, sb.toString());
    }
    
    @Override
    public void exitFor_statement(For_statementContext ctx)
    {
        StringBuilder sb = new StringBuilder();
        sb.append("for " + getValue(ctx.for_indices()) + " loop\n");
        for (StatementContext sc : ctx.statement())
            sb.append(getValue(sc) + ";\n");
        sb.append("end for");
        setValue(ctx, sb.toString());
    }
    
    @Override
    public void exitFor_indices(For_indicesContext ctx)
    {
        processStraight(ctx);
    }
    
    @Override
    public void exitFor_index(For_indexContext ctx)
    {
        processStraight(ctx);
    }
    
    @Override
    public void exitWhile_statement(While_statementContext ctx)
    {
        StringBuilder sb = new StringBuilder();
        sb.append("while " + getValue(ctx.expression()) + " loop");
        for (StatementContext sc : ctx.statement())
            sb.append(getValue(sc) + ";\n");
        sb.append("end while");
        setValue(ctx, sb.toString());
    }
    
    @Override
    public void exitExpression(ExpressionContext ctx)
    {
        processWithSpaces(ctx);
        if (ctx.simple_expression() != null)
            values.put(ctx, values.get(ctx.simple_expression()));
    }
    
    @Override
    public void exitSimple_expression(Simple_expressionContext ctx)
    {
        processStraight(ctx);
        if (ctx.logical_expression().size() == 1)
            values.put(ctx, values.get(ctx.logical_expression(0)));
    }
    
    @Override
    public void exitLogical_expression(Logical_expressionContext ctx)
    {
        StringBuilder sb = new StringBuilder();
        for (Logical_termContext ltc : ctx.logical_term())
        {
            if (ltc != ctx.logical_term(0))
                sb.append(" " + getValue(ctx.OR(0)) + " ");
            sb.append(getValue(ltc));
        }
        setValue(ctx, sb.toString());
        if (ctx.logical_term().size() == 1)
            values.put(ctx, values.get(ctx.logical_term(0)));
    }
    
    @Override
    public void exitLogical_term(Logical_termContext ctx)
    {
        StringBuilder sb = new StringBuilder();
        for (Logical_factorContext lfc : ctx.logical_factor())
        {
            if (lfc != ctx.logical_factor(0))
                sb.append(" " + getValue(ctx.AND(0)) + " ");
            sb.append(getValue(lfc));
        }
        setValue(ctx, sb.toString());
        if (ctx.logical_factor().size() == 1)
            values.put(ctx, values.get(ctx.logical_factor(0)));
    }
    
    @Override
    public void exitLogical_factor(Logical_factorContext ctx)
    {
        StringBuilder sb = new StringBuilder();
        if (ctx.NOT() != null)
            sb.append(getValue(ctx.NOT()) + " ");
        sb.append(getValue(ctx.relation()));
        setValue(ctx, sb.toString());
        values.put(ctx, values.get(ctx.relation()));
    }
    
    @Override
    public void exitRelation(RelationContext ctx)
    {
        processStraight(ctx);
        if (ctx.arithmetic_expression().size() == 1)
            values.put(ctx, values.get(ctx.arithmetic_expression(0)));
    }
    
    @Override
    public void exitRel_op(Rel_opContext ctx)
    {
        processStraight(ctx);
    }
    
    @Override
    public void exitArithmetic_expression(Arithmetic_expressionContext ctx)
    {
        processStraight(ctx);
        if (ctx.term().size() == 1 && ctx.add_op().size() == 0)
            values.put(ctx, values.get(ctx.term(0)));
    }
    
    @Override
    public void exitAdd_op(Add_opContext ctx)
    {
        processStraight(ctx);
    }
    
    @Override
    public void exitTerm(TermContext ctx)
    {
        processStraight(ctx);
        if (ctx.factor().size() == 1)
            values.put(ctx, values.get(ctx.factor(0)));
    }
    
    @Override
    public void exitMul_op(Mul_opContext ctx)
    {
        processStraight(ctx);
    }
    
    @Override
    public void exitFactor(FactorContext ctx)
    {
        processStraight(ctx);
        values.put(ctx, values.get(ctx.primary(0)));
    }
    
    @Override
    public void exitPrimary(PrimaryContext ctx)
    {
        StringBuilder sb = new StringBuilder();
        if (ctx.UNSIGNED_NUMBER() != null)
            sb.append(ctx.getText());
        else if (ctx.STRING() != null)
            sb.append(ctx.STRING());
        else if (ctx.TRUE() != null)
            sb.append("true");
        else if (ctx.FALSE() != null)
            sb.append("false");
        else if (ctx.function_call_args() != null)
        {
            if (ctx.name_path() != null)
                sb.append(getValue(ctx.name_path()));
            else if (ctx.DER() != null)
                sb.append("der");
            sb.append(getValue(ctx.function_call_args()));
        }
        else if (ctx.component_reference() != null)
        {
            sb.append(getValue(ctx.component_reference()));
            values.put(ctx, values.get(ctx.component_reference()));
        }
        else if (ctx.output_expression_list() != null)
            sb.append("(" + getValue(ctx.output_expression_list()) + ")");
        else if (ctx.expression_list().size() != 0)
        {
            sb.append("[");
            for (Expression_listContext elc : ctx.expression_list())
            {
                if (elc != ctx.expression_list(0))
                    sb.append(";");
                sb.append(getValue(elc));
            }
            sb.append("]");
        }
        setValue(ctx, sb.toString());
    }
    
    @Override
    public void exitFunction_call_args(Function_call_argsContext ctx)
    {
        processStraight(ctx);
    }
    
    @Override
    public void exitName_path(Name_pathContext ctx)
    {
        processStraight(ctx);
    }
    
    @Override
    public void exitComponent_reference(Component_referenceContext ctx)
    {
        String text = ctx.getText();
        if (text.equals("time"))
        {
            setValue(ctx, "time");
            return;
        }
        if (substitutions != null)
        {
            String sub = substitutions.get(text);
            if (sub != null)
            {
                setValue(ctx, sub);
                return;
            }
        }
        List<String> ls = new ArrayList<String>(ctx.IDENT().size());
        for (int i = 0; i < ctx.IDENT().size(); i++)
        {
            String v = getValue(ctx.IDENT(i));
            ls.add(v);
        }
        values.put(ctx, ls);
        // if parameters, replace path by parameter
        if (parameters == null)
        {
            processStraight(ctx);
        }
        else
        {
            // find existing key
            List<String> k = null;
            loop: for (List<String> key : parameters.keySet())
            {
                if (key.size() == ls.size())
                {
                    for (int i = 0; i < key.size(); i++)
                    {
                        if (!key.get(i).equals(ls.get(i)))
                            continue loop;
                    }
                    // at this point, the key is ls
                    k = key;
                    break loop;
                }
            }
            
            if (k != null)
            {
                setValue(ctx, parameters.get(k));
            }
            else
            {
                StringBuilder sb = new StringBuilder();
                for (int i = 0; i < ls.size(); i++)
                {
                    if (i != 0)
                        sb.append("__");
                    sb.append(ls.get(i));
                }
                String sbs = sb.toString();
                parameters.put(ls, sbs);
                setValue(ctx, sbs);
            }
        }
    }
    
    @Override
    public void exitFunction_arguments(Function_argumentsContext ctx)
    {
        StringBuilder sb = new StringBuilder();
        if (ctx.named_arguments() != null)
            sb.append(getValue(ctx.named_arguments()));
        else
        {
            sb.append(getValue(ctx.function_argument()));
            if (ctx.function_arguments() != null)
                sb.append("," + getValue(ctx.function_arguments()));
        }
        setValue(ctx, sb.toString());
    }
    
    @Override
    public void exitNamed_arguments(Named_argumentsContext ctx)
    {
        processStraight(ctx);
    }
    
    @Override
    public void exitNamed_argument(Named_argumentContext ctx)
    {
        processStraight(ctx);
    }
    
    @Override
    public void exitFunction_argument(Function_argumentContext ctx)
    {
        StringBuilder sb = new StringBuilder();
        if (ctx.expression() != null)
            sb.append(getValue(ctx.expression()));
        else
        {
            sb.append("function " + getValue(ctx.IDENT()) + "(");
            sb.append(getValue(ctx.named_arguments()));
            sb.append(")");
        }
        setValue(ctx, sb.toString());
    }
    
    @Override
    public void exitOutput_expression_list(Output_expression_listContext ctx)
    {
        processStraight(ctx);
    }
    
    @Override
    public void exitExpression_list(Expression_listContext ctx)
    {
        processStraight(ctx);
    }
    
    @Override
    public void exitArray_subscripts(Array_subscriptsContext ctx)
    {
        processStraight(ctx);
    }
    
    @Override
    public void exitSubscript(SubscriptContext ctx)
    {
        processStraight(ctx);
    }
    
    @Override
    public void visitTerminal(TerminalNode node)
    {
        setValue(node, node.getText());
    }
    
    /**
     * Concatenate children of the node
     * 
     * @param ctx syntax tree node
     */
    protected void processStraight(ParseTree ctx)
    {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < ctx.getChildCount(); i++)
        {
            String v = getValue(ctx.getChild(i));
            if (v != null)
                sb.append(v);
        }
        String ret = sb.toString();
        setValue(ctx, ret);
    }
    
    /**
     * Add space between children of the syntax tree node
     * 
     * @param ctx syntax tree node
     */
    protected void processWithSpaces(ParseTree ctx)
    {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < ctx.getChildCount(); i++)
        {
            if (i != 0)
                sb.append(" ");
            String v = getValue(ctx.getChild(i));
            if (v != null)
                sb.append(v);
        }
        String ret = sb.toString();
        setValue(ctx, ret);
    }
    
}
