/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.gateway.filter.authswap.client;

import io.confluent.gateway.filter.authswap.Session;
import io.confluent.gateway.filter.authswap.client.AbstractClientAuthProcessor;
import io.confluent.gateway.filter.authswap.client.AuthResult;
import io.confluent.gateway.filter.authswap.client.FailureResponseBuilder;
import io.confluent.gateway.filter.authswap.client.ResponseBuilder;
import io.confluent.gateway.filter.authswap.client.sasl.SaslServerHandler;
import io.confluent.gateway.filter.authswap.config.AuthSwapFilterConfig;
import io.confluent.gateway.filter.authswap.metrics.AuthMetricsRecorder;
import io.kroxylicious.proxy.filter.FilterContext;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import org.apache.kafka.common.message.RequestHeaderData;
import org.apache.kafka.common.message.SaslAuthenticateRequestData;
import org.apache.kafka.common.message.SaslAuthenticateResponseData;
import org.apache.kafka.common.message.SaslHandshakeRequestData;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.protocol.Errors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SaslClientAuthProcessor
extends AbstractClientAuthProcessor {
    private static final Logger LOGGER = LoggerFactory.getLogger(SaslClientAuthProcessor.class);
    private final SaslServerHandler saslServerHandler;
    private final AuthSwapFilterConfig authSwapConfig;
    private final FailureResponseBuilder<SaslAuthenticateResponseData> failureResponseBuilder = errorMessage -> {
        SaslAuthenticateResponseData failureResponse = new SaslAuthenticateResponseData().setErrorCode(Errors.SASL_AUTHENTICATION_FAILED.code());
        if (errorMessage != null) {
            failureResponse.setErrorMessage(errorMessage);
        }
        return failureResponse;
    };
    private State saslState = State.PENDING;

    public SaslClientAuthProcessor(AuthSwapFilterConfig config, Session session, AuthMetricsRecorder recorder) {
        super(session, recorder);
        this.authSwapConfig = config;
        this.saslServerHandler = new SaslServerHandler(config, recorder);
    }

    @Override
    protected boolean shouldInterceptForAuthentication(ApiKeys apiKey, short apiVersion) {
        return apiKey == ApiKeys.SASL_HANDSHAKE || apiKey == ApiKeys.SASL_AUTHENTICATE || this.saslState != State.AUTHENTICATED;
    }

    @Override
    protected boolean shouldInterceptForReAuthentication(ApiKeys apiKey, short apiVersion) {
        return this.session.shouldClientReauthenticate();
    }

    private CompletionStage<Object> onSaslAuthenticateRequest(RequestHeaderData header, SaslAuthenticateRequestData request, FilterContext context) {
        SaslServerHandler.ClientAuthenticationResult authResult = this.saslServerHandler.validateClientCredentials(request, context);
        if (authResult.success()) {
            this.recorder.incrementClientAuthSuccess(context.getVirtualClusterName());
            this.setIncomingClientId(authResult.clientId());
            this.session.updateClientToGatewaySession(authResult.sessionLifetimeMs());
            LOGGER.debug("SASL authenticate request completed for channel: {} with session lifetime: {} ms", (Object)context.channelDescriptor(), (Object)authResult.sessionLifetimeMs());
            this.saslState = State.AUTHENTICATED;
            return CompletableFuture.completedStage(authResult);
        }
        this.recorder.incrementClientAuthFailure(context.getVirtualClusterName());
        return CompletableFuture.completedStage(authResult);
    }

    @Override
    public CompletionStage<AuthResult> authenticate(ApiKeys apiKey, RequestHeaderData header, ApiMessage request, FilterContext context) {
        return switch (this.saslState) {
            default -> throw new IncompatibleClassChangeError();
            case State.PENDING -> this.handleHandshakeStateForAuth(apiKey, header, request, context);
            case State.HANDSHAKE_DONE -> this.handleAuthenticateStateForAuth(apiKey, header, request, context);
            case State.AUTHENTICATED -> this.handleCompleteStateForAuth(apiKey, header, request, context);
        };
    }

    private CompletionStage<AuthResult> handleHandshakeStateForAuth(ApiKeys apiKey, RequestHeaderData header, ApiMessage request, FilterContext context) {
        if (apiKey != ApiKeys.SASL_HANDSHAKE) {
            this.recorder.incrementClientAuthFailure(context.getVirtualClusterName());
            return CompletableFuture.completedStage(AuthResult.failure(new IllegalStateException("Unexpected request: " + String.valueOf(apiKey) + " while expecting SASL_HANDSHAKE"), Errors.ILLEGAL_SASL_STATE.code()));
        }
        this.saslState = State.HANDSHAKE_DONE;
        LOGGER.debug("SASL_HANDSHAKE received in HANDSHAKE_REQUEST state, transitioning to AUTHENTICATE for channel: {}", (Object)context.channelDescriptor());
        try {
            SaslHandshakeRequestData handshakeRequest = (SaslHandshakeRequestData)request;
            return this.saslServerHandler.handleSaslHandshakeRequest(header, handshakeRequest, context).thenApply(filterResult -> AuthResult.continueWithResponse(filterResult)).exceptionally(throwable -> AuthResult.failure(throwable, Errors.SASL_AUTHENTICATION_FAILED.code()));
        }
        catch (Exception e) {
            this.recorder.incrementClientAuthFailure(context.getVirtualClusterName());
            return CompletableFuture.completedStage(AuthResult.failure(e, Errors.SASL_AUTHENTICATION_FAILED.code()));
        }
    }

    private CompletionStage<AuthResult> handleAuthenticateStateForAuth(ApiKeys apiKey, RequestHeaderData header, ApiMessage request, FilterContext context) {
        if (apiKey == ApiKeys.SASL_AUTHENTICATE) {
            try {
                SaslAuthenticateRequestData authRequest = (SaslAuthenticateRequestData)request;
                return this.onSaslAuthenticateRequest(header, authRequest, context).thenApply(authResultObj -> {
                    SaslServerHandler.ClientAuthenticationResult authResult = (SaslServerHandler.ClientAuthenticationResult)authResultObj;
                    if (authResult.success()) {
                        this.saslState = State.AUTHENTICATED;
                        ResponseBuilder<SaslAuthenticateResponseData> responseBuilder = () -> {
                            long currentSessionLifetime = this.session.calculateRemainingSessionTimeMs(this.authSwapConfig.config().clientAuth().connectionsMaxReAuthMs(), context);
                            return new SaslAuthenticateResponseData().setAuthBytes(authResult.authBytes()).setErrorCode(Errors.NONE.code()).setSessionLifetimeMs(currentSessionLifetime);
                        };
                        return AuthResult.successWithResponseBuilder(this.getIncomingClientId(), responseBuilder, this.failureResponseBuilder);
                    }
                    String errorMessage = null;
                    if (authResult.exception() != null && authResult.exception().getMessage() != null) {
                        errorMessage = authResult.exception().getMessage();
                    }
                    return AuthResult.failure(new Throwable(errorMessage), Errors.SASL_AUTHENTICATION_FAILED.code());
                });
            }
            catch (Exception e) {
                return CompletableFuture.completedStage(AuthResult.failure(e, Errors.SASL_AUTHENTICATION_FAILED.code()));
            }
        }
        this.recorder.incrementClientAuthFailure(context.getVirtualClusterName());
        return CompletableFuture.completedStage(AuthResult.failure(new IllegalStateException("Expected SASL_AUTHENTICATE but got " + String.valueOf(apiKey)), Errors.ILLEGAL_SASL_STATE.code()));
    }

    private CompletionStage<AuthResult> handleCompleteStateForAuth(ApiKeys apiKey, RequestHeaderData header, ApiMessage request, FilterContext context) {
        if (apiKey == ApiKeys.SASL_HANDSHAKE) {
            LOGGER.debug("SASL_HANDSHAKE received in COMPLETE state, starting client re-authentication, for channel: {}", (Object)context.channelDescriptor());
            this.saslState = State.PENDING;
            return this.handleHandshakeStateForAuth(apiKey, header, request, context);
        }
        LOGGER.error("Unexpected request {} in COMPLETE state for channel: {}, only SASL_HANDSHAKE is expected to trigger re-authentication, closing channel", (Object)apiKey, (Object)context.channelDescriptor());
        return CompletableFuture.completedStage(AuthResult.failure(new IllegalStateException("Unexpected request " + String.valueOf(apiKey) + " in COMPLETE state"), Errors.UNKNOWN_SERVER_ERROR.code()));
    }

    private static enum State {
        PENDING,
        HANDSHAKE_DONE,
        AUTHENTICATED;

    }
}

