/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.security.auth.provider.oauth;

import io.confluent.kafka.common.multitenant.oauth.OAuthBearerJwsToken;
import io.confluent.kafka.multitenant.BasePhysicalClusterMetadata;
import io.confluent.kafka.multitenant.KafkaLogicalClusterMetadata;
import io.confluent.kafka.server.plugins.auth.SniValidationMode;
import io.confluent.kafka.server.plugins.auth.TrafficNetworkIdAuthenticator;
import io.confluent.kafka.server.plugins.auth.TrafficNetworkIdValidationMode;
import io.confluent.kafka.server.plugins.auth.oauth.JwtAuthenticatorConfig;
import io.confluent.kafka.traffic.TopicBasedTrafficNetworkIdRoutesStore;
import io.confluent.kafka.util.ClientContext;
import io.confluent.security.auth.metadata.AuthStore;
import io.confluent.security.authentication.AdmissionController;
import io.confluent.security.authentication.AuthenticationException;
import io.confluent.security.authentication.Authenticator;
import io.confluent.security.authentication.credential.BearerCredential;
import io.confluent.security.authentication.oauthbearer.Claims;
import io.confluent.security.authentication.oauthbearer.JwtAuthenticator;
import io.confluent.security.config.ConfigurationException;
import io.confluent.security.policyapi.engine.PolicyEngine;
import io.confluent.security.policyapi.engine.TrustPolicyEngine;
import io.confluent.security.trustservice.store.TrustCache;
import io.confluent.security.trustservice.store.data.IdentityPool;
import io.confluent.security.util.SecurityContext;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.AppConfigurationEntry;
import org.apache.kafka.common.network.CCloudTrafficType;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.authenticator.PathAwareSniHostName;
import org.apache.kafka.common.security.oauthbearer.CommonExtensionsValidatorCallback;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerExtensionsValidatorCallback;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerValidatorCallback;
import org.apache.kafka.common.security.oauthbearer.PreTokenValidationExtensionsValidatorCallback;
import org.apache.kafka.server.traffic.TrafficNetworkIdRoutes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class EnhancedOAuthBearerValidatorCallbackHandler
implements AuthenticateCallbackHandler {
    private static final Logger log = LoggerFactory.getLogger(EnhancedOAuthBearerValidatorCallbackHandler.class);
    private static final String AUTH_ERROR_MESSAGE = "Authentication failed";
    private AdmissionController admissionController;
    private BasePhysicalClusterMetadata<KafkaLogicalClusterMetadata> clusterMetadata;
    private SniValidationMode mode;
    private String networkIdValidationModeJaasConfigEntry;
    private String sessionUuid;
    private static final String OAUTH_IDENTITY_PROVIDER_ID_PROPERTY_KEY = "providerId";
    private static final String OAUTH_IDENTITY_PROPERTY_KEY = "identity";
    private static final String OAUTH_ORGANIZATION_ID_PROPERTY_KEY = "organizationId";
    private boolean configured = false;

    public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
        if (!"OAUTHBEARER".equals(saslMechanism)) {
            throw new IllegalArgumentException(String.format("Unexpected SASL mechanism: %s", saslMechanism));
        }
        if (Objects.requireNonNull(jaasConfigEntries).size() != 1 || jaasConfigEntries.get(0) == null) {
            throw new IllegalArgumentException(String.format("Must supply exactly 1 non-null JAAS mechanism configuration (size was %d)", jaasConfigEntries.size()));
        }
        HashMap moduleOptions = new HashMap(jaasConfigEntries.get(0).getOptions());
        JwtAuthenticator jwtAuthenticator = JwtAuthenticatorConfig.newInstance(moduleOptions).generateConfig(configs);
        Object uuid = configs.get("broker.session.uuid");
        if (uuid == null || uuid.toString().isEmpty()) {
            throw new ConfigurationException("Broker session UUID must be set in the Kafka config!");
        }
        this.sessionUuid = uuid.toString();
        this.clusterMetadata = BasePhysicalClusterMetadata.getInstance((String)this.sessionUuid);
        if (this.clusterMetadata == null) {
            throw new ConfigurationException("Could not get a PhysicalClusterMetadata instance with broker session UUID " + uuid);
        }
        this.mode = SniValidationMode.fromString((String)moduleOptions.get("sni_host_name_validation_mode"));
        this.networkIdValidationModeJaasConfigEntry = (String)moduleOptions.get("traffic_network_id_validation_mode");
        this.admissionController = new AdmissionController((Authenticator)jwtAuthenticator, () -> {
            AuthStore store = Objects.requireNonNull(AuthStore.getInstance((String)this.sessionUuid));
            return store.trustCache();
        }, (PolicyEngine)new TrustPolicyEngine());
        this.configured = true;
    }

    public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
        if (!this.configured) {
            throw new IllegalStateException("Callback handler not configured");
        }
        for (Callback callback : callbacks) {
            if (callback instanceof OAuthBearerValidatorCallback) {
                this.handleValidatorCallback((OAuthBearerValidatorCallback)callback);
                continue;
            }
            if (callback instanceof PreTokenValidationExtensionsValidatorCallback) {
                this.handlePreTokenValidationCallback((PreTokenValidationExtensionsValidatorCallback)callback);
                continue;
            }
            if (callback instanceof OAuthBearerExtensionsValidatorCallback) {
                this.handleExtensionsCallback((OAuthBearerExtensionsValidatorCallback)callback);
                continue;
            }
            throw new UnsupportedCallbackException(callback);
        }
    }

    public void close() {
    }

    private void handlePreTokenValidationCallback(PreTokenValidationExtensionsValidatorCallback callback) {
        String logicalCluster = (String)callback.inputExtensions().map().get("logicalCluster");
        String identityPoolId = (String)callback.inputExtensions().map().get("identityPoolId");
        if (!this.doesClusterExtensionExist((CommonExtensionsValidatorCallback)callback, logicalCluster)) {
            return;
        }
        if (identityPoolId != null) {
            TrustCache cache = Objects.requireNonNull(AuthStore.getInstance((String)this.sessionUuid).trustCache());
            IdentityPool pool = cache.identityPool(identityPoolId);
            if (pool == null) {
                AuthenticationException exe = new AuthenticationException(String.format("Token precheck failed - unknown Identity Pool %s.", identityPoolId), "IDENTITY_POOL_NOT_FOUND");
                this.handleExtensionError((CommonExtensionsValidatorCallback)callback, exe.getMessage(), "identityPoolId", exe.reasonCode());
                return;
            }
            callback.context().add("identityPoolId", (Object)pool.poolId());
            callback.context().add(OAUTH_IDENTITY_PROVIDER_ID_PROPERTY_KEY, (Object)pool.providerId());
        }
    }

    private void handleValidatorCallback(OAuthBearerValidatorCallback callback) {
        try {
            String tokenValue = callback.tokenValue();
            if (tokenValue == null) {
                throw new AuthenticationException("Callback missing required token value", "TOKEN_VALUE_ABSENT");
            }
            OAuthBearerToken token = this.processToken(tokenValue, callback.context());
            if (token instanceof OAuthBearerJwsToken && "Confluent".equalsIgnoreCase(((OAuthBearerJwsToken)token).issuer().trim()) && !this.checkAudClaim((OAuthBearerJwsToken)token, callback)) {
                return;
            }
            callback.token(token);
            log.debug("Successfully validated token");
        }
        catch (AuthenticationException e) {
            log.info("Failed to verify OAuth JWT token", (Throwable)e);
            callback.error(e.reasonCode().equals("AUTHENTICATION_EXCEPTION_OCCURRED") ? AUTH_ERROR_MESSAGE : e.reasonCode(), null, null);
        }
    }

    private void handleExtensionsCallback(OAuthBearerExtensionsValidatorCallback callback) {
        KafkaLogicalClusterMetadata metadata;
        OAuthBearerJwsToken token = (OAuthBearerJwsToken)callback.token();
        String logicalCluster = (String)callback.inputExtensions().map().get("logicalCluster");
        String identityPoolId = (String)callback.inputExtensions().map().get("identityPoolId");
        String sniHostName = (String)callback.inputExtensions().map().get("__confluent_sni_broker_host_name");
        String networkId = (String)callback.inputExtensions().map().get("__confluent_traffic_network_id");
        String trafficTypeName = (String)callback.inputExtensions().map().get("__confluent_ccloud_traffic_type");
        this.addIdentityInformation(identityPoolId, token.jwtClaims(), callback);
        if (!this.doesClusterExtensionExist((CommonExtensionsValidatorCallback)callback, logicalCluster)) {
            return;
        }
        try {
            metadata = this.checkClusterMetadataMatched(callback, token, logicalCluster);
            if (Objects.isNull((Object)metadata)) {
                return;
            }
        }
        catch (IllegalStateException e) {
            this.reportErrorGettingMetadata(callback, e);
            return;
        }
        if (!this.networkIdMatches(callback, logicalCluster, networkId, trafficTypeName)) {
            return;
        }
        if (!(token.issuer() == null || !"Confluent".equalsIgnoreCase(token.issuer().trim()) || this.checkSniHostNameMatched(callback, logicalCluster, sniHostName, this.mode) && this.checkLogicalClusterBelongToOrg(callback, token, metadata))) {
            return;
        }
        if (identityPoolId != null) {
            log.debug("Start validate identity pool trust policy based on token claims: {}", (Object)token.jwtClaims());
            try {
                callback.valid("identityPoolId", this.admissionController.assumePrincipal(token.jwtClaims(), identityPoolId, metadata.organizationId()));
                callback.valid("identityPoolId", identityPoolId);
            }
            catch (AuthenticationException e) {
                this.handleExtensionError((CommonExtensionsValidatorCallback)callback, e.getMessage(), "identityPoolId", e.reasonCode());
                return;
            }
            catch (IllegalArgumentException e) {
                this.handleExtensionError((CommonExtensionsValidatorCallback)callback, e.getMessage(), "identityPoolId", "FAILED_TO_READ_CLAIMS");
                return;
            }
        }
        callback.valid("logicalCluster", logicalCluster);
        log.debug("Successfully authenticated for user: {} (cluster: {})", (Object)token.principalName(), (Object)logicalCluster);
    }

    private boolean checkAudClaim(OAuthBearerJwsToken token, OAuthBearerValidatorCallback callback) {
        if (token.jwtClaims().containsKey("aud")) {
            log.info(String.format("Expecting no aud claim got: %s", token.jwtClaims().get("aud")));
            callback.error("AUD_CLAIM_MISMATCH", null, null);
            return false;
        }
        return true;
    }

    private boolean checkLogicalClusterBelongToOrg(OAuthBearerExtensionsValidatorCallback callback, OAuthBearerJwsToken token, KafkaLogicalClusterMetadata metadata) {
        String orgResourceId = (String)token.jwtClaims().get("orgResourceId");
        if (orgResourceId != null && orgResourceId.equals(metadata.organizationId())) {
            return true;
        }
        String errorMessage = String.format("The principal %s's logical cluster %s is not belong to the org in the token (%s).", token.principalName(), metadata.logicalClusterId(), orgResourceId);
        this.handleExtensionError((CommonExtensionsValidatorCallback)callback, errorMessage, "logicalCluster", "ORG_ID_CLUSTER_ID_MISMATCH");
        return false;
    }

    private void addIdentityInformation(String identityPoolId, Map<String, Object> claims, OAuthBearerExtensionsValidatorCallback callback) {
        if (identityPoolId != null) {
            AuthStore store = Objects.requireNonNull(AuthStore.getInstance((String)this.sessionUuid));
            callback.valid("identityPoolId", identityPoolId);
            IdentityPool pool = store.trustCache().identityPool(identityPoolId);
            callback.data(OAUTH_IDENTITY_PROVIDER_ID_PROPERTY_KEY, (String)(pool.providerId() != null && pool.providerId().trim().isEmpty() ? null : pool.providerId()));
            callback.data(OAUTH_IDENTITY_PROPERTY_KEY, (String)claims.getOrDefault(pool.subjectClaim(), null));
            callback.data(OAUTH_ORGANIZATION_ID_PROPERTY_KEY, pool.orgId());
        }
    }

    private void reportErrorGettingMetadata(OAuthBearerExtensionsValidatorCallback callback, IllegalStateException e) {
        log.error("Could not get physical cluster metadata to validate the token. ", (Throwable)e);
        callback.errorMessage("Could not get cluster metadata to validate the token");
        callback.error("logicalCluster", AUTH_ERROR_MESSAGE);
    }

    private KafkaLogicalClusterMetadata checkClusterMetadataMatched(OAuthBearerExtensionsValidatorCallback callback, OAuthBearerJwsToken token, String logicalCluster) {
        KafkaLogicalClusterMetadata metadata = (KafkaLogicalClusterMetadata)this.clusterMetadata.metadata(logicalCluster);
        if (Objects.isNull((Object)metadata)) {
            if (this.clusterMetadata.logicalClusterIdsIncludingStale().contains(logicalCluster)) {
                log.info("Failing OAuth authentication because the metadata for the logical cluster {} is stale.", (Object)logicalCluster);
            }
            String errorMessage = String.format("The principal %s's logical cluster %s is not hosted on this broker.", token.principalName(), logicalCluster);
            this.handleExtensionError((CommonExtensionsValidatorCallback)callback, errorMessage, "logicalCluster", "CLUSTER_NOT_FOUND");
            return null;
        }
        return metadata;
    }

    private boolean doesClusterExtensionExist(CommonExtensionsValidatorCallback callback, String logicalCluster) {
        if (logicalCluster == null || logicalCluster.isEmpty()) {
            String errorMessage = "The logical cluster extension is missing or is empty";
            this.handleExtensionError(callback, errorMessage, "logicalCluster", "CLUSTER_ID_MISSING_OR_EMPTY");
            return false;
        }
        return true;
    }

    protected boolean checkSniHostNameMatched(OAuthBearerExtensionsValidatorCallback callback, String logicalClusterId, String sniHostName, SniValidationMode sniValidationMode) {
        Optional<PathAwareSniHostName> sniHostNameOptional = sniHostName == null ? Optional.empty() : Optional.of(new PathAwareSniHostName(sniHostName));
        Optional<String> sniClusterId = sniHostNameOptional.map(PathAwareSniHostName::logicalClusterId);
        if (sniValidationMode.sniHostNameMatches(logicalClusterId, sniClusterId, sniHostNameOptional)) {
            return true;
        }
        String errorMessage = String.format("The SNI cluster Id: %s doesn't match with logical cluster extension: %s.", sniClusterId.orElse("<empty>"), logicalClusterId);
        this.handleExtensionError((CommonExtensionsValidatorCallback)callback, errorMessage, "__confluent_sni_broker_host_name", "SNI_ID_CLUSTER_ID_MISMATCH");
        return false;
    }

    private boolean networkIdMatches(OAuthBearerExtensionsValidatorCallback callback, String logicalClusterId, String networkId, String trafficTypeName) {
        TrafficNetworkIdRoutes networkIdRoutes = this.loadNetworkIdRoutes();
        CCloudTrafficType trafficType = trafficTypeName != null ? CCloudTrafficType.valueOf((String)trafficTypeName) : null;
        TrafficNetworkIdValidationMode networkIdValidationMode = TrafficNetworkIdValidationMode.fromConfigs(trafficType, () -> this.networkIdValidationModeJaasConfigEntry);
        TrafficNetworkIdAuthenticator networkIdAuthenticator = new TrafficNetworkIdAuthenticator(networkIdRoutes, networkIdValidationMode, errorMessage -> this.handleExtensionError((CommonExtensionsValidatorCallback)callback, (String)errorMessage, "__confluent_traffic_network_id", "NETWORK_ID_DISALLOWED"));
        return networkIdAuthenticator.authenticate(Optional.ofNullable(networkId), logicalClusterId);
    }

    private void handleExtensionError(CommonExtensionsValidatorCallback callback, String errorMessage, String invalidExtensionName, String reasonCode) {
        log.info(errorMessage);
        callback.errorMessage(errorMessage);
        if (reasonCode == null || reasonCode.trim().equals("") || reasonCode.equals("AUTHENTICATION_EXCEPTION_OCCURRED")) {
            reasonCode = AUTH_ERROR_MESSAGE;
        }
        callback.error(invalidExtensionName, reasonCode);
    }

    OAuthBearerToken processToken(String jws, ClientContext context) throws AuthenticationException {
        Claims claims = this.admissionController.authenticate(new BearerCredential(jws), SecurityContext.fromMap((Map)context.getContextMap()));
        String orgClaim = (String)claims.claimValue("orgResourceId", String.class);
        Set<Object> scope = orgClaim != null ? Collections.singleton(orgClaim) : Collections.emptySet();
        return new OAuthBearerJwsToken(jws, scope, claims.expiresOn(), claims.subject(), Long.valueOf(claims.issuedAt()), claims.asMap(), claims.issuer());
    }

    private TrafficNetworkIdRoutes loadNetworkIdRoutes() {
        TopicBasedTrafficNetworkIdRoutesStore routesLoader = TopicBasedTrafficNetworkIdRoutesStore.getInstance(this.sessionUuid);
        return routesLoader != null ? routesLoader.load() : null;
    }
}

