/*
 * Decompiled with CFR 0.152.
 */
package io.kroxylicious.proxy.internal;

import io.kroxylicious.proxy.bootstrap.FilterChainFactory;
import io.kroxylicious.proxy.config.NamedFilterDefinition;
import io.kroxylicious.proxy.config.PluginFactoryRegistry;
import io.kroxylicious.proxy.filter.Filter;
import io.kroxylicious.proxy.filter.FilterAndInvoker;
import io.kroxylicious.proxy.filter.NetFilter;
import io.kroxylicious.proxy.internal.ApiVersionsServiceImpl;
import io.kroxylicious.proxy.internal.KafkaAuthnHandler;
import io.kroxylicious.proxy.internal.KafkaProxyFrontendHandler;
import io.kroxylicious.proxy.internal.ResponseOrderer;
import io.kroxylicious.proxy.internal.SaslDecodePredicate;
import io.kroxylicious.proxy.internal.codec.KafkaMessageListener;
import io.kroxylicious.proxy.internal.codec.KafkaRequestDecoder;
import io.kroxylicious.proxy.internal.codec.KafkaResponseEncoder;
import io.kroxylicious.proxy.internal.filter.ApiVersionsDowngradeFilter;
import io.kroxylicious.proxy.internal.filter.ApiVersionsIntersectFilter;
import io.kroxylicious.proxy.internal.filter.BrokerAddressFilter;
import io.kroxylicious.proxy.internal.filter.EagerMetadataLearner;
import io.kroxylicious.proxy.internal.filter.NettyFilterContext;
import io.kroxylicious.proxy.internal.metrics.DownstreamMessageCountingKafkaMessageListener;
import io.kroxylicious.proxy.internal.metrics.MetricEmittingKafkaMessageListener;
import io.kroxylicious.proxy.internal.net.Endpoint;
import io.kroxylicious.proxy.internal.net.EndpointBinding;
import io.kroxylicious.proxy.internal.net.EndpointBindingResolver;
import io.kroxylicious.proxy.internal.net.EndpointGateway;
import io.kroxylicious.proxy.internal.net.EndpointReconciler;
import io.kroxylicious.proxy.internal.util.Metrics;
import io.kroxylicious.proxy.model.VirtualClusterModel;
import io.kroxylicious.proxy.service.HostPort;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Meter;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.SniHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.util.concurrent.Future;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletionStage;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KafkaProxyInitializer
extends ChannelInitializer<Channel> {
    private static final Logger LOGGER = LoggerFactory.getLogger(KafkaProxyInitializer.class);
    private static final ChannelInboundHandlerAdapter LOGGING_INBOUND_ERROR_HANDLER = new LoggingInboundErrorHandler();
    static final String LOGGING_INBOUND_ERROR_HANDLER_NAME = "loggingInboundErrorHandler";
    private final boolean haproxyProtocol;
    private final Map<KafkaAuthnHandler.SaslMechanism, AuthenticateCallbackHandler> authnHandlers;
    private final boolean tls;
    private final EndpointBindingResolver bindingResolver;
    private final EndpointReconciler endpointReconciler;
    private final PluginFactoryRegistry pfr;
    private final FilterChainFactory filterChainFactory;
    private final ApiVersionsServiceImpl apiVersionsService;
    private final Counter clientToProxyErrorCounter;

    public KafkaProxyInitializer(FilterChainFactory filterChainFactory, PluginFactoryRegistry pfr, boolean tls, EndpointBindingResolver bindingResolver, EndpointReconciler endpointReconciler, boolean haproxyProtocol, Map<KafkaAuthnHandler.SaslMechanism, AuthenticateCallbackHandler> authnMechanismHandlers, ApiVersionsServiceImpl apiVersionsService) {
        this.pfr = pfr;
        this.endpointReconciler = endpointReconciler;
        this.haproxyProtocol = haproxyProtocol;
        this.authnHandlers = authnMechanismHandlers != null ? authnMechanismHandlers : Map.of();
        this.tls = tls;
        this.bindingResolver = bindingResolver;
        this.filterChainFactory = filterChainFactory;
        this.apiVersionsService = apiVersionsService;
        this.clientToProxyErrorCounter = (Counter)Metrics.clientToProxyErrorCounter("", null).withTags(new String[0]);
    }

    public void initChannel(Channel ch) {
        LOGGER.trace("Connection from {} to my address {}", (Object)ch.remoteAddress(), (Object)ch.localAddress());
        if (this.tls) {
            this.initTlsChannel(ch);
        } else {
            this.initPlainChannel(ch);
        }
        KafkaProxyInitializer.addLoggingErrorHandler(ch.pipeline());
    }

    private void initPlainChannel(final Channel ch) {
        ch.pipeline().addLast("plainResolver", (ChannelHandler)new ChannelInboundHandlerAdapter(){

            public void channelActive(ChannelHandlerContext ctx) {
                KafkaProxyInitializer.this.bindingResolver.resolve(Endpoint.createEndpoint(ch, KafkaProxyInitializer.this.tls), null).handle((binding, t) -> {
                    if (t != null) {
                        ctx.fireExceptionCaught(t);
                        return null;
                    }
                    try {
                        KafkaProxyInitializer.this.addHandlers(ch, (EndpointBinding)binding);
                        ctx.fireChannelActive();
                    }
                    catch (Throwable t1) {
                        ctx.fireExceptionCaught(t1);
                    }
                    finally {
                        ch.pipeline().remove((ChannelHandler)this);
                    }
                    return null;
                });
            }
        });
    }

    private void initTlsChannel(Channel ch) {
        LOGGER.debug("Adding SSL/SNI handler");
        ch.pipeline().addLast("sniResolver", (ChannelHandler)new SniHandler((sniHostname, promise) -> {
            try {
                Endpoint endpoint = Endpoint.createEndpoint(ch, this.tls);
                CompletionStage<EndpointBinding> stage = this.bindingResolver.resolve(endpoint, (String)sniHostname);
                stage.handle((binding, t) -> {
                    try {
                        if (t != null) {
                            LOGGER.warn("Exception resolving Virtual Cluster Binding for endpoint {} and sniHostname {}: {}", new Object[]{endpoint, sniHostname, t.getMessage()});
                            promise.setFailure(t);
                            return null;
                        }
                        EndpointGateway gateway = binding.endpointGateway();
                        Optional<SslContext> sslContext = gateway.getDownstreamSslContext();
                        if (sslContext.isEmpty()) {
                            promise.setFailure((Throwable)new IllegalStateException("Virtual cluster %s does not provide SSL context".formatted(gateway)));
                        } else {
                            this.addHandlers(ch, (EndpointBinding)binding);
                            promise.setSuccess((Object)sslContext.get());
                        }
                    }
                    catch (Throwable t1) {
                        promise.setFailure(t1);
                    }
                    return null;
                });
                return promise;
            }
            catch (Throwable cause) {
                return promise.setFailure(cause);
            }
        }){

            protected void onLookupComplete(ChannelHandlerContext ctx, Future<SslContext> future) throws Exception {
                if (future.isSuccess()) {
                    super.onLookupComplete(ctx, future);
                    ctx.fireChannelActive();
                } else {
                    KafkaProxyInitializer.this.clientToProxyErrorCounter.increment();
                    ctx.close();
                }
            }
        });
    }

    void addHandlers(Channel ch, EndpointBinding binding) {
        VirtualClusterModel virtualCluster = binding.endpointGateway().virtualCluster();
        ChannelPipeline pipeline = ch.pipeline();
        pipeline.remove(LOGGING_INBOUND_ERROR_HANDLER_NAME);
        if (virtualCluster.isLogNetwork()) {
            pipeline.addLast("networkLogger", (ChannelHandler)new LoggingHandler("io.kroxylicious.proxy.internal.DownstreamNetworkLogger", LogLevel.INFO));
        }
        if (this.haproxyProtocol) {
            LOGGER.debug("Adding haproxy handler");
            pipeline.addLast("HAProxyMessageDecoder", (ChannelHandler)new HAProxyMessageDecoder());
        }
        SaslDecodePredicate dp = new SaslDecodePredicate(!this.authnHandlers.isEmpty());
        MetricEmittingKafkaMessageListener encoderListener = KafkaProxyInitializer.buildMetricsMessageListenerForEncode(binding, virtualCluster);
        KafkaMessageListener decoderListener = this.buildMetricsMessageListenerForDecode(binding, virtualCluster);
        KafkaRequestDecoder decoder = new KafkaRequestDecoder(dp, virtualCluster.socketFrameMaxSizeBytes(), this.apiVersionsService, decoderListener);
        pipeline.addLast("requestDecoder", (ChannelHandler)decoder);
        pipeline.addLast("responseEncoder", (ChannelHandler)new KafkaResponseEncoder(encoderListener));
        pipeline.addLast("responseOrderer", (ChannelHandler)new ResponseOrderer());
        if (virtualCluster.isLogFrames()) {
            pipeline.addLast("frameLogger", (ChannelHandler)new LoggingHandler("io.kroxylicious.proxy.internal.DownstreamFrameLogger", LogLevel.INFO));
        }
        if (!this.authnHandlers.isEmpty()) {
            LOGGER.debug("Adding authn handler for handlers {}", this.authnHandlers);
            pipeline.addLast(new ChannelHandler[]{new KafkaAuthnHandler(ch, this.authnHandlers)});
        }
        InitalizerNetFilter netFilter = new InitalizerNetFilter(dp, ch, binding, this.pfr, this.filterChainFactory, virtualCluster.getFilters(), this.endpointReconciler, new ApiVersionsIntersectFilter(this.apiVersionsService), new ApiVersionsDowngradeFilter(this.apiVersionsService));
        KafkaProxyFrontendHandler frontendHandler = new KafkaProxyFrontendHandler((NetFilter)netFilter, dp, binding, virtualCluster.getClusterName());
        pipeline.addLast("netHandler", (ChannelHandler)frontendHandler);
        KafkaProxyInitializer.addLoggingErrorHandler(pipeline);
        LOGGER.debug("{}: Initial pipeline: {}", (Object)ch, (Object)pipeline);
    }

    private KafkaMessageListener buildMetricsMessageListenerForDecode(EndpointBinding binding, VirtualClusterModel virtualCluster) {
        String clusterName = virtualCluster.getClusterName();
        Integer nodeId = binding.nodeId();
        Meter.MeterProvider<Counter> clientToProxyMessageCounterProvider = Metrics.clientToProxyMessageCounterProvider(clusterName, nodeId);
        Meter.MeterProvider<DistributionSummary> clientToProxyMessageSizeDistributionProvider = Metrics.clientToProxyMessageSizeDistributionProvider(clusterName, nodeId);
        return KafkaMessageListener.chainOf(new MetricEmittingKafkaMessageListener(clientToProxyMessageCounterProvider, clientToProxyMessageSizeDistributionProvider), this.deprecatedMessageMetricHandler(clusterName));
    }

    private static MetricEmittingKafkaMessageListener buildMetricsMessageListenerForEncode(EndpointBinding binding, VirtualClusterModel virtualCluster) {
        String clusterName = virtualCluster.getClusterName();
        Integer nodeId = binding.nodeId();
        Meter.MeterProvider<Counter> proxyToClientMessageCounterProvider = Metrics.proxyToClientMessageCounterProvider(clusterName, nodeId);
        Meter.MeterProvider<DistributionSummary> proxyToClientMessageSizeDistributionProvider = Metrics.proxyToClientMessageSizeDistributionProvider(clusterName, nodeId);
        return new MetricEmittingKafkaMessageListener(proxyToClientMessageCounterProvider, proxyToClientMessageSizeDistributionProvider);
    }

    private KafkaMessageListener deprecatedMessageMetricHandler(String clusterName) {
        return new DownstreamMessageCountingKafkaMessageListener(Metrics.inboundDownstreamMessageCounter(clusterName), Metrics.inboundDownstreamDecodedMessageCounter(clusterName), Metrics.payloadSizeBytesUpstreamSummary(clusterName));
    }

    private static void addLoggingErrorHandler(ChannelPipeline pipeline) {
        pipeline.addLast(LOGGING_INBOUND_ERROR_HANDLER_NAME, (ChannelHandler)LOGGING_INBOUND_ERROR_HANDLER);
    }

    static class InitalizerNetFilter
    implements NetFilter {
        private final SaslDecodePredicate decodePredicate;
        private final Channel ch;
        private final EndpointGateway gateway;
        private final EndpointBinding binding;
        private final PluginFactoryRegistry pfr;
        private final FilterChainFactory filterChainFactory;
        private final List<NamedFilterDefinition> filterDefinitions;
        private final EndpointReconciler endpointReconciler;
        private final ApiVersionsIntersectFilter apiVersionsIntersectFilter;
        private final ApiVersionsDowngradeFilter apiVersionsDowngradeFilter;

        InitalizerNetFilter(SaslDecodePredicate decodePredicate, Channel ch, EndpointBinding binding, PluginFactoryRegistry pfr, FilterChainFactory filterChainFactory, List<NamedFilterDefinition> filterDefinitions, EndpointReconciler endpointReconciler, ApiVersionsIntersectFilter apiVersionsIntersectFilter, ApiVersionsDowngradeFilter apiVersionsDowngradeFilter) {
            this.decodePredicate = decodePredicate;
            this.ch = ch;
            this.gateway = binding.endpointGateway();
            this.binding = binding;
            this.pfr = pfr;
            this.filterChainFactory = filterChainFactory;
            this.filterDefinitions = filterDefinitions;
            this.endpointReconciler = endpointReconciler;
            this.apiVersionsIntersectFilter = apiVersionsIntersectFilter;
            this.apiVersionsDowngradeFilter = apiVersionsDowngradeFilter;
        }

        @Override
        public void selectServer(NetFilter.NetFilterContext context) {
            List apiVersionFilters = this.decodePredicate.isAuthenticationOffloadEnabled() ? List.of() : FilterAndInvoker.build("ApiVersionsIntersect (internal)", (Filter)this.apiVersionsIntersectFilter);
            NettyFilterContext filterContext = new NettyFilterContext(this.ch.eventLoop(), this.pfr);
            List<FilterAndInvoker> filterChain = this.filterChainFactory.createFilters(filterContext, this.filterDefinitions);
            List<FilterAndInvoker> brokerAddressFilters = FilterAndInvoker.build("BrokerAddress (internal)", (Filter)new BrokerAddressFilter(this.gateway, this.endpointReconciler));
            ArrayList<FilterAndInvoker> filters = new ArrayList<FilterAndInvoker>(apiVersionFilters);
            filters.addAll(FilterAndInvoker.build("ApiVersionsDowngrade (internal)", (Filter)this.apiVersionsDowngradeFilter));
            filters.addAll(filterChain);
            if (this.binding.restrictUpstreamToMetadataDiscovery()) {
                filters.addAll(FilterAndInvoker.build("EagerMetadataLearner (internal)", (Filter)new EagerMetadataLearner()));
            }
            filters.addAll(brokerAddressFilters);
            HostPort target = this.binding.upstreamTarget();
            if (target == null) {
                throw new IllegalStateException("A target address for binding %s is not known.".formatted(this.binding));
            }
            context.initiateConnect(target, filters);
        }
    }

    @ChannelHandler.Sharable
    static class LoggingInboundErrorHandler
    extends ChannelInboundHandlerAdapter {
        LoggingInboundErrorHandler() {
        }

        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
            LOGGER.atWarn().setCause(LOGGER.isDebugEnabled() ? cause : null).log("An exceptionCaught() event was caught by the error handler {}: {}. Increase log level to DEBUG for stacktrace", (Object)cause.getClass().getSimpleName(), (Object)cause.getMessage());
        }
    }
}

