/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.ksql.function.udaf.correlation;

import io.confluent.ksql.function.udaf.TableUdaf;
import io.confluent.ksql.function.udaf.UdafDescription;
import io.confluent.ksql.function.udaf.UdafFactory;
import io.confluent.ksql.util.Pair;
import java.util.function.Function;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.SchemaBuilder;
import org.apache.kafka.connect.data.Struct;

@UdafDescription(name="correlation", description="Computes the Pearson correlation coefficient between two columns of data.", author="Confluent")
public class CorrelationUdaf<T>
implements TableUdaf<Pair<T, T>, Struct, Double> {
    private static final String X_SUM = "X_SUM";
    private static final String Y_SUM = "Y_SUM";
    private static final String X_SQUARED_SUM = "X_SQUARED_SUM";
    private static final String Y_SQUARED_SUM = "Y_SQUARED_SUM";
    private static final String XY_SUM = "XY_SUM";
    private static final String COUNT = "COUNT";
    private static final Schema structSchema = SchemaBuilder.struct().optional().field("X_SUM", Schema.OPTIONAL_FLOAT64_SCHEMA).field("Y_SUM", Schema.OPTIONAL_FLOAT64_SCHEMA).field("X_SQUARED_SUM", Schema.OPTIONAL_FLOAT64_SCHEMA).field("Y_SQUARED_SUM", Schema.OPTIONAL_FLOAT64_SCHEMA).field("XY_SUM", Schema.OPTIONAL_FLOAT64_SCHEMA).field("COUNT", Schema.OPTIONAL_INT64_SCHEMA).build();
    private final Function<T, Double> toDouble;

    @UdafFactory(description="Computes the Pearson correlation coefficient between two DOUBLE columns.", aggregateSchema="STRUCT<X_SUM double, Y_SUM double, X_SQUARED_SUM double, Y_SQUARED_SUM double, XY_SUM double, COUNT bigint>")
    public static TableUdaf<Pair<Double, Double>, Struct, Double> createCorrelationDouble() {
        return new CorrelationUdaf<Double>(Function.identity());
    }

    @UdafFactory(description="Computes the Pearson correlation coefficient between two INTEGER columns.", aggregateSchema="STRUCT<X_SUM double, Y_SUM double, X_SQUARED_SUM double, Y_SQUARED_SUM double, XY_SUM double, COUNT bigint>")
    public static TableUdaf<Pair<Integer, Integer>, Struct, Double> createCorrelationInteger() {
        return new CorrelationUdaf<Integer>(Integer::doubleValue);
    }

    @UdafFactory(description="Computes the Pearson correlation coefficient between two BIGINT columns.", aggregateSchema="STRUCT<X_SUM double, Y_SUM double, X_SQUARED_SUM double, Y_SQUARED_SUM double, XY_SUM double, COUNT bigint>")
    public static TableUdaf<Pair<Long, Long>, Struct, Double> createCorrelationLong() {
        return new CorrelationUdaf<Long>(Long::doubleValue);
    }

    public CorrelationUdaf(Function<T, Double> toDouble) {
        this.toDouble = toDouble;
    }

    public Struct initialize() {
        return new Struct(structSchema).put(X_SUM, (Object)0.0).put(Y_SUM, (Object)0.0).put(X_SQUARED_SUM, (Object)0.0).put(Y_SQUARED_SUM, (Object)0.0).put(XY_SUM, (Object)0.0).put(COUNT, (Object)0L);
    }

    public Struct aggregate(Pair<T, T> current, Struct aggregate) {
        if (current.getLeft() == null || current.getRight() == null) {
            return aggregate;
        }
        double x = this.toDouble.apply(current.getLeft());
        double y = this.toDouble.apply(current.getRight());
        return new Struct(structSchema).put(X_SUM, (Object)(aggregate.getFloat64(X_SUM) + x)).put(Y_SUM, (Object)(aggregate.getFloat64(Y_SUM) + y)).put(X_SQUARED_SUM, (Object)(aggregate.getFloat64(X_SQUARED_SUM) + x * x)).put(Y_SQUARED_SUM, (Object)(aggregate.getFloat64(Y_SQUARED_SUM) + y * y)).put(XY_SUM, (Object)(aggregate.getFloat64(XY_SUM) + x * y)).put(COUNT, (Object)(aggregate.getInt64(COUNT) + 1L));
    }

    public Struct merge(Struct aggOne, Struct aggTwo) {
        return new Struct(structSchema).put(X_SUM, (Object)(aggOne.getFloat64(X_SUM) + aggTwo.getFloat64(X_SUM))).put(Y_SUM, (Object)(aggOne.getFloat64(Y_SUM) + aggTwo.getFloat64(Y_SUM))).put(X_SQUARED_SUM, (Object)(aggOne.getFloat64(X_SQUARED_SUM) + aggTwo.getFloat64(X_SQUARED_SUM))).put(Y_SQUARED_SUM, (Object)(aggOne.getFloat64(Y_SQUARED_SUM) + aggTwo.getFloat64(Y_SQUARED_SUM))).put(XY_SUM, (Object)(aggOne.getFloat64(XY_SUM) + aggTwo.getFloat64(XY_SUM))).put(COUNT, (Object)(aggOne.getInt64(COUNT) + aggTwo.getInt64(COUNT)));
    }

    public Double map(Struct agg) {
        double sumX = agg.getFloat64(X_SUM);
        double sumY = agg.getFloat64(Y_SUM);
        double squaredXSum = agg.getFloat64(X_SQUARED_SUM);
        double squaredYSum = agg.getFloat64(Y_SQUARED_SUM);
        double sumXY = agg.getFloat64(XY_SUM);
        long count = agg.getInt64(COUNT);
        double numerator = (double)count * sumXY - sumX * sumY;
        double denominatorX = (double)count * squaredXSum - sumX * sumX;
        double denominatorY = (double)count * squaredYSum - sumY * sumY;
        double denominator = Math.sqrt(denominatorX * denominatorY);
        return numerator / denominator;
    }

    public Struct undo(Pair<T, T> valueToUndo, Struct aggregateValue) {
        if (valueToUndo.getLeft() == null || valueToUndo.getRight() == null) {
            return aggregateValue;
        }
        double x = this.toDouble.apply(valueToUndo.getLeft());
        double y = this.toDouble.apply(valueToUndo.getRight());
        return new Struct(structSchema).put(X_SUM, (Object)(aggregateValue.getFloat64(X_SUM) - x)).put(Y_SUM, (Object)(aggregateValue.getFloat64(Y_SUM) - y)).put(X_SQUARED_SUM, (Object)(aggregateValue.getFloat64(X_SQUARED_SUM) - x * x)).put(Y_SQUARED_SUM, (Object)(aggregateValue.getFloat64(Y_SQUARED_SUM) - y * y)).put(XY_SUM, (Object)(aggregateValue.getFloat64(XY_SUM) - x * y)).put(COUNT, (Object)(aggregateValue.getInt64(COUNT) - 1L));
    }
}

