From b00855b097b68f302a839b9c6b1d590d7db97110 Mon Sep 17 00:00:00 2001 From: Greyson Parrelli Date: Thu, 5 Sep 2024 16:40:18 -0400 Subject: [PATCH] Add support for more methods in TruncatingInputStream. --- .../signal/core/util/InputStreamExtensions.kt | 25 ++++++ .../core/util/stream/TruncatingInputStream.kt | 57 ++++++++++++- .../util/stream/TruncatingInputStreamTest.kt | 81 +++++++++++++++++++ .../core/util/InputStreamExtensionTests.kt | 16 ++++ 4 files changed, 176 insertions(+), 3 deletions(-) diff --git a/core-util-jvm/src/main/java/org/signal/core/util/InputStreamExtensions.kt b/core-util-jvm/src/main/java/org/signal/core/util/InputStreamExtensions.kt index c0dbdd2ef6..fc7a0eb7f2 100644 --- a/core-util-jvm/src/main/java/org/signal/core/util/InputStreamExtensions.kt +++ b/core-util-jvm/src/main/java/org/signal/core/util/InputStreamExtensions.kt @@ -5,8 +5,10 @@ package org.signal.core.util +import java.io.ByteArrayOutputStream import java.io.IOException import java.io.InputStream +import kotlin.math.min /** * Reads a 32-bit variable-length integer from the stream. @@ -68,6 +70,29 @@ fun InputStream.readNBytesOrThrow(length: Int): ByteArray { return buffer } +/** + * Read at most [byteLimit] bytes from the stream. + */ +fun InputStream.readAtMostNBytes(byteLimit: Int): ByteArray { + val buffer = ByteArrayOutputStream() + val readBuffer = ByteArray(4096) + + var remaining = byteLimit + while (remaining > 0) { + val bytesToRead = min(remaining, readBuffer.size) + val read = this.read(readBuffer, 0, bytesToRead) + + if (read == -1) { + break + } + + buffer.write(readBuffer, 0, read) + remaining -= read + } + + return buffer.toByteArray() +} + @Throws(IOException::class) fun InputStream.readLength(): Long { val buffer = ByteArray(4096) diff --git a/core-util-jvm/src/main/java/org/signal/core/util/stream/TruncatingInputStream.kt b/core-util-jvm/src/main/java/org/signal/core/util/stream/TruncatingInputStream.kt index 36c7dec3d1..5760e0c508 100644 --- a/core-util-jvm/src/main/java/org/signal/core/util/stream/TruncatingInputStream.kt +++ b/core-util-jvm/src/main/java/org/signal/core/util/stream/TruncatingInputStream.kt @@ -5,9 +5,12 @@ package org.signal.core.util.stream +import org.signal.core.util.readAtMostNBytes +import org.signal.core.util.readFully import java.io.FilterInputStream import java.io.InputStream import java.lang.UnsupportedOperationException +import kotlin.math.min /** * An [InputStream] that will read from the target [InputStream] until it reaches the end, or until it has read [maxBytes] bytes. @@ -15,6 +18,7 @@ import java.lang.UnsupportedOperationException class TruncatingInputStream(private val wrapped: InputStream, private val maxBytes: Long) : FilterInputStream(wrapped) { private var bytesRead: Long = 0 + private var lastMark = -1L override fun read(): Int { if (bytesRead >= maxBytes) { @@ -48,11 +52,58 @@ class TruncatingInputStream(private val wrapped: InputStream, private val maxByt return bytesRead } - override fun skip(n: Long): Long { - throw UnsupportedOperationException() + override fun skip(requestedSkipCount: Long): Long { + val bytesRemaining: Long = maxBytes - bytesRead + val bytesToSkip: Long = min(bytesRemaining, requestedSkipCount) + + return super.skip(bytesToSkip).also { bytesSkipped -> + if (bytesSkipped > 0) { + this.bytesRead += bytesSkipped + } + } + } + + override fun available(): Int { + val bytesRemaining = Math.toIntExact(maxBytes - bytesRead) + return min(bytesRemaining, wrapped.available()) + } + + override fun markSupported(): Boolean { + return wrapped.markSupported() + } + + override fun mark(readlimit: Int) { + if (!markSupported()) { + throw UnsupportedOperationException("Mark not supported") + } + + wrapped.mark(readlimit) + lastMark = bytesRead } override fun reset() { - throw UnsupportedOperationException() + if (!markSupported()) { + throw UnsupportedOperationException("Mark not supported") + } + + if (lastMark == -1L) { + throw UnsupportedOperationException("Mark not set") + } + + wrapped.reset() + bytesRead = lastMark + } + + /** + * If the stream has been fully read, this will return all bytes that were truncated from the stream. + * + * @param byteLimit The maximum number of truncated bytes to read. Defaults to no limit. + */ + fun readTruncatedBytes(byteLimit: Int = -1): ByteArray { + return if (byteLimit < 0) { + wrapped.readFully() + } else { + wrapped.readAtMostNBytes(byteLimit) + } } } diff --git a/core-util-jvm/src/test/java/org/signal/core/util/stream/TruncatingInputStreamTest.kt b/core-util-jvm/src/test/java/org/signal/core/util/stream/TruncatingInputStreamTest.kt index c62f239497..b51fea4c54 100644 --- a/core-util-jvm/src/test/java/org/signal/core/util/stream/TruncatingInputStreamTest.kt +++ b/core-util-jvm/src/test/java/org/signal/core/util/stream/TruncatingInputStreamTest.kt @@ -8,6 +8,7 @@ package org.signal.core.util.stream import org.junit.Assert.assertEquals import org.junit.Test import org.signal.core.util.readFully +import org.signal.core.util.readNBytesOrThrow class TruncatingInputStreamTest { @@ -32,4 +33,84 @@ class TruncatingInputStreamTest { assertEquals(75, count) } + + @Test + fun `when I skip past the maxBytes, I should get -1`() { + val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + + val skipCount = inputStream.skip(100) + val read = inputStream.read() + + assertEquals(75, skipCount) + assertEquals(-1, read) + } + + @Test + fun `when I skip, I should still truncate correctly afterwards`() { + val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + + val skipCount = inputStream.skip(50) + val data = inputStream.readFully() + + assertEquals(50, skipCount) + assertEquals(25, data.size) + } + + @Test + fun `when I skip more than maxBytes, I only skip maxBytes`() { + val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + + val skipCount = inputStream.skip(100) + + assertEquals(75, skipCount) + } + + @Test + fun `when I finish reading the stream, getTruncatedBytes gives me the rest`() { + val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + inputStream.readFully() + + val truncatedBytes = inputStream.readTruncatedBytes() + assertEquals(25, truncatedBytes.size) + } + + @Test + fun `when I finish reading the stream, getTruncatedBytes gives me the rest, respecting the byte limit`() { + val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + inputStream.readFully() + + val truncatedBytes = inputStream.readTruncatedBytes(byteLimit = 10) + assertEquals(10, truncatedBytes.size) + } + + @Test + fun `when I call available, it should respect the maxBytes`() { + val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + val available = inputStream.available() + + assertEquals(75, available) + } + + @Test + fun `when I call available after reading some bytes, it should respect the maxBytes`() { + val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + inputStream.readNBytesOrThrow(50) + + val available = inputStream.available() + + assertEquals(25, available) + } + + @Test + fun `when I mark and reset, it should jump back to the correct position`() { + val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75) + + inputStream.mark(100) + inputStream.readNBytesOrThrow(10) + inputStream.reset() + + val data = inputStream.readFully() + + assertEquals(75, data.size) + } } diff --git a/core-util/src/test/java/org/signal/core/util/InputStreamExtensionTests.kt b/core-util/src/test/java/org/signal/core/util/InputStreamExtensionTests.kt index 1aeaa1c05e..8c85bff813 100644 --- a/core-util/src/test/java/org/signal/core/util/InputStreamExtensionTests.kt +++ b/core-util/src/test/java/org/signal/core/util/InputStreamExtensionTests.kt @@ -19,4 +19,20 @@ class InputStreamExtensionTests { assertEquals(bytes.size.toLong(), length) } } + + @Test + fun `when I call readAtMostNBytes, I only read that many bytes`() { + val bytes = ByteArray(100) + val inputStream = bytes.inputStream() + val readBytes = inputStream.readAtMostNBytes(50) + assertEquals(50, readBytes.size) + } + + @Test + fun `when I call readAtMostNBytes, it will return at most the length of the stream`() { + val bytes = ByteArray(100) + val inputStream = bytes.inputStream() + val readBytes = inputStream.readAtMostNBytes(200) + assertEquals(100, readBytes.size) + } }