/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.ksql.execution.util;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.execution.expression.formatter.ExpressionFormatter;
import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression;
import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression;
import io.confluent.ksql.execution.expression.tree.BetweenPredicate;
import io.confluent.ksql.execution.expression.tree.BooleanLiteral;
import io.confluent.ksql.execution.expression.tree.BytesLiteral;
import io.confluent.ksql.execution.expression.tree.Cast;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.CreateArrayExpression;
import io.confluent.ksql.execution.expression.tree.CreateMapExpression;
import io.confluent.ksql.execution.expression.tree.CreateStructExpression;
import io.confluent.ksql.execution.expression.tree.DateLiteral;
import io.confluent.ksql.execution.expression.tree.DecimalLiteral;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.DoubleLiteral;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.ExpressionVisitor;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.InListExpression;
import io.confluent.ksql.execution.expression.tree.InPredicate;
import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
import io.confluent.ksql.execution.expression.tree.IntervalUnit;
import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate;
import io.confluent.ksql.execution.expression.tree.IsNullPredicate;
import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall;
import io.confluent.ksql.execution.expression.tree.LambdaVariable;
import io.confluent.ksql.execution.expression.tree.LikePredicate;
import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression;
import io.confluent.ksql.execution.expression.tree.LongLiteral;
import io.confluent.ksql.execution.expression.tree.NotExpression;
import io.confluent.ksql.execution.expression.tree.NullLiteral;
import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.SearchedCaseExpression;
import io.confluent.ksql.execution.expression.tree.SimpleCaseExpression;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.execution.expression.tree.SubscriptExpression;
import io.confluent.ksql.execution.expression.tree.TimeLiteral;
import io.confluent.ksql.execution.expression.tree.TimestampLiteral;
import io.confluent.ksql.execution.expression.tree.Type;
import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.WhenClause;
import io.confluent.ksql.execution.function.UdafUtil;
import io.confluent.ksql.execution.util.CoercionUtil;
import io.confluent.ksql.execution.util.ComparisonUtil;
import io.confluent.ksql.execution.util.FunctionArgumentsUtil;
import io.confluent.ksql.function.AggregateFunctionFactory;
import io.confluent.ksql.function.AggregateFunctionInitArguments;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.KsqlTableFunction;
import io.confluent.ksql.function.UdfFactory;
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlArray;
import io.confluent.ksql.schema.ksql.types.SqlBaseType;
import io.confluent.ksql.schema.ksql.types.SqlMap;
import io.confluent.ksql.schema.ksql.types.SqlStruct;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.KsqlStatementException;
import io.confluent.ksql.util.VisitorUtil;
import java.math.BigDecimal;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

public class ExpressionTypeManager {
    private final LogicalSchema schema;
    private final FunctionRegistry functionRegistry;

    public ExpressionTypeManager(LogicalSchema schema, FunctionRegistry functionRegistry) {
        this.schema = Objects.requireNonNull(schema, "schema");
        this.functionRegistry = Objects.requireNonNull(functionRegistry, "functionRegistry");
    }

    public SqlType getExpressionSqlType(Expression expression) {
        return this.getExpressionSqlType(expression, Collections.emptyMap());
    }

    public SqlType getExpressionSqlType(Expression expression, Map<String, SqlType> lambdaSqlTypeMapping) {
        Context context = new Context(lambdaSqlTypeMapping);
        new Visitor().process(expression, context);
        return context.getSqlType();
    }

    private class Visitor
    implements ExpressionVisitor<Void, Context> {
        private Visitor() {
        }

        @Override
        public Void visitArithmeticBinary(ArithmeticBinaryExpression node, Context context) throws KsqlException {
            SqlType resultType;
            this.process(node.getLeft(), context);
            SqlType leftType = context.getSqlType();
            this.process(node.getRight(), context);
            SqlType rightType = context.getSqlType();
            try {
                resultType = node.getOperator().resultType(leftType, rightType);
            }
            catch (KsqlException e) {
                throw new KsqlStatementException("Error processing expression.", String.format("Error processing expression: %s. %s", new Object[]{node, e.getMessage()}), Objects.toString((Object)node), (Throwable)e);
            }
            context.setSqlType(resultType);
            return null;
        }

        @Override
        public Void visitArithmeticUnary(ArithmeticUnaryExpression node, Context context) {
            this.process(node.getValue(), context);
            return null;
        }

        @Override
        public Void visitLambdaExpression(LambdaFunctionCall node, Context context) {
            this.process(node.getBody(), context);
            return null;
        }

        @Override
        public Void visitLambdaVariable(LambdaVariable node, Context context) {
            context.setSqlType(context.getLambdaSqlTypeMapping().get(node.getLambdaCharacter()));
            return null;
        }

        @Override
        public Void visitIntervalUnit(IntervalUnit exp, Context context) {
            return null;
        }

        @Override
        public Void visitNotExpression(NotExpression node, Context context) {
            context.setSqlType((SqlType)SqlTypes.BOOLEAN);
            return null;
        }

        @Override
        public Void visitCast(Cast node, Context context) {
            context.setSqlType(node.getType().getSqlType());
            return null;
        }

        @Override
        public Void visitComparisonExpression(ComparisonExpression node, Context context) {
            this.process(node.getLeft(), context);
            SqlType leftSchema = context.getSqlType();
            this.process(node.getRight(), context);
            SqlType rightSchema = context.getSqlType();
            if (!ComparisonUtil.isValidComparison(leftSchema, node.getType(), rightSchema)) {
                throw new KsqlStatementException("Cannot compare " + leftSchema + " to " + rightSchema + " with " + (Object)((Object)node.getType()) + ".", "Cannot compare " + node.getLeft().toString() + " (" + leftSchema + ") to " + node.getRight().toString() + " (" + rightSchema + ") with " + (Object)((Object)node.getType()) + ".", node.toString());
            }
            context.setSqlType((SqlType)SqlTypes.BOOLEAN);
            return null;
        }

        @Override
        public Void visitBetweenPredicate(BetweenPredicate node, Context context) {
            context.setSqlType((SqlType)SqlTypes.BOOLEAN);
            return null;
        }

        @Override
        public Void visitUnqualifiedColumnReference(UnqualifiedColumnReferenceExp node, Context context) {
            Optional possibleColumn = ExpressionTypeManager.this.schema.findValueColumn(node.getColumnName());
            Column schemaColumn = (Column)possibleColumn.orElseThrow(() -> new KsqlException("Unknown column " + (Object)((Object)node) + "."));
            context.setSqlType(schemaColumn.type());
            return null;
        }

        @Override
        public Void visitQualifiedColumnReference(QualifiedColumnReferenceExp node, Context context) {
            throw new IllegalStateException("Qualified column references must be resolved to unqualified reference before type can be resolved");
        }

        @Override
        public Void visitDereferenceExpression(DereferenceExpression node, Context context) {
            this.process(node.getBase(), context);
            SqlType sqlType = context.getSqlType();
            if (!(sqlType instanceof SqlStruct)) {
                throw new IllegalStateException("Expected STRUCT type, got: " + sqlType);
            }
            SqlStruct structType = (SqlStruct)sqlType;
            String fieldName = node.getFieldName();
            SqlStruct.Field structField = (SqlStruct.Field)structType.field(fieldName).orElseThrow(() -> new KsqlException("Could not find field '" + fieldName + "' in '" + (Object)((Object)node.getBase()) + "'."));
            context.setSqlType(structField.type());
            return null;
        }

        @Override
        public Void visitStringLiteral(StringLiteral node, Context context) {
            context.setSqlType((SqlType)SqlTypes.STRING);
            return null;
        }

        @Override
        public Void visitBytesLiteral(BytesLiteral node, Context context) {
            context.setSqlType((SqlType)SqlTypes.BYTES);
            return null;
        }

        @Override
        public Void visitBooleanLiteral(BooleanLiteral node, Context context) {
            context.setSqlType((SqlType)SqlTypes.BOOLEAN);
            return null;
        }

        @Override
        public Void visitLongLiteral(LongLiteral node, Context context) {
            context.setSqlType((SqlType)SqlTypes.BIGINT);
            return null;
        }

        @Override
        public Void visitIntegerLiteral(IntegerLiteral node, Context context) {
            context.setSqlType((SqlType)SqlTypes.INTEGER);
            return null;
        }

        @Override
        public Void visitDoubleLiteral(DoubleLiteral node, Context context) {
            context.setSqlType((SqlType)SqlTypes.DOUBLE);
            return null;
        }

        @Override
        public Void visitNullLiteral(NullLiteral node, Context context) {
            context.setSqlType(null);
            return null;
        }

        @Override
        public Void visitLikePredicate(LikePredicate node, Context context) {
            context.setSqlType((SqlType)SqlTypes.BOOLEAN);
            return null;
        }

        @Override
        public Void visitIsNotNullPredicate(IsNotNullPredicate node, Context context) {
            context.setSqlType((SqlType)SqlTypes.BOOLEAN);
            return null;
        }

        @Override
        public Void visitIsNullPredicate(IsNullPredicate node, Context context) {
            context.setSqlType((SqlType)SqlTypes.BOOLEAN);
            return null;
        }

        @Override
        public Void visitSearchedCaseExpression(SearchedCaseExpression node, Context context) {
            Optional<SqlType> whenType = this.validateWhenClauses(node.getWhenClauses(), context);
            Optional<SqlType> defaultType = node.getDefaultValue().map(expression -> ExpressionTypeManager.this.getExpressionSqlType((Expression)((Object)expression), context.getLambdaSqlTypeMapping()));
            if (whenType.isPresent() && defaultType.isPresent()) {
                if (!whenType.get().equals(defaultType.get())) {
                    throw new KsqlException("Invalid Case expression. Type for the default clause should be the same as for 'THEN' clauses." + System.lineSeparator() + "THEN type: " + whenType.get() + "." + System.lineSeparator() + "DEFAULT type: " + defaultType.get() + ".");
                }
                context.setSqlType(whenType.get());
            } else if (whenType.isPresent()) {
                context.setSqlType(whenType.get());
            } else if (defaultType.isPresent()) {
                context.setSqlType(defaultType.get());
            } else {
                throw new KsqlException("Invalid Case expression. All case branches have NULL type");
            }
            return null;
        }

        @Override
        public Void visitSubscriptExpression(SubscriptExpression node, Context context) {
            SqlType valueType;
            this.process(node.getBase(), context);
            SqlType arrayMapType = context.getSqlType();
            if (arrayMapType instanceof SqlMap) {
                valueType = ((SqlMap)arrayMapType).getValueType();
            } else if (arrayMapType instanceof SqlArray) {
                valueType = ((SqlArray)arrayMapType).getItemType();
            } else {
                String structMessage = arrayMapType instanceof SqlStruct ? String.format(" Use the dereference operator for STRUCTS: %s", new Object[]{new DereferenceExpression(Optional.empty(), node.getBase(), ExpressionFormatter.formatExpression(node.getIndex()))}) : "";
                throw new UnsupportedOperationException(String.format("Subscript expression (%s) do not apply to %s.%s", new Object[]{node, arrayMapType, structMessage}));
            }
            context.setSqlType(valueType);
            return null;
        }

        @Override
        public Void visitCreateArrayExpression(CreateArrayExpression exp, Context context) {
            if (exp.getValues().isEmpty()) {
                throw new KsqlException("Array constructor cannot be empty. Please supply at least one element (see https://github.com/confluentinc/ksql/issues/4239).");
            }
            SqlType elementType = CoercionUtil.coerceUserList(exp.getValues(), ExpressionTypeManager.this, context.getLambdaSqlTypeMapping()).commonType().orElseThrow(() -> new KsqlException("Cannot construct an array with all NULL elements (see https://github.com/confluentinc/ksql/issues/4239). As a workaround, you may cast a NULL value to the desired type."));
            context.setSqlType((SqlType)SqlArray.of((SqlType)elementType));
            return null;
        }

        @Override
        public Void visitCreateMapExpression(CreateMapExpression exp, Context context) {
            ImmutableMap<Expression, Expression> map = exp.getMap();
            if (map.isEmpty()) {
                throw new KsqlException("Map constructor cannot be empty. Please supply at least one key value pair (see https://github.com/confluentinc/ksql/issues/4239).");
            }
            SqlType keyType = CoercionUtil.coerceUserList((Collection<Expression>)map.keySet(), ExpressionTypeManager.this, context.getLambdaSqlTypeMapping()).commonType().orElseThrow(() -> new KsqlException("Cannot construct a map with all NULL keys (see https://github.com/confluentinc/ksql/issues/4239). As a workaround, you may cast a NULL key to the desired type."));
            SqlType valueType = CoercionUtil.coerceUserList((Collection<Expression>)map.values(), ExpressionTypeManager.this, context.getLambdaSqlTypeMapping()).commonType().orElseThrow(() -> new KsqlException("Cannot construct a map with all NULL values (see https://github.com/confluentinc/ksql/issues/4239). As a workaround, you may cast a NULL value to the desired type."));
            context.setSqlType((SqlType)SqlMap.of((SqlType)keyType, (SqlType)valueType));
            return null;
        }

        @Override
        public Void visitStructExpression(CreateStructExpression exp, Context context) {
            SqlStruct.Builder builder = SqlStruct.builder();
            for (CreateStructExpression.Field field : exp.getFields()) {
                this.process(field.getValue(), context);
                builder.field(field.getName(), context.getSqlType());
            }
            context.setSqlType((SqlType)builder.build());
            return null;
        }

        @Override
        public Void visitFunctionCall(FunctionCall node, Context context) {
            if (ExpressionTypeManager.this.functionRegistry.isAggregate(node.getName())) {
                List<Expression> args = node.getArguments();
                List<Object> schema = args.stream().map(arg -> ExpressionTypeManager.this.getExpressionSqlType((Expression)((Object)arg), context.getLambdaSqlTypeMapping())).collect(Collectors.toList());
                if (schema.isEmpty()) {
                    schema = Collections.singletonList(FunctionRegistry.DEFAULT_FUNCTION_ARG_SCHEMA);
                }
                AggregateFunctionFactory factory = ExpressionTypeManager.this.functionRegistry.getAggregateFactory(node.getName());
                AggregateFunctionFactory.FunctionSource initArgsAndCreator = factory.getFunction(schema);
                int numInitArgs = initArgsAndCreator.initArgs;
                AggregateFunctionInitArguments initArgs = UdafUtil.createAggregateFunctionInitArgs(numInitArgs, node);
                KsqlAggregateFunction function = (KsqlAggregateFunction)initArgsAndCreator.source.apply(initArgs);
                context.setSqlType(function.returnType());
                return null;
            }
            if (ExpressionTypeManager.this.functionRegistry.isTableFunction(node.getName())) {
                ImmutableList argumentTypes = node.getArguments().isEmpty() ? ImmutableList.of((Object)SqlArgument.of((SqlType)FunctionRegistry.DEFAULT_FUNCTION_ARG_SCHEMA)) : node.getArguments().stream().map(expression -> ExpressionTypeManager.this.getExpressionSqlType((Expression)((Object)expression), context.getLambdaSqlTypeMapping())).map(SqlArgument::of).collect(Collectors.toList());
                KsqlTableFunction tableFunction = ExpressionTypeManager.this.functionRegistry.getTableFunction(node.getName(), (List)argumentTypes);
                context.setSqlType(tableFunction.getReturnType((List)argumentTypes));
                return null;
            }
            UdfFactory udfFactory = ExpressionTypeManager.this.functionRegistry.getUdfFactory(node.getName());
            FunctionArgumentsUtil.FunctionTypeInfo argumentsAndContext = FunctionArgumentsUtil.getFunctionTypeInfo(ExpressionTypeManager.this, node, udfFactory, context.getLambdaSqlTypeMapping());
            context.setSqlType(argumentsAndContext.getReturnType());
            return null;
        }

        @Override
        public Void visitLogicalBinaryExpression(LogicalBinaryExpression node, Context context) {
            this.process(node.getLeft(), context);
            this.process(node.getRight(), context);
            return null;
        }

        @Override
        public Void visitType(Type type, Context context) {
            throw VisitorUtil.illegalState((Object)this, (Object)((Object)type));
        }

        @Override
        public Void visitTimeLiteral(TimeLiteral timeLiteral, Context context) {
            context.setSqlType((SqlType)SqlTypes.TIME);
            return null;
        }

        @Override
        public Void visitDateLiteral(DateLiteral dateLiteral, Context context) {
            context.setSqlType((SqlType)SqlTypes.DATE);
            return null;
        }

        @Override
        public Void visitTimestampLiteral(TimestampLiteral timestampLiteral, Context context) {
            context.setSqlType((SqlType)SqlTypes.TIMESTAMP);
            return null;
        }

        @Override
        public Void visitDecimalLiteral(DecimalLiteral decimalLiteral, Context context) {
            context.setSqlType(DecimalUtil.fromValue((BigDecimal)decimalLiteral.getValue()));
            return null;
        }

        @Override
        public Void visitSimpleCaseExpression(SimpleCaseExpression simpleCaseExpression, Context context) {
            throw VisitorUtil.unsupportedOperation((Object)this, (Object)((Object)simpleCaseExpression));
        }

        @Override
        public Void visitInListExpression(InListExpression inListExpression, Context context) {
            throw VisitorUtil.unsupportedOperation((Object)this, (Object)((Object)inListExpression));
        }

        @Override
        public Void visitInPredicate(InPredicate inPredicate, Context context) {
            context.setSqlType((SqlType)SqlTypes.BOOLEAN);
            return null;
        }

        @Override
        public Void visitWhenClause(WhenClause whenClause, Context context) {
            throw VisitorUtil.illegalState((Object)this, (Object)((Object)whenClause));
        }

        private Optional<SqlType> validateWhenClauses(List<WhenClause> whenClauses, Context context) {
            Optional<SqlType> previousResult = Optional.empty();
            for (WhenClause whenClause : whenClauses) {
                this.process(whenClause.getOperand(), context);
                SqlType operandType = context.getSqlType();
                if (operandType.baseType() != SqlBaseType.BOOLEAN) {
                    throw new KsqlException("WHEN operand type should be boolean." + System.lineSeparator() + "Type for '" + (Object)((Object)whenClause.getOperand()) + "' is " + operandType);
                }
                this.process(whenClause.getResult(), context);
                SqlType resultType = context.getSqlType();
                if (resultType == null) continue;
                if (!previousResult.isPresent()) {
                    previousResult = Optional.of(resultType);
                    continue;
                }
                if (previousResult.get().equals(resultType)) continue;
                throw new KsqlException("Invalid Case expression. Type for all 'THEN' clauses should be the same." + System.lineSeparator() + "THEN expression '" + (Object)((Object)whenClause) + "' has type: " + resultType + "." + System.lineSeparator() + "Previous THEN expression(s) type: " + previousResult.get() + ".");
            }
            return previousResult;
        }
    }

    private static final class Context {
        private final ImmutableMap<String, SqlType> lambdaSqlTypeMapping;
        private SqlType sqlType;

        private Context(Map<String, SqlType> mapping) {
            this.lambdaSqlTypeMapping = ImmutableMap.copyOf(mapping);
        }

        Map<String, SqlType> getLambdaSqlTypeMapping() {
            return this.lambdaSqlTypeMapping;
        }

        SqlType getSqlType() {
            return this.sqlType;
        }

        void setSqlType(SqlType sqlType) {
            this.sqlType = sqlType;
        }
    }
}

