/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.common.network;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.kafka.common.MetricName;
import org.apache.kafka.common.config.AbstractConfig;
import org.apache.kafka.common.metrics.KafkaMetric;
import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.network.CertStores;
import org.apache.kafka.common.network.ChannelBuilder;
import org.apache.kafka.common.network.ChannelBuilders;
import org.apache.kafka.common.network.ChannelState;
import org.apache.kafka.common.network.KafkaChannel;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.network.NetworkReceive;
import org.apache.kafka.common.network.NetworkTestUtils;
import org.apache.kafka.common.network.NioEchoServer;
import org.apache.kafka.common.network.Selector;
import org.apache.kafka.common.network.SelectorTest;
import org.apache.kafka.common.network.SslTransportLayer;
import org.apache.kafka.common.security.JaasContext;
import org.apache.kafka.common.security.TestSecurityConfig;
import org.apache.kafka.common.security.auth.KafkaPrincipal;
import org.apache.kafka.common.security.auth.SecurityProtocol;
import org.apache.kafka.common.security.authenticator.CredentialCache;
import org.apache.kafka.common.security.authenticator.LoginManager;
import org.apache.kafka.common.security.authenticator.TestJaasConfig;
import org.apache.kafka.common.security.scram.ScramCredential;
import org.apache.kafka.common.security.scram.internals.ScramFormatter;
import org.apache.kafka.common.security.scram.internals.ScramMechanism;
import org.apache.kafka.common.security.ssl.DefaultSslEngineFactory;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.test.TestUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;

@Tag(value="bazel:shard_count:2")
public class ReverseConnectionTest {
    private static final int BUFFER_SIZE = 4096;
    private final Time clientTime = new MockTime();
    private final Time serverTime = new MockTime();
    private final Semaphore pollSemaphore = new Semaphore(1);
    private final Metrics clientMetrics = new Metrics();
    private NioEchoServer server;
    private Selector selector;
    private Map<String, Object> saslClientConfigs;
    private Map<String, Object> saslServerConfigs;
    private CredentialCache credentialCache;

    @BeforeEach
    public void setup() throws Exception {
        LoginManager.closeAll();
        CertStores serverCertStores = new CertStores(true, "localhost");
        CertStores clientCertStores = new CertStores(false, "localhost");
        this.saslServerConfigs = serverCertStores.getTrustingConfig(clientCertStores);
        this.saslClientConfigs = clientCertStores.getTrustingConfig(serverCertStores);
        this.saslServerConfigs.put("ssl.engine.factory.class", DefaultSslEngineFactory.class);
        this.saslClientConfigs.put("ssl.engine.factory.class", DefaultSslEngineFactory.class);
        this.credentialCache = new CredentialCache();
    }

    @AfterEach
    public void teardown() throws Exception {
        if (this.server != null) {
            this.server.close();
        }
        if (this.selector != null) {
            this.selector.close();
        }
    }

    @Test
    public void testReverseConnectionSaslPlain() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
        this.configureMechanisms("PLAIN", Collections.singletonList("PLAIN"));
        this.server = this.createEchoServer(securityProtocol);
        String node = "0";
        this.createAndVerifyConnection(securityProtocol, node);
        this.reverseAndVerifyConnection(securityProtocol, node, "1", true);
    }

    @Test
    public void testReverseConnectionSaslScram() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        this.configureMechanisms("SCRAM-SHA-256", Collections.singletonList("SCRAM-SHA-256"));
        this.server = this.createEchoServer(securityProtocol);
        this.updateScramCredentialCache("SCRAM-SHA-256", "myuser", "mypassword");
        String node = "0";
        this.createAndVerifyConnection(securityProtocol, node);
        this.reverseAndVerifyConnection(securityProtocol, node, "1", true);
    }

    @Test
    public void testReverseConnectionSsl() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SSL;
        this.server = this.createEchoServer(securityProtocol);
        String node = "0";
        this.createAndVerifyConnection(securityProtocol, node);
        this.reverseAndVerifyConnection(securityProtocol, node, "1", true);
    }

    @Test
    public void testReverseConnectionSslWithBufferedRead() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SSL;
        this.server = this.createEchoServer(securityProtocol);
        String oldNode = "0";
        String newNode = "1";
        this.createAndVerifyConnection(securityProtocol, oldNode);
        this.pollSemaphore.acquire();
        Selector serverSelector = this.server.selector();
        Selector clientSelector = this.selector;
        KafkaChannel serverSideChannel = (KafkaChannel)serverSelector.channels().get(0);
        serverSelector.removeChannelWithoutClosing(serverSideChannel);
        KafkaChannel clientSideChannel = this.selector.channel(oldNode);
        clientSelector.removeChannelWithoutClosing(clientSideChannel);
        SslTransportLayer transportLayer = (SslTransportLayer)TestUtils.fieldValue(clientSideChannel, KafkaChannel.class, "transportLayer");
        ByteBuffer appReadBuffer = (ByteBuffer)TestUtils.fieldValue(transportLayer, SslTransportLayer.class, "appReadBuffer");
        Assertions.assertEquals((int)0, (int)appReadBuffer.position());
        byte[] messageBytes = "testMessage".getBytes(StandardCharsets.UTF_8);
        appReadBuffer.putInt(messageBytes.length);
        appReadBuffer.put(messageBytes);
        TestUtils.setFieldValue(transportLayer, "hasBytesBuffered", true);
        KafkaPrincipal principal = serverSideChannel.principal();
        KafkaChannel reverseClientSideChannel = serverSideChannel.reverse(newNode, unused -> {});
        clientSelector.addReverseChannel(reverseClientSideChannel);
        KafkaChannel reverseServerSideChannel = clientSideChannel.reverse(serverSideChannel.id(), null, principal, Optional.empty(), null, unused -> {});
        serverSelector.addReverseChannel(reverseServerSideChannel);
        this.pollSemaphore.release();
        TestUtils.waitForCondition(() -> {
            clientSelector.poll(1L);
            if (clientSelector.completedReceives().isEmpty()) {
                return false;
            }
            Assertions.assertEquals((int)1, (int)clientSelector.completedReceives().size());
            NetworkReceive receive = (NetworkReceive)clientSelector.completedReceives().iterator().next();
            Assertions.assertEquals((Object)"testMessage", (Object)new String(Utils.toArray((ByteBuffer)receive.payload()), StandardCharsets.UTF_8));
            return true;
        }, "Buffered receive not processed");
    }

    @Test
    public void testReverseBeforeAuthentication() throws Exception {
        String node = "0";
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        this.configureMechanisms("SCRAM-SHA-256", Collections.singletonList("SCRAM-SHA-256"));
        this.server = this.createEchoServer(securityProtocol);
        this.updateScramCredentialCache("SCRAM-SHA-256", "myuser", "mypassword");
        this.createSelector(securityProtocol, this.saslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port());
        this.selector.connect(node, addr, 4096, 4096);
        while (this.server.selector().channels().isEmpty()) {
            this.selector.poll(100L);
        }
        KafkaChannel serverSideChannel = (KafkaChannel)this.server.selector().channels().get(0);
        KafkaChannel clientSideChannel = (KafkaChannel)this.selector.channels().get(0);
        Assertions.assertFalse((boolean)serverSideChannel.ready());
        Assertions.assertFalse((boolean)clientSideChannel.ready());
        KafkaPrincipal principal = new KafkaPrincipal("User", "someuser");
        Assertions.assertThrows(IllegalStateException.class, () -> serverSideChannel.reverse("1", null));
        Assertions.assertThrows(IllegalStateException.class, () -> clientSideChannel.reverse(serverSideChannel.id(), null, principal, Optional.empty(), null, null));
        this.createAndVerifyConnection(securityProtocol, node);
        Assertions.assertThrows(IllegalStateException.class, () -> serverSideChannel.reverse("1", null));
        Assertions.assertThrows(IllegalStateException.class, () -> clientSideChannel.reverse(serverSideChannel.id(), null, principal, Optional.empty(), null, null));
    }

    @Test
    public void testReverseFailsIfAuthenticationFails() throws Exception {
        String node = "0";
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        this.configureMechanisms("SCRAM-SHA-256", Collections.singletonList("SCRAM-SHA-256"));
        this.server = this.createEchoServer(securityProtocol);
        this.createSelector(securityProtocol, this.saslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port());
        this.selector.connect(node, addr, 4096, 4096);
        while (this.server.selector().channels().isEmpty()) {
            this.selector.poll(100L);
        }
        KafkaChannel serverSideChannel = (KafkaChannel)this.server.selector().channels().get(0);
        KafkaChannel clientSideChannel = (KafkaChannel)this.selector.channels().get(0);
        NetworkTestUtils.waitForChannelClose(this.selector, node, ChannelState.State.AUTHENTICATION_FAILED);
        KafkaPrincipal principal = new KafkaPrincipal("User", "someuser");
        Assertions.assertThrows(IllegalStateException.class, () -> serverSideChannel.reverse("1", null));
        Assertions.assertThrows(IllegalStateException.class, () -> clientSideChannel.reverse(serverSideChannel.id(), null, principal, Optional.empty(), null, null));
    }

    @Test
    public void testClientIdleExpiry() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SSL;
        this.server = this.createEchoServer(securityProtocol);
        String node = "0";
        String reverseNode = "10";
        this.createAndVerifyConnection(securityProtocol, node);
        this.reverseAndVerifyConnection(securityProtocol, node, reverseNode, false);
        Assertions.assertEquals((int)1, (int)this.server.selector().channels().size());
        Assertions.assertEquals((int)1, (int)this.selector.channels().size());
        this.clientTime.sleep(TimeUnit.MINUTES.toMillis(10L));
        this.selector.poll(1L);
        Assertions.assertEquals(Collections.emptyList(), (Object)this.selector.channels());
        Assertions.assertEquals((Object)ChannelState.State.EXPIRED, (Object)((ChannelState)this.selector.disconnected().get(reverseNode)).state());
        TestUtils.waitForCondition(() -> this.server.selector().channels().isEmpty(), "Server channel not disconnected");
        SelectorTest.verifySelectorEmpty(this.selector);
        SelectorTest.verifySelectorEmpty(this.server.selector());
    }

    @Test
    public void testServerIdleExpiry() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SSL;
        this.server = this.createEchoServer(securityProtocol);
        String node = "0";
        String reverseNode = "10";
        this.createAndVerifyConnection(securityProtocol, node);
        this.reverseAndVerifyConnection(securityProtocol, node, reverseNode, false);
        Assertions.assertEquals((int)1, (int)this.server.selector().channels().size());
        Assertions.assertEquals((int)1, (int)this.selector.channels().size());
        this.serverTime.sleep(TimeUnit.MINUTES.toMillis(10L));
        TestUtils.waitForCondition(() -> this.server.selector().channels().isEmpty(), "Server channel not expired");
        NetworkTestUtils.waitForChannelClose(this.selector, reverseNode, ChannelState.State.READY);
        SelectorTest.verifySelectorEmpty(this.selector);
        SelectorTest.verifySelectorEmpty(this.server.selector());
    }

    private void configureMechanisms(String clientMechanism, List<String> serverMechanisms) {
        this.saslClientConfigs.put("sasl.mechanism", clientMechanism);
        this.saslServerConfigs.put("sasl.enabled.mechanisms", serverMechanisms);
        TestJaasConfig.createConfiguration(clientMechanism, serverMechanisms);
    }

    private NioEchoServer createEchoServer(SecurityProtocol securityProtocol) throws Exception {
        ListenerName listenerName = ListenerName.forSecurityProtocol((SecurityProtocol)securityProtocol);
        NioEchoServer server = new NioEchoServer(listenerName, securityProtocol, new TestSecurityConfig(this.saslServerConfigs), "localhost", null, this.credentialCache, 0, this.serverTime){

            @Override
            protected void poll() throws IOException {
                try {
                    ReverseConnectionTest.this.pollSemaphore.acquire();
                }
                catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
                super.poll();
                ReverseConnectionTest.this.pollSemaphore.release();
            }
        };
        server.start();
        return server;
    }

    private void updateScramCredentialCache(String mechanism, String username, String password) throws NoSuchAlgorithmException {
        ScramMechanism scramMechanism = ScramMechanism.forMechanismName((String)mechanism);
        ScramFormatter formatter = new ScramFormatter(scramMechanism);
        ScramCredential credential = formatter.generateCredential(password, 4096);
        this.credentialCache.cache(scramMechanism.mechanismName(), ScramCredential.class).put(username, (Object)credential);
    }

    private void createSelector(SecurityProtocol securityProtocol, Map<String, Object> clientConfigs) {
        if (this.selector != null) {
            this.selector.close();
            this.selector = null;
        }
        String saslMechanism = (String)this.saslClientConfigs.get("sasl.mechanism");
        ChannelBuilder channelBuilder = ChannelBuilders.clientChannelBuilder((SecurityProtocol)securityProtocol, (JaasContext.Type)JaasContext.Type.CLIENT, (AbstractConfig)new TestSecurityConfig(clientConfigs), null, (String)saslMechanism, (Time)this.clientTime, (boolean)true, (LogContext)new LogContext());
        this.selector = new Selector(5000L, this.clientMetrics, this.clientTime, "MetricGroup", channelBuilder, new LogContext());
    }

    private void createAndVerifyConnection(SecurityProtocol securityProtocol, String node) throws Exception {
        this.createSelector(securityProtocol, this.saslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port());
        this.selector.connect(node, addr, 4096, 4096);
        NetworkTestUtils.checkClientConnection(this.selector, node, 100, 10);
    }

    private void reverseAndVerifyConnection(SecurityProtocol securityProtocol, String oldNode, String newNode, boolean closeChannels) throws Exception {
        this.verifyClientMetric("reverse-connection-added-total", Optional.empty(), 0.0);
        this.verifyClientMetric("reverse-connection-removed-total", Optional.empty(), 0.0);
        if (securityProtocol == SecurityProtocol.SSL || securityProtocol == SecurityProtocol.SASL_SSL) {
            this.verifyServerMetric("connections", Optional.of("cipher"), 1.0);
            this.verifyClientMetric("connections", Optional.of("cipher"), 1.0);
        }
        this.pollSemaphore.acquire();
        Selector serverSelector = this.server.selector();
        Assertions.assertEquals((int)1, (int)serverSelector.channels().size());
        Selector clientSelector = this.selector;
        Assertions.assertEquals((int)1, (int)clientSelector.channels().size());
        AtomicInteger closedServerChannels = new AtomicInteger();
        AtomicInteger closedClientChannels = new AtomicInteger();
        KafkaChannel serverSideChannel = (KafkaChannel)serverSelector.channels().get(0);
        serverSelector.removeChannelWithoutClosing(serverSideChannel);
        KafkaChannel clientSideChannel = this.selector.channel(oldNode);
        clientSelector.removeChannelWithoutClosing(clientSideChannel);
        SelectorTest.verifySelectorEmpty(serverSelector);
        SelectorTest.verifySelectorEmpty(clientSelector);
        KafkaPrincipal principal = serverSideChannel.principal();
        KafkaChannel reverseClientSideChannel = serverSideChannel.reverse(newNode, unused -> closedServerChannels.incrementAndGet());
        clientSelector.addReverseChannel(reverseClientSideChannel);
        if (securityProtocol == SecurityProtocol.SSL || securityProtocol == SecurityProtocol.SASL_SSL) {
            this.verifyServerMetric("connections", Optional.of("cipher"), 0.0);
            this.verifyClientMetric("connections", Optional.of("cipher"), 2.0);
        }
        KafkaChannel reverseServerSideChannel = clientSideChannel.reverse(serverSideChannel.id(), null, serverSideChannel.principal(), Optional.empty(), serverSideChannel.authenticationContext(), unused -> closedClientChannels.incrementAndGet());
        serverSelector.addReverseChannel(reverseServerSideChannel);
        if (securityProtocol == SecurityProtocol.SSL || securityProtocol == SecurityProtocol.SASL_SSL) {
            this.verifyServerMetric("connections", Optional.of("cipher"), 1.0);
            this.verifyClientMetric("connections", Optional.of("cipher"), 1.0);
        }
        Assertions.assertEquals((int)1, (int)serverSelector.channels().size());
        Assertions.assertEquals((int)1, (int)clientSelector.channels().size());
        KafkaChannel serverReverseChannel = (KafkaChannel)serverSelector.channels().get(0);
        KafkaChannel clientReverseChannel = (KafkaChannel)clientSelector.channels().get(0);
        Assertions.assertEquals((Object)principal, (Object)serverReverseChannel.principal());
        Assertions.assertEquals((Object)principal.getName(), (Object)serverReverseChannel.publicCredential().authenticationId());
        Assertions.assertEquals((Object)ListenerName.forSecurityProtocol((SecurityProtocol)securityProtocol).value(), (Object)serverReverseChannel.authenticationContext().listenerName());
        Assertions.assertEquals((Object)serverReverseChannel.socketChannel().socket().getInetAddress(), (Object)serverReverseChannel.authenticationContext().clientAddress());
        Assertions.assertEquals((Object)securityProtocol, (Object)clientReverseChannel.authenticationContext().securityProtocol());
        this.pollSemaphore.release();
        NetworkTestUtils.checkClientConnection(this.selector, newNode, 100, 10);
        this.verifyClientMetric("reverse-connection-added-total", Optional.empty(), 1.0);
        this.verifyClientMetric("reverse-connection-removed-total", Optional.empty(), 1.0);
        if (closeChannels) {
            clientSelector.close(reverseClientSideChannel.id());
            Assertions.assertEquals((int)1, (int)closedServerChannels.get());
            TestUtils.waitForCondition(() -> closedClientChannels.get() == 1, "Close listener not invoked");
        }
    }

    private void verifyClientMetric(String name, Optional<String> tagName, double value) {
        this.verifyMetric(this.clientMetrics, name, tagName, value);
    }

    private void verifyServerMetric(String name, Optional<String> tagName, double value) {
        this.verifyMetric(this.server.metrics(), name, tagName, value);
    }

    private void verifyMetric(Metrics metrics, String name, Optional<String> tagName, double value) {
        Optional<KafkaMetric> metric = metrics.metrics().entrySet().stream().filter(e -> ((MetricName)e.getKey()).name().equals(name)).filter(e -> !tagName.isPresent() || ((MetricName)e.getKey()).tags().containsKey(tagName.get())).map(Map.Entry::getValue).findFirst();
        Assertions.assertTrue((boolean)metric.isPresent(), (String)("Metric not found " + name));
        Assertions.assertEquals((double)value, (double)((Number)metric.get().metricValue()).doubleValue(), (double)0.001);
    }
}

