Don't recreate attachment InputStream if we don't have to.
This commit is contained in:
parent
d1ef9d5dcf
commit
467dae8132
3 changed files with 181 additions and 17 deletions
|
@ -5,9 +5,9 @@
|
|||
|
||||
package org.signal.core.util
|
||||
|
||||
import java.io.EOFException
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
import kotlin.jvm.Throws
|
||||
|
||||
/**
|
||||
* Reads a 32-bit variable-length integer from the stream.
|
||||
|
@ -80,3 +80,26 @@ fun InputStream.readLength(): Long {
|
|||
|
||||
return count
|
||||
}
|
||||
|
||||
/**
|
||||
* Backported from AOSP API 31 source code.
|
||||
*
|
||||
* @param count number of bytes to skip
|
||||
*/
|
||||
@Throws(IOException::class)
|
||||
fun InputStream.skipNBytesCompat(count: Long) {
|
||||
var n = count
|
||||
while (n > 0) {
|
||||
val ns = skip(n)
|
||||
if (ns in 1..n) {
|
||||
n -= ns
|
||||
} else if (ns == 0L) {
|
||||
if (read() == -1) {
|
||||
throw EOFException()
|
||||
}
|
||||
n--
|
||||
} else {
|
||||
throw IOException("Unable to skip exactly")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,8 @@ package org.thoughtcrime.securesms.video.videoconverter.mediadatasource
|
|||
|
||||
import android.media.MediaDataSource
|
||||
import androidx.annotation.RequiresApi
|
||||
import org.signal.core.util.skipNBytesCompat
|
||||
import java.io.EOFException
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
|
||||
|
@ -15,30 +17,50 @@ import java.io.InputStream
|
|||
*/
|
||||
@RequiresApi(23)
|
||||
abstract class InputStreamMediaDataSource : MediaDataSource() {
|
||||
private var lastPositionRead = -1L
|
||||
private var lastUsedInputStream: InputStream? = null
|
||||
private val sink = ByteArray(2048)
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun readAt(position: Long, bytes: ByteArray?, offset: Int, length: Int): Int {
|
||||
if (position >= size) {
|
||||
if (position >= size || position < 0) {
|
||||
return -1
|
||||
}
|
||||
|
||||
createInputStream(position).use { inputStream ->
|
||||
var totalRead = 0
|
||||
while (totalRead < length) {
|
||||
val read: Int = inputStream.read(bytes, offset + totalRead, length - totalRead)
|
||||
if (read == -1) {
|
||||
return if (totalRead == 0) {
|
||||
-1
|
||||
} else {
|
||||
totalRead
|
||||
}
|
||||
}
|
||||
totalRead += read
|
||||
}
|
||||
return totalRead
|
||||
val inputStream = if (lastPositionRead > position || lastUsedInputStream == null) {
|
||||
lastUsedInputStream?.close()
|
||||
lastPositionRead = position
|
||||
createInputStream(position)
|
||||
} else {
|
||||
lastUsedInputStream!!
|
||||
}
|
||||
|
||||
try {
|
||||
inputStream.skipNBytesCompat(position - lastPositionRead)
|
||||
} catch (e: EOFException) {
|
||||
return -1
|
||||
}
|
||||
|
||||
var totalRead = 0
|
||||
while (totalRead < length) {
|
||||
val read: Int = inputStream.read(bytes, offset + totalRead, length - totalRead)
|
||||
if (read == -1) {
|
||||
return if (totalRead == 0) {
|
||||
-1
|
||||
} else {
|
||||
totalRead
|
||||
}
|
||||
}
|
||||
totalRead += read
|
||||
}
|
||||
lastPositionRead = totalRead + position
|
||||
lastUsedInputStream = inputStream
|
||||
return totalRead
|
||||
}
|
||||
|
||||
abstract override fun close()
|
||||
override fun close() {
|
||||
lastUsedInputStream?.close()
|
||||
}
|
||||
|
||||
abstract override fun getSize(): Long
|
||||
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
package org.thoughtcrime.securesms.video.videoconverter.mediadatasource
|
||||
|
||||
import org.junit.Assert.assertArrayEquals
|
||||
import org.junit.Assert.assertEquals
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import org.signal.core.util.skipNBytesCompat
|
||||
import java.io.ByteArrayInputStream
|
||||
import java.io.InputStream
|
||||
import java.util.Arrays
|
||||
import kotlin.random.Random
|
||||
|
||||
@OptIn(ExperimentalStdlibApi::class)
|
||||
class InputStreamMediaDataSourceTest {
|
||||
companion object {
|
||||
const val BUFFER_SIZE = 1024
|
||||
const val DATA_SIZE = 8192
|
||||
}
|
||||
private lateinit var dataSource: TestInputStreamMediaDataSource
|
||||
private val outputBuffer = ByteArray(BUFFER_SIZE)
|
||||
|
||||
@Before
|
||||
fun setUp() {
|
||||
dataSource = TestInputStreamMediaDataSource(Random.Default.nextBytes(DATA_SIZE))
|
||||
Arrays.fill(outputBuffer, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Happy path test for reading from the start of the stream.
|
||||
*/
|
||||
@Test
|
||||
fun testStartRead() {
|
||||
val readLength = BUFFER_SIZE
|
||||
dataSource.readAt(0, outputBuffer, 0, readLength)
|
||||
assertArrayEquals(dataSource.getSliceOfData(0..<readLength), outputBuffer)
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure that reading from a specified index works.
|
||||
*/
|
||||
@Test
|
||||
fun testSkipForward() {
|
||||
val readLength = BUFFER_SIZE
|
||||
val skipOffset = BUFFER_SIZE
|
||||
val endIndex = skipOffset + readLength
|
||||
dataSource.readAt(skipOffset.toLong(), outputBuffer, 0, readLength)
|
||||
assertArrayEquals(dataSource.getSliceOfData(skipOffset..<endIndex), outputBuffer)
|
||||
}
|
||||
|
||||
/**
|
||||
* "Skipping backwards" actually involves recreating the underlying stream and skipping forwards. This tests that.
|
||||
*/
|
||||
@Test
|
||||
fun testSkipBackward() {
|
||||
val readLength = BUFFER_SIZE
|
||||
val skipOffset = BUFFER_SIZE
|
||||
val skipAheadAmount = skipOffset * 2
|
||||
val endIndex = skipOffset + readLength
|
||||
dataSource.readAt(skipAheadAmount.toLong(), outputBuffer, 0, readLength)
|
||||
|
||||
dataSource.readAt(skipOffset.toLong(), outputBuffer, 0, readLength)
|
||||
assertArrayEquals(dataSource.getSliceOfData(skipOffset..<endIndex), outputBuffer)
|
||||
}
|
||||
|
||||
/**
|
||||
* Successfully read the final n bytes of a stream, even though >n were requested
|
||||
*/
|
||||
@Test
|
||||
fun testReadPastInputStreamSize() {
|
||||
val readLength = 512
|
||||
val distanceFromEnd = readLength / 2
|
||||
val skipOffset = DATA_SIZE - distanceFromEnd
|
||||
val readResult = dataSource.readAt(skipOffset.toLong(), outputBuffer, 0, readLength)
|
||||
|
||||
assertEquals(distanceFromEnd, readResult)
|
||||
assertArrayEquals(dataSource.getSliceOfData(skipOffset..<DATA_SIZE), outputBuffer.sliceArray(0..<distanceFromEnd))
|
||||
}
|
||||
|
||||
/**
|
||||
* Successfully read the final n bytes of a stream, even though >n were requested
|
||||
*/
|
||||
@Test
|
||||
fun testReadUpToEndAndThenKeepReading() {
|
||||
val readLength = 512
|
||||
val distanceFromEnd = readLength / 2
|
||||
val skipOffset = DATA_SIZE - distanceFromEnd
|
||||
|
||||
val readResultLastOfStream = dataSource.readAt(skipOffset.toLong(), outputBuffer, 0, readLength)
|
||||
val readResultAtEndOfStream = dataSource.readAt((skipOffset + readResultLastOfStream).toLong(), outputBuffer, 0, readLength)
|
||||
|
||||
assertEquals(-1, readResultAtEndOfStream)
|
||||
assertArrayEquals(dataSource.getSliceOfData(skipOffset..<DATA_SIZE), outputBuffer.sliceArray(0..<distanceFromEnd))
|
||||
}
|
||||
|
||||
/**
|
||||
* A negative position is outside the stream, should return EOS.
|
||||
*/
|
||||
@Test
|
||||
fun testReadNegativePosition() {
|
||||
val readResult = dataSource.readAt(-128, outputBuffer, 0, BUFFER_SIZE)
|
||||
|
||||
assertEquals(-1, readResult)
|
||||
assertArrayEquals(ByteArray(BUFFER_SIZE), outputBuffer)
|
||||
}
|
||||
|
||||
private class TestInputStreamMediaDataSource(private val data: ByteArray) : InputStreamMediaDataSource() {
|
||||
override fun getSize() = data.size.toLong()
|
||||
|
||||
override fun createInputStream(position: Long): InputStream {
|
||||
val inputStream = ByteArrayInputStream(data)
|
||||
inputStream.skipNBytesCompat(position)
|
||||
return inputStream
|
||||
}
|
||||
|
||||
fun getSliceOfData(indices: IntRange): ByteArray {
|
||||
return data.sliceArray(indices)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue