package com.engisis.sysphs.translation.simulink;

import java.util.Hashtable;
import java.util.LinkedList;
import java.util.List;

import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.tree.TerminalNode;
import org.apache.log4j.Logger;
import org.eclipse.emf.common.util.EList;

import com.engisis.sysphs.language.simscape.SComponent;
import com.engisis.sysphs.language.simscape.SComponentReference;
import com.engisis.sysphs.language.simscape.SDomain;
import com.engisis.sysphs.language.simscape.SInput;
import com.engisis.sysphs.language.simscape.SMember;
import com.engisis.sysphs.language.simscape.SNode;
import com.engisis.sysphs.language.simscape.SOutput;
import com.engisis.sysphs.language.simscape.SParameter;
import com.engisis.sysphs.language.simscape.SVariable;
import com.engisis.sysphs.language.simulink.SFContinuousStateVariable;
import com.engisis.sysphs.language.simulink.SFDWorkAssignment;
import com.engisis.sysphs.language.simulink.SFDWorkVariable;
import com.engisis.sysphs.language.simulink.SFDerivativeAssignment;
import com.engisis.sysphs.language.simulink.SFDiscreteStateVariable;
import com.engisis.sysphs.language.simulink.SFInputVariable;
import com.engisis.sysphs.language.simulink.SFOutputAssignment;
import com.engisis.sysphs.language.simulink.SFOutputVariable;
import com.engisis.sysphs.language.simulink.SFUpdateAssignment;
import com.engisis.sysphs.language.simulink.SFVariable;
import com.engisis.sysphs.language.simulink.SFVariableAssignment;
import com.engisis.sysphs.language.simulink.SFunction;
import com.engisis.sysphs.language.simulink.SFunction1;
import com.engisis.sysphs.language.simulink.SFunction2;
import com.engisis.sysphs.language.simulink.SimulinkFactory;
import com.engisis.sysphs.util.ExpressionLanguageToSimulation;
import com.engisis.sysphs.util.ExpressionLanguageParser;
import com.engisis.sysphs.util.ExpressionLanguageParser.Component_referenceContext;
import com.engisis.sysphs.util.ExpressionLanguageParser.EquationContext;
import com.engisis.sysphs.util.ExpressionLanguageParser.ExpressionContext;
import com.engisis.sysphs.util.ExpressionLanguageParser.For_indexContext;
import com.engisis.sysphs.util.ExpressionLanguageParser.For_indicesContext;
import com.engisis.sysphs.util.ExpressionLanguageParser.For_statementContext;
import com.engisis.sysphs.util.ExpressionLanguageParser.If_equationContext;
import com.engisis.sysphs.util.ExpressionLanguageParser.If_statementContext;
import com.engisis.sysphs.util.ExpressionLanguageParser.PrimaryContext;
import com.engisis.sysphs.util.ExpressionLanguageParser.Simple_expressionContext;
import com.engisis.sysphs.util.ExpressionLanguageParser.StatementContext;
import com.engisis.sysphs.util.ExpressionLanguageParser.While_statementContext;

/**
 * Translator from the SysPhS expression language to MATLAB
 * 
 * @author barbau
 *
 */
public class ExpressionLanguageToMatlabTranslator extends ExpressionLanguageToSimulation
{
    private static final Logger log = Logger.getLogger(ExpressionLanguageToMatlabTranslator.class);
    
    /**
     * current component
     */
    private SComponent scomponent;
    /**
     * current S-Function
     */
    private SFunction sfunction;
    /**
     * involved assignment
     */
    private SFVariableAssignment sfassignment;
    /**
     * current context, non-null if the "unit fix" should be performed
     */
    private PrimaryContext convertContext;
    
    /**
     * non-null if in equation
     */
    private ExpressionContext ec = null;
    
    /**
     * substitution table used for Simscape equations
     */
    private Hashtable<String, String> ht = new Hashtable<String, String>();
    /**
     * substitution table used in the left-hand side of equations (for
     * S-Functions)
     */
    private Hashtable<String, String> htLHS = new Hashtable<String, String>();
    /**
     * substitution table used in the right-hand side of equations (for
     * S-Functions)
     */
    private Hashtable<String, String> htRHS = new Hashtable<String, String>();
    
    /**
     * list of input variables, used to identify inputs in output assignments
     * and set direct feedthrough
     */
    private List<SFInputVariable> lsfiv;
    
    /**
     * List of IDs encountered in the equation
     */
    private List<String> lsp;
    
    @Override
    protected void beforeParsing()
    {
        super.beforeParsing();
        convertContext = null;
        sfassignment = null;
    }
    
    @Override
    protected void afterParsing()
    {
        super.afterParsing();
        scomponent = null;
        ht = null;
        sfunction = null;
        htLHS = null;
        htRHS = null;
        lsfiv = null;
        lsp = null;
    }
    
    /**
     * Initiates the translator for Simscape component
     * 
     * @param scomponent
     *            Simscape component
     * @param ht
     *            substitution table
     */
    public void prepareNextEquationParsing(SComponent scomponent, Hashtable<String, String> ht)
    {
        this.scomponent = scomponent;
        this.ht = ht;
    }
    
    /**
     * Initiates the translator for S-Functions
     * 
     * @param sfunction
     *            S-Function
     * @param htLHS
     *            LHS substitution table
     * @param htRHS
     *            RHS substitution table
     */
    public void prepareNextEquationParsing(SFunction sfunction, Hashtable<String, String> htLHS,
            Hashtable<String, String> htRHS, List<SFInputVariable> lsfiv)
    {
        this.sfunction = sfunction;
        this.htLHS = htLHS;
        this.htRHS = htRHS;
        this.lsfiv = lsfiv;
        this.lsp = new LinkedList<String>();
    }
    
    /**
     * Returns the found variable assignment
     * 
     * @return found variable assignment
     */
    public SFVariableAssignment getAssignment()
    {
        return sfassignment;
    }
    
    @Override
    public void enterEquation(EquationContext ctx)
    {
        if (ctx.EQUALS() != null)
        {
            if (sfunction instanceof SFunction1)
                htRHS.put("time", "t");
            else if (sfunction instanceof SFunction2)
                htRHS.put("time", "block.CurrentTime");
            
        }
        super.enterEquation(ctx);
    }
    
    @Override
    public void exitEquation(EquationContext ctx)
    {
        if (sfunction != null && ctx.EQUALS() != null)
        {
            if (htLHS != null)
            {
                String key = getValue(ctx.simple_expression());
                String val = htLHS.get(key);
                if (val != null)
                {
                    setValue(ctx.simple_expression(), val);
                    for (SFVariable sfv : sfunction.getVariables())
                    {
                        if (sfv instanceof SFContinuousStateVariable)
                        {
                            if (key.equals("der(" + sfv.getName() + ")"))
                            {
                                SFDerivativeAssignment assignment = SimulinkFactory.eINSTANCE
                                        .createSFDerivativeAssignment();
                                assignment.setVariable((SFContinuousStateVariable) sfv);
                                sfassignment = assignment;
                            }
                            else if (key.equals(sfv.getName()))
                            {
                                SFDerivativeAssignment assignment = SimulinkFactory.eINSTANCE
                                        .createSFDerivativeAssignment();
                                assignment.setVariable((SFContinuousStateVariable) sfv);
                                sfassignment = assignment;
                            }
                        }
                        else if (sfv instanceof SFDiscreteStateVariable && key.equals("next(" + sfv.getName() + ")"))
                        {
                            SFUpdateAssignment assignment = SimulinkFactory.eINSTANCE.createSFUpdateAssignment();
                            assignment.setVariable((SFDiscreteStateVariable) sfv);
                            sfassignment = assignment;
                        }
                        else if (sfv instanceof SFOutputVariable && key.equals(sfv.getName()))
                        {
                            SFOutputAssignment assignment = SimulinkFactory.eINSTANCE.createSFOutputAssignment();
                            assignment.setVariable((SFOutputVariable) sfv);
                            sfassignment = assignment;
                            // set direct feedthrough
                            if (lsp != null && lsfiv != null)
                                for(SFInputVariable sfiv : lsfiv)
                                    if (lsp.contains(sfiv.getName()))
                                        sfiv.setDirectFeedthrough(true);
                        }
                        else if (sfv instanceof SFDWorkVariable && key.equals(sfv.getName()))
                        {
                            SFDWorkAssignment assignment = SimulinkFactory.eINSTANCE.createSFDWorkAssignment();
                            assignment.setVariable((SFDWorkVariable) sfv);
                            sfassignment = assignment;
                        }
                    }
                }
            }
            
            if (sfassignment instanceof SFDWorkAssignment)
            {
                SFDWorkVariable sfv = ((SFDWorkAssignment) sfassignment).getVariable();
                EList<Integer> dims = sfv.getDimensions();
                if (dims.size() == 2)
                {
                    setValue(ctx, getValue(ctx.simple_expression()) + "=reshape(" + getValue(ctx.expression()) + ","
                            + (dims.get(0).intValue() * dims.get(1).intValue()) + ",[])");
                    afterParsing();
                    return;
                }
            }
            setValue(ctx, getValue(ctx.simple_expression()) + "=" + getValue(ctx.expression()));
            afterParsing();
            return;
        }
        super.exitEquation(ctx);
    }
    
    @Override
    public void enterSimple_expression(Simple_expressionContext ctx)
    {
        super.enterSimple_expression(ctx);
    }
    
    @Override
    public void exitSimple_expression(Simple_expressionContext ctx)
    {
        super.exitSimple_expression(ctx);
    }
    
    @Override
    public void enterExpression(ExpressionContext ctx)
    {
        if (ctx.getParent() instanceof EquationContext)
            ec = ctx;
        super.enterExpression(ctx);
    }
    
    // Rewriting language
    
    @Override
    public void exitIf_equation(If_equationContext ctx)
    {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < ctx.getChildCount(); i++)
        {
            ParseTree pt = ctx.getChild(i);
            if (pt instanceof TerminalNode)
            {
                int t = ((TerminalNode) pt).getSymbol().getType();
                if (t == ExpressionLanguageParser.IF)
                    sb.append("if");
                else if (t == ExpressionLanguageParser.ELSEIF)
                    sb.append("elseif");
                else if (t == ExpressionLanguageParser.ELSE)
                    sb.append("else\n");
                else if (t == ExpressionLanguageParser.END)
                {
                    sb.append("end");
                    break;
                }
            }
            else if (pt instanceof ExpressionContext)
                sb.append(" " + getValue(pt) + "\n");
            else if (pt instanceof EquationContext)
                sb.append(getValue(pt) + ";\n");
        }
        setValue(ctx, sb.toString());
    }
    
    @Override
    public void exitIf_statement(If_statementContext ctx)
    {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < ctx.getChildCount(); i++)
        {
            ParseTree pt = ctx.getChild(i);
            if (pt instanceof TerminalNode)
            {
                int t = ((TerminalNode) pt).getSymbol().getType();
                if (t == ExpressionLanguageParser.IF)
                    sb.append("if");
                else if (t == ExpressionLanguageParser.ELSEIF)
                    sb.append("elseif");
                else if (t == ExpressionLanguageParser.ELSE)
                    sb.append("else\n");
                else if (t == ExpressionLanguageParser.END)
                {
                    sb.append("end");
                    break;
                }
            }
            else if (pt instanceof ExpressionContext)
                sb.append(" " + getValue(pt) + "\n");
            else if (pt instanceof StatementContext)
                sb.append(getValue(pt) + ";\n");
        }
        setValue(ctx, sb.toString());
    }
    
    @Override
    public void exitFor_statement(For_statementContext ctx)
    {
        StringBuilder sb = new StringBuilder();
        for (For_indexContext fic : ctx.for_indices().for_index())
        {
            if (fic != ctx.for_indices().for_index(0))
                sb.append("\n");
            sb.append("for " + getValue(fic) + "\n");
            for (StatementContext sc : ctx.statement())
                sb.append(getValue(sc) + ";\n");
            sb.append("end");
        }
        setValue(ctx, sb.toString());
    }
    
    @Override
    public void exitFor_indices(For_indicesContext ctx)
    {
        processStraight(ctx);
    }
    
    @Override
    public void exitWhile_statement(While_statementContext ctx)
    {
        StringBuilder sb = new StringBuilder();
        sb.append("while " + getValue(ctx.expression()));
        for (StatementContext sc : ctx.statement())
            sb.append(getValue(sc) + ";\n");
        sb.append("end");
        setValue(ctx, sb.toString());
    }
    
    @Override
    public void exitExpression(ExpressionContext ctx)
    {
        if (ctx.getParent() instanceof EquationContext)
            ec = null;
        StringBuilder sb = new StringBuilder();
        if (ctx.simple_expression() != null)
            sb.append(getValue(ctx.simple_expression()));
        else
        {
            for (int i = 0; i < ctx.getChildCount(); i++)
            {
                ParseTree pt = ctx.getChild(i);
                if (pt instanceof TerminalNode)
                {
                    int t = ((TerminalNode) pt).getSymbol().getType();
                    if (t == ExpressionLanguageParser.IF)
                        sb.append("if ");
                    else if (t == ExpressionLanguageParser.THEN)
                        sb.append(" then ");
                    else if (t == ExpressionLanguageParser.ELSEIF)
                        sb.append(" elseif ");
                    else if (t == ExpressionLanguageParser.ELSE)
                        sb.append(" else ");
                }
                else if (pt instanceof ExpressionContext)
                    sb.append(getValue(pt));
            }
        }
        setValue(ctx, sb.toString());
    }
    
    @Override
    public void enterPrimary(PrimaryContext ctx)
    {
        super.enterPrimary(ctx);
        if (ctx.function_call_args() != null && ctx.name_path() != null && scomponent != null)
        {
            String name = ctx.name_path().getText();
            if (name.equals("sin") || name.equals("cos") || name.equals("tan") || name.equals("asin")
                    || name.equals("acos") || name.equals("atan") || name.equals("atan2") || name.equals("sinh")
                    || name.equals("cosh") || name.equals("tanh") || name.equals("log") || name.equals("log10")
                    || name.equals("exp"))
            {
                convertContext = ctx;
            }
        }
    }
    
    @Override
    public void exitPrimary(PrimaryContext ctx)
    {
        super.exitPrimary(ctx);
        if (ctx == convertContext)
        {
            convertContext = null;
        }
        String key = getValue(ctx);
        // replace only for RHS.
        // LHS should be handled in equation, because of the possibility of both
        // x and der(x) being present
        if (ec != null && htRHS != null)
        {
            String val = htRHS.get(key);
            if (val != null)
                setValue(ctx, val);
            if (lsp != null)
                lsp.add(key);
        }
        
    }
    
    @Override
    public void exitComponent_reference(Component_referenceContext ctx)
    {
        // process as usual
        super.exitComponent_reference(ctx);
        String val = getValue(ctx);
        
        // perform substitution if needed
        if (ht != null)
        {
            String rep = ht.get(val);
            if (rep != null)
            {
                val = rep;
                setValue(ctx, val);
            }
        }
        
        // take care of unit
        if (convertContext != null)
        {
            SComponent sc = scomponent;
            SDomain sd = null;
            SMember sm = null;
            String unit = null;
            
            int size = ctx.IDENT().size();
            if (size == 1 && ctx.IDENT(0).getText().equals("time"))
            {
                setValue(ctx, "(" + val + "/{1,'s'})");
                return;
            }
            for (int i = 0; i < size; i++)
            {
                if (sc != null)
                {
                    sm = sc.getMember(ctx.IDENT(i).getText());
                    if (sm instanceof SNode)
                    {
                        sc = null;
                        sd = ((SNode) sm).getDomain();
                    }
                    else if (sm instanceof SComponentReference)
                    {
                        sc = ((SComponentReference) sm).getComponent();
                        sd = null;
                    }
                    else if (sm instanceof SParameter)
                        unit = ((SParameter) sm).getUnit();
                    else if (sm instanceof SVariable)
                        unit = ((SVariable) sm).getUnit();
                    else if (sm instanceof SInput)
                        unit = ((SInput) sm).getUnit();
                    else if (sm instanceof SOutput)
                        unit = ((SOutput) sm).getUnit();
                }
                else if (sd != null)
                {
                    sm = sd.getMember(ctx.IDENT(i).getText());
                }
            }
            if (unit != null)
                setValue(ctx, "(" + val + "/{1,'" + unit + "'})");
            else
                log.warn("Member without unit: " + ctx.getText() + " in " + scomponent.getName());
        }
    }
    
    @Override
    public void visitTerminal(TerminalNode node)
    {
        int t = node.getSymbol().getType();
        if (t == ExpressionLanguageParser.EQUALS)
            setValue(node, "==");
        else if (t == ExpressionLanguageParser.LESSGT)
            setValue(node, "~=");
        else if (t == ExpressionLanguageParser.NOT)
            setValue(node, "~");
        else if (t == ExpressionLanguageParser.AND)
            setValue(node, "&&");
        else if (t == ExpressionLanguageParser.OR)
            setValue(node, "||");
        else if (t == ExpressionLanguageParser.IN)
            setValue(node, "=");
        else if (t == ExpressionLanguageParser.ASSIGN)
            setValue(node, "=");
        else
            super.visitTerminal(node);
    }
    
}
