diff --git a/app/src/main/java/org/thoughtcrime/securesms/backup/v2/BackupRepository.kt b/app/src/main/java/org/thoughtcrime/securesms/backup/v2/BackupRepository.kt index fd83f9ede0..8193585d4b 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/backup/v2/BackupRepository.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/backup/v2/BackupRepository.kt @@ -37,6 +37,7 @@ import org.thoughtcrime.securesms.jobs.RequestGroupV2InfoJob import org.thoughtcrime.securesms.keyvalue.SignalStore import org.thoughtcrime.securesms.recipients.RecipientId import org.whispersystems.signalservice.api.NetworkResult +import org.whispersystems.signalservice.api.StatusCodeErrorAction import org.whispersystems.signalservice.api.archive.ArchiveGetMediaItemsResponse import org.whispersystems.signalservice.api.archive.ArchiveMediaRequest import org.whispersystems.signalservice.api.archive.ArchiveServiceCredential @@ -61,6 +62,13 @@ object BackupRepository { private val TAG = Log.tag(BackupRepository::class.java) private const val VERSION = 1L + private val resetInitializedStateErrorAction: StatusCodeErrorAction = { error -> + if (error.code == 401) { + Log.i(TAG, "Resetting initialized state due to 401.") + SignalStore.backup().backupsInitialized = false + } + } + fun export(outputStream: OutputStream, append: (ByteArray) -> Unit, plaintext: Boolean = false) { val eventTimer = EventTimer() val writer: BackupExportWriter = if (plaintext) { @@ -229,13 +237,7 @@ object BackupRepository { val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey() - return api - .triggerBackupIdReservation(backupKey) - .then { getAuthCredential() } - .then { credential -> - api.setPublicKey(backupKey, credential) - .map { credential } - } + return initBackupAndFetchAuth(backupKey) .then { credential -> api.getArchiveMediaItemsPage(backupKey, credential, limit, cursor) } @@ -248,14 +250,7 @@ object BackupRepository { val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey() - return api - .triggerBackupIdReservation(backupKey) - .then { getAuthCredential() } - .then { credential -> - api.setPublicKey(backupKey, credential) - .also { Log.i(TAG, "PublicKeyResult: $it") } - .map { credential } - } + return initBackupAndFetchAuth(backupKey) .then { credential -> api.getBackupInfo(backupKey, credential) .map { it to credential } @@ -282,14 +277,7 @@ object BackupRepository { val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey() - return api - .triggerBackupIdReservation(backupKey) - .then { getAuthCredential() } - .then { credential -> - api.setPublicKey(backupKey, credential) - .also { Log.i(TAG, "PublicKeyResult: $it") } - .map { credential } - } + return initBackupAndFetchAuth(backupKey) .then { credential -> api.getMessageBackupUploadForm(backupKey, credential) .also { Log.i(TAG, "UploadFormResult: $it") } @@ -311,9 +299,7 @@ object BackupRepository { val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey() - return api - .triggerBackupIdReservation(backupKey) - .then { getAuthCredential() } + return initBackupAndFetchAuth(backupKey) .then { credential -> api.getBackupInfo(backupKey, credential) } @@ -332,9 +318,7 @@ object BackupRepository { val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey() - return api - .triggerBackupIdReservation(backupKey) - .then { getAuthCredential() } + return initBackupAndFetchAuth(backupKey) .then { credential -> api.debugGetUploadedMediaItemMetadata(backupKey, credential) } @@ -347,9 +331,7 @@ object BackupRepository { val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey() - return api - .triggerBackupIdReservation(backupKey) - .then { getAuthCredential() } + return initBackupAndFetchAuth(backupKey) .then { credential -> api.getMediaUploadForm(backupKey, credential) } @@ -362,9 +344,7 @@ object BackupRepository { val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey() - return api - .triggerBackupIdReservation(backupKey) - .then { getAuthCredential() } + return initBackupAndFetchAuth(backupKey) .then { credential -> api.setPublicKey(backupKey, credential) .map { credential } @@ -390,9 +370,7 @@ object BackupRepository { val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey() - return api - .triggerBackupIdReservation(backupKey) - .then { getAuthCredential() } + return initBackupAndFetchAuth(backupKey) .then { credential -> val requests = mutableListOf() val mediaIdToAttachmentId = mutableMapOf() @@ -445,7 +423,7 @@ object BackupRepository { return NetworkResult.Success(Unit) } - return getAuthCredential() + return initBackupAndFetchAuth(backupKey) .then { credential -> api.deleteArchivedMedia( backupKey = backupKey, @@ -476,7 +454,7 @@ object BackupRepository { return NetworkResult.Success(Unit) } - return getAuthCredential() + return initBackupAndFetchAuth(backupKey) .then { credential -> api.deleteArchivedMedia( backupKey = backupKey, @@ -533,7 +511,7 @@ object BackupRepository { val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey() - return getAuthCredential() + return initBackupAndFetchAuth(backupKey) .then { credential -> api.getCdnReadCredentials( cdnNumber = cdnNumber, @@ -570,7 +548,7 @@ object BackupRepository { val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey() - return getAuthCredential() + return initBackupAndFetchAuth(backupKey) .then { credential -> api.getBackupInfo(backupKey, credential).map { BackupDirectories(it.backupDir!!, it.mediaDir!!) @@ -584,6 +562,25 @@ object BackupRepository { } } + /** + * Ensures that the backupId has been reserved and that your public key has been set, while also returning an auth credential. + * Should be the basis of all backup operations. + */ + private fun initBackupAndFetchAuth(backupKey: BackupKey): NetworkResult { + val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi + + return if (SignalStore.backup().backupsInitialized) { + getAuthCredential().runOnStatusCodeError(resetInitializedStateErrorAction) + } else { + return api + .triggerBackupIdReservation(backupKey) + .then { getAuthCredential() } + .then { credential -> api.setPublicKey(backupKey, credential).map { credential } } + .runIfSuccessful { SignalStore.backup().backupsInitialized = true } + .runOnStatusCodeError(resetInitializedStateErrorAction) + } + } + /** * Retrieves an auth credential, preferring a cached value if available. */ diff --git a/app/src/main/java/org/thoughtcrime/securesms/keyvalue/BackupValues.kt b/app/src/main/java/org/thoughtcrime/securesms/keyvalue/BackupValues.kt index f32f06a27d..3d73cedc34 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/keyvalue/BackupValues.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/keyvalue/BackupValues.kt @@ -26,6 +26,7 @@ internal class BackupValues(store: KeyValueStore) : SignalStoreValues(store) { private const val KEY_CDN_BACKUP_MEDIA_DIRECTORY = "backup.cdn.mediaDirectory" private const val KEY_OPTIMIZE_STORAGE = "backup.optimizeStorage" + private const val KEY_BACKUPS_INITIALIZED = "backup.initialized" /** * Specifies whether remote backups are enabled on this device. @@ -49,7 +50,20 @@ internal class BackupValues(store: KeyValueStore) : SignalStoreValues(store) { var nextBackupTime: Long by longValue(KEY_NEXT_BACKUP_TIME, -1) - var areBackupsEnabled: Boolean by booleanValue(KEY_BACKUPS_ENABLED, false) + var areBackupsEnabled: Boolean + get() { + return getBoolean(KEY_BACKUPS_ENABLED, false) + } + set(value) { + store + .beginWrite() + .putBoolean(KEY_BACKUPS_ENABLED, value) + .putLong(KEY_NEXT_BACKUP_TIME, -1) + .putBoolean(KEY_BACKUPS_INITIALIZED, false) + .apply() + } + + var backupsInitialized: Boolean by booleanValue(KEY_BACKUPS_INITIALIZED, false) /** * Retrieves the stored credentials, mapped by the day they're valid. The day is represented as diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt index d8fa217fbd..6465bd5f88 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt @@ -8,6 +8,8 @@ package org.whispersystems.signalservice.api import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException import java.io.IOException +typealias StatusCodeErrorAction = (NetworkResult.StatusCodeError<*>) -> Unit + /** * A helper class that wraps the result of a network request, turning common exceptions * into sealed classes, with optional request chaining. @@ -22,7 +24,9 @@ import java.io.IOException * sealed class. However, for the majority of requests which just require getting a model from * the success case and the status code of the error, this can be quite convenient. */ -sealed class NetworkResult { +sealed class NetworkResult( + private val statusCodeErrorActions: MutableSet = mutableSetOf() +) { companion object { /** * A convenience method to capture the common case of making a request. @@ -54,6 +58,8 @@ sealed class NetworkResult { /** * Returns the result if successful, otherwise turns the result back into an exception and throws it. + * + * Useful for bridging to Java, where you may want to use try-catch. */ fun successOrThrow(): T { when (this) { @@ -78,27 +84,95 @@ sealed class NetworkResult { /** * Takes the output of one [NetworkResult] and transforms it into another if the operation is successful. - * If it's a failure, the original failure will be propagated. Useful for changing the type of a result. + * If it's non-successful, [transform] lambda is not run, and instead the original failure will be propagated. + * Useful for changing the type of a result. + * + * ```kotlin + * val user: NetworkResult = NetworkResult + * .fromFetch { fetchRemoteUserModel() } + * .map { it.toLocalUserModel() } + * ``` */ fun map(transform: (T) -> R): NetworkResult { return when (this) { - is Success -> Success(transform(this.result)) - is NetworkError -> NetworkError(exception) - is StatusCodeError -> StatusCodeError(code, body, exception) - is ApplicationError -> ApplicationError(throwable) + is Success -> Success(transform(this.result)).runOnStatusCodeError(statusCodeErrorActions) + is NetworkError -> NetworkError(exception).runOnStatusCodeError(statusCodeErrorActions) + is ApplicationError -> ApplicationError(throwable).runOnStatusCodeError(statusCodeErrorActions) + is StatusCodeError -> StatusCodeError(code, body, exception).runOnStatusCodeError(statusCodeErrorActions) } } /** * Takes the output of one [NetworkResult] and passes it as the input to another if the operation is successful. - * If it's a failure, the original failure will be propagated. Useful for chaining operations together. + * If it's non-successful, the [result] lambda is not run, and instead the original failure will be propagated. + * Useful for chaining operations together. + * + * ```kotlin + * val networkResult: NetworkResult = NetworkResult + * .fromFetch { fetchAuthCredential() } + * .then { + * NetworkResult.fromFetch { credential -> fetchData(credential) } + * } + * ``` */ fun then(result: (T) -> NetworkResult): NetworkResult { return when (this) { - is Success -> result(this.result) - is NetworkError -> NetworkError(exception) - is StatusCodeError -> StatusCodeError(code, body, exception) - is ApplicationError -> ApplicationError(throwable) + is Success -> result(this.result).runOnStatusCodeError(statusCodeErrorActions) + is NetworkError -> NetworkError(exception).runOnStatusCodeError(statusCodeErrorActions) + is ApplicationError -> ApplicationError(throwable).runOnStatusCodeError(statusCodeErrorActions) + is StatusCodeError -> StatusCodeError(code, body, exception).runOnStatusCodeError(statusCodeErrorActions) } } + + /** + * Will perform an operation if the result at this point in the chain is successful. Note that it runs if the chain is _currently_ successful. It does not + * depend on anything futher down the chain. + * + * ```kotlin + * val networkResult: NetworkResult = NetworkResult + * .fromFetch { fetchAuthCredential() } + * .runIfSuccessful { storeMyCredential(it) } + * ``` + */ + fun runIfSuccessful(result: (T) -> Unit): NetworkResult { + if (this is Success) { + result(this.result) + } + return this + } + + /** + * Specify an action to be run when a status code error occurs. When a result is a [StatusCodeError] or is transformed into one further down the chain via + * a future [map] or [then], this code will be run. There can only ever be a single status code error in a chain, and therefore this lambda will only ever + * be run a single time. + * + * This is a low-visibility way of doing things, so use sparingly. + * + * ```kotlin + * val result = NetworkResult + * .fromFetch { getAuth() } + * .runOnStatusCodeError { error -> logError(error) } + * .then { credential -> + * NetworkResult.fromFetch { fetchUserDetails(credential) } + * } + * ``` + */ + fun runOnStatusCodeError(action: StatusCodeErrorAction): NetworkResult { + return runOnStatusCodeError(setOf(action)) + } + + internal fun runOnStatusCodeError(actions: Collection): NetworkResult { + if (actions.isEmpty()) { + return this + } + + statusCodeErrorActions += actions + + if (this is StatusCodeError) { + statusCodeErrorActions.forEach { it.invoke(this) } + statusCodeErrorActions.clear() + } + + return this + } } diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/api/NetworkResultTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/NetworkResultTest.kt new file mode 100644 index 0000000000..591ee17d11 --- /dev/null +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/NetworkResultTest.kt @@ -0,0 +1,221 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.signalservice.api + +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Test +import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException +import org.whispersystems.signalservice.api.push.exceptions.PushNetworkException + +class NetworkResultTest { + @Test + fun `generic success`() { + val result = NetworkResult.fromFetch {} + + assertTrue(result is NetworkResult.Success) + } + + @Test + fun `generic non-successful status code`() { + val exception = NonSuccessfulResponseCodeException(404, "not found", "body") + + val result = NetworkResult.fromFetch { + throw exception + } + + check(result is NetworkResult.StatusCodeError) + assertEquals(exception, result.exception) + assertEquals(404, result.code) + assertEquals("body", result.body) + } + + @Test + fun `generic network error`() { + val exception = PushNetworkException("general exception") + + val result = NetworkResult.fromFetch { + throw exception + } + + assertTrue(result is NetworkResult.NetworkError) + assertEquals(exception, (result as NetworkResult.NetworkError).exception) + } + + @Test + fun `generic application error`() { + val throwable = RuntimeException("test") + + val result = NetworkResult.fromFetch { + throw throwable + } + + assertTrue(result is NetworkResult.ApplicationError) + assertEquals(throwable, (result as NetworkResult.ApplicationError).throwable) + } + + @Test + fun `then - generic`() { + val result = NetworkResult + .fromFetch { NetworkResult.Success(1) } + .then { NetworkResult.Success(2) } + + assertTrue(result is NetworkResult.Success) + assertEquals(2, (result as NetworkResult.Success).result) + } + + @Test + fun `then - doesn't run on error`() { + val throwable = RuntimeException("test") + var run = false + + val result = NetworkResult + .fromFetch { throw throwable } + .then { + run = true + NetworkResult.Success(1) + } + + assertTrue(result is NetworkResult.ApplicationError) + assertFalse(run) + } + + @Test + fun `map - generic`() { + val result = NetworkResult + .fromFetch { NetworkResult.Success(1) } + .map { 2 } + + assertTrue(result is NetworkResult.Success) + assertEquals(2, (result as NetworkResult.Success).result) + } + + @Test + fun `map - doesn't run on error`() { + val throwable = RuntimeException("test") + var run = false + + val result = NetworkResult + .fromFetch { throw throwable } + .map { + run = true + 1 + } + + assertTrue(result is NetworkResult.ApplicationError) + assertFalse(run) + } + + @Test + fun `runIfSuccessful - doesn't run on error`() { + val throwable = RuntimeException("test") + var run = false + + val result = NetworkResult + .fromFetch { throw throwable } + .runIfSuccessful { run = true } + + assertTrue(result is NetworkResult.ApplicationError) + assertFalse(run) + } + + @Test + fun `runIfSuccessful - runs on success`() { + var run = false + + NetworkResult + .fromFetch { NetworkResult.Success(1) } + .runIfSuccessful { run = true } + + assertTrue(run) + } + + @Test + fun `runIfSuccessful - runs before error`() { + val throwable = RuntimeException("test") + var run = false + + val result = NetworkResult + .fromFetch { NetworkResult.Success(Unit) } + .runIfSuccessful { run = true } + .then { NetworkResult.ApplicationError(throwable) } + + assertTrue(result is NetworkResult.ApplicationError) + assertTrue(run) + } + + @Test + fun `runOnStatusCodeError - simple call`() { + var handled = false + + NetworkResult + .fromFetch { throw NonSuccessfulResponseCodeException(404, "not found", "body") } + .runOnStatusCodeError { handled = true } + + assertTrue(handled) + } + + @Test + fun `runOnStatusCodeError - ensure only called once`() { + var handleCount = 0 + + NetworkResult + .fromFetch { throw NonSuccessfulResponseCodeException(404, "not found", "body") } + .runOnStatusCodeError { handleCount++ } + .map { 1 } + .then { NetworkResult.Success(2) } + .map { 3 } + + assertEquals(1, handleCount) + } + + @Test + fun `runOnStatusCodeError - called when placed before a failing then`() { + var handled = false + + val result = NetworkResult + .fromFetch { } + .runOnStatusCodeError { handled = true } + .then { NetworkResult.fromFetch { throw NonSuccessfulResponseCodeException(404, "not found", "body") } } + + assertTrue(handled) + assertTrue(result is NetworkResult.StatusCodeError) + } + + @Test + fun `runOnStatusCodeError - called when placed two spots before a failing then`() { + var handled = false + + val result = NetworkResult + .fromFetch { } + .runOnStatusCodeError { handled = true } + .then { NetworkResult.Success(Unit) } + .then { NetworkResult.fromFetch { throw NonSuccessfulResponseCodeException(404, "not found", "body") } } + + assertTrue(handled) + assertTrue(result is NetworkResult.StatusCodeError) + } + + @Test + fun `runOnStatusCodeError - should not be called for successful results`() { + var handled = false + + NetworkResult + .fromFetch {} + .runOnStatusCodeError { handled = true } + + NetworkResult + .fromFetch { throw RuntimeException("application error") } + .runOnStatusCodeError { handled = true } + + NetworkResult + .fromFetch { throw PushNetworkException("network error") } + .runOnStatusCodeError { handled = true } + + assertFalse(handled) + } +}