/*
 * Decompiled with CFR 0.152.
 */
package me.modmuss50.optifabric.patcher;

import com.google.common.collect.Sets;
import com.google.common.util.concurrent.Runnables;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import me.modmuss50.optifabric.patcher.ASMUtils;
import me.modmuss50.optifabric.patcher.Lambda;
import me.modmuss50.optifabric.patcher.MethodComparison;
import net.fabricmc.tinyremapper.IMappingProvider;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.MethodNode;

public class LambdaRebuiler
implements IMappingProvider {
    private final File optifineFile;
    private final File minecraftClientFile;
    private final Map<IMappingProvider.Member, String> fixes = new HashMap<IMappingProvider.Member, String>();

    public LambdaRebuiler(File optifineFile, File minecraftClientFile) throws IOException {
        this.optifineFile = optifineFile;
        this.minecraftClientFile = minecraftClientFile;
    }

    public void buildLambadaMap() throws IOException {
        try (JarFile optifineJar = new JarFile(this.optifineFile);
             JarFile clientJar = new JarFile(this.minecraftClientFile);){
            Enumeration<JarEntry> entrys = optifineJar.entries();
            while (entrys.hasMoreElements()) {
                JarEntry entry = entrys.nextElement();
                String name = entry.getName();
                if (!name.endsWith(".class") || name.startsWith("net/") || name.startsWith("optifine/") || name.startsWith("javax/")) continue;
                ClassNode classNode = ASMUtils.asClassNode(entry, optifineJar);
                ClassNode minecraftClass = ASMUtils.asClassNode(clientJar.getJarEntry(name), clientJar);
                if (!minecraftClass.name.equals(classNode.name)) {
                    throw new RuntimeException("Something went wrong");
                }
                this.findLambdas(minecraftClass.name, minecraftClass.methods, classNode.methods);
            }
        }
    }

    private boolean findLambdas(String className, List<MethodNode> original, List<MethodNode> patched) {
        MethodComparison method2;
        int i;
        ArrayList<MethodComparison> commonMethods = new ArrayList<MethodComparison>();
        ArrayList<MethodNode> lostMethods = new ArrayList<MethodNode>();
        ArrayList<MethodNode> gainedMethods = new ArrayList<MethodNode>();
        Map originalMethods = original.stream().collect(Collectors.toMap(method -> method.name.concat(method.desc), Function.identity()));
        Map patchedMethods = patched.stream().collect(Collectors.toMap(method -> method.name.concat(method.desc), Function.identity()));
        for (String methodName : Sets.union(originalMethods.keySet(), patchedMethods.keySet())) {
            MethodNode originalMethod = (MethodNode)originalMethods.get(methodName);
            MethodNode patchedMethod = (MethodNode)patchedMethods.get(methodName);
            if (originalMethod != null) {
                if (patchedMethod != null) {
                    commonMethods.add(new MethodComparison(originalMethod, patchedMethod));
                    continue;
                }
                lostMethods.add(originalMethod);
                continue;
            }
            if (patchedMethod != null) {
                gainedMethods.add(patchedMethod);
                continue;
            }
            throw new IllegalStateException("Unable to find " + methodName + " in either " + className + " versions");
        }
        commonMethods.sort(Comparator.comparingInt(method -> !"<clinit>".equals(method.node.name) ? patched.indexOf(method.node) : ("com/mojang/blaze3d/platform/GLX".equals(className) ? patched.size() : -1)));
        lostMethods.sort(Comparator.comparingInt(original::indexOf));
        gainedMethods.sort(Comparator.comparingInt(patched::indexOf));
        if (commonMethods.stream().noneMatch(method -> !method.equal && method.hasLambdas()) || lostMethods.isEmpty() || gainedMethods.isEmpty()) {
            return true;
        }
        List gainedLambdas = gainedMethods.stream().filter(method -> (method.access & 0x1000) != 0 && method.name.startsWith("lambda$")).collect(Collectors.toList());
        if (gainedLambdas.isEmpty()) {
            return true;
        }
        Map<String, MethodNode> possibleLambdas = gainedLambdas.stream().collect(Collectors.toMap(method -> method.name.concat(method.desc), Function.identity()));
        Map<String, MethodNode> nameToLosses = lostMethods.stream().collect(Collectors.toMap(method -> method.name.concat(method.desc), Function.identity()));
        for (i = 0; i < commonMethods.size(); ++i) {
            method2 = (MethodComparison)commonMethods.get(i);
            if (!method2.effectivelyEqual) continue;
            this.resolveCloseMethod(className, commonMethods, lostMethods, gainedMethods, method2, nameToLosses, possibleLambdas);
        }
        for (i = 0; i < commonMethods.size(); ++i) {
            List<Lambda> patchedLambdas;
            List<Lambda> originalLambdas;
            block12: {
                method2 = (MethodComparison)commonMethods.get(i);
                if (method2.effectivelyEqual) continue;
                originalLambdas = method2.getOriginalLambads();
                patchedLambdas = method2.getPatchedLambads();
                if (originalLambdas.size() == patchedLambdas.size()) {
                    Iterator<Lambda> itOriginal = originalLambdas.iterator();
                    Iterator<Lambda> itPatched = patchedLambdas.iterator();
                    while (itOriginal.hasNext() && itPatched.hasNext()) {
                        Lambda originalLambda = itOriginal.next();
                        Lambda patchedLambda = itPatched.next();
                        if (Objects.equals(originalLambda.method, patchedLambda.method)) continue;
                        break block12;
                    }
                    this.pairUp(className, commonMethods, lostMethods, gainedMethods, originalLambdas, patchedLambdas, nameToLosses, possibleLambdas, () -> {
                        for (int j = commonMethods.size() - 1; j < commonMethods.size(); ++j) {
                            MethodComparison innerMethod = (MethodComparison)commonMethods.get(j);
                            if (!innerMethod.effectivelyEqual) continue;
                            this.resolveCloseMethod(className, commonMethods, lostMethods, gainedMethods, innerMethod, nameToLosses, possibleLambdas);
                        }
                    });
                    continue;
                }
            }
            Collector<Lambda, ?, Map<String, Map<String, List<Lambda>>>> lambdaCategorisation = Collectors.groupingBy(lambda -> lambda.desc, Collectors.groupingBy(lambda -> lambda.method));
            Map<String, Map<String, List<Lambda>>> descToOriginalLambda = originalLambdas.stream().collect(lambdaCategorisation);
            Map<String, Map<String, List<Lambda>>> descToPatchedLambda = patchedLambdas.stream().collect(lambdaCategorisation);
            Sets.SetView commonDescs = Sets.intersection(descToOriginalLambda.keySet(), descToPatchedLambda.keySet());
            if (commonDescs.isEmpty()) continue;
            int fixedLambdas = 0;
            for (String desc : commonDescs) {
                Map<String, List<Lambda>> typeToOriginalLambda = descToOriginalLambda.get(desc);
                Map<String, List<Lambda>> typeToPatchedLambda = descToPatchedLambda.get(desc);
                for (String type : Sets.intersection(typeToOriginalLambda.keySet(), typeToPatchedLambda.keySet())) {
                    List<Lambda> matchedOriginalLambdas = typeToOriginalLambda.get(type);
                    List<Lambda> matchedPatchedLambdas = typeToPatchedLambda.get(type);
                    if (matchedOriginalLambdas.size() != matchedPatchedLambdas.size()) continue;
                    fixedLambdas += matchedOriginalLambdas.size();
                    this.pairUp(className, commonMethods, lostMethods, gainedMethods, matchedOriginalLambdas, matchedPatchedLambdas, nameToLosses, possibleLambdas, () -> {
                        for (int j = commonMethods.size() - 1; j < commonMethods.size(); ++j) {
                            MethodComparison innerMethod = (MethodComparison)commonMethods.get(j);
                            if (!innerMethod.effectivelyEqual) continue;
                            this.resolveCloseMethod(className, commonMethods, lostMethods, gainedMethods, innerMethod, nameToLosses, possibleLambdas);
                        }
                    });
                }
            }
            if (fixedLambdas != originalLambdas.size()) continue;
            return true;
        }
        return possibleLambdas.isEmpty();
    }

    private void resolveCloseMethod(String className, List<MethodComparison> commonMethods, List<MethodNode> lostMethods, List<MethodNode> gainedMethods, MethodComparison method, Map<String, MethodNode> nameToLosses, Map<String, MethodNode> possibleLambdas) {
        assert (method.effectivelyEqual);
        if (!method.equal) {
            if (method.getOriginalLambads().size() != method.getPatchedLambads().size()) {
                throw new IllegalStateException("Bytecode in " + className + '#' + method.node.name + method.node.desc + " appeared unchanged but lambda count changed?");
            }
            this.pairUp(className, commonMethods, lostMethods, gainedMethods, method.getOriginalLambads(), method.getPatchedLambads(), nameToLosses, possibleLambdas, Runnables.doNothing());
        }
    }

    private void pairUp(String className, List<MethodComparison> commonMethods, List<MethodNode> lostMethods, List<MethodNode> gainedMethods, List<Lambda> originalLambdas, List<Lambda> patchedLambdas, Map<String, MethodNode> nameToLosses, Map<String, MethodNode> possibleLambdas, Runnable onPair) {
        Iterator<Lambda> itOriginal = originalLambdas.iterator();
        Iterator<Lambda> itPatched = patchedLambdas.iterator();
        while (itOriginal.hasNext() && itPatched.hasNext()) {
            Lambda lost = itOriginal.next();
            Lambda gained = itPatched.next();
            if (!className.equals(lost.owner)) {
                return;
            }
            assert (className.equals(gained.owner));
            MethodNode lostMethod = nameToLosses.remove(lost.getName());
            MethodNode gainedMethod = possibleLambdas.remove(gained.getName());
            if (lostMethod == null) {
                if (gainedMethod == null) {
                    assert (Objects.equals(lost.getFullName(), gained.getFullName()));
                    return;
                }
                throw new IllegalStateException("Couldn't find original method for lambda: " + lost.getFullName());
            }
            if (gainedMethod == null) {
                throw new IllegalStateException("Couldn't find patched method for lambda: " + gained.getFullName());
            }
            if (!this.addFix(className, commonMethods, gainedMethod, lostMethod)) continue;
            lostMethods.remove(lostMethod);
            gainedMethods.remove(gainedMethod);
            onPair.run();
        }
    }

    private boolean addFix(String className, List<MethodComparison> commonMethods, MethodNode from, MethodNode to) {
        if (!from.desc.equals(to.desc)) {
            System.err.println("Description changed remapping lambda handle: " + className + '#' + from.name + from.desc + " => " + className + '#' + to.name + to.desc);
            return false;
        }
        this.fixes.put(new IMappingProvider.Member(className, from.name, from.desc), to.name);
        from.name = to.name;
        commonMethods.add(new MethodComparison(to, from));
        return true;
    }

    public void load(IMappingProvider.MappingAcceptor out) {
        this.fixes.forEach((arg_0, arg_1) -> ((IMappingProvider.MappingAcceptor)out).acceptMethod(arg_0, arg_1));
    }
}

