/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.gateway.filter.authswap.cluster;

import io.confluent.gateway.filter.authswap.Session;
import io.confluent.gateway.filter.authswap.cluster.ClusterAuthProcessor;
import io.confluent.gateway.filter.authswap.cluster.sasl.SaslClientHandler;
import io.confluent.gateway.filter.authswap.config.AuthSwapFilterConfig;
import io.confluent.gateway.filter.authswap.metrics.AuthMetricsRecorder;
import io.confluent.gateway.filter.authswap.secretstore.Credential;
import io.kroxylicious.proxy.filter.FilterContext;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import org.apache.commons.lang3.StringUtils;
import org.apache.kafka.common.errors.IllegalSaslStateException;
import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.errors.UnsupportedSaslMechanismException;
import org.apache.kafka.common.message.RequestHeaderData;
import org.apache.kafka.common.message.SaslAuthenticateRequestData;
import org.apache.kafka.common.message.SaslAuthenticateResponseData;
import org.apache.kafka.common.message.SaslHandshakeRequestData;
import org.apache.kafka.common.message.SaslHandshakeResponseData;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.protocol.Errors;
import org.apache.kafka.common.requests.RequestHeader;
import org.apache.kafka.common.requests.SaslAuthenticateRequest;
import org.apache.kafka.common.requests.SaslHandshakeRequest;
import org.apache.kafka.common.utils.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SaslClusterAuthProcessor
implements ClusterAuthProcessor {
    private static final Logger LOGGER = LoggerFactory.getLogger(SaslClusterAuthProcessor.class);
    private static final String GATEWAY_CLIENT_ID = "gateway";
    private final AuthSwapFilterConfig authSwapConfig;
    private final SaslClientHandler saslClientHandler;
    private final Session session;
    private final String clientAuthMechanism;
    private final AuthMetricsRecorder recorder;
    private String originalExchangedClientId;

    public SaslClusterAuthProcessor(AuthSwapFilterConfig authSwapConfig, Session session, AuthMetricsRecorder recorder) {
        this.authSwapConfig = authSwapConfig;
        this.saslClientHandler = new SaslClientHandler(authSwapConfig);
        this.session = session;
        this.recorder = recorder;
        this.clientAuthMechanism = authSwapConfig.config().clusterAuth().sasl().mechanism();
    }

    @Override
    public CompletionStage<ClusterAuthProcessor.AuthData> authenticate(FilterContext context, String incomingClientId) {
        return this.sendHandshakeRequest(context).thenCompose(response -> this.processHandshakeResponse((SaslHandshakeResponseData)response, incomingClientId, context)).thenCompose(saslRequest -> this.sendSaslAuthenticateRequest((SaslRequest)saslRequest, context));
    }

    @Override
    public CompletionStage<ClusterAuthProcessor.AuthData> handleClusterAuthenticateResponse(Object response, FilterContext context) {
        if (!(response instanceof SaslAuthenticateResponseData)) {
            throw new IllegalArgumentException("Expected SaslAuthenticateResponseData but got: " + String.valueOf(response.getClass()));
        }
        SaslAuthenticateResponseData saslResponse = (SaslAuthenticateResponseData)response;
        return this.handleBrokerSaslAuthenticateResponse(saslResponse, context);
    }

    private CompletionStage<SaslHandshakeResponseData> sendHandshakeRequest(FilterContext context) {
        SaslHandshakeRequest saslHandshakeRequest = new SaslHandshakeRequest.Builder(new SaslHandshakeRequestData().setMechanism(this.clientAuthMechanism)).build(ApiKeys.SASL_HANDSHAKE.latestVersion());
        RequestHeader requestHeader = this.createRequestHeader(ApiKeys.SASL_HANDSHAKE, saslHandshakeRequest.version(), GATEWAY_CLIENT_ID);
        return this.toCompletableFuture(context.sendRequest(requestHeader.data(), (ApiMessage)saslHandshakeRequest.data()));
    }

    private CompletionStage<SaslRequest> processHandshakeResponse(SaslHandshakeResponseData saslHandshakeResponse, String incomingClientId, FilterContext context) {
        Errors error = Errors.forCode((short)saslHandshakeResponse.errorCode());
        if (error != Errors.NONE) {
            Exception exception = this.createHandshakeException(error, saslHandshakeResponse);
            this.cleanupResourcesOnFailure(context);
            return CompletableFuture.failedStage(exception);
        }
        if (!this.saslClientHandler.isInitialized()) {
            try {
                String storeType = this.authSwapConfig.secretStore().getType();
                Map<String, String> secretStoreTags = Map.of("type", storeType, "route", context.getVirtualClusterName());
                Credential exchangedCredential = this.recorder.executeAndRecordLatency("gateway_authswap_secret_store_latency", secretStoreTags, () -> this.authSwapConfig.secretStore().exchangeCredential(Credential.withUsername(incomingClientId)));
                try {
                    this.validatePrincipalConsistency(exchangedCredential.username(), context);
                }
                catch (SaslAuthenticationException e) {
                    this.cleanupResourcesOnFailure(context);
                    return CompletableFuture.failedStage(e);
                }
                LOGGER.info("Successfully exchanged credential for incoming principal: {} with outgoing principal: {} on channel: {}", new Object[]{incomingClientId, exchangedCredential.username(), context.channelDescriptor()});
                this.saslClientHandler.initializeSaslClient(exchangedCredential, context);
            }
            catch (Exception e) {
                String errorMessage = String.format("Failed to exchange credential for incoming principal: %s on channel: %s due to error: %s", incomingClientId, context.channelDescriptor(), e.getMessage());
                LOGGER.error(errorMessage);
                this.cleanupResourcesOnFailure(context);
                return CompletableFuture.failedStage(new Exception(String.format("Failed to exchange credential for incoming principal: %s due to error: %s", incomingClientId, e.getMessage()), e));
            }
        }
        if (!this.saslClientHandler.isComplete()) {
            try {
                return CompletableFuture.completedStage(this.createSaslRequest(new byte[0], true, context));
            }
            catch (Exception e) {
                this.cleanupResourcesOnFailure(context);
                return CompletableFuture.failedStage(e);
            }
        }
        return CompletableFuture.failedStage((Throwable)new IllegalSaslStateException("SASL client is already complete, cannot create initial SASL request."));
    }

    private CompletionStage<ClusterAuthProcessor.AuthData> sendSaslAuthenticateRequest(SaslRequest saslRequest, FilterContext context) {
        return this.toCompletableFuture(context.sendRequest(saslRequest.header.data(), (ApiMessage)saslRequest.request.data())).thenCompose(brokerResponse -> this.handleBrokerSaslAuthenticateResponse((SaslAuthenticateResponseData)brokerResponse, context));
    }

    Exception createHandshakeException(Errors error, SaslHandshakeResponseData response) {
        if (error == Errors.UNSUPPORTED_SASL_MECHANISM) {
            return new UnsupportedSaslMechanismException(String.format("Gateway client SASL mechanism '%s' not enabled in the server, enabled mechanisms are %s", this.clientAuthMechanism, response.mechanisms()));
        }
        if (error == Errors.ILLEGAL_SASL_STATE) {
            return new IllegalSaslStateException(String.format("Unexpected handshake request with gateway client mechanism %s, enabled mechanisms are %s", this.clientAuthMechanism, response.mechanisms()));
        }
        return new IllegalSaslStateException(String.format("Unknown error code %s, gateway client mechanism is %s, enabled mechanisms are %s", response.errorCode(), this.clientAuthMechanism, response.mechanisms()));
    }

    public CompletionStage<ClusterAuthProcessor.AuthData> handleBrokerSaslAuthenticateResponse(SaslAuthenticateResponseData response, FilterContext context) {
        try {
            byte[] authBytes = this.handleSaslAuthenticateResponse(response);
            LOGGER.debug("Handling broker SASL authentication for gateway auth, gatewayToBrokerSessionLifetimeMs is: {} ms for channel: {}", (Object)response.sessionLifetimeMs(), (Object)context.channelDescriptor());
            this.session.updateGatewayToClusterSession(response.sessionLifetimeMs());
            if (!this.saslClientHandler.isComplete()) {
                return this.handlePendingResponse(authBytes, context);
            }
            if (this.saslClientHandler.isComplete()) {
                return this.handleSuccessfulAuthentication(authBytes, context);
            }
            LOGGER.error("SASL authenticate request failed for channel: {} as Sasl Client was not completed", (Object)context.channelDescriptor());
            this.cleanupResourcesOnFailure(context);
            throw new RuntimeException("SASL client was not completed");
        }
        catch (Exception e) {
            LOGGER.error("Error handling SaslAuthenticateResponse for channel: {}", (Object)context.channelDescriptor(), (Object)e);
            this.cleanupResourcesOnFailure(context);
            throw new RuntimeException("Broker SASL authentication failed due to an unexpected error", e);
        }
    }

    CompletionStage<ClusterAuthProcessor.AuthData> handlePendingResponse(byte[] authBytes, FilterContext context) {
        try {
            SaslRequest saslRequest = this.createSaslRequest(authBytes, false, context);
            if (saslRequest.saslToken != null) {
                CompletableFuture stage = this.toCompletableFuture(context.sendRequest(saslRequest.header.data(), (ApiMessage)saslRequest.request.data()));
                return ((CompletableFuture)stage.thenApply(saslAuthenticateResponse -> {
                    byte[] responseAuthBytes = null;
                    if (this.saslClientHandler.isComplete()) {
                        responseAuthBytes = this.handleSaslAuthenticateResponse((SaslAuthenticateResponseData)saslAuthenticateResponse);
                    } else {
                        LOGGER.error("SaslClient is not complete after sending token for channel: {}", (Object)context.channelDescriptor());
                    }
                    return responseAuthBytes;
                })).thenCompose(responseAuthBytes -> {
                    if (responseAuthBytes != null) {
                        return this.handleSuccessfulAuthentication((byte[])responseAuthBytes, context);
                    }
                    LOGGER.error("SASL authenticate request failed for channel: {} as response auth bytes was null", (Object)context.channelDescriptor());
                    this.cleanupResourcesOnFailure(context);
                    throw new RuntimeException("Broker SASL authentication failed due to null auth bytes response while handling pending response.");
                });
            }
        }
        catch (Exception e) {
            LOGGER.error("Error in handling authentication for channel: {}", (Object)context.channelDescriptor(), (Object)e);
            this.cleanupResourcesOnFailure(context);
            throw new RuntimeException("Error in handling authentication", e);
        }
        if (this.saslClientHandler.isComplete()) {
            return this.handleSuccessfulAuthentication(authBytes, context);
        }
        this.cleanupResourcesOnFailure(context);
        throw new RuntimeException("Broker SASL authentication failed due to an unexpected error while handling pending response.");
    }

    CompletionStage<ClusterAuthProcessor.AuthData> handleSuccessfulAuthentication(byte[] authBytes, FilterContext context) {
        this.saslClientHandler.dispose(context);
        this.recorder.incrementClusterAuthSuccess(context.getVirtualClusterName());
        return CompletableFuture.completedStage(new ClusterAuthProcessor.AuthData(authBytes));
    }

    private void cleanupResourcesOnFailure(FilterContext context) {
        this.saslClientHandler.dispose(context);
        this.recorder.incrementClusterAuthFailure(context.getVirtualClusterName());
    }

    private void validatePrincipalConsistency(String newPrincipal, FilterContext context) {
        if (StringUtils.isBlank((CharSequence)this.originalExchangedClientId)) {
            this.originalExchangedClientId = newPrincipal;
        } else if (!StringUtils.equals((CharSequence)this.originalExchangedClientId, (CharSequence)newPrincipal)) {
            LOGGER.error("Principal mismatch during re-authentication for Original principal: {}, Current principal: {} on channel: {}", new Object[]{this.originalExchangedClientId, newPrincipal, context.channelDescriptor()});
            String errorMessage = String.format("Principal mismatch during re-authentication for Original principal: %s, Current principal: %s", this.originalExchangedClientId, newPrincipal);
            throw new SaslAuthenticationException(errorMessage);
        }
    }

    SaslRequest createSaslRequest(byte[] authBytes, boolean isInitial, FilterContext context) {
        try {
            byte[] saslToken = this.saslClientHandler.createSaslToken(authBytes, isInitial);
            if (saslToken != null) {
                ByteBuffer tokenBuf = ByteBuffer.wrap(saslToken);
                SaslAuthenticateRequestData saslAuthenticateRequestData = new SaslAuthenticateRequestData().setAuthBytes(tokenBuf.array());
                SaslAuthenticateRequest saslAuthenticateRequest = new SaslAuthenticateRequest.Builder(saslAuthenticateRequestData).build(ApiKeys.SASL_AUTHENTICATE.latestVersion());
                RequestHeader saslAuthenticateRequestHeader = this.createRequestHeader(ApiKeys.SASL_AUTHENTICATE, ApiKeys.SASL_AUTHENTICATE.latestVersion(), GATEWAY_CLIENT_ID);
                return new SaslRequest(saslAuthenticateRequestHeader, saslAuthenticateRequest, saslToken);
            }
        }
        catch (Exception e) {
            throw new IllegalStateException(e.getMessage(), e);
        }
        LOGGER.trace("Creating SASL Request was successful as no responses were pending for channel: {} ", (Object)context.channelDescriptor());
        return new SaslRequest(null, null, null);
    }

    byte[] handleSaslAuthenticateResponse(SaslAuthenticateResponseData responseEvent) {
        Errors error = Errors.forCode((short)responseEvent.errorCode());
        if (error != Errors.NONE) {
            String errMsg = responseEvent.errorMessage();
            throw errMsg == null ? error.exception() : error.exception(errMsg);
        }
        return Utils.copyArray((byte[])responseEvent.authBytes());
    }

    private RequestHeader createRequestHeader(ApiKeys apiKey, short version, String clientId) {
        short requestApiKey = apiKey.id;
        return new RequestHeader(new RequestHeaderData().setRequestApiKey(requestApiKey).setRequestApiVersion(version).setClientId(clientId), apiKey.requestHeaderVersion(version));
    }

    @Override
    public boolean shouldInterceptForGatewayAuth(ApiKeys apiKey, short apiVersion) {
        return this.session.isClusterConnectionEstablished() == false;
    }

    @Override
    public boolean shouldInterceptForGatewayReauth(ApiKeys apiKey, short apiVersion) {
        return this.session.shouldGatewayReauthenticate();
    }

    private <T> CompletableFuture<T> toCompletableFuture(CompletionStage<T> stage) {
        CompletableFuture future = new CompletableFuture();
        stage.whenComplete((result, throwable) -> {
            if (throwable != null) {
                future.completeExceptionally((Throwable)throwable);
            } else {
                future.complete(result);
            }
        });
        return future;
    }

    private record SaslRequest(RequestHeader header, SaslAuthenticateRequest request, byte[] saslToken) {
    }
}

