diff --git a/app/src/main/java/org/thoughtcrime/securesms/backup/v2/stream/EncryptedBackupReader.kt b/app/src/main/java/org/thoughtcrime/securesms/backup/v2/stream/EncryptedBackupReader.kt index 1ce0702175..cf9c341eee 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/backup/v2/stream/EncryptedBackupReader.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/backup/v2/stream/EncryptedBackupReader.kt @@ -9,8 +9,8 @@ import com.google.common.io.CountingInputStream import org.signal.core.util.readFully import org.signal.core.util.readNBytesOrThrow import org.signal.core.util.readVarInt32 +import org.signal.core.util.stream.LimitedInputStream import org.signal.core.util.stream.MacInputStream -import org.signal.core.util.stream.TruncatingInputStream import org.thoughtcrime.securesms.backup.v2.proto.BackupInfo import org.thoughtcrime.securesms.backup.v2.proto.Frame import org.whispersystems.signalservice.api.backup.BackupKey @@ -56,7 +56,7 @@ class EncryptedBackupReader( stream = GZIPInputStream( CipherInputStream( - TruncatingInputStream( + LimitedInputStream( wrapped = countingStream, maxBytes = length - MAC_SIZE ), @@ -121,7 +121,7 @@ class EncryptedBackupReader( } val macStream = MacInputStream( - wrapped = TruncatingInputStream(dataStream, maxBytes = streamLength - MAC_SIZE), + wrapped = LimitedInputStream(dataStream, maxBytes = streamLength - MAC_SIZE), mac = mac ) diff --git a/app/src/main/java/org/thoughtcrime/securesms/crypto/ClassicDecryptingPartInputStream.java b/app/src/main/java/org/thoughtcrime/securesms/crypto/ClassicDecryptingPartInputStream.java index 061c844eb4..ea1c273e3a 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/crypto/ClassicDecryptingPartInputStream.java +++ b/app/src/main/java/org/thoughtcrime/securesms/crypto/ClassicDecryptingPartInputStream.java @@ -19,7 +19,7 @@ package org.thoughtcrime.securesms.crypto; import androidx.annotation.NonNull; import org.signal.core.util.logging.Log; -import org.signal.core.util.stream.TruncatingInputStream; +import org.signal.core.util.stream.LimitedInputStream; import org.thoughtcrime.securesms.util.Util; import java.io.File; @@ -63,7 +63,7 @@ public class ClassicDecryptingPartInputStream { IvParameterSpec iv = new IvParameterSpec(ivBytes); cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(attachmentSecret.getClassicCipherKey(), "AES"), iv); - return new CipherInputStreamWrapper(new TruncatingInputStream(fileStream, file.length() - MAC_LENGTH - IV_LENGTH), cipher); + return new CipherInputStreamWrapper(new LimitedInputStream(fileStream, file.length() - MAC_LENGTH - IV_LENGTH), cipher); } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | InvalidAlgorithmParameterException e) { throw new AssertionError(e); } @@ -72,7 +72,7 @@ public class ClassicDecryptingPartInputStream { private static void verifyMac(AttachmentSecret attachmentSecret, File file) throws IOException { Mac mac = initializeMac(new SecretKeySpec(attachmentSecret.getClassicMacKey(), "HmacSHA1")); FileInputStream macStream = new FileInputStream(file); - InputStream dataStream = new TruncatingInputStream(new FileInputStream(file), file.length() - MAC_LENGTH); + InputStream dataStream = new LimitedInputStream(new FileInputStream(file), file.length() - MAC_LENGTH); byte[] theirMac = new byte[MAC_LENGTH]; if (macStream.skip(file.length() - MAC_LENGTH) != file.length() - MAC_LENGTH) { diff --git a/core-util-jvm/src/main/java/org/signal/core/util/stream/TruncatingInputStream.kt b/core-util-jvm/src/main/java/org/signal/core/util/stream/LimitedInputStream.kt similarity index 69% rename from core-util-jvm/src/main/java/org/signal/core/util/stream/TruncatingInputStream.kt rename to core-util-jvm/src/main/java/org/signal/core/util/stream/LimitedInputStream.kt index 5760e0c508..47b240a455 100644 --- a/core-util-jvm/src/main/java/org/signal/core/util/stream/TruncatingInputStream.kt +++ b/core-util-jvm/src/main/java/org/signal/core/util/stream/LimitedInputStream.kt @@ -15,21 +15,22 @@ import kotlin.math.min /** * An [InputStream] that will read from the target [InputStream] until it reaches the end, or until it has read [maxBytes] bytes. */ -class TruncatingInputStream(private val wrapped: InputStream, private val maxBytes: Long) : FilterInputStream(wrapped) { +class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: Long) : FilterInputStream(wrapped) { - private var bytesRead: Long = 0 + private var totalBytesRead: Long = 0 private var lastMark = -1L override fun read(): Int { - if (bytesRead >= maxBytes) { + if (totalBytesRead >= maxBytes) { return -1 } - return wrapped.read().also { - if (it >= 0) { - bytesRead++ - } + val read = wrapped.read() + if (read >= 0) { + totalBytesRead++ } + + return read } override fun read(destination: ByteArray): Int { @@ -37,34 +38,33 @@ class TruncatingInputStream(private val wrapped: InputStream, private val maxByt } override fun read(destination: ByteArray, offset: Int, length: Int): Int { - if (bytesRead >= maxBytes) { + if (totalBytesRead >= maxBytes) { return -1 } - val bytesRemaining: Long = maxBytes - bytesRead - val bytesToRead: Int = if (bytesRemaining > length) length else Math.toIntExact(bytesRemaining) + val bytesRemaining: Long = maxBytes - totalBytesRead + val bytesToRead: Int = min(length, Math.toIntExact(bytesRemaining)) val bytesRead = wrapped.read(destination, offset, bytesToRead) if (bytesRead > 0) { - this.bytesRead += bytesRead + totalBytesRead += bytesRead } return bytesRead } override fun skip(requestedSkipCount: Long): Long { - val bytesRemaining: Long = maxBytes - bytesRead + val bytesRemaining: Long = maxBytes - totalBytesRead val bytesToSkip: Long = min(bytesRemaining, requestedSkipCount) + val skipCount = super.skip(bytesToSkip) - return super.skip(bytesToSkip).also { bytesSkipped -> - if (bytesSkipped > 0) { - this.bytesRead += bytesSkipped - } - } + totalBytesRead += skipCount + + return skipCount } override fun available(): Int { - val bytesRemaining = Math.toIntExact(maxBytes - bytesRead) + val bytesRemaining = Math.toIntExact(maxBytes - totalBytesRead) return min(bytesRemaining, wrapped.available()) } @@ -78,7 +78,7 @@ class TruncatingInputStream(private val wrapped: InputStream, private val maxByt } wrapped.mark(readlimit) - lastMark = bytesRead + lastMark = totalBytesRead } override fun reset() { @@ -91,7 +91,7 @@ class TruncatingInputStream(private val wrapped: InputStream, private val maxByt } wrapped.reset() - bytesRead = lastMark + totalBytesRead = lastMark } /** @@ -100,6 +100,10 @@ class TruncatingInputStream(private val wrapped: InputStream, private val maxByt * @param byteLimit The maximum number of truncated bytes to read. Defaults to no limit. */ fun readTruncatedBytes(byteLimit: Int = -1): ByteArray { + if (totalBytesRead < maxBytes) { + throw IllegalStateException("Stream has not been fully read") + } + return if (byteLimit < 0) { wrapped.readFully() } else { diff --git a/core-util-jvm/src/test/java/org/signal/core/util/stream/TruncatingInputStreamTest.kt b/core-util-jvm/src/test/java/org/signal/core/util/stream/LimitedInputStreamTest.kt similarity index 67% rename from core-util-jvm/src/test/java/org/signal/core/util/stream/TruncatingInputStreamTest.kt rename to core-util-jvm/src/test/java/org/signal/core/util/stream/LimitedInputStreamTest.kt index b51fea4c54..7879463749 100644 --- a/core-util-jvm/src/test/java/org/signal/core/util/stream/TruncatingInputStreamTest.kt +++ b/core-util-jvm/src/test/java/org/signal/core/util/stream/LimitedInputStreamTest.kt @@ -10,11 +10,11 @@ import org.junit.Test import org.signal.core.util.readFully import org.signal.core.util.readNBytesOrThrow -class TruncatingInputStreamTest { +class LimitedInputStreamTest { @Test fun `when I fully read the stream via a buffer, I should only get maxBytes`() { - val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) val data = inputStream.readFully() assertEquals(75, data.size) @@ -22,7 +22,7 @@ class TruncatingInputStreamTest { @Test fun `when I fully read the stream one byte at a time, I should only get maxBytes`() { - val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) var count = 0 var lastRead = inputStream.read() @@ -36,7 +36,7 @@ class TruncatingInputStreamTest { @Test fun `when I skip past the maxBytes, I should get -1`() { - val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) val skipCount = inputStream.skip(100) val read = inputStream.read() @@ -47,7 +47,7 @@ class TruncatingInputStreamTest { @Test fun `when I skip, I should still truncate correctly afterwards`() { - val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) val skipCount = inputStream.skip(50) val data = inputStream.readFully() @@ -58,7 +58,7 @@ class TruncatingInputStreamTest { @Test fun `when I skip more than maxBytes, I only skip maxBytes`() { - val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) val skipCount = inputStream.skip(100) @@ -67,7 +67,7 @@ class TruncatingInputStreamTest { @Test fun `when I finish reading the stream, getTruncatedBytes gives me the rest`() { - val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) inputStream.readFully() val truncatedBytes = inputStream.readTruncatedBytes() @@ -76,16 +76,22 @@ class TruncatingInputStreamTest { @Test fun `when I finish reading the stream, getTruncatedBytes gives me the rest, respecting the byte limit`() { - val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) inputStream.readFully() val truncatedBytes = inputStream.readTruncatedBytes(byteLimit = 10) assertEquals(10, truncatedBytes.size) } + @Test(expected = IllegalStateException::class) + fun `if I have not finished reading the stream, getTruncatedBytes throws IllegalStateException`() { + val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) + inputStream.readTruncatedBytes() + } + @Test fun `when I call available, it should respect the maxBytes`() { - val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) val available = inputStream.available() assertEquals(75, available) @@ -93,7 +99,7 @@ class TruncatingInputStreamTest { @Test fun `when I call available after reading some bytes, it should respect the maxBytes`() { - val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) inputStream.readNBytesOrThrow(50) val available = inputStream.available() @@ -103,7 +109,7 @@ class TruncatingInputStreamTest { @Test fun `when I mark and reset, it should jump back to the correct position`() { - val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) inputStream.mark(100) inputStream.readNBytesOrThrow(10) diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.java index 2dae01b6d4..6c85a6abe5 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.java @@ -6,7 +6,7 @@ package org.whispersystems.signalservice.api.crypto; -import org.signal.core.util.stream.TruncatingInputStream; +import org.signal.core.util.stream.LimitedInputStream; import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice; import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream; @@ -117,7 +117,7 @@ public class AttachmentCipherInputStream extends FilterInputStream { InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], streamLength - BLOCK_SIZE - mac.getMacLength()); if (plaintextLength != 0) { - inputStream = new TruncatingInputStream(inputStream, plaintextLength); + inputStream = new LimitedInputStream(inputStream, plaintextLength); } return inputStream; @@ -142,7 +142,7 @@ public class AttachmentCipherInputStream extends FilterInputStream { InputStream inputStream = new AttachmentCipherInputStream(new FileInputStream(file), archivedMediaKeyMaterial.getCipherKey(), file.length() - BLOCK_SIZE - mac.getMacLength()); if (originalCipherTextLength != 0) { - inputStream = new TruncatingInputStream(inputStream, originalCipherTextLength); + inputStream = new LimitedInputStream(inputStream, originalCipherTextLength); } return inputStream; @@ -179,7 +179,7 @@ public class AttachmentCipherInputStream extends FilterInputStream { InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], file.length() - BLOCK_SIZE - mac.getMacLength()); if (plaintextLength != 0) { - inputStream = new TruncatingInputStream(inputStream, plaintextLength); + inputStream = new LimitedInputStream(inputStream, plaintextLength); } return inputStream; diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/messages/multidevice/DeviceContactsInputStream.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/messages/multidevice/DeviceContactsInputStream.java index 12c11c0d29..f5df38d1b9 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/messages/multidevice/DeviceContactsInputStream.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/messages/multidevice/DeviceContactsInputStream.java @@ -6,15 +6,13 @@ package org.whispersystems.signalservice.api.messages.multidevice; -import org.signal.core.util.stream.TruncatingInputStream; +import org.signal.core.util.stream.LimitedInputStream; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.logging.Log; import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.profiles.ProfileKey; -import org.whispersystems.signalservice.api.messages.SignalServiceAttachment; -import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentStream; import org.whispersystems.signalservice.api.push.ServiceId; import org.whispersystems.signalservice.api.push.ServiceId.ACI; import org.whispersystems.signalservice.api.push.SignalServiceAddress; @@ -62,7 +60,7 @@ public class DeviceContactsInputStream extends ChunkedInputStream { if (details.avatar != null && details.avatar.length != null) { long avatarLength = details.avatar.length; - InputStream avatarStream = new TruncatingInputStream(in, avatarLength); + InputStream avatarStream = new LimitedInputStream(in, avatarLength); String avatarContentType = details.avatar.contentType != null ? details.avatar.contentType : "image/*"; avatar = Optional.of(new DeviceContactAvatar(avatarStream, avatarLength, avatarContentType)); diff --git a/video/lib/src/main/java/org/thoughtcrime/securesms/video/postprocessing/Mp4FaststartPostProcessor.kt b/video/lib/src/main/java/org/thoughtcrime/securesms/video/postprocessing/Mp4FaststartPostProcessor.kt index 7a4e5673bd..678f90e4d4 100644 --- a/video/lib/src/main/java/org/thoughtcrime/securesms/video/postprocessing/Mp4FaststartPostProcessor.kt +++ b/video/lib/src/main/java/org/thoughtcrime/securesms/video/postprocessing/Mp4FaststartPostProcessor.kt @@ -6,7 +6,7 @@ package org.thoughtcrime.securesms.video.postprocessing import org.signal.core.util.readLength -import org.signal.core.util.stream.TruncatingInputStream +import org.signal.core.util.stream.LimitedInputStream import org.signal.libsignal.media.Mp4Sanitizer import org.signal.libsignal.media.SanitizedMetadata import org.thoughtcrime.securesms.video.exceptions.VideoPostProcessingException @@ -34,7 +34,7 @@ class Mp4FaststartPostProcessor(private val inputStreamFactory: InputStreamFacto } val inputStream = inputStreamFactory.create() inputStream.skip(metadata.dataOffset) - return SequenceInputStream(ByteArrayInputStream(metadata.sanitizedMetadata), TruncatingInputStream(inputStream, metadata.dataLength)) + return SequenceInputStream(ByteArrayInputStream(metadata.sanitizedMetadata), LimitedInputStream(inputStream, metadata.dataLength)) } fun processAndWriteTo(outputStream: OutputStream, inputLength: Long = calculateStreamLength(inputStreamFactory.create())): Long {