Fix use of media credentials for fetching/restoring media related data.

This commit is contained in:
Cody Henthorne 2024-11-04 16:42:15 -05:00 committed by Greyson Parrelli
parent d7c08690ee
commit f848a78365
11 changed files with 65 additions and 54 deletions

View file

@ -23,6 +23,7 @@ import org.signal.core.util.getAllTriggerDefinitions
import org.signal.core.util.getForeignKeyViolations
import org.signal.core.util.logging.Log
import org.signal.core.util.stream.NonClosingOutputStream
import org.signal.core.util.urlEncode
import org.signal.core.util.withinTransaction
import org.signal.libsignal.messagebackup.MessageBackup
import org.signal.libsignal.messagebackup.MessageBackup.ValidationResult
@ -112,6 +113,7 @@ object BackupRepository {
SignalStore.backup.backupsInitialized = false
SignalStore.backup.messageCredentials.clearAll()
SignalStore.backup.mediaCredentials.clearAll()
SignalStore.backup.cachedMediaCdnPath = null
}
403 -> {
@ -716,7 +718,7 @@ object BackupRepository {
return initBackupAndFetchAuth(backupKey, mediaRootBackupKey)
.then { credential ->
SignalNetwork.archive.getBackupInfo(backupKey, SignalStore.account.requireAci(), credential.messageCredential)
SignalNetwork.archive.getBackupInfo(mediaRootBackupKey, SignalStore.account.requireAci(), credential.mediaCredential)
.map { it.usedSpace }
}
}
@ -744,7 +746,7 @@ object BackupRepository {
return initBackupAndFetchAuth(backupKey, mediaRootBackupKey)
.map { credential ->
val zkCredential = SignalNetwork.archive.getZkCredential(backupKey, aci, credential.mediaCredential)
val zkCredential = SignalNetwork.archive.getZkCredential(backupKey, aci, credential.messageCredential)
if (zkCredential.backupLevel == BackupLevel.PAID) {
MessageBackupTier.PAID
} else {
@ -762,16 +764,16 @@ object BackupRepository {
return initBackupAndFetchAuth(backupKey, mediaRootBackupKey)
.then { credential ->
SignalNetwork.archive.getBackupInfo(backupKey, SignalStore.account.requireAci(), credential.messageCredential)
SignalNetwork.archive.getBackupInfo(mediaRootBackupKey, SignalStore.account.requireAci(), credential.mediaCredential)
.map { it to credential }
}
.then { pair ->
val (info, credential) = pair
val (mediaBackupInfo, credential) = pair
SignalNetwork.archive.debugGetUploadedMediaItemMetadata(mediaRootBackupKey, SignalStore.account.requireAci(), credential.mediaCredential)
.also { Log.i(TAG, "MediaItemMetadataResult: $it") }
.map { mediaObjects ->
BackupMetadata(
usedSpace = info.usedSpace ?: 0,
usedSpace = mediaBackupInfo.usedSpace ?: 0,
mediaCount = mediaObjects.size.toLong()
)
}
@ -1119,21 +1121,15 @@ object BackupRepository {
}
/**
* Retrieves backupDir and mediaDir, preferring cached value if available.
* Retrieves media-specific cdn path, preferring cached value if available.
*
* These will only ever change if the backup expires.
* This will change if the backup expires, a new backup-id is set, or the delete all endpoint is called.
*/
fun getCdnBackupDirectories(): NetworkResult<BackupDirectories> {
val cachedBackupDirectory = SignalStore.backup.cachedBackupDirectory
val cachedBackupMediaDirectory = SignalStore.backup.cachedBackupMediaDirectory
fun getArchivedMediaCdnPath(): NetworkResult<String> {
val cachedMediaPath = SignalStore.backup.cachedMediaCdnPath
if (cachedBackupDirectory != null && cachedBackupMediaDirectory != null) {
return NetworkResult.Success(
BackupDirectories(
backupDir = cachedBackupDirectory,
mediaDir = cachedBackupMediaDirectory
)
)
if (cachedMediaPath != null) {
return NetworkResult.Success(cachedMediaPath)
}
val backupKey = SignalStore.backup.messageBackupKey
@ -1141,15 +1137,14 @@ object BackupRepository {
return initBackupAndFetchAuth(backupKey, mediaRootBackupKey)
.then { credential ->
SignalNetwork.archive.getBackupInfo(backupKey, SignalStore.account.requireAci(), credential.messageCredential).map {
SignalNetwork.archive.getBackupInfo(mediaRootBackupKey, SignalStore.account.requireAci(), credential.mediaCredential).map {
SignalStore.backup.usedBackupMediaSpace = it.usedSpace ?: 0L
BackupDirectories(it.backupDir!!, it.mediaDir!!)
"${it.backupDir!!.urlEncode()}/${it.mediaDir!!.urlEncode()}"
}
}
.also {
if (it is NetworkResult.Success) {
SignalStore.backup.cachedBackupDirectory = it.result.backupDir
SignalStore.backup.cachedBackupMediaDirectory = it.result.mediaDir
SignalStore.backup.cachedMediaCdnPath = it.result
}
}
}
@ -1303,8 +1298,6 @@ object BackupRepository {
data class ArchivedMediaObject(val mediaId: String, val cdn: Int)
data class BackupDirectories(val backupDir: String, val mediaDir: String)
class ExportState(val backupTime: Long, val mediaBackupEnabled: Boolean) {
val recipientIds: MutableSet<Long> = hashSetOf()
val threadIds: MutableSet<Long> = hashSetOf()

View file

@ -35,11 +35,10 @@ fun DatabaseAttachment.createArchiveAttachmentPointer(useArchiveCdn: Boolean): S
return try {
val (remoteId, cdnNumber) = if (useArchiveCdn) {
val mediaRootBackupKey = SignalStore.backup.mediaRootBackupKey
val backupDirectories = BackupRepository.getCdnBackupDirectories().successOrThrow()
val mediaCdnPath = BackupRepository.getArchivedMediaCdnPath().successOrThrow()
val id = SignalServiceAttachmentRemoteId.Backup(
backupDir = backupDirectories.backupDir,
mediaDir = backupDirectories.mediaDir,
mediaCdnPath = mediaCdnPath,
mediaId = mediaRootBackupKey.deriveMediaId(MediaName(archiveMediaName!!)).encode()
)
@ -92,15 +91,14 @@ fun DatabaseAttachment.createArchiveThumbnailPointer(): SignalServiceAttachmentP
}
val mediaRootBackupKey = SignalStore.backup.mediaRootBackupKey
val backupDirectories = BackupRepository.getCdnBackupDirectories().successOrThrow()
val mediaCdnPath = BackupRepository.getArchivedMediaCdnPath().successOrThrow()
return try {
val key = mediaRootBackupKey.deriveThumbnailTransitKey(getThumbnailMediaName())
val mediaId = mediaRootBackupKey.deriveMediaId(getThumbnailMediaName()).encode()
SignalServiceAttachmentPointer(
cdnNumber = archiveCdn,
remoteId = SignalServiceAttachmentRemoteId.Backup(
backupDir = backupDirectories.backupDir,
mediaDir = backupDirectories.mediaDir,
mediaCdnPath = mediaCdnPath,
mediaId = mediaId
),
contentType = null,

View file

@ -223,6 +223,7 @@ class InternalBackupPlaygroundViewModel : ViewModel() {
}
else -> {
Log.w(TAG, "Error checking remote backup state", result.getCause())
_state.value = _state.value.copy(remoteBackupState = RemoteBackupState.GeneralError)
}
}

View file

@ -1583,7 +1583,7 @@ class AttachmentTable(
SELECT
$mmsId,
$CONTENT_TYPE,
$TRANSFER_PROGRESS_PENDING,
$TRANSFER_NEEDS_RESTORE,
$CDN_NUMBER,
$REMOTE_LOCATION,
$REMOTE_DIGEST,

View file

@ -98,8 +98,14 @@ class BackupMessagesJob private constructor(parameters: Parameters) : Job(parame
Log.i(TAG, "Successfully uploaded backup file.")
SignalStore.backup.hasBackupBeenUploaded = true
}
is NetworkResult.NetworkError -> return Result.retry(defaultBackoff())
is NetworkResult.StatusCodeError -> return Result.retry(defaultBackoff())
is NetworkResult.NetworkError -> {
Log.i(TAG, "Network failure", result.getCause())
return Result.retry(defaultBackoff())
}
is NetworkResult.StatusCodeError -> {
Log.i(TAG, "Status code failure", result.getCause())
return Result.retry(defaultBackoff())
}
is NetworkResult.ApplicationError -> throw result.throwable
}
}

View file

@ -221,7 +221,7 @@ class RestoreAttachmentJob private constructor(
val downloadResult = if (useArchiveCdn) {
archiveFile = SignalDatabase.attachments.getOrCreateArchiveTransferFile(attachmentId)
val cdnCredentials = BackupRepository.getCdnReadCredentials(BackupRepository.CredentialType.MESSAGE, attachment.archiveCdn).successOrThrow().headers
val cdnCredentials = BackupRepository.getCdnReadCredentials(BackupRepository.CredentialType.MEDIA, attachment.archiveCdn).successOrThrow().headers
messageReceiver
.retrieveArchivedAttachment(
@ -265,6 +265,7 @@ class RestoreAttachmentJob private constructor(
return
} else if (e.code == 401 && useArchiveCdn) {
SignalStore.backup.mediaCredentials.cdnReadCredentials = null
SignalStore.backup.cachedMediaCdnPath = null
throw RetryLaterException(e)
}
}

View file

@ -42,8 +42,7 @@ class BackupValues(store: KeyValueStore) : SignalStoreValues(store) {
private const val KEY_TOTAL_RESTORABLE_ATTACHMENT_SIZE = "backup.totalRestorableAttachmentSize"
private const val KEY_BACKUP_FREQUENCY = "backup.backupFrequency"
private const val KEY_CDN_BACKUP_DIRECTORY = "backup.cdn.directory"
private const val KEY_CDN_BACKUP_MEDIA_DIRECTORY = "backup.cdn.mediaDirectory"
private const val KEY_CDN_MEDIA_PATH = "backup.cdn.mediaPath"
private const val KEY_BACKUP_OVER_CELLULAR = "backup.useCellular"
private const val KEY_OPTIMIZE_STORAGE = "backup.optimizeStorage"
@ -69,8 +68,7 @@ class BackupValues(store: KeyValueStore) : SignalStoreValues(store) {
override fun onFirstEverAppLaunch() = Unit
override fun getKeysToIncludeInBackup(): List<String> = emptyList()
var cachedBackupDirectory: String? by stringValue(KEY_CDN_BACKUP_DIRECTORY, null)
var cachedBackupMediaDirectory: String? by stringValue(KEY_CDN_BACKUP_MEDIA_DIRECTORY, null)
var cachedMediaCdnPath: String? by stringValue(KEY_CDN_MEDIA_PATH, null)
var usedBackupMediaSpace: Long by longValue(KEY_BACKUP_USED_MEDIA_SPACE, 0L)
var lastBackupProtoSize: Long by longValue(KEY_BACKUP_LAST_PROTO_SIZE, 0L)
@ -116,6 +114,8 @@ class BackupValues(store: KeyValueStore) : SignalStoreValues(store) {
lock.withLock {
Log.i(TAG, "Setting MediaRootBackupKey", Throwable())
putBlob(KEY_MEDIA_ROOT_BACKUP_KEY, value.value)
mediaCredentials.clearAll()
cachedMediaCdnPath = null
}
}
@ -240,6 +240,7 @@ class BackupValues(store: KeyValueStore) : SignalStoreValues(store) {
/** Clears all credentials. */
fun clearAll() {
putString(authKey, null)
cdnReadCredentials = null
}
/** Credentials to read from the CDN. */

View file

@ -5,6 +5,8 @@
package org.signal.core.util
import java.net.URLEncoder
import java.nio.charset.StandardCharsets
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
@ -70,3 +72,10 @@ fun CharSequence?.isNotNullOrBlank(): Boolean {
}
return !this.isNullOrBlank()
}
/**
* Encode this string in a url-safe way with UTF-8 encoding.
*/
fun String.urlEncode(): String {
return URLEncoder.encode(this, StandardCharsets.UTF_8.name())
}

View file

@ -138,14 +138,15 @@ class ArchiveApi(private val pushServiceSocket: PushServiceSocket) {
}
/**
* Fetches metadata about your current backup.
* Will return a [NetworkResult.StatusCodeError] with status code 404 if you haven't uploaded a
* backup yet.
* Fetches metadata about your current backup. This will be different for different key/credential pairs. For example, message credentials will always
* return 0 for used space since that is stored under the media key/credential.
*
* Will return a [NetworkResult.StatusCodeError] with status code 404 if you haven't uploaded a backup yet.
*/
fun getBackupInfo(messageBackupKey: MessageBackupKey, aci: ACI, messageServiceCredential: ArchiveServiceCredential): NetworkResult<ArchiveGetBackupInfoResponse> {
fun getBackupInfo(backupKey: BackupKey, aci: ACI, messageServiceCredential: ArchiveServiceCredential): NetworkResult<ArchiveGetBackupInfoResponse> {
return NetworkResult.fromFetch {
val zkCredential = getZkCredential(messageBackupKey, aci, messageServiceCredential)
val presentationData = CredentialPresentationData.from(messageBackupKey, aci, zkCredential, backupServerPublicParams)
val zkCredential = getZkCredential(backupKey, aci, messageServiceCredential)
val presentationData = CredentialPresentationData.from(backupKey, aci, zkCredential, backupServerPublicParams)
pushServiceSocket.getArchiveBackupInfo(presentationData.toArchiveCredentialPresentation())
}
}

View file

@ -24,7 +24,7 @@ sealed interface SignalServiceAttachmentRemoteId {
override fun toString() = cdnKey
}
data class Backup(val backupDir: String, val mediaDir: String, val mediaId: String) : SignalServiceAttachmentRemoteId {
data class Backup(val mediaCdnPath: String, val mediaId: String) : SignalServiceAttachmentRemoteId {
override fun toString() = mediaId
}

View file

@ -337,7 +337,7 @@ public class PushServiceSocket {
private static final String ARCHIVE_MEDIA_LIST = "/v1/archives/media?limit=%d";
private static final String ARCHIVE_MEDIA_BATCH = "/v1/archives/media/batch";
private static final String ARCHIVE_MEDIA_DELETE = "/v1/archives/media/delete";
private static final String ARCHIVE_MEDIA_DOWNLOAD_PATH = "backups/%s/%s/%s";
private static final String ARCHIVE_MEDIA_DOWNLOAD_PATH = "backups/%s/%s";
private static final String SET_SHARE_SET_PATH = "/v3/backup/share-set";
@ -1033,20 +1033,21 @@ public class PushServiceSocket {
downloadFromCdn(destination, cdnNumber, headers, cdnPath, maxSizeBytes, listener);
}
public void retrieveAttachment(int cdnNumber, Map<String, String> headers, SignalServiceAttachmentRemoteId cdnPath, File destination, long maxSizeBytes, ProgressListener listener)
public void retrieveAttachment(int cdnNumber, Map<String, String> headers, SignalServiceAttachmentRemoteId remoteId, File destination, long maxSizeBytes, ProgressListener listener)
throws IOException, MissingConfigurationException
{
final String path;
if (cdnPath instanceof SignalServiceAttachmentRemoteId.V2) {
path = String.format(Locale.US, ATTACHMENT_ID_DOWNLOAD_PATH, ((SignalServiceAttachmentRemoteId.V2) cdnPath).getCdnId());
} else if (cdnPath instanceof SignalServiceAttachmentRemoteId.V4) {
String urlEncodedKey = urlEncode(((SignalServiceAttachmentRemoteId.V4) cdnPath).getCdnKey());
if (remoteId instanceof SignalServiceAttachmentRemoteId.V2) {
path = String.format(Locale.US, ATTACHMENT_ID_DOWNLOAD_PATH, ((SignalServiceAttachmentRemoteId.V2) remoteId).getCdnId());
} else if (remoteId instanceof SignalServiceAttachmentRemoteId.V4) {
String urlEncodedKey = urlEncode(((SignalServiceAttachmentRemoteId.V4) remoteId).getCdnKey());
path = String.format(Locale.US, ATTACHMENT_KEY_DOWNLOAD_PATH, urlEncodedKey);
} else if (cdnPath instanceof SignalServiceAttachmentRemoteId.Backup) {
SignalServiceAttachmentRemoteId.Backup backupCdnId = (SignalServiceAttachmentRemoteId.Backup) cdnPath;
path = String.format(Locale.US, ARCHIVE_MEDIA_DOWNLOAD_PATH, backupCdnId.getBackupDir(), backupCdnId.getMediaDir(), backupCdnId.getMediaId());
} else if (remoteId instanceof SignalServiceAttachmentRemoteId.Backup) {
//noinspection PatternVariableCanBeUsed
SignalServiceAttachmentRemoteId.Backup backupCdnRemoteId = (SignalServiceAttachmentRemoteId.Backup) remoteId;
path = String.format(Locale.US, ARCHIVE_MEDIA_DOWNLOAD_PATH, backupCdnRemoteId.getMediaCdnPath(), backupCdnRemoteId.getMediaId());
} else {
throw new IllegalArgumentException("Invalid cdnPath type: " + cdnPath.getClass().getSimpleName());
throw new IllegalArgumentException("Invalid cdnPath type: " + remoteId.getClass().getSimpleName());
}
downloadFromCdn(destination, cdnNumber, headers, path, maxSizeBytes, listener);
}