/*
 * Decompiled with CFR 0.152.
 */
package io.kroxylicious.proxy.internal;

import edu.umd.cs.findbugs.annotations.Nullable;
import io.kroxylicious.proxy.filter.FilterAndInvoker;
import io.kroxylicious.proxy.filter.NetFilter;
import io.kroxylicious.proxy.frame.DecodedRequestFrame;
import io.kroxylicious.proxy.frame.RequestFrame;
import io.kroxylicious.proxy.internal.KafkaProxyBackendHandler;
import io.kroxylicious.proxy.internal.KafkaProxyFrontendHandler;
import io.kroxylicious.proxy.internal.ProxyChannelState;
import io.kroxylicious.proxy.internal.SaslDecodePredicate;
import io.kroxylicious.proxy.internal.codec.FrameOversizedException;
import io.kroxylicious.proxy.internal.util.Metrics;
import io.kroxylicious.proxy.internal.util.StableKroxyliciousLinkGenerator;
import io.kroxylicious.proxy.model.VirtualClusterModel;
import io.kroxylicious.proxy.service.HostPort;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Timer;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
import org.apache.kafka.common.errors.ApiException;
import org.apache.kafka.common.message.ApiVersionsRequestData;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.Errors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ProxyChannelStateMachine {
    private static final String DUPLICATE_INITIATE_CONNECT_ERROR = "NetFilter called NetFilterContext.initiateConnect() more than once";
    private static final Logger LOGGER = LoggerFactory.getLogger(ProxyChannelStateMachine.class);
    @Deprecated(since="0.13.0", forRemoval=true)
    private final Counter downstreamConnectionsCounter;
    @Deprecated(since="0.13.0", forRemoval=true)
    private final Counter upstreamConnectionsCounter;
    @Deprecated(since="0.13.0", forRemoval=true)
    private final Counter downstreamErrorCounter;
    @Deprecated(since="0.13.0", forRemoval=true)
    private final Counter upstreamErrorCounter;
    @Deprecated(since="0.13.0", forRemoval=true)
    private final Counter connectionAttemptsCounter;
    @Deprecated(since="0.13.0", forRemoval=true)
    private final Counter upstreamConnectionFailureCounter;
    private final Counter clientToProxyErrorCounter;
    private final Counter clientToProxyConnectionCounter;
    private final Counter proxyToServerConnectionCounter;
    private final Counter proxyToServerErrorCounter;
    private final Timer serverToProxyBackpressureMeter;
    private final Timer clientToProxyBackPressureMeter;
    @Nullable
    Timer.Sample clientToProxyBackpressureTimer;
    @Nullable
    Timer.Sample serverBackpressureTimer;
    private ProxyChannelState state = ProxyChannelState.Startup.STARTING_STATE;
    boolean serverReadsBlocked;
    boolean clientReadsBlocked;
    @Nullable
    private KafkaProxyFrontendHandler frontendHandler = null;
    @Nullable
    private KafkaProxyBackendHandler backendHandler;

    public ProxyChannelStateMachine(String clusterName, @Nullable Integer nodeId) {
        this.clientToProxyConnectionCounter = (Counter)Metrics.clientToProxyConnectionCounter(clusterName, nodeId).withTags(new String[0]);
        this.clientToProxyErrorCounter = (Counter)Metrics.clientToProxyErrorCounter(clusterName, nodeId).withTags(new String[0]);
        this.proxyToServerConnectionCounter = (Counter)Metrics.proxyToServerConnectionCounter(clusterName, nodeId).withTags(new String[0]);
        this.proxyToServerErrorCounter = (Counter)Metrics.proxyToServerErrorCounter(clusterName, nodeId).withTags(new String[0]);
        this.serverToProxyBackpressureMeter = (Timer)Metrics.serverToProxyBackpressureTimer(clusterName, nodeId).withTags(new String[0]);
        this.clientToProxyBackPressureMeter = (Timer)Metrics.clientToProxyBackpressureTimer(clusterName, nodeId).withTags(new String[0]);
        List<Tag> tags = Metrics.tags("virtualCluster", clusterName);
        this.downstreamConnectionsCounter = Metrics.taggedCounter("kroxylicious_downstream_connections", tags);
        this.downstreamErrorCounter = Metrics.taggedCounter("kroxylicious_downstream_errors", tags);
        this.upstreamConnectionsCounter = Metrics.taggedCounter("kroxylicious_upstream_connections", tags);
        this.connectionAttemptsCounter = Metrics.taggedCounter("kroxylicious_upstream_connection_attempts", tags);
        this.upstreamErrorCounter = Metrics.taggedCounter("kroxylicious_upstream_errors", tags);
        this.upstreamConnectionFailureCounter = Metrics.taggedCounter("kroxylicious_upstream_connection_failures", tags);
    }

    ProxyChannelState state() {
        return this.state;
    }

    void forceState(ProxyChannelState state, KafkaProxyFrontendHandler frontendHandler, @Nullable KafkaProxyBackendHandler backendHandler) {
        LOGGER.info("Forcing state to {} with {} and {}", new Object[]{state, frontendHandler, backendHandler});
        this.state = state;
        this.frontendHandler = frontendHandler;
        this.backendHandler = backendHandler;
    }

    public String toString() {
        return "StateHolder{state=" + String.valueOf(this.state) + ", serverReadsBlocked=" + this.serverReadsBlocked + ", clientReadsBlocked=" + this.clientReadsBlocked + ", frontendHandler=" + String.valueOf(this.frontendHandler) + ", backendHandler=" + String.valueOf((Object)this.backendHandler) + "}";
    }

    public String currentState() {
        return this.state().getClass().getSimpleName();
    }

    public void onClientUnwritable() {
        if (!this.serverReadsBlocked) {
            this.serverReadsBlocked = true;
            this.serverBackpressureTimer = Timer.start();
            Objects.requireNonNull(this.backendHandler).applyBackpressure();
        }
    }

    public void onClientWritable() {
        if (this.serverReadsBlocked) {
            this.serverReadsBlocked = false;
            if (this.serverBackpressureTimer != null) {
                this.serverBackpressureTimer.stop(this.serverToProxyBackpressureMeter);
                this.serverBackpressureTimer = null;
            }
            Objects.requireNonNull(this.backendHandler).relieveBackpressure();
        }
    }

    public void onServerUnwritable() {
        if (!this.clientReadsBlocked) {
            this.clientReadsBlocked = true;
            this.clientToProxyBackpressureTimer = Timer.start();
            Objects.requireNonNull(this.frontendHandler).applyBackpressure();
        }
    }

    public void onServerWritable() {
        if (this.clientReadsBlocked) {
            this.clientReadsBlocked = false;
            if (this.clientToProxyBackpressureTimer != null) {
                this.clientToProxyBackpressureTimer.stop(this.clientToProxyBackPressureMeter);
                this.clientToProxyBackpressureTimer = null;
            }
            Objects.requireNonNull(this.frontendHandler).relieveBackpressure();
        }
    }

    void onClientActive(KafkaProxyFrontendHandler frontendHandler) {
        if (ProxyChannelState.Startup.STARTING_STATE.equals(this.state)) {
            this.frontendHandler = frontendHandler;
            this.toClientActive(ProxyChannelState.Startup.STARTING_STATE.toClientActive(), frontendHandler);
        } else {
            this.illegalState("Client activation while not in the start state");
        }
    }

    void onNetFilterInitiateConnect(HostPort peer, List<FilterAndInvoker> filters, VirtualClusterModel virtualClusterModel, NetFilter netFilter) {
        ProxyChannelState proxyChannelState = this.state;
        if (proxyChannelState instanceof ProxyChannelState.SelectingServer) {
            ProxyChannelState.SelectingServer selectingServerState = (ProxyChannelState.SelectingServer)proxyChannelState;
            this.toConnecting(selectingServerState.toConnecting(peer), filters, virtualClusterModel);
        } else {
            this.illegalState("NetFilter called NetFilterContext.initiateConnect() more than once : netFilter='" + String.valueOf(netFilter) + "'");
        }
    }

    void onServerActive() {
        ProxyChannelState proxyChannelState = this.state();
        if (proxyChannelState instanceof ProxyChannelState.Connecting) {
            ProxyChannelState.Connecting connectedState = (ProxyChannelState.Connecting)proxyChannelState;
            this.toForwarding(connectedState.toForwarding());
        } else {
            this.illegalState("Server became active while not in the connecting state");
        }
    }

    void illegalState(String msg) {
        if (!(this.state instanceof ProxyChannelState.Closed)) {
            LOGGER.error("Unexpected event while in {} message: {}, closing channels with no client response.", (Object)this.state, (Object)msg);
            this.toClosed(null);
        }
    }

    void messageFromServer(Object msg) {
        Objects.requireNonNull(this.frontendHandler).forwardToClient(msg);
    }

    void serverReadComplete() {
        Objects.requireNonNull(this.frontendHandler).flushToClient();
    }

    void messageFromClient(Object msg) {
        Objects.requireNonNull(this.backendHandler).forwardToServer(msg);
    }

    void clientReadComplete() {
        if (this.state instanceof ProxyChannelState.Forwarding) {
            Objects.requireNonNull(this.backendHandler).flushToServer();
        }
    }

    void onClientRequest(SaslDecodePredicate dp, Object msg) {
        Objects.requireNonNull(this.frontendHandler);
        if (this.state() instanceof ProxyChannelState.Forwarding) {
            this.messageFromClient(msg);
        } else if (!this.onClientRequestBeforeForwarding(dp, msg)) {
            this.illegalState("Unexpected message received: " + (String)(msg == null ? "null" : "message class=" + String.valueOf(msg.getClass())));
        }
    }

    void assertIsConnecting(String msg) {
        if (!(this.state instanceof ProxyChannelState.Connecting)) {
            this.illegalState(msg);
        }
    }

    ProxyChannelState.SelectingServer enforceInSelectingServer(String errorMessage) {
        ProxyChannelState proxyChannelState = this.state;
        if (proxyChannelState instanceof ProxyChannelState.SelectingServer) {
            ProxyChannelState.SelectingServer selectingServerState = (ProxyChannelState.SelectingServer)proxyChannelState;
            return selectingServerState;
        }
        this.illegalState(errorMessage);
        throw new IllegalStateException("State required to be " + ProxyChannelState.SelectingServer.class.getSimpleName() + " but was " + this.currentState() + ":" + errorMessage);
    }

    void onServerInactive() {
        this.toClosed(null);
    }

    void onClientInactive() {
        this.toClosed(null);
    }

    void onServerException(@Nullable Throwable cause) {
        LOGGER.atWarn().setCause(LOGGER.isDebugEnabled() ? cause : null).addArgument((Object)(cause != null ? cause.getMessage() : "")).log("Exception from the server channel: {}. Increase log level to DEBUG for stacktrace");
        if (this.state instanceof ProxyChannelState.Connecting) {
            this.upstreamConnectionFailureCounter.increment();
        }
        this.upstreamErrorCounter.increment();
        this.proxyToServerErrorCounter.increment();
        this.toClosed(cause);
    }

    void onClientException(@Nullable Throwable cause, boolean tlsEnabled) {
        ApiException errorCodeEx;
        DecoderException de;
        Throwable throwable;
        if (cause instanceof DecoderException && (throwable = (de = (DecoderException)cause).getCause()) instanceof FrameOversizedException) {
            FrameOversizedException e = (FrameOversizedException)throwable;
            String tlsHint = tlsEnabled ? "" : " Possible unexpected TLS handshake? When connecting via TLS from your client, make sure to enable TLS for the Kroxylicious gateway (" + StableKroxyliciousLinkGenerator.INSTANCE.errorLink("clientTls") + ").";
            LOGGER.warn("Received over-sized frame from the client, max frame size bytes {}, received frame size bytes {} (hint: {} Other possible causes are: an oversized Kafka frame, or something unexpected like an HTTP request.)", new Object[]{e.getMaxFrameSizeBytes(), e.getReceivedFrameSizeBytes(), tlsHint});
            errorCodeEx = Errors.INVALID_REQUEST.exception();
        } else {
            LOGGER.atWarn().setCause(LOGGER.isDebugEnabled() ? cause : null).addArgument((Object)(cause != null ? cause.getMessage() : "")).log("Exception from the client channel: {}. Increase log level to DEBUG for stacktrace");
            errorCodeEx = Errors.UNKNOWN_SERVER_ERROR.exception();
        }
        this.downstreamErrorCounter.increment();
        this.clientToProxyErrorCounter.increment();
        this.toClosed((Throwable)errorCodeEx);
    }

    private void toClientActive(ProxyChannelState.ClientActive clientActive, KafkaProxyFrontendHandler frontendHandler) {
        this.setState(clientActive);
        frontendHandler.inClientActive();
        this.downstreamConnectionsCounter.increment();
        this.clientToProxyConnectionCounter.increment();
    }

    private void toConnecting(ProxyChannelState.Connecting connecting, List<FilterAndInvoker> filters, VirtualClusterModel virtualClusterModel) {
        this.setState(connecting);
        this.backendHandler = new KafkaProxyBackendHandler(this, virtualClusterModel);
        Objects.requireNonNull(this.frontendHandler).inConnecting(connecting.remote(), filters, this.backendHandler);
        this.connectionAttemptsCounter.increment();
        this.proxyToServerConnectionCounter.increment();
    }

    private void toForwarding(ProxyChannelState.Forwarding forwarding) {
        this.setState(forwarding);
        Objects.requireNonNull(this.frontendHandler).inForwarding();
        this.upstreamConnectionsCounter.increment();
    }

    private boolean onClientRequestBeforeForwarding(SaslDecodePredicate dp, Object msg) {
        Objects.requireNonNull(this.frontendHandler).bufferMsg(msg);
        ProxyChannelState proxyChannelState = this.state();
        if (proxyChannelState instanceof ProxyChannelState.ClientActive) {
            ProxyChannelState.ClientActive clientActive = (ProxyChannelState.ClientActive)proxyChannelState;
            return this.onClientRequestInClientActiveState(dp, msg, clientActive);
        }
        proxyChannelState = this.state();
        if (proxyChannelState instanceof ProxyChannelState.HaProxy) {
            ProxyChannelState.HaProxy haProxy = (ProxyChannelState.HaProxy)proxyChannelState;
            return this.onClientRequestInHaProxyState(dp, msg, haProxy);
        }
        proxyChannelState = this.state();
        if (proxyChannelState instanceof ProxyChannelState.ApiVersions) {
            ProxyChannelState.ApiVersions apiVersions = (ProxyChannelState.ApiVersions)proxyChannelState;
            return this.onClientRequestInApiVersionsState(dp, msg, apiVersions);
        }
        if (this.state() instanceof ProxyChannelState.SelectingServer) {
            return msg instanceof RequestFrame;
        }
        return this.state() instanceof ProxyChannelState.Connecting && msg instanceof RequestFrame;
    }

    private boolean onClientRequestInApiVersionsState(SaslDecodePredicate dp, Object msg, ProxyChannelState.ApiVersions apiVersions) {
        if (msg instanceof RequestFrame) {
            this.toSelectingServer(apiVersions.toSelectingServer());
            return true;
        }
        return false;
    }

    private boolean onClientRequestInHaProxyState(SaslDecodePredicate dp, Object msg, ProxyChannelState.HaProxy haProxy) {
        return this.transitionClientRequest(dp, msg, haProxy::toApiVersions, haProxy::toSelectingServer);
    }

    private boolean transitionClientRequest(SaslDecodePredicate dp, Object msg, Function<DecodedRequestFrame<ApiVersionsRequestData>, ProxyChannelState.ApiVersions> apiVersionsFactory, Function<DecodedRequestFrame<ApiVersionsRequestData>, ProxyChannelState.SelectingServer> selectingServerFactory) {
        if (ProxyChannelStateMachine.isMessageApiVersionsRequest(msg)) {
            DecodedRequestFrame apiVersionsFrame = (DecodedRequestFrame)msg;
            if (dp.isAuthenticationOffloadEnabled()) {
                this.toApiVersions(apiVersionsFactory.apply(apiVersionsFrame), apiVersionsFrame);
            } else {
                this.toSelectingServer(selectingServerFactory.apply(apiVersionsFrame));
            }
            return true;
        }
        if (msg instanceof RequestFrame) {
            this.toSelectingServer(selectingServerFactory.apply(null));
            return true;
        }
        return false;
    }

    private boolean onClientRequestInClientActiveState(SaslDecodePredicate dp, Object msg, ProxyChannelState.ClientActive clientActive) {
        if (msg instanceof HAProxyMessage) {
            HAProxyMessage haProxyMessage = (HAProxyMessage)msg;
            this.toHaProxy(clientActive.toHaProxy(haProxyMessage));
            return true;
        }
        return this.transitionClientRequest(dp, msg, clientActive::toApiVersions, clientActive::toSelectingServer);
    }

    private void toHaProxy(ProxyChannelState.HaProxy haProxy) {
        this.setState(haProxy);
    }

    private void toApiVersions(ProxyChannelState.ApiVersions apiVersions, DecodedRequestFrame<ApiVersionsRequestData> apiVersionsFrame) {
        this.setState(apiVersions);
        Objects.requireNonNull(this.frontendHandler).inApiVersions(apiVersionsFrame);
    }

    private void toSelectingServer(ProxyChannelState.SelectingServer selectingServer) {
        this.setState(selectingServer);
        Objects.requireNonNull(this.frontendHandler).inSelectingServer();
    }

    private void toClosed(@Nullable Throwable errorCodeEx) {
        if (this.state instanceof ProxyChannelState.Closed) {
            return;
        }
        this.setState(new ProxyChannelState.Closed());
        if (this.backendHandler != null) {
            this.backendHandler.inClosed();
        }
        if (this.frontendHandler != null) {
            this.frontendHandler.inClosed(errorCodeEx);
        }
    }

    private void setState(ProxyChannelState state) {
        LOGGER.trace("{} transitioning to {}", (Object)this, (Object)state);
        this.state = state;
    }

    private static boolean isMessageApiVersionsRequest(Object msg) {
        return msg instanceof DecodedRequestFrame && ((DecodedRequestFrame)msg).apiKey() == ApiKeys.API_VERSIONS;
    }
}

