diff --git a/src/main/java/org/adde0109/ambassador/forge/ForgeHandshakeUtils.java b/src/main/java/org/adde0109/ambassador/forge/ForgeHandshakeUtils.java index 562e773..36cfa82 100644 --- a/src/main/java/org/adde0109/ambassador/forge/ForgeHandshakeUtils.java +++ b/src/main/java/org/adde0109/ambassador/forge/ForgeHandshakeUtils.java @@ -6,8 +6,8 @@ import com.google.common.io.ByteStreams; import com.velocitypowered.proxy.protocol.ProtocolUtils; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; -import io.netty.handler.codec.DecoderException; import org.adde0109.ambassador.forge.packet.Context; +import org.adde0109.ambassador.forge.packet.GenericForgeLoginWrapperPacket; import org.adde0109.ambassador.forge.packet.IForgeLoginWrapperPacket; import java.nio.charset.StandardCharsets; @@ -117,37 +117,59 @@ public class ForgeHandshakeUtils { return stream.toByteArray(); } - public static class SilentGearUtils { - public static boolean isSilentGearPacket(byte[] data) { - ByteBuf buf = Unpooled.wrappedBuffer(data); - String channel = null; - try { - channel = ProtocolUtils.readString(buf); - } catch (DecoderException e) { - } finally { - buf.release(); - } - return channel != null && channel.equals("silentgear:network"); + public static class ThirdPartyRegistryUtils { + + static enum ThirdPartyChannel { + SILENTGEAR_NETWORK { + @Override + public ThirdPartyRegistryUtils.ACKPacket generateResponsePacket(Context.ClientContext context) { + return new ACKPacket(context, 3); + } + }, + ZETA_MAIN { + @Override + public ThirdPartyRegistryUtils.ACKPacket generateResponsePacket(Context.ClientContext context) { + return new ACKPacket(context, 99); + } + }; + abstract public ACKPacket generateResponsePacket(Context.ClientContext context); } - public static class ACKPacket implements IForgeLoginWrapperPacket { - private final Context.ClientContext context; - - public ACKPacket(Context.ClientContext context) { - this.context = context; + static boolean isThirdPartyPacket(GenericForgeLoginWrapperPacket packet) { + try { + Enum.valueOf(ThirdPartyChannel.class, + packet.getContext().getChannelName().replace(':', '_').toUpperCase()); + return true; + } catch (IllegalArgumentException e) { + return false; } - @Override + } + + static ThirdPartyChannel getThirdPartyChannel(GenericForgeLoginWrapperPacket packet) { + return Enum.valueOf(ThirdPartyChannel.class, + packet.getContext().getChannelName().replace(':', '_').toUpperCase()); + } + + static class ACKPacket implements IForgeLoginWrapperPacket { + + private final Context.ClientContext context; + private final int packetID; + ACKPacket(Context.ClientContext context, int packetID) { + this.context = context; + this.packetID = packetID; + } + public ByteBuf encode() { ByteBuf buf = Unpooled.buffer(); - ProtocolUtils.writeVarInt(buf, 3); + ProtocolUtils.writeVarInt(buf, packetID); return buf; } @Override public Context.ClientContext getContext() { - return context; + return null; } } } diff --git a/src/main/java/org/adde0109/ambassador/forge/VelocityForgeBackendConnectionPhase.java b/src/main/java/org/adde0109/ambassador/forge/VelocityForgeBackendConnectionPhase.java index 4816f79..44c3ff8 100644 --- a/src/main/java/org/adde0109/ambassador/forge/VelocityForgeBackendConnectionPhase.java +++ b/src/main/java/org/adde0109/ambassador/forge/VelocityForgeBackendConnectionPhase.java @@ -62,7 +62,7 @@ public enum VelocityForgeBackendConnectionPhase implements BackendConnectionPhas VelocityForgeBackendConnectionPhase() { } - public void handle(VelocityServerConnection server, ConnectedPlayer player, IForgeLoginWrapperPacket message) { + public void handle(VelocityServerConnection server, ConnectedPlayer player, IForgeLoginWrapperPacket message) { VelocityForgeBackendConnectionPhase newPhase = getNewPhase(server,message); server.setConnectionPhase(newPhase); @@ -137,10 +137,9 @@ public enum VelocityForgeBackendConnectionPhase implements BackendConnectionPhas remainingRegistries.countDown(); } else if (message instanceof ConfigDataPacket) { server.getConnection().write(new ACKPacket(Context.fromContext(message.getContext(), true))); - } else if (message instanceof GenericForgeLoginWrapperPacket packet - && ForgeHandshakeUtils.SilentGearUtils.isSilentGearPacket(packet.getContent())) { - server.getConnection().write(new ForgeHandshakeUtils.SilentGearUtils.ACKPacket( - Context.fromContext(message.getContext(), true))); + } else if (message instanceof GenericForgeLoginWrapperPacket packet + && ForgeHandshakeUtils.ThirdPartyRegistryUtils.isThirdPartyPacket(packet)) { + server.getConnection().write(ForgeHandshakeUtils.ThirdPartyRegistryUtils.getThirdPartyChannel(packet)); } } //Forge server diff --git a/src/main/java/org/adde0109/ambassador/forge/packet/Context.java b/src/main/java/org/adde0109/ambassador/forge/packet/Context.java index 8f39814..22d1864 100644 --- a/src/main/java/org/adde0109/ambassador/forge/packet/Context.java +++ b/src/main/java/org/adde0109/ambassador/forge/packet/Context.java @@ -4,31 +4,38 @@ public class Context { private final int responseID; - private Context(int responseID) { + private final String channelName; + + private Context(int responseID, String channelName) { this.responseID = responseID; + this.channelName = channelName; } - public static Context createContext(int responseID) { - return new Context(responseID); + public static Context createContext(int responseID, String channelName) { + return new Context(responseID, channelName); } - public static ClientContext createClientContext(int responseID, boolean clientSuccess) { - return new ClientContext(responseID,clientSuccess); + public static ClientContext createClientContext(int responseID, boolean clientSuccess, String channelName) { + return new ClientContext(responseID, clientSuccess, channelName); } public static ClientContext fromContext(Context context, boolean clientSuccess) { - return new ClientContext(context.responseID,clientSuccess); + return new ClientContext(context.responseID, clientSuccess, context.channelName); } public int getResponseID() { return responseID; } + public String getChannelName() { + return channelName; + } + public static class ClientContext extends Context { private final boolean clientSuccess; - ClientContext(int responseID, boolean clientSuccess) { - super(responseID); + ClientContext(int responseID, boolean clientSuccess, String channelName) { + super(responseID, channelName); this.clientSuccess = clientSuccess; } diff --git a/src/main/java/org/adde0109/ambassador/forge/pipeline/ForgeLoginWrapperCodec.java b/src/main/java/org/adde0109/ambassador/forge/pipeline/ForgeLoginWrapperCodec.java index dd8eeb4..717487b 100644 --- a/src/main/java/org/adde0109/ambassador/forge/pipeline/ForgeLoginWrapperCodec.java +++ b/src/main/java/org/adde0109/ambassador/forge/pipeline/ForgeLoginWrapperCodec.java @@ -9,7 +9,7 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.DecoderException; import io.netty.handler.codec.MessageToMessageCodec; -import org.adde0109.ambassador.forge.ForgeHandshakeUtils; +import org.adde0109.ambassador.Ambassador; import org.adde0109.ambassador.forge.packet.*; import java.util.ArrayList; @@ -28,21 +28,27 @@ public class ForgeLoginWrapperCodec extends MessageToMessageCodec out) throws Exception { ByteBuf buf = in.content(); - Context context; - if (in instanceof LoginPluginMessagePacket msg && msg.getChannel().equals("fml:loginwrapper")) { - context = Context.createContext(msg.getId()); - } else if (in instanceof LoginPluginResponsePacket msg && loginWrapperIDs.remove(Integer.valueOf(msg.getId()))) { - context = Context.createClientContext(msg.getId(), msg.isSuccess()); - } else { - ctx.fireChannelRead(in.retain()); - return; - } - int originalReaderIndex = buf.readerIndex(); + + String channel; + try { - String channel = ProtocolUtils.readString(buf); + Context context; + if (in instanceof LoginPluginMessagePacket msg && msg.getChannel().equals("fml:loginwrapper")) { + channel = ProtocolUtils.readString(buf); + context = Context.createContext(msg.getId(), channel); + } else if (in instanceof LoginPluginResponsePacket msg && loginWrapperIDs.remove(Integer.valueOf(msg.getId()))) { + channel = ProtocolUtils.readString(buf); + context = Context.createClientContext(msg.getId(), msg.isSuccess(), channel); + } else { + //Not a loginWrapperPacket + buf.readerIndex(originalReaderIndex); + ctx.fireChannelRead(in.retain()); + return; + } + if (!channel.equals("fml:handshake")) { - throw new DecoderException(); + out.add(GenericForgeLoginWrapperPacket.read(buf, context)); } else { int length = ProtocolUtils.readVarInt(buf); int packetID = ProtocolUtils.readVarInt(buf); @@ -55,7 +61,7 @@ public class ForgeLoginWrapperCodec extends MessageToMessageCodec msg, List out) throws Exception { ByteBuf wrapped; - if (msg instanceof GenericForgeLoginWrapperPacket) { - wrapped = msg.encode(); - } else { - String channel = "fml:handshake"; - if (msg instanceof ForgeHandshakeUtils.SilentGearUtils.ACKPacket) { - channel = "silentgear:network"; - } - wrapped = Unpooled.buffer(); - ByteBuf encoded = msg.encode(); - ProtocolUtils.writeString(wrapped, channel); - ProtocolUtils.writeVarInt(wrapped, encoded.readableBytes()); - wrapped.writeBytes(encoded); - encoded.release(); - } + + String channel = msg.getContext().getChannelName(); + + wrapped = Unpooled.buffer(); + ByteBuf encoded = msg.encode(); + ProtocolUtils.writeString(wrapped, channel); + ProtocolUtils.writeVarInt(wrapped, encoded.readableBytes()); + wrapped.writeBytes(encoded); + encoded.release(); if (msg.getContext() instanceof Context.ClientContext clientContext) { out.add(new LoginPluginResponsePacket(clientContext.getResponseID(), clientContext.success(), wrapped)); } else { out.add(new LoginPluginMessagePacket(msg.getContext().getResponseID(), "fml:loginwrapper", wrapped)); - if (!(msg instanceof ModDataPacket)) { + if (!(msg instanceof ModDataPacket)) { //ModDataPacket doesn't require a response this.loginWrapperIDs.add(msg.getContext().getResponseID()); } }