/*
 * Decompiled with CFR 0.152.
 */
package io.kroxylicious.proxy.internal;

import edu.umd.cs.findbugs.annotations.Nullable;
import io.kroxylicious.proxy.frame.BareSaslRequest;
import io.kroxylicious.proxy.frame.BareSaslResponse;
import io.kroxylicious.proxy.frame.DecodedRequestFrame;
import io.kroxylicious.proxy.frame.DecodedResponseFrame;
import io.kroxylicious.proxy.internal.AuthenticationEvent;
import io.kroxylicious.proxy.internal.KafkaProxyExceptionMapper;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.kafka.common.errors.IllegalSaslStateException;
import org.apache.kafka.common.errors.InvalidRequestException;
import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.errors.UnsupportedSaslMechanismException;
import org.apache.kafka.common.message.ResponseHeaderData;
import org.apache.kafka.common.message.SaslAuthenticateRequestData;
import org.apache.kafka.common.message.SaslAuthenticateResponseData;
import org.apache.kafka.common.message.SaslHandshakeRequestData;
import org.apache.kafka.common.message.SaslHandshakeResponseData;
import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.protocol.Errors;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.plain.internals.PlainSaslServerProvider;
import org.apache.kafka.common.security.scram.internals.ScramMechanism;
import org.apache.kafka.common.security.scram.internals.ScramSaslServerProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KafkaAuthnHandler
extends ChannelInboundHandlerAdapter {
    private static final IllegalSaslStateException NOT_AUTHENTICATED_EXCEPTION = new IllegalSaslStateException("Not authenticated");
    private static final Logger LOG;
    private final List<String> enabledMechanisms;
    @Nullable
    SaslServer saslServer;
    private final Map<String, AuthenticateCallbackHandler> mechanismHandlers;
    State lastSeen;

    public KafkaAuthnHandler(Channel ch, Map<SaslMechanism, AuthenticateCallbackHandler> mechanismHandlers) {
        this(ch, State.START, mechanismHandlers);
    }

    KafkaAuthnHandler(Channel ch, State init, Map<SaslMechanism, AuthenticateCallbackHandler> mechanismHandlers) {
        this.lastSeen = init;
        LOG.debug("{}: Initial state {}", (Object)ch, (Object)this.lastSeen);
        this.mechanismHandlers = mechanismHandlers.entrySet().stream().collect(Collectors.toMap(e -> ((SaslMechanism)((Object)((Object)e.getKey()))).mechanismName(), Map.Entry::getValue));
        this.enabledMechanisms = List.copyOf(this.mechanismHandlers.keySet());
    }

    private InvalidRequestException illegalTransition(State next) {
        InvalidRequestException e = new InvalidRequestException("Illegal state transition from " + String.valueOf((Object)this.lastSeen) + " to " + String.valueOf((Object)next));
        this.lastSeen = State.FAILED;
        return e;
    }

    private void doTransition(Channel channel, State next) {
        State previous = this.lastSeen;
        switch (next.ordinal()) {
            case 1: {
                if (previous == State.START) break;
                throw this.illegalTransition(next);
            }
            case 2: 
            case 3: {
                if (previous == State.START || previous == State.API_VERSIONS) break;
                throw this.illegalTransition(next);
            }
            case 4: {
                if (previous == State.START || previous == State.SASL_HANDSHAKE_v0 || previous == State.UNFRAMED_SASL_AUTHENTICATE) break;
                throw this.illegalTransition(next);
            }
            case 5: {
                if (previous == State.SASL_HANDSHAKE_v1_PLUS || previous == State.FRAMED_SASL_AUTHENTICATE) break;
                throw this.illegalTransition(next);
            }
            case 7: {
                if (previous == State.FRAMED_SASL_AUTHENTICATE || previous == State.UNFRAMED_SASL_AUTHENTICATE) break;
                throw this.illegalTransition(next);
            }
            case 6: {
                break;
            }
            default: {
                throw this.illegalTransition(next);
            }
        }
        LOG.debug("{}: Transition from {} to {}", new Object[]{channel, this.lastSeen, next});
        this.lastSeen = next;
    }

    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if (msg instanceof BareSaslRequest) {
            this.handleBareRequest(ctx, (BareSaslRequest)msg);
        } else if (msg instanceof DecodedRequestFrame) {
            this.handleFramedRequest(ctx, (DecodedRequestFrame)msg);
        } else if (this.lastSeen == State.AUTHN_SUCCESS) {
            ctx.fireChannelRead(msg);
        } else {
            throw new IllegalStateException("Unexpected message " + String.valueOf(msg.getClass()));
        }
    }

    private void handleFramedRequest(ChannelHandlerContext ctx, DecodedRequestFrame<?> frame) throws SaslException {
        switch (frame.apiKey()) {
            case API_VERSIONS: {
                if (this.lastSeen != State.AUTHN_SUCCESS) {
                    this.doTransition(ctx.channel(), State.API_VERSIONS);
                }
                ctx.fireChannelRead(frame);
                return;
            }
            case SASL_HANDSHAKE: {
                this.doTransition(ctx.channel(), frame.apiVersion() == 0 ? State.SASL_HANDSHAKE_v0 : State.SASL_HANDSHAKE_v1_PLUS);
                this.onSaslHandshakeRequest(ctx, frame);
                return;
            }
            case SASL_AUTHENTICATE: {
                this.doTransition(ctx.channel(), State.FRAMED_SASL_AUTHENTICATE);
                this.onSaslAuthenticateRequest(ctx, frame);
                return;
            }
        }
        if (this.lastSeen == State.AUTHN_SUCCESS) {
            ctx.fireChannelRead(frame);
        } else {
            KafkaAuthnHandler.writeFramedResponse(ctx, frame, KafkaProxyExceptionMapper.errorResponseMessage(frame, (Throwable)NOT_AUTHENTICATED_EXCEPTION));
        }
    }

    private void handleBareRequest(ChannelHandlerContext ctx, BareSaslRequest msg) throws SaslException {
        if (this.lastSeen != State.SASL_HANDSHAKE_v0 && this.lastSeen != State.UNFRAMED_SASL_AUTHENTICATE) {
            this.lastSeen = State.FAILED;
            throw new InvalidRequestException("Bare SASL bytes without GSSAPI support or prior SaslHandshake");
        }
        this.doTransition(ctx.channel(), State.UNFRAMED_SASL_AUTHENTICATE);
        this.writeBareResponse(ctx, this.doEvaluateResponse(ctx, msg.bytes()));
    }

    private void writeBareResponse(ChannelHandlerContext ctx, byte[] bytes) throws SaslException {
        ctx.writeAndFlush((Object)new BareSaslResponse(bytes));
    }

    private void onSaslHandshakeRequest(ChannelHandlerContext ctx, DecodedRequestFrame<SaslHandshakeRequestData> data) throws SaslException {
        Errors error;
        String mechanism = ((SaslHandshakeRequestData)data.body()).mechanism();
        if (this.lastSeen == State.AUTHN_SUCCESS) {
            error = Errors.ILLEGAL_SASL_STATE;
        } else if (this.enabledMechanisms.contains(mechanism)) {
            AuthenticateCallbackHandler cbh = this.mechanismHandlers.get(mechanism);
            this.saslServer = Sasl.createSaslServer(mechanism, "kafka", null, null, (CallbackHandler)cbh);
            if (this.saslServer == null) {
                throw new IllegalStateException("SASL mechanism had no providers: " + mechanism);
            }
            error = Errors.NONE;
        } else {
            error = Errors.UNSUPPORTED_SASL_MECHANISM;
        }
        SaslHandshakeResponseData body = new SaslHandshakeResponseData().setMechanisms(this.enabledMechanisms).setErrorCode(error.code());
        KafkaAuthnHandler.writeFramedResponse(ctx, data, (ApiMessage)body);
        ctx.channel().read();
    }

    private void onSaslAuthenticateRequest(ChannelHandlerContext ctx, DecodedRequestFrame<SaslAuthenticateRequestData> data) {
        String errorMessage;
        Errors error;
        byte[] bytes = new byte[]{};
        try {
            bytes = this.doEvaluateResponse(ctx, ((SaslAuthenticateRequestData)data.body()).authBytes());
            error = Errors.NONE;
            errorMessage = null;
        }
        catch (SaslAuthenticationException e) {
            error = Errors.SASL_AUTHENTICATION_FAILED;
            errorMessage = e.getMessage();
        }
        catch (SaslException e) {
            error = Errors.SASL_AUTHENTICATION_FAILED;
            errorMessage = "An error occurred";
        }
        SaslAuthenticateResponseData body = new SaslAuthenticateResponseData().setErrorCode(error.code()).setErrorMessage(errorMessage).setAuthBytes(bytes);
        KafkaAuthnHandler.writeFramedResponse(ctx, data, (ApiMessage)body);
        ctx.channel().read();
    }

    private static void writeFramedResponse(ChannelHandlerContext ctx, DecodedRequestFrame<?> data, ApiMessage body) {
        ctx.writeAndFlush(new DecodedResponseFrame<ApiMessage>(data.apiVersion(), data.correlationId(), new ResponseHeaderData().setCorrelationId(data.correlationId()), body));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private byte[] doEvaluateResponse(ChannelHandlerContext ctx, byte[] authBytes) throws SaslException {
        byte[] bytes;
        Objects.requireNonNull(this.saslServer);
        try {
            bytes = this.saslServer.evaluateResponse(authBytes);
        }
        catch (SaslAuthenticationException e) {
            LOG.debug("{}: Authentication failed", (Object)ctx.channel());
            this.doTransition(ctx.channel(), State.FAILED);
            this.saslServer.dispose();
            throw e;
        }
        catch (Exception e) {
            LOG.debug("{}: Authentication failed", (Object)ctx.channel());
            this.doTransition(ctx.channel(), State.FAILED);
            this.saslServer.dispose();
            throw new SaslAuthenticationException(e.getMessage());
        }
        if (this.saslServer.isComplete()) {
            try {
                String authorizationId = this.saslServer.getAuthorizationID();
                Map<String, Object> properties = SaslMechanism.fromMechanismName(this.saslServer.getMechanismName()).negotiatedProperties(this.saslServer);
                this.doTransition(ctx.channel(), State.AUTHN_SUCCESS);
                LOG.debug("{}: Authentication successful, authorizationId={}, negotiatedProperties={}", new Object[]{ctx.channel(), authorizationId, properties});
                ctx.fireUserEventTriggered((Object)new AuthenticationEvent(authorizationId, properties));
            }
            finally {
                this.saslServer.dispose();
            }
        }
        return bytes;
    }

    static {
        PlainSaslServerProvider.initialize();
        ScramSaslServerProvider.initialize();
        LOG = LoggerFactory.getLogger(KafkaAuthnHandler.class);
    }

    static enum State {
        START,
        API_VERSIONS,
        SASL_HANDSHAKE_v0,
        SASL_HANDSHAKE_v1_PLUS,
        UNFRAMED_SASL_AUTHENTICATE,
        FRAMED_SASL_AUTHENTICATE,
        FAILED,
        AUTHN_SUCCESS;

    }

    /*
     * Uses 'sealed' constructs - enablewith --sealed true
     */
    public static enum SaslMechanism {
        PLAIN("PLAIN", null){

            @Override
            public Map<String, Object> negotiatedProperties(SaslServer saslServer) {
                return Map.of();
            }
        }
        ,
        SCRAM_SHA_256("SCRAM-SHA-256", ScramMechanism.SCRAM_SHA_256){

            @Override
            public Map<String, Object> negotiatedProperties(SaslServer saslServer) {
                Object lifetime = saslServer.getNegotiatedProperty("CREDENTIAL.LIFETIME.MS");
                return lifetime == null ? Map.of() : Map.of("CREDENTIAL.LIFETIME.MS", lifetime);
            }
        }
        ,
        SCRAM_SHA_512("SCRAM-SHA-512", ScramMechanism.SCRAM_SHA_512){

            @Override
            public Map<String, Object> negotiatedProperties(SaslServer saslServer) {
                Object lifetime = saslServer.getNegotiatedProperty("CREDENTIAL.LIFETIME.MS");
                return lifetime == null ? Map.of() : Map.of("CREDENTIAL.LIFETIME.MS", lifetime);
            }
        };

        private final String name;
        private final ScramMechanism scramMechanism;

        private SaslMechanism(String saslName, ScramMechanism scramMechanism) {
            this.name = saslName;
            this.scramMechanism = scramMechanism;
        }

        public String mechanismName() {
            return this.name;
        }

        static SaslMechanism fromMechanismName(String mechanismName) {
            switch (mechanismName) {
                case "PLAIN": {
                    return PLAIN;
                }
                case "SCRAM-SHA-256": {
                    return SCRAM_SHA_256;
                }
                case "SCRAM-SHA-512": {
                    return SCRAM_SHA_512;
                }
            }
            throw new UnsupportedSaslMechanismException(mechanismName);
        }

        public ScramMechanism scramMechanism() {
            return this.scramMechanism;
        }

        public abstract Map<String, Object> negotiatedProperties(SaslServer var1);
    }
}

