diff --git a/src/main/java/org/embeddedt/modernfix/forge/capability/CapabilityProviderDispatcherGenerator.java b/src/main/java/org/embeddedt/modernfix/forge/capability/CapabilityProviderDispatcherGenerator.java index 05dc3c4d..2e460737 100644 --- a/src/main/java/org/embeddedt/modernfix/forge/capability/CapabilityProviderDispatcherGenerator.java +++ b/src/main/java/org/embeddedt/modernfix/forge/capability/CapabilityProviderDispatcherGenerator.java @@ -18,7 +18,11 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; @@ -31,6 +35,23 @@ import static org.objectweb.asm.Opcodes.*; * and performs direct dispatch instead of megamorphic virtual calls. */ public class CapabilityProviderDispatcherGenerator { + /** + * Describes the dispatch strategy for a single capability provider in the generated class. + */ + sealed interface ProviderDispatch { + /** Provider handles a known capability - emit an identity guard before dispatch. */ + record Guarded(int providerIndex, String fieldDesc, CapabilityRef capability) implements ProviderDispatch {} + /** Provider capabilities are unknown - dispatch unconditionally. */ + record Unguarded(int providerIndex, String fieldDesc) implements ProviderDispatch {} + /** Multiple guarded dispatches collapsed into a Map lookup. */ + record Hash(int mapIndex, List entries) implements ProviderDispatch {} + } + + /** + * Number of consecutive equality checks that must be performed to switch to a hash map. + */ + private static final int HASH_DISPATCH_THRESHOLD = 3; + private static final String GENERATED_CLASSES_FOLDER = System.getProperty("modernfix.generatedCapabilityDispatcherClassDumpFolder", ""); private static final ConcurrentHashMap>, MethodHandle> cache = @@ -44,6 +65,7 @@ public class CapabilityProviderDispatcherGenerator { private static final String CAPABILITY_DESC = "Lnet/minecraftforge/common/capabilities/Capability;"; private static final String LAZY_OPTIONAL_DESC = "Lnet/minecraftforge/common/util/LazyOptional;"; private static final String DIRECTION_DESC = "Lnet/minecraft/core/Direction;"; + private static final String MAP_DESC = "Ljava/util/Map;"; /** * Gets or generates a constructor MethodHandle for the given capability provider types. @@ -124,8 +146,122 @@ public class CapabilityProviderDispatcherGenerator { } } + /** + * Build the dispatch list describing how each provider should be handled. + */ + static List buildDispatchList(List> providerTypes, List analysisResults) { + List dispatches = new ArrayList<>(providerTypes.size()); + for (int i = 0; i < providerTypes.size(); i++) { + Class type = providerTypes.get(i); + String fieldDesc = (!type.isHidden() && Modifier.isPublic(type.getModifiers())) + ? Type.getDescriptor(type) : ICAP_PROVIDER_DESC; + + CapabilityAnalysisResult analysis = analysisResults.get(i); + if (analysis instanceof CapabilityAnalysisResult.AlwaysEmpty) { + // No dispatch needed - provider never returns a capability + } else if (analysis instanceof CapabilityAnalysisResult.KnownCapabilities known + && known.capabilities().size() <= 5) { + for (CapabilityRef ref : known.capabilities()) { + dispatches.add(new ProviderDispatch.Guarded(i, fieldDesc, ref)); + } + } else { + dispatches.add(new ProviderDispatch.Unguarded(i, fieldDesc)); + } + } + return dispatches; + } + + /** + * Collapse runs of 3+ consecutive Guarded dispatches into Hash dispatches. + * Duplicate CapabilityRefs within a run are kept as trailing Guarded entries + * after the Hash to preserve sequential fallthrough semantics. + */ + static List optimizeDispatches(List dispatches) { + List result = new ArrayList<>(dispatches.size()); + int mapIndex = 0; + int i = 0; + while (i < dispatches.size()) { + // Collect a run of consecutive Guarded entries + int runStart = i; + while (i < dispatches.size() && dispatches.get(i) instanceof ProviderDispatch.Guarded) { + i++; + } + + List run = dispatches.subList(runStart, i); + if (run.isEmpty()) { + // Not a Guarded entry, pass through + result.add(dispatches.get(i)); + i++; + continue; + } + + if (!tryCollapseToHash(run, mapIndex, result)) { + result.addAll(run); + } else { + mapIndex++; + } + } + return result; + } + + /** + * Attempt to collapse a run of Guarded dispatches into a Hash. + * Returns true if a Hash was emitted, false if the run should be kept as-is. + */ + private static boolean tryCollapseToHash(List run, int mapIndex, List result) { + if (run.size() < HASH_DISPATCH_THRESHOLD) { + return false; + } + + // Deduplicate by CapabilityRef - first occurrence goes into the hash, + // duplicates are kept as trailing Guarded entries for fallthrough + Set seen = new HashSet<>(); + List hashEntries = new ArrayList<>(); + List duplicates = new ArrayList<>(); + for (ProviderDispatch dispatch : run) { + ProviderDispatch.Guarded g = (ProviderDispatch.Guarded) dispatch; + if (seen.add(g.capability())) { + hashEntries.add(g); + } else { + duplicates.add(g); + } + } + + if (hashEntries.size() < HASH_DISPATCH_THRESHOLD) { + return false; + } + + result.add(new ProviderDispatch.Hash(mapIndex, hashEntries)); + result.addAll(duplicates); + return true; + } + + /** + * Collect all unique provider fields (index → fieldDesc) referenced by a dispatch list, + * including those inside Hash entries. + */ + private static LinkedHashMap collectProviderFields(List dispatches) { + LinkedHashMap fields = new LinkedHashMap<>(); + for (ProviderDispatch dispatch : dispatches) { + if (dispatch instanceof ProviderDispatch.Guarded g) { + fields.putIfAbsent(g.providerIndex(), g.fieldDesc()); + } else if (dispatch instanceof ProviderDispatch.Unguarded u) { + fields.putIfAbsent(u.providerIndex(), u.fieldDesc()); + } + // Hash entries don't need provider fields - map reads from constructor array + } + return fields; + } + private static byte[] generateClassBytes(String className, List> providerTypes, List analysisResults) { - ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS); + List dispatches = optimizeDispatches(buildDispatchList(providerTypes, analysisResults)); + + ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS) { + @Override + protected ClassLoader getClassLoader() { + return CapabilityProviderDispatcherGenerator.class.getClassLoader(); + } + }; // Class declaration: implements ICapabilityProvider cw.visit( @@ -137,67 +273,89 @@ public class CapabilityProviderDispatcherGenerator { new String[] { "net/minecraftforge/common/capabilities/ICapabilityProvider" } ); - // Compute field descriptors: use concrete type when possible for JIT devirtualization - String[] fieldDescs = new String[providerTypes.size()]; - for (int i = 0; i < providerTypes.size(); i++) { - Class type = providerTypes.get(i); - fieldDescs[i] = (!type.isHidden() && Modifier.isPublic(type.getModifiers())) - ? Type.getDescriptor(type) : ICAP_PROVIDER_DESC; + // Generate final fields for each distinct provider + LinkedHashMap providerFields = collectProviderFields(dispatches); + for (var entry : providerFields.entrySet()) { + cw.visitField(ACC_PRIVATE | ACC_FINAL, "provider" + entry.getKey(), entry.getValue(), null, null).visitEnd(); } - // Generate final fields for each provider - for (int i = 0; i < providerTypes.size(); i++) { - cw.visitField( - ACC_PRIVATE | ACC_FINAL, - "provider" + i, - fieldDescs[i], - null, - null - ).visitEnd(); + // Generate map fields for Hash dispatches + for (ProviderDispatch dispatch : dispatches) { + if (dispatch instanceof ProviderDispatch.Hash hash) { + cw.visitField(ACC_PRIVATE | ACC_FINAL, "capMap" + hash.mapIndex(), MAP_DESC, null, null).visitEnd(); + } } // Generate constructor - generateConstructor(cw, className, providerTypes.size(), fieldDescs); + generateConstructor(cw, className, providerFields, dispatches); // Generate getCapability method with sided parameter - generateGetCapabilityMethod(cw, className, fieldDescs, analysisResults); + generateGetCapabilityMethod(cw, className, dispatches); cw.visitEnd(); return cw.toByteArray(); } - private static void generateConstructor(ClassWriter cw, String className, int providerCount, String[] fieldDescs) { + private static void generateConstructor(ClassWriter cw, String className, Map providerFields, List dispatches) { Method constructor = Method.getMethod("void (net.minecraftforge.common.capabilities.ICapabilityProvider[])"); GeneratorAdapter mg = new GeneratorAdapter(ACC_PUBLIC, constructor, null, null, cw); + Type classType = Type.getObjectType(className.replace('.', '/')); // Call super constructor mg.loadThis(); mg.invokeConstructor(Type.getType(Object.class), Method.getMethod("void ()")); - // Unpack array into final fields - for (int i = 0; i < providerCount; i++) { - Type fieldType = Type.getType(fieldDescs[i]); - mg.loadThis(); // this + // Unpack array into provider fields + for (var entry : providerFields.entrySet()) { + int idx = entry.getKey(); + String desc = entry.getValue(); + Type fieldType = Type.getType(desc); + mg.loadThis(); mg.loadArg(0); // array - mg.push(i); // index - mg.arrayLoad(Type.getType(ICAP_PROVIDER_DESC)); // array[i] - if (!fieldDescs[i].equals(ICAP_PROVIDER_DESC)) { + mg.push(idx); // index + mg.arrayLoad(Type.getType(ICAP_PROVIDER_DESC)); + if (!desc.equals(ICAP_PROVIDER_DESC)) { mg.checkCast(fieldType); } - mg.putField( - Type.getObjectType(className.replace('.', '/')), - "provider" + i, - fieldType - ); + mg.putField(classType, "provider" + idx, fieldType); + } + + // Build hash maps + for (ProviderDispatch dispatch : dispatches) { + if (dispatch instanceof ProviderDispatch.Hash hash) { + generateMapConstruction(mg, classType, hash); + } } mg.returnValue(); mg.endMethod(); } - private static void generateGetCapabilityMethod(ClassWriter cw, String className, String[] fieldDescs, List analysisResults) { - int providerCount = fieldDescs.length; + private static void generateMapConstruction(GeneratorAdapter mg, Type classType, ProviderDispatch.Hash hash) { + List entries = hash.entries(); + mg.loadThis(); // for PUTFIELD at the end + mg.push(entries.size()); + mg.visitTypeInsn(ANEWARRAY, "java/util/Map$Entry"); + for (int i = 0; i < entries.size(); i++) { + ProviderDispatch.Guarded g = entries.get(i); + mg.dup(); + mg.push(i); + mg.visitFieldInsn(GETSTATIC, g.capability().owner(), g.capability().fieldName(), CAPABILITY_DESC); + mg.loadArg(0); + mg.push(g.providerIndex()); + mg.arrayLoad(Type.getType(ICAP_PROVIDER_DESC)); + mg.visitMethodInsn(INVOKESTATIC, "java/util/Map", "entry", + "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/util/Map$Entry;", true); + mg.visitInsn(AASTORE); + } + mg.visitMethodInsn(INVOKESTATIC, "java/util/Map", "ofEntries", + "([Ljava/util/Map$Entry;)Ljava/util/Map;", true); + + mg.putField(classType, "capMap" + hash.mapIndex(), Type.getType(MAP_DESC)); + } + + private static void generateGetCapabilityMethod(ClassWriter cw, String className, List dispatches) { // Method: LazyOptional getCapability(Capability, Direction) MethodVisitor mv = cw.visitMethod( ACC_PUBLIC, @@ -213,76 +371,73 @@ public class CapabilityProviderDispatcherGenerator { // For each provider, call getCapability and check if present Label endLabel = new Label(); - for (int i = 0; i < providerCount; i++) { - CapabilityAnalysisResult analysis = analysisResults.get(i); + String internalName = className.replace('.', '/'); + String getCapDesc = "(" + CAPABILITY_DESC + DIRECTION_DESC + ")" + LAZY_OPTIONAL_DESC; + + for (ProviderDispatch dispatch : dispatches) { Label nextLabel = new Label(); - // AlwaysEmpty: skip code generation for this provider entirely - if (analysis instanceof CapabilityAnalysisResult.AlwaysEmpty) { - continue; - } + if (dispatch instanceof ProviderDispatch.Hash hash) { + // ICapabilityProvider p = (ICapabilityProvider) this.capMapN.get(cap); + mv.visitVarInsn(ALOAD, 0); + mv.visitFieldInsn(GETFIELD, internalName, "capMap" + hash.mapIndex(), MAP_DESC); + mv.visitVarInsn(ALOAD, 1); + mv.visitMethodInsn(INVOKEINTERFACE, "java/util/Map", "get", + "(Ljava/lang/Object;)Ljava/lang/Object;", true); + mv.visitVarInsn(ASTORE, 3); - // KnownCapabilities: emit guard checks before dispatch - if (analysis instanceof CapabilityAnalysisResult.KnownCapabilities known - && known.capabilities().size() <= 5) { - if (known.capabilities().size() == 1) { - // Single cap: if (cap != KNOWN_CAP) goto nextProvider - CapabilityRef ref = known.capabilities().iterator().next(); - mv.visitVarInsn(ALOAD, 1); // cap parameter + // if (p == null) goto next + mv.visitVarInsn(ALOAD, 3); + mv.visitJumpInsn(IFNULL, nextLabel); + + // result = ((ICapabilityProvider) p).getCapability(cap, side) + mv.visitVarInsn(ALOAD, 3); + mv.visitTypeInsn(CHECKCAST, "net/minecraftforge/common/capabilities/ICapabilityProvider"); + mv.visitVarInsn(ALOAD, 1); + mv.visitVarInsn(ALOAD, 2); + mv.visitMethodInsn(INVOKEINTERFACE, + "net/minecraftforge/common/capabilities/ICapabilityProvider", + "getCapability", getCapDesc, true); + mv.visitVarInsn(ASTORE, 3); + } else { + if (dispatch instanceof ProviderDispatch.Guarded guarded) { + // if (cap != KNOWN_CAP) goto next + CapabilityRef ref = guarded.capability(); + mv.visitVarInsn(ALOAD, 1); mv.visitFieldInsn(GETSTATIC, ref.owner(), ref.fieldName(), CAPABILITY_DESC); mv.visitJumpInsn(IF_ACMPNE, nextLabel); - } else { - // Multiple caps: check each, jump to callProvider on match - Label callProvider = new Label(); - for (CapabilityRef ref : known.capabilities()) { - mv.visitVarInsn(ALOAD, 1); // cap parameter - mv.visitFieldInsn(GETSTATIC, ref.owner(), ref.fieldName(), CAPABILITY_DESC); - mv.visitJumpInsn(IF_ACMPEQ, callProvider); - } - // No match, skip this provider - mv.visitJumpInsn(GOTO, nextLabel); - mv.visitLabel(callProvider); } + + // LazyOptional result = this.providerN.getCapability(cap, side); + int provIdx; + String fDesc; + if (dispatch instanceof ProviderDispatch.Guarded g) { + provIdx = g.providerIndex(); fDesc = g.fieldDesc(); + } else { + var u = (ProviderDispatch.Unguarded) dispatch; + provIdx = u.providerIndex(); fDesc = u.fieldDesc(); + } + mv.visitVarInsn(ALOAD, 0); + mv.visitFieldInsn(GETFIELD, internalName, "provider" + provIdx, fDesc); + mv.visitVarInsn(ALOAD, 1); + mv.visitVarInsn(ALOAD, 2); + mv.visitMethodInsn(INVOKEINTERFACE, + "net/minecraftforge/common/capabilities/ICapabilityProvider", + "getCapability", getCapDesc, true); + mv.visitVarInsn(ASTORE, 3); } - // Indeterminate: no guard, fall through to dispatch - // LazyOptional result = this.providerN.getCapability(cap, side); - mv.visitVarInsn(ALOAD, 0); // this - mv.visitFieldInsn( - GETFIELD, - className.replace('.', '/'), - "provider" + i, - fieldDescs[i] - ); - mv.visitVarInsn(ALOAD, 1); // cap parameter - mv.visitVarInsn(ALOAD, 2); // side parameter - mv.visitMethodInsn( - INVOKEINTERFACE, - "net/minecraftforge/common/capabilities/ICapabilityProvider", - "getCapability", - "(" + CAPABILITY_DESC + DIRECTION_DESC + ")" + LAZY_OPTIONAL_DESC, - true - ); - - // Store result in local variable - mv.visitVarInsn(ASTORE, 3); - - // if (result == null) continue to next; + // if (result == null) goto next mv.visitVarInsn(ALOAD, 3); mv.visitJumpInsn(IFNULL, nextLabel); - // if (result.isPresent()) return result; + // if (result.isPresent()) return result mv.visitVarInsn(ALOAD, 3); - mv.visitMethodInsn( - INVOKEVIRTUAL, + mv.visitMethodInsn(INVOKEVIRTUAL, "net/minecraftforge/common/util/LazyOptional", - "isPresent", - "()Z", - false - ); + "isPresent", "()Z", false); mv.visitJumpInsn(IFEQ, nextLabel); - // return result mv.visitVarInsn(ALOAD, 3); mv.visitInsn(ARETURN);