Rename TruncatingInputStream -> LimitedInputStream.

This commit is contained in:
Greyson Parrelli 2024-09-06 08:13:50 -04:00 committed by Cody Henthorne
parent a6767e4f8a
commit a8fb4eb21a
7 changed files with 55 additions and 47 deletions

View file

@ -9,8 +9,8 @@ import com.google.common.io.CountingInputStream
import org.signal.core.util.readFully import org.signal.core.util.readFully
import org.signal.core.util.readNBytesOrThrow import org.signal.core.util.readNBytesOrThrow
import org.signal.core.util.readVarInt32 import org.signal.core.util.readVarInt32
import org.signal.core.util.stream.LimitedInputStream
import org.signal.core.util.stream.MacInputStream import org.signal.core.util.stream.MacInputStream
import org.signal.core.util.stream.TruncatingInputStream
import org.thoughtcrime.securesms.backup.v2.proto.BackupInfo import org.thoughtcrime.securesms.backup.v2.proto.BackupInfo
import org.thoughtcrime.securesms.backup.v2.proto.Frame import org.thoughtcrime.securesms.backup.v2.proto.Frame
import org.whispersystems.signalservice.api.backup.BackupKey import org.whispersystems.signalservice.api.backup.BackupKey
@ -56,7 +56,7 @@ class EncryptedBackupReader(
stream = GZIPInputStream( stream = GZIPInputStream(
CipherInputStream( CipherInputStream(
TruncatingInputStream( LimitedInputStream(
wrapped = countingStream, wrapped = countingStream,
maxBytes = length - MAC_SIZE maxBytes = length - MAC_SIZE
), ),
@ -121,7 +121,7 @@ class EncryptedBackupReader(
} }
val macStream = MacInputStream( val macStream = MacInputStream(
wrapped = TruncatingInputStream(dataStream, maxBytes = streamLength - MAC_SIZE), wrapped = LimitedInputStream(dataStream, maxBytes = streamLength - MAC_SIZE),
mac = mac mac = mac
) )

View file

@ -19,7 +19,7 @@ package org.thoughtcrime.securesms.crypto;
import androidx.annotation.NonNull; import androidx.annotation.NonNull;
import org.signal.core.util.logging.Log; import org.signal.core.util.logging.Log;
import org.signal.core.util.stream.TruncatingInputStream; import org.signal.core.util.stream.LimitedInputStream;
import org.thoughtcrime.securesms.util.Util; import org.thoughtcrime.securesms.util.Util;
import java.io.File; import java.io.File;
@ -63,7 +63,7 @@ public class ClassicDecryptingPartInputStream {
IvParameterSpec iv = new IvParameterSpec(ivBytes); IvParameterSpec iv = new IvParameterSpec(ivBytes);
cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(attachmentSecret.getClassicCipherKey(), "AES"), iv); cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(attachmentSecret.getClassicCipherKey(), "AES"), iv);
return new CipherInputStreamWrapper(new TruncatingInputStream(fileStream, file.length() - MAC_LENGTH - IV_LENGTH), cipher); return new CipherInputStreamWrapper(new LimitedInputStream(fileStream, file.length() - MAC_LENGTH - IV_LENGTH), cipher);
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | InvalidAlgorithmParameterException e) { } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | InvalidAlgorithmParameterException e) {
throw new AssertionError(e); throw new AssertionError(e);
} }
@ -72,7 +72,7 @@ public class ClassicDecryptingPartInputStream {
private static void verifyMac(AttachmentSecret attachmentSecret, File file) throws IOException { private static void verifyMac(AttachmentSecret attachmentSecret, File file) throws IOException {
Mac mac = initializeMac(new SecretKeySpec(attachmentSecret.getClassicMacKey(), "HmacSHA1")); Mac mac = initializeMac(new SecretKeySpec(attachmentSecret.getClassicMacKey(), "HmacSHA1"));
FileInputStream macStream = new FileInputStream(file); FileInputStream macStream = new FileInputStream(file);
InputStream dataStream = new TruncatingInputStream(new FileInputStream(file), file.length() - MAC_LENGTH); InputStream dataStream = new LimitedInputStream(new FileInputStream(file), file.length() - MAC_LENGTH);
byte[] theirMac = new byte[MAC_LENGTH]; byte[] theirMac = new byte[MAC_LENGTH];
if (macStream.skip(file.length() - MAC_LENGTH) != file.length() - MAC_LENGTH) { if (macStream.skip(file.length() - MAC_LENGTH) != file.length() - MAC_LENGTH) {

View file

@ -15,21 +15,22 @@ import kotlin.math.min
/** /**
* An [InputStream] that will read from the target [InputStream] until it reaches the end, or until it has read [maxBytes] bytes. * An [InputStream] that will read from the target [InputStream] until it reaches the end, or until it has read [maxBytes] bytes.
*/ */
class TruncatingInputStream(private val wrapped: InputStream, private val maxBytes: Long) : FilterInputStream(wrapped) { class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: Long) : FilterInputStream(wrapped) {
private var bytesRead: Long = 0 private var totalBytesRead: Long = 0
private var lastMark = -1L private var lastMark = -1L
override fun read(): Int { override fun read(): Int {
if (bytesRead >= maxBytes) { if (totalBytesRead >= maxBytes) {
return -1 return -1
} }
return wrapped.read().also { val read = wrapped.read()
if (it >= 0) { if (read >= 0) {
bytesRead++ totalBytesRead++
}
} }
return read
} }
override fun read(destination: ByteArray): Int { override fun read(destination: ByteArray): Int {
@ -37,34 +38,33 @@ class TruncatingInputStream(private val wrapped: InputStream, private val maxByt
} }
override fun read(destination: ByteArray, offset: Int, length: Int): Int { override fun read(destination: ByteArray, offset: Int, length: Int): Int {
if (bytesRead >= maxBytes) { if (totalBytesRead >= maxBytes) {
return -1 return -1
} }
val bytesRemaining: Long = maxBytes - bytesRead val bytesRemaining: Long = maxBytes - totalBytesRead
val bytesToRead: Int = if (bytesRemaining > length) length else Math.toIntExact(bytesRemaining) val bytesToRead: Int = min(length, Math.toIntExact(bytesRemaining))
val bytesRead = wrapped.read(destination, offset, bytesToRead) val bytesRead = wrapped.read(destination, offset, bytesToRead)
if (bytesRead > 0) { if (bytesRead > 0) {
this.bytesRead += bytesRead totalBytesRead += bytesRead
} }
return bytesRead return bytesRead
} }
override fun skip(requestedSkipCount: Long): Long { override fun skip(requestedSkipCount: Long): Long {
val bytesRemaining: Long = maxBytes - bytesRead val bytesRemaining: Long = maxBytes - totalBytesRead
val bytesToSkip: Long = min(bytesRemaining, requestedSkipCount) val bytesToSkip: Long = min(bytesRemaining, requestedSkipCount)
val skipCount = super.skip(bytesToSkip)
return super.skip(bytesToSkip).also { bytesSkipped -> totalBytesRead += skipCount
if (bytesSkipped > 0) {
this.bytesRead += bytesSkipped return skipCount
}
}
} }
override fun available(): Int { override fun available(): Int {
val bytesRemaining = Math.toIntExact(maxBytes - bytesRead) val bytesRemaining = Math.toIntExact(maxBytes - totalBytesRead)
return min(bytesRemaining, wrapped.available()) return min(bytesRemaining, wrapped.available())
} }
@ -78,7 +78,7 @@ class TruncatingInputStream(private val wrapped: InputStream, private val maxByt
} }
wrapped.mark(readlimit) wrapped.mark(readlimit)
lastMark = bytesRead lastMark = totalBytesRead
} }
override fun reset() { override fun reset() {
@ -91,7 +91,7 @@ class TruncatingInputStream(private val wrapped: InputStream, private val maxByt
} }
wrapped.reset() wrapped.reset()
bytesRead = lastMark totalBytesRead = lastMark
} }
/** /**
@ -100,6 +100,10 @@ class TruncatingInputStream(private val wrapped: InputStream, private val maxByt
* @param byteLimit The maximum number of truncated bytes to read. Defaults to no limit. * @param byteLimit The maximum number of truncated bytes to read. Defaults to no limit.
*/ */
fun readTruncatedBytes(byteLimit: Int = -1): ByteArray { fun readTruncatedBytes(byteLimit: Int = -1): ByteArray {
if (totalBytesRead < maxBytes) {
throw IllegalStateException("Stream has not been fully read")
}
return if (byteLimit < 0) { return if (byteLimit < 0) {
wrapped.readFully() wrapped.readFully()
} else { } else {

View file

@ -10,11 +10,11 @@ import org.junit.Test
import org.signal.core.util.readFully import org.signal.core.util.readFully
import org.signal.core.util.readNBytesOrThrow import org.signal.core.util.readNBytesOrThrow
class TruncatingInputStreamTest { class LimitedInputStreamTest {
@Test @Test
fun `when I fully read the stream via a buffer, I should only get maxBytes`() { fun `when I fully read the stream via a buffer, I should only get maxBytes`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
val data = inputStream.readFully() val data = inputStream.readFully()
assertEquals(75, data.size) assertEquals(75, data.size)
@ -22,7 +22,7 @@ class TruncatingInputStreamTest {
@Test @Test
fun `when I fully read the stream one byte at a time, I should only get maxBytes`() { fun `when I fully read the stream one byte at a time, I should only get maxBytes`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
var count = 0 var count = 0
var lastRead = inputStream.read() var lastRead = inputStream.read()
@ -36,7 +36,7 @@ class TruncatingInputStreamTest {
@Test @Test
fun `when I skip past the maxBytes, I should get -1`() { fun `when I skip past the maxBytes, I should get -1`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
val skipCount = inputStream.skip(100) val skipCount = inputStream.skip(100)
val read = inputStream.read() val read = inputStream.read()
@ -47,7 +47,7 @@ class TruncatingInputStreamTest {
@Test @Test
fun `when I skip, I should still truncate correctly afterwards`() { fun `when I skip, I should still truncate correctly afterwards`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
val skipCount = inputStream.skip(50) val skipCount = inputStream.skip(50)
val data = inputStream.readFully() val data = inputStream.readFully()
@ -58,7 +58,7 @@ class TruncatingInputStreamTest {
@Test @Test
fun `when I skip more than maxBytes, I only skip maxBytes`() { fun `when I skip more than maxBytes, I only skip maxBytes`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
val skipCount = inputStream.skip(100) val skipCount = inputStream.skip(100)
@ -67,7 +67,7 @@ class TruncatingInputStreamTest {
@Test @Test
fun `when I finish reading the stream, getTruncatedBytes gives me the rest`() { fun `when I finish reading the stream, getTruncatedBytes gives me the rest`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
inputStream.readFully() inputStream.readFully()
val truncatedBytes = inputStream.readTruncatedBytes() val truncatedBytes = inputStream.readTruncatedBytes()
@ -76,16 +76,22 @@ class TruncatingInputStreamTest {
@Test @Test
fun `when I finish reading the stream, getTruncatedBytes gives me the rest, respecting the byte limit`() { fun `when I finish reading the stream, getTruncatedBytes gives me the rest, respecting the byte limit`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
inputStream.readFully() inputStream.readFully()
val truncatedBytes = inputStream.readTruncatedBytes(byteLimit = 10) val truncatedBytes = inputStream.readTruncatedBytes(byteLimit = 10)
assertEquals(10, truncatedBytes.size) assertEquals(10, truncatedBytes.size)
} }
@Test(expected = IllegalStateException::class)
fun `if I have not finished reading the stream, getTruncatedBytes throws IllegalStateException`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
inputStream.readTruncatedBytes()
}
@Test @Test
fun `when I call available, it should respect the maxBytes`() { fun `when I call available, it should respect the maxBytes`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
val available = inputStream.available() val available = inputStream.available()
assertEquals(75, available) assertEquals(75, available)
@ -93,7 +99,7 @@ class TruncatingInputStreamTest {
@Test @Test
fun `when I call available after reading some bytes, it should respect the maxBytes`() { fun `when I call available after reading some bytes, it should respect the maxBytes`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
inputStream.readNBytesOrThrow(50) inputStream.readNBytesOrThrow(50)
val available = inputStream.available() val available = inputStream.available()
@ -103,7 +109,7 @@ class TruncatingInputStreamTest {
@Test @Test
fun `when I mark and reset, it should jump back to the correct position`() { fun `when I mark and reset, it should jump back to the correct position`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
inputStream.mark(100) inputStream.mark(100)
inputStream.readNBytesOrThrow(10) inputStream.readNBytesOrThrow(10)

View file

@ -6,7 +6,7 @@
package org.whispersystems.signalservice.api.crypto; package org.whispersystems.signalservice.api.crypto;
import org.signal.core.util.stream.TruncatingInputStream; import org.signal.core.util.stream.LimitedInputStream;
import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice; import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice;
import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream; import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream;
@ -117,7 +117,7 @@ public class AttachmentCipherInputStream extends FilterInputStream {
InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], streamLength - BLOCK_SIZE - mac.getMacLength()); InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], streamLength - BLOCK_SIZE - mac.getMacLength());
if (plaintextLength != 0) { if (plaintextLength != 0) {
inputStream = new TruncatingInputStream(inputStream, plaintextLength); inputStream = new LimitedInputStream(inputStream, plaintextLength);
} }
return inputStream; return inputStream;
@ -142,7 +142,7 @@ public class AttachmentCipherInputStream extends FilterInputStream {
InputStream inputStream = new AttachmentCipherInputStream(new FileInputStream(file), archivedMediaKeyMaterial.getCipherKey(), file.length() - BLOCK_SIZE - mac.getMacLength()); InputStream inputStream = new AttachmentCipherInputStream(new FileInputStream(file), archivedMediaKeyMaterial.getCipherKey(), file.length() - BLOCK_SIZE - mac.getMacLength());
if (originalCipherTextLength != 0) { if (originalCipherTextLength != 0) {
inputStream = new TruncatingInputStream(inputStream, originalCipherTextLength); inputStream = new LimitedInputStream(inputStream, originalCipherTextLength);
} }
return inputStream; return inputStream;
@ -179,7 +179,7 @@ public class AttachmentCipherInputStream extends FilterInputStream {
InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], file.length() - BLOCK_SIZE - mac.getMacLength()); InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], file.length() - BLOCK_SIZE - mac.getMacLength());
if (plaintextLength != 0) { if (plaintextLength != 0) {
inputStream = new TruncatingInputStream(inputStream, plaintextLength); inputStream = new LimitedInputStream(inputStream, plaintextLength);
} }
return inputStream; return inputStream;

View file

@ -6,15 +6,13 @@
package org.whispersystems.signalservice.api.messages.multidevice; package org.whispersystems.signalservice.api.messages.multidevice;
import org.signal.core.util.stream.TruncatingInputStream; import org.signal.core.util.stream.LimitedInputStream;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.logging.Log; import org.signal.libsignal.protocol.logging.Log;
import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.profiles.ProfileKey; import org.signal.libsignal.zkgroup.profiles.ProfileKey;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentStream;
import org.whispersystems.signalservice.api.push.ServiceId; import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.push.ServiceId.ACI; import org.whispersystems.signalservice.api.push.ServiceId.ACI;
import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.push.SignalServiceAddress;
@ -62,7 +60,7 @@ public class DeviceContactsInputStream extends ChunkedInputStream {
if (details.avatar != null && details.avatar.length != null) { if (details.avatar != null && details.avatar.length != null) {
long avatarLength = details.avatar.length; long avatarLength = details.avatar.length;
InputStream avatarStream = new TruncatingInputStream(in, avatarLength); InputStream avatarStream = new LimitedInputStream(in, avatarLength);
String avatarContentType = details.avatar.contentType != null ? details.avatar.contentType : "image/*"; String avatarContentType = details.avatar.contentType != null ? details.avatar.contentType : "image/*";
avatar = Optional.of(new DeviceContactAvatar(avatarStream, avatarLength, avatarContentType)); avatar = Optional.of(new DeviceContactAvatar(avatarStream, avatarLength, avatarContentType));

View file

@ -6,7 +6,7 @@
package org.thoughtcrime.securesms.video.postprocessing package org.thoughtcrime.securesms.video.postprocessing
import org.signal.core.util.readLength import org.signal.core.util.readLength
import org.signal.core.util.stream.TruncatingInputStream import org.signal.core.util.stream.LimitedInputStream
import org.signal.libsignal.media.Mp4Sanitizer import org.signal.libsignal.media.Mp4Sanitizer
import org.signal.libsignal.media.SanitizedMetadata import org.signal.libsignal.media.SanitizedMetadata
import org.thoughtcrime.securesms.video.exceptions.VideoPostProcessingException import org.thoughtcrime.securesms.video.exceptions.VideoPostProcessingException
@ -34,7 +34,7 @@ class Mp4FaststartPostProcessor(private val inputStreamFactory: InputStreamFacto
} }
val inputStream = inputStreamFactory.create() val inputStream = inputStreamFactory.create()
inputStream.skip(metadata.dataOffset) inputStream.skip(metadata.dataOffset)
return SequenceInputStream(ByteArrayInputStream(metadata.sanitizedMetadata), TruncatingInputStream(inputStream, metadata.dataLength)) return SequenceInputStream(ByteArrayInputStream(metadata.sanitizedMetadata), LimitedInputStream(inputStream, metadata.dataLength))
} }
fun processAndWriteTo(outputStream: OutputStream, inputLength: Long = calculateStreamLength(inputStreamFactory.create())): Long { fun processAndWriteTo(outputStream: OutputStream, inputLength: Long = calculateStreamLength(inputStreamFactory.create())): Long {