Revert "Don't recreate attachment InputStream if we don't have to."

This reverts commit 467dae8132.
This commit is contained in:
Nicholas Tinsley 2024-02-05 10:49:05 -05:00
parent adf1d8a43a
commit f18070b78c
3 changed files with 15 additions and 179 deletions

View file

@ -5,9 +5,9 @@
package org.signal.core.util
import java.io.EOFException
import java.io.IOException
import java.io.InputStream
import kotlin.jvm.Throws
/**
* Reads a 32-bit variable-length integer from the stream.
@ -80,26 +80,3 @@ fun InputStream.readLength(): Long {
return count
}
/**
* Backported from AOSP API 31 source code.
*
* @param count number of bytes to skip
*/
@Throws(IOException::class)
fun InputStream.skipNBytesCompat(count: Long) {
var n = count
while (n > 0) {
val ns = skip(n)
if (ns in 1..n) {
n -= ns
} else if (ns == 0L) {
if (read() == -1) {
throw EOFException()
}
n--
} else {
throw IOException("Unable to skip exactly")
}
}
}

View file

@ -7,8 +7,6 @@ package org.thoughtcrime.securesms.video.videoconverter.mediadatasource
import android.media.MediaDataSource
import androidx.annotation.RequiresApi
import org.signal.core.util.skipNBytesCompat
import java.io.EOFException
import java.io.IOException
import java.io.InputStream
@ -17,50 +15,30 @@ import java.io.InputStream
*/
@RequiresApi(23)
abstract class InputStreamMediaDataSource : MediaDataSource() {
private var lastPositionRead = -1L
private var lastUsedInputStream: InputStream? = null
private val sink = ByteArray(2048)
@Throws(IOException::class)
override fun readAt(position: Long, bytes: ByteArray?, offset: Int, length: Int): Int {
if (position >= size || position < 0) {
if (position >= size) {
return -1
}
val inputStream = if (lastPositionRead > position || lastUsedInputStream == null) {
lastUsedInputStream?.close()
lastPositionRead = position
createInputStream(position)
} else {
lastUsedInputStream!!
}
try {
inputStream.skipNBytesCompat(position - lastPositionRead)
} catch (e: EOFException) {
return -1
}
var totalRead = 0
while (totalRead < length) {
val read: Int = inputStream.read(bytes, offset + totalRead, length - totalRead)
if (read == -1) {
return if (totalRead == 0) {
-1
} else {
totalRead
createInputStream(position).use { inputStream ->
var totalRead = 0
while (totalRead < length) {
val read: Int = inputStream.read(bytes, offset + totalRead, length - totalRead)
if (read == -1) {
return if (totalRead == 0) {
-1
} else {
totalRead
}
}
totalRead += read
}
totalRead += read
return totalRead
}
lastPositionRead = totalRead + position
lastUsedInputStream = inputStream
return totalRead
}
override fun close() {
lastUsedInputStream?.close()
}
abstract override fun close()
abstract override fun getSize(): Long

View file

@ -1,119 +0,0 @@
package org.thoughtcrime.securesms.video.videoconverter.mediadatasource
import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
import org.junit.Before
import org.junit.Test
import org.signal.core.util.skipNBytesCompat
import java.io.ByteArrayInputStream
import java.io.InputStream
import java.util.Arrays
import kotlin.random.Random
@OptIn(ExperimentalStdlibApi::class)
class InputStreamMediaDataSourceTest {
companion object {
const val BUFFER_SIZE = 1024
const val DATA_SIZE = 8192
}
private lateinit var dataSource: TestInputStreamMediaDataSource
private val outputBuffer = ByteArray(BUFFER_SIZE)
@Before
fun setUp() {
dataSource = TestInputStreamMediaDataSource(Random.Default.nextBytes(DATA_SIZE))
Arrays.fill(outputBuffer, 0)
}
/**
* Happy path test for reading from the start of the stream.
*/
@Test
fun testStartRead() {
val readLength = BUFFER_SIZE
dataSource.readAt(0, outputBuffer, 0, readLength)
assertArrayEquals(dataSource.getSliceOfData(0..<readLength), outputBuffer)
}
/**
* Make sure that reading from a specified index works.
*/
@Test
fun testSkipForward() {
val readLength = BUFFER_SIZE
val skipOffset = BUFFER_SIZE
val endIndex = skipOffset + readLength
dataSource.readAt(skipOffset.toLong(), outputBuffer, 0, readLength)
assertArrayEquals(dataSource.getSliceOfData(skipOffset..<endIndex), outputBuffer)
}
/**
* "Skipping backwards" actually involves recreating the underlying stream and skipping forwards. This tests that.
*/
@Test
fun testSkipBackward() {
val readLength = BUFFER_SIZE
val skipOffset = BUFFER_SIZE
val skipAheadAmount = skipOffset * 2
val endIndex = skipOffset + readLength
dataSource.readAt(skipAheadAmount.toLong(), outputBuffer, 0, readLength)
dataSource.readAt(skipOffset.toLong(), outputBuffer, 0, readLength)
assertArrayEquals(dataSource.getSliceOfData(skipOffset..<endIndex), outputBuffer)
}
/**
* Successfully read the final n bytes of a stream, even though >n were requested
*/
@Test
fun testReadPastInputStreamSize() {
val readLength = 512
val distanceFromEnd = readLength / 2
val skipOffset = DATA_SIZE - distanceFromEnd
val readResult = dataSource.readAt(skipOffset.toLong(), outputBuffer, 0, readLength)
assertEquals(distanceFromEnd, readResult)
assertArrayEquals(dataSource.getSliceOfData(skipOffset..<DATA_SIZE), outputBuffer.sliceArray(0..<distanceFromEnd))
}
/**
* Successfully read the final n bytes of a stream, even though >n were requested
*/
@Test
fun testReadUpToEndAndThenKeepReading() {
val readLength = 512
val distanceFromEnd = readLength / 2
val skipOffset = DATA_SIZE - distanceFromEnd
val readResultLastOfStream = dataSource.readAt(skipOffset.toLong(), outputBuffer, 0, readLength)
val readResultAtEndOfStream = dataSource.readAt((skipOffset + readResultLastOfStream).toLong(), outputBuffer, 0, readLength)
assertEquals(-1, readResultAtEndOfStream)
assertArrayEquals(dataSource.getSliceOfData(skipOffset..<DATA_SIZE), outputBuffer.sliceArray(0..<distanceFromEnd))
}
/**
* A negative position is outside the stream, should return EOS.
*/
@Test
fun testReadNegativePosition() {
val readResult = dataSource.readAt(-128, outputBuffer, 0, BUFFER_SIZE)
assertEquals(-1, readResult)
assertArrayEquals(ByteArray(BUFFER_SIZE), outputBuffer)
}
private class TestInputStreamMediaDataSource(private val data: ByteArray) : InputStreamMediaDataSource() {
override fun getSize() = data.size.toLong()
override fun createInputStream(position: Long): InputStream {
val inputStream = ByteArrayInputStream(data)
inputStream.skipNBytesCompat(position)
return inputStream
}
fun getSliceOfData(indices: IntRange): ByteArray {
return data.sliceArray(indices)
}
}
}