修复ScaledProcessingPattern构建过程中产出比的计算问题

This commit is contained in:
C-H716 2025-08-28 15:23:06 +08:00
parent 1d55920af9
commit ac522c7751
2 changed files with 138 additions and 25 deletions

View File

@ -0,0 +1,78 @@
package com.extendedae_plus.util;
import java.util.Arrays;
public class ArraySimplifier {
// 计算两个数的GCD using Euclidean algorithm (long版本)
public static long gcd(long a, long b) {
while (b != 0) {
long temp = b;
b = a % b;
a = temp;
}
return a;
}
// 计算整个数组的GCD
public static long findGCD(long[] arr) {
if (arr.length == 0) {
return 0;
}
long result = arr[0];
for (int i = 1; i < arr.length; i++) {
result = gcd(result, arr[i]);
// 如果已经找到GCD为1可以提前终止
if (result == 1) {
break;
}
}
return result;
}
// 简化数组每个元素除以数组的GCD
public static long[] simplifyFraction(long[] arr) {
if (arr.length == 0) {
return new long[0];
}
long gcd = findGCD(arr);
if (gcd == 0) {
// 如果GCD为0所有元素为0返回原数组的副本
return Arrays.copyOf(arr, arr.length);
}
long[] simplified = new long[arr.length];
for (int i = 0; i < arr.length; i++) {
simplified[i] = arr[i] / gcd;
}
return simplified;
}
// 将两个数组合并为一个新数组先放 a 后放 b
public static long[] combine(long[] a, long[] b) {
long[] out = new long[a.length + b.length];
System.arraycopy(a, 0, out, 0, a.length);
System.arraycopy(b, 0, out, a.length, b.length);
return out;
}
// 寻找数组的 GCD遇到 1 则立即返回 1早期退出优化
public static long findGCDWithEarlyExit(long[] arr) {
if (arr.length == 0) return 0;
long result = 0;
for (long v : arr) {
if (v == 1) return 1; // already irreducible
if (v == 0) continue;
if (result == 0) result = v; else result = gcd(result, v);
if (result == 1) return 1;
}
return result == 0 ? 0 : Math.abs(result);
}
// 根据给定的 gcd 返回一个已除以 gcd 的新数组如果 gcd==1 返回原数组避免不必要的分配
public static long[] simplifyByGcd(long[] arr, long gcd) {
if (gcd <= 1) return arr;
long[] out = new long[arr.length];
for (int i = 0; i < arr.length; i++) out[i] = arr[i] / gcd;
return out;
}
}

View File

@ -21,35 +21,56 @@ public final class PatternScaler {
IInput[] baseInputs = base.getInputs();
GenericStack[] baseOutputs = base.getOutputs();
/* 1. 构建缩放后的 sparseInputs */
GenericStack[] scaledSparseInputs = new GenericStack[baseSparseInputs.length];
for (int i = 0; i < baseSparseInputs.length; i++) {
GenericStack in = baseSparseInputs[i];
if (in != null) {
scaledSparseInputs[i] = new GenericStack(in.what(), requestedAmount);
}
// 计算每个压缩输入槽位的总量per operation: multiplier * template.amount
long[] inputsCounts = new long[baseInputs.length];
for (int i = 0; i < baseInputs.length; i++) {
var in = baseInputs[i];
var first = in.getPossibleInputs()[0];
inputsCounts[i] = in.getMultiplier() * first.amount();
}
/* 2. 构建缩放后的 sparseOutputs */
GenericStack[] scaledSparseOutputs = new GenericStack[baseSparseOutputs.length];
for (int i = 0; i < baseSparseOutputs.length; i++) {
GenericStack out = baseSparseOutputs[i];
if (out != null) {
scaledSparseOutputs[i] = new GenericStack(out.what(), requestedAmount);
}
// 计算每个输出的数量per operation
long[] outputsCounts = new long[baseOutputs.length];
for (int i = 0; i < baseOutputs.length; i++) {
var out = baseOutputs[i];
outputsCounts[i] = out == null ? 0L : out.amount();
}
/* 3. 构建压缩输入ScaledInput */
// 合并为一个数组并计算 gcd使用早期退出优化
long[] combined = ArraySimplifier.combine(inputsCounts, outputsCounts);
long gcd = ArraySimplifier.findGCDWithEarlyExit(combined);
if (gcd <= 0) gcd = 1;
// 如果 gcd == 1则无需分配新的数组直接使用 combined 作为 simplified 视图
long[] simplified = ArraySimplifier.simplifyByGcd(combined, gcd);
// 找到目标输出在 outputs 中的索引
int targetOutIndex = -1;
for (int i = 0; i < baseOutputs.length; i++) {
if (baseOutputs[i] != null) {
targetOutIndex = i;
break;
}
}
if (targetOutIndex == -1 && baseOutputs.length > 0) targetOutIndex = 0;
long simplifiedTargetPerUnit = simplified[inputsCounts.length + Math.max(0, targetOutIndex)];
if (simplifiedTargetPerUnit <= 0) simplifiedTargetPerUnit = 1;
// 单位数需要多少 "最简约单位" 才能满足 requestedAmount向上取整
long units = (requestedAmount + simplifiedTargetPerUnit - 1) / simplifiedTargetPerUnit;
// 构建压缩输入ScaledInput模板数量为 simplifiedInputs, multiplier units
IInput[] scaledInputs = new IInput[baseInputs.length];
for (int i = 0; i < baseInputs.length; i++) {
var in = baseInputs[i];
var template = in.getPossibleInputs();
GenericStack[] scaledTemplates = new GenericStack[template.length];
long simplifiedInputAmount = simplified[i];
for (int j = 0; j < template.length; j++) {
scaledTemplates[j] = new GenericStack(template[j].what(), 1);
scaledTemplates[j] = new GenericStack(template[j].what(), simplifiedInputAmount);
}
scaledInputs[i] = new ScaledProcessingPattern.Input(scaledTemplates, requestedAmount);
scaledInputs[i] = new ScaledProcessingPattern.Input(scaledTemplates, units);
}
/* 4. 构建压缩输出 */
@ -57,10 +78,30 @@ public final class PatternScaler {
for (int i = 0; i < baseOutputs.length; i++) {
GenericStack out = baseOutputs[i];
if (out != null) {
scaledCondensedOutputs[i] = new GenericStack(out.what(), requestedAmount);
long simplifiedOutAmount = simplified[inputsCounts.length + i];
scaledCondensedOutputs[i] = new GenericStack(out.what(), simplifiedOutAmount * units);
}
}
// 构建并打印稀疏表示 unit * simplified / gcd 映射回原稀疏槽
GenericStack[] scaledSparseInputs = new GenericStack[baseSparseInputs.length];
for (int i = 0; i < baseSparseInputs.length; i++) {
var in = baseSparseInputs[i];
if (in != null) {
long scaledAmount = in.amount() * units / gcd;
scaledSparseInputs[i] = new GenericStack(in.what(), scaledAmount);
}
}
GenericStack[] scaledSparseOutputs = new GenericStack[baseSparseOutputs.length];
for (int i = 0; i < baseSparseOutputs.length; i++) {
var out = baseSparseOutputs[i];
if (out != null) {
long scaledAmount = out.amount() * units / gcd;
scaledSparseOutputs[i] = new GenericStack(out.what(), scaledAmount);
}
}
/* Debug 输出 */
System.out.println("[extendedae_plus] 正在缩放样板:");
System.out.println(" 原始样板: " + base);
@ -79,10 +120,4 @@ public final class PatternScaler {
scaledInputs,
scaledCondensedOutputs);
}
private static long safeMul(long a, long b) {
if (a == 0 || b == 0) return 0;
if (a > Long.MAX_VALUE / b) return Long.MAX_VALUE;
return a * b;
}
}