/*
 * Decompiled with CFR 0.152.
 */
package kafka.tier.store.encryption;

import com.google.crypto.tink.Aead;
import com.google.crypto.tink.CleartextKeysetHandle;
import com.google.crypto.tink.JsonKeysetReader;
import com.google.crypto.tink.JsonKeysetWriter;
import com.google.crypto.tink.KeyTemplates;
import com.google.crypto.tink.KeysetHandle;
import com.google.crypto.tink.KeysetWriter;
import com.google.crypto.tink.proto.Keyset;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.time.Duration;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import kafka.tier.exceptions.TierObjectStoreFatalException;
import kafka.tier.exceptions.TierObjectStoreRetriableException;
import kafka.tier.store.OpaqueData;
import kafka.tier.store.TierObjectStore;
import kafka.tier.store.encryption.CleartextDataKey;
import kafka.tier.store.encryption.DataEncryptionKeyHolder;
import kafka.tier.store.encryption.EncryptedDataKey;
import kafka.tier.store.encryption.EncryptionKeyManagerMetrics;
import kafka.tier.store.encryption.KeyContext;
import kafka.tier.store.encryption.KeySha;
import kafka.tier.store.encryption.Util;
import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.utils.ByteBufferInputStream;
import org.apache.kafka.common.utils.ByteBufferOutputStream;
import org.apache.kafka.common.utils.Time;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class EncryptionKeyManager {
    private static final Logger log = LoggerFactory.getLogger(EncryptionKeyManager.class);
    static final String METADATA_SHA_KEY = "io.confluent/key-sha-256";
    static final String METADATA_DATA_KEY = "io.confluent/base64-encrypted-data-key";
    static final String METADATA_KEY_CREATION_TIME = "io.confluent/key-creation-time";
    private static final String DATA_KEY_TEMPLATE = "AES256_GCM_RAW";
    final EncryptionKeyManagerMetrics metrics;
    private final Time time;
    private final KeyCache cache = new KeyCache();
    private final Duration keyRefreshInterval;
    private final Aead remoteKek;
    private WellKnownKeypathHook wellKnownKeypathHook;

    public EncryptionKeyManager(Time time, Metrics metrics, Aead masterKeyAead, Duration maxKeyAge) {
        if (metrics != null) {
            this.metrics = new EncryptionKeyManagerMetrics(metrics);
            this.metrics.updateMaxKeyAge(maxKeyAge);
        } else {
            this.metrics = null;
        }
        this.time = time;
        this.remoteKek = masterKeyAead;
        this.keyRefreshInterval = maxKeyAge;
    }

    public void bindHook(WellKnownKeypathHook hook) {
        this.wellKnownKeypathHook = hook;
    }

    public void close() {
        if (this.metrics != null) {
            this.metrics.close();
        }
    }

    public KeyContext keyContext(KeySha keySha) {
        DataEncryptionKeyHolder holder = this.cache.get(keySha);
        if (holder == null) {
            return null;
        }
        HashMap<String, String> metadata = EncryptionKeyManager.keyToObjectMetadata(holder);
        return new KeyContext(holder.cleartextDataKey, metadata, keySha);
    }

    public KeySha registerKeyFromObjectMetadata(Map<String, String> metadata) {
        DataEncryptionKeyHolder holder = this.parseKeyFromObjectMetadata(metadata);
        log.info("Registering key {} decoded from metadata", (Object)holder.keySha);
        this.cache.add(holder);
        return holder.keySha;
    }

    public KeySha registerKeyIfAbsentFromObjectMetadata(Map<String, String> metadata) {
        DataEncryptionKeyHolder holder = this.parseKeyFromObjectMetadata(metadata);
        if (this.cache.get(holder.keySha) == null) {
            this.cache.add(holder);
        }
        return holder.keySha;
    }

    public KeySha activeKeySha() {
        this.maybeRotate();
        return this.cache.activeKeySha();
    }

    public void clear() {
        this.cache.clear();
    }

    private static HashMap<String, String> keyToObjectMetadata(DataEncryptionKeyHolder holder) {
        HashMap<String, String> metadata = new HashMap<String, String>();
        metadata.put(METADATA_SHA_KEY, holder.keySha.base64Encoded());
        holder.keyCreationTimeOpt.ifPresent(instant -> metadata.put(METADATA_KEY_CREATION_TIME, Long.toString(instant.toEpochMilli())));
        metadata.put(METADATA_DATA_KEY, holder.encryptedDataKey.base64Encoded());
        return metadata;
    }

    public TierObjectStore.ByokKeyHolder getActiveKey() {
        KeySha active = this.activeKeySha();
        OpaqueData opaqueData = OpaqueData.fromByteArray(active.toRawBytes());
        return new TierObjectStore.ByokKeyHolder(opaqueData, null);
    }

    private DataEncryptionKeyHolder parseKeyFromObjectMetadata(Map<String, String> metadata) {
        DataEncryptionKeyHolder decoded;
        String keySha = metadata.get(METADATA_SHA_KEY);
        if (keySha == null || keySha.isEmpty()) {
            throw new TierObjectStoreFatalException(String.format("%s metadata field not present", METADATA_SHA_KEY));
        }
        String encryptedDataKey = metadata.get(METADATA_DATA_KEY);
        if (encryptedDataKey == null || encryptedDataKey.isEmpty()) {
            throw new TierObjectStoreFatalException(String.format("%s metadata field not present", METADATA_DATA_KEY));
        }
        String keyCreationTime = metadata.get(METADATA_KEY_CREATION_TIME);
        if (keyCreationTime == null || keyCreationTime.isEmpty()) {
            throw new TierObjectStoreFatalException(String.format("%s metadata field not present", METADATA_KEY_CREATION_TIME));
        }
        KeySha parsedKeySha = KeySha.fromBase64Encoded(keySha);
        EncryptedDataKey parsedEncryptedDataKey = EncryptedDataKey.fromBase64Encoded(encryptedDataKey);
        Instant parsedKeyCreationTime = Instant.ofEpochMilli(Long.parseLong(keyCreationTime));
        try {
            decoded = this.decryptDek(parsedEncryptedDataKey, parsedKeyCreationTime);
        }
        catch (IOException | GeneralSecurityException e) {
            throw new TierObjectStoreRetriableException("Failed to decrypt data encryption key from object metadata", e);
        }
        if (!decoded.keySha.equals(parsedKeySha)) {
            throw new TierObjectStoreFatalException(String.format("KeySha parsed from object metadata '%s' does not match decoded KeySha '%s'", parsedKeySha, decoded.keySha));
        }
        return decoded;
    }

    private synchronized void maybeRotate() {
        KeySha active = this.cache.activeKeySha();
        if (active == null) {
            log.info("No active key found, seeding key cache");
            this.forceRotate();
            active = this.cache.activeKeySha();
        }
        DataEncryptionKeyHolder holder = this.cache.get(active);
        if (holder.keyCreationTimeOpt.isPresent()) {
            Instant deadline;
            Instant creationTime = holder.keyCreationTimeOpt.get();
            Instant timeNow = Instant.ofEpochMilli(this.time.milliseconds());
            if (timeNow.isAfter(deadline = creationTime.plus(this.keyRefreshInterval))) {
                log.info("Key corresponding to {} created at {} has expired determined by the refresh interval {}, seeding key cache", new Object[]{active, creationTime, this.keyRefreshInterval});
                this.forceRotate();
            }
        } else {
            throw new TierObjectStoreFatalException(String.format("Key corresponding to %s has not been checked for rotation since no corresponding creation time has been found.", active));
        }
    }

    private Map<String, String> fetchWellKnownPathMetadata() {
        if (this.wellKnownKeypathHook != null) {
            return this.wellKnownKeypathHook.fetchWellKnownPathMetadata();
        }
        return null;
    }

    private void writeWellKnownPathMetadata(Map<String, String> metadata) {
        if (this.wellKnownKeypathHook != null) {
            this.wellKnownKeypathHook.writeWellKnownPathMetadata(metadata);
        }
    }

    private void forceRotate() {
        DataEncryptionKeyHolder newKey;
        Map<String, String> metadata = this.fetchWellKnownPathMetadata();
        if (metadata != null && !metadata.isEmpty()) {
            DataEncryptionKeyHolder newKeyHolder = this.parseKeyFromObjectMetadata(metadata);
            if (newKeyHolder.keyCreationTimeOpt.isPresent()) {
                Instant creationTime = newKeyHolder.keyCreationTimeOpt.get();
                Instant timeNow = Instant.ofEpochMilli(this.time.milliseconds());
                Instant deadline = creationTime.plus(this.keyRefreshInterval);
                log.info("Recovered previously written key {} created at {} from the well-known keypath", (Object)newKeyHolder.keySha, (Object)creationTime);
                if (timeNow.isBefore(deadline)) {
                    log.info("Using key {} as the active key", (Object)newKeyHolder.keySha);
                    if (this.metrics != null) {
                        this.metrics.updateActiveKeyCreationTime(creationTime);
                    }
                    this.cache.replaceActiveKeySha(newKeyHolder);
                    return;
                }
                log.info("Key {} recovered from the well-known keypath is too old to use as the active key", (Object)newKeyHolder.keySha);
                this.cache.add(newKeyHolder);
            } else {
                throw new TierObjectStoreFatalException(String.format("Key %s has not been rotated since no corresponding creation time has been found.", newKeyHolder.keySha));
            }
        }
        log.info("Unable to restore a valid active key from the well-known keypath, generating a new one");
        try {
            newKey = this.generateNewDek();
        }
        catch (IOException e) {
            throw new TierObjectStoreRetriableException("Failed to generate data encryption key for rotation", e);
        }
        catch (GeneralSecurityException e) {
            throw new TierObjectStoreFatalException("Failed to generate data encryption key for rotation", e);
        }
        log.info("Using key {} as the active key", (Object)newKey.keySha);
        if (this.metrics != null && newKey.keyCreationTimeOpt.isPresent()) {
            this.metrics.updateActiveKeyCreationTime(newKey.keyCreationTimeOpt.get());
        }
        this.cache.replaceActiveKeySha(newKey);
        log.info("Writing out newly generated key {} to the well-known key path", (Object)newKey.keySha);
        this.writeWellKnownPathMetadata(EncryptionKeyManager.keyToObjectMetadata(newKey));
    }

    private DataEncryptionKeyHolder generateNewDek() throws GeneralSecurityException, IOException {
        KeysetHandle dekKeysetHandle = KeysetHandle.generateNew(KeyTemplates.get(DATA_KEY_TEMPLATE));
        ByteBufferOutputStream out = new ByteBufferOutputStream(256);
        try {
            KeysetWriter writer = JsonKeysetWriter.withOutputStream((OutputStream)out);
            long encryptStart = this.time.hiResClockMs();
            dekKeysetHandle.write(writer, this.remoteKek);
            if (this.metrics != null) {
                this.metrics.recordEncryptCall(this.time.hiResClockMs() - encryptStart);
            }
        }
        catch (Exception e) {
            throw new TierObjectStoreRetriableException("Exception trying to encrypt key using master key", e);
        }
        ByteBuffer edekBuf = out.buffer();
        edekBuf.flip();
        byte[] encryptedDataKeyArr = new byte[edekBuf.remaining()];
        edekBuf.get(encryptedDataKeyArr);
        EncryptedDataKey encryptedDataKey = new EncryptedDataKey(encryptedDataKeyArr);
        Keyset cleartext = CleartextKeysetHandle.getKeyset(dekKeysetHandle);
        CleartextDataKey cleartextDataKey = new CleartextDataKey(Util.extractRawAes256GCMKey(cleartext));
        return new DataEncryptionKeyHolder(encryptedDataKey, cleartextDataKey, Optional.of(Instant.ofEpochMilli(this.time.milliseconds())));
    }

    private DataEncryptionKeyHolder decryptDek(EncryptedDataKey encryptedDataKey, Instant keyCreationTime) throws IOException, GeneralSecurityException {
        ByteBuffer bb = ByteBuffer.wrap(encryptedDataKey.keyMaterial());
        ByteBufferInputStream in = new ByteBufferInputStream(bb);
        JsonKeysetReader reader = JsonKeysetReader.withInputStream((InputStream)in);
        long decryptStart = this.time.hiResClockMs();
        KeysetHandle keyset = KeysetHandle.read(reader, this.remoteKek);
        if (this.metrics != null) {
            this.metrics.recordDecryptCall(this.time.hiResClockMs() - decryptStart);
        }
        Keyset cleartext = CleartextKeysetHandle.getKeyset(keyset);
        CleartextDataKey cleartextDataKey = new CleartextDataKey(Util.extractRawAes256GCMKey(cleartext));
        return new DataEncryptionKeyHolder(encryptedDataKey, cleartextDataKey, Optional.of(keyCreationTime));
    }

    private static class KeyCache {
        private KeySha active;
        private final HashMap<KeySha, DataEncryptionKeyHolder> cache = new HashMap();

        private KeyCache() {
        }

        synchronized void add(DataEncryptionKeyHolder keyHolder) {
            this.cache.put(keyHolder.keySha, keyHolder);
        }

        synchronized KeySha activeKeySha() {
            return this.active;
        }

        synchronized void replaceActiveKeySha(DataEncryptionKeyHolder keyHolder) {
            this.active = keyHolder.keySha;
            this.cache.put(keyHolder.keySha, keyHolder);
        }

        synchronized DataEncryptionKeyHolder get(KeySha keySha) {
            return this.cache.get(keySha);
        }

        synchronized void clear() {
            this.active = null;
            this.cache.clear();
        }
    }

    public static interface WellKnownKeypathHook {
        public void writeWellKnownPathMetadata(Map<String, String> var1);

        public Map<String, String> fetchWellKnownPathMetadata();
    }
}

