Add support for more methods in TruncatingInputStream.

This commit is contained in:
Greyson Parrelli 2024-09-05 16:40:18 -04:00 committed by Cody Henthorne
parent 929942de9d
commit b00855b097
4 changed files with 176 additions and 3 deletions

View file

@ -5,8 +5,10 @@
package org.signal.core.util
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream
import kotlin.math.min
/**
* Reads a 32-bit variable-length integer from the stream.
@ -68,6 +70,29 @@ fun InputStream.readNBytesOrThrow(length: Int): ByteArray {
return buffer
}
/**
* Read at most [byteLimit] bytes from the stream.
*/
fun InputStream.readAtMostNBytes(byteLimit: Int): ByteArray {
val buffer = ByteArrayOutputStream()
val readBuffer = ByteArray(4096)
var remaining = byteLimit
while (remaining > 0) {
val bytesToRead = min(remaining, readBuffer.size)
val read = this.read(readBuffer, 0, bytesToRead)
if (read == -1) {
break
}
buffer.write(readBuffer, 0, read)
remaining -= read
}
return buffer.toByteArray()
}
@Throws(IOException::class)
fun InputStream.readLength(): Long {
val buffer = ByteArray(4096)

View file

@ -5,9 +5,12 @@
package org.signal.core.util.stream
import org.signal.core.util.readAtMostNBytes
import org.signal.core.util.readFully
import java.io.FilterInputStream
import java.io.InputStream
import java.lang.UnsupportedOperationException
import kotlin.math.min
/**
* An [InputStream] that will read from the target [InputStream] until it reaches the end, or until it has read [maxBytes] bytes.
@ -15,6 +18,7 @@ import java.lang.UnsupportedOperationException
class TruncatingInputStream(private val wrapped: InputStream, private val maxBytes: Long) : FilterInputStream(wrapped) {
private var bytesRead: Long = 0
private var lastMark = -1L
override fun read(): Int {
if (bytesRead >= maxBytes) {
@ -48,11 +52,58 @@ class TruncatingInputStream(private val wrapped: InputStream, private val maxByt
return bytesRead
}
override fun skip(n: Long): Long {
throw UnsupportedOperationException()
override fun skip(requestedSkipCount: Long): Long {
val bytesRemaining: Long = maxBytes - bytesRead
val bytesToSkip: Long = min(bytesRemaining, requestedSkipCount)
return super.skip(bytesToSkip).also { bytesSkipped ->
if (bytesSkipped > 0) {
this.bytesRead += bytesSkipped
}
}
}
override fun available(): Int {
val bytesRemaining = Math.toIntExact(maxBytes - bytesRead)
return min(bytesRemaining, wrapped.available())
}
override fun markSupported(): Boolean {
return wrapped.markSupported()
}
override fun mark(readlimit: Int) {
if (!markSupported()) {
throw UnsupportedOperationException("Mark not supported")
}
wrapped.mark(readlimit)
lastMark = bytesRead
}
override fun reset() {
throw UnsupportedOperationException()
if (!markSupported()) {
throw UnsupportedOperationException("Mark not supported")
}
if (lastMark == -1L) {
throw UnsupportedOperationException("Mark not set")
}
wrapped.reset()
bytesRead = lastMark
}
/**
* If the stream has been fully read, this will return all bytes that were truncated from the stream.
*
* @param byteLimit The maximum number of truncated bytes to read. Defaults to no limit.
*/
fun readTruncatedBytes(byteLimit: Int = -1): ByteArray {
return if (byteLimit < 0) {
wrapped.readFully()
} else {
wrapped.readAtMostNBytes(byteLimit)
}
}
}

View file

@ -8,6 +8,7 @@ package org.signal.core.util.stream
import org.junit.Assert.assertEquals
import org.junit.Test
import org.signal.core.util.readFully
import org.signal.core.util.readNBytesOrThrow
class TruncatingInputStreamTest {
@ -32,4 +33,84 @@ class TruncatingInputStreamTest {
assertEquals(75, count)
}
@Test
fun `when I skip past the maxBytes, I should get -1`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
val skipCount = inputStream.skip(100)
val read = inputStream.read()
assertEquals(75, skipCount)
assertEquals(-1, read)
}
@Test
fun `when I skip, I should still truncate correctly afterwards`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
val skipCount = inputStream.skip(50)
val data = inputStream.readFully()
assertEquals(50, skipCount)
assertEquals(25, data.size)
}
@Test
fun `when I skip more than maxBytes, I only skip maxBytes`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
val skipCount = inputStream.skip(100)
assertEquals(75, skipCount)
}
@Test
fun `when I finish reading the stream, getTruncatedBytes gives me the rest`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
inputStream.readFully()
val truncatedBytes = inputStream.readTruncatedBytes()
assertEquals(25, truncatedBytes.size)
}
@Test
fun `when I finish reading the stream, getTruncatedBytes gives me the rest, respecting the byte limit`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
inputStream.readFully()
val truncatedBytes = inputStream.readTruncatedBytes(byteLimit = 10)
assertEquals(10, truncatedBytes.size)
}
@Test
fun `when I call available, it should respect the maxBytes`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
val available = inputStream.available()
assertEquals(75, available)
}
@Test
fun `when I call available after reading some bytes, it should respect the maxBytes`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
inputStream.readNBytesOrThrow(50)
val available = inputStream.available()
assertEquals(25, available)
}
@Test
fun `when I mark and reset, it should jump back to the correct position`() {
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
inputStream.mark(100)
inputStream.readNBytesOrThrow(10)
inputStream.reset()
val data = inputStream.readFully()
assertEquals(75, data.size)
}
}

View file

@ -19,4 +19,20 @@ class InputStreamExtensionTests {
assertEquals(bytes.size.toLong(), length)
}
}
@Test
fun `when I call readAtMostNBytes, I only read that many bytes`() {
val bytes = ByteArray(100)
val inputStream = bytes.inputStream()
val readBytes = inputStream.readAtMostNBytes(50)
assertEquals(50, readBytes.size)
}
@Test
fun `when I call readAtMostNBytes, it will return at most the length of the stream`() {
val bytes = ByteArray(100)
val inputStream = bytes.inputStream()
val readBytes = inputStream.readAtMostNBytes(200)
assertEquals(100, readBytes.size)
}
}