/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.kafka.schemaregistry.encryption;

import com.google.crypto.tink.Aead;
import com.google.crypto.tink.KmsClient;
import com.google.crypto.tink.proto.AesGcmKey;
import com.google.crypto.tink.proto.AesSivKey;
import com.google.protobuf.ByteString;
import io.confluent.dekregistry.client.CachedDekRegistryClient;
import io.confluent.dekregistry.client.DekRegistryClient;
import io.confluent.dekregistry.client.DekRegistryClientFactory;
import io.confluent.dekregistry.client.rest.entities.Dek;
import io.confluent.dekregistry.client.rest.entities.Kek;
import io.confluent.kafka.schemaregistry.client.rest.entities.RuleMode;
import io.confluent.kafka.schemaregistry.client.rest.exceptions.RestClientException;
import io.confluent.kafka.schemaregistry.encryption.tink.Cryptor;
import io.confluent.kafka.schemaregistry.encryption.tink.DekFormat;
import io.confluent.kafka.schemaregistry.encryption.tink.KmsDriverManager;
import io.confluent.kafka.schemaregistry.rules.FieldRuleExecutor;
import io.confluent.kafka.schemaregistry.rules.FieldTransform;
import io.confluent.kafka.schemaregistry.rules.RuleContext;
import io.confluent.kafka.schemaregistry.rules.RuleException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.time.Clock;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.kafka.common.config.ConfigException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FieldEncryptionExecutor
extends FieldRuleExecutor {
    private static final Logger log = LoggerFactory.getLogger(FieldEncryptionExecutor.class);
    public static final String TYPE = "ENCRYPT";
    public static final String ENCRYPT_KEK_NAME = "encrypt.kek.name";
    public static final String ENCRYPT_KMS_KEY_ID = "encrypt.kms.key.id";
    public static final String ENCRYPT_KMS_TYPE = "encrypt.kms.type";
    public static final String ENCRYPT_DEK_ALGORITHM = "encrypt.dek.algorithm";
    public static final String ENCRYPT_DEK_EXPIRY_DAYS = "encrypt.dek.expiry.days";
    public static final String KMS_TYPE_SUFFIX = "://";
    public static final byte[] EMPTY_AAD = new byte[0];
    public static final String CACHE_EXPIRY_SECS = "cache.expiry.secs";
    public static final String CACHE_SIZE = "cache.size";
    public static final String CLOCK = "clock";
    protected static final int LATEST_VERSION = -1;
    protected static final byte MAGIC_BYTE = 0;
    protected static final int MILLIS_IN_DAY = 86400000;
    protected static final int VERSION_SIZE = 4;
    private Map<DekFormat, Cryptor> cryptors;
    private Map<String, ?> configs;
    private int cacheExpirySecs = -1;
    private int cacheSize = 10000;
    private Clock clock = Clock.systemUTC();
    private DekRegistryClient client;

    public boolean addOriginalConfigs() {
        return true;
    }

    public void configure(Map<String, ?> configs) {
        Object url;
        Object clock;
        Object cacheSizeConfig;
        super.configure(configs);
        this.configs = configs;
        Object cacheExpirySecsConfig = configs.get(CACHE_EXPIRY_SECS);
        if (cacheExpirySecsConfig != null) {
            try {
                this.cacheExpirySecs = Integer.parseInt(cacheExpirySecsConfig.toString());
            }
            catch (NumberFormatException e) {
                throw new ConfigException("Cannot parse cache.expiry.secs");
            }
        }
        if ((cacheSizeConfig = configs.get(CACHE_SIZE)) != null) {
            try {
                this.cacheSize = Integer.parseInt(cacheSizeConfig.toString());
            }
            catch (NumberFormatException e) {
                throw new ConfigException("Cannot parse cache.size");
            }
        }
        if ((clock = configs.get(CLOCK)) instanceof Clock) {
            this.clock = (Clock)clock;
        }
        if ((url = configs.get("schema.registry.url")) == null) {
            throw new ConfigException("Missing schema registry url!");
        }
        List<String> baseUrls = Arrays.asList(url.toString().split("\\s*,\\s*"));
        this.client = DekRegistryClientFactory.newClient(baseUrls, (int)this.cacheSize, (int)this.cacheExpirySecs, configs, Collections.emptyMap());
        this.cryptors = new ConcurrentHashMap<DekFormat, Cryptor>();
    }

    public String type() {
        return TYPE;
    }

    public FieldEncryptionExecutorTransform newTransform(RuleContext ctx) throws RuleException {
        FieldEncryptionExecutorTransform transform = new FieldEncryptionExecutorTransform();
        transform.init(ctx);
        return transform;
    }

    private Cryptor getCryptor(RuleContext ctx) {
        String algorithm = ctx.getParameter(ENCRYPT_DEK_ALGORITHM);
        DekFormat dekFormat = algorithm != null && !algorithm.isEmpty() ? DekFormat.valueOf((String)algorithm) : DekFormat.AES256_GCM;
        return this.getCryptor(dekFormat);
    }

    private Cryptor getCryptor(DekFormat dekFormat) {
        return this.cryptors.computeIfAbsent(dekFormat, k -> {
            try {
                return new Cryptor(dekFormat);
            }
            catch (GeneralSecurityException e) {
                throw new IllegalArgumentException("Invalid format " + dekFormat, e);
            }
        });
    }

    public Map<DekFormat, Cryptor> getCryptors() {
        return this.cryptors;
    }

    private byte[] generateKey(DekFormat dekFormat) throws GeneralSecurityException {
        byte[] dek = this.generateDek(dekFormat);
        if (dek != null) {
            switch (dekFormat) {
                case AES128_GCM: 
                case AES256_GCM: {
                    return AesGcmKey.newBuilder().setKeyValue(ByteString.copyFrom((byte[])dek)).build().toByteArray();
                }
                case AES256_SIV: {
                    return AesSivKey.newBuilder().setKeyValue(ByteString.copyFrom((byte[])dek)).build().toByteArray();
                }
            }
            throw new IllegalArgumentException("Invalid format " + dekFormat);
        }
        return this.getCryptor(dekFormat).generateKey();
    }

    protected byte[] generateDek(DekFormat dekFormat) throws GeneralSecurityException {
        return null;
    }

    private static byte[] toBytes(RuleContext.Type type, Object obj) {
        switch (type) {
            case BYTES: {
                if (obj instanceof ByteBuffer) {
                    return ((ByteBuffer)obj).array();
                }
                if (obj instanceof ByteString) {
                    return ((ByteString)obj).toByteArray();
                }
                if (obj instanceof byte[]) {
                    return (byte[])obj;
                }
                throw new IllegalArgumentException("Unrecognized bytes object of type: " + obj.getClass().getName());
            }
            case STRING: {
                return obj.toString().getBytes(StandardCharsets.UTF_8);
            }
        }
        return null;
    }

    private static Object toObject(RuleContext.Type type, byte[] bytes) {
        switch (type) {
            case BYTES: {
                return bytes;
            }
            case STRING: {
                return new String(bytes, StandardCharsets.UTF_8);
            }
        }
        return null;
    }

    public void close() throws RuleException {
        if (this.client != null) {
            try {
                this.client.close();
            }
            catch (IOException e) {
                throw new RuleException((Throwable)e);
            }
        }
    }

    private static Aead getAead(Map<String, ?> configs, Kek kek) throws GeneralSecurityException, RuleException {
        String kekUrl = kek.getKmsType() + KMS_TYPE_SUFFIX + kek.getKmsKeyId();
        KmsClient kmsClient = FieldEncryptionExecutor.getKmsClient(configs, kekUrl);
        if (kmsClient == null) {
            throw new RuleException("No kms client found for " + kekUrl);
        }
        return kmsClient.getAead(kekUrl);
    }

    private static KmsClient getKmsClient(Map<String, ?> configs, String kekUrl) throws GeneralSecurityException {
        try {
            return KmsDriverManager.getDriver((String)kekUrl).getKmsClient(kekUrl);
        }
        catch (GeneralSecurityException e) {
            return KmsDriverManager.getDriver((String)kekUrl).registerKmsClient(configs, Optional.of(kekUrl));
        }
    }

    public class FieldEncryptionExecutorTransform
    implements FieldTransform {
        private Cryptor cryptor;
        private String kekName;
        private Kek kek;
        private int dekExpiryDays;

        public void init(RuleContext ctx) throws RuleException {
            this.cryptor = FieldEncryptionExecutor.this.getCryptor(ctx);
            this.kekName = this.getKekName(ctx);
            this.kek = this.getOrCreateKek(ctx);
            this.dekExpiryDays = this.getDekExpiryDays(ctx);
        }

        public boolean isDekRotated() {
            return this.dekExpiryDays > 0;
        }

        protected String getKekName(RuleContext ctx) throws RuleException {
            String name = ctx.getParameter(FieldEncryptionExecutor.ENCRYPT_KEK_NAME);
            if (name == null) {
                throw new RuleException("No kek name found");
            }
            int length = name.length();
            if (length == 0) {
                throw new RuleException("Empty kek name");
            }
            char first = name.charAt(0);
            if (!Character.isLetter(first) && first != '_') {
                throw new RuleException("Illegal initial character in kek name: " + name);
            }
            for (int i = 1; i < length; ++i) {
                char c = name.charAt(i);
                if (Character.isLetterOrDigit(c) || c == '_' || c == '-') continue;
                throw new RuleException("Illegal character in kek name: " + name);
            }
            return name;
        }

        protected Kek getOrCreateKek(RuleContext ctx) throws RuleException {
            boolean isRead = ctx.ruleMode() == RuleMode.READ;
            CachedDekRegistryClient.KekId kekId = new CachedDekRegistryClient.KekId(this.kekName, isRead);
            String kmsType = ctx.getParameter(FieldEncryptionExecutor.ENCRYPT_KMS_TYPE);
            String kmsKeyId = ctx.getParameter(FieldEncryptionExecutor.ENCRYPT_KMS_KEY_ID);
            Kek kek = this.retrieveKekFromRegistry(kekId);
            if (kek == null) {
                if (isRead) {
                    throw new RuleException("No kek found for " + this.kekName + " during consume");
                }
                if (kmsType == null || kmsType.isEmpty()) {
                    throw new RuleException("No kms type found for " + this.kekName + " during produce");
                }
                if (kmsKeyId == null || kmsKeyId.isEmpty()) {
                    throw new RuleException("No kms key id found for " + this.kekName + " during produce");
                }
                kek = this.storeKekToRegistry(kekId, kmsType, kmsKeyId, false);
                if (kek == null) {
                    kek = this.retrieveKekFromRegistry(kekId);
                }
                if (kek == null) {
                    throw new RuleException("No kek found for " + this.kekName + " during produce");
                }
            }
            if (kmsType != null && !kmsType.isEmpty() && !kmsType.equals(kek.getKmsType())) {
                throw new RuleException("Found " + this.kekName + " with kms type '" + kek.getKmsType() + "' which differs from rule kms type '" + kmsType + "'");
            }
            if (kmsKeyId != null && !kmsKeyId.isEmpty() && !kmsKeyId.equals(kek.getKmsKeyId())) {
                throw new RuleException("Found " + this.kekName + " with kms key id '" + kek.getKmsKeyId() + "' which differs from rule kms key id '" + kmsKeyId + "'");
            }
            return kek;
        }

        private int getDekExpiryDays(RuleContext ctx) throws RuleException {
            int dekExpiryDays;
            String expiryStr = ctx.getParameter(FieldEncryptionExecutor.ENCRYPT_DEK_EXPIRY_DAYS);
            if (expiryStr == null || expiryStr.isEmpty()) {
                return 0;
            }
            try {
                dekExpiryDays = Integer.parseInt(expiryStr);
            }
            catch (NumberFormatException e) {
                throw new RuleException("Invalid value for encrypt.dek.expiry.days: " + expiryStr);
            }
            if (dekExpiryDays < 0) {
                throw new RuleException("Invalid value for encrypt.dek.expiry.days: " + expiryStr);
            }
            return dekExpiryDays;
        }

        private Kek retrieveKekFromRegistry(CachedDekRegistryClient.KekId key) throws RuleException {
            try {
                return FieldEncryptionExecutor.this.client.getKek(key.getName(), key.isLookupDeleted());
            }
            catch (RestClientException e) {
                if (e.getStatus() == 404) {
                    return null;
                }
                throw new RuleException("Could not get kek " + key.getName(), (Throwable)e);
            }
            catch (IOException e) {
                throw new RuleException("Could not get kek " + key.getName(), (Throwable)e);
            }
        }

        private Kek storeKekToRegistry(CachedDekRegistryClient.KekId key, String kmsType, String kmsKeyId, boolean shared) throws RuleException {
            try {
                Kek kek = FieldEncryptionExecutor.this.client.createKek(key.getName(), kmsType, kmsKeyId, null, null, shared);
                log.info("Registered kek " + key.getName());
                return kek;
            }
            catch (RestClientException e) {
                if (e.getStatus() == 409) {
                    return null;
                }
                throw new RuleException("Could not register kek " + key.getName(), (Throwable)e);
            }
            catch (IOException e) {
                throw new RuleException("Could not register kek " + key.getName(), (Throwable)e);
            }
        }

        public Dek getOrCreateDek(RuleContext ctx, Integer version) throws RuleException, GeneralSecurityException {
            boolean isRead = ctx.ruleMode() == RuleMode.READ;
            CachedDekRegistryClient.DekId dekId = new CachedDekRegistryClient.DekId(this.kekName, ctx.subject(), version, this.cryptor.getDekFormat(), isRead);
            Aead aead = null;
            Dek dek = this.retrieveDekFromRegistry(dekId);
            boolean isExpired = this.isExpired(ctx, dek);
            if (isExpired) {
                log.info("Dek with ts " + dek.getTimestamp() + " expired after " + this.dekExpiryDays + " day(s)");
            }
            if (dek == null || isExpired) {
                if (isRead) {
                    throw new RuleException("No dek found for " + this.kekName + " during consume");
                }
                byte[] encryptedDek = null;
                if (!this.kek.isShared()) {
                    aead = FieldEncryptionExecutor.getAead(FieldEncryptionExecutor.this.configs, this.kek);
                    byte[] rawDek = FieldEncryptionExecutor.this.generateKey(dekId.getDekFormat());
                    encryptedDek = aead.encrypt(rawDek, EMPTY_AAD);
                }
                Integer newVersion = isExpired ? Integer.valueOf(dek.getVersion() + 1) : null;
                CachedDekRegistryClient.DekId newDekId = new CachedDekRegistryClient.DekId(this.kekName, ctx.subject(), newVersion, this.cryptor.getDekFormat(), isRead);
                dek = this.storeDekToRegistry(newDekId, encryptedDek);
                if (dek == null) {
                    dek = this.retrieveDekFromRegistry(dekId);
                }
                if (dek == null) {
                    throw new RuleException("No dek found for " + this.kekName + " during produce");
                }
            }
            if (dek.getKeyMaterialBytes() == null) {
                if (aead == null) {
                    aead = FieldEncryptionExecutor.getAead(FieldEncryptionExecutor.this.configs, this.kek);
                }
                byte[] rawDek = aead.decrypt(dek.getEncryptedKeyMaterialBytes(), EMPTY_AAD);
                dek.setKeyMaterial(rawDek);
            }
            return dek;
        }

        private boolean isExpired(RuleContext ctx, Dek dek) {
            return ctx.ruleMode() != RuleMode.READ && this.dekExpiryDays > 0 && dek != null && (FieldEncryptionExecutor.this.clock.millis() - dek.getTimestamp()) / 86400000L >= (long)this.dekExpiryDays;
        }

        private Dek retrieveDekFromRegistry(CachedDekRegistryClient.DekId key) throws RuleException {
            try {
                Dek dek = key.getVersion() != null ? FieldEncryptionExecutor.this.client.getDekVersion(key.getKekName(), key.getSubject(), key.getVersion().intValue(), key.getDekFormat(), key.isLookupDeleted()) : FieldEncryptionExecutor.this.client.getDek(key.getKekName(), key.getSubject(), key.getDekFormat(), key.isLookupDeleted());
                return dek != null && dek.getEncryptedKeyMaterial() != null ? dek : null;
            }
            catch (RestClientException e) {
                if (e.getStatus() == 404) {
                    return null;
                }
                throw new RuleException("Could not get dek for kek " + key.getKekName() + ", subject " + key.getSubject(), (Throwable)e);
            }
            catch (IOException e) {
                throw new RuleException("Could not get dek for kek " + key.getKekName() + ", subject " + key.getSubject(), (Throwable)e);
            }
        }

        private Dek storeDekToRegistry(CachedDekRegistryClient.DekId key, byte[] encryptedDek) throws RuleException {
            try {
                String encryptedDekStr = encryptedDek != null ? (String)FieldEncryptionExecutor.toObject(RuleContext.Type.STRING, Base64.getEncoder().encode(encryptedDek)) : null;
                Dek dek = key.getVersion() != null ? FieldEncryptionExecutor.this.client.createDek(key.getKekName(), key.getSubject(), key.getVersion().intValue(), key.getDekFormat(), encryptedDekStr) : FieldEncryptionExecutor.this.client.createDek(key.getKekName(), key.getSubject(), key.getDekFormat(), encryptedDekStr);
                log.info("Registered dek for kek " + key.getKekName() + ", subject " + key.getSubject());
                return dek;
            }
            catch (RestClientException e) {
                if (e.getStatus() == 409) {
                    return null;
                }
                throw new RuleException("Could not register dek for kek " + key.getKekName() + ", subject " + key.getSubject(), (Throwable)e);
            }
            catch (IOException e) {
                throw new RuleException("Could not register dek for kek " + key.getKekName() + ", subject " + key.getSubject(), (Throwable)e);
            }
        }

        public Object transform(RuleContext ctx, RuleContext.FieldContext fieldCtx, Object fieldValue) throws RuleException {
            try {
                if (fieldValue == null) {
                    return null;
                }
                switch (ctx.ruleMode()) {
                    case WRITE: {
                        byte[] plaintext = FieldEncryptionExecutor.toBytes(fieldCtx.getType(), fieldValue);
                        if (plaintext == null) {
                            throw new RuleException("Type '" + fieldCtx.getType() + "' not supported for encryption");
                        }
                        Dek dek = this.getOrCreateDek(ctx, this.isDekRotated() ? Integer.valueOf(-1) : null);
                        byte[] ciphertext = this.cryptor.encrypt(dek.getKeyMaterialBytes(), plaintext, EMPTY_AAD);
                        if (this.isDekRotated()) {
                            ciphertext = this.prefixVersion(dek.getVersion(), ciphertext);
                        }
                        if (fieldCtx.getType() == RuleContext.Type.STRING) {
                            ciphertext = Base64.getEncoder().encode(ciphertext);
                        }
                        return FieldEncryptionExecutor.toObject(fieldCtx.getType(), ciphertext);
                    }
                    case READ: {
                        byte[] ciphertext = FieldEncryptionExecutor.toBytes(fieldCtx.getType(), fieldValue);
                        if (ciphertext == null) {
                            return fieldValue;
                        }
                        if (fieldCtx.getType() == RuleContext.Type.STRING) {
                            ciphertext = Base64.getDecoder().decode(ciphertext);
                        }
                        Integer version = null;
                        if (this.isDekRotated()) {
                            Map.Entry<Integer, byte[]> kv = this.extractVersion(ciphertext);
                            version = kv.getKey();
                            ciphertext = kv.getValue();
                        }
                        Dek dek = this.getOrCreateDek(ctx, version);
                        byte[] plaintext = this.cryptor.decrypt(dek.getKeyMaterialBytes(), ciphertext, EMPTY_AAD);
                        return FieldEncryptionExecutor.toObject(fieldCtx.getType(), plaintext);
                    }
                }
                throw new IllegalArgumentException("Unsupported rule mode " + ctx.ruleMode());
            }
            catch (Exception e) {
                throw new RuleException((Throwable)e);
            }
        }

        private byte[] prefixVersion(int version, byte[] ciphertext) {
            byte[] combined = new byte[ciphertext.length + 1 + 4];
            ByteBuffer buffer = ByteBuffer.wrap(combined);
            buffer.put((byte)0);
            buffer.putInt(version);
            buffer.put(ciphertext);
            return combined;
        }

        private Map.Entry<Integer, byte[]> extractVersion(byte[] ciphertext) throws RuleException {
            ByteBuffer buffer = ByteBuffer.wrap(ciphertext);
            if (buffer.get() != 0) {
                throw new RuleException("Unknown magic byte!");
            }
            int version = buffer.getInt();
            int remainingSize = ciphertext.length - 1 - 4;
            byte[] remaining = new byte[remainingSize];
            buffer.get(remaining, 0, remainingSize);
            return new AbstractMap.SimpleEntry<Integer, byte[]>(version, remaining);
        }

        public void close() {
        }
    }
}

