/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.gateway.filter.authswap;

import io.confluent.gateway.filter.authswap.Session;
import io.confluent.gateway.filter.authswap.client.AuthResult;
import io.confluent.gateway.filter.authswap.client.ClientAuthProcessor;
import io.confluent.gateway.filter.authswap.client.MtlsClientAuthProcessor;
import io.confluent.gateway.filter.authswap.client.SaslClientAuthProcessor;
import io.confluent.gateway.filter.authswap.cluster.ClusterAuthProcessor;
import io.confluent.gateway.filter.authswap.cluster.SaslClusterAuthProcessor;
import io.confluent.gateway.filter.authswap.config.AuthSwapFilterConfig;
import io.confluent.gateway.filter.authswap.config.ClientAuth;
import io.confluent.gateway.filter.authswap.metrics.AuthMetricsRecorder;
import io.confluent.gateway.filter.authswap.metrics.AuthSwapMeter;
import io.confluent.gateway.filter.authswap.metrics.MicrometerAuthSwapMeter;
import io.confluent.gateway.filter.authswap.response.ErrorResponseBuilderRegistry;
import io.kroxylicious.proxy.filter.FilterContext;
import io.kroxylicious.proxy.filter.RequestFilter;
import io.kroxylicious.proxy.filter.RequestFilterResult;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import org.apache.kafka.common.message.RequestHeaderData;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.protocol.Errors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AuthSwapFilter
implements RequestFilter {
    private static final Logger LOGGER = LoggerFactory.getLogger(AuthSwapFilter.class);
    ClientAuthProcessor clientAuthProcessor;
    ClusterAuthProcessor clusterAuthProcessor;
    private final Session session = new Session();
    private final AuthMetricsRecorder recorder;

    public AuthSwapFilter(AuthSwapFilterConfig authSwapFilterConfig) {
        this(authSwapFilterConfig, new MicrometerAuthSwapMeter());
    }

    public AuthSwapFilter(AuthSwapFilterConfig authSwapFilterConfig, AuthSwapMeter meter) {
        this.recorder = new AuthMetricsRecorder(meter);
        ClientAuth clientAuth = authSwapFilterConfig.config().clientAuth();
        if (clientAuth.sasl() != null) {
            this.clientAuthProcessor = new SaslClientAuthProcessor(authSwapFilterConfig, this.session, this.recorder);
        } else if (clientAuth.ssl() != null) {
            this.clientAuthProcessor = new MtlsClientAuthProcessor(this.session, this.recorder, clientAuth.ssl());
        } else {
            throw new IllegalArgumentException("No valid auth swap strategy configured");
        }
        this.clusterAuthProcessor = new SaslClusterAuthProcessor(authSwapFilterConfig, this.session, this.recorder);
    }

    public boolean shouldHandleRequest(ApiKeys apiKey, short apiVersion) {
        if (apiKey == ApiKeys.API_VERSIONS) {
            return false;
        }
        boolean needsClientAuth = this.clientAuthProcessor.shouldHandleRequest(apiKey, apiVersion);
        boolean needsGatewayAuth = this.clusterAuthProcessor.shouldInterceptForGatewayAuth(apiKey, apiVersion);
        boolean needsGatewayReauth = this.clusterAuthProcessor.shouldInterceptForGatewayReauth(apiKey, apiVersion);
        return needsClientAuth || needsGatewayAuth || needsGatewayReauth;
    }

    public CompletionStage<RequestFilterResult> onRequest(ApiKeys apiKey, RequestHeaderData header, ApiMessage request, FilterContext context) {
        short apiVersion = header.requestApiVersion();
        boolean needsClientAuth = this.clientAuthProcessor.shouldHandleRequest(apiKey, apiVersion);
        boolean needsClusterAuth = this.clusterAuthProcessor.shouldInterceptForGatewayAuth(apiKey, apiVersion) || this.clusterAuthProcessor.shouldInterceptForGatewayReauth(apiKey, apiVersion);
        Map<String, String> authTags = Map.of("route", context.getVirtualClusterName());
        if (needsClientAuth && !this.session.isClusterConnectionEstablished().booleanValue()) {
            return this.executeWithLatencyRecording(() -> this.handleFirstTimeAuthentication(apiKey, header, request, context), authTags);
        }
        if (needsClientAuth) {
            return this.executeWithLatencyRecording(() -> this.handleClientReauthentication(apiKey, header, request, context, needsClusterAuth), authTags);
        }
        if (needsClusterAuth) {
            return this.executeWithLatencyRecording(() -> this.handleClusterOnlyAuthentication(header, request, context), authTags);
        }
        return context.forwardRequest(header, request);
    }

    private CompletionStage<RequestFilterResult> handleFirstTimeAuthentication(ApiKeys apiKey, RequestHeaderData header, ApiMessage request, FilterContext context) {
        return this.handleClientAuthentication(apiKey, header, request, context).thenCompose(clientResult -> {
            if (clientResult.isSuccess()) {
                return this.handleClusterAuthenticationWithFallback(header, request, context, clientResult.getClientId(), (AuthResult)clientResult);
            }
            return this.handleAuthenticationFailure(request, (AuthResult)clientResult, context);
        });
    }

    private CompletionStage<RequestFilterResult> handleClientReauthentication(ApiKeys apiKey, RequestHeaderData header, ApiMessage request, FilterContext context, boolean needsClusterAuth) {
        return this.handleClientAuthentication(apiKey, header, request, context).thenCompose(clientResult -> {
            if (clientResult.isSuccess()) {
                if (needsClusterAuth) {
                    return this.handleClusterAuthenticationWithFallback(header, request, context, clientResult.getClientId(), (AuthResult)clientResult);
                }
                return this.createResponseFromAuthResult((AuthResult)clientResult, context, header, request);
            }
            return this.handleAuthenticationFailure(request, (AuthResult)clientResult, context);
        });
    }

    private CompletionStage<RequestFilterResult> handleClusterOnlyAuthentication(RequestHeaderData header, ApiMessage request, FilterContext context) {
        return this.handleClusterAuthentication(context, this.clientAuthProcessor.getIncomingClientId()).thenCompose(Result -> context.forwardRequest(header, request)).exceptionallyCompose(throwable -> this.createClusterAuthFailureResponse(request, context, (Throwable)throwable, Errors.UNKNOWN_SERVER_ERROR.code()));
    }

    private CompletionStage<RequestFilterResult> handleClusterAuthenticationWithFallback(RequestHeaderData header, ApiMessage request, FilterContext context, String clientId, AuthResult clientResult) {
        return this.handleClusterAuthentication(context, clientId).thenCompose(result -> this.createResponseFromAuthResult(clientResult, context, header, request)).exceptionallyCompose(throwable -> {
            if (clientResult.getFailureResponseBuilder() != null) {
                Object responseData = clientResult.getFailureResponseBuilder().build(throwable.getMessage());
                return context.requestFilterResultBuilder().shortCircuitResponse((ApiMessage)responseData).withCloseConnection().completed();
            }
            if (clientResult.getFilterResult() != null) {
                return CompletableFuture.completedStage(clientResult.getFilterResult());
            }
            return this.createClusterAuthFailureResponse(request, context, (Throwable)throwable, Errors.UNKNOWN_SERVER_ERROR.code());
        });
    }

    private CompletionStage<Void> handleClusterAuthentication(FilterContext context, String clientId) {
        return this.clusterAuthProcessor.authenticate(context, clientId).thenCompose(authData -> {
            LOGGER.debug("Cluster authentication successful for client: {} for channel: {}", (Object)clientId, (Object)context.channelDescriptor());
            return CompletableFuture.completedStage(null);
        }).exceptionallyCompose(throwable -> {
            LOGGER.error("Gateway authentication failed for channel: {} with error: {}", (Object)context.channelDescriptor(), (Object)throwable.getMessage());
            return CompletableFuture.failedStage(new RuntimeException("Gateway authentication failed: " + throwable.getMessage(), (Throwable)throwable));
        });
    }

    private CompletionStage<AuthResult> handleClientAuthentication(ApiKeys apiKey, RequestHeaderData header, ApiMessage request, FilterContext context) {
        return this.clientAuthProcessor.authenticate(apiKey, header, request, context);
    }

    private CompletionStage<RequestFilterResult> handleAuthenticationFailure(ApiMessage request, AuthResult authResult, FilterContext context) {
        if (authResult.getFilterResult() != null) {
            return CompletableFuture.completedStage(authResult.getFilterResult());
        }
        if (authResult.getError() != null) {
            Throwable error = authResult.getError();
            LOGGER.error("Client authentication failed for channel: {} with error:{}", (Object)context.channelDescriptor(), (Object)error.getMessage());
            short errorCode = authResult.getErrorCode() != 0 ? authResult.getErrorCode() : Errors.UNKNOWN_SERVER_ERROR.code();
            return this.createFailureResponse(request, context, errorCode, authResult.getError().getMessage());
        }
        LOGGER.error("Client authentication failed with no error details for channel: {}", (Object)context.channelDescriptor());
        return context.requestFilterResultBuilder().withCloseConnection().completed();
    }

    private CompletionStage<RequestFilterResult> createFailureResponse(ApiMessage request, FilterContext context, short errorCode, String errorMessage) {
        ApiMessage response = ErrorResponseBuilderRegistry.createErrorResponse(request.apiKey(), request, errorCode, errorMessage).orElse(null);
        if (response == null) {
            LOGGER.warn("Directly closing the connection for channel: {} as no response was created for the error: {}", (Object)context.channelDescriptor(), (Object)errorMessage);
            return context.requestFilterResultBuilder().withCloseConnection().completed();
        }
        return context.requestFilterResultBuilder().shortCircuitResponse(response).withCloseConnection().completed();
    }

    private CompletionStage<RequestFilterResult> createResponseFromAuthResult(AuthResult authResult, FilterContext context, RequestHeaderData header, ApiMessage request) {
        if (authResult.shouldForwardRequest()) {
            return context.forwardRequest(header, request);
        }
        if (authResult.getFilterResult() != null) {
            return CompletableFuture.completedStage(authResult.getFilterResult());
        }
        if (authResult.getResponseBuilder() != null) {
            Object responseData = authResult.getResponseBuilder().build();
            return context.requestFilterResultBuilder().shortCircuitResponse((ApiMessage)responseData).completed();
        }
        throw new IllegalStateException("AuthResult has no response data, response builder, or filter result");
    }

    private CompletionStage<RequestFilterResult> createClusterAuthFailureResponse(ApiMessage request, FilterContext context, Throwable throwable, short errorCode) {
        LOGGER.error("Cluster authentication failed for channel : {}, with error :{}", (Object)context.channelDescriptor(), (Object)throwable.getMessage());
        return this.createFailureResponse(request, context, errorCode, "Cluster authentication failed: " + throwable.getMessage());
    }

    private CompletionStage<RequestFilterResult> executeWithLatencyRecording(Callable<CompletionStage<RequestFilterResult>> authHandler, Map<String, String> authTags) {
        return this.recorder.executeAsyncAndRecordLatency("gateway_authswap_latency", authTags, authHandler);
    }
}

