Add the ability to set no limit on LimitedInputStream.

This commit is contained in:
Greyson Parrelli 2024-09-06 08:44:14 -04:00 committed by Cody Henthorne
parent a8fb4eb21a
commit 85d90aa121
2 changed files with 72 additions and 0 deletions

View file

@ -14,6 +14,8 @@ import kotlin.math.min
/**
* An [InputStream] that will read from the target [InputStream] until it reaches the end, or until it has read [maxBytes] bytes.
*
* @param maxBytes The maximum number of bytes to read from the stream. If set to -1, there will be no limit.
*/
class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: Long) : FilterInputStream(wrapped) {
@ -21,6 +23,10 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes:
private var lastMark = -1L
override fun read(): Int {
if (maxBytes == -1L) {
return wrapped.read()
}
if (totalBytesRead >= maxBytes) {
return -1
}
@ -38,6 +44,10 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes:
}
override fun read(destination: ByteArray, offset: Int, length: Int): Int {
if (maxBytes == -1L) {
return wrapped.read(destination, offset, length)
}
if (totalBytesRead >= maxBytes) {
return -1
}
@ -54,6 +64,10 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes:
}
override fun skip(requestedSkipCount: Long): Long {
if (maxBytes == -1L) {
return wrapped.skip(requestedSkipCount)
}
val bytesRemaining: Long = maxBytes - totalBytesRead
val bytesToSkip: Long = min(bytesRemaining, requestedSkipCount)
val skipCount = super.skip(bytesToSkip)
@ -64,6 +78,10 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes:
}
override fun available(): Int {
if (maxBytes == -1L) {
return wrapped.available()
}
val bytesRemaining = Math.toIntExact(maxBytes - totalBytesRead)
return min(bytesRemaining, wrapped.available())
}
@ -78,6 +96,11 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes:
}
wrapped.mark(readlimit)
if (maxBytes == -1L) {
return
}
lastMark = totalBytesRead
}
@ -91,15 +114,25 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes:
}
wrapped.reset()
if (maxBytes == -1L) {
return
}
totalBytesRead = lastMark
}
/**
* If the stream has been fully read, this will return all bytes that were truncated from the stream.
* If the stream was setup with no limit, this will always return an empty array.
*
* @param byteLimit The maximum number of truncated bytes to read. Defaults to no limit.
*/
fun readTruncatedBytes(byteLimit: Int = -1): ByteArray {
if (maxBytes == -1L) {
return ByteArray(0)
}
if (totalBytesRead < maxBytes) {
throw IllegalStateException("Stream has not been fully read")
}

View file

@ -20,6 +20,14 @@ class LimitedInputStreamTest {
assertEquals(75, data.size)
}
@Test
fun `when I fully read the stream via a buffer with no limit, I should get all bytes`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1)
val data = inputStream.readFully()
assertEquals(100, data.size)
}
@Test
fun `when I fully read the stream one byte at a time, I should only get maxBytes`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
@ -34,6 +42,20 @@ class LimitedInputStreamTest {
assertEquals(75, count)
}
@Test
fun `when I fully read the stream one byte at a time with no limit, I should only get maxBytes`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1)
var count = 0
var lastRead = inputStream.read()
while (lastRead != -1) {
count++
lastRead = inputStream.read()
}
assertEquals(100, count)
}
@Test
fun `when I skip past the maxBytes, I should get -1`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
@ -89,6 +111,15 @@ class LimitedInputStreamTest {
inputStream.readTruncatedBytes()
}
@Test
fun `when call getTruncatedBytes on a stream with no limit, it returns an empty array`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1)
inputStream.readFully()
val truncatedBytes = inputStream.readTruncatedBytes()
assertEquals(0, truncatedBytes.size)
}
@Test
fun `when I call available, it should respect the maxBytes`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
@ -97,6 +128,14 @@ class LimitedInputStreamTest {
assertEquals(75, available)
}
@Test
fun `when I call available with no limit, it should return the full length`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1)
val available = inputStream.available()
assertEquals(100, available)
}
@Test
fun `when I call available after reading some bytes, it should respect the maxBytes`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)