diff --git a/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/ContactDiscovery.kt b/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/ContactDiscovery.kt index a47734f3db..17d848ce53 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/ContactDiscovery.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/ContactDiscovery.kt @@ -37,7 +37,6 @@ import org.thoughtcrime.securesms.util.Util import org.whispersystems.signalservice.api.push.SignalServiceAddress import org.whispersystems.signalservice.api.util.UuidUtil import java.io.IOException -import java.lang.Exception import java.util.Calendar import java.util.concurrent.Callable import java.util.concurrent.ExecutionException @@ -80,9 +79,9 @@ object ContactDiscovery { descriptor = "refresh-all", refresh = { if (FeatureFlags.phoneNumberPrivacy()) { - ContactDiscoveryRefreshV2.refreshAll(context, useCompat = false) + ContactDiscoveryRefreshV2.refreshAll(context, useCompat = false, ignoreResults = false) } else if (FeatureFlags.cdsV2Compat()) { - ContactDiscoveryRefreshV2.refreshAll(context, useCompat = true) + ContactDiscoveryRefreshV2.refreshAll(context, useCompat = true, ignoreResults = false) } else if (FeatureFlags.cdsV2LoadTesting()) { loadTestRefreshAll(context) } else { @@ -105,9 +104,9 @@ object ContactDiscovery { descriptor = "refresh-multiple", refresh = { if (FeatureFlags.phoneNumberPrivacy()) { - ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = false) + ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = false, ignoreResults = false) } else if (FeatureFlags.cdsV2Compat()) { - ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = true) + ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = true, ignoreResults = false) } else if (FeatureFlags.cdsV2LoadTesting()) { loadTestRefresh(context, recipients) } else { @@ -128,9 +127,9 @@ object ContactDiscovery { descriptor = "refresh-single", refresh = { if (FeatureFlags.phoneNumberPrivacy()) { - ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = false) + ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = false, ignoreResults = false) } else if (FeatureFlags.cdsV2Compat()) { - ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = true) + ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = true, ignoreResults = false) } else if (FeatureFlags.cdsV2LoadTesting()) { loadTestRefresh(context, listOf(recipient)) } else { @@ -404,7 +403,7 @@ object ContactDiscovery { try { v2Future.get() - } catch (e: Exception) { + } catch (e: Throwable) { Log.w(TAG, "Failed to complete the V2 fetch!", e) } diff --git a/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/ContactDiscoveryRefreshV2.kt b/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/ContactDiscoveryRefreshV2.kt index ca62d450bd..753308a859 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/ContactDiscoveryRefreshV2.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/contacts/sync/ContactDiscoveryRefreshV2.kt @@ -18,6 +18,7 @@ import org.thoughtcrime.securesms.phonenumbers.PhoneNumberFormatter import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.RecipientId import org.whispersystems.signalservice.api.push.ACI +import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException import org.whispersystems.signalservice.api.services.CdsiV2Service import java.io.IOException import java.util.Optional @@ -44,7 +45,7 @@ object ContactDiscoveryRefreshV2 { @WorkerThread @Synchronized @JvmStatic - fun refreshAll(context: Context, useCompat: Boolean, ignoreResults: Boolean = false): ContactDiscovery.RefreshResult { + fun refreshAll(context: Context, useCompat: Boolean, ignoreResults: Boolean): ContactDiscovery.RefreshResult { val recipientE164s: Set = SignalDatabase.recipients.getAllE164s().sanitize() val systemE164s: Set = SystemContactsRepository.getAllDisplayNumbers(context).toE164s(context).sanitize() @@ -52,7 +53,7 @@ object ContactDiscoveryRefreshV2 { recipientE164s = recipientE164s, systemE164s = systemE164s, inputPreviousE164s = SignalDatabase.cds.getAllE164s(), - saveToken = true, + isPartialRefresh = false, useCompat = useCompat, ignoreResults = ignoreResults ) @@ -62,14 +63,14 @@ object ContactDiscoveryRefreshV2 { @WorkerThread @Synchronized @JvmStatic - fun refresh(context: Context, inputRecipients: List, useCompat: Boolean, ignoreResults: Boolean = false): ContactDiscovery.RefreshResult { + fun refresh(context: Context, inputRecipients: List, useCompat: Boolean, ignoreResults: Boolean): ContactDiscovery.RefreshResult { val recipients: List = inputRecipients.map { it.resolve() } val inputE164s: Set = recipients.mapNotNull { it.e164.orElse(null) }.toSet() return if (inputE164s.size > MAXIMUM_ONE_OFF_REQUEST_SIZE) { Log.i(TAG, "List of specific recipients to refresh is too large! (Size: ${recipients.size}). Doing a full refresh instead.") - val fullResult: ContactDiscovery.RefreshResult = refreshAll(context, ignoreResults) + val fullResult: ContactDiscovery.RefreshResult = refreshAll(context, useCompat = useCompat, ignoreResults = ignoreResults) val inputIds: Set = recipients.map { it.id }.toSet() ContactDiscovery.RefreshResult( @@ -81,7 +82,7 @@ object ContactDiscoveryRefreshV2 { recipientE164s = inputE164s, systemE164s = inputE164s, inputPreviousE164s = emptySet(), - saveToken = false, + isPartialRefresh = true, useCompat = useCompat, ignoreResults = ignoreResults ) @@ -93,13 +94,14 @@ object ContactDiscoveryRefreshV2 { recipientE164s: Set, systemE164s: Set, inputPreviousE164s: Set, - saveToken: Boolean, + isPartialRefresh: Boolean, useCompat: Boolean, ignoreResults: Boolean ): ContactDiscovery.RefreshResult { - val stopwatch = Stopwatch("refreshInternal-${if (useCompat) "compat" else "v2"}") + val tag = "refreshInternal-${if (useCompat) "compat" else "v2"}" + val stopwatch = Stopwatch(tag) - val previousE164s: Set = if (SignalStore.misc().cdsToken != null) inputPreviousE164s else emptySet() + val previousE164s: Set = if (SignalStore.misc().cdsToken != null && !isPartialRefresh) inputPreviousE164s else emptySet() val allE164s: Set = recipientE164s + systemE164s val newRawE164s: Set = allE164s - previousE164s @@ -107,40 +109,50 @@ object ContactDiscoveryRefreshV2 { val newE164s: Set = fuzzyInput.numbers if (newE164s.isEmpty() && previousE164s.isEmpty()) { - Log.w(TAG, "[refreshInternal] No data to send! Ignoring.") + Log.w(TAG, "[$tag] No data to send! Ignoring.") return ContactDiscovery.RefreshResult(emptySet(), emptyMap()) } - val token: ByteArray? = if (previousE164s.isNotEmpty()) SignalStore.misc().cdsToken else null + val token: ByteArray? = if (previousE164s.isNotEmpty() && !isPartialRefresh) SignalStore.misc().cdsToken else null stopwatch.split("preamble") - val response: CdsiV2Service.Response = ApplicationDependencies.getSignalServiceAccountManager().getRegisteredUsersWithCdsi( - previousE164s, - newE164s, - SignalDatabase.recipients.getAllServiceIdProfileKeyPairs(), - useCompat, - Optional.ofNullable(token), - BuildConfig.CDSI_MRENCLAVE - ) { tokenToSave -> - if (saveToken) { - SignalStore.misc().cdsToken = tokenToSave - Log.d(TAG, "Token saved!") - } else { - Log.d(TAG, "Ignoring token.") + val response: CdsiV2Service.Response = try { + ApplicationDependencies.getSignalServiceAccountManager().getRegisteredUsersWithCdsi( + previousE164s, + newE164s, + SignalDatabase.recipients.getAllServiceIdProfileKeyPairs(), + useCompat, + Optional.ofNullable(token), + BuildConfig.CDSI_MRENCLAVE + ) { tokenToSave -> + stopwatch.split("network-pre-token") + if (!isPartialRefresh) { + SignalStore.misc().cdsToken = tokenToSave + SignalDatabase.cds.updateAfterFullCdsQuery(previousE164s + newE164s, allE164s + newE164s) + Log.d(TAG, "Token saved!") + } else { + SignalDatabase.cds.updateAfterPartialCdsQuery(newE164s) + Log.d(TAG, "Ignoring token.") + } + stopwatch.split("cds-db") } + } catch (e: NonSuccessfulResponseCodeException) { + if (e.code == 4101) { + Log.w(TAG, "Our token was invalid! Only thing we can do now is clear our local state :(") + SignalStore.misc().cdsToken = null + SignalDatabase.cds.clearAll() + } + throw e } - Log.d(TAG, "[refreshInternal] Used ${response.quotaUsedDebugOnly} quota.") - stopwatch.split("network") - - SignalDatabase.cds.updateAfterCdsQuery(newE164s, allE164s + newE164s) - stopwatch.split("cds-db") + Log.d(TAG, "[$tag] Used ${response.quotaUsedDebugOnly} quota.") + stopwatch.split("network-post-token") val registeredIds: MutableSet = mutableSetOf() val rewrites: MutableMap = mutableMapOf() if (ignoreResults) { - Log.w(TAG, "[refreshInternal] Ignoring CDSv2 results.") + Log.w(TAG, "[$tag] Ignoring CDSv2 results.") } else { if (useCompat) { val transformed: Map = response.results.mapValues { entry -> entry.value.aci.orElse(null) } diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/CdsDatabase.kt b/app/src/main/java/org/thoughtcrime/securesms/database/CdsDatabase.kt index b91c12c795..e5ec71baeb 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/CdsDatabase.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/CdsDatabase.kt @@ -8,7 +8,7 @@ import org.signal.core.util.delete import org.signal.core.util.logging.Log import org.signal.core.util.requireNonNullString import org.signal.core.util.select -import org.signal.core.util.update +import org.signal.core.util.withinTransaction /** * Keeps track of the numbers we've previously queried CDS for. @@ -53,32 +53,58 @@ class CdsDatabase(context: Context, databaseHelper: SignalDatabase) : Database(c } /** - * @param newE164s The newly-added E164s that we hadn't previously queried for. - * @param seenE164s The E164s that were seen in either the system contacts or recipients table. - * This should be a superset of [newE164s] - * + * Saves the set of e164s used after a full refresh. + * @param fullE164s All of the e164s used in the last CDS query (previous and new). + * @param seenE164s The E164s that were seen in either the system contacts or recipients table. This is different from [fullE164s] in that [fullE164s] + * includes every number we've ever seen, even if it's not in our contacts anymore. */ - fun updateAfterCdsQuery(newE164s: Set, seenE164s: Set) { + fun updateAfterFullCdsQuery(fullE164s: Set, seenE164s: Set) { val lastSeen = System.currentTimeMillis() - writableDatabase.beginTransaction() - try { - val insertValues: List = newE164s.map { contentValuesOf(E164 to it) } + writableDatabase.withinTransaction { db -> + val existingE164s: Set = getAllE164s() + val removedE164s: Set = existingE164s - fullE164s + val addedE164s: Set = fullE164s - existingE164s - SqlUtil.buildBulkInsert(TABLE_NAME, arrayOf(E164), insertValues) - .forEach { writableDatabase.execSQL(it.where, it.whereArgs) } + if (removedE164s.isNotEmpty()) { + SqlUtil.buildCollectionQuery(E164, removedE164s) + .forEach { db.delete(TABLE_NAME, it.where, it.whereArgs) } + } - val contentValues = contentValuesOf(LAST_SEEN_AT to lastSeen) + if (addedE164s.isNotEmpty()) { + val insertValues: List = addedE164s.map { contentValuesOf(E164 to it) } - SqlUtil.buildCollectionQuery(E164, seenE164s) - .forEach { query -> writableDatabase.update(TABLE_NAME, contentValues, query.where, query.whereArgs) } + SqlUtil.buildBulkInsert(TABLE_NAME, arrayOf(E164), insertValues) + .forEach { db.execSQL(it.where, it.whereArgs) } + } - writableDatabase.setTransactionSuccessful() - } finally { - writableDatabase.endTransaction() + if (seenE164s.isNotEmpty()) { + val contentValues = contentValuesOf(LAST_SEEN_AT to lastSeen) + + SqlUtil.buildCollectionQuery(E164, seenE164s) + .forEach { query -> db.update(TABLE_NAME, contentValues, query.where, query.whereArgs) } + } } } + /** + * Updates after a partial CDS query. Will not insert new entries. Instead, this will simply update the lastSeen timestamp of any entry we already have. + * @param seenE164s The newly-added E164s that we hadn't previously queried for. + */ + fun updateAfterPartialCdsQuery(seenE164s: Set) { + val lastSeen = System.currentTimeMillis() + + writableDatabase.withinTransaction { db -> + val contentValues = contentValuesOf(LAST_SEEN_AT to lastSeen) + + SqlUtil.buildCollectionQuery(E164, seenE164s) + .forEach { query -> db.update(TABLE_NAME, contentValues, query.where, query.whereArgs) } + } + } + + /** + * Wipes the entire table. + */ fun clearAll() { writableDatabase .delete(TABLE_NAME) diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java index 4fd3c7877a..f814dd2044 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java @@ -549,7 +549,11 @@ public class SignalServiceAccountManager { if (serviceResponse.getResult().isPresent()) { return serviceResponse.getResult().get(); } else if (serviceResponse.getApplicationError().isPresent()) { - throw new IOException(serviceResponse.getApplicationError().get()); + if (serviceResponse.getApplicationError().get() instanceof IOException) { + throw (IOException) serviceResponse.getApplicationError().get(); + } else { + throw new IOException(serviceResponse.getApplicationError().get()); + } } else if (serviceResponse.getExecutionError().isPresent()) { throw new IOException(serviceResponse.getExecutionError().get()); } else {