/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.gateway.filter.authswap.client.sasl;

import io.confluent.gateway.filter.authswap.config.AuthSwapFilterConfig;
import io.confluent.gateway.filter.authswap.metrics.AuthMetricsRecorder;
import io.kroxylicious.proxy.filter.FilterContext;
import io.kroxylicious.proxy.filter.RequestFilterResult;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.CompletionStage;
import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.commons.lang3.StringUtils;
import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.message.RequestHeaderData;
import org.apache.kafka.common.message.SaslAuthenticateRequestData;
import org.apache.kafka.common.message.SaslHandshakeRequestData;
import org.apache.kafka.common.message.SaslHandshakeResponseData;
import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.protocol.Errors;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SaslServerHandler {
    private static final Logger LOGGER = LoggerFactory.getLogger(SaslServerHandler.class);
    private static final String GATEWAY_SASL_PROTOCOL_NAME = "gateway-kafka";
    private final AuthenticateCallbackHandler validatorCallbackHandler;
    private final String serverAuthMechanism;
    private final AuthMetricsRecorder recorder;
    private SaslServer saslServer;

    public SaslServerHandler(AuthSwapFilterConfig authSwapConfig, AuthMetricsRecorder recorder) {
        this.validatorCallbackHandler = authSwapConfig.validatorCallbackHandler();
        this.serverAuthMechanism = authSwapConfig.config().clientAuth().sasl().mechanism();
        this.recorder = recorder;
    }

    public CompletionStage<RequestFilterResult> handleSaslHandshakeRequest(RequestHeaderData header, SaslHandshakeRequestData request, FilterContext context) {
        if (this.saslServer != null) {
            LOGGER.error("SASL handshake request failed due to a not null SASL server (ILLEGAL_SASL_STATE) for channel: {} with mechanism: {}", (Object)context.channelDescriptor(), (Object)request.mechanism());
            this.recorder.incrementClientAuthFailure(context.getVirtualClusterName());
            return context.requestFilterResultBuilder().shortCircuitResponse((ApiMessage)new SaslHandshakeResponseData().setErrorCode(Errors.ILLEGAL_SASL_STATE.code())).withCloseConnection().completed();
        }
        try {
            if (!StringUtils.equals((CharSequence)this.serverAuthMechanism, (CharSequence)request.mechanism())) {
                LOGGER.error("SASL handshake request failed for channel: {} with mechanism: {} as the supported SASL mechanism is: {}", new Object[]{context.channelDescriptor(), request.mechanism(), this.serverAuthMechanism});
                this.recorder.incrementClientAuthFailure(context.getVirtualClusterName());
                return context.requestFilterResultBuilder().shortCircuitResponse((ApiMessage)new SaslHandshakeResponseData().setErrorCode(Errors.UNSUPPORTED_SASL_MECHANISM.code()).setMechanisms(List.of(this.serverAuthMechanism))).withCloseConnection().completed();
            }
            this.saslServer = Sasl.createSaslServer(this.serverAuthMechanism, GATEWAY_SASL_PROTOCOL_NAME, null, null, (CallbackHandler)this.validatorCallbackHandler);
        }
        catch (Exception e) {
            LOGGER.error("SASL handshake request failed for channel: {} with mechanism: {} due to error: {}", new Object[]{context.channelDescriptor(), request.mechanism(), e.getMessage()});
            this.recorder.incrementClientAuthFailure(context.getVirtualClusterName());
            return context.requestFilterResultBuilder().shortCircuitResponse((ApiMessage)new SaslHandshakeResponseData().setErrorCode(Errors.UNKNOWN_SERVER_ERROR.code())).withCloseConnection().completed();
        }
        LOGGER.debug("SASL handshake request completed for channel: {} with mechanism: {}", (Object)context.channelDescriptor(), (Object)request.mechanism());
        LOGGER.debug("Setting SASL state to AUTHENTICATING for channel: {}", (Object)context.channelDescriptor());
        return context.requestFilterResultBuilder().shortCircuitResponse((ApiMessage)new SaslHandshakeResponseData().setErrorCode(Errors.NONE.code()).setMechanisms(List.of(request.mechanism()))).completed();
    }

    public ClientAuthenticationResult validateClientCredentials(SaslAuthenticateRequestData request, FilterContext context) {
        LOGGER.debug("Validating client credentials for channel: {}", (Object)context.channelDescriptor());
        SaslServer server = this.saslServer;
        if (server == null) {
            LOGGER.error("SASL authenticate request failed for channel: {} as the SASL Server is null", (Object)context.channelDescriptor());
            return ClientAuthenticationResult.failed(Errors.ILLEGAL_SASL_STATE, (Exception)new SaslAuthenticationException("Unexpected SASL request"));
        }
        this.saslServer = null;
        try {
            byte[] authBytes = this.authenticate(server, request.authBytes(), context);
            String clientId = server.getAuthorizationID();
            Long sessionLifetimeMs = (Long)server.getNegotiatedProperty("CREDENTIAL.LIFETIME.MS");
            if (sessionLifetimeMs == null) {
                LOGGER.debug("Successfully authenticated incoming Client ID: {} for channel: {}", (Object)clientId, (Object)context.channelDescriptor());
                return ClientAuthenticationResult.success(clientId, null, authBytes);
            }
            LOGGER.debug("Successfully authenticated incoming Client ID: {} for channel: {} with session lifetime: {}", new Object[]{clientId, context.channelDescriptor(), sessionLifetimeMs});
            return ClientAuthenticationResult.success(clientId, sessionLifetimeMs, authBytes);
        }
        catch (Exception e) {
            LOGGER.error("SASL authenticate request failed for channel: {} due to error: {}", (Object)context.channelDescriptor(), (Object)e.getMessage());
            return ClientAuthenticationResult.failed(Errors.SASL_AUTHENTICATION_FAILED, e);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    byte[] authenticate(SaslServer server, byte[] authBytes, FilterContext context) throws SaslException {
        try {
            byte[] bytes = server.evaluateResponse(authBytes);
            if (!server.isComplete()) {
                throw new SaslAuthenticationException("SASL failed : " + new String(bytes, StandardCharsets.UTF_8));
            }
            byte[] byArray = bytes;
            return byArray;
        }
        finally {
            server.dispose();
        }
    }

    public record ClientAuthenticationResult(boolean success, boolean alreadyAuthenticated, byte[] authBytes, String clientId, Long sessionLifetimeMs, Errors error, Exception exception) {
        public static ClientAuthenticationResult success(String clientId, Long sessionLifetimeMs, byte[] authBytes) {
            return new ClientAuthenticationResult(true, false, authBytes, clientId, sessionLifetimeMs, null, null);
        }

        public static ClientAuthenticationResult failed(Errors error, Exception exception) {
            return new ClientAuthenticationResult(false, false, null, null, null, error, exception);
        }
    }
}

