/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.test;

import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadInfo;
import java.lang.management.ThreadMXBean;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Paths;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.kafka.clients.admin.AdminClientUnitTestEnv;
import org.apache.kafka.test.GlobalLeakedThreads;
import org.apache.kafka.test.TestUtils;
import org.junit.platform.engine.TestExecutionResult;
import org.junit.platform.engine.TestSource;
import org.junit.platform.engine.support.descriptor.ClassSource;
import org.junit.platform.engine.support.descriptor.MethodSource;
import org.junit.platform.launcher.TestExecutionListener;
import org.junit.platform.launcher.TestIdentifier;

public class ThreadLeakListener
implements TestExecutionListener {
    public void executionFinished(TestIdentifier testIdentifier, TestExecutionResult testExecutionResult) {
        this.verifyNoThreadLeaks(testIdentifier);
    }

    public void executionSkipped(TestIdentifier testIdentifier, String reason) {
        this.verifyNoThreadLeaks(testIdentifier);
    }

    private void verifyNoThreadLeaks(TestIdentifier testIdentifier) {
        HashSet<String> unexpectedThreadNames = new HashSet<String>();
        unexpectedThreadNames.add("kafka-coordinator-heartbeat-thread");
        unexpectedThreadNames.add(AdminClientUnitTestEnv.kafkaAdminClientNetworkThreadPrefix());
        long waitTimeMs = 130000L;
        List<Thread> leakedThreads = TestUtils.getNewlyLeakedThreads(waitTimeMs, unexpectedThreadNames);
        if (!leakedThreads.isEmpty()) {
            List<Long> leakedThreadIds = leakedThreads.stream().map(Thread::getId).collect(Collectors.toList());
            this.writeThreadDumpToFile(testIdentifier, GlobalLeakedThreads.getLeakedThreads());
            GlobalLeakedThreads.add(leakedThreadIds);
            throw new RuntimeException("Found leaked threads while running " + testIdentifier.getDisplayName());
        }
    }

    private String getFileNameFromTestIdentifier(TestIdentifier testIdentifier) {
        String fileName = "leakedTestStackDump-";
        Optional testSourceOptional = testIdentifier.getSource();
        if (testSourceOptional.isPresent()) {
            TestSource testSource = (TestSource)testSourceOptional.get();
            if (testSource instanceof MethodSource) {
                MethodSource methodSource = (MethodSource)testSource;
                fileName = fileName + methodSource.getClassName() + "-" + methodSource.getMethodName();
                if (!methodSource.getMethodParameterTypes().isEmpty()) {
                    fileName = fileName + "-" + methodSource.getMethodParameterTypes();
                }
            } else if (testSource instanceof ClassSource) {
                ClassSource classSource = (ClassSource)testSource;
                fileName = fileName + classSource.getClassName();
            }
        } else {
            fileName = fileName + "UnknownSource";
        }
        return fileName + ".log";
    }

    private void writeThreadDumpToFile(TestIdentifier testIdentifier, Set<Long> ignored) {
        String fileName = this.getFileNameFromTestIdentifier(testIdentifier);
        File buildDirectory = new File("build/thread-reports");
        if (!buildDirectory.exists()) {
            buildDirectory.mkdirs();
        }
        String filePath = buildDirectory + "/" + fileName;
        try (PrintStream out = new PrintStream(Files.newOutputStream(Paths.get(filePath, new String[0]), new OpenOption[0]));){
            ThreadInfo[] threadInfos;
            ThreadMXBean threadMxBean = ManagementFactory.getThreadMXBean();
            for (ThreadInfo threadInfo : threadInfos = threadMxBean.dumpAllThreads(true, true)) {
                if (ignored.contains(threadInfo.getThreadId())) continue;
                out.println("Thread: " + threadInfo.getThreadId() + " - " + threadInfo.getThreadName());
                for (StackTraceElement stackTraceElement : threadInfo.getStackTrace()) {
                    out.println("\t" + stackTraceElement.toString());
                }
                out.println();
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }
}

