/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.kafka.security.authenticator;

import java.nio.ByteBuffer;
import java.security.Principal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import javax.net.ssl.SSLSession;
import javax.security.sasl.SaslServer;
import org.apache.kafka.common.Configurable;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.config.SslClientAuth;
import org.apache.kafka.common.errors.SerializationException;
import org.apache.kafka.common.message.DefaultPrincipalData;
import org.apache.kafka.common.network.SaslChannelBuilder;
import org.apache.kafka.common.network.SslChannelBuilder;
import org.apache.kafka.common.protocol.ByteBufferAccessor;
import org.apache.kafka.common.protocol.Message;
import org.apache.kafka.common.protocol.MessageUtil;
import org.apache.kafka.common.protocol.Readable;
import org.apache.kafka.common.security.auth.AuthenticationContext;
import org.apache.kafka.common.security.auth.ConfluentPrincipal;
import org.apache.kafka.common.security.auth.KafkaPrincipal;
import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder;
import org.apache.kafka.common.security.auth.KafkaPrincipalSerde;
import org.apache.kafka.common.security.auth.SaslAuthenticationContext;
import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder;
import org.apache.kafka.common.security.kerberos.KerberosShortNamer;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.ssl.SslPrincipalMapper;
import org.apache.kafka.common.utils.ConfigUtils;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwt.MalformedClaimException;
import org.jose4j.jwt.consumer.JwtConsumer;
import org.jose4j.jwt.consumer.JwtConsumerBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OAuthKafkaPrincipalBuilder
implements KafkaPrincipalBuilder,
KafkaPrincipalSerde,
Configurable {
    public static final String CONFLUENT_IMPERSONATED_BY_CLAIM_NAME = "cp_proxy";
    private static final Logger log = LoggerFactory.getLogger(OAuthKafkaPrincipalBuilder.class);
    private static final String CONFLUENT_ISSUER = "Confluent";
    private static final String OAUTH_NEGOTIATED_TOKEN_PROPERTY_KEY = "OAUTHBEARER.token";
    private static final String CONFLUENT_GROUPS_CLAIM_NAME = "groups";
    private String oauthGroupsClaimName = "";
    private final JwtConsumer jwtConsumer = new JwtConsumerBuilder().setSkipSignatureVerification().setDisableRequireSignature().setSkipAllValidators().build();
    private SslClientAuth clientAuth;
    private SslPrincipalMapper sslPrincipalMapper;
    private DefaultKafkaPrincipalBuilder defaultKafkaPrincipalBuilder;
    private boolean isImpersonationTokenValidationEnabled;

    public void configure(Map<String, ?> configs) {
        KerberosShortNamer kerberosShortNamer = SaslChannelBuilder.createKerberosShortNamerFromConfigs(configs);
        String sslAuthConfig = configs.getOrDefault("ssl.client.auth", null);
        this.clientAuth = SslClientAuth.forConfig((String)sslAuthConfig);
        this.sslPrincipalMapper = SslChannelBuilder.createSslPrincipalMapperFromConfigs(configs);
        this.defaultKafkaPrincipalBuilder = new DefaultKafkaPrincipalBuilder(kerberosShortNamer, this.sslPrincipalMapper);
        this.oauthGroupsClaimName = this.getConfiguredOauthGroupsClaimName(configs);
        HashMap nextConfigs = new HashMap(configs);
        this.isImpersonationTokenValidationEnabled = ConfigUtils.getBoolean(nextConfigs, (String)"token.impersonation.validation", (boolean)true);
    }

    private String getConfiguredOauthGroupsClaimName(Map<String, ?> configs) {
        String configuredGroupsClaimName = (String)configs.get("confluent.oauth.groups.claim.name");
        if (configuredGroupsClaimName == null) {
            return "";
        }
        return configuredGroupsClaimName.trim();
    }

    public KafkaPrincipal build(AuthenticationContext context) {
        SaslServer saslServer;
        if (context instanceof SaslAuthenticationContext && "OAUTHBEARER".equals((saslServer = ((SaslAuthenticationContext)context).server()).getMechanismName())) {
            OAuthBearerToken token = (OAuthBearerToken)saslServer.getNegotiatedProperty(OAUTH_NEGOTIATED_TOKEN_PROPERTY_KEY);
            JwtClaims jwtClaims = this.getJwtClaims(token);
            this.validateImpersonationIdentity(token, jwtClaims, (SaslAuthenticationContext)context);
            return this.applyOAuthBearerPrincipalMapper(token, jwtClaims);
        }
        return Objects.requireNonNull(this.defaultKafkaPrincipalBuilder, "Principal builder instance has not yet been configured").build(context);
    }

    private JwtClaims getJwtClaims(OAuthBearerToken token) {
        try {
            return this.jwtConsumer.processToClaims(token.value());
        }
        catch (Exception e) {
            throw new KafkaException("Failed to read OAuthBearer token for '" + token.principalName() + "'", (Throwable)e);
        }
    }

    private void validateImpersonationIdentity(OAuthBearerToken token, JwtClaims claims, SaslAuthenticationContext context) {
        if (!claims.hasClaim(CONFLUENT_IMPERSONATED_BY_CLAIM_NAME) || !context.sslSession().isPresent()) {
            return;
        }
        if (!SslClientAuth.REQUIRED.equals((Object)this.clientAuth)) {
            log.debug("Skipping impersonation identity validation for:" + token.principalName() + " as client auth is not required.");
            return;
        }
        if (!this.isImpersonationTokenValidationEnabled) {
            log.debug("Skipping impersonation identity validation for:" + token.principalName() + " as it's disabled.");
            return;
        }
        SSLSession sslSession = (SSLSession)context.sslSession().get();
        if (!this.impersonationMatch(sslSession, token, claims)) {
            throw new KafkaException("Impersonation identity mismatch for '" + token.principalName() + "'");
        }
    }

    private boolean impersonationMatch(SSLSession sslSession, OAuthBearerToken token, JwtClaims claims) {
        try {
            Principal sslPrincipal = sslSession.getPeerPrincipal();
            String impersonatedBy = claims.getStringClaimValue(CONFLUENT_IMPERSONATED_BY_CLAIM_NAME);
            if (!sslPrincipal.getName().equals(impersonatedBy)) {
                log.error("Impersonation identity mismatch for '{}' - expected '{}', but got '{}'", new Object[]{token.principalName(), impersonatedBy, sslPrincipal.getName()});
                return false;
            }
            return true;
        }
        catch (Exception e) {
            throw new KafkaException("Failed to validate impersonation identity for '" + token.principalName() + "'", (Throwable)e);
        }
    }

    private KafkaPrincipal applyOAuthBearerPrincipalMapper(OAuthBearerToken oauthBearerToken, JwtClaims claims) {
        try {
            String groupsClaimName;
            String string = groupsClaimName = this.getJwtIssuerOrEmpty(claims).equals(CONFLUENT_ISSUER) ? CONFLUENT_GROUPS_CLAIM_NAME : this.oauthGroupsClaimName;
            if (claims.hasClaim(groupsClaimName)) {
                List groups = claims.getStringListClaimValue(groupsClaimName);
                log.debug("Creating ConfluentPrincipal for '{}' with groups: '{}'", (Object)oauthBearerToken.principalName(), (Object)groups);
                return new ConfluentPrincipal("User", oauthBearerToken.principalName(), oauthBearerToken.principalName(), Optional.empty(), false, new HashSet(groups));
            }
            return new KafkaPrincipal("User", oauthBearerToken.principalName());
        }
        catch (Exception e) {
            throw new KafkaException("Failed to map OAuthBearer token to ConfluentPrincipal for '" + oauthBearerToken.principalName() + "'", (Throwable)e);
        }
    }

    private String getJwtIssuerOrEmpty(JwtClaims jwtClaims) {
        try {
            return jwtClaims.getIssuer();
        }
        catch (MalformedClaimException ignored) {
            return "";
        }
    }

    public byte[] serialize(KafkaPrincipal principal) throws SerializationException {
        if (principal instanceof ConfluentPrincipal) {
            ConfluentPrincipal cp = (ConfluentPrincipal)principal;
            DefaultPrincipalData data = new DefaultPrincipalData().setType(principal.getPrincipalType()).setName(principal.getName()).setTokenAuthenticated(principal.tokenAuthenticated());
            if (cp.getGroups() != null && !cp.getGroups().isEmpty()) {
                data.setGroups(new ArrayList(cp.getGroups()));
            }
            return MessageUtil.toVersionPrefixedBytes((short)0, (Message)data);
        }
        return this.defaultKafkaPrincipalBuilder.serialize(principal);
    }

    public KafkaPrincipal deserialize(byte[] bytes) throws SerializationException {
        DefaultPrincipalData data;
        ByteBuffer buffer = ByteBuffer.wrap(bytes);
        short version = buffer.getShort();
        if (version < 0 || version > 0) {
            throw new SerializationException("Invalid principal data version " + version);
        }
        try {
            data = new DefaultPrincipalData((Readable)new ByteBufferAccessor(buffer), version);
        }
        catch (Throwable t) {
            throw new SerializationException("Failed to deserialize principal", t);
        }
        if (buffer.hasRemaining()) {
            throw new SerializationException("Failed to deserialize principal: " + buffer.remaining() + " bytes remaining after parsing");
        }
        if (data.groups() != null && !data.groups().isEmpty()) {
            return new ConfluentPrincipal(data.type(), data.name(), data.name(), Optional.empty(), data.tokenAuthenticated(), new HashSet(data.groups()));
        }
        return this.defaultKafkaPrincipalBuilder.deserialize(bytes);
    }
}

