Change the way hashes for packets are sent and received on the PacketLoader.

Add new utility functions for PacketLoader.
This commit is contained in:
Captain ALM 2023-06-11 01:20:01 +01:00
parent 590b7c5d5d
commit f4e3dc8f11
Signed by: alfred
GPG Key ID: 4E4ADD02609997B1
6 changed files with 155 additions and 41 deletions

View File

@ -18,7 +18,7 @@ public final class FragmentationOptions {
* See: * See:
* {@link FragmentSender#setSplitSize(int)} * {@link FragmentSender#setSplitSize(int)}
*/ */
public int fragmentationSplitSize = 496; public int fragmentationSplitSize = 448;
/** /**
* See: * See:
* {@link FragmentReceiver#setNumberOfEmptySendsTillForcedCompleteOrResend(int)} * {@link FragmentReceiver#setNumberOfEmptySendsTillForcedCompleteOrResend(int)}

View File

@ -101,17 +101,16 @@ public class NetMarshalClient implements Closeable {
fragmentMonitorThread = new Thread(() -> { fragmentMonitorThread = new Thread(() -> {
int ageCheckTime = this.fragmentationOptions.maximumFragmentAge - 1; int ageCheckTime = this.fragmentationOptions.maximumFragmentAge - 1;
while (running) { while (running) {
int id = -1;
synchronized (this.fragmentationOptions) { synchronized (this.fragmentationOptions) {
for (int c : fragmentRMM.keySet()) { for (int c : fragmentRMM.keySet()) {
if (!fragmentRMM.get(c).plusSeconds(ageCheckTime).isAfter(LocalDateTime.now())) { if (!fragmentRMM.get(c).plusSeconds(ageCheckTime).isAfter(LocalDateTime.now())) {
fragmentRMM.remove(id); fragmentRMM.remove(c);
fragmentReceiver.deletePacketFromRegistry(c); fragmentReceiver.deletePacketFromRegistry(c);
} }
} }
for (int c : fragmentSMM.keySet()) { for (int c : fragmentSMM.keySet()) {
if (!fragmentSMM.get(c).plusSeconds(ageCheckTime).isAfter(LocalDateTime.now())) { if (!fragmentSMM.get(c).plusSeconds(ageCheckTime).isAfter(LocalDateTime.now())) {
fragmentSMM.remove(id); fragmentSMM.remove(c);
fragmentSender.deletePacketFromRegistry(c); fragmentSender.deletePacketFromRegistry(c);
} }
} }
@ -127,8 +126,8 @@ public class NetMarshalClient implements Closeable {
fragmentSMM.clear(); fragmentSMM.clear();
}, "thread_frag_monitor_" + remoteAddress.getHostAddress() + ":" + remotePort); }, "thread_frag_monitor_" + remoteAddress.getHostAddress() + ":" + remotePort);
fragmentFinishReceiveMonitorThread = new Thread(() -> { fragmentFinishReceiveMonitorThread = new Thread(() -> {
while (running) {
int id = -1; int id = -1;
while (running) {
try { try {
while ((id = fragmentReceiver.getLastIDFinished()) != -1) synchronized (this.fragmentationOptions) { while ((id = fragmentReceiver.getLastIDFinished()) != -1) synchronized (this.fragmentationOptions) {
fragmentRMM.remove(id); fragmentRMM.remove(id);
@ -139,8 +138,8 @@ public class NetMarshalClient implements Closeable {
fragmentReceiver.clearLastIDFinished(); fragmentReceiver.clearLastIDFinished();
}, "thread_frag_fin_recv_monitor_" + remoteAddress.getHostAddress() + ":" + remotePort); }, "thread_frag_fin_recv_monitor_" + remoteAddress.getHostAddress() + ":" + remotePort);
fragmentFinishSendMonitorThread = new Thread(() -> { fragmentFinishSendMonitorThread = new Thread(() -> {
while (running) {
int id = -1; int id = -1;
while (running) {
try { try {
while ((id = fragmentSender.getLastIDFinished()) != -1) synchronized (this.fragmentationOptions) { while ((id = fragmentSender.getLastIDFinished()) != -1) synchronized (this.fragmentationOptions) {
fragmentSMM.remove(id); fragmentSMM.remove(id);

View File

@ -23,24 +23,41 @@ import static com.captainalm.lib.calmnet.packet.PacketProtocolInformation.savePa
public class PacketLoader { public class PacketLoader {
protected boolean allowInvalidPackets; protected boolean allowInvalidPackets;
protected boolean oldPacketFormat;
/** /**
* Constructs a new Packet loader instance. * Constructs a new Packet loader instance.
* If using a digest provider, use {@link #PacketLoader(DigestProvider)} * If using a digest provider, use {@link #PacketLoader(DigestProvider)}
*/ */
public PacketLoader() { public PacketLoader() {
this(null); this(null, false);
} }
/** /**
* Constructs a new Packet loader instance with the specified {@link DigestProvider}. * Constructs a new Packet loader instance with the specified {@link DigestProvider}.
* If using a digest provider, make sure all endpoints use the same algorithm; * If using a digest provider, make sure all endpoints use the same algorithm;
* if null, no trailer is created or expected; * if null, no trailer is created;
* this is ignored if saving / loading packets from byte arrays. * this is ignored if saving / loading packets from byte arrays.
* *
* @param provider The digest provider or null. * @param provider The digest provider or null.
*/ */
public PacketLoader(DigestProvider provider) { public PacketLoader(DigestProvider provider) {
this(provider, false);
}
/**
* Constructs a new Packet loader instance with the specified {@link DigestProvider}
* and if the old packet format should be used.
* If using a digest provider, make sure all endpoints use the same algorithm;
* if null, no trailer is created;
* this is ignored if saving / loading packets from byte arrays.
*
* @param provider The digest provider or null.
* @param oldPacketFormat If the old packet format should be used (No explicit hash indication nor length).
*/
public PacketLoader(DigestProvider provider, boolean oldPacketFormat) {
hashProvider = provider; hashProvider = provider;
this.oldPacketFormat = oldPacketFormat;
} }
protected DigestProvider hashProvider; protected DigestProvider hashProvider;
@ -72,10 +89,51 @@ public class PacketLoader {
this.allowInvalidPackets = allowInvalidPackets; this.allowInvalidPackets = allowInvalidPackets;
} }
/**
* Is the old packet format in use (No explicit hash indication nor length).
*
* @return If the old packet format is in use.
*/
public boolean isOldPacketFormatInUse() {
return oldPacketFormat;
}
/**
* Sets if the old packet format should be used (No explicit hash indication nor length).
* @param useOldFormat If the old packet format should be used.
*/
public void setOldPacketFormatUsage(boolean useOldFormat) {
oldPacketFormat = useOldFormat;
}
protected boolean isPacketInvalid(IPacket packetIn) { protected boolean isPacketInvalid(IPacket packetIn) {
return (packetIn == null || !packetIn.isValid()) && !allowInvalidPackets; return (packetIn == null || !packetIn.isValid()) && !allowInvalidPackets;
} }
/**
* Adds the most significant flag to the given integer.
*
* @param value The integer to add the flag to.
* @return The integer with the flag added.
*/
public static int addMostSignificantFlag(int value) {
value += 1;
value += Integer.MAX_VALUE;
return value;
}
/**
* Subtracts the most significant flag from the given integer.
*
* @param value The integer to subtract the flag from.
* @return The integer with the flag subtracted.
*/
public static int subtractMostSignificantFlag(int value) {
value -= 1;
value -= Integer.MAX_VALUE;
return value;
}
/** /**
* Reads a {@link IPacket} from a byte array (No digest support). * Reads a {@link IPacket} from a byte array (No digest support).
* If the information parameter is null, this is obtained as part of the reading. * If the information parameter is null, this is obtained as part of the reading.
@ -102,8 +160,9 @@ public class PacketLoader {
if (toret != null) { if (toret != null) {
if (arrayIn.length < 6) throw new PacketException("arrayIn does not have a length header."); if (arrayIn.length < 6) throw new PacketException("arrayIn does not have a length header.");
int length = (arrayIn[2] & 0xff) * 16777216 + (arrayIn[3] & 0xff) * 65536 + (arrayIn[4] & 0xff) * 256 + (arrayIn[5] & 0xff); int length = (arrayIn[2] & 0xff) * 16777216 + (arrayIn[3] & 0xff) * 65536 + (arrayIn[4] & 0xff) * 256 + (arrayIn[5] & 0xff);
if (length < 0) length = subtractMostSignificantFlag(length);
byte[] loadArray = new byte[length]; byte[] loadArray = new byte[length];
System.arraycopy(arrayIn, 6, loadArray, 0, arrayIn.length - 6); System.arraycopy(arrayIn, 6, loadArray, 0, Math.min(arrayIn.length - 6, length));
toret.loadPayload(loadArray); toret.loadPayload(loadArray);
if (isPacketInvalid(toret)) toret = null; if (isPacketInvalid(toret)) toret = null;
} }
@ -131,9 +190,23 @@ public class PacketLoader {
IPacket toret = factory.getPacket(information); IPacket toret = factory.getPacket(information);
if (toret != null) { if (toret != null) {
InputStream lIS = (hashProvider == null) ? inputStream : hashProvider.getDigestInputStream(inputStream); int length = readInteger(inputStream);
byte[] loadArray = readArrayFromInputStream(lIS, readInteger(inputStream)); boolean hasHash = length < 0;
if (hashProvider == null || DigestComparer.compareDigests(inputStream, ((DigestInputStream) lIS).getMessageDigest().digest())) toret.loadPayload(loadArray); if (hasHash) length = subtractMostSignificantFlag(length);
InputStream lIS = (hashProvider == null || !hasHash) ? inputStream : hashProvider.getDigestInputStream(inputStream);
byte[] loadArray = readArrayFromInputStream(lIS, length);
int hashLength;
if (hasHash) {
hashLength = readByteIntegerFromInputStream(inputStream);
if (hashProvider != null && hashProvider.getLength() != hashLength) {
readArrayFromInputStream(inputStream, hashLength);
return null;
}
} else hashLength = 0;
if ((!hasHash && !oldPacketFormat) || hashProvider == null) {
readArrayFromInputStream(inputStream, hashLength);
toret.loadPayload(loadArray);
} else if (DigestComparer.compareDigests(inputStream, ((DigestInputStream) lIS).getMessageDigest().digest())) toret.loadPayload(loadArray);
if (isPacketInvalid(toret)) toret = null; if (isPacketInvalid(toret)) toret = null;
} }
return toret; return toret;
@ -161,7 +234,9 @@ public class PacketLoader {
IPacket toret = factory.getPacket(information); IPacket toret = factory.getPacket(information);
if (toret != null) { if (toret != null) {
byte[] loadArray = readArrayFromInputStream(inputStream, readInteger(inputStream)); int length = readInteger(inputStream);
if (length < 0) length = subtractMostSignificantFlag(length);
byte[] loadArray = readArrayFromInputStream(inputStream, length);
toret.loadPayload(loadArray); toret.loadPayload(loadArray);
if (isPacketInvalid(toret)) toret = null; if (isPacketInvalid(toret)) toret = null;
} }
@ -191,9 +266,21 @@ public class PacketLoader {
if (toret instanceof IStreamedPacket) { if (toret instanceof IStreamedPacket) {
int length = readInteger(inputStream); int length = readInteger(inputStream);
InputStream lIS = (hashProvider == null) ? inputStream : hashProvider.getDigestInputStream(inputStream); boolean hasHash = length < 0;
if (hasHash) length = subtractMostSignificantFlag(length);
InputStream lIS = (hashProvider == null || !hasHash) ? inputStream : hashProvider.getDigestInputStream(inputStream);
((IStreamedPacket) toret).writeData(lIS, length); ((IStreamedPacket) toret).writeData(lIS, length);
if (hashProvider != null && !DigestComparer.compareDigests(inputStream, ((DigestInputStream) lIS).getMessageDigest().digest())) toret = null; int hashLength;
if (hasHash) {
hashLength = readByteIntegerFromInputStream(inputStream);
if (hashProvider != null && hashProvider.getLength() != hashLength) {
readArrayFromInputStream(inputStream, hashLength);
return null;
}
} else hashLength = 0;
if ((hasHash || oldPacketFormat) && hashProvider != null) {
if (!DigestComparer.compareDigests(inputStream, ((DigestInputStream) lIS).getMessageDigest().digest())) toret = null;
} else readArrayFromInputStream(inputStream, hashLength);
if (isPacketInvalid(toret)) toret = null; if (isPacketInvalid(toret)) toret = null;
} else if (toret != null) { } else if (toret != null) {
return readPacket(inputStream, factory, information); return readPacket(inputStream, factory, information);
@ -225,6 +312,7 @@ public class PacketLoader {
if (toret instanceof IStreamedPacket) { if (toret instanceof IStreamedPacket) {
int length = readInteger(inputStream); int length = readInteger(inputStream);
if (length < 0) length = subtractMostSignificantFlag(length);
((IStreamedPacket) toret).writeData(inputStream, length); ((IStreamedPacket) toret).writeData(inputStream, length);
if (isPacketInvalid(toret)) toret = null; if (isPacketInvalid(toret)) toret = null;
} else if (toret != null) { } else if (toret != null) {
@ -286,15 +374,25 @@ public class PacketLoader {
if (writeInformation) savePacketProtocolInformation(outputStream, packet.getProtocol()); if (writeInformation) savePacketProtocolInformation(outputStream, packet.getProtocol());
if (packet instanceof IStreamedPacket) { if (packet instanceof IStreamedPacket) {
writeInteger(outputStream, ((IStreamedPacket) packet).getSize()); int pLength = ((IStreamedPacket) packet).getSize();
if (hashProvider != null && !oldPacketFormat) pLength = addMostSignificantFlag(pLength);
writeInteger(outputStream, pLength);
OutputStream lOS = (hashProvider == null) ? outputStream : hashProvider.getDigestOutputStream(outputStream); OutputStream lOS = (hashProvider == null) ? outputStream : hashProvider.getDigestOutputStream(outputStream);
((IStreamedPacket) packet).readData(lOS); ((IStreamedPacket) packet).readData(lOS);
if (hashProvider != null) outputStream.write(((DigestOutputStream) lOS).getMessageDigest().digest()); if (hashProvider != null) {
if (!oldPacketFormat) outputStream.write(hashProvider.getLength());
outputStream.write(((DigestOutputStream) lOS).getMessageDigest().digest());
}
} else { } else {
byte[] saveArray = packet.savePayload(); byte[] saveArray = packet.savePayload();
writeInteger(outputStream, saveArray.length); int pLength = saveArray.length;
if (hashProvider != null && !oldPacketFormat) pLength = addMostSignificantFlag(pLength);
writeInteger(outputStream, pLength);
outputStream.write(saveArray); outputStream.write(saveArray);
if (hashProvider != null) outputStream.write(hashProvider.getDigestOf(saveArray)); if (hashProvider != null) {
if (!oldPacketFormat) outputStream.write(hashProvider.getLength());
outputStream.write(hashProvider.getDigestOf(saveArray));
}
} }
outputStream.flush(); outputStream.flush();
} }
@ -338,10 +436,10 @@ public class PacketLoader {
*/ */
public static int readInteger(InputStream inputStream) throws IOException { public static int readInteger(InputStream inputStream) throws IOException {
if (inputStream == null) throw new NullPointerException("inputStream is null"); if (inputStream == null) throw new NullPointerException("inputStream is null");
int length = (readByteFromInputStream(inputStream) & 0xff) * 16777216; int length = readByteIntegerFromInputStream(inputStream)* 16777216;
length += (readByteFromInputStream(inputStream) & 0xff) * 65536; length += readByteIntegerFromInputStream(inputStream) * 65536;
length += (readByteFromInputStream(inputStream) & 0xff) * 256; length += readByteIntegerFromInputStream(inputStream) * 256;
length += (readByteFromInputStream(inputStream) & 0xff); length += readByteIntegerFromInputStream(inputStream);
return length; return length;
} }
@ -368,6 +466,7 @@ public class PacketLoader {
/** /**
* Reads a byte from an {@link InputStream}. * Reads a byte from an {@link InputStream}.
* See also: {@link #readByteIntegerFromInputStream(InputStream)}.
* *
* @param inputStream The input stream to read from. * @param inputStream The input stream to read from.
* @return The byte read. * @return The byte read.
@ -381,6 +480,22 @@ public class PacketLoader {
return (byte) toret; return (byte) toret;
} }
/**
* Reads a byte (In int form) from an {@link InputStream}.
* See also: {@link #readByteFromInputStream(InputStream)}.
*
* @param inputStream The input stream to read from.
* @return The byte read (As an int).
* @throws NullPointerException inputStream is null.
* @throws IOException An I/O error has occurred or end of stream has been reached.
*/
public static int readByteIntegerFromInputStream(InputStream inputStream) throws IOException {
if (inputStream == null) throw new NullPointerException("inputStream is null");
int toret;
if ((toret = inputStream.read()) == -1) throw new IOException("inputStream end of stream");
return toret;
}
/** /**
* Reads in a byte array of a specified length from an {@link InputStream}. * Reads in a byte array of a specified length from an {@link InputStream}.
* *

View File

@ -12,7 +12,7 @@ import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import static com.captainalm.lib.calmnet.packet.PacketLoader.readByteFromInputStream; import static com.captainalm.lib.calmnet.packet.PacketLoader.readByteIntegerFromInputStream;
/** /**
* This class provides an encrypted packet that can hold an {@link IPacket}. * This class provides an encrypted packet that can hold an {@link IPacket}.
@ -358,7 +358,7 @@ public class EncryptedPacket implements IStreamedPacket, IInternalCache {
if (size < 0) throw new IllegalArgumentException("size is less than 0"); if (size < 0) throw new IllegalArgumentException("size is less than 0");
synchronized (slock) { synchronized (slock) {
if (size < 1) throw new IOException("inputStream end of stream"); if (size < 1) throw new IOException("inputStream end of stream");
int flag = readByteFromInputStream(inputStream) & 0xff; int flag = readByteIntegerFromInputStream(inputStream);
if (size < 5) throw new IOException("inputStream end of stream"); if (size < 5) throw new IOException("inputStream end of stream");
int cipherLenCache = PacketLoader.readInteger(inputStream); int cipherLenCache = PacketLoader.readInteger(inputStream);
@ -380,7 +380,7 @@ public class EncryptedPacket implements IStreamedPacket, IInternalCache {
trailingArrayLengthCache = 0; trailingArrayLengthCache = 0;
if ((flag & 1) == 1) { if ((flag & 1) == 1) {
if (size < 9 + cipherLenCache) throw new IOException("inputStream end of stream"); if (size < 9 + cipherLenCache) throw new IOException("inputStream end of stream");
trailingArrayLengthCache = PacketLoader.readByteFromInputStream(inputStream); trailingArrayLengthCache = PacketLoader.readByteIntegerFromInputStream(inputStream);
if (trailingArrayLengthCache < 1) throw new PacketException("trailer length less than 1"); if (trailingArrayLengthCache < 1) throw new PacketException("trailer length less than 1");
} }

View File

@ -283,19 +283,19 @@ public final class FragmentReceiver {
} }
/** /**
* Gets whether responses should be verified. * Gets whether responses should be verified by sending back the payload to be verified.
* *
* @return Should responses be verified. * @return Should responses be verified by sending back the payload.
*/ */
public boolean shouldVerifyResponses() { public boolean shouldVerifyResponses() {
return verifyResponses; return verifyResponses;
} }
/** /**
* Sets whether responses should be verified. * Sets whether responses should be verified by sending back the payload to be verified.
* If set to false, {@link #setSentDataWillBeAllVerified(boolean)} will be set to false too. * If set to false, {@link #setSentDataWillBeAllVerified(boolean)} will be set to false too.
* *
* @param state If responses should be verified. * @param state If responses should be verified by sending back the payload.
*/ */
public void setResponseVerification(boolean state) { public void setResponseVerification(boolean state) {
synchronized (slock) { synchronized (slock) {
@ -305,19 +305,19 @@ public final class FragmentReceiver {
} }
/** /**
* Gets whether all sent fragments are verified to be equal. * Gets whether all sent fragments are expected to be verified.
* *
* @return If all sent fragments will be verified to be equal. * @return If all sent fragments are expected to be verified.
*/ */
public boolean shouldSentDataBeAllVerified() { public boolean shouldSentDataBeAllVerified() {
return makeSureSendDataVerified; return makeSureSendDataVerified;
} }
/** /**
* Gets whether all sent fragments are verified to be equal. * Gets whether all sent fragments are expected to be verified.
* Requires {@link #setResponseVerification(boolean)} set to true. * Requires {@link #setResponseVerification(boolean)} set to true.
* *
* @param state If all sent fragments will be verified to be equal. * @param state If all sent fragments are expected to be verified.
*/ */
public void setSentDataWillBeAllVerified(boolean state) { public void setSentDataWillBeAllVerified(boolean state) {
synchronized (slock) { synchronized (slock) {

View File

@ -18,7 +18,7 @@ public final class FragmentSender {
private final HashMap<Integer, FragmentOutput> registry = new HashMap<>(); private final HashMap<Integer, FragmentOutput> registry = new HashMap<>();
private final Object slock = new Object(); private final Object slock = new Object();
private final Object slockfinish = new Object(); private final Object slockfinish = new Object();
private int splitSize = 496; private int splitSize = 448;
private PacketLoader packetLoader; private PacketLoader packetLoader;
private boolean verifyResponses = false; private boolean verifyResponses = false;
private boolean makeSureSendDataVerified = false; private boolean makeSureSendDataVerified = false;
@ -249,7 +249,7 @@ public final class FragmentSender {
} }
/** /**
* Gets whether responses should be verified. * Gets whether responses should be verified by checking if they are equal.
* *
* @return Should responses be verified. * @return Should responses be verified.
*/ */
@ -258,7 +258,7 @@ public final class FragmentSender {
} }
/** /**
* Sets whether responses should be verified. * Sets whether responses should be verified by checking if they are equal.
* If set to false, {@link #setSentDataWillBeAllVerified(boolean)} will be set to false too. * If set to false, {@link #setSentDataWillBeAllVerified(boolean)} will be set to false too.
* *
* @param state If responses should be verified. * @param state If responses should be verified.
@ -271,19 +271,19 @@ public final class FragmentSender {
} }
/** /**
* Gets whether all sent fragments are verified to be equal. * Gets whether all sent fragments are verified via resend checks for equality.
* *
* @return If all sent fragments will be verified to be equal. * @return If all sent fragments will be verified via resend checks for equality.
*/ */
public boolean shouldSentDataBeAllVerified() { public boolean shouldSentDataBeAllVerified() {
return makeSureSendDataVerified; return makeSureSendDataVerified;
} }
/** /**
* Gets whether all sent fragments are verified to be equal. * Gets whether all sent fragments are verified via resend checks for equality.
* Requires {@link #setResponseVerification(boolean)} set to true. * Requires {@link #setResponseVerification(boolean)} set to true.
* *
* @param state If all sent fragments will be verified to be equal. * @param state If all sent fragments will be verified via resend checks for equality.
*/ */
public void setSentDataWillBeAllVerified(boolean state) { public void setSentDataWillBeAllVerified(boolean state) {
synchronized (slock) { synchronized (slock) {