Optimize runs of ICapabilityProvider calls into hash lookups

This commit is contained in:
embeddedt 2026-02-26 22:26:57 -05:00
parent b9933b1158
commit 784b914a43
No known key found for this signature in database
GPG Key ID: A69433EC199B5613

View File

@ -18,7 +18,11 @@ import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
@ -31,6 +35,23 @@ import static org.objectweb.asm.Opcodes.*;
* and performs direct dispatch instead of megamorphic virtual calls.
*/
public class CapabilityProviderDispatcherGenerator {
/**
* Describes the dispatch strategy for a single capability provider in the generated class.
*/
sealed interface ProviderDispatch {
/** Provider handles a known capability - emit an identity guard before dispatch. */
record Guarded(int providerIndex, String fieldDesc, CapabilityRef capability) implements ProviderDispatch {}
/** Provider capabilities are unknown - dispatch unconditionally. */
record Unguarded(int providerIndex, String fieldDesc) implements ProviderDispatch {}
/** Multiple guarded dispatches collapsed into a Map lookup. */
record Hash(int mapIndex, List<Guarded> entries) implements ProviderDispatch {}
}
/**
* Number of consecutive equality checks that must be performed to switch to a hash map.
*/
private static final int HASH_DISPATCH_THRESHOLD = 3;
private static final String GENERATED_CLASSES_FOLDER = System.getProperty("modernfix.generatedCapabilityDispatcherClassDumpFolder", "");
private static final ConcurrentHashMap<List<Class<? extends ICapabilityProvider>>, MethodHandle> cache =
@ -44,6 +65,7 @@ public class CapabilityProviderDispatcherGenerator {
private static final String CAPABILITY_DESC = "Lnet/minecraftforge/common/capabilities/Capability;";
private static final String LAZY_OPTIONAL_DESC = "Lnet/minecraftforge/common/util/LazyOptional;";
private static final String DIRECTION_DESC = "Lnet/minecraft/core/Direction;";
private static final String MAP_DESC = "Ljava/util/Map;";
/**
* Gets or generates a constructor MethodHandle for the given capability provider types.
@ -124,8 +146,122 @@ public class CapabilityProviderDispatcherGenerator {
}
}
/**
* Build the dispatch list describing how each provider should be handled.
*/
static List<ProviderDispatch> buildDispatchList(List<Class<? extends ICapabilityProvider>> providerTypes, List<CapabilityAnalysisResult> analysisResults) {
List<ProviderDispatch> dispatches = new ArrayList<>(providerTypes.size());
for (int i = 0; i < providerTypes.size(); i++) {
Class<? extends ICapabilityProvider> type = providerTypes.get(i);
String fieldDesc = (!type.isHidden() && Modifier.isPublic(type.getModifiers()))
? Type.getDescriptor(type) : ICAP_PROVIDER_DESC;
CapabilityAnalysisResult analysis = analysisResults.get(i);
if (analysis instanceof CapabilityAnalysisResult.AlwaysEmpty) {
// No dispatch needed - provider never returns a capability
} else if (analysis instanceof CapabilityAnalysisResult.KnownCapabilities known
&& known.capabilities().size() <= 5) {
for (CapabilityRef ref : known.capabilities()) {
dispatches.add(new ProviderDispatch.Guarded(i, fieldDesc, ref));
}
} else {
dispatches.add(new ProviderDispatch.Unguarded(i, fieldDesc));
}
}
return dispatches;
}
/**
* Collapse runs of 3+ consecutive Guarded dispatches into Hash dispatches.
* Duplicate CapabilityRefs within a run are kept as trailing Guarded entries
* after the Hash to preserve sequential fallthrough semantics.
*/
static List<ProviderDispatch> optimizeDispatches(List<ProviderDispatch> dispatches) {
List<ProviderDispatch> result = new ArrayList<>(dispatches.size());
int mapIndex = 0;
int i = 0;
while (i < dispatches.size()) {
// Collect a run of consecutive Guarded entries
int runStart = i;
while (i < dispatches.size() && dispatches.get(i) instanceof ProviderDispatch.Guarded) {
i++;
}
List<ProviderDispatch> run = dispatches.subList(runStart, i);
if (run.isEmpty()) {
// Not a Guarded entry, pass through
result.add(dispatches.get(i));
i++;
continue;
}
if (!tryCollapseToHash(run, mapIndex, result)) {
result.addAll(run);
} else {
mapIndex++;
}
}
return result;
}
/**
* Attempt to collapse a run of Guarded dispatches into a Hash.
* Returns true if a Hash was emitted, false if the run should be kept as-is.
*/
private static boolean tryCollapseToHash(List<ProviderDispatch> run, int mapIndex, List<ProviderDispatch> result) {
if (run.size() < HASH_DISPATCH_THRESHOLD) {
return false;
}
// Deduplicate by CapabilityRef - first occurrence goes into the hash,
// duplicates are kept as trailing Guarded entries for fallthrough
Set<CapabilityRef> seen = new HashSet<>();
List<ProviderDispatch.Guarded> hashEntries = new ArrayList<>();
List<ProviderDispatch.Guarded> duplicates = new ArrayList<>();
for (ProviderDispatch dispatch : run) {
ProviderDispatch.Guarded g = (ProviderDispatch.Guarded) dispatch;
if (seen.add(g.capability())) {
hashEntries.add(g);
} else {
duplicates.add(g);
}
}
if (hashEntries.size() < HASH_DISPATCH_THRESHOLD) {
return false;
}
result.add(new ProviderDispatch.Hash(mapIndex, hashEntries));
result.addAll(duplicates);
return true;
}
/**
* Collect all unique provider fields (index fieldDesc) referenced by a dispatch list,
* including those inside Hash entries.
*/
private static LinkedHashMap<Integer, String> collectProviderFields(List<ProviderDispatch> dispatches) {
LinkedHashMap<Integer, String> fields = new LinkedHashMap<>();
for (ProviderDispatch dispatch : dispatches) {
if (dispatch instanceof ProviderDispatch.Guarded g) {
fields.putIfAbsent(g.providerIndex(), g.fieldDesc());
} else if (dispatch instanceof ProviderDispatch.Unguarded u) {
fields.putIfAbsent(u.providerIndex(), u.fieldDesc());
}
// Hash entries don't need provider fields - map reads from constructor array
}
return fields;
}
private static byte[] generateClassBytes(String className, List<Class<? extends ICapabilityProvider>> providerTypes, List<CapabilityAnalysisResult> analysisResults) {
ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS);
List<ProviderDispatch> dispatches = optimizeDispatches(buildDispatchList(providerTypes, analysisResults));
ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS) {
@Override
protected ClassLoader getClassLoader() {
return CapabilityProviderDispatcherGenerator.class.getClassLoader();
}
};
// Class declaration: implements ICapabilityProvider
cw.visit(
@ -137,67 +273,89 @@ public class CapabilityProviderDispatcherGenerator {
new String[] { "net/minecraftforge/common/capabilities/ICapabilityProvider" }
);
// Compute field descriptors: use concrete type when possible for JIT devirtualization
String[] fieldDescs = new String[providerTypes.size()];
for (int i = 0; i < providerTypes.size(); i++) {
Class<? extends ICapabilityProvider> type = providerTypes.get(i);
fieldDescs[i] = (!type.isHidden() && Modifier.isPublic(type.getModifiers()))
? Type.getDescriptor(type) : ICAP_PROVIDER_DESC;
// Generate final fields for each distinct provider
LinkedHashMap<Integer, String> providerFields = collectProviderFields(dispatches);
for (var entry : providerFields.entrySet()) {
cw.visitField(ACC_PRIVATE | ACC_FINAL, "provider" + entry.getKey(), entry.getValue(), null, null).visitEnd();
}
// Generate final fields for each provider
for (int i = 0; i < providerTypes.size(); i++) {
cw.visitField(
ACC_PRIVATE | ACC_FINAL,
"provider" + i,
fieldDescs[i],
null,
null
).visitEnd();
// Generate map fields for Hash dispatches
for (ProviderDispatch dispatch : dispatches) {
if (dispatch instanceof ProviderDispatch.Hash hash) {
cw.visitField(ACC_PRIVATE | ACC_FINAL, "capMap" + hash.mapIndex(), MAP_DESC, null, null).visitEnd();
}
}
// Generate constructor
generateConstructor(cw, className, providerTypes.size(), fieldDescs);
generateConstructor(cw, className, providerFields, dispatches);
// Generate getCapability method with sided parameter
generateGetCapabilityMethod(cw, className, fieldDescs, analysisResults);
generateGetCapabilityMethod(cw, className, dispatches);
cw.visitEnd();
return cw.toByteArray();
}
private static void generateConstructor(ClassWriter cw, String className, int providerCount, String[] fieldDescs) {
private static void generateConstructor(ClassWriter cw, String className, Map<Integer, String> providerFields, List<ProviderDispatch> dispatches) {
Method constructor = Method.getMethod("void <init>(net.minecraftforge.common.capabilities.ICapabilityProvider[])");
GeneratorAdapter mg = new GeneratorAdapter(ACC_PUBLIC, constructor, null, null, cw);
Type classType = Type.getObjectType(className.replace('.', '/'));
// Call super constructor
mg.loadThis();
mg.invokeConstructor(Type.getType(Object.class), Method.getMethod("void <init>()"));
// Unpack array into final fields
for (int i = 0; i < providerCount; i++) {
Type fieldType = Type.getType(fieldDescs[i]);
mg.loadThis(); // this
// Unpack array into provider fields
for (var entry : providerFields.entrySet()) {
int idx = entry.getKey();
String desc = entry.getValue();
Type fieldType = Type.getType(desc);
mg.loadThis();
mg.loadArg(0); // array
mg.push(i); // index
mg.arrayLoad(Type.getType(ICAP_PROVIDER_DESC)); // array[i]
if (!fieldDescs[i].equals(ICAP_PROVIDER_DESC)) {
mg.push(idx); // index
mg.arrayLoad(Type.getType(ICAP_PROVIDER_DESC));
if (!desc.equals(ICAP_PROVIDER_DESC)) {
mg.checkCast(fieldType);
}
mg.putField(
Type.getObjectType(className.replace('.', '/')),
"provider" + i,
fieldType
);
mg.putField(classType, "provider" + idx, fieldType);
}
// Build hash maps
for (ProviderDispatch dispatch : dispatches) {
if (dispatch instanceof ProviderDispatch.Hash hash) {
generateMapConstruction(mg, classType, hash);
}
}
mg.returnValue();
mg.endMethod();
}
private static void generateGetCapabilityMethod(ClassWriter cw, String className, String[] fieldDescs, List<CapabilityAnalysisResult> analysisResults) {
int providerCount = fieldDescs.length;
private static void generateMapConstruction(GeneratorAdapter mg, Type classType, ProviderDispatch.Hash hash) {
List<ProviderDispatch.Guarded> entries = hash.entries();
mg.loadThis(); // for PUTFIELD at the end
mg.push(entries.size());
mg.visitTypeInsn(ANEWARRAY, "java/util/Map$Entry");
for (int i = 0; i < entries.size(); i++) {
ProviderDispatch.Guarded g = entries.get(i);
mg.dup();
mg.push(i);
mg.visitFieldInsn(GETSTATIC, g.capability().owner(), g.capability().fieldName(), CAPABILITY_DESC);
mg.loadArg(0);
mg.push(g.providerIndex());
mg.arrayLoad(Type.getType(ICAP_PROVIDER_DESC));
mg.visitMethodInsn(INVOKESTATIC, "java/util/Map", "entry",
"(Ljava/lang/Object;Ljava/lang/Object;)Ljava/util/Map$Entry;", true);
mg.visitInsn(AASTORE);
}
mg.visitMethodInsn(INVOKESTATIC, "java/util/Map", "ofEntries",
"([Ljava/util/Map$Entry;)Ljava/util/Map;", true);
mg.putField(classType, "capMap" + hash.mapIndex(), Type.getType(MAP_DESC));
}
private static void generateGetCapabilityMethod(ClassWriter cw, String className, List<ProviderDispatch> dispatches) {
// Method: <T> LazyOptional<T> getCapability(Capability<T>, Direction)
MethodVisitor mv = cw.visitMethod(
ACC_PUBLIC,
@ -213,76 +371,73 @@ public class CapabilityProviderDispatcherGenerator {
// For each provider, call getCapability and check if present
Label endLabel = new Label();
for (int i = 0; i < providerCount; i++) {
CapabilityAnalysisResult analysis = analysisResults.get(i);
String internalName = className.replace('.', '/');
String getCapDesc = "(" + CAPABILITY_DESC + DIRECTION_DESC + ")" + LAZY_OPTIONAL_DESC;
for (ProviderDispatch dispatch : dispatches) {
Label nextLabel = new Label();
// AlwaysEmpty: skip code generation for this provider entirely
if (analysis instanceof CapabilityAnalysisResult.AlwaysEmpty) {
continue;
}
if (dispatch instanceof ProviderDispatch.Hash hash) {
// ICapabilityProvider p = (ICapabilityProvider) this.capMapN.get(cap);
mv.visitVarInsn(ALOAD, 0);
mv.visitFieldInsn(GETFIELD, internalName, "capMap" + hash.mapIndex(), MAP_DESC);
mv.visitVarInsn(ALOAD, 1);
mv.visitMethodInsn(INVOKEINTERFACE, "java/util/Map", "get",
"(Ljava/lang/Object;)Ljava/lang/Object;", true);
mv.visitVarInsn(ASTORE, 3);
// KnownCapabilities: emit guard checks before dispatch
if (analysis instanceof CapabilityAnalysisResult.KnownCapabilities known
&& known.capabilities().size() <= 5) {
if (known.capabilities().size() == 1) {
// Single cap: if (cap != KNOWN_CAP) goto nextProvider
CapabilityRef ref = known.capabilities().iterator().next();
mv.visitVarInsn(ALOAD, 1); // cap parameter
// if (p == null) goto next
mv.visitVarInsn(ALOAD, 3);
mv.visitJumpInsn(IFNULL, nextLabel);
// result = ((ICapabilityProvider) p).getCapability(cap, side)
mv.visitVarInsn(ALOAD, 3);
mv.visitTypeInsn(CHECKCAST, "net/minecraftforge/common/capabilities/ICapabilityProvider");
mv.visitVarInsn(ALOAD, 1);
mv.visitVarInsn(ALOAD, 2);
mv.visitMethodInsn(INVOKEINTERFACE,
"net/minecraftforge/common/capabilities/ICapabilityProvider",
"getCapability", getCapDesc, true);
mv.visitVarInsn(ASTORE, 3);
} else {
if (dispatch instanceof ProviderDispatch.Guarded guarded) {
// if (cap != KNOWN_CAP) goto next
CapabilityRef ref = guarded.capability();
mv.visitVarInsn(ALOAD, 1);
mv.visitFieldInsn(GETSTATIC, ref.owner(), ref.fieldName(), CAPABILITY_DESC);
mv.visitJumpInsn(IF_ACMPNE, nextLabel);
} else {
// Multiple caps: check each, jump to callProvider on match
Label callProvider = new Label();
for (CapabilityRef ref : known.capabilities()) {
mv.visitVarInsn(ALOAD, 1); // cap parameter
mv.visitFieldInsn(GETSTATIC, ref.owner(), ref.fieldName(), CAPABILITY_DESC);
mv.visitJumpInsn(IF_ACMPEQ, callProvider);
}
// No match, skip this provider
mv.visitJumpInsn(GOTO, nextLabel);
mv.visitLabel(callProvider);
}
// LazyOptional<T> result = this.providerN.getCapability(cap, side);
int provIdx;
String fDesc;
if (dispatch instanceof ProviderDispatch.Guarded g) {
provIdx = g.providerIndex(); fDesc = g.fieldDesc();
} else {
var u = (ProviderDispatch.Unguarded) dispatch;
provIdx = u.providerIndex(); fDesc = u.fieldDesc();
}
mv.visitVarInsn(ALOAD, 0);
mv.visitFieldInsn(GETFIELD, internalName, "provider" + provIdx, fDesc);
mv.visitVarInsn(ALOAD, 1);
mv.visitVarInsn(ALOAD, 2);
mv.visitMethodInsn(INVOKEINTERFACE,
"net/minecraftforge/common/capabilities/ICapabilityProvider",
"getCapability", getCapDesc, true);
mv.visitVarInsn(ASTORE, 3);
}
// Indeterminate: no guard, fall through to dispatch
// LazyOptional<T> result = this.providerN.getCapability(cap, side);
mv.visitVarInsn(ALOAD, 0); // this
mv.visitFieldInsn(
GETFIELD,
className.replace('.', '/'),
"provider" + i,
fieldDescs[i]
);
mv.visitVarInsn(ALOAD, 1); // cap parameter
mv.visitVarInsn(ALOAD, 2); // side parameter
mv.visitMethodInsn(
INVOKEINTERFACE,
"net/minecraftforge/common/capabilities/ICapabilityProvider",
"getCapability",
"(" + CAPABILITY_DESC + DIRECTION_DESC + ")" + LAZY_OPTIONAL_DESC,
true
);
// Store result in local variable
mv.visitVarInsn(ASTORE, 3);
// if (result == null) continue to next;
// if (result == null) goto next
mv.visitVarInsn(ALOAD, 3);
mv.visitJumpInsn(IFNULL, nextLabel);
// if (result.isPresent()) return result;
// if (result.isPresent()) return result
mv.visitVarInsn(ALOAD, 3);
mv.visitMethodInsn(
INVOKEVIRTUAL,
mv.visitMethodInsn(INVOKEVIRTUAL,
"net/minecraftforge/common/util/LazyOptional",
"isPresent",
"()Z",
false
);
"isPresent", "()Z", false);
mv.visitJumpInsn(IFEQ, nextLabel);
// return result
mv.visitVarInsn(ALOAD, 3);
mv.visitInsn(ARETURN);