/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.oidc.services;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import io.confluent.common.security.auth.JwtPrincipal;
import io.confluent.http.client.ConfigurableHttpClient;
import io.confluent.oidc.config.OidcConfig;
import io.confluent.oidc.encryption.EncryptionHandler;
import io.confluent.oidc.entities.CheckDeviceAuthIdpErrorResponse;
import io.confluent.oidc.entities.CheckDeviceAuthRequest;
import io.confluent.oidc.entities.CheckDeviceAuthResponse;
import io.confluent.oidc.entities.DeviceAuthIdpResponse;
import io.confluent.oidc.entities.ExtendAuthResponse;
import io.confluent.oidc.entities.InitDeviceAuthResponse;
import io.confluent.oidc.entities.OidcAuthResponse;
import io.confluent.oidc.exceptions.AuthorizationResponseException;
import io.confluent.oidc.exceptions.DeviceAuthResponseException;
import io.confluent.oidc.exceptions.EncryptionFailedException;
import io.confluent.oidc.exceptions.InvalidCodeException;
import io.confluent.oidc.exceptions.InvalidDeviceKeyException;
import io.confluent.oidc.exceptions.InvalidStateException;
import io.confluent.oidc.exceptions.TokenResponseException;
import io.confluent.rest.RestConfig;
import io.confluent.rest.SslConfig;
import io.confluent.rest.SslFactory;
import io.confluent.security.auth.metadata.AuthStore;
import io.confluent.security.auth.store.data.RefreshTokenInfoKey;
import io.confluent.security.authentication.oidc.RefreshTokenInfo;
import io.confluent.security.authentication.oidc.TokenResponse;
import io.confluent.security.trustservice.store.TrustCache;
import io.confluent.security.trustservice.store.TrustWriter;
import io.confluent.tokenapi.entities.RefreshTokenRequest;
import io.confluent.tokenapi.exceptions.InvalidTokenException;
import io.confluent.tokenapi.services.TokenService;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
import java.security.Principal;
import java.time.Duration;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletionStage;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.Cookie;
import javax.ws.rs.core.Form;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.SecurityContext;
import javax.ws.rs.core.UriBuilder;
import javax.ws.rs.core.UriInfo;
import org.apache.commons.lang3.StringUtils;
import org.apache.kafka.common.Configurable;
import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.security.auth.KafkaPrincipal;
import org.apache.kafka.common.security.ssl.HostSslSocketFactory;
import org.eclipse.jetty.http.DateGenerator;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.glassfish.jersey.client.ClientConfig;
import org.glassfish.jersey.client.HttpUrlConnectorProvider;
import org.glassfish.jersey.client.spi.ConnectorProvider;
import org.jose4j.http.Get;
import org.jose4j.http.SimpleGet;
import org.jose4j.jwa.AlgorithmConstraints;
import org.jose4j.jwk.HttpsJwks;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwt.MalformedClaimException;
import org.jose4j.jwt.consumer.JwtConsumer;
import org.jose4j.jwt.consumer.JwtConsumerBuilder;
import org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver;
import org.jose4j.keys.resolvers.VerificationKeyResolver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OidcTokenService
implements Configurable {
    private static final Logger log = LoggerFactory.getLogger(OidcTokenService.class);
    public static final String TOKEN_MAX_EXP_CLAIM_NAME = "mex";
    public static final String TOKEN_SESSION_ID = "ssid";
    private final TokenService tokenService;
    private final TrustWriter trustWriter;
    private final TrustCache trustCache;
    private final EncryptionHandler encryptionHandler;
    private ConfigurableHttpClient configurableHttpClient;
    private String clientIdIdp;
    private String clientSecretIdp;
    private String subClaimName;
    private String groupsClaimName;
    private String scopeForGroupsClaim;
    private String idpIssuer;
    private boolean idpRefreshTokenEnabled;
    private URI jwksUriIdp;
    private URI deviceAuthEndpointUri;
    private URI authorizeEndpointBaseUri;
    private URI tokenEndpointBaseUri;
    private long sessionTokenLifetimeMs;
    private long sessionTokenMaxExtendabilityMs;
    private JwtConsumer jwtConsumer;
    private SslContextFactory sslContextFactory;

    public OidcTokenService(TokenService tokenService, AuthStore authStore, EncryptionHandler encryptionHandler) {
        this.tokenService = tokenService;
        this.trustWriter = (TrustWriter)authStore.writer();
        this.trustCache = authStore.trustCache();
        this.encryptionHandler = encryptionHandler;
    }

    public void configure(Map<String, ?> props) {
        this.setSslContextFactory(props);
        OidcConfig config = new OidcConfig(props);
        this.sessionTokenLifetimeMs = config.getLong("confluent.oidc.session.token.expiry.ms");
        this.sessionTokenMaxExtendabilityMs = config.getLong("confluent.oidc.session.max.timeout.ms");
        this.subClaimName = config.getString("confluent.oidc.idp.sub.claim.name");
        this.groupsClaimName = config.getString("confluent.oidc.idp.groups.claim.name");
        this.scopeForGroupsClaim = config.getString("confluent.oidc.idp.groups.claim.scope");
        this.clientIdIdp = config.getString("confluent.oidc.idp.client.id");
        this.clientSecretIdp = config.getPassword("confluent.oidc.idp.client.secret").value();
        this.idpIssuer = config.getString("confluent.oidc.idp.issuer");
        this.idpRefreshTokenEnabled = config.getBoolean("confluent.oidc.idp.refresh.token.enabled");
        this.jwksUriIdp = this.getUriFromConfig(config, "confluent.oidc.idp.jwks.endpoint.uri");
        this.authorizeEndpointBaseUri = this.getUriFromConfig(config, "confluent.oidc.idp.authorize.base.endpoint.uri");
        this.deviceAuthEndpointUri = this.getUriFromConfigOrNull(config, "confluent.oidc.idp.device.authorization.endpoint.uri");
        this.tokenEndpointBaseUri = this.getUriFromConfig(config, "confluent.oidc.idp.token.base.endpoint.uri");
        String sslEndpointIdentificationAlgorithm = props.getOrDefault("confluent.idp.ssl.endpoint.identification.algorithm", null);
        this.jwtConsumer = this.createJwtConsumer(sslEndpointIdentificationAlgorithm);
        this.configurableHttpClient = this.createConfigurableHttpClient(sslEndpointIdentificationAlgorithm);
    }

    private void setSslContextFactory(Map<String, ?> props) {
        SslConfig sslConfig = new RestConfig(RestConfig.baseConfigDef(), props).getBaseSslConfig();
        if (StringUtils.isAllBlank((CharSequence[])new CharSequence[]{sslConfig.getTrustStorePath()})) {
            log.info("Using default SSLContext for OIDC service since truststore is not configured.");
            return;
        }
        this.sslContextFactory = SslFactory.createSslContextFactory((SslConfig)sslConfig);
        if (!this.sslContextFactory.isRunning()) {
            try {
                this.sslContextFactory.start();
            }
            catch (Exception e) {
                throw new IllegalStateException("Failed to start SSL Context Factory. One of the reason could be wrong truststore and keystore password.", e);
            }
        }
    }

    private URI getUriFromConfig(OidcConfig config, String key) {
        try {
            return new URI(config.getString(key));
        }
        catch (URISyntaxException e) {
            throw new ConfigException(String.format("Invalid %s.", key), (Object)e);
        }
    }

    private URI getUriFromConfigOrNull(OidcConfig config, String key) {
        if (Objects.isNull(config.getString(key))) {
            return null;
        }
        return this.getUriFromConfig(config, key);
    }

    private JwtConsumer createJwtConsumer(String sslEndpointIdentificationAlgorithm) {
        HostSslSocketFactory sslSocketFactory;
        String jwksUriHostname;
        Get get = new Get();
        String string = jwksUriHostname = Objects.nonNull(this.jwksUriIdp) ? this.jwksUriIdp.getHost() : null;
        if (Objects.nonNull(this.sslContextFactory) && Objects.nonNull(this.sslContextFactory.getSslContext())) {
            log.info("Setting up custom SSLContext for OIDC JWT authenticator.");
            get.setSslSocketFactory(this.sslContextFactory.getSslContext().getSocketFactory());
            sslSocketFactory = new HostSslSocketFactory(this.sslContextFactory.getSslContext().getSocketFactory(), jwksUriHostname, false);
            get.setSslSocketFactory((SSLSocketFactory)sslSocketFactory);
            if (StringUtils.isAllBlank((CharSequence[])new CharSequence[]{sslEndpointIdentificationAlgorithm})) {
                log.info("Provided ssl.endpoint.identification.algorithm={}. Skipping hostname verification in OIDC JWT authenticator.", (Object)sslEndpointIdentificationAlgorithm);
                get.setHostnameVerifier((s, sslSession) -> true);
            }
        } else if (Objects.nonNull(this.jwksUriIdp) && this.jwksUriIdp.getScheme().equals("https")) {
            try {
                sslSocketFactory = new HostSslSocketFactory(SSLContext.getDefault().getSocketFactory(), jwksUriHostname, false);
                get.setSslSocketFactory((SSLSocketFactory)sslSocketFactory);
            }
            catch (NoSuchAlgorithmException e) {
                log.error("Error while getting default SSLContext: ", (Throwable)e);
                throw new RuntimeException(e);
            }
        }
        HttpsJwks httpsJwks = new HttpsJwks(this.jwksUriIdp.toString());
        httpsJwks.setSimpleHttpGet((SimpleGet)get);
        HttpsJwksVerificationKeyResolver keyResolver = new HttpsJwksVerificationKeyResolver(httpsJwks);
        return new JwtConsumerBuilder().setExpectedIssuer(true, this.idpIssuer).setJwsAlgorithmConstraints(AlgorithmConstraints.DISALLOW_NONE).setJweAlgorithmConstraints(AlgorithmConstraints.DISALLOW_NONE).setJweContentEncryptionAlgorithmConstraints(AlgorithmConstraints.DISALLOW_NONE).setSkipDefaultAudienceValidation().setRequireExpirationTime().setAllowedClockSkewInSeconds(30).setVerificationKeyResolver((VerificationKeyResolver)keyResolver).build();
    }

    private ConfigurableHttpClient createConfigurableHttpClient(String sslEndpointIdentificationAlgorithm) {
        ConfigurableHttpClient.Builder httpClientBuilder = ConfigurableHttpClient.builder();
        if (Objects.nonNull(this.sslContextFactory) && Objects.nonNull(this.sslContextFactory.getSslContext())) {
            log.info("Setting up custom SSLContext for OIDC http user agent.");
            httpClientBuilder.sslContext(this.sslContextFactory.getSslContext());
            if (StringUtils.isAllBlank((CharSequence[])new CharSequence[]{sslEndpointIdentificationAlgorithm})) {
                log.info("Provided ssl.endpoint.identification.algorithm={}. Skipping hostname verification in OIDC http user agent.", (Object)sslEndpointIdentificationAlgorithm);
                httpClientBuilder.hostNameVerifier((s, sslSession) -> true);
            }
        }
        httpClientBuilder.clientConfig(this.getClientConfigForCustomSslSocketFactory());
        return httpClientBuilder.build();
    }

    private ClientConfig getClientConfigForCustomSslSocketFactory() {
        ClientConfig clientConfig = new ClientConfig();
        HttpUrlConnectorProvider.ConnectionFactory factory = url -> {
            HttpURLConnection urlConnection = (HttpURLConnection)url.openConnection();
            if (url.getProtocol().equals("https")) {
                SSLContext sslContext;
                HttpsURLConnection httpsURLConnection = (HttpsURLConnection)urlConnection;
                if (Objects.nonNull(this.sslContextFactory) && Objects.nonNull(this.sslContextFactory.getSslContext())) {
                    sslContext = this.sslContextFactory.getSslContext();
                } else {
                    try {
                        sslContext = SSLContext.getDefault();
                    }
                    catch (NoSuchAlgorithmException e) {
                        log.error("Error while getting default SSLContext: ", (Throwable)e);
                        throw new RuntimeException(e);
                    }
                }
                HostSslSocketFactory hostSslSocketFactory = new HostSslSocketFactory(sslContext.getSocketFactory(), url.getHost(), false);
                httpsURLConnection.setSSLSocketFactory((SSLSocketFactory)hostSslSocketFactory);
            }
            return urlConnection;
        };
        clientConfig.connectorProvider((ConnectorProvider)new HttpUrlConnectorProvider().connectionFactory(factory));
        return clientConfig;
    }

    public OidcAuthResponse getIdpAuthUri(UriInfo uriInfo, URI caller) {
        String callbackUri = this.getCallbackUri(uriInfo, caller);
        String state = UUID.randomUUID() + callbackUri;
        URI uri = UriBuilder.fromUri((URI)this.authorizeEndpointBaseUri).queryParam("response_type", new Object[]{"code"}).queryParam("client_id", new Object[]{this.clientIdIdp}).queryParam("scope", new Object[]{this.getScopeParam(this.idpRefreshTokenEnabled)}).queryParam("redirect_uri", new Object[]{callbackUri}).queryParam("state", new Object[]{state}).build(new Object[0]);
        String stateCookie = this.getStateCookie(state);
        log.info("Returning auth uri: {}", (Object)uri);
        return new OidcAuthResponse(uri, Collections.singletonList(stateCookie));
    }

    private String getCallbackUri(UriInfo uriInfo, URI caller) {
        String newPathSuffix = uriInfo.getRequestUri().getPath().replace("authenticate", "authorization-code/callback");
        String callerPath = caller.getPath().replace("/security/1.0/oidc/authenticate", "");
        URI callbackUri = UriBuilder.fromUri((URI)caller).replacePath(callerPath).path(newPathSuffix).replaceQuery(null).build(new Object[0]);
        log.info("CallbackUri for inputUri: " + uriInfo.getRequestUri() + " is: " + callbackUri);
        return callbackUri.toString();
    }

    private String getScopeParam(boolean requestRefreshToken) {
        StringBuilder scope = new StringBuilder("openid");
        if (requestRefreshToken) {
            scope.append(" offline_access");
        }
        if (!Strings.isNullOrEmpty((String)this.scopeForGroupsClaim)) {
            scope.append(" ").append(this.scopeForGroupsClaim);
        }
        return scope.toString();
    }

    private String getStateCookie(String state) {
        String encryptedState = this.encryptionHandler.encrypt(state);
        return "o2state=" + encryptedState + "; HttpOnly; Secure; Path=/; Expires=" + this.getCookieExpiry(Duration.ofMinutes(10L).getSeconds()) + "; SameSite=Lax; Max-Age=" + Duration.ofMinutes(10L).getSeconds();
    }

    private String getCookieExpiry(long expirySec) {
        String expires = expirySec <= 0L ? DateGenerator.formatCookieDate((long)0L).trim() : DateGenerator.formatCookieDate((long)(System.currentTimeMillis() + 1000L * expirySec));
        return expires;
    }

    public OidcAuthResponse handleCallback(Cookie encryptedStateCk, String state, String code, String error, String errorDescription) {
        this.verifyState(encryptedStateCk, state);
        this.handleError(error, errorDescription);
        String callbackUri = this.getCallbackUriFromState(state);
        TokenResponse tokens = this.exchangeAuthorizationCodeForTokens(code, callbackUri);
        JwtClaims jwtClaims = this.validateIdTokenAndCollectClaims(tokens);
        String sessionId = UUID.randomUUID().toString();
        this.encryptAndStoreRefreshToken(tokens, jwtClaims, null, sessionId);
        String sub = this.getSubjectFromJwtClaims(jwtClaims);
        long extendabilitySec = this.getIdTokenRemainingValidity(jwtClaims);
        String confluentToken = this.issueToken(jwtClaims, sub, sessionId, extendabilitySec);
        long tokenLifetimeSec = this.getSessionTokenLifetime(extendabilitySec);
        String authTokenCookieHeader = this.tokenService.getCookieHeader(confluentToken, tokenLifetimeSec);
        String deleteStateCookieHeader = this.getDeleteStateCookieHeader();
        return new OidcAuthResponse(this.getHomeUri(callbackUri), Arrays.asList(authTokenCookieHeader, deleteStateCookieHeader));
    }

    private void verifyState(Cookie encryptedStateCk, String state) {
        if (Objects.isNull(encryptedStateCk) || Strings.isNullOrEmpty((String)encryptedStateCk.getValue())) {
            throw new InvalidStateException("o2state missing or empty in cookies");
        }
        String encryptedState = encryptedStateCk.getValue();
        String storedState = this.encryptionHandler.decrypt(encryptedState);
        if (!storedState.equals(state)) {
            log.error("state stored in cookies: {}, state received in callback: {}.", (Object)storedState, (Object)state);
            throw new InvalidStateException("Invalid state parameter");
        }
    }

    private void handleError(String error, String errorDescription) {
        if (!Strings.isNullOrEmpty((String)error)) {
            log.error("Request failed or denied by IdP. error:{}. description:{}", (Object)error, (Object)errorDescription);
            throw new AuthorizationResponseException(error);
        }
    }

    private String getCallbackUriFromState(String state) {
        if (Objects.isNull(state) || state.length() < 36) {
            log.error("Length of state is less than 36. State={}", (Object)state);
            return "";
        }
        String callbackUri = state.substring(36);
        log.info("CallbackUri from state is: {}", (Object)callbackUri);
        return callbackUri;
    }

    private TokenResponse exchangeAuthorizationCodeForTokens(String code, String callbackUri) {
        Form form = this.createTokenRequestForm(code, callbackUri);
        String encodedCredentials = this.getEncodedCredentials(this.clientIdIdp, this.clientSecretIdp);
        CompletionStage<Response> responseAsync = this.getTokensAsync(form, encodedCredentials);
        log.debug("Fetching idp tokens using authorization code. callbackUri={}", (Object)callbackUri);
        return this.fetchIdpTokens(responseAsync);
    }

    private Form createTokenRequestForm(String code, String callbackUri) {
        if (Strings.isNullOrEmpty((String)code)) {
            throw new InvalidCodeException("authorization code is null or empty");
        }
        return new Form().param("grant_type", "authorization_code").param("code", code).param("redirect_uri", callbackUri);
    }

    private String getEncodedCredentials(String clientId, String clientSecret) {
        String credentials = clientId + ":" + clientSecret;
        return Base64.getEncoder().encodeToString(credentials.getBytes(StandardCharsets.UTF_8));
    }

    private CompletionStage<Response> getTokensAsync(Form form, String encodedCredentials) {
        return this.configurableHttpClient.target(this.tokenEndpointBaseUri).request().header("Authorization", (Object)("Basic " + encodedCredentials)).accept(new String[]{"application/json"}).rx().post(Entity.entity((Object)form, (String)"application/x-www-form-urlencoded"));
    }

    private TokenResponse fetchIdpTokens(CompletionStage<Response> responseAsync) {
        try {
            return responseAsync.thenApply(this::processResponse).toCompletableFuture().get();
        }
        catch (Exception e) {
            log.error("Exception in async http call of getting tokens: " + e.getMessage(), (Throwable)e);
            throw new TokenResponseException(e.getMessage(), e);
        }
    }

    private TokenResponse processResponse(Response response) {
        if (response == null) {
            throw new TokenResponseException("Response is null from IdP when fetching tokens");
        }
        if (Response.Status.OK.getStatusCode() == response.getStatus()) {
            log.debug("Successful token response from IDP: {}", (Object)response.getStatus());
            return (TokenResponse)response.readEntity(TokenResponse.class);
        }
        if (Response.Status.BAD_REQUEST.getStatusCode() == response.getStatus()) {
            throw new RuntimeException("Got bad request status from IdP: " + (String)response.readEntity(String.class));
        }
        String errorMsg = "Failed to retrieve tokens from IDP with status:" + response.getStatus();
        String details = (String)response.readEntity(String.class);
        if (!Strings.isNullOrEmpty((String)details)) {
            errorMsg = errorMsg + ". Error: " + details;
        }
        throw new RuntimeException(errorMsg);
    }

    private JwtClaims validateIdTokenAndCollectClaims(TokenResponse idpTokensResponse) {
        try {
            if (Strings.isNullOrEmpty((String)idpTokensResponse.getIdToken())) {
                throw new InvalidTokenException("id token is null or empty in token response");
            }
            JwtClaims jwtClaims = this.jwtConsumer.processToClaims(idpTokensResponse.getIdToken());
            log.debug("Claims Received from IDP are : {}", (Object)String.join((CharSequence)",", jwtClaims.getClaimNames()));
            this.getSubjectFromJwtClaims(jwtClaims);
            this.getGroupsFromJwtClaims(jwtClaims);
            return jwtClaims;
        }
        catch (Exception e) {
            log.error("Failed to validate id token form IdP. Error: {}", (Object)e.getMessage());
            throw new InvalidTokenException((Throwable)e);
        }
    }

    private String getSubjectFromJwtClaims(JwtClaims jwtClaims) {
        if (!jwtClaims.hasClaim(this.subClaimName)) {
            throw new InvalidTokenException(this.subClaimName + "(sub claim) not present");
        }
        try {
            String sub = (String)jwtClaims.getClaimValue(this.subClaimName, String.class);
            if (Strings.isNullOrEmpty((String)sub)) {
                throw new InvalidTokenException(this.subClaimName + "(sub claim) is:" + sub);
            }
            return sub;
        }
        catch (MalformedClaimException e) {
            throw new InvalidTokenException(this.subClaimName + "(sub claim) not a String");
        }
    }

    private Set<String> getGroupsFromJwtClaims(JwtClaims jwtClaims) {
        if (!jwtClaims.hasClaim(this.groupsClaimName)) {
            return Collections.emptySet();
        }
        Object groupsObj = jwtClaims.getClaimValue(this.groupsClaimName);
        HashSet<String> groups = new HashSet<String>();
        if (!(groupsObj instanceof List)) {
            throw new InvalidTokenException("groups is not a List. Actual type:" + groupsObj.getClass());
        }
        for (Object group : (List)groupsObj) {
            if (!(group instanceof String)) {
                throw new InvalidTokenException("group is not a String. Actual type:" + group.getClass());
            }
            groups.add((String)group);
        }
        return groups;
    }

    private void encryptAndStoreRefreshToken(TokenResponse tokens, JwtClaims jwtClaims, String grantType, String sessionId) {
        if (Strings.isNullOrEmpty((String)tokens.getRefreshToken())) {
            log.info("No refresh token for iss:{} sub:{}", (Object)this.idpIssuer, (Object)this.getSubjectFromJwtClaims(jwtClaims));
            return;
        }
        if (!this.idpRefreshTokenEnabled) {
            log.warn("Refresh token is not requested but issued for iss:{}", (Object)this.idpIssuer);
            return;
        }
        String encryptedRefreshToken = this.encryptionHandler.encrypt(tokens.getRefreshToken());
        String sub = this.getSubjectFromJwtClaims(jwtClaims);
        long issuedAt = this.getIssuedTimeEpoch(jwtClaims);
        String issuer = this.idpIssuer + (Objects.nonNull(grantType) ? "-" + grantType : "");
        log.info("Storing refresh token - issuer:{}, issuedAt:{}, sub:{}, ssid:{}", new Object[]{issuer, issuedAt, sub, sessionId});
        this.trustWriter.addRefreshTokenInfo(issuer, encryptedRefreshToken, issuedAt, sub, sessionId);
    }

    private long getIssuedTimeEpoch(JwtClaims jwtClaims) {
        try {
            return jwtClaims.getIssuedAt().getValue();
        }
        catch (MalformedClaimException e) {
            return System.currentTimeMillis() / 1000L;
        }
    }

    private String issueToken(JwtClaims claims, String sub, String ssid, long extendabilitySec) {
        long mex = System.currentTimeMillis() / 1000L + extendabilitySec;
        HashMap<String, Object> customClaims = new HashMap<String, Object>();
        customClaims.put("groups", this.getGroupsFromJwtClaims(claims));
        customClaims.put(TOKEN_MAX_EXP_CLAIM_NAME, mex);
        customClaims.put(TOKEN_SESSION_ID, ssid);
        long tokenLifetimeSec = this.getSessionTokenLifetime(extendabilitySec);
        log.info("Issuing token to `{}` with validity:`{}` and mex:`{}`", new Object[]{sub, tokenLifetimeSec, mex});
        return this.tokenService.issueToken((Principal)new KafkaPrincipal("User", sub), customClaims, tokenLifetimeSec, Collections.emptyList());
    }

    private long getSessionTokenLifetime(long extendabilitySec) {
        return Math.min(this.sessionTokenMaxExtendabilityMs / 1000L, Math.min(this.sessionTokenLifetimeMs / 1000L, extendabilitySec));
    }

    private long getIdTokenRemainingValidity(JwtClaims jwtClaims) {
        try {
            long idTkExp = jwtClaims.getExpirationTime().getValue() - System.currentTimeMillis() / 1000L;
            return Math.max(0L, Math.min(this.sessionTokenMaxExtendabilityMs / 1000L, idTkExp));
        }
        catch (MalformedClaimException e) {
            log.error("exp claim not present in JwtClaims");
            return 0L;
        }
    }

    private String getDeleteStateCookieHeader() {
        return "o2state=; HttpOnly; Secure; Path=/; Expires=" + this.getCookieExpiry(0L) + "; SameSite=Strict; Max-Age=0";
    }

    private URI getHomeUri(String callbackUriStr) {
        URI callbackUri = URI.create(callbackUriStr);
        return UriBuilder.fromUri((URI)callbackUri).replacePath(null).replaceQuery(null).build(new Object[0]);
    }

    public OidcAuthResponse refreshConfluentTokenIfApplicable(SecurityContext securityContext) {
        if (this.isInvalidPrincipal(securityContext.getUserPrincipal())) {
            return new OidcAuthResponse(null, Collections.emptyList());
        }
        JwtPrincipal principal = (JwtPrincipal)securityContext.getUserPrincipal();
        RefreshTokenInfo refreshTokenInfo = this.getRefreshTokenInfoFromCache((Principal)principal, null);
        if (this.authTokenExpiredOrCannotBeExtended(refreshTokenInfo, principal)) {
            return new OidcAuthResponse(null, Collections.emptyList(), principal.getJwt(), this.getLongValue(principal, "exp"));
        }
        ExtendAuthResponse authResponse = this.getExtendAuthResponse(refreshTokenInfo, principal);
        return new OidcAuthResponse(null, Collections.singletonList(this.tokenService.getCookieHeader(authResponse.authToken(), authResponse.expiresIn().longValue())), authResponse.authToken(), authResponse.expiresIn() + System.currentTimeMillis() / 1000L);
    }

    private ExtendAuthResponse getExtendAuthResponse(RefreshTokenInfo refreshTokenInfo, JwtPrincipal principal) {
        if (this.useRefreshTokensForExtension(refreshTokenInfo, principal)) {
            return this.extendSessionUsingRefreshToken(refreshTokenInfo, principal);
        }
        return this.extendSessionUsingMexTokenClaim(principal);
    }

    private boolean isInvalidPrincipal(Principal principal) {
        if (Objects.isNull(principal) || !(principal instanceof JwtPrincipal)) {
            log.error("context user principal=`{}` is not valid", (Object)principal);
            return true;
        }
        return false;
    }

    private RefreshTokenInfo getRefreshTokenInfoFromCache(Principal principal, String grantType) {
        log.info("Fetching refresh token info from cache for principal : {}", (Object)principal.getName());
        String issuer = this.idpIssuer + (Objects.nonNull(grantType) ? "-" + grantType : "");
        return this.trustCache.refreshTokenInfo(RefreshTokenInfoKey.cacheKey((String)issuer, (String)principal.getName()));
    }

    private boolean authTokenExpiredOrCannotBeExtended(RefreshTokenInfo refreshTokenInfo, JwtPrincipal principal) {
        long currentTimeSec;
        long expirySec = this.getLongValue(principal, "exp");
        if (expirySec < (currentTimeSec = System.currentTimeMillis() / 1000L)) {
            log.info("auth_token already expired for {}. expSec:{}, currentTimeSec:{}", new Object[]{principal.getName(), expirySec, currentTimeSec});
            return true;
        }
        if (this.useRefreshTokensForExtension(refreshTokenInfo, principal)) {
            if (currentTimeSec - refreshTokenInfo.issuedAt() > this.sessionTokenMaxExtendabilityMs / 1000L) {
                log.info("auth_token cannot be extended beyond abs timeout for {}. refreshTkIssuedAt:{}, currentTimeSec:{}, absoluteTimeoutSec:{}", new Object[]{principal.getName(), refreshTokenInfo.issuedAt(), currentTimeSec, this.sessionTokenMaxExtendabilityMs / 1000L});
                return true;
            }
            return false;
        }
        long mex = this.getLongValue(principal, TOKEN_MAX_EXP_CLAIM_NAME);
        if (mex < currentTimeSec) {
            log.info("auth_token cannot be extended beyond mex for {}. mex:{}, currentTimeSec:{}", new Object[]{principal.getName(), mex, currentTimeSec});
            return true;
        }
        return false;
    }

    private boolean useRefreshTokensForExtension(RefreshTokenInfo refreshTokenInfo, JwtPrincipal principal) {
        return Objects.nonNull(refreshTokenInfo) && this.idpRefreshTokenEnabled && !Strings.isNullOrEmpty((String)refreshTokenInfo.encryptedRefreshToken()) && Objects.nonNull(refreshTokenInfo.sessionId()) && refreshTokenInfo.sessionId().equals(principal.jwtClaims().getOrDefault(TOKEN_SESSION_ID, null));
    }

    private ExtendAuthResponse extendSessionUsingRefreshToken(RefreshTokenInfo refreshTokenInfo, JwtPrincipal principal) {
        log.info("Extending session using refresh token for iss:{}, principal:{}, ssid:{}", new Object[]{refreshTokenInfo.issuer(), principal.getName(), refreshTokenInfo.sessionId()});
        TokenResponse tokens = this.getTokensFromRefreshToken(refreshTokenInfo);
        JwtClaims jwtClaims = this.validateIdTokenAndCollectClaims(tokens);
        String sessionId = principal.jwtClaims().getOrDefault(TOKEN_SESSION_ID, null);
        long remExtendabilityMs = refreshTokenInfo.issuedAt() * 1000L + this.sessionTokenMaxExtendabilityMs - System.currentTimeMillis();
        long extendabilitySec = Math.min(remExtendabilityMs / 1000L, this.getIdTokenRemainingValidity(jwtClaims));
        String confluentToken = this.issueToken(jwtClaims, principal.getName(), sessionId, extendabilitySec);
        long tokenLifetimeSec = this.getSessionTokenLifetime(extendabilitySec);
        return new ExtendAuthResponse(confluentToken, tokenLifetimeSec);
    }

    private TokenResponse getTokensFromRefreshToken(RefreshTokenInfo refreshTokenInfo) {
        String refreshToken = this.encryptionHandler.decrypt(refreshTokenInfo.encryptedRefreshToken());
        Form form = this.createTokenRequestFormUsingRefreshToken(refreshToken);
        String encodedCredentials = this.getEncodedCredentials(this.clientIdIdp, this.clientSecretIdp);
        CompletionStage<Response> responseAsync = this.getTokensAsync(form, encodedCredentials);
        log.info("Fetching idp tokens using refresh token for `{}`", (Object)refreshTokenInfo.subClaim());
        return this.fetchIdpTokens(responseAsync);
    }

    private Form createTokenRequestFormUsingRefreshToken(String refreshToken) {
        return new Form().param("grant_type", "refresh_token").param("refresh_token", refreshToken).param("scope", this.getScopeParam(false));
    }

    private ExtendAuthResponse extendSessionUsingMexTokenClaim(JwtPrincipal principal) {
        RefreshTokenRequest refreshTokenRequest = new RefreshTokenRequest(principal.getJwt(), "");
        long tokenLifetime = this.getTokenLifetime(principal);
        log.info("Extending session using mex for sub:{} by:{}", (Object)principal.getName(), (Object)tokenLifetime);
        String newAuthToken = this.tokenService.refreshToken((Principal)principal, refreshTokenRequest, Collections.emptyMap(), tokenLifetime);
        return new ExtendAuthResponse(newAuthToken, tokenLifetime);
    }

    private long getTokenLifetime(JwtPrincipal principal) {
        long currentTimeSec = System.currentTimeMillis() / 1000L;
        long mex = this.getLongValue(principal, TOKEN_MAX_EXP_CLAIM_NAME);
        long extendabilitySec = Math.max(0L, mex - currentTimeSec);
        return this.getSessionTokenLifetime(extendabilitySec);
    }

    private long getLongValue(JwtPrincipal jwtPrincipal, String claimName) {
        Long expObj = jwtPrincipal.jwtClaims().getOrDefault(claimName, 0L);
        if (expObj instanceof Long) {
            return expObj;
        }
        if (expObj instanceof String) {
            try {
                return Long.parseLong((String)((Object)expObj));
            }
            catch (Exception exception) {
                // empty catch block
            }
        }
        log.error("jwt claim=`{}` is not of type long", (Object)claimName);
        return 0L;
    }

    public OidcAuthResponse clearToken(SecurityContext securityContext) {
        if (this.isInvalidPrincipal(securityContext.getUserPrincipal())) {
            return this.clearAuthTokenResponse();
        }
        JwtPrincipal principal = (JwtPrincipal)securityContext.getUserPrincipal();
        RefreshTokenInfo refreshTokenInfo = this.getRefreshTokenInfoFromCache((Principal)principal, null);
        if (Objects.nonNull(refreshTokenInfo) && Objects.nonNull(refreshTokenInfo.sessionId()) && refreshTokenInfo.sessionId().equals(principal.jwtClaims().getOrDefault(TOKEN_SESSION_ID, null))) {
            this.trustWriter.removeRefreshTokenInfo(this.idpIssuer, securityContext.getUserPrincipal().getName());
        }
        return this.clearAuthTokenResponse();
    }

    private OidcAuthResponse clearAuthTokenResponse() {
        String token = this.tokenService.getCookieHeader("", 0L);
        return new OidcAuthResponse(null, Collections.singletonList(token));
    }

    public InitDeviceAuthResponse createDeviceAuth() {
        if (this.deviceAuthEndpointUri == null) {
            throw new ConfigException("Device auth endpoint is not configured");
        }
        Form body = new Form().param("client_id", this.clientIdIdp).param("client_secret", this.clientSecretIdp).param("scope", this.getScopeParam(this.idpRefreshTokenEnabled));
        DeviceAuthIdpResponse deviceAuthIdpResponse = this.configurableHttpClient.target(this.deviceAuthEndpointUri).request().accept(new String[]{"application/json"}).rx().post(Entity.entity((Object)body, (String)"application/x-www-form-urlencoded")).thenApply(this::processDeviceAuthResponse).toCompletableFuture().join();
        log.info("Device authentication requested successfully: {}", (Object)deviceAuthIdpResponse);
        return this.deviceAuthIdpResponseToInitDeviceAuthResponse(deviceAuthIdpResponse);
    }

    private DeviceAuthIdpResponse processDeviceAuthResponse(Response response) {
        if (response == null) {
            throw new DeviceAuthResponseException("Response is null from IdP when starting device auth");
        }
        if (Response.Status.OK.getStatusCode() == response.getStatus()) {
            log.debug("Successful device auth response from IDP: {}", (Object)response.getStatus());
            return (DeviceAuthIdpResponse)response.readEntity(DeviceAuthIdpResponse.class);
        }
        if (Response.Status.BAD_REQUEST.getStatusCode() == response.getStatus()) {
            throw new RuntimeException("Got bad request status from IdP: " + (String)response.readEntity(String.class));
        }
        String errorMsg = "Got device auth response from IDP with status:" + response.getStatus();
        String details = (String)response.readEntity(String.class);
        if (!Strings.isNullOrEmpty((String)details)) {
            errorMsg = errorMsg + ". Error: " + details;
        }
        throw new RuntimeException(errorMsg);
    }

    private InitDeviceAuthResponse deviceAuthIdpResponseToInitDeviceAuthResponse(DeviceAuthIdpResponse deviceAuthIdpResponse) {
        String key = this.encryptionHandler.encrypt(deviceAuthIdpResponse.deviceCode());
        String verificationUriComplete = deviceAuthIdpResponse.verificationUriComplete();
        if (Strings.isNullOrEmpty((String)verificationUriComplete)) {
            log.debug("Verification URI complete is not present in the response. Building it.");
            verificationUriComplete = UriBuilder.fromUri((String)deviceAuthIdpResponse.verificationUri()).queryParam("user_code", new Object[]{deviceAuthIdpResponse.userCode()}).build(new Object[0]).toString();
        }
        return new InitDeviceAuthResponse(deviceAuthIdpResponse.userCode(), verificationUriComplete, key, deviceAuthIdpResponse.interval(), deviceAuthIdpResponse.getExpiresIn());
    }

    public CheckDeviceAuthResponse checkDeviceAuth(CheckDeviceAuthRequest request) {
        String deviceCode = this.deviceCodeFromKey(request);
        Form body = new Form().param("client_id", this.clientIdIdp).param("client_secret", this.clientSecretIdp).param("grant_type", "urn:ietf:params:oauth:grant-type:device_code").param("device_code", deviceCode);
        return this.configurableHttpClient.target(this.tokenEndpointBaseUri).request().accept(new String[]{"application/json"}).rx().post(Entity.entity((Object)body, (String)"application/x-www-form-urlencoded")).thenApply(response -> this.processCheckDeviceAuthResponse((Response)response, request.userCode())).toCompletableFuture().join();
    }

    private String deviceCodeFromKey(CheckDeviceAuthRequest request) {
        try {
            if (request == null || request.key() == null) {
                throw new InvalidDeviceKeyException("Invalid key!");
            }
            return this.encryptionHandler.decrypt(request.key());
        }
        catch (EncryptionFailedException e) {
            log.error("Failed to decrypt device code from key: {} for user_code: {}", (Object)request.userCode(), (Object)request.key());
            throw new InvalidDeviceKeyException("Invalid key!");
        }
    }

    private CheckDeviceAuthResponse processCheckDeviceAuthResponse(Response response, String userCode) {
        if (response == null) {
            throw new TokenResponseException("Response is null from IdP when checking device auth");
        }
        if (Response.Status.OK.getStatusCode() == response.getStatus()) {
            return this.processSuccessfulDeviceAuthResponse(response, userCode);
        }
        if (Response.Status.BAD_REQUEST.getStatusCode() == response.getStatus()) {
            return this.processBadRequestDeviceAuthResponse(response, userCode);
        }
        String errorMsg = "Got device auth response from IDP with status:" + response.getStatus();
        String details = (String)response.readEntity(String.class);
        if (!Strings.isNullOrEmpty((String)details)) {
            errorMsg = errorMsg + ". Error: " + details;
        }
        throw new RuntimeException(errorMsg);
    }

    private CheckDeviceAuthResponse processSuccessfulDeviceAuthResponse(Response response, String userCode) {
        log.debug("Successful device auth check response from IDP: {}", (Object)response.getStatus());
        TokenResponse tokens = (TokenResponse)response.readEntity(TokenResponse.class);
        JwtClaims jwtClaims = this.validateIdTokenAndCollectClaims(tokens);
        String sessionId = UUID.randomUUID().toString();
        this.encryptAndStoreRefreshToken(tokens, jwtClaims, "device", sessionId);
        String sub = this.getSubjectFromJwtClaims(jwtClaims);
        long extendabilitySec = this.getIdTokenRemainingValidity(jwtClaims);
        String confluentToken = this.issueToken(jwtClaims, sub, sessionId, extendabilitySec);
        long tokenLifetimeSec = this.getSessionTokenLifetime(extendabilitySec);
        return CheckDeviceAuthResponse.createCompleteAuthResponse(userCode, confluentToken, tokenLifetimeSec);
    }

    private CheckDeviceAuthResponse processBadRequestDeviceAuthResponse(Response response, String userCode) {
        log.debug("Bad request response from IDP for device auth check: {}", (Object)response.getStatus());
        CheckDeviceAuthIdpErrorResponse errorResponse = (CheckDeviceAuthIdpErrorResponse)response.readEntity(CheckDeviceAuthIdpErrorResponse.class);
        if ("authorization_pending".equalsIgnoreCase(errorResponse.error()) || "slow_down".equalsIgnoreCase(errorResponse.error())) {
            return CheckDeviceAuthResponse.createPendingAuthResponse(userCode, errorResponse.error(), errorResponse.errorDescription());
        }
        return CheckDeviceAuthResponse.createErrorResponse(userCode, errorResponse.error(), errorResponse.errorDescription());
    }

    public ExtendAuthResponse extendDeviceAuthIfApplicable(SecurityContext securityContext) {
        if (this.isInvalidPrincipal(securityContext.getUserPrincipal())) {
            return new ExtendAuthResponse(null, 0L);
        }
        JwtPrincipal principal = (JwtPrincipal)securityContext.getUserPrincipal();
        RefreshTokenInfo refreshTokenInfo = this.getRefreshTokenInfoFromCache((Principal)principal, "device");
        if (this.authTokenExpiredOrCannotBeExtended(refreshTokenInfo, principal)) {
            return new ExtendAuthResponse(null, 0L);
        }
        return this.getExtendAuthResponse(refreshTokenInfo, principal);
    }

    @VisibleForTesting
    public void setJwtConsumer(JwtConsumer jwtConsumer) {
        this.jwtConsumer = jwtConsumer;
    }
}

