Optimize runs of ICapabilityProvider calls into hash lookups
This commit is contained in:
parent
b9933b1158
commit
784b914a43
|
|
@ -18,7 +18,11 @@ import java.nio.file.Path;
|
|||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.stream.Collectors;
|
||||
|
|
@ -31,6 +35,23 @@ import static org.objectweb.asm.Opcodes.*;
|
|||
* and performs direct dispatch instead of megamorphic virtual calls.
|
||||
*/
|
||||
public class CapabilityProviderDispatcherGenerator {
|
||||
/**
|
||||
* Describes the dispatch strategy for a single capability provider in the generated class.
|
||||
*/
|
||||
sealed interface ProviderDispatch {
|
||||
/** Provider handles a known capability - emit an identity guard before dispatch. */
|
||||
record Guarded(int providerIndex, String fieldDesc, CapabilityRef capability) implements ProviderDispatch {}
|
||||
/** Provider capabilities are unknown - dispatch unconditionally. */
|
||||
record Unguarded(int providerIndex, String fieldDesc) implements ProviderDispatch {}
|
||||
/** Multiple guarded dispatches collapsed into a Map lookup. */
|
||||
record Hash(int mapIndex, List<Guarded> entries) implements ProviderDispatch {}
|
||||
}
|
||||
|
||||
/**
|
||||
* Number of consecutive equality checks that must be performed to switch to a hash map.
|
||||
*/
|
||||
private static final int HASH_DISPATCH_THRESHOLD = 3;
|
||||
|
||||
private static final String GENERATED_CLASSES_FOLDER = System.getProperty("modernfix.generatedCapabilityDispatcherClassDumpFolder", "");
|
||||
|
||||
private static final ConcurrentHashMap<List<Class<? extends ICapabilityProvider>>, MethodHandle> cache =
|
||||
|
|
@ -44,6 +65,7 @@ public class CapabilityProviderDispatcherGenerator {
|
|||
private static final String CAPABILITY_DESC = "Lnet/minecraftforge/common/capabilities/Capability;";
|
||||
private static final String LAZY_OPTIONAL_DESC = "Lnet/minecraftforge/common/util/LazyOptional;";
|
||||
private static final String DIRECTION_DESC = "Lnet/minecraft/core/Direction;";
|
||||
private static final String MAP_DESC = "Ljava/util/Map;";
|
||||
|
||||
/**
|
||||
* Gets or generates a constructor MethodHandle for the given capability provider types.
|
||||
|
|
@ -124,8 +146,122 @@ public class CapabilityProviderDispatcherGenerator {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the dispatch list describing how each provider should be handled.
|
||||
*/
|
||||
static List<ProviderDispatch> buildDispatchList(List<Class<? extends ICapabilityProvider>> providerTypes, List<CapabilityAnalysisResult> analysisResults) {
|
||||
List<ProviderDispatch> dispatches = new ArrayList<>(providerTypes.size());
|
||||
for (int i = 0; i < providerTypes.size(); i++) {
|
||||
Class<? extends ICapabilityProvider> type = providerTypes.get(i);
|
||||
String fieldDesc = (!type.isHidden() && Modifier.isPublic(type.getModifiers()))
|
||||
? Type.getDescriptor(type) : ICAP_PROVIDER_DESC;
|
||||
|
||||
CapabilityAnalysisResult analysis = analysisResults.get(i);
|
||||
if (analysis instanceof CapabilityAnalysisResult.AlwaysEmpty) {
|
||||
// No dispatch needed - provider never returns a capability
|
||||
} else if (analysis instanceof CapabilityAnalysisResult.KnownCapabilities known
|
||||
&& known.capabilities().size() <= 5) {
|
||||
for (CapabilityRef ref : known.capabilities()) {
|
||||
dispatches.add(new ProviderDispatch.Guarded(i, fieldDesc, ref));
|
||||
}
|
||||
} else {
|
||||
dispatches.add(new ProviderDispatch.Unguarded(i, fieldDesc));
|
||||
}
|
||||
}
|
||||
return dispatches;
|
||||
}
|
||||
|
||||
/**
|
||||
* Collapse runs of 3+ consecutive Guarded dispatches into Hash dispatches.
|
||||
* Duplicate CapabilityRefs within a run are kept as trailing Guarded entries
|
||||
* after the Hash to preserve sequential fallthrough semantics.
|
||||
*/
|
||||
static List<ProviderDispatch> optimizeDispatches(List<ProviderDispatch> dispatches) {
|
||||
List<ProviderDispatch> result = new ArrayList<>(dispatches.size());
|
||||
int mapIndex = 0;
|
||||
int i = 0;
|
||||
while (i < dispatches.size()) {
|
||||
// Collect a run of consecutive Guarded entries
|
||||
int runStart = i;
|
||||
while (i < dispatches.size() && dispatches.get(i) instanceof ProviderDispatch.Guarded) {
|
||||
i++;
|
||||
}
|
||||
|
||||
List<ProviderDispatch> run = dispatches.subList(runStart, i);
|
||||
if (run.isEmpty()) {
|
||||
// Not a Guarded entry, pass through
|
||||
result.add(dispatches.get(i));
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!tryCollapseToHash(run, mapIndex, result)) {
|
||||
result.addAll(run);
|
||||
} else {
|
||||
mapIndex++;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempt to collapse a run of Guarded dispatches into a Hash.
|
||||
* Returns true if a Hash was emitted, false if the run should be kept as-is.
|
||||
*/
|
||||
private static boolean tryCollapseToHash(List<ProviderDispatch> run, int mapIndex, List<ProviderDispatch> result) {
|
||||
if (run.size() < HASH_DISPATCH_THRESHOLD) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Deduplicate by CapabilityRef - first occurrence goes into the hash,
|
||||
// duplicates are kept as trailing Guarded entries for fallthrough
|
||||
Set<CapabilityRef> seen = new HashSet<>();
|
||||
List<ProviderDispatch.Guarded> hashEntries = new ArrayList<>();
|
||||
List<ProviderDispatch.Guarded> duplicates = new ArrayList<>();
|
||||
for (ProviderDispatch dispatch : run) {
|
||||
ProviderDispatch.Guarded g = (ProviderDispatch.Guarded) dispatch;
|
||||
if (seen.add(g.capability())) {
|
||||
hashEntries.add(g);
|
||||
} else {
|
||||
duplicates.add(g);
|
||||
}
|
||||
}
|
||||
|
||||
if (hashEntries.size() < HASH_DISPATCH_THRESHOLD) {
|
||||
return false;
|
||||
}
|
||||
|
||||
result.add(new ProviderDispatch.Hash(mapIndex, hashEntries));
|
||||
result.addAll(duplicates);
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Collect all unique provider fields (index → fieldDesc) referenced by a dispatch list,
|
||||
* including those inside Hash entries.
|
||||
*/
|
||||
private static LinkedHashMap<Integer, String> collectProviderFields(List<ProviderDispatch> dispatches) {
|
||||
LinkedHashMap<Integer, String> fields = new LinkedHashMap<>();
|
||||
for (ProviderDispatch dispatch : dispatches) {
|
||||
if (dispatch instanceof ProviderDispatch.Guarded g) {
|
||||
fields.putIfAbsent(g.providerIndex(), g.fieldDesc());
|
||||
} else if (dispatch instanceof ProviderDispatch.Unguarded u) {
|
||||
fields.putIfAbsent(u.providerIndex(), u.fieldDesc());
|
||||
}
|
||||
// Hash entries don't need provider fields - map reads from constructor array
|
||||
}
|
||||
return fields;
|
||||
}
|
||||
|
||||
private static byte[] generateClassBytes(String className, List<Class<? extends ICapabilityProvider>> providerTypes, List<CapabilityAnalysisResult> analysisResults) {
|
||||
ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS);
|
||||
List<ProviderDispatch> dispatches = optimizeDispatches(buildDispatchList(providerTypes, analysisResults));
|
||||
|
||||
ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS) {
|
||||
@Override
|
||||
protected ClassLoader getClassLoader() {
|
||||
return CapabilityProviderDispatcherGenerator.class.getClassLoader();
|
||||
}
|
||||
};
|
||||
|
||||
// Class declaration: implements ICapabilityProvider
|
||||
cw.visit(
|
||||
|
|
@ -137,67 +273,89 @@ public class CapabilityProviderDispatcherGenerator {
|
|||
new String[] { "net/minecraftforge/common/capabilities/ICapabilityProvider" }
|
||||
);
|
||||
|
||||
// Compute field descriptors: use concrete type when possible for JIT devirtualization
|
||||
String[] fieldDescs = new String[providerTypes.size()];
|
||||
for (int i = 0; i < providerTypes.size(); i++) {
|
||||
Class<? extends ICapabilityProvider> type = providerTypes.get(i);
|
||||
fieldDescs[i] = (!type.isHidden() && Modifier.isPublic(type.getModifiers()))
|
||||
? Type.getDescriptor(type) : ICAP_PROVIDER_DESC;
|
||||
// Generate final fields for each distinct provider
|
||||
LinkedHashMap<Integer, String> providerFields = collectProviderFields(dispatches);
|
||||
for (var entry : providerFields.entrySet()) {
|
||||
cw.visitField(ACC_PRIVATE | ACC_FINAL, "provider" + entry.getKey(), entry.getValue(), null, null).visitEnd();
|
||||
}
|
||||
|
||||
// Generate final fields for each provider
|
||||
for (int i = 0; i < providerTypes.size(); i++) {
|
||||
cw.visitField(
|
||||
ACC_PRIVATE | ACC_FINAL,
|
||||
"provider" + i,
|
||||
fieldDescs[i],
|
||||
null,
|
||||
null
|
||||
).visitEnd();
|
||||
// Generate map fields for Hash dispatches
|
||||
for (ProviderDispatch dispatch : dispatches) {
|
||||
if (dispatch instanceof ProviderDispatch.Hash hash) {
|
||||
cw.visitField(ACC_PRIVATE | ACC_FINAL, "capMap" + hash.mapIndex(), MAP_DESC, null, null).visitEnd();
|
||||
}
|
||||
}
|
||||
|
||||
// Generate constructor
|
||||
generateConstructor(cw, className, providerTypes.size(), fieldDescs);
|
||||
generateConstructor(cw, className, providerFields, dispatches);
|
||||
|
||||
// Generate getCapability method with sided parameter
|
||||
generateGetCapabilityMethod(cw, className, fieldDescs, analysisResults);
|
||||
generateGetCapabilityMethod(cw, className, dispatches);
|
||||
|
||||
cw.visitEnd();
|
||||
return cw.toByteArray();
|
||||
}
|
||||
|
||||
private static void generateConstructor(ClassWriter cw, String className, int providerCount, String[] fieldDescs) {
|
||||
private static void generateConstructor(ClassWriter cw, String className, Map<Integer, String> providerFields, List<ProviderDispatch> dispatches) {
|
||||
Method constructor = Method.getMethod("void <init>(net.minecraftforge.common.capabilities.ICapabilityProvider[])");
|
||||
GeneratorAdapter mg = new GeneratorAdapter(ACC_PUBLIC, constructor, null, null, cw);
|
||||
Type classType = Type.getObjectType(className.replace('.', '/'));
|
||||
|
||||
// Call super constructor
|
||||
mg.loadThis();
|
||||
mg.invokeConstructor(Type.getType(Object.class), Method.getMethod("void <init>()"));
|
||||
|
||||
// Unpack array into final fields
|
||||
for (int i = 0; i < providerCount; i++) {
|
||||
Type fieldType = Type.getType(fieldDescs[i]);
|
||||
mg.loadThis(); // this
|
||||
// Unpack array into provider fields
|
||||
for (var entry : providerFields.entrySet()) {
|
||||
int idx = entry.getKey();
|
||||
String desc = entry.getValue();
|
||||
Type fieldType = Type.getType(desc);
|
||||
mg.loadThis();
|
||||
mg.loadArg(0); // array
|
||||
mg.push(i); // index
|
||||
mg.arrayLoad(Type.getType(ICAP_PROVIDER_DESC)); // array[i]
|
||||
if (!fieldDescs[i].equals(ICAP_PROVIDER_DESC)) {
|
||||
mg.push(idx); // index
|
||||
mg.arrayLoad(Type.getType(ICAP_PROVIDER_DESC));
|
||||
if (!desc.equals(ICAP_PROVIDER_DESC)) {
|
||||
mg.checkCast(fieldType);
|
||||
}
|
||||
mg.putField(
|
||||
Type.getObjectType(className.replace('.', '/')),
|
||||
"provider" + i,
|
||||
fieldType
|
||||
);
|
||||
mg.putField(classType, "provider" + idx, fieldType);
|
||||
}
|
||||
|
||||
// Build hash maps
|
||||
for (ProviderDispatch dispatch : dispatches) {
|
||||
if (dispatch instanceof ProviderDispatch.Hash hash) {
|
||||
generateMapConstruction(mg, classType, hash);
|
||||
}
|
||||
}
|
||||
|
||||
mg.returnValue();
|
||||
mg.endMethod();
|
||||
}
|
||||
|
||||
private static void generateGetCapabilityMethod(ClassWriter cw, String className, String[] fieldDescs, List<CapabilityAnalysisResult> analysisResults) {
|
||||
int providerCount = fieldDescs.length;
|
||||
private static void generateMapConstruction(GeneratorAdapter mg, Type classType, ProviderDispatch.Hash hash) {
|
||||
List<ProviderDispatch.Guarded> entries = hash.entries();
|
||||
mg.loadThis(); // for PUTFIELD at the end
|
||||
|
||||
mg.push(entries.size());
|
||||
mg.visitTypeInsn(ANEWARRAY, "java/util/Map$Entry");
|
||||
for (int i = 0; i < entries.size(); i++) {
|
||||
ProviderDispatch.Guarded g = entries.get(i);
|
||||
mg.dup();
|
||||
mg.push(i);
|
||||
mg.visitFieldInsn(GETSTATIC, g.capability().owner(), g.capability().fieldName(), CAPABILITY_DESC);
|
||||
mg.loadArg(0);
|
||||
mg.push(g.providerIndex());
|
||||
mg.arrayLoad(Type.getType(ICAP_PROVIDER_DESC));
|
||||
mg.visitMethodInsn(INVOKESTATIC, "java/util/Map", "entry",
|
||||
"(Ljava/lang/Object;Ljava/lang/Object;)Ljava/util/Map$Entry;", true);
|
||||
mg.visitInsn(AASTORE);
|
||||
}
|
||||
mg.visitMethodInsn(INVOKESTATIC, "java/util/Map", "ofEntries",
|
||||
"([Ljava/util/Map$Entry;)Ljava/util/Map;", true);
|
||||
|
||||
mg.putField(classType, "capMap" + hash.mapIndex(), Type.getType(MAP_DESC));
|
||||
}
|
||||
|
||||
private static void generateGetCapabilityMethod(ClassWriter cw, String className, List<ProviderDispatch> dispatches) {
|
||||
// Method: <T> LazyOptional<T> getCapability(Capability<T>, Direction)
|
||||
MethodVisitor mv = cw.visitMethod(
|
||||
ACC_PUBLIC,
|
||||
|
|
@ -213,76 +371,73 @@ public class CapabilityProviderDispatcherGenerator {
|
|||
// For each provider, call getCapability and check if present
|
||||
Label endLabel = new Label();
|
||||
|
||||
for (int i = 0; i < providerCount; i++) {
|
||||
CapabilityAnalysisResult analysis = analysisResults.get(i);
|
||||
String internalName = className.replace('.', '/');
|
||||
String getCapDesc = "(" + CAPABILITY_DESC + DIRECTION_DESC + ")" + LAZY_OPTIONAL_DESC;
|
||||
|
||||
for (ProviderDispatch dispatch : dispatches) {
|
||||
Label nextLabel = new Label();
|
||||
|
||||
// AlwaysEmpty: skip code generation for this provider entirely
|
||||
if (analysis instanceof CapabilityAnalysisResult.AlwaysEmpty) {
|
||||
continue;
|
||||
}
|
||||
if (dispatch instanceof ProviderDispatch.Hash hash) {
|
||||
// ICapabilityProvider p = (ICapabilityProvider) this.capMapN.get(cap);
|
||||
mv.visitVarInsn(ALOAD, 0);
|
||||
mv.visitFieldInsn(GETFIELD, internalName, "capMap" + hash.mapIndex(), MAP_DESC);
|
||||
mv.visitVarInsn(ALOAD, 1);
|
||||
mv.visitMethodInsn(INVOKEINTERFACE, "java/util/Map", "get",
|
||||
"(Ljava/lang/Object;)Ljava/lang/Object;", true);
|
||||
mv.visitVarInsn(ASTORE, 3);
|
||||
|
||||
// KnownCapabilities: emit guard checks before dispatch
|
||||
if (analysis instanceof CapabilityAnalysisResult.KnownCapabilities known
|
||||
&& known.capabilities().size() <= 5) {
|
||||
if (known.capabilities().size() == 1) {
|
||||
// Single cap: if (cap != KNOWN_CAP) goto nextProvider
|
||||
CapabilityRef ref = known.capabilities().iterator().next();
|
||||
mv.visitVarInsn(ALOAD, 1); // cap parameter
|
||||
// if (p == null) goto next
|
||||
mv.visitVarInsn(ALOAD, 3);
|
||||
mv.visitJumpInsn(IFNULL, nextLabel);
|
||||
|
||||
// result = ((ICapabilityProvider) p).getCapability(cap, side)
|
||||
mv.visitVarInsn(ALOAD, 3);
|
||||
mv.visitTypeInsn(CHECKCAST, "net/minecraftforge/common/capabilities/ICapabilityProvider");
|
||||
mv.visitVarInsn(ALOAD, 1);
|
||||
mv.visitVarInsn(ALOAD, 2);
|
||||
mv.visitMethodInsn(INVOKEINTERFACE,
|
||||
"net/minecraftforge/common/capabilities/ICapabilityProvider",
|
||||
"getCapability", getCapDesc, true);
|
||||
mv.visitVarInsn(ASTORE, 3);
|
||||
} else {
|
||||
if (dispatch instanceof ProviderDispatch.Guarded guarded) {
|
||||
// if (cap != KNOWN_CAP) goto next
|
||||
CapabilityRef ref = guarded.capability();
|
||||
mv.visitVarInsn(ALOAD, 1);
|
||||
mv.visitFieldInsn(GETSTATIC, ref.owner(), ref.fieldName(), CAPABILITY_DESC);
|
||||
mv.visitJumpInsn(IF_ACMPNE, nextLabel);
|
||||
} else {
|
||||
// Multiple caps: check each, jump to callProvider on match
|
||||
Label callProvider = new Label();
|
||||
for (CapabilityRef ref : known.capabilities()) {
|
||||
mv.visitVarInsn(ALOAD, 1); // cap parameter
|
||||
mv.visitFieldInsn(GETSTATIC, ref.owner(), ref.fieldName(), CAPABILITY_DESC);
|
||||
mv.visitJumpInsn(IF_ACMPEQ, callProvider);
|
||||
}
|
||||
// No match, skip this provider
|
||||
mv.visitJumpInsn(GOTO, nextLabel);
|
||||
mv.visitLabel(callProvider);
|
||||
}
|
||||
|
||||
// LazyOptional<T> result = this.providerN.getCapability(cap, side);
|
||||
int provIdx;
|
||||
String fDesc;
|
||||
if (dispatch instanceof ProviderDispatch.Guarded g) {
|
||||
provIdx = g.providerIndex(); fDesc = g.fieldDesc();
|
||||
} else {
|
||||
var u = (ProviderDispatch.Unguarded) dispatch;
|
||||
provIdx = u.providerIndex(); fDesc = u.fieldDesc();
|
||||
}
|
||||
mv.visitVarInsn(ALOAD, 0);
|
||||
mv.visitFieldInsn(GETFIELD, internalName, "provider" + provIdx, fDesc);
|
||||
mv.visitVarInsn(ALOAD, 1);
|
||||
mv.visitVarInsn(ALOAD, 2);
|
||||
mv.visitMethodInsn(INVOKEINTERFACE,
|
||||
"net/minecraftforge/common/capabilities/ICapabilityProvider",
|
||||
"getCapability", getCapDesc, true);
|
||||
mv.visitVarInsn(ASTORE, 3);
|
||||
}
|
||||
// Indeterminate: no guard, fall through to dispatch
|
||||
|
||||
// LazyOptional<T> result = this.providerN.getCapability(cap, side);
|
||||
mv.visitVarInsn(ALOAD, 0); // this
|
||||
mv.visitFieldInsn(
|
||||
GETFIELD,
|
||||
className.replace('.', '/'),
|
||||
"provider" + i,
|
||||
fieldDescs[i]
|
||||
);
|
||||
mv.visitVarInsn(ALOAD, 1); // cap parameter
|
||||
mv.visitVarInsn(ALOAD, 2); // side parameter
|
||||
mv.visitMethodInsn(
|
||||
INVOKEINTERFACE,
|
||||
"net/minecraftforge/common/capabilities/ICapabilityProvider",
|
||||
"getCapability",
|
||||
"(" + CAPABILITY_DESC + DIRECTION_DESC + ")" + LAZY_OPTIONAL_DESC,
|
||||
true
|
||||
);
|
||||
|
||||
// Store result in local variable
|
||||
mv.visitVarInsn(ASTORE, 3);
|
||||
|
||||
// if (result == null) continue to next;
|
||||
// if (result == null) goto next
|
||||
mv.visitVarInsn(ALOAD, 3);
|
||||
mv.visitJumpInsn(IFNULL, nextLabel);
|
||||
|
||||
// if (result.isPresent()) return result;
|
||||
// if (result.isPresent()) return result
|
||||
mv.visitVarInsn(ALOAD, 3);
|
||||
mv.visitMethodInsn(
|
||||
INVOKEVIRTUAL,
|
||||
mv.visitMethodInsn(INVOKEVIRTUAL,
|
||||
"net/minecraftforge/common/util/LazyOptional",
|
||||
"isPresent",
|
||||
"()Z",
|
||||
false
|
||||
);
|
||||
"isPresent", "()Z", false);
|
||||
mv.visitJumpInsn(IFEQ, nextLabel);
|
||||
|
||||
// return result
|
||||
mv.visitVarInsn(ALOAD, 3);
|
||||
mv.visitInsn(ARETURN);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user