/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.kafka.replication.push.buffer;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.RemovalNotification;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Consumer;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class RefCountingMemoryTracker<K> {
    private static final Logger log = LoggerFactory.getLogger(RefCountingMemoryTracker.class);
    private final Function<K, Integer> sizeEstimator;
    private final long maxTrackedMemory;
    private final LongAdder totalBytes;
    private final Cache<K, AtomicInteger> tracker;

    public RefCountingMemoryTracker(Function<K, Integer> sizeEstimator, long maxTrackedMemory, Consumer<RemovalNotification<K, AtomicInteger>> expirationCallback) {
        this.sizeEstimator = sizeEstimator;
        this.maxTrackedMemory = maxTrackedMemory;
        this.totalBytes = new LongAdder();
        CacheBuilder trackerBuilder = CacheBuilder.newBuilder().weakKeys().removalListener(expirationCallback::accept);
        this.tracker = trackerBuilder.build();
    }

    public boolean initCount(K memorySized, int refCount) {
        if (refCount <= 0) {
            throw new IllegalArgumentException("Ref count must be positive (was " + refCount + ")");
        }
        int sizeInBytes = this.sizeEstimator.apply(memorySized);
        if (this.totalBytes.longValue() + (long)sizeInBytes > this.maxTrackedMemory) {
            return false;
        }
        AtomicInteger counter = new AtomicInteger(refCount);
        try {
            if (!counter.equals(this.tracker.get(memorySized, () -> counter))) {
                throw new IllegalArgumentException("The given key " + memorySized + " is already being tracked");
            }
        }
        catch (ExecutionException e) {
            throw new RuntimeException(e);
        }
        this.totalBytes.add(sizeInBytes);
        return true;
    }

    public int countDown(K memorySized) {
        if (memorySized == null) {
            throw new NullPointerException("The given key must not be null");
        }
        AtomicInteger refCount = (AtomicInteger)this.tracker.getIfPresent(memorySized);
        if (refCount == null) {
            return 0;
        }
        int count = refCount.decrementAndGet();
        if (count <= 0) {
            this.tracker.invalidate(memorySized);
            this.totalBytes.add(-this.sizeEstimator.apply(memorySized).intValue());
        }
        return Math.max(count, 0);
    }

    public void close() {
        this.tracker.invalidateAll();
        this.totalBytes.reset();
    }

    public long totalBytes() {
        return this.totalBytes.longValue();
    }
}

