Fix digests for non-zero padding.
This commit is contained in:
parent
a50f316659
commit
1e8626647e
10 changed files with 256 additions and 62 deletions
|
@ -1,15 +1,23 @@
|
|||
package org.thoughtcrime.securesms.database
|
||||
|
||||
import android.content.Context
|
||||
import android.net.Uri
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
import androidx.test.filters.FlakyTest
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import org.junit.Assert.assertArrayEquals
|
||||
import org.junit.Assert.assertEquals
|
||||
import org.junit.Assert.assertNotEquals
|
||||
import org.junit.Before
|
||||
import org.junit.Ignore
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.signal.core.util.copyTo
|
||||
import org.signal.core.util.readFully
|
||||
import org.signal.core.util.stream.NullOutputStream
|
||||
import org.thoughtcrime.securesms.attachments.Attachment
|
||||
import org.thoughtcrime.securesms.attachments.AttachmentId
|
||||
import org.thoughtcrime.securesms.attachments.PointerAttachment
|
||||
import org.thoughtcrime.securesms.attachments.UriAttachment
|
||||
import org.thoughtcrime.securesms.mms.MediaStream
|
||||
import org.thoughtcrime.securesms.mms.SentMediaQuality
|
||||
|
@ -17,6 +25,15 @@ import org.thoughtcrime.securesms.providers.BlobProvider
|
|||
import org.thoughtcrime.securesms.testing.assertIs
|
||||
import org.thoughtcrime.securesms.testing.assertIsNot
|
||||
import org.thoughtcrime.securesms.util.MediaUtil
|
||||
import org.thoughtcrime.securesms.util.Util
|
||||
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream
|
||||
import org.whispersystems.signalservice.api.crypto.AttachmentCipherOutputStream
|
||||
import org.whispersystems.signalservice.api.crypto.NoCipherOutputStream
|
||||
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer
|
||||
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId
|
||||
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.io.File
|
||||
import java.util.Optional
|
||||
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
|
@ -163,6 +180,91 @@ class AttachmentTableTest {
|
|||
highInfo.file.exists() assertIs true
|
||||
}
|
||||
|
||||
@Test
|
||||
fun finalizeAttachmentAfterDownload_fixDigestOnNonZeroPadding() {
|
||||
// Insert attachment metadata for badly-padded attachment
|
||||
val plaintext = byteArrayOf(1, 2, 3, 4)
|
||||
val key = Util.getSecretBytes(64)
|
||||
val iv = Util.getSecretBytes(16)
|
||||
|
||||
val badlyPaddedPlaintext = PaddingInputStream(plaintext.inputStream(), plaintext.size.toLong()).readFully().also { it[it.size - 1] = 0x42 }
|
||||
val badlyPaddedCiphertext = encryptPrePaddedBytes(badlyPaddedPlaintext, key, iv)
|
||||
val badlyPaddedDigest = getDigest(badlyPaddedCiphertext)
|
||||
|
||||
val cipherFile = getTempFile()
|
||||
cipherFile.writeBytes(badlyPaddedCiphertext)
|
||||
|
||||
val mmsId = -1L
|
||||
val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, badlyPaddedDigest, plaintext.size)), emptyList()).values.first()
|
||||
|
||||
// Give data to attachment table
|
||||
val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, badlyPaddedDigest, null, 4, false)
|
||||
SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream, iv)
|
||||
|
||||
// Verify the digest has been updated to the properly padded one
|
||||
val properlyPaddedPlaintext = PaddingInputStream(plaintext.inputStream(), plaintext.size.toLong()).readFully()
|
||||
val properlyPaddedCiphertext = encryptPrePaddedBytes(properlyPaddedPlaintext, key, iv)
|
||||
val properlyPaddedDigest = getDigest(properlyPaddedCiphertext)
|
||||
|
||||
val newDigest = SignalDatabase.attachments.getAttachment(attachmentId)!!.remoteDigest!!
|
||||
|
||||
assertArrayEquals(properlyPaddedDigest, newDigest)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun finalizeAttachmentAfterDownload_leaveDigestAloneForAllZeroPadding() {
|
||||
// Insert attachment metadata for properly-padded attachment
|
||||
val plaintext = byteArrayOf(1, 2, 3, 4)
|
||||
val key = Util.getSecretBytes(64)
|
||||
val iv = Util.getSecretBytes(16)
|
||||
|
||||
val paddedPlaintext = PaddingInputStream(plaintext.inputStream(), plaintext.size.toLong()).readFully()
|
||||
val ciphertext = encryptPrePaddedBytes(paddedPlaintext, key, iv)
|
||||
val digest = getDigest(ciphertext)
|
||||
|
||||
val cipherFile = getTempFile()
|
||||
cipherFile.writeBytes(ciphertext)
|
||||
|
||||
val mmsId = -1L
|
||||
val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, digest, plaintext.size)), emptyList()).values.first()
|
||||
|
||||
// Give data to attachment table
|
||||
val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, digest, null, 4, false)
|
||||
SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream, iv)
|
||||
|
||||
// Verify the digest hasn't changed
|
||||
val newDigest = SignalDatabase.attachments.getAttachment(attachmentId)!!.remoteDigest!!
|
||||
assertArrayEquals(digest, newDigest)
|
||||
}
|
||||
|
||||
private fun createAttachmentPointer(key: ByteArray, digest: ByteArray, size: Int): Attachment {
|
||||
return PointerAttachment.forPointer(
|
||||
pointer = Optional.of(
|
||||
SignalServiceAttachmentPointer(
|
||||
cdnNumber = 3,
|
||||
remoteId = SignalServiceAttachmentRemoteId.V4("asdf"),
|
||||
contentType = MediaUtil.IMAGE_JPEG,
|
||||
key = key,
|
||||
size = Optional.of(size),
|
||||
preview = Optional.empty(),
|
||||
width = 2,
|
||||
height = 2,
|
||||
digest = Optional.of(digest),
|
||||
incrementalDigest = Optional.empty(),
|
||||
incrementalMacChunkSize = 0,
|
||||
fileName = Optional.of("file.jpg"),
|
||||
voiceNote = false,
|
||||
isBorderless = false,
|
||||
isGif = false,
|
||||
caption = Optional.empty(),
|
||||
blurHash = Optional.empty(),
|
||||
uploadTimestamp = 0,
|
||||
uuid = null
|
||||
)
|
||||
)
|
||||
).get()
|
||||
}
|
||||
|
||||
private fun createAttachment(id: Long, uri: Uri, transformProperties: AttachmentTable.TransformProperties): UriAttachment {
|
||||
return UriAttachmentBuilder.build(
|
||||
id,
|
||||
|
@ -179,4 +281,24 @@ class AttachmentTableTest {
|
|||
private fun createMediaStream(byteArray: ByteArray): MediaStream {
|
||||
return MediaStream(byteArray.inputStream(), MediaUtil.IMAGE_JPEG, 2, 2)
|
||||
}
|
||||
|
||||
private fun getDigest(ciphertext: ByteArray): ByteArray {
|
||||
val digestStream = NoCipherOutputStream(NullOutputStream)
|
||||
ciphertext.inputStream().copyTo(digestStream)
|
||||
return digestStream.transmittedDigest
|
||||
}
|
||||
|
||||
private fun encryptPrePaddedBytes(plaintext: ByteArray, key: ByteArray, iv: ByteArray): ByteArray {
|
||||
val outputStream = ByteArrayOutputStream()
|
||||
val cipherStream = AttachmentCipherOutputStream(key, iv, outputStream)
|
||||
plaintext.inputStream().copyTo(cipherStream)
|
||||
|
||||
return outputStream.toByteArray()
|
||||
}
|
||||
|
||||
private fun getTempFile(): File {
|
||||
val dir = InstrumentationRegistry.getInstrumentation().targetContext.getDir("temp", Context.MODE_PRIVATE)
|
||||
dir.mkdir()
|
||||
return File.createTempFile("transfer", ".mms", dir)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,6 +36,8 @@ import org.signal.core.util.Base64
|
|||
import org.signal.core.util.SqlUtil
|
||||
import org.signal.core.util.StreamUtil
|
||||
import org.signal.core.util.ThreadUtil
|
||||
import org.signal.core.util.allMatch
|
||||
import org.signal.core.util.copyTo
|
||||
import org.signal.core.util.count
|
||||
import org.signal.core.util.delete
|
||||
import org.signal.core.util.deleteAll
|
||||
|
@ -59,6 +61,8 @@ import org.signal.core.util.requireNonNullString
|
|||
import org.signal.core.util.requireObject
|
||||
import org.signal.core.util.requireString
|
||||
import org.signal.core.util.select
|
||||
import org.signal.core.util.stream.LimitedInputStream
|
||||
import org.signal.core.util.stream.NullOutputStream
|
||||
import org.signal.core.util.toInt
|
||||
import org.signal.core.util.update
|
||||
import org.signal.core.util.withinTransaction
|
||||
|
@ -94,7 +98,9 @@ import org.thoughtcrime.securesms.util.StorageUtil
|
|||
import org.thoughtcrime.securesms.util.Util
|
||||
import org.thoughtcrime.securesms.video.EncryptedMediaDataSource
|
||||
import org.whispersystems.signalservice.api.attachment.AttachmentUploadResult
|
||||
import org.whispersystems.signalservice.api.crypto.AttachmentCipherOutputStream
|
||||
import org.whispersystems.signalservice.api.util.UuidUtil
|
||||
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream
|
||||
import org.whispersystems.signalservice.internal.util.JsonUtil
|
||||
import java.io.File
|
||||
import java.io.FileNotFoundException
|
||||
|
@ -963,14 +969,32 @@ class AttachmentTable(
|
|||
* that the content of the attachment will never change.
|
||||
*/
|
||||
@Throws(MmsException::class)
|
||||
fun finalizeAttachmentAfterDownload(mmsId: Long, attachmentId: AttachmentId, inputStream: InputStream, iv: ByteArray?) {
|
||||
fun finalizeAttachmentAfterDownload(mmsId: Long, attachmentId: AttachmentId, inputStream: LimitedInputStream, iv: ByteArray?) {
|
||||
Log.i(TAG, "[finalizeAttachmentAfterDownload] Finalizing downloaded data for $attachmentId. (MessageId: $mmsId, $attachmentId)")
|
||||
|
||||
val existingPlaceholder: DatabaseAttachment = getAttachment(attachmentId) ?: throw MmsException("No attachment found for id: $attachmentId")
|
||||
|
||||
val fileWriteResult: DataFileWriteResult = writeToDataFile(newDataFile(context), inputStream, TransformProperties.empty())
|
||||
val fileWriteResult: DataFileWriteResult = writeToDataFile(newDataFile(context), inputStream, TransformProperties.empty(), closeInputStream = false)
|
||||
val transferFile: File? = getTransferFile(databaseHelper.signalReadableDatabase, attachmentId)
|
||||
|
||||
val paddingAllZeroes = inputStream.use { limitStream ->
|
||||
limitStream.leftoverStream().allMatch { it == 0x00.toByte() }
|
||||
}
|
||||
|
||||
val digest = if (paddingAllZeroes) {
|
||||
Log.d(TAG, "[finalizeAttachmentAfterDownload] $attachmentId has all-zero padding. Digest is good.")
|
||||
existingPlaceholder.remoteDigest!!
|
||||
} else {
|
||||
Log.w(TAG, "[finalizeAttachmentAfterDownload] $attachmentId has non-zero padding bytes. Recomputing digest.")
|
||||
|
||||
val stream = PaddingInputStream(getDataStream(fileWriteResult.file, fileWriteResult.random, 0), fileWriteResult.length)
|
||||
val key = Base64.decode(existingPlaceholder.remoteKey!!)
|
||||
val cipherOutputStream = AttachmentCipherOutputStream(key, iv, NullOutputStream)
|
||||
|
||||
StreamUtil.copy(stream, cipherOutputStream)
|
||||
cipherOutputStream.transmittedDigest
|
||||
}
|
||||
|
||||
val foundDuplicate = writableDatabase.withinTransaction { db ->
|
||||
// We can look and see if we have any exact matches on hash_ends and dedupe the file if we see one.
|
||||
// We don't look at hash_start here because that could result in us matching on a file that got compressed down to something smaller, effectively lowering
|
||||
|
@ -1013,6 +1037,7 @@ class AttachmentTable(
|
|||
values.put(TRANSFORM_PROPERTIES, TransformProperties.forSkipTransform().serialize())
|
||||
values.put(ARCHIVE_TRANSFER_FILE, null as String?)
|
||||
values.put(REMOTE_IV, iv)
|
||||
values.put(REMOTE_DIGEST, digest)
|
||||
|
||||
db.update(TABLE_NAME)
|
||||
.values(values)
|
||||
|
@ -1878,7 +1903,7 @@ class AttachmentTable(
|
|||
* Reads the entire stream and saves to disk and returns a bunch of metadat about the write.
|
||||
*/
|
||||
@Throws(MmsException::class, IllegalStateException::class)
|
||||
private fun writeToDataFile(destination: File, inputStream: InputStream, transformProperties: TransformProperties): DataFileWriteResult {
|
||||
private fun writeToDataFile(destination: File, inputStream: InputStream, transformProperties: TransformProperties, closeInputStream: Boolean = true): DataFileWriteResult {
|
||||
return try {
|
||||
// Sometimes the destination is a file that's already in use, sometimes it's not.
|
||||
// To avoid writing to a file while it's in-use, we write to a temp file and then rename it to the destination file at the end.
|
||||
|
@ -1890,7 +1915,7 @@ class AttachmentTable(
|
|||
val random = encryptingStreamData.first
|
||||
val encryptingOutputStream = encryptingStreamData.second
|
||||
|
||||
val length = StreamUtil.copy(digestInputStream, encryptingOutputStream)
|
||||
val length = digestInputStream.copyTo(encryptingOutputStream, closeInputStream)
|
||||
val hash = Base64.encodeWithPadding(digestInputStream.messageDigest.digest())
|
||||
|
||||
if (!tempFile.renameTo(destination)) {
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.greenrobot.eventbus.EventBus
|
|||
import org.signal.core.util.Base64
|
||||
import org.signal.core.util.Hex
|
||||
import org.signal.core.util.logging.Log
|
||||
import org.signal.core.util.stream.LimitedInputStream
|
||||
import org.signal.libsignal.protocol.InvalidMacException
|
||||
import org.signal.libsignal.protocol.InvalidMessageException
|
||||
import org.thoughtcrime.securesms.attachments.Attachment
|
||||
|
@ -415,7 +416,12 @@ class AttachmentDownloadJob private constructor(
|
|||
if (body.contentLength() > RemoteConfig.maxAttachmentReceiveSizeBytes) {
|
||||
throw MmsException("Attachment too large, failing download")
|
||||
}
|
||||
SignalDatabase.attachments.finalizeAttachmentAfterDownload(messageId, attachmentId, (body.source() as Source).buffer().inputStream(), iv = null)
|
||||
SignalDatabase.attachments.finalizeAttachmentAfterDownload(
|
||||
messageId,
|
||||
attachmentId,
|
||||
LimitedInputStream.withoutLimits((body.source() as Source).buffer().inputStream()),
|
||||
iv = null
|
||||
)
|
||||
}
|
||||
}
|
||||
} catch (e: MmsException) {
|
||||
|
|
|
@ -5,9 +5,11 @@
|
|||
|
||||
package org.signal.core.util
|
||||
|
||||
import org.signal.core.util.stream.LimitedInputStream
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
import java.io.OutputStream
|
||||
import kotlin.math.min
|
||||
|
||||
/**
|
||||
|
@ -112,3 +114,37 @@ fun InputStream.readLength(): Long {
|
|||
fun InputStream.drain() {
|
||||
this.readLength()
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a [LimitedInputStream] that will limit the number of bytes read from this stream to [limit].
|
||||
*/
|
||||
fun InputStream.limit(limit: Long): LimitedInputStream {
|
||||
return LimitedInputStream(this, limit)
|
||||
}
|
||||
|
||||
/**
|
||||
* Copies the contents of this stream to the [outputStream].
|
||||
*
|
||||
* @param closeInputStream If true, the input stream will be closed after the copy is complete.
|
||||
*/
|
||||
fun InputStream.copyTo(outputStream: OutputStream, closeInputStream: Boolean = true): Long {
|
||||
return StreamUtil.copy(this, outputStream, closeInputStream)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if every byte in this stream matches the predicate, otherwise false.
|
||||
*/
|
||||
fun InputStream.allMatch(predicate: (Byte) -> Boolean): Boolean {
|
||||
val buffer = ByteArray(4096)
|
||||
|
||||
var readCount: Int
|
||||
while (this.read(buffer).also { readCount = it } != -1) {
|
||||
for (i in 0 until readCount) {
|
||||
if (!predicate(buffer[i])) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -96,6 +96,10 @@ public final class StreamUtil {
|
|||
}
|
||||
|
||||
public static long copy(InputStream in, OutputStream out) throws IOException {
|
||||
return copy(in, out, true);
|
||||
}
|
||||
|
||||
public static long copy(InputStream in, OutputStream out, boolean closeInputStream) throws IOException {
|
||||
byte[] buffer = new byte[64 * 1024];
|
||||
int read;
|
||||
long total = 0;
|
||||
|
@ -105,7 +109,10 @@ public final class StreamUtil {
|
|||
total += read;
|
||||
}
|
||||
|
||||
in.close();
|
||||
if (closeInputStream) {
|
||||
in.close();
|
||||
}
|
||||
|
||||
out.flush();
|
||||
out.close();
|
||||
|
||||
|
|
|
@ -5,8 +5,6 @@
|
|||
|
||||
package org.signal.core.util.stream
|
||||
|
||||
import org.signal.core.util.readAtMostNBytes
|
||||
import org.signal.core.util.readFully
|
||||
import java.io.FilterInputStream
|
||||
import java.io.InputStream
|
||||
import java.lang.UnsupportedOperationException
|
||||
|
@ -22,8 +20,21 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes:
|
|||
private var totalBytesRead: Long = 0
|
||||
private var lastMark = -1L
|
||||
|
||||
companion object {
|
||||
|
||||
private const val UNLIMITED = -1L
|
||||
|
||||
/**
|
||||
* Returns a [LimitedInputStream] that doesn't limit the stream at all -- it'll allow reading the full thing.
|
||||
*/
|
||||
@JvmStatic
|
||||
fun withoutLimits(wrapped: InputStream): LimitedInputStream {
|
||||
return LimitedInputStream(wrapped = wrapped, maxBytes = UNLIMITED)
|
||||
}
|
||||
}
|
||||
|
||||
override fun read(): Int {
|
||||
if (maxBytes == -1L) {
|
||||
if (maxBytes == UNLIMITED) {
|
||||
return wrapped.read()
|
||||
}
|
||||
|
||||
|
@ -44,7 +55,7 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes:
|
|||
}
|
||||
|
||||
override fun read(destination: ByteArray, offset: Int, length: Int): Int {
|
||||
if (maxBytes == -1L) {
|
||||
if (maxBytes == UNLIMITED) {
|
||||
return wrapped.read(destination, offset, length)
|
||||
}
|
||||
|
||||
|
@ -64,7 +75,7 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes:
|
|||
}
|
||||
|
||||
override fun skip(requestedSkipCount: Long): Long {
|
||||
if (maxBytes == -1L) {
|
||||
if (maxBytes == UNLIMITED) {
|
||||
return wrapped.skip(requestedSkipCount)
|
||||
}
|
||||
|
||||
|
@ -78,7 +89,7 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes:
|
|||
}
|
||||
|
||||
override fun available(): Int {
|
||||
if (maxBytes == -1L) {
|
||||
if (maxBytes == UNLIMITED) {
|
||||
return wrapped.available()
|
||||
}
|
||||
|
||||
|
@ -97,7 +108,7 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes:
|
|||
|
||||
wrapped.mark(readlimit)
|
||||
|
||||
if (maxBytes == -1L) {
|
||||
if (maxBytes == UNLIMITED) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -109,13 +120,13 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes:
|
|||
throw UnsupportedOperationException("Mark not supported")
|
||||
}
|
||||
|
||||
if (lastMark == -1L) {
|
||||
if (lastMark == UNLIMITED) {
|
||||
throw UnsupportedOperationException("Mark not set")
|
||||
}
|
||||
|
||||
wrapped.reset()
|
||||
|
||||
if (maxBytes == -1L) {
|
||||
if (maxBytes == UNLIMITED) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -123,24 +134,18 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes:
|
|||
}
|
||||
|
||||
/**
|
||||
* If the stream has been fully read, this will return all bytes that were truncated from the stream.
|
||||
* If the stream was setup with no limit, this will always return an empty array.
|
||||
*
|
||||
* @param byteLimit The maximum number of truncated bytes to read. Defaults to no limit.
|
||||
* If the stream has been fully read, this will return a stream that contains the remaining bytes that were truncated.
|
||||
* If the stream was setup with no limit, this will always return an empty stream.
|
||||
*/
|
||||
fun readTruncatedBytes(byteLimit: Int = -1): ByteArray {
|
||||
if (maxBytes == -1L) {
|
||||
return ByteArray(0)
|
||||
fun leftoverStream(): InputStream {
|
||||
if (maxBytes == UNLIMITED) {
|
||||
return ByteArray(0).inputStream()
|
||||
}
|
||||
|
||||
if (totalBytesRead < maxBytes) {
|
||||
throw IllegalStateException("Stream has not been fully read")
|
||||
}
|
||||
|
||||
return if (byteLimit < 0) {
|
||||
wrapped.readFully()
|
||||
} else {
|
||||
wrapped.readAtMostNBytes(byteLimit)
|
||||
}
|
||||
return wrapped
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ class LimitedInputStreamTest {
|
|||
|
||||
@Test
|
||||
fun `when I fully read the stream via a buffer with no limit, I should get all bytes`() {
|
||||
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1)
|
||||
val inputStream = LimitedInputStream.withoutLimits(ByteArray(100).inputStream())
|
||||
val data = inputStream.readFully()
|
||||
|
||||
assertEquals(100, data.size)
|
||||
|
@ -44,7 +44,7 @@ class LimitedInputStreamTest {
|
|||
|
||||
@Test
|
||||
fun `when I fully read the stream one byte at a time with no limit, I should only get maxBytes`() {
|
||||
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1)
|
||||
val inputStream = LimitedInputStream.withoutLimits(ByteArray(100).inputStream())
|
||||
|
||||
var count = 0
|
||||
var lastRead = inputStream.read()
|
||||
|
@ -88,35 +88,26 @@ class LimitedInputStreamTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
fun `when I finish reading the stream, getTruncatedBytes gives me the rest`() {
|
||||
fun `when I finish reading the stream, leftoverStream gives me the rest`() {
|
||||
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
|
||||
inputStream.readFully()
|
||||
|
||||
val truncatedBytes = inputStream.readTruncatedBytes()
|
||||
val truncatedBytes = inputStream.leftoverStream().readFully()
|
||||
assertEquals(25, truncatedBytes.size)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when I finish reading the stream, getTruncatedBytes gives me the rest, respecting the byte limit`() {
|
||||
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
|
||||
inputStream.readFully()
|
||||
|
||||
val truncatedBytes = inputStream.readTruncatedBytes(byteLimit = 10)
|
||||
assertEquals(10, truncatedBytes.size)
|
||||
}
|
||||
|
||||
@Test(expected = IllegalStateException::class)
|
||||
fun `if I have not finished reading the stream, getTruncatedBytes throws IllegalStateException`() {
|
||||
fun `if I have not finished reading the stream, leftoverStream throws IllegalStateException`() {
|
||||
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
|
||||
inputStream.readTruncatedBytes()
|
||||
inputStream.leftoverStream()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when call getTruncatedBytes on a stream with no limit, it returns an empty array`() {
|
||||
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1)
|
||||
fun `when call leftoverStream on a stream with no limit, it returns an empty array`() {
|
||||
val inputStream = LimitedInputStream.withoutLimits(ByteArray(100).inputStream())
|
||||
inputStream.readFully()
|
||||
|
||||
val truncatedBytes = inputStream.readTruncatedBytes()
|
||||
val truncatedBytes = inputStream.leftoverStream().readFully()
|
||||
assertEquals(0, truncatedBytes.size)
|
||||
}
|
||||
|
||||
|
@ -130,7 +121,7 @@ class LimitedInputStreamTest {
|
|||
|
||||
@Test
|
||||
fun `when I call available with no limit, it should return the full length`() {
|
||||
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1)
|
||||
val inputStream = LimitedInputStream.withoutLimits(ByteArray(100).inputStream())
|
||||
val available = inputStream.available()
|
||||
|
||||
assertEquals(100, available)
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.signal.core.util.StreamUtil;
|
|||
import org.signal.core.util.concurrent.FutureTransformers;
|
||||
import org.signal.core.util.concurrent.ListenableFuture;
|
||||
import org.signal.core.util.concurrent.SettableFuture;
|
||||
import org.signal.core.util.stream.LimitedInputStream;
|
||||
import org.signal.libsignal.protocol.InvalidMessageException;
|
||||
import org.signal.libsignal.zkgroup.profiles.ClientZkProfileOperations;
|
||||
import org.signal.libsignal.zkgroup.profiles.ProfileKey;
|
||||
|
@ -220,7 +221,7 @@ public class SignalServiceMessageReceiver {
|
|||
StreamUtil.readFully(tempStream, iv);
|
||||
}
|
||||
|
||||
InputStream dataStream = AttachmentCipherInputStream.createForAttachment(
|
||||
LimitedInputStream dataStream = AttachmentCipherInputStream.createForAttachment(
|
||||
attachmentDestination,
|
||||
pointer.getSize().orElse(0),
|
||||
pointer.getKey(),
|
||||
|
|
|
@ -5,12 +5,12 @@
|
|||
|
||||
package org.whispersystems.signalservice.api.attachment
|
||||
|
||||
import java.io.InputStream
|
||||
import org.signal.core.util.stream.LimitedInputStream
|
||||
|
||||
/**
|
||||
* Holds the result of an attachment download.
|
||||
*/
|
||||
class AttachmentDownloadResult(
|
||||
val dataStream: InputStream,
|
||||
val dataStream: LimitedInputStream,
|
||||
val iv: ByteArray
|
||||
)
|
||||
|
|
|
@ -59,7 +59,7 @@ public class AttachmentCipherInputStream extends FilterInputStream {
|
|||
/**
|
||||
* Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation.
|
||||
*/
|
||||
public static InputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize)
|
||||
public static LimitedInputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize)
|
||||
throws InvalidMessageException, IOException {
|
||||
return createForAttachment(file, plaintextLength, combinedKeyMaterial, digest, incrementalDigest, incrementalMacChunkSize, false);
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ public class AttachmentCipherInputStream extends FilterInputStream {
|
|||
*
|
||||
* Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST
|
||||
*/
|
||||
public static InputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize, boolean ignoreDigest)
|
||||
public static LimitedInputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize, boolean ignoreDigest)
|
||||
throws InvalidMessageException, IOException
|
||||
{
|
||||
return createForAttachment(() -> new FileInputStream(file), file.length(), plaintextLength, combinedKeyMaterial, digest, incrementalDigest, incrementalMacChunkSize, ignoreDigest);
|
||||
|
@ -80,7 +80,7 @@ public class AttachmentCipherInputStream extends FilterInputStream {
|
|||
*
|
||||
* Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST
|
||||
*/
|
||||
public static InputStream createForAttachment(StreamSupplier streamSupplier, long streamLength, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize, boolean ignoreDigest)
|
||||
public static LimitedInputStream createForAttachment(StreamSupplier streamSupplier, long streamLength, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize, boolean ignoreDigest)
|
||||
throws InvalidMessageException, IOException
|
||||
{
|
||||
byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE);
|
||||
|
@ -117,16 +117,16 @@ public class AttachmentCipherInputStream extends FilterInputStream {
|
|||
InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], streamLength - BLOCK_SIZE - mac.getMacLength());
|
||||
|
||||
if (plaintextLength != 0) {
|
||||
inputStream = new LimitedInputStream(inputStream, plaintextLength);
|
||||
return new LimitedInputStream(inputStream, plaintextLength);
|
||||
} else {
|
||||
return LimitedInputStream.withoutLimits(inputStream);
|
||||
}
|
||||
|
||||
return inputStream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypt archived media to it's original attachment encrypted blob.
|
||||
*/
|
||||
public static InputStream createForArchivedMedia(BackupKey.MediaKeyMaterial archivedMediaKeyMaterial, File file, long originalCipherTextLength)
|
||||
public static LimitedInputStream createForArchivedMedia(BackupKey.MediaKeyMaterial archivedMediaKeyMaterial, File file, long originalCipherTextLength)
|
||||
throws InvalidMessageException, IOException
|
||||
{
|
||||
Mac mac = initMac(archivedMediaKeyMaterial.getMacKey());
|
||||
|
@ -142,13 +142,13 @@ public class AttachmentCipherInputStream extends FilterInputStream {
|
|||
InputStream inputStream = new AttachmentCipherInputStream(new FileInputStream(file), archivedMediaKeyMaterial.getCipherKey(), file.length() - BLOCK_SIZE - mac.getMacLength());
|
||||
|
||||
if (originalCipherTextLength != 0) {
|
||||
inputStream = new LimitedInputStream(inputStream, originalCipherTextLength);
|
||||
return new LimitedInputStream(inputStream, originalCipherTextLength);
|
||||
} else {
|
||||
return LimitedInputStream.withoutLimits(inputStream);
|
||||
}
|
||||
|
||||
return inputStream;
|
||||
}
|
||||
|
||||
public static InputStream createStreamingForArchivedAttachment(BackupKey.MediaKeyMaterial archivedMediaKeyMaterial, File file, long originalCipherTextLength, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize)
|
||||
public static LimitedInputStream createStreamingForArchivedAttachment(BackupKey.MediaKeyMaterial archivedMediaKeyMaterial, File file, long originalCipherTextLength, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize)
|
||||
throws InvalidMessageException, IOException
|
||||
{
|
||||
final InputStream archiveStream = createForArchivedMedia(archivedMediaKeyMaterial, file, originalCipherTextLength);
|
||||
|
@ -179,10 +179,11 @@ public class AttachmentCipherInputStream extends FilterInputStream {
|
|||
InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], file.length() - BLOCK_SIZE - mac.getMacLength());
|
||||
|
||||
if (plaintextLength != 0) {
|
||||
inputStream = new LimitedInputStream(inputStream, plaintextLength);
|
||||
return new LimitedInputStream(inputStream, plaintextLength);
|
||||
} else {
|
||||
return LimitedInputStream.withoutLimits(inputStream);
|
||||
}
|
||||
|
||||
return inputStream;
|
||||
}
|
||||
|
||||
public static InputStream createForStickerData(byte[] data, byte[] packKey)
|
||||
|
|
Loading…
Add table
Reference in a new issue