/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.ksql.function.udaf.stddev;

import io.confluent.ksql.function.udaf.TableUdaf;
import io.confluent.ksql.function.udaf.UdafDescription;
import io.confluent.ksql.function.udaf.UdafFactory;
import java.util.function.BiFunction;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.SchemaBuilder;
import org.apache.kafka.connect.data.Struct;

@UdafDescription(name="STDDEV_SAMP", description="Returns the sample standard deviation of the column. Applicable only to numeric types.", author="Confluent")
public final class StandardDeviationSampUdaf {
    private static final String COUNT = "COUNT";
    private static final String SUM = "SUM";
    private static final String M2 = "M2";
    private static final Schema STRUCT_LONG = SchemaBuilder.struct().optional().field("SUM", Schema.OPTIONAL_INT64_SCHEMA).field("COUNT", Schema.OPTIONAL_INT64_SCHEMA).field("M2", Schema.OPTIONAL_FLOAT64_SCHEMA).build();
    private static final Schema STRUCT_INT = SchemaBuilder.struct().optional().field("SUM", Schema.OPTIONAL_INT32_SCHEMA).field("COUNT", Schema.OPTIONAL_INT64_SCHEMA).field("M2", Schema.OPTIONAL_FLOAT64_SCHEMA).build();
    private static final Schema STRUCT_DOUBLE = SchemaBuilder.struct().optional().field("SUM", Schema.OPTIONAL_FLOAT64_SCHEMA).field("COUNT", Schema.OPTIONAL_INT64_SCHEMA).field("M2", Schema.OPTIONAL_FLOAT64_SCHEMA).build();

    private StandardDeviationSampUdaf() {
    }

    @UdafFactory(description="Compute sample standard deviation of column with type Long.", aggregateSchema="STRUCT<SUM bigint, COUNT bigint, M2 double>")
    public static TableUdaf<Long, Struct, Double> stdDevLong() {
        return StandardDeviationSampUdaf.getStdDevImplementation(0L, STRUCT_LONG, (agg, newValue) -> newValue + agg.getInt64(SUM), (agg, newValue) -> newValue * (agg.getInt64(COUNT) + 1L) - (agg.getInt64(SUM) + newValue), (agg1, agg2) -> agg1.getInt64(SUM) / agg1.getInt64(COUNT) - agg2.getInt64(SUM) / agg2.getInt64(COUNT), (agg1, agg2) -> agg1.getInt64(SUM) + agg2.getInt64(SUM), (agg, valueToRemove) -> agg.getInt64(SUM) - valueToRemove);
    }

    @UdafFactory(description="Compute sample standard deviation of column with type Integer.", aggregateSchema="STRUCT<SUM integer, COUNT bigint, M2 double>")
    public static TableUdaf<Integer, Struct, Double> stdDevInt() {
        return StandardDeviationSampUdaf.getStdDevImplementation(0, STRUCT_INT, (agg, newValue) -> newValue + agg.getInt32(SUM), (agg, newValue) -> (long)newValue.intValue() * (agg.getInt64(COUNT) + 1L) - (long)(agg.getInt32(SUM) + newValue), (agg1, agg2) -> (long)agg1.getInt32(SUM).intValue() / agg1.getInt64(COUNT) - (long)agg2.getInt32(SUM).intValue() / agg2.getInt64(COUNT), (agg1, agg2) -> agg1.getInt32(SUM) + agg2.getInt32(SUM), (agg, valueToRemove) -> agg.getInt32(SUM) - valueToRemove);
    }

    @UdafFactory(description="Compute sample standard deviation of column with type Double.", aggregateSchema="STRUCT<SUM double, COUNT bigint, M2 double>")
    public static TableUdaf<Double, Struct, Double> stdDevDouble() {
        return StandardDeviationSampUdaf.getStdDevImplementation(0.0, STRUCT_DOUBLE, (agg, newValue) -> newValue + agg.getFloat64(SUM), (agg, newValue) -> newValue * (double)(agg.getInt64(COUNT) + 1L) - (agg.getFloat64(SUM) + newValue), (agg1, agg2) -> agg1.getFloat64(SUM) / (double)agg1.getInt64(COUNT).longValue() - agg2.getFloat64(SUM) / (double)agg2.getInt64(COUNT).longValue(), (agg1, agg2) -> agg1.getFloat64(SUM) + agg2.getFloat64(SUM), (agg, valueToRemove) -> agg.getFloat64(SUM) - valueToRemove);
    }

    private static <I> TableUdaf<I, Struct, Double> getStdDevImplementation(final I initialValue, final Schema structSchema, final BiFunction<Struct, I, I> add, final BiFunction<Struct, I, Double> createDelta, final BiFunction<Struct, Struct, Double> mergeInner, final BiFunction<Struct, Struct, I> mergeSum, final BiFunction<Struct, I, I> undoSum) {
        return new TableUdaf<I, Struct, Double>(){

            public Struct initialize() {
                return new Struct(structSchema).put(StandardDeviationSampUdaf.SUM, initialValue).put(StandardDeviationSampUdaf.COUNT, (Object)0L).put(StandardDeviationSampUdaf.M2, (Object)0.0);
            }

            public Struct aggregate(I newValue, Struct aggregate) {
                double newM2;
                if (newValue == null) {
                    return aggregate;
                }
                long newCount = aggregate.getInt64(StandardDeviationSampUdaf.COUNT) + 1L;
                if (newCount - 1L > 0L) {
                    double delta = (Double)createDelta.apply(aggregate, newValue);
                    newM2 = delta * delta / (double)(newCount * (newCount - 1L));
                } else {
                    newM2 = 0.0;
                }
                return new Struct(structSchema).put(StandardDeviationSampUdaf.COUNT, (Object)newCount).put(StandardDeviationSampUdaf.SUM, add.apply(aggregate, newValue)).put(StandardDeviationSampUdaf.M2, (Object)(newM2 + aggregate.getFloat64(StandardDeviationSampUdaf.M2)));
            }

            public Struct merge(Struct aggOne, Struct aggTwo) {
                double newM2;
                long countOne = aggOne.getInt64(StandardDeviationSampUdaf.COUNT);
                long countTwo = aggTwo.getInt64(StandardDeviationSampUdaf.COUNT);
                double m2One = aggOne.getFloat64(StandardDeviationSampUdaf.M2);
                double m2Two = aggTwo.getFloat64(StandardDeviationSampUdaf.M2);
                long newCount = countOne + countTwo;
                if (countOne == 0L || countTwo == 0L) {
                    newM2 = m2One + m2Two;
                } else {
                    double innerCalc = (Double)mergeInner.apply(aggOne, aggTwo);
                    newM2 = m2One + m2Two + (double)(countOne * countTwo) * innerCalc * innerCalc / (double)newCount;
                }
                return new Struct(structSchema).put(StandardDeviationSampUdaf.COUNT, (Object)newCount).put(StandardDeviationSampUdaf.SUM, mergeSum.apply(aggOne, aggTwo)).put(StandardDeviationSampUdaf.M2, (Object)newM2);
            }

            public Double map(Struct aggregate) {
                long count = aggregate.getInt64(StandardDeviationSampUdaf.COUNT);
                if (count < 2L) {
                    return 0.0;
                }
                return aggregate.getFloat64(StandardDeviationSampUdaf.M2) / (double)(count - 1L);
            }

            public Struct undo(I valueToUndo, Struct aggregate) {
                double newM2;
                if (valueToUndo == null) {
                    return aggregate;
                }
                long newCount = aggregate.getInt64(StandardDeviationSampUdaf.COUNT) - 1L;
                if (newCount > 0L) {
                    double delta = (Double)createDelta.apply(aggregate, valueToUndo);
                    newM2 = delta * delta / (double)(newCount * (newCount + 1L));
                } else {
                    newM2 = 0.0;
                }
                return new Struct(structSchema).put(StandardDeviationSampUdaf.COUNT, (Object)newCount).put(StandardDeviationSampUdaf.SUM, undoSum.apply(aggregate, valueToUndo)).put(StandardDeviationSampUdaf.M2, (Object)(aggregate.getFloat64(StandardDeviationSampUdaf.M2) - newM2));
            }
        };
    }
}

