/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.ksql.execution.interpreter;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Streams;
import com.google.errorprone.annotations.Immutable;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import io.confluent.ksql.execution.codegen.helpers.ArrayAccess;
import io.confluent.ksql.execution.codegen.helpers.InListEvaluator;
import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression;
import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression;
import io.confluent.ksql.execution.expression.tree.BetweenPredicate;
import io.confluent.ksql.execution.expression.tree.BooleanLiteral;
import io.confluent.ksql.execution.expression.tree.BytesLiteral;
import io.confluent.ksql.execution.expression.tree.Cast;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.CreateArrayExpression;
import io.confluent.ksql.execution.expression.tree.CreateMapExpression;
import io.confluent.ksql.execution.expression.tree.CreateStructExpression;
import io.confluent.ksql.execution.expression.tree.DateLiteral;
import io.confluent.ksql.execution.expression.tree.DecimalLiteral;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.DoubleLiteral;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.ExpressionVisitor;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.InListExpression;
import io.confluent.ksql.execution.expression.tree.InPredicate;
import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
import io.confluent.ksql.execution.expression.tree.IntervalUnit;
import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate;
import io.confluent.ksql.execution.expression.tree.IsNullPredicate;
import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall;
import io.confluent.ksql.execution.expression.tree.LambdaVariable;
import io.confluent.ksql.execution.expression.tree.LikePredicate;
import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression;
import io.confluent.ksql.execution.expression.tree.LongLiteral;
import io.confluent.ksql.execution.expression.tree.NotExpression;
import io.confluent.ksql.execution.expression.tree.NullLiteral;
import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.SearchedCaseExpression;
import io.confluent.ksql.execution.expression.tree.SimpleCaseExpression;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.execution.expression.tree.SubscriptExpression;
import io.confluent.ksql.execution.expression.tree.TimeLiteral;
import io.confluent.ksql.execution.expression.tree.TimestampLiteral;
import io.confluent.ksql.execution.expression.tree.Type;
import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.WhenClause;
import io.confluent.ksql.execution.interpreter.ArithmeticInterpreter;
import io.confluent.ksql.execution.interpreter.CastInterpreter;
import io.confluent.ksql.execution.interpreter.ComparisonInterpreter;
import io.confluent.ksql.execution.interpreter.terms.ColumnReferenceTerm;
import io.confluent.ksql.execution.interpreter.terms.CreateArrayTerm;
import io.confluent.ksql.execution.interpreter.terms.CreateMapTerm;
import io.confluent.ksql.execution.interpreter.terms.DereferenceTerm;
import io.confluent.ksql.execution.interpreter.terms.FunctionCallTerm;
import io.confluent.ksql.execution.interpreter.terms.InPredicateTerm;
import io.confluent.ksql.execution.interpreter.terms.IsNotNullTerm;
import io.confluent.ksql.execution.interpreter.terms.IsNullTerm;
import io.confluent.ksql.execution.interpreter.terms.LambdaFunctionTerms;
import io.confluent.ksql.execution.interpreter.terms.LambdaVariableTerm;
import io.confluent.ksql.execution.interpreter.terms.LikeTerm;
import io.confluent.ksql.execution.interpreter.terms.LiteralTerms;
import io.confluent.ksql.execution.interpreter.terms.LogicalBinaryTerms;
import io.confluent.ksql.execution.interpreter.terms.NotTerm;
import io.confluent.ksql.execution.interpreter.terms.SearchedCaseTerm;
import io.confluent.ksql.execution.interpreter.terms.StructTerm;
import io.confluent.ksql.execution.interpreter.terms.SubscriptTerm;
import io.confluent.ksql.execution.interpreter.terms.Term;
import io.confluent.ksql.execution.util.CoercionUtil;
import io.confluent.ksql.execution.util.ExpressionTypeManager;
import io.confluent.ksql.execution.util.FunctionArgumentsUtil;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.function.GenericsUtil;
import io.confluent.ksql.function.KsqlScalarFunction;
import io.confluent.ksql.function.UdfFactory;
import io.confluent.ksql.function.types.ArrayType;
import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.function.types.ParamTypes;
import io.confluent.ksql.function.udf.Kudf;
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.SchemaConverters;
import io.confluent.ksql.schema.ksql.types.SqlArray;
import io.confluent.ksql.schema.ksql.types.SqlBaseType;
import io.confluent.ksql.schema.ksql.types.SqlDecimal;
import io.confluent.ksql.schema.ksql.types.SqlMap;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.Pair;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class TermCompiler
implements ExpressionVisitor<Term, Context> {
    private final FunctionRegistry functionRegistry;
    private final LogicalSchema schema;
    private final KsqlConfig ksqlConfig;
    private final ExpressionTypeManager expressionTypeManager;

    @SuppressFBWarnings(value={"EI_EXPOSE_REP2"})
    public TermCompiler(FunctionRegistry functionRegistry, LogicalSchema schema, KsqlConfig ksqlConfig, ExpressionTypeManager expressionTypeManager) {
        this.functionRegistry = functionRegistry;
        this.schema = schema;
        this.ksqlConfig = ksqlConfig;
        this.expressionTypeManager = expressionTypeManager;
    }

    private Term visitIllegalState(Expression expression) {
        throw new IllegalStateException(String.format("Expression type %s should never be visited.%nCheck if there's an existing issue: https://github.com/confluentinc/ksql/issues %nIf not, please file a new one with your expression.", ((Object)((Object)expression)).getClass()));
    }

    private Term visitUnsupported(Expression expression) {
        throw new UnsupportedOperationException(String.format("Not yet implemented: %s.visit%s.%nCheck if there's an existing issue: https://github.com/confluentinc/ksql/issues %nIf not, please file a new one with your expression.", this.getClass().getName(), ((Object)((Object)expression)).getClass().getSimpleName()));
    }

    @Override
    public Term visitType(Type node, Context context) {
        return this.visitIllegalState(node);
    }

    @Override
    public Term visitWhenClause(WhenClause whenClause, Context context) {
        return this.visitIllegalState(whenClause);
    }

    @Override
    public Term visitInPredicate(InPredicate inPredicate, Context context) {
        InPredicate preprocessed = InListEvaluator.preprocess(inPredicate, this.expressionTypeManager, context.getLambdaSqlTypeMapping());
        Term value = (Term)this.process(preprocessed.getValue(), context);
        List valueList = (List)preprocessed.getValueList().getValues().stream().map(v -> (Term)this.process((Expression)((Object)v), context)).collect(ImmutableList.toImmutableList());
        return new InPredicateTerm(value, valueList);
    }

    @Override
    public Term visitInListExpression(InListExpression inListExpression, Context context) {
        return this.visitUnsupported(inListExpression);
    }

    @Override
    public Term visitTimestampLiteral(TimestampLiteral node, Context context) {
        return LiteralTerms.of(node.getValue());
    }

    @Override
    public Term visitTimeLiteral(TimeLiteral node, Context context) {
        return LiteralTerms.of(node.getValue());
    }

    @Override
    public Term visitDateLiteral(DateLiteral node, Context context) {
        return LiteralTerms.of(node.getValue());
    }

    @Override
    public Term visitSimpleCaseExpression(SimpleCaseExpression simpleCaseExpression, Context context) {
        return this.visitUnsupported(simpleCaseExpression);
    }

    @Override
    public Term visitBooleanLiteral(BooleanLiteral node, Context context) {
        return LiteralTerms.of(node.getValue());
    }

    @Override
    public Term visitStringLiteral(StringLiteral node, Context context) {
        return LiteralTerms.of(node.getValue());
    }

    @Override
    public Term visitDoubleLiteral(DoubleLiteral node, Context context) {
        return LiteralTerms.of(node.getValue());
    }

    @Override
    public Term visitDecimalLiteral(DecimalLiteral decimalLiteral, Context context) {
        SqlType sqlType = DecimalUtil.fromValue((BigDecimal)decimalLiteral.getValue());
        return LiteralTerms.of(decimalLiteral.getValue(), sqlType);
    }

    @Override
    public Term visitNullLiteral(NullLiteral node, Context context) {
        return LiteralTerms.ofNull();
    }

    @Override
    public Term visitLambdaExpression(LambdaFunctionCall lambdaFunctionCall, Context context) {
        Term lambdaBody = (Term)this.process(lambdaFunctionCall.getBody(), context);
        ImmutableList.Builder nameToType = ImmutableList.builder();
        for (String lambdaArg : lambdaFunctionCall.getArguments()) {
            nameToType.add((Object)Pair.of((Object)lambdaArg, (Object)context.getLambdaSqlTypeMapping().get(lambdaArg)));
        }
        switch (lambdaFunctionCall.getArguments().size()) {
            case 1: {
                return new LambdaFunctionTerms.LambdaFunction1Term((List<Pair<String, SqlType>>)nameToType.build(), lambdaBody);
            }
            case 2: {
                return new LambdaFunctionTerms.LambdaFunction2Term((List<Pair<String, SqlType>>)nameToType.build(), lambdaBody);
            }
            case 3: {
                return new LambdaFunctionTerms.LambdaFunction3Term((List<Pair<String, SqlType>>)nameToType.build(), lambdaBody);
            }
        }
        throw new KsqlException("Interpreter only supports lambdas up to three arguments");
    }

    @Override
    public Term visitLambdaVariable(LambdaVariable lambdaVariable, Context context) {
        return new LambdaVariableTerm(lambdaVariable.getLambdaCharacter(), context.getLambdaSqlTypeMapping().get(lambdaVariable.getLambdaCharacter()));
    }

    @Override
    public Term visitIntervalUnit(IntervalUnit exp, Context context) {
        return LiteralTerms.of(exp.getUnit());
    }

    @Override
    public Term visitUnqualifiedColumnReference(UnqualifiedColumnReferenceExp node, Context context) {
        Column schemaColumn = (Column)this.schema.findValueColumn(node.getColumnName()).orElseThrow(() -> new KsqlException("Field not found: " + node.getColumnName()));
        return new ColumnReferenceTerm(schemaColumn.index(), schemaColumn.type());
    }

    @Override
    public Term visitQualifiedColumnReference(QualifiedColumnReferenceExp node, Context context) {
        throw new IllegalStateException("Qualified column reference must be resolved to unqualified reference before interpreter");
    }

    @Override
    public Term visitDereferenceExpression(DereferenceExpression node, Context context) {
        SqlType functionReturnSchema = this.expressionTypeManager.getExpressionSqlType(node, context.getLambdaSqlTypeMapping());
        Term struct = (Term)this.process(node.getBase(), context);
        if (struct.getSqlType().baseType() != SqlBaseType.STRUCT) {
            throw new KsqlException("Can only dereference Struct type, instead got " + struct.getSqlType());
        }
        return new DereferenceTerm(struct, node.getFieldName(), functionReturnSchema);
    }

    @Override
    public Term visitLongLiteral(LongLiteral node, Context context) {
        return LiteralTerms.of(node.getValue());
    }

    @Override
    public Term visitIntegerLiteral(IntegerLiteral node, Context context) {
        return LiteralTerms.of(node.getValue());
    }

    @Override
    public Term visitBytesLiteral(BytesLiteral node, Context context) {
        return LiteralTerms.of(node.getValue());
    }

    @Override
    public Term visitFunctionCall(FunctionCall node, Context context) {
        UdfFactory udfFactory = this.functionRegistry.getUdfFactory(node.getName());
        FunctionArgumentsUtil.FunctionTypeInfo argumentsAndContext = FunctionArgumentsUtil.getFunctionTypeInfo(this.expressionTypeManager, node, udfFactory, context.getLambdaSqlTypeMapping());
        List<FunctionArgumentsUtil.ArgumentInfo> argumentInfos = argumentsAndContext.getArgumentInfos();
        KsqlScalarFunction function = argumentsAndContext.getFunction();
        SqlType functionReturnSchema = argumentsAndContext.getReturnType();
        Class javaClass = SchemaConverters.sqlToJavaConverter().toJavaType(functionReturnSchema);
        List<Expression> arguments = node.getArguments();
        ArrayList<Term> args = new ArrayList<Term>();
        for (int i = 0; i < arguments.size(); ++i) {
            Expression arg = arguments.get(i);
            SqlType sqlType = argumentInfos.get(i).getSqlArgument().getSqlType().orElse(null);
            ParamType paramType = i >= function.parameters().size() - 1 && function.isVariadic() ? ((ArrayType)Iterables.getLast((Iterable)function.parameters())).element() : (ParamType)function.parameters().get(i);
            Term argTerm = (Term)this.process(this.convertArgument(arg, sqlType, paramType), new Context(argumentInfos.get(i).getLambdaSqlTypeMapping()));
            args.add(argTerm);
        }
        Kudf kudf = function.newInstance(this.ksqlConfig);
        return new FunctionCallTerm(kudf, args, javaClass, functionReturnSchema);
    }

    private Expression convertArgument(Expression argument, SqlType argType, ParamType funType) {
        if (argType == null || GenericsUtil.hasGenerics((ParamType)funType) || SchemaConverters.sqlToFunctionConverter().toFunctionType(argType).equals((Object)funType)) {
            return argument;
        }
        SqlDecimal target = funType == ParamTypes.DECIMAL ? DecimalUtil.toSqlDecimal((SqlType)argType) : SchemaConverters.functionToSqlConverter().toSqlType(funType);
        return new Cast(argument, new Type((SqlType)target));
    }

    @Override
    public Term visitLogicalBinaryExpression(LogicalBinaryExpression node, Context context) {
        Term left = (Term)this.process(node.getLeft(), context);
        Term right = (Term)this.process(node.getRight(), context);
        if (left.getSqlType().baseType() != SqlBaseType.BOOLEAN || right.getSqlType().baseType() != SqlBaseType.BOOLEAN) {
            throw new KsqlException(String.format("Logical binary expects two boolean values.  Actual %s and %s", left.getSqlType(), right.getSqlType()));
        }
        return LogicalBinaryTerms.create(node.getType(), left, right);
    }

    @Override
    public Term visitNotExpression(NotExpression node, Context context) {
        Term term = (Term)this.process(node.getValue(), context);
        if (term.getSqlType().baseType() != SqlBaseType.BOOLEAN) {
            throw new IllegalStateException(String.format("Not expression expects a boolean value.  Actual %s", term.getSqlType()));
        }
        return new NotTerm(term);
    }

    @Override
    public Term visitComparisonExpression(ComparisonExpression node, Context context) {
        Term left = (Term)this.process(node.getLeft(), context);
        Term right = (Term)this.process(node.getRight(), context);
        return ComparisonInterpreter.doComparison(node.getType(), left, right);
    }

    @Override
    public Term visitCast(Cast node, Context context) {
        Term term = (Term)this.process(node.getExpression(), context);
        SqlType from = term.getSqlType();
        SqlType to = node.getType().getSqlType();
        return CastInterpreter.cast(term, from, to, this.ksqlConfig);
    }

    @Override
    public Term visitIsNullPredicate(IsNullPredicate node, Context context) {
        Term value = (Term)this.process(node.getValue(), context);
        return new IsNullTerm(value);
    }

    @Override
    public Term visitIsNotNullPredicate(IsNotNullPredicate node, Context context) {
        Term value = (Term)this.process(node.getValue(), context);
        return new IsNotNullTerm(value);
    }

    @Override
    public Term visitArithmeticUnary(ArithmeticUnaryExpression node, Context context) {
        Term value = (Term)this.process(node.getValue(), context);
        return ArithmeticInterpreter.doUnaryArithmetic(node.getSign(), value);
    }

    @Override
    public Term visitArithmeticBinary(ArithmeticBinaryExpression node, Context context) {
        Term left = (Term)this.process(node.getLeft(), context);
        Term right = (Term)this.process(node.getRight(), context);
        SqlType schema = this.expressionTypeManager.getExpressionSqlType(node, context.getLambdaSqlTypeMapping());
        return ArithmeticInterpreter.doBinaryArithmetic(node.getOperator(), left, right, schema, this.ksqlConfig);
    }

    @Override
    public Term visitSearchedCaseExpression(SearchedCaseExpression node, Context context) {
        SqlType resultSchema = this.expressionTypeManager.getExpressionSqlType(node, context.getLambdaSqlTypeMapping());
        List operandResultTerms = (List)node.getWhenClauses().stream().map(whenClause -> Pair.of(this.process(whenClause.getOperand(), context), this.process(whenClause.getResult(), context))).collect(ImmutableList.toImmutableList());
        Optional<Term> defaultValueTerm = node.getDefaultValue().map(exp -> (Term)this.process(node.getDefaultValue().get(), context));
        return new SearchedCaseTerm(operandResultTerms, defaultValueTerm, resultSchema);
    }

    @Override
    public Term visitLikePredicate(LikePredicate node, Context context) {
        Term patternString = (Term)this.process(node.getPattern(), context);
        Term valueString = (Term)this.process(node.getValue(), context);
        return new LikeTerm(patternString, valueString, node.getEscape());
    }

    @Override
    public Term visitSubscriptExpression(SubscriptExpression node, Context context) {
        SqlType internalSchema = this.expressionTypeManager.getExpressionSqlType(node.getBase(), context.getLambdaSqlTypeMapping());
        switch (internalSchema.baseType()) {
            case ARRAY: {
                SqlArray array = (SqlArray)internalSchema;
                Term listTerm = (Term)this.process(node.getBase(), context);
                Term indexTerm = (Term)this.process(node.getIndex(), context);
                return new SubscriptTerm(listTerm, indexTerm, (o, index) -> ArrayAccess.arrayAccess((List)o, (Integer)index), array.getItemType());
            }
            case MAP: {
                SqlMap mapSchema = (SqlMap)internalSchema;
                Term mapTerm = (Term)this.process(node.getBase(), context);
                Term keyTerm = (Term)this.process(node.getIndex(), context);
                return new SubscriptTerm(mapTerm, keyTerm, (map, key) -> ((Map)map).get(key), mapSchema.getValueType());
            }
        }
        throw new UnsupportedOperationException();
    }

    @Override
    public Term visitCreateArrayExpression(CreateArrayExpression exp, Context context) {
        List<Expression> expressions = CoercionUtil.coerceUserList(exp.getValues(), this.expressionTypeManager, context.getLambdaSqlTypeMapping()).expressions();
        List arrayTerms = (List)expressions.stream().map(value -> (Term)this.process((Expression)((Object)value), context)).collect(ImmutableList.toImmutableList());
        SqlType sqlType = this.expressionTypeManager.getExpressionSqlType(exp, context.getLambdaSqlTypeMapping());
        return new CreateArrayTerm(arrayTerms, sqlType);
    }

    @Override
    public Term visitCreateMapExpression(CreateMapExpression exp, Context context) {
        ImmutableMap<Expression, Expression> map = exp.getMap();
        List<Expression> keys = CoercionUtil.coerceUserList((Collection<Expression>)map.keySet(), this.expressionTypeManager, context.getLambdaSqlTypeMapping()).expressions();
        List<Expression> values = CoercionUtil.coerceUserList((Collection<Expression>)map.values(), this.expressionTypeManager, context.getLambdaSqlTypeMapping()).expressions();
        Iterable pairs = () -> Streams.zip(keys.stream(), values.stream(), Pair::of).iterator();
        ImmutableMap.Builder mapTerms = ImmutableMap.builder();
        for (Pair p : pairs) {
            mapTerms.put(this.process((Expression)((Object)p.getLeft()), context), this.process((Expression)((Object)p.getRight()), context));
        }
        SqlType resultType = this.expressionTypeManager.getExpressionSqlType(exp, context.getLambdaSqlTypeMapping());
        return new CreateMapTerm((Map<Term, Term>)mapTerms.build(), resultType);
    }

    @Override
    public Term visitStructExpression(CreateStructExpression node, Context context) {
        ImmutableMap.Builder nameToTerm = ImmutableMap.builder();
        for (CreateStructExpression.Field field : node.getFields()) {
            nameToTerm.put((Object)field.getName(), this.process(field.getValue(), context));
        }
        SqlType resultType = this.expressionTypeManager.getExpressionSqlType(node, context.getLambdaSqlTypeMapping());
        return new StructTerm((Map<String, Term>)nameToTerm.build(), resultType);
    }

    @Override
    public Term visitBetweenPredicate(BetweenPredicate node, Context context) {
        LogicalBinaryExpression and = new LogicalBinaryExpression(LogicalBinaryExpression.Type.AND, new ComparisonExpression(ComparisonExpression.Type.GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin()), new ComparisonExpression(ComparisonExpression.Type.LESS_THAN_OR_EQUAL, node.getValue(), node.getMax()));
        return (Term)this.process(and, context);
    }

    @Immutable
    static final class Context {
        private final ImmutableMap<String, SqlType> lambdaSqlTypeMapping;

        Context() {
            this(new HashMap<String, SqlType>());
        }

        Context(Map<String, SqlType> mapping) {
            this.lambdaSqlTypeMapping = ImmutableMap.copyOf(mapping);
        }

        Map<String, SqlType> getLambdaSqlTypeMapping() {
            return this.lambdaSqlTypeMapping;
        }
    }
}

