/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.ksql.function.udaf.topk;

import io.confluent.ksql.function.udaf.Udaf;
import io.confluent.ksql.function.udaf.UdafDescription;
import io.confluent.ksql.function.udaf.UdafFactory;
import io.confluent.ksql.function.udaf.VariadicArgs;
import io.confluent.ksql.schema.ksql.SchemaConverters;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlArray;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.util.Pair;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.SchemaBuilder;
import org.apache.kafka.connect.data.Struct;

@UdafDescription(name="TOPK", description="Returns the top k values for a column and other values in those records.", author="Confluent")
public class TopkKudaf<T extends Comparable<? super T>, S>
implements Udaf<Pair<T, VariadicArgs<Object>>, List<S>, List<S>> {
    private static final String SORT_FIELD = "sort_col";
    private static final Function<Integer, String> OTHER_COL_TO_FIELD = fieldNum -> "col" + fieldNum;
    private final int topKSize;
    private Function<S, T> structToVal;
    private Function<Pair<T, VariadicArgs<Object>>, S> valToStruct;
    private SqlType aggregateSchema;

    @UdafFactory(description="Returns the top k values for an integer column and other values in those records.")
    public static <S> Udaf<Pair<Integer, VariadicArgs<Object>>, List<S>, List<S>> createTopKInt(int k) {
        return new TopkKudaf(k);
    }

    @UdafFactory(description="Returns the top k values for a bigint column and other values in those records.")
    public static <S> Udaf<Pair<Long, VariadicArgs<Object>>, List<S>, List<S>> createTopKLong(int k) {
        return new TopkKudaf(k);
    }

    @UdafFactory(description="Returns the top k values for a double column and other values in those records.")
    public static <S> Udaf<Pair<Double, VariadicArgs<Object>>, List<S>, List<S>> createTopKDouble(int k) {
        return new TopkKudaf(k);
    }

    @UdafFactory(description="Returns the top k values for a string column and other values in those records.")
    public static <S> Udaf<Pair<String, VariadicArgs<Object>>, List<S>, List<S>> createTopKString(int k) {
        return new TopkKudaf(k);
    }

    TopkKudaf(int topKSize) {
        this.topKSize = topKSize;
    }

    public void initializeTypeArguments(List<SqlArgument> argTypeList) {
        if (argTypeList.size() > 2) {
            Schema structSchema = this.makeStructSchema(argTypeList);
            this.aggregateSchema = SchemaConverters.connectToSqlConverter().toSqlType(structSchema);
            this.structToVal = struct -> (Comparable)((Struct)struct).get(SORT_FIELD);
            this.valToStruct = pair -> this.makeStruct(structSchema, (Comparable)pair.getLeft(), (VariadicArgs<Object>)((VariadicArgs)pair.getRight()));
        } else {
            this.aggregateSchema = argTypeList.get(0).getSqlTypeOrThrow();
            this.structToVal = self -> (Comparable)self;
            this.valToStruct = selfAndVarArgs -> selfAndVarArgs.getLeft();
        }
    }

    public Optional<SqlType> getAggregateSqlType() {
        return Optional.of(SqlArray.of((SqlType)this.aggregateSchema));
    }

    public Optional<SqlType> getReturnSqlType() {
        return Optional.of(SqlArray.of((SqlType)this.aggregateSchema));
    }

    public List<S> initialize() {
        return new ArrayList();
    }

    public List<S> aggregate(Pair<T, VariadicArgs<Object>> currentValue, List<S> aggregateValue) {
        if (currentValue.getLeft() == null) {
            return aggregateValue;
        }
        int currentSize = aggregateValue.size();
        if (!aggregateValue.isEmpty()) {
            Comparable last = (Comparable)this.structToVal.apply(aggregateValue.get(currentSize - 1));
            if (((Comparable)currentValue.getLeft()).compareTo(last) <= 0 && currentSize == this.topKSize) {
                return aggregateValue;
            }
        }
        if (currentSize == this.topKSize) {
            aggregateValue.set(currentSize - 1, this.valToStruct.apply(currentValue));
        } else {
            aggregateValue.add(this.valToStruct.apply(currentValue));
        }
        aggregateValue.sort(Comparator.comparing(this.structToVal).reversed());
        return aggregateValue;
    }

    public List<S> merge(List<S> aggOne, List<S> aggTwo) {
        ArrayList<Object> merged = new ArrayList<Object>(Math.min(this.topKSize, aggOne.size() + aggTwo.size()));
        int idx1 = 0;
        int idx2 = 0;
        for (int i = 0; i != this.topKSize; ++i) {
            Comparable v2;
            Object s2;
            Comparable v1;
            Object s1;
            if (idx1 < aggOne.size()) {
                s1 = aggOne.get(idx1);
                v1 = (Comparable)this.structToVal.apply(s1);
            } else {
                s1 = null;
                v1 = null;
            }
            if (idx2 < aggTwo.size()) {
                s2 = aggTwo.get(idx2);
                v2 = (Comparable)this.structToVal.apply(s2);
            } else {
                s2 = null;
                v2 = null;
            }
            if (v1 != null && (v2 == null || v1.compareTo(v2) >= 0)) {
                merged.add(s1);
                ++idx1;
                continue;
            }
            if (v2 == null || v1 != null && v1.compareTo(v2) >= 0) break;
            merged.add(s2);
            ++idx2;
        }
        return merged;
    }

    public List<S> map(List<S> agg) {
        return agg;
    }

    private Struct makeStruct(Schema structSchema, T sortCol, VariadicArgs<Object> otherCols) {
        Struct struct = new Struct(structSchema);
        struct.put(SORT_FIELD, sortCol);
        for (int argIndex = 0; argIndex < otherCols.size(); ++argIndex) {
            struct.put(OTHER_COL_TO_FIELD.apply(argIndex), otherCols.get(argIndex));
        }
        return struct;
    }

    private Schema makeStructSchema(List<SqlArgument> argTypeList) {
        SchemaBuilder builder = SchemaBuilder.struct().optional();
        for (int argIndex = 0; argIndex < argTypeList.size() - 1; ++argIndex) {
            SqlType argSchema = argTypeList.get(argIndex).getSqlTypeOrThrow();
            builder.field(argIndex == 0 ? SORT_FIELD : OTHER_COL_TO_FIELD.apply(argIndex - 1), SchemaConverters.sqlToConnectConverter().toConnectSchema(argSchema));
        }
        return builder.build();
    }
}

