Add IV to the attachment table.

This commit is contained in:
Greyson Parrelli 2024-08-30 12:11:22 -04:00 committed by Cody Henthorne
parent 07289b417b
commit 4b47d38d78
26 changed files with 534 additions and 309 deletions

View file

@ -15,7 +15,6 @@ import org.signal.core.util.Base64
import org.signal.core.util.update import org.signal.core.util.update
import org.thoughtcrime.securesms.attachments.AttachmentId import org.thoughtcrime.securesms.attachments.AttachmentId
import org.thoughtcrime.securesms.attachments.Cdn import org.thoughtcrime.securesms.attachments.Cdn
import org.thoughtcrime.securesms.attachments.PointerAttachment
import org.thoughtcrime.securesms.backup.v2.BackupRepository.getMediaName import org.thoughtcrime.securesms.backup.v2.BackupRepository.getMediaName
import org.thoughtcrime.securesms.database.AttachmentTable.TransformProperties import org.thoughtcrime.securesms.database.AttachmentTable.TransformProperties
import org.thoughtcrime.securesms.keyvalue.SignalStore import org.thoughtcrime.securesms.keyvalue.SignalStore
@ -27,7 +26,9 @@ import org.thoughtcrime.securesms.providers.BlobProvider
import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.util.MediaUtil import org.thoughtcrime.securesms.util.MediaUtil
import org.thoughtcrime.securesms.util.Util import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.attachment.AttachmentUploadResult
import org.whispersystems.signalservice.api.backup.MediaId import org.whispersystems.signalservice.api.backup.MediaId
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId
import org.whispersystems.signalservice.api.push.ServiceId import org.whispersystems.signalservice.api.push.ServiceId
import java.io.File import java.io.File
import java.util.UUID import java.util.UUID
@ -661,7 +662,7 @@ class AttachmentTableTest_deduping {
} }
fun upload(attachmentId: AttachmentId, uploadTimestamp: Long = System.currentTimeMillis()) { fun upload(attachmentId: AttachmentId, uploadTimestamp: Long = System.currentTimeMillis()) {
SignalDatabase.attachments.finalizeAttachmentAfterUpload(attachmentId, createPointerAttachment(attachmentId, uploadTimestamp), uploadTimestamp) SignalDatabase.attachments.finalizeAttachmentAfterUpload(attachmentId, createUploadResult(attachmentId, uploadTimestamp))
val attachment = SignalDatabase.attachments.getAttachment(attachmentId)!! val attachment = SignalDatabase.attachments.getAttachment(attachmentId)!!
SignalDatabase.attachments.setArchiveData( SignalDatabase.attachments.setArchiveData(
@ -763,6 +764,7 @@ class AttachmentTableTest_deduping {
assertEquals(lhsAttachment.remoteLocation, rhsAttachment.remoteLocation) assertEquals(lhsAttachment.remoteLocation, rhsAttachment.remoteLocation)
assertEquals(lhsAttachment.remoteKey, rhsAttachment.remoteKey) assertEquals(lhsAttachment.remoteKey, rhsAttachment.remoteKey)
assertArrayEquals(lhsAttachment.remoteIv, rhsAttachment.remoteIv)
assertArrayEquals(lhsAttachment.remoteDigest, rhsAttachment.remoteDigest) assertArrayEquals(lhsAttachment.remoteDigest, rhsAttachment.remoteDigest)
assertArrayEquals(lhsAttachment.incrementalDigest, rhsAttachment.incrementalDigest) assertArrayEquals(lhsAttachment.incrementalDigest, rhsAttachment.incrementalDigest)
assertEquals(lhsAttachment.incrementalMacChunkSize, rhsAttachment.incrementalMacChunkSize) assertEquals(lhsAttachment.incrementalMacChunkSize, rhsAttachment.incrementalMacChunkSize)
@ -796,36 +798,19 @@ class AttachmentTableTest_deduping {
return MediaStream(this.inputStream(), MediaUtil.IMAGE_JPEG, 2, 2) return MediaStream(this.inputStream(), MediaUtil.IMAGE_JPEG, 2, 2)
} }
private fun createPointerAttachment(attachmentId: AttachmentId, uploadTimestamp: Long = System.currentTimeMillis()): PointerAttachment { private fun createUploadResult(attachmentId: AttachmentId, uploadTimestamp: Long = System.currentTimeMillis()): AttachmentUploadResult {
val location = "somewhere-${Random.nextLong()}"
val key = "somekey-${Random.nextLong()}"
val digest = Random.nextBytes(32)
val incrementalDigest = Random.nextBytes(16)
val databaseAttachment = SignalDatabase.attachments.getAttachment(attachmentId)!! val databaseAttachment = SignalDatabase.attachments.getAttachment(attachmentId)!!
return PointerAttachment( return AttachmentUploadResult(
"image/jpeg", remoteId = SignalServiceAttachmentRemoteId.V4("somewhere-${Random.nextLong()}"),
AttachmentTable.TRANSFER_PROGRESS_DONE, cdnNumber = Cdn.CDN_3.cdnNumber,
databaseAttachment.size, // size key = databaseAttachment.remoteKey?.let { Base64.decode(it) } ?: Util.getSecretBytes(64),
null, iv = databaseAttachment.remoteIv ?: Util.getSecretBytes(16),
Cdn.CDN_3, // cdnNumber digest = Random.nextBytes(32),
location, incrementalDigest = Random.nextBytes(16),
key, incrementalDigestChunkSize = 5,
digest, uploadTimestamp = uploadTimestamp,
incrementalDigest, dataSize = databaseAttachment.size
5, // incrementalMacChunkSize
null,
databaseAttachment.voiceNote,
databaseAttachment.borderless,
databaseAttachment.videoGif,
databaseAttachment.width,
databaseAttachment.height,
uploadTimestamp,
databaseAttachment.caption,
databaseAttachment.stickerLocator,
databaseAttachment.blurHash,
databaseAttachment.uuid
) )
} }
} }

View file

@ -13,6 +13,7 @@ import org.junit.Ignore
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith import org.junit.runner.RunWith
import org.signal.core.util.Base64
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.signal.core.util.update import org.signal.core.util.update
import org.signal.core.util.withinTransaction import org.signal.core.util.withinTransaction
@ -33,6 +34,9 @@ import org.thoughtcrime.securesms.testing.assertIsNot
import org.thoughtcrime.securesms.testing.assertIsNotNull import org.thoughtcrime.securesms.testing.assertIsNotNull
import org.thoughtcrime.securesms.testing.assertIsSize import org.thoughtcrime.securesms.testing.assertIsSize
import org.thoughtcrime.securesms.util.IdentityUtil import org.thoughtcrime.securesms.util.IdentityUtil
import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.attachment.AttachmentUploadResult
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId
import java.util.UUID import java.util.UUID
@Suppress("ClassName") @Suppress("ClassName")
@ -574,30 +578,35 @@ class SyncMessageProcessorTest_synchronizeDeleteForMe {
// Has all three // Has all three
SignalDatabase.attachments.finalizeAttachmentAfterUpload( SignalDatabase.attachments.finalizeAttachmentAfterUpload(
id = attachments[0].attachmentId, id = attachments[0].attachmentId,
attachment = attachments[0].copy(digest = byteArrayOf(attachments[0].attachmentId.id.toByte())), uploadResult = attachments[0].toUploadResult(
uploadTimestamp = message1.timestamp + 1 digest = byteArrayOf(attachments[0].attachmentId.id.toByte()),
uploadTimestamp = message1.timestamp + 1
)
) )
// Missing uuid and digest // Missing uuid and digest
SignalDatabase.attachments.finalizeAttachmentAfterUpload( SignalDatabase.attachments.finalizeAttachmentAfterUpload(
id = attachments[1].attachmentId, id = attachments[1].attachmentId,
attachment = attachments[1], uploadResult = attachments[1].toUploadResult(uploadTimestamp = message1.timestamp + 1)
uploadTimestamp = message1.timestamp + 1
) )
// Missing uuid and plain text // Missing uuid and plain text
SignalDatabase.attachments.finalizeAttachmentAfterUpload( SignalDatabase.attachments.finalizeAttachmentAfterUpload(
id = attachments[2].attachmentId, id = attachments[2].attachmentId,
attachment = attachments[2].copy(digest = byteArrayOf(attachments[2].attachmentId.id.toByte())), uploadResult = attachments[2].toUploadResult(
uploadTimestamp = message1.timestamp + 1 digest = byteArrayOf(attachments[2].attachmentId.id.toByte()),
uploadTimestamp = message1.timestamp + 1
)
) )
SignalDatabase.rawDatabase.update(AttachmentTable.TABLE_NAME).values(AttachmentTable.DATA_HASH_END to null).where("${AttachmentTable.ID} = ?", attachments[2].attachmentId).run() SignalDatabase.rawDatabase.update(AttachmentTable.TABLE_NAME).values(AttachmentTable.DATA_HASH_END to null).where("${AttachmentTable.ID} = ?", attachments[2].attachmentId).run()
// Different has all three // Different has all three
SignalDatabase.attachments.finalizeAttachmentAfterUpload( SignalDatabase.attachments.finalizeAttachmentAfterUpload(
id = attachments[3].attachmentId, id = attachments[3].attachmentId,
attachment = attachments[3].copy(digest = byteArrayOf(attachments[3].attachmentId.id.toByte())), uploadResult = attachments[3].toUploadResult(
uploadTimestamp = message1.timestamp + 1 digest = byteArrayOf(attachments[3].attachmentId.id.toByte()),
uploadTimestamp = message1.timestamp + 1
)
) )
attachments = SignalDatabase.attachments.getAttachmentsForMessage(message1.messageId) attachments = SignalDatabase.attachments.getAttachmentsForMessage(message1.messageId)
@ -674,6 +683,7 @@ class SyncMessageProcessorTest_synchronizeDeleteForMe {
cdn = this.cdn, cdn = this.cdn,
location = this.remoteLocation, location = this.remoteLocation,
key = this.remoteKey, key = this.remoteKey,
iv = this.remoteIv,
digest = digest, digest = digest,
incrementalDigest = this.incrementalDigest, incrementalDigest = this.incrementalDigest,
incrementalMacChunkSize = this.incrementalMacChunkSize, incrementalMacChunkSize = this.incrementalMacChunkSize,
@ -700,4 +710,21 @@ class SyncMessageProcessorTest_synchronizeDeleteForMe {
uuid = uuid uuid = uuid
) )
} }
private fun Attachment.toUploadResult(
digest: ByteArray = this.remoteDigest ?: byteArrayOf(),
uploadTimestamp: Long = this.uploadTimestamp
): AttachmentUploadResult {
return AttachmentUploadResult(
remoteId = SignalServiceAttachmentRemoteId.V4(this.remoteLocation ?: "some-location"),
cdnNumber = this.cdn.cdnNumber,
key = this.remoteKey?.let { Base64.decode(it) } ?: Util.getSecretBytes(64),
iv = this.remoteIv ?: Util.getSecretBytes(16),
digest = digest,
incrementalDigest = this.incrementalDigest,
incrementalDigestChunkSize = this.incrementalMacChunkSize,
dataSize = this.size,
uploadTimestamp = uploadTimestamp
)
}
} }

View file

@ -32,6 +32,7 @@ class ArchivedAttachment : Attachment {
size: Long, size: Long,
cdn: Int, cdn: Int,
key: ByteArray, key: ByteArray,
iv: ByteArray?,
cdnKey: String?, cdnKey: String?,
archiveCdn: Int?, archiveCdn: Int?,
archiveMediaName: String, archiveMediaName: String,
@ -60,6 +61,7 @@ class ArchivedAttachment : Attachment {
cdn = Cdn.fromCdnNumber(cdn), cdn = Cdn.fromCdnNumber(cdn),
remoteLocation = cdnKey, remoteLocation = cdnKey,
remoteKey = Base64.encodeWithoutPadding(key), remoteKey = Base64.encodeWithoutPadding(key),
remoteIv = iv,
remoteDigest = digest, remoteDigest = digest,
incrementalDigest = incrementalMac, incrementalDigest = incrementalMac,
fastPreflightId = null, fastPreflightId = null,

View file

@ -37,6 +37,8 @@ abstract class Attachment(
@JvmField @JvmField
val remoteKey: String?, val remoteKey: String?,
@JvmField @JvmField
val remoteIv: ByteArray?,
@JvmField
val remoteDigest: ByteArray?, val remoteDigest: ByteArray?,
@JvmField @JvmField
val incrementalDigest: ByteArray?, val incrementalDigest: ByteArray?,
@ -86,6 +88,7 @@ abstract class Attachment(
cdn = Cdn.deserialize(parcel.readInt()), cdn = Cdn.deserialize(parcel.readInt()),
remoteLocation = parcel.readString(), remoteLocation = parcel.readString(),
remoteKey = parcel.readString(), remoteKey = parcel.readString(),
remoteIv = ParcelUtil.readByteArray(parcel),
remoteDigest = ParcelUtil.readByteArray(parcel), remoteDigest = ParcelUtil.readByteArray(parcel),
incrementalDigest = ParcelUtil.readByteArray(parcel), incrementalDigest = ParcelUtil.readByteArray(parcel),
fastPreflightId = parcel.readString(), fastPreflightId = parcel.readString(),

View file

@ -58,6 +58,7 @@ class DatabaseAttachment : Attachment {
cdn: Cdn, cdn: Cdn,
location: String?, location: String?,
key: String?, key: String?,
iv: ByteArray?,
digest: ByteArray?, digest: ByteArray?,
incrementalDigest: ByteArray?, incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int, incrementalMacChunkSize: Int,
@ -90,6 +91,7 @@ class DatabaseAttachment : Attachment {
cdn = cdn, cdn = cdn,
remoteLocation = location, remoteLocation = location,
remoteKey = key, remoteKey = key,
remoteIv = iv,
remoteDigest = digest, remoteDigest = digest,
incrementalDigest = incrementalDigest, incrementalDigest = incrementalDigest,
fastPreflightId = fastPreflightId, fastPreflightId = fastPreflightId,

View file

@ -3,7 +3,7 @@ package org.thoughtcrime.securesms.attachments
import android.net.Uri import android.net.Uri
import android.os.Parcel import android.os.Parcel
import androidx.annotation.VisibleForTesting import androidx.annotation.VisibleForTesting
import org.signal.core.util.Base64.encodeWithPadding import org.signal.core.util.Base64
import org.thoughtcrime.securesms.blurhash.BlurHash import org.thoughtcrime.securesms.blurhash.BlurHash
import org.thoughtcrime.securesms.database.AttachmentTable import org.thoughtcrime.securesms.database.AttachmentTable
import org.thoughtcrime.securesms.stickers.StickerLocator import org.thoughtcrime.securesms.stickers.StickerLocator
@ -24,6 +24,7 @@ class PointerAttachment : Attachment {
cdn: Cdn, cdn: Cdn,
location: String, location: String,
key: String?, key: String?,
iv: ByteArray?,
digest: ByteArray?, digest: ByteArray?,
incrementalDigest: ByteArray?, incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int, incrementalMacChunkSize: Int,
@ -46,6 +47,7 @@ class PointerAttachment : Attachment {
cdn = cdn, cdn = cdn,
remoteLocation = location, remoteLocation = location,
remoteKey = key, remoteKey = key,
remoteIv = iv,
remoteDigest = digest, remoteDigest = digest,
incrementalDigest = incrementalDigest, incrementalDigest = incrementalDigest,
fastPreflightId = fastPreflightId, fastPreflightId = fastPreflightId,
@ -86,12 +88,17 @@ class PointerAttachment : Attachment {
@JvmStatic @JvmStatic
@JvmOverloads @JvmOverloads
fun forPointer(pointer: Optional<SignalServiceAttachment>, stickerLocator: StickerLocator? = null, fastPreflightId: String? = null, transferState: Int = AttachmentTable.TRANSFER_PROGRESS_PENDING): Optional<Attachment> { fun forPointer(
pointer: Optional<SignalServiceAttachment>,
stickerLocator: StickerLocator? = null,
fastPreflightId: String? = null,
transferState: Int = AttachmentTable.TRANSFER_PROGRESS_PENDING
): Optional<Attachment> {
if (!pointer.isPresent || !pointer.get().isPointer()) { if (!pointer.isPresent || !pointer.get().isPointer()) {
return Optional.empty() return Optional.empty()
} }
val encodedKey: String? = pointer.get().asPointer().key?.let { encodeWithPadding(it) } val encodedKey: String? = pointer.get().asPointer().key?.let { Base64.encodeWithPadding(it) }
return Optional.of( return Optional.of(
PointerAttachment( PointerAttachment(
@ -102,6 +109,7 @@ class PointerAttachment : Attachment {
cdn = Cdn.fromCdnNumber(pointer.get().asPointer().cdnNumber), cdn = Cdn.fromCdnNumber(pointer.get().asPointer().cdnNumber),
location = pointer.get().asPointer().remoteId.toString(), location = pointer.get().asPointer().remoteId.toString(),
key = encodedKey, key = encodedKey,
iv = null,
digest = pointer.get().asPointer().digest.orElse(null), digest = pointer.get().asPointer().digest.orElse(null),
incrementalDigest = pointer.get().asPointer().incrementalDigest.orElse(null), incrementalDigest = pointer.get().asPointer().incrementalDigest.orElse(null),
incrementalMacChunkSize = pointer.get().asPointer().incrementalMacChunkSize, incrementalMacChunkSize = pointer.get().asPointer().incrementalMacChunkSize,
@ -139,7 +147,8 @@ class PointerAttachment : Attachment {
fileName = quotedAttachment.fileName, fileName = quotedAttachment.fileName,
cdn = Cdn.fromCdnNumber(thumbnail?.asPointer()?.cdnNumber ?: 0), cdn = Cdn.fromCdnNumber(thumbnail?.asPointer()?.cdnNumber ?: 0),
location = thumbnail?.asPointer()?.remoteId?.toString() ?: "0", location = thumbnail?.asPointer()?.remoteId?.toString() ?: "0",
key = thumbnail?.asPointer()?.key?.let { encodeWithPadding(it) }, key = thumbnail?.asPointer()?.key?.let { Base64.encodeWithPadding(it) },
iv = null,
digest = thumbnail?.asPointer()?.digest?.orElse(null), digest = thumbnail?.asPointer()?.digest?.orElse(null),
incrementalDigest = thumbnail?.asPointer()?.incrementalDigest?.orElse(null), incrementalDigest = thumbnail?.asPointer()?.incrementalDigest?.orElse(null),
incrementalMacChunkSize = thumbnail?.asPointer()?.incrementalMacChunkSize ?: 0, incrementalMacChunkSize = thumbnail?.asPointer()?.incrementalMacChunkSize ?: 0,

View file

@ -22,6 +22,7 @@ class TombstoneAttachment : Attachment {
cdn = Cdn.CDN_0, cdn = Cdn.CDN_0,
remoteLocation = null, remoteLocation = null,
remoteKey = null, remoteKey = null,
remoteIv = null,
remoteDigest = null, remoteDigest = null,
incrementalDigest = null, incrementalDigest = null,
fastPreflightId = null, fastPreflightId = null,
@ -62,6 +63,7 @@ class TombstoneAttachment : Attachment {
cdn = Cdn.CDN_0, cdn = Cdn.CDN_0,
remoteLocation = null, remoteLocation = null,
remoteKey = null, remoteKey = null,
remoteIv = null,
remoteDigest = null, remoteDigest = null,
incrementalDigest = incrementalMac, incrementalDigest = incrementalMac,
fastPreflightId = null, fastPreflightId = null,

View file

@ -75,6 +75,7 @@ class UriAttachment : Attachment {
cdn = Cdn.CDN_0, cdn = Cdn.CDN_0,
remoteLocation = null, remoteLocation = null,
remoteKey = null, remoteKey = null,
remoteIv = null,
remoteDigest = null, remoteDigest = null,
incrementalDigest = null, incrementalDigest = null,
fastPreflightId = fastPreflightId, fastPreflightId = fastPreflightId,

View file

@ -1030,6 +1030,7 @@ class ChatItemImportInserter(
size = this.backupLocator.size.toLong(), size = this.backupLocator.size.toLong(),
cdn = this.backupLocator.transitCdnNumber ?: Cdn.CDN_0.cdnNumber, cdn = this.backupLocator.transitCdnNumber ?: Cdn.CDN_0.cdnNumber,
key = this.backupLocator.key.toByteArray(), key = this.backupLocator.key.toByteArray(),
iv = null,
cdnKey = this.backupLocator.transitCdnKey, cdnKey = this.backupLocator.transitCdnKey,
archiveCdn = this.backupLocator.cdnNumber, archiveCdn = this.backupLocator.cdnNumber,
archiveMediaName = this.backupLocator.mediaName, archiveMediaName = this.backupLocator.mediaName,

View file

@ -90,7 +90,9 @@ import org.thoughtcrime.securesms.util.FileUtils
import org.thoughtcrime.securesms.util.JsonUtils.SaneJSONObject import org.thoughtcrime.securesms.util.JsonUtils.SaneJSONObject
import org.thoughtcrime.securesms.util.MediaUtil import org.thoughtcrime.securesms.util.MediaUtil
import org.thoughtcrime.securesms.util.StorageUtil import org.thoughtcrime.securesms.util.StorageUtil
import org.thoughtcrime.securesms.util.Util
import org.thoughtcrime.securesms.video.EncryptedMediaDataSource import org.thoughtcrime.securesms.video.EncryptedMediaDataSource
import org.whispersystems.signalservice.api.attachment.AttachmentUploadResult
import org.whispersystems.signalservice.api.util.UuidUtil import org.whispersystems.signalservice.api.util.UuidUtil
import org.whispersystems.signalservice.internal.util.JsonUtil import org.whispersystems.signalservice.internal.util.JsonUtil
import java.io.File import java.io.File
@ -118,6 +120,7 @@ class AttachmentTable(
const val MESSAGE_ID = "message_id" const val MESSAGE_ID = "message_id"
const val CONTENT_TYPE = "content_type" const val CONTENT_TYPE = "content_type"
const val REMOTE_KEY = "remote_key" const val REMOTE_KEY = "remote_key"
const val REMOTE_IV = "remote_iv"
const val REMOTE_LOCATION = "remote_location" const val REMOTE_LOCATION = "remote_location"
const val REMOTE_DIGEST = "remote_digest" const val REMOTE_DIGEST = "remote_digest"
const val REMOTE_INCREMENTAL_DIGEST = "remote_incremental_digest" const val REMOTE_INCREMENTAL_DIGEST = "remote_incremental_digest"
@ -178,6 +181,7 @@ class AttachmentTable(
MESSAGE_ID, MESSAGE_ID,
CONTENT_TYPE, CONTENT_TYPE,
REMOTE_KEY, REMOTE_KEY,
REMOTE_IV,
REMOTE_LOCATION, REMOTE_LOCATION,
REMOTE_DIGEST, REMOTE_DIGEST,
REMOTE_INCREMENTAL_DIGEST, REMOTE_INCREMENTAL_DIGEST,
@ -263,7 +267,8 @@ class AttachmentTable(
$THUMBNAIL_FILE TEXT DEFAULT NULL, $THUMBNAIL_FILE TEXT DEFAULT NULL,
$THUMBNAIL_RANDOM BLOB DEFAULT NULL, $THUMBNAIL_RANDOM BLOB DEFAULT NULL,
$THUMBNAIL_RESTORE_STATE INTEGER DEFAULT ${ThumbnailRestoreState.NONE.value}, $THUMBNAIL_RESTORE_STATE INTEGER DEFAULT ${ThumbnailRestoreState.NONE.value},
$ATTACHMENT_UUID TEXT DEFAULT NULL $ATTACHMENT_UUID TEXT DEFAULT NULL,
$REMOTE_IV BLOB DEFAULT NULL
) )
""" """
@ -1026,7 +1031,7 @@ class AttachmentTable(
* it's ending hash, which is critical for backups. * it's ending hash, which is critical for backups.
*/ */
@Throws(IOException::class) @Throws(IOException::class)
fun finalizeAttachmentAfterUpload(id: AttachmentId, attachment: Attachment, uploadTimestamp: Long) { fun finalizeAttachmentAfterUpload(id: AttachmentId, uploadResult: AttachmentUploadResult) {
Log.i(TAG, "[finalizeAttachmentAfterUpload] Finalizing upload for $id.") Log.i(TAG, "[finalizeAttachmentAfterUpload] Finalizing upload for $id.")
val dataStream = getAttachmentStream(id, 0) val dataStream = getAttachmentStream(id, 0)
@ -1040,17 +1045,14 @@ class AttachmentTable(
val values = contentValuesOf( val values = contentValuesOf(
TRANSFER_STATE to TRANSFER_PROGRESS_DONE, TRANSFER_STATE to TRANSFER_PROGRESS_DONE,
CDN_NUMBER to attachment.cdn.serialize(), CDN_NUMBER to uploadResult.cdnNumber,
REMOTE_LOCATION to attachment.remoteLocation, REMOTE_LOCATION to uploadResult.remoteId.toString(),
REMOTE_DIGEST to attachment.remoteDigest, REMOTE_DIGEST to uploadResult.digest,
REMOTE_INCREMENTAL_DIGEST to attachment.incrementalDigest, REMOTE_INCREMENTAL_DIGEST to uploadResult.incrementalDigest,
REMOTE_INCREMENTAL_DIGEST_CHUNK_SIZE to attachment.incrementalMacChunkSize, REMOTE_INCREMENTAL_DIGEST_CHUNK_SIZE to uploadResult.incrementalDigestChunkSize,
REMOTE_KEY to attachment.remoteKey, DATA_SIZE to uploadResult.dataSize,
DATA_SIZE to attachment.size,
DATA_HASH_END to dataHashEnd, DATA_HASH_END to dataHashEnd,
FAST_PREFLIGHT_ID to attachment.fastPreflightId, UPLOAD_TIMESTAMP to uploadResult.uploadTimestamp
BLUR_HASH to attachment.getVisualHashStringOrNull(),
UPLOAD_TIMESTAMP to uploadTimestamp
) )
val dataFilePath = getDataFilePath(id) ?: throw IOException("No data file found for attachment!") val dataFilePath = getDataFilePath(id) ?: throw IOException("No data file found for attachment!")
@ -1153,6 +1155,25 @@ class AttachmentTable(
} }
} }
fun createKeyIvIfNecessary(attachmentId: AttachmentId) {
val key = Util.getSecretBytes(64)
val iv = Util.getSecretBytes(16)
writableDatabase.withinTransaction {
writableDatabase
.update(TABLE_NAME)
.values(REMOTE_KEY to Base64.encodeWithPadding(key))
.where("$ID = ? AND $REMOTE_KEY IS NULL", attachmentId.id)
.run()
writableDatabase
.update(TABLE_NAME)
.values(REMOTE_IV to iv)
.where("$ID = ? AND $REMOTE_IV IS NULL", attachmentId.id)
.run()
}
}
/** /**
* Inserts new attachments in the table. The [Attachment]s may or may not have data, depending on whether it's an attachment we created locally or some * Inserts new attachments in the table. The [Attachment]s may or may not have data, depending on whether it's an attachment we created locally or some
* inbound attachment that we haven't fetched yet. * inbound attachment that we haven't fetched yet.
@ -1507,6 +1528,7 @@ class AttachmentTable(
cdn = Cdn.deserialize(jsonObject.getInt(CDN_NUMBER)), cdn = Cdn.deserialize(jsonObject.getInt(CDN_NUMBER)),
location = jsonObject.getString(REMOTE_LOCATION), location = jsonObject.getString(REMOTE_LOCATION),
key = jsonObject.getString(REMOTE_KEY), key = jsonObject.getString(REMOTE_KEY),
iv = null,
digest = null, digest = null,
incrementalDigest = null, incrementalDigest = null,
incrementalMacChunkSize = 0, incrementalMacChunkSize = 0,
@ -2040,6 +2062,7 @@ class AttachmentTable(
contentValues.put(REMOTE_INCREMENTAL_DIGEST, uploadTemplate?.incrementalDigest) contentValues.put(REMOTE_INCREMENTAL_DIGEST, uploadTemplate?.incrementalDigest)
contentValues.put(REMOTE_INCREMENTAL_DIGEST_CHUNK_SIZE, uploadTemplate?.incrementalMacChunkSize ?: 0) contentValues.put(REMOTE_INCREMENTAL_DIGEST_CHUNK_SIZE, uploadTemplate?.incrementalMacChunkSize ?: 0)
contentValues.put(REMOTE_KEY, uploadTemplate?.remoteKey) contentValues.put(REMOTE_KEY, uploadTemplate?.remoteKey)
contentValues.put(REMOTE_IV, uploadTemplate?.remoteIv)
contentValues.put(FILE_NAME, StorageUtil.getCleanFileName(attachment.fileName)) contentValues.put(FILE_NAME, StorageUtil.getCleanFileName(attachment.fileName))
contentValues.put(FAST_PREFLIGHT_ID, attachment.fastPreflightId) contentValues.put(FAST_PREFLIGHT_ID, attachment.fastPreflightId)
contentValues.put(VOICE_NOTE, if (attachment.voiceNote) 1 else 0) contentValues.put(VOICE_NOTE, if (attachment.voiceNote) 1 else 0)
@ -2120,6 +2143,7 @@ class AttachmentTable(
cdn = cursor.requireObject(CDN_NUMBER, Cdn.Serializer), cdn = cursor.requireObject(CDN_NUMBER, Cdn.Serializer),
location = cursor.requireString(REMOTE_LOCATION), location = cursor.requireString(REMOTE_LOCATION),
key = cursor.requireString(REMOTE_KEY), key = cursor.requireString(REMOTE_KEY),
iv = cursor.requireBlob(REMOTE_IV),
digest = cursor.requireBlob(REMOTE_DIGEST), digest = cursor.requireBlob(REMOTE_DIGEST),
incrementalDigest = cursor.requireBlob(REMOTE_INCREMENTAL_DIGEST), incrementalDigest = cursor.requireBlob(REMOTE_INCREMENTAL_DIGEST),
incrementalMacChunkSize = cursor.requireInt(REMOTE_INCREMENTAL_DIGEST_CHUNK_SIZE), incrementalMacChunkSize = cursor.requireInt(REMOTE_INCREMENTAL_DIGEST_CHUNK_SIZE),

View file

@ -101,6 +101,7 @@ import org.thoughtcrime.securesms.database.helpers.migration.V240_MessageFullTex
import org.thoughtcrime.securesms.database.helpers.migration.V241_ExpireTimerVersion import org.thoughtcrime.securesms.database.helpers.migration.V241_ExpireTimerVersion
import org.thoughtcrime.securesms.database.helpers.migration.V242_MessageFullTextSearchEmojiSupportV2 import org.thoughtcrime.securesms.database.helpers.migration.V242_MessageFullTextSearchEmojiSupportV2
import org.thoughtcrime.securesms.database.helpers.migration.V243_MessageFullTextSearchDisableSecureDelete import org.thoughtcrime.securesms.database.helpers.migration.V243_MessageFullTextSearchDisableSecureDelete
import org.thoughtcrime.securesms.database.helpers.migration.V244_AttachmentRemoteIv
/** /**
* Contains all of the database migrations for [SignalDatabase]. Broken into a separate file for cleanliness. * Contains all of the database migrations for [SignalDatabase]. Broken into a separate file for cleanliness.
@ -204,10 +205,11 @@ object SignalDatabaseMigrations {
240 to V240_MessageFullTextSearchSecureDelete, 240 to V240_MessageFullTextSearchSecureDelete,
241 to V241_ExpireTimerVersion, 241 to V241_ExpireTimerVersion,
242 to V242_MessageFullTextSearchEmojiSupportV2, 242 to V242_MessageFullTextSearchEmojiSupportV2,
243 to V243_MessageFullTextSearchDisableSecureDelete 243 to V243_MessageFullTextSearchDisableSecureDelete,
244 to V244_AttachmentRemoteIv
) )
const val DATABASE_VERSION = 243 const val DATABASE_VERSION = 244
@JvmStatic @JvmStatic
fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) { fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {

View file

@ -0,0 +1,18 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.database.helpers.migration
import android.app.Application
import net.zetetic.database.sqlcipher.SQLiteDatabase
/**
* Adds the remoteIv column to attachments.
*/
object V244_AttachmentRemoteIv : SignalDatabaseMigration {
override fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
db.execSQL("ALTER TABLE attachment ADD COLUMN remote_iv BLOB DEFAULT NULL;")
}
}

View file

@ -8,11 +8,9 @@ package org.thoughtcrime.securesms.jobs
import org.greenrobot.eventbus.EventBus import org.greenrobot.eventbus.EventBus
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.signal.protos.resumableuploads.ResumableUpload import org.signal.protos.resumableuploads.ResumableUpload
import org.thoughtcrime.securesms.attachments.Attachment
import org.thoughtcrime.securesms.attachments.AttachmentId import org.thoughtcrime.securesms.attachments.AttachmentId
import org.thoughtcrime.securesms.attachments.AttachmentUploadUtil import org.thoughtcrime.securesms.attachments.AttachmentUploadUtil
import org.thoughtcrime.securesms.attachments.DatabaseAttachment import org.thoughtcrime.securesms.attachments.DatabaseAttachment
import org.thoughtcrime.securesms.attachments.PointerAttachment
import org.thoughtcrime.securesms.backup.v2.BackupRepository import org.thoughtcrime.securesms.backup.v2.BackupRepository
import org.thoughtcrime.securesms.backup.v2.BackupV2Event import org.thoughtcrime.securesms.backup.v2.BackupV2Event
import org.thoughtcrime.securesms.database.AttachmentTable import org.thoughtcrime.securesms.database.AttachmentTable
@ -24,9 +22,8 @@ import org.thoughtcrime.securesms.jobs.protos.ArchiveAttachmentBackfillJobData
import org.whispersystems.signalservice.api.NetworkResult import org.whispersystems.signalservice.api.NetworkResult
import org.whispersystems.signalservice.api.archive.ArchiveMediaResponse import org.whispersystems.signalservice.api.archive.ArchiveMediaResponse
import org.whispersystems.signalservice.api.archive.ArchiveMediaUploadFormStatusCodes import org.whispersystems.signalservice.api.archive.ArchiveMediaUploadFormStatusCodes
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer import org.whispersystems.signalservice.api.attachment.AttachmentUploadResult
import java.io.IOException import java.io.IOException
import java.util.Optional
import kotlin.time.Duration.Companion.days import kotlin.time.Duration.Companion.days
/** /**
@ -159,16 +156,16 @@ class ArchiveAttachmentBackfillJob private constructor(
} }
Log.d(TAG, "Beginning upload...") Log.d(TAG, "Beginning upload...")
val remoteAttachment: SignalServiceAttachmentPointer = try { val attachmentApi = AppDependencies.signalServiceMessageSender.attachmentApi
AppDependencies.signalServiceMessageSender.uploadAttachment(attachmentStream) val uploadResult: AttachmentUploadResult = when (val result = attachmentApi.uploadAttachmentV4(attachmentStream)) {
} catch (e: IOException) { is NetworkResult.Success -> result.result
Log.w(TAG, "Failed to upload $attachmentId", e) is NetworkResult.ApplicationError -> throw result.throwable
return Result.retry(defaultBackoff()) is NetworkResult.NetworkError -> return Result.retry(defaultBackoff())
is NetworkResult.StatusCodeError -> return Result.retry(defaultBackoff())
} }
Log.d(TAG, "Upload complete!") Log.d(TAG, "Upload complete!")
val pointerAttachment: Attachment = PointerAttachment.forPointer(Optional.of(remoteAttachment), null, attachmentRecord.fastPreflightId).get() SignalDatabase.attachments.finalizeAttachmentAfterUpload(attachmentRecord.attachmentId, uploadResult)
SignalDatabase.attachments.finalizeAttachmentAfterUpload(attachmentRecord.attachmentId, pointerAttachment, remoteAttachment.uploadTimestamp)
SignalDatabase.attachments.setArchiveTransferState(attachmentRecord.attachmentId, AttachmentTable.ArchiveTransferState.BACKFILL_UPLOADED) SignalDatabase.attachments.setArchiveTransferState(attachmentRecord.attachmentId, AttachmentTable.ArchiveTransferState.BACKFILL_UPLOADED)
attachmentRecord = SignalDatabase.attachments.getAttachment(attachmentRecord.attachmentId) attachmentRecord = SignalDatabase.attachments.getAttachment(attachmentRecord.attachmentId)

View file

@ -7,6 +7,7 @@ package org.thoughtcrime.securesms.jobs
import android.text.TextUtils import android.text.TextUtils
import okhttp3.internal.http2.StreamResetException import okhttp3.internal.http2.StreamResetException
import org.greenrobot.eventbus.EventBus import org.greenrobot.eventbus.EventBus
import org.signal.core.util.Base64
import org.signal.core.util.concurrent.SignalExecutors import org.signal.core.util.concurrent.SignalExecutors
import org.signal.core.util.inRoundedDays import org.signal.core.util.inRoundedDays
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
@ -17,7 +18,6 @@ import org.thoughtcrime.securesms.attachments.Attachment
import org.thoughtcrime.securesms.attachments.AttachmentId import org.thoughtcrime.securesms.attachments.AttachmentId
import org.thoughtcrime.securesms.attachments.AttachmentUploadUtil import org.thoughtcrime.securesms.attachments.AttachmentUploadUtil
import org.thoughtcrime.securesms.attachments.DatabaseAttachment import org.thoughtcrime.securesms.attachments.DatabaseAttachment
import org.thoughtcrime.securesms.attachments.PointerAttachment
import org.thoughtcrime.securesms.database.AttachmentTable import org.thoughtcrime.securesms.database.AttachmentTable
import org.thoughtcrime.securesms.database.SignalDatabase import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.dependencies.AppDependencies import org.thoughtcrime.securesms.dependencies.AppDependencies
@ -32,6 +32,7 @@ import org.thoughtcrime.securesms.net.NotPushRegisteredException
import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.service.AttachmentProgressService import org.thoughtcrime.securesms.service.AttachmentProgressService
import org.thoughtcrime.securesms.util.RemoteConfig import org.thoughtcrime.securesms.util.RemoteConfig
import org.whispersystems.signalservice.api.attachment.AttachmentUploadResult
import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment import org.whispersystems.signalservice.api.messages.SignalServiceAttachment
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentStream import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentStream
@ -39,7 +40,6 @@ import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResumab
import org.whispersystems.signalservice.api.push.exceptions.ResumeLocationInvalidException import org.whispersystems.signalservice.api.push.exceptions.ResumeLocationInvalidException
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream import org.whispersystems.signalservice.internal.crypto.PaddingInputStream
import java.io.IOException import java.io.IOException
import java.util.Optional
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import kotlin.time.Duration.Companion.days import kotlin.time.Duration.Companion.days
import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.milliseconds
@ -136,7 +136,10 @@ class AttachmentUploadJob private constructor(
throw NotPushRegisteredException() throw NotPushRegisteredException()
} }
SignalDatabase.attachments.createKeyIvIfNecessary(attachmentId)
val messageSender = AppDependencies.signalServiceMessageSender val messageSender = AppDependencies.signalServiceMessageSender
val attachmentApi = messageSender.attachmentApi
val databaseAttachment = SignalDatabase.attachments.getAttachment(attachmentId) ?: throw InvalidAttachmentException("Cannot find the specified attachment.") val databaseAttachment = SignalDatabase.attachments.getAttachment(attachmentId) ?: throw InvalidAttachmentException("Cannot find the specified attachment.")
val timeSinceUpload = System.currentTimeMillis() - databaseAttachment.uploadTimestamp val timeSinceUpload = System.currentTimeMillis() - databaseAttachment.uploadTimestamp
@ -154,7 +157,17 @@ class AttachmentUploadJob private constructor(
if (uploadSpec == null) { if (uploadSpec == null) {
Log.d(TAG, "Need an upload spec. Fetching...") Log.d(TAG, "Need an upload spec. Fetching...")
uploadSpec = AppDependencies.signalServiceMessageSender.getResumableUploadSpec().toProto() uploadSpec = attachmentApi
.getAttachmentV4UploadForm()
.then { form ->
attachmentApi.getResumableUploadSpec(
key = Base64.decode(databaseAttachment.remoteKey!!),
iv = databaseAttachment.remoteIv!!,
uploadForm = form
)
}
.successOrThrow()
.toProto()
} else { } else {
Log.d(TAG, "Re-using existing upload spec.") Log.d(TAG, "Re-using existing upload spec.")
} }
@ -163,9 +176,8 @@ class AttachmentUploadJob private constructor(
try { try {
getAttachmentNotificationIfNeeded(databaseAttachment).use { notification -> getAttachmentNotificationIfNeeded(databaseAttachment).use { notification ->
buildAttachmentStream(databaseAttachment, notification, uploadSpec!!).use { localAttachment -> buildAttachmentStream(databaseAttachment, notification, uploadSpec!!).use { localAttachment ->
val remoteAttachment = messageSender.uploadAttachment(localAttachment) val uploadResult: AttachmentUploadResult = attachmentApi.uploadAttachmentV4(localAttachment).successOrThrow()
val attachment = PointerAttachment.forPointer(Optional.of(remoteAttachment), null, databaseAttachment.fastPreflightId).get() SignalDatabase.attachments.finalizeAttachmentAfterUpload(databaseAttachment.attachmentId, uploadResult)
SignalDatabase.attachments.finalizeAttachmentAfterUpload(databaseAttachment.attachmentId, attachment, remoteAttachment.uploadTimestamp)
ArchiveThumbnailUploadJob.enqueueIfNecessary(databaseAttachment.attachmentId) ArchiveThumbnailUploadJob.enqueueIfNecessary(databaseAttachment.attachmentId)
} }
} }

View file

@ -240,6 +240,7 @@ class UploadDependencyGraphTest {
cdn = attachment.cdn, cdn = attachment.cdn,
location = attachment.remoteLocation, location = attachment.remoteLocation,
key = attachment.remoteKey, key = attachment.remoteKey,
iv = attachment.remoteIv,
digest = attachment.remoteDigest, digest = attachment.remoteDigest,
incrementalDigest = attachment.incrementalDigest, incrementalDigest = attachment.incrementalDigest,
incrementalMacChunkSize = attachment.incrementalMacChunkSize, incrementalMacChunkSize = attachment.incrementalMacChunkSize,

View file

@ -40,6 +40,7 @@ object FakeMessageRecords {
cdnNumber: Int = 3, cdnNumber: Int = 3,
location: String = "", location: String = "",
key: String = "", key: String = "",
iv: ByteArray = byteArrayOf(),
relay: String = "", relay: String = "",
digest: ByteArray = byteArrayOf(), digest: ByteArray = byteArrayOf(),
incrementalDigest: ByteArray = byteArrayOf(), incrementalDigest: ByteArray = byteArrayOf(),
@ -67,42 +68,43 @@ object FakeMessageRecords {
thumbnailRestoreState: AttachmentTable.ThumbnailRestoreState = AttachmentTable.ThumbnailRestoreState.NONE thumbnailRestoreState: AttachmentTable.ThumbnailRestoreState = AttachmentTable.ThumbnailRestoreState.NONE
): DatabaseAttachment { ): DatabaseAttachment {
return DatabaseAttachment( return DatabaseAttachment(
attachmentId, attachmentId = attachmentId,
mmsId, mmsId = mmsId,
hasData, hasData = hasData,
hasThumbnail, hasThumbnail = hasThumbnail,
hasArchiveThumbnail, hasArchiveThumbnail = hasArchiveThumbnail,
contentType, contentType = contentType,
transferProgress, transferProgress = transferProgress,
size, size = size,
fileName, fileName = fileName,
Cdn.fromCdnNumber(cdnNumber), cdn = Cdn.fromCdnNumber(cdnNumber),
location, location = location,
key, key = key,
digest, iv = iv,
incrementalDigest, digest = digest,
incrementalMacChunkSize, incrementalDigest = incrementalDigest,
fastPreflightId, incrementalMacChunkSize = incrementalMacChunkSize,
voiceNote, fastPreflightId = fastPreflightId,
borderless, voiceNote = voiceNote,
videoGif, borderless = borderless,
width, videoGif = videoGif,
height, width = width,
quote, height = height,
caption, quote = quote,
stickerLocator, caption = caption,
blurHash, stickerLocator = stickerLocator,
audioHash, blurHash = blurHash,
transformProperties, audioHash = audioHash,
displayOrder, transformProperties = transformProperties,
uploadTimestamp, displayOrder = displayOrder,
dataHash, uploadTimestamp = uploadTimestamp,
archiveCdn, dataHash = dataHash,
archiveThumbnailCdn, archiveCdn = archiveCdn,
archiveMediaId, archiveThumbnailCdn = archiveThumbnailCdn,
archiveMediaName, archiveMediaName = archiveMediaId,
thumbnailRestoreState, archiveMediaId = archiveMediaName,
null thumbnailRestoreState = thumbnailRestoreState,
uuid = null
) )
} }

View file

@ -6,7 +6,11 @@
package org.whispersystems.signalservice.api package org.whispersystems.signalservice.api
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException
import org.whispersystems.signalservice.internal.util.JsonUtil
import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage
import org.whispersystems.signalservice.internal.websocket.WebsocketResponse
import java.io.IOException import java.io.IOException
import kotlin.reflect.KClass
typealias StatusCodeErrorAction = (NetworkResult.StatusCodeError<*>) -> Unit typealias StatusCodeErrorAction = (NetworkResult.StatusCodeError<*>) -> Unit
@ -43,6 +47,50 @@ sealed class NetworkResult<T>(
} catch (e: Throwable) { } catch (e: Throwable) {
ApplicationError(e) ApplicationError(e)
} }
/**
* A convenience method to convert a websocket request into a network result with simple conversion of the response body to the desired class.
* Common exceptions will be caught and translated to errors.
*/
@JvmStatic
fun <T : Any> fromWebSocketRequest(
signalWebSocket: SignalWebSocket,
request: WebSocketRequestMessage,
clazz: KClass<T>
): NetworkResult<T> = try {
val result = signalWebSocket.request(request)
.map { response: WebsocketResponse -> JsonUtil.fromJson(response.body, clazz.java) }
.blockingGet()
Success(result)
} catch (e: NonSuccessfulResponseCodeException) {
StatusCodeError(e)
} catch (e: IOException) {
NetworkError(e)
} catch (e: Throwable) {
ApplicationError(e)
}
/**
* A convenience method to convert a websocket request into a network result with the ability to convert the response to your target class.
* Common exceptions will be caught and translated to errors.
*/
@JvmStatic
fun <T : Any> fromWebSocketRequest(
signalWebSocket: SignalWebSocket,
request: WebSocketRequestMessage,
webSocketResponseConverter: WebSocketResponseConverter<T>
): NetworkResult<T> = try {
val result = signalWebSocket.request(request)
.map { response: WebsocketResponse -> webSocketResponseConverter.convert(response) }
.blockingGet()
Success(result)
} catch (e: NonSuccessfulResponseCodeException) {
StatusCodeError(e)
} catch (e: IOException) {
NetworkError(e)
} catch (e: Throwable) {
ApplicationError(e)
}
} }
/** Indicates the request was successful */ /** Indicates the request was successful */
@ -105,6 +153,34 @@ sealed class NetworkResult<T>(
} }
} }
/**
* Provides the ability to fallback to [fromFetch] if the current [NetworkResult] is non-successful.
*
* The [fallback] will only be triggered on non-[Success] results. You can provide a [unless] to limit what kinds of errors you fallback on
* (the default is to fallback on every error).
*
* This primary usecase of this is to make a websocket request (see [fromWebSocketRequest]) and fallback to rest upon failure.
*
* ```kotlin
* val user: NetworkResult<LocalUserModel> = NetworkResult
* .fromWebSocketRequest(websocket, request, LocalUserMode.class.java)
* .fallbackTo { result -> NetworkResult.fromFetch { http.getUser() } }
* ```
*
* @param unless If this lamba returns true, the fallback will not be triggered.
*/
fun fallbackToFetch(unless: (NetworkResult<T>) -> Boolean = { false }, fallback: Fetcher<T>): NetworkResult<T> {
if (this is Success) {
return this
}
return if (unless(this)) {
fromFetch(fallback)
} else {
this
}
}
/** /**
* Takes the output of one [NetworkResult] and passes it as the input to another if the operation is successful. * Takes the output of one [NetworkResult] and passes it as the input to another if the operation is successful.
* If it's non-successful, the [result] lambda is not run, and instead the original failure will be propagated. * If it's non-successful, the [result] lambda is not run, and instead the original failure will be propagated.
@ -183,4 +259,9 @@ sealed class NetworkResult<T>(
@Throws(Exception::class) @Throws(Exception::class)
fun fetch(): T fun fetch(): T
} }
fun interface WebSocketResponseConverter<T> {
@Throws(Exception::class)
fun convert(response: WebsocketResponse): T
}
} }

View file

@ -23,6 +23,7 @@ import org.signal.libsignal.protocol.state.SessionRecord;
import org.signal.libsignal.protocol.util.Pair; import org.signal.libsignal.protocol.util.Pair;
import org.signal.libsignal.zkgroup.groupsend.GroupSendFullToken; import org.signal.libsignal.zkgroup.groupsend.GroupSendFullToken;
import org.signal.libsignal.zkgroup.profiles.ClientZkProfileOperations; import org.signal.libsignal.zkgroup.profiles.ClientZkProfileOperations;
import org.whispersystems.signalservice.api.attachment.AttachmentApi;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil; import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil;
import org.whispersystems.signalservice.api.crypto.ContentHint; import org.whispersystems.signalservice.api.crypto.ContentHint;
import org.whispersystems.signalservice.api.crypto.EnvelopeContent; import org.whispersystems.signalservice.api.crypto.EnvelopeContent;
@ -170,6 +171,7 @@ public class SignalServiceMessageSender {
private static final int RETRY_COUNT = 4; private static final int RETRY_COUNT = 4;
private final PushServiceSocket socket; private final PushServiceSocket socket;
private final SignalWebSocket webSocket;
private final SignalServiceAccountDataStore aciStore; private final SignalServiceAccountDataStore aciStore;
private final SignalSessionLock sessionLock; private final SignalSessionLock sessionLock;
private final SignalServiceAddress localAddress; private final SignalServiceAddress localAddress;
@ -198,6 +200,7 @@ public class SignalServiceMessageSender {
boolean automaticNetworkRetry) boolean automaticNetworkRetry)
{ {
this.socket = new PushServiceSocket(urls, credentialsProvider, signalAgent, clientZkProfileOperations, automaticNetworkRetry); this.socket = new PushServiceSocket(urls, credentialsProvider, signalAgent, clientZkProfileOperations, automaticNetworkRetry);
this.webSocket = signalWebSocket;
this.aciStore = store.aci(); this.aciStore = store.aci();
this.sessionLock = sessionLock; this.sessionLock = sessionLock;
this.localAddress = new SignalServiceAddress(credentialsProvider.getAci(), credentialsProvider.getE164()); this.localAddress = new SignalServiceAddress(credentialsProvider.getAci(), credentialsProvider.getE164());
@ -212,6 +215,10 @@ public class SignalServiceMessageSender {
this.scheduler = Schedulers.from(executor, false, false); this.scheduler = Schedulers.from(executor, false, false);
} }
public AttachmentApi getAttachmentApi() {
return AttachmentApi.create(webSocket, socket);
}
/** /**
* Send a read receipt for a received message. * Send a read receipt for a received message.
* *
@ -799,8 +806,8 @@ public class SignalServiceMessageSender {
} }
public SignalServiceAttachmentPointer uploadAttachment(SignalServiceAttachmentStream attachment) throws IOException { public SignalServiceAttachmentPointer uploadAttachment(SignalServiceAttachmentStream attachment) throws IOException {
byte[] attachmentKey = attachment.getResumableUploadSpec().map(ResumableUploadSpec::getSecretKey).orElseGet(() -> Util.getSecretBytes(64)); byte[] attachmentKey = attachment.getResumableUploadSpec().map(ResumableUploadSpec::getAttachmentKey).orElseGet(() -> Util.getSecretBytes(64));
byte[] attachmentIV = attachment.getResumableUploadSpec().map(ResumableUploadSpec::getIV).orElseGet(() -> Util.getSecretBytes(16)); byte[] attachmentIV = attachment.getResumableUploadSpec().map(ResumableUploadSpec::getAttachmentIv).orElseGet(() -> Util.getSecretBytes(16));
long paddedLength = PaddingInputStream.getPaddedSize(attachment.getLength()); long paddedLength = PaddingInputStream.getPaddedSize(attachment.getLength());
InputStream dataStream = new PaddingInputStream(attachment.getInputStream(), attachment.getLength()); InputStream dataStream = new PaddingInputStream(attachment.getInputStream(), attachment.getLength());
long ciphertextLength = AttachmentCipherStreamUtil.getCiphertextLength(paddedLength); long ciphertextLength = AttachmentCipherStreamUtil.getCiphertextLength(paddedLength);
@ -811,7 +818,7 @@ public class SignalServiceMessageSender {
new AttachmentCipherOutputStreamFactory(attachmentKey, attachmentIV), new AttachmentCipherOutputStreamFactory(attachmentKey, attachmentIV),
attachment.getListener(), attachment.getListener(),
attachment.getCancelationSignal(), attachment.getCancelationSignal(),
attachment.getResumableUploadSpec().orElse(null)); attachment.getResumableUploadSpec().get());
if (attachment.getResumableUploadSpec().isEmpty()) { if (attachment.getResumableUploadSpec().isEmpty()) {
throw new IllegalStateException("Attachment must have a resumable upload spec."); throw new IllegalStateException("Attachment must have a resumable upload spec.");

View file

@ -0,0 +1,120 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.api.attachment
import org.whispersystems.signalservice.api.NetworkResult
import org.whispersystems.signalservice.api.SignalWebSocket
import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentStream
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream
import org.whispersystems.signalservice.internal.push.AttachmentUploadForm
import org.whispersystems.signalservice.internal.push.PushAttachmentData
import org.whispersystems.signalservice.internal.push.PushServiceSocket
import org.whispersystems.signalservice.internal.push.http.AttachmentCipherOutputStreamFactory
import org.whispersystems.signalservice.internal.push.http.ResumableUploadSpec
import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage
import java.io.InputStream
import java.security.SecureRandom
/**
* Class to interact with various attachment-related endpoints.
*/
class AttachmentApi(
private val signalWebSocket: SignalWebSocket,
private val pushServiceSocket: PushServiceSocket
) {
companion object {
@JvmStatic
fun create(signalWebSocket: SignalWebSocket, pushServiceSocket: PushServiceSocket): AttachmentApi {
return AttachmentApi(signalWebSocket, pushServiceSocket)
}
}
/**
* Gets a v4 attachment upload form, which provides the necessary information to upload an attachment.
*/
fun getAttachmentV4UploadForm(): NetworkResult<AttachmentUploadForm> {
val request = WebSocketRequestMessage(
id = SecureRandom().nextLong(),
verb = "GET",
path = "/v4/attachments/form/upload"
)
return NetworkResult
.fromWebSocketRequest(signalWebSocket, request, AttachmentUploadForm::class)
.fallbackToFetch(
unless = { it is NetworkResult.StatusCodeError && it.code == 209 },
fallback = { pushServiceSocket.attachmentV4UploadAttributes }
)
}
/**
* Gets a resumable upload spec, which can be saved and re-used across upload attempts to resume upload progress.
*/
fun getResumableUploadSpec(key: ByteArray, iv: ByteArray, uploadForm: AttachmentUploadForm): NetworkResult<ResumableUploadSpec> {
return getResumableUploadUrl(uploadForm)
.map { url ->
ResumableUploadSpec(
attachmentKey = key,
attachmentIv = iv,
cdnKey = uploadForm.key,
cdnNumber = uploadForm.cdn,
resumeLocation = url,
expirationTimestamp = System.currentTimeMillis() + PushServiceSocket.CDN2_RESUMABLE_LINK_LIFETIME_MILLIS,
headers = uploadForm.headers
)
}
}
/**
* Uploads an attachment using the v4 upload scheme.
*/
fun uploadAttachmentV4(attachmentStream: SignalServiceAttachmentStream): NetworkResult<AttachmentUploadResult> {
if (attachmentStream.resumableUploadSpec.isEmpty) {
throw IllegalStateException("Attachment must have a resumable upload spec!")
}
return NetworkResult.fromFetch {
val resumableUploadSpec = attachmentStream.resumableUploadSpec.get()
val paddedLength = PaddingInputStream.getPaddedSize(attachmentStream.length)
val dataStream: InputStream = PaddingInputStream(attachmentStream.inputStream, attachmentStream.length)
val ciphertextLength = AttachmentCipherStreamUtil.getCiphertextLength(paddedLength)
val attachmentData = PushAttachmentData(
contentType = attachmentStream.contentType,
data = dataStream,
dataSize = ciphertextLength,
incremental = attachmentStream.isFaststart,
outputStreamFactory = AttachmentCipherOutputStreamFactory(resumableUploadSpec.attachmentKey, resumableUploadSpec.attachmentIv),
listener = attachmentStream.listener,
cancelationSignal = attachmentStream.cancelationSignal,
resumableUploadSpec = attachmentStream.resumableUploadSpec.get()
)
val digestInfo = pushServiceSocket.uploadAttachment(attachmentData)
AttachmentUploadResult(
remoteId = SignalServiceAttachmentRemoteId.V4(attachmentData.resumableUploadSpec.cdnKey),
cdnNumber = attachmentData.resumableUploadSpec.cdnNumber,
key = resumableUploadSpec.attachmentKey,
iv = resumableUploadSpec.attachmentIv,
digest = digestInfo.digest,
incrementalDigest = digestInfo.incrementalDigest,
incrementalDigestChunkSize = digestInfo.incrementalMacChunkSize,
uploadTimestamp = attachmentStream.uploadTimestamp,
dataSize = attachmentData.dataSize
)
}
}
private fun getResumableUploadUrl(uploadForm: AttachmentUploadForm): NetworkResult<String> {
return NetworkResult.fromFetch {
pushServiceSocket.getResumableUploadUrl(uploadForm)
}
}
}

View file

@ -0,0 +1,23 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.api.attachment
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId
/**
* The result of uploading an attachment. Just the additional metadata related to the upload itself.
*/
class AttachmentUploadResult(
val remoteId: SignalServiceAttachmentRemoteId,
val cdnNumber: Int,
val key: ByteArray,
val iv: ByteArray,
val digest: ByteArray,
val incrementalDigest: ByteArray?,
val incrementalDigestChunkSize: Int,
val dataSize: Long,
val uploadTimestamp: Long
)

View file

@ -5,4 +5,8 @@
package org.whispersystems.signalservice.internal.crypto package org.whispersystems.signalservice.internal.crypto
data class AttachmentDigest(val digest: ByteArray, val incrementalDigest: ByteArray?, val incrementalMacChunkSize: Int) data class AttachmentDigest(
val digest: ByteArray,
val incrementalDigest: ByteArray?,
val incrementalMacChunkSize: Int
)

View file

@ -1,74 +0,0 @@
/**
* Copyright (C) 2014-2016 Open Whisper Systems
*
* Licensed according to the LICENSE file in this repository.
*/
package org.whispersystems.signalservice.internal.push;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener;
import org.whispersystems.signalservice.internal.push.http.CancelationSignal;
import org.whispersystems.signalservice.internal.push.http.OutputStreamFactory;
import org.whispersystems.signalservice.internal.push.http.ResumableUploadSpec;
import java.io.InputStream;
public class PushAttachmentData {
private final String contentType;
private final InputStream data;
private final long dataSize;
private final boolean incremental;
private final OutputStreamFactory outputStreamFactory;
private final ProgressListener listener;
private final CancelationSignal cancelationSignal;
private final ResumableUploadSpec resumableUploadSpec;
public PushAttachmentData(String contentType, InputStream data, long dataSize,
boolean incremental, OutputStreamFactory outputStreamFactory,
ProgressListener listener, CancelationSignal cancelationSignal,
ResumableUploadSpec resumableUploadSpec)
{
this.contentType = contentType;
this.data = data;
this.dataSize = dataSize;
this.incremental = incremental;
this.outputStreamFactory = outputStreamFactory;
this.resumableUploadSpec = resumableUploadSpec;
this.listener = listener;
this.cancelationSignal = cancelationSignal;
}
public String getContentType() {
return contentType;
}
public InputStream getData() {
return data;
}
public long getDataSize() {
return dataSize;
}
public boolean getIncremental() {
return incremental;
}
public OutputStreamFactory getOutputStreamFactory() {
return outputStreamFactory;
}
public ProgressListener getListener() {
return listener;
}
public CancelationSignal getCancelationSignal() {
return cancelationSignal;
}
public ResumableUploadSpec getResumableUploadSpec() {
return resumableUploadSpec;
}
}

View file

@ -0,0 +1,26 @@
/**
* Copyright (C) 2014-2016 Open Whisper Systems
*
* Licensed according to the LICENSE file in this repository.
*/
package org.whispersystems.signalservice.internal.push
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment
import org.whispersystems.signalservice.internal.push.http.CancelationSignal
import org.whispersystems.signalservice.internal.push.http.OutputStreamFactory
import org.whispersystems.signalservice.internal.push.http.ResumableUploadSpec
import java.io.InputStream
/**
* A bundle of data needed to start an attachment upload.
*/
data class PushAttachmentData(
val contentType: String?,
val data: InputStream,
val dataSize: Long,
val incremental: Boolean,
val outputStreamFactory: OutputStreamFactory,
val listener: SignalServiceAttachment.ProgressListener?,
val cancelationSignal: CancelationSignal?,
val resumableUploadSpec: ResumableUploadSpec
)

View file

@ -343,7 +343,7 @@ public class PushServiceSocket {
private static final ResponseCodeHandler NO_HANDLER = new EmptyResponseCodeHandler(); private static final ResponseCodeHandler NO_HANDLER = new EmptyResponseCodeHandler();
private static final ResponseCodeHandler UNOPINIONATED_HANDLER = new UnopinionatedResponseCodeHandler(); private static final ResponseCodeHandler UNOPINIONATED_HANDLER = new UnopinionatedResponseCodeHandler();
private static final long CDN2_RESUMABLE_LINK_LIFETIME_MILLIS = TimeUnit.DAYS.toMillis(7); public static final long CDN2_RESUMABLE_LINK_LIFETIME_MILLIS = TimeUnit.DAYS.toMillis(7);
private static final int MAX_FOLLOW_UPS = 20; private static final int MAX_FOLLOW_UPS = 20;

View file

@ -1,121 +0,0 @@
package org.whispersystems.signalservice.internal.push.http;
import org.signal.protos.resumableuploads.ResumableUpload;
import org.whispersystems.signalservice.api.push.exceptions.ResumeLocationInvalidException;
import org.signal.core.util.Base64;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
import okio.ByteString;
public final class ResumableUploadSpec {
private final byte[] secretKey;
private final byte[] iv;
private final String cdnKey;
private final Integer cdnNumber;
private final String resumeLocation;
private final Long expirationTimestamp;
private final Map<String, String> headers;
public ResumableUploadSpec(byte[] secretKey,
byte[] iv,
String cdnKey,
int cdnNumber,
String resumeLocation,
long expirationTimestamp,
Map<String, String> headers)
{
this.secretKey = secretKey;
this.iv = iv;
this.cdnKey = cdnKey;
this.cdnNumber = cdnNumber;
this.resumeLocation = resumeLocation;
this.expirationTimestamp = expirationTimestamp;
this.headers = headers;
}
public byte[] getSecretKey() {
return secretKey;
}
public byte[] getIV() {
return iv;
}
public String getCdnKey() {
return cdnKey;
}
public Integer getCdnNumber() {
return cdnNumber;
}
public String getResumeLocation() {
return resumeLocation;
}
public Long getExpirationTimestamp() {
return expirationTimestamp;
}
public Map<String, String> getHeaders() {
return headers;
}
public ResumableUpload toProto() {
ResumableUpload.Builder builder = new ResumableUpload.Builder()
.secretKey(ByteString.of(getSecretKey()))
.iv(ByteString.of(getIV()))
.timeout(getExpirationTimestamp())
.cdnNumber(getCdnNumber())
.cdnKey(getCdnKey())
.location(getResumeLocation())
.timeout(getExpirationTimestamp());
builder.headers(
headers.entrySet()
.stream()
.map(e -> new ResumableUpload.Header.Builder().key(e.getKey()).value_(e.getValue()).build())
.collect(Collectors.toList())
);
return builder.build();
}
public String serialize() {
return Base64.encodeWithPadding(toProto().encode());
}
public static ResumableUploadSpec deserialize(String serializedSpec) throws ResumeLocationInvalidException {
try {
ResumableUpload resumableUpload = ResumableUpload.ADAPTER.decode(Base64.decode(serializedSpec));
return from(resumableUpload);
} catch (IOException e) {
throw new ResumeLocationInvalidException();
}
}
public static ResumableUploadSpec from(ResumableUpload resumableUpload) throws ResumeLocationInvalidException {
if (resumableUpload == null) return null;
Map<String, String> headers = new HashMap<>();
for (ResumableUpload.Header header : resumableUpload.headers) {
headers.put(header.key, header.value_);
}
return new ResumableUploadSpec(
resumableUpload.secretKey.toByteArray(),
resumableUpload.iv.toByteArray(),
resumableUpload.cdnKey,
resumableUpload.cdnNumber,
resumableUpload.location,
resumableUpload.timeout,
headers
);
}
}

View file

@ -0,0 +1,71 @@
package org.whispersystems.signalservice.internal.push.http
import okio.ByteString.Companion.toByteString
import org.signal.core.util.Base64
import org.signal.protos.resumableuploads.ResumableUpload
import org.whispersystems.signalservice.api.push.exceptions.ResumeLocationInvalidException
import java.io.IOException
/**
* Contains data around how to begin or resume an upload.
* For given attachment, this data be saved and reused within the [expirationTimestamp] window.
*/
class ResumableUploadSpec(
val attachmentKey: ByteArray,
val attachmentIv: ByteArray,
val cdnKey: String,
val cdnNumber: Int,
val resumeLocation: String,
val expirationTimestamp: Long,
val headers: Map<String, String>
) {
fun toProto(): ResumableUpload {
return ResumableUpload(
secretKey = attachmentKey.toByteString(),
iv = attachmentIv.toByteString(),
timeout = expirationTimestamp,
cdnNumber = cdnNumber,
cdnKey = cdnKey,
location = resumeLocation,
headers = headers.entries.map { ResumableUpload.Header(key = it.key, value_ = it.value) }
)
}
fun serialize(): String {
return Base64.encodeWithPadding(toProto().encode())
}
companion object {
@Throws(ResumeLocationInvalidException::class)
fun deserialize(serializedSpec: String?): ResumableUploadSpec? {
try {
val resumableUpload = ResumableUpload.ADAPTER.decode(Base64.decode(serializedSpec!!))
return from(resumableUpload)
} catch (e: IOException) {
throw ResumeLocationInvalidException()
}
}
@Throws(ResumeLocationInvalidException::class)
fun from(resumableUpload: ResumableUpload?): ResumableUploadSpec? {
if (resumableUpload == null) {
return null
}
val headers: MutableMap<String, String> = HashMap()
for (header in resumableUpload.headers) {
headers[header.key] = header.value_
}
return ResumableUploadSpec(
attachmentKey = resumableUpload.secretKey.toByteArray(),
attachmentIv = resumableUpload.iv.toByteArray(),
cdnKey = resumableUpload.cdnKey,
cdnNumber = resumableUpload.cdnNumber,
resumeLocation = resumableUpload.location,
expirationTimestamp = resumableUpload.timeout,
headers = headers
)
}
}
}