Add support for more methods in TruncatingInputStream.
This commit is contained in:
parent
929942de9d
commit
b00855b097
4 changed files with 176 additions and 3 deletions
|
@ -5,8 +5,10 @@
|
||||||
|
|
||||||
package org.signal.core.util
|
package org.signal.core.util
|
||||||
|
|
||||||
|
import java.io.ByteArrayOutputStream
|
||||||
import java.io.IOException
|
import java.io.IOException
|
||||||
import java.io.InputStream
|
import java.io.InputStream
|
||||||
|
import kotlin.math.min
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reads a 32-bit variable-length integer from the stream.
|
* Reads a 32-bit variable-length integer from the stream.
|
||||||
|
@ -68,6 +70,29 @@ fun InputStream.readNBytesOrThrow(length: Int): ByteArray {
|
||||||
return buffer
|
return buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read at most [byteLimit] bytes from the stream.
|
||||||
|
*/
|
||||||
|
fun InputStream.readAtMostNBytes(byteLimit: Int): ByteArray {
|
||||||
|
val buffer = ByteArrayOutputStream()
|
||||||
|
val readBuffer = ByteArray(4096)
|
||||||
|
|
||||||
|
var remaining = byteLimit
|
||||||
|
while (remaining > 0) {
|
||||||
|
val bytesToRead = min(remaining, readBuffer.size)
|
||||||
|
val read = this.read(readBuffer, 0, bytesToRead)
|
||||||
|
|
||||||
|
if (read == -1) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer.write(readBuffer, 0, read)
|
||||||
|
remaining -= read
|
||||||
|
}
|
||||||
|
|
||||||
|
return buffer.toByteArray()
|
||||||
|
}
|
||||||
|
|
||||||
@Throws(IOException::class)
|
@Throws(IOException::class)
|
||||||
fun InputStream.readLength(): Long {
|
fun InputStream.readLength(): Long {
|
||||||
val buffer = ByteArray(4096)
|
val buffer = ByteArray(4096)
|
||||||
|
|
|
@ -5,9 +5,12 @@
|
||||||
|
|
||||||
package org.signal.core.util.stream
|
package org.signal.core.util.stream
|
||||||
|
|
||||||
|
import org.signal.core.util.readAtMostNBytes
|
||||||
|
import org.signal.core.util.readFully
|
||||||
import java.io.FilterInputStream
|
import java.io.FilterInputStream
|
||||||
import java.io.InputStream
|
import java.io.InputStream
|
||||||
import java.lang.UnsupportedOperationException
|
import java.lang.UnsupportedOperationException
|
||||||
|
import kotlin.math.min
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An [InputStream] that will read from the target [InputStream] until it reaches the end, or until it has read [maxBytes] bytes.
|
* An [InputStream] that will read from the target [InputStream] until it reaches the end, or until it has read [maxBytes] bytes.
|
||||||
|
@ -15,6 +18,7 @@ import java.lang.UnsupportedOperationException
|
||||||
class TruncatingInputStream(private val wrapped: InputStream, private val maxBytes: Long) : FilterInputStream(wrapped) {
|
class TruncatingInputStream(private val wrapped: InputStream, private val maxBytes: Long) : FilterInputStream(wrapped) {
|
||||||
|
|
||||||
private var bytesRead: Long = 0
|
private var bytesRead: Long = 0
|
||||||
|
private var lastMark = -1L
|
||||||
|
|
||||||
override fun read(): Int {
|
override fun read(): Int {
|
||||||
if (bytesRead >= maxBytes) {
|
if (bytesRead >= maxBytes) {
|
||||||
|
@ -48,11 +52,58 @@ class TruncatingInputStream(private val wrapped: InputStream, private val maxByt
|
||||||
return bytesRead
|
return bytesRead
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun skip(n: Long): Long {
|
override fun skip(requestedSkipCount: Long): Long {
|
||||||
throw UnsupportedOperationException()
|
val bytesRemaining: Long = maxBytes - bytesRead
|
||||||
|
val bytesToSkip: Long = min(bytesRemaining, requestedSkipCount)
|
||||||
|
|
||||||
|
return super.skip(bytesToSkip).also { bytesSkipped ->
|
||||||
|
if (bytesSkipped > 0) {
|
||||||
|
this.bytesRead += bytesSkipped
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun available(): Int {
|
||||||
|
val bytesRemaining = Math.toIntExact(maxBytes - bytesRead)
|
||||||
|
return min(bytesRemaining, wrapped.available())
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun markSupported(): Boolean {
|
||||||
|
return wrapped.markSupported()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun mark(readlimit: Int) {
|
||||||
|
if (!markSupported()) {
|
||||||
|
throw UnsupportedOperationException("Mark not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
wrapped.mark(readlimit)
|
||||||
|
lastMark = bytesRead
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun reset() {
|
override fun reset() {
|
||||||
throw UnsupportedOperationException()
|
if (!markSupported()) {
|
||||||
|
throw UnsupportedOperationException("Mark not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lastMark == -1L) {
|
||||||
|
throw UnsupportedOperationException("Mark not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
wrapped.reset()
|
||||||
|
bytesRead = lastMark
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If the stream has been fully read, this will return all bytes that were truncated from the stream.
|
||||||
|
*
|
||||||
|
* @param byteLimit The maximum number of truncated bytes to read. Defaults to no limit.
|
||||||
|
*/
|
||||||
|
fun readTruncatedBytes(byteLimit: Int = -1): ByteArray {
|
||||||
|
return if (byteLimit < 0) {
|
||||||
|
wrapped.readFully()
|
||||||
|
} else {
|
||||||
|
wrapped.readAtMostNBytes(byteLimit)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ package org.signal.core.util.stream
|
||||||
import org.junit.Assert.assertEquals
|
import org.junit.Assert.assertEquals
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
import org.signal.core.util.readFully
|
import org.signal.core.util.readFully
|
||||||
|
import org.signal.core.util.readNBytesOrThrow
|
||||||
|
|
||||||
class TruncatingInputStreamTest {
|
class TruncatingInputStreamTest {
|
||||||
|
|
||||||
|
@ -32,4 +33,84 @@ class TruncatingInputStreamTest {
|
||||||
|
|
||||||
assertEquals(75, count)
|
assertEquals(75, count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `when I skip past the maxBytes, I should get -1`() {
|
||||||
|
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
|
||||||
|
|
||||||
|
val skipCount = inputStream.skip(100)
|
||||||
|
val read = inputStream.read()
|
||||||
|
|
||||||
|
assertEquals(75, skipCount)
|
||||||
|
assertEquals(-1, read)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `when I skip, I should still truncate correctly afterwards`() {
|
||||||
|
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
|
||||||
|
|
||||||
|
val skipCount = inputStream.skip(50)
|
||||||
|
val data = inputStream.readFully()
|
||||||
|
|
||||||
|
assertEquals(50, skipCount)
|
||||||
|
assertEquals(25, data.size)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `when I skip more than maxBytes, I only skip maxBytes`() {
|
||||||
|
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
|
||||||
|
|
||||||
|
val skipCount = inputStream.skip(100)
|
||||||
|
|
||||||
|
assertEquals(75, skipCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `when I finish reading the stream, getTruncatedBytes gives me the rest`() {
|
||||||
|
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
|
||||||
|
inputStream.readFully()
|
||||||
|
|
||||||
|
val truncatedBytes = inputStream.readTruncatedBytes()
|
||||||
|
assertEquals(25, truncatedBytes.size)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `when I finish reading the stream, getTruncatedBytes gives me the rest, respecting the byte limit`() {
|
||||||
|
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
|
||||||
|
inputStream.readFully()
|
||||||
|
|
||||||
|
val truncatedBytes = inputStream.readTruncatedBytes(byteLimit = 10)
|
||||||
|
assertEquals(10, truncatedBytes.size)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `when I call available, it should respect the maxBytes`() {
|
||||||
|
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
|
||||||
|
val available = inputStream.available()
|
||||||
|
|
||||||
|
assertEquals(75, available)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `when I call available after reading some bytes, it should respect the maxBytes`() {
|
||||||
|
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
|
||||||
|
inputStream.readNBytesOrThrow(50)
|
||||||
|
|
||||||
|
val available = inputStream.available()
|
||||||
|
|
||||||
|
assertEquals(25, available)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `when I mark and reset, it should jump back to the correct position`() {
|
||||||
|
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
|
||||||
|
|
||||||
|
inputStream.mark(100)
|
||||||
|
inputStream.readNBytesOrThrow(10)
|
||||||
|
inputStream.reset()
|
||||||
|
|
||||||
|
val data = inputStream.readFully()
|
||||||
|
|
||||||
|
assertEquals(75, data.size)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,4 +19,20 @@ class InputStreamExtensionTests {
|
||||||
assertEquals(bytes.size.toLong(), length)
|
assertEquals(bytes.size.toLong(), length)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `when I call readAtMostNBytes, I only read that many bytes`() {
|
||||||
|
val bytes = ByteArray(100)
|
||||||
|
val inputStream = bytes.inputStream()
|
||||||
|
val readBytes = inputStream.readAtMostNBytes(50)
|
||||||
|
assertEquals(50, readBytes.size)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `when I call readAtMostNBytes, it will return at most the length of the stream`() {
|
||||||
|
val bytes = ByteArray(100)
|
||||||
|
val inputStream = bytes.inputStream()
|
||||||
|
val readBytes = inputStream.readAtMostNBytes(200)
|
||||||
|
assertEquals(100, readBytes.size)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue