Fix token mismatch issues when using CDSv2.

This commit is contained in:
Greyson Parrelli 2022-09-07 14:18:54 -04:00 committed by Cody Henthorne
parent f1bcc756d3
commit 658741be52
4 changed files with 96 additions and 55 deletions

View file

@ -37,7 +37,6 @@ import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.push.SignalServiceAddress import org.whispersystems.signalservice.api.push.SignalServiceAddress
import org.whispersystems.signalservice.api.util.UuidUtil import org.whispersystems.signalservice.api.util.UuidUtil
import java.io.IOException import java.io.IOException
import java.lang.Exception
import java.util.Calendar import java.util.Calendar
import java.util.concurrent.Callable import java.util.concurrent.Callable
import java.util.concurrent.ExecutionException import java.util.concurrent.ExecutionException
@ -80,9 +79,9 @@ object ContactDiscovery {
descriptor = "refresh-all", descriptor = "refresh-all",
refresh = { refresh = {
if (FeatureFlags.phoneNumberPrivacy()) { if (FeatureFlags.phoneNumberPrivacy()) {
ContactDiscoveryRefreshV2.refreshAll(context, useCompat = false) ContactDiscoveryRefreshV2.refreshAll(context, useCompat = false, ignoreResults = false)
} else if (FeatureFlags.cdsV2Compat()) { } else if (FeatureFlags.cdsV2Compat()) {
ContactDiscoveryRefreshV2.refreshAll(context, useCompat = true) ContactDiscoveryRefreshV2.refreshAll(context, useCompat = true, ignoreResults = false)
} else if (FeatureFlags.cdsV2LoadTesting()) { } else if (FeatureFlags.cdsV2LoadTesting()) {
loadTestRefreshAll(context) loadTestRefreshAll(context)
} else { } else {
@ -105,9 +104,9 @@ object ContactDiscovery {
descriptor = "refresh-multiple", descriptor = "refresh-multiple",
refresh = { refresh = {
if (FeatureFlags.phoneNumberPrivacy()) { if (FeatureFlags.phoneNumberPrivacy()) {
ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = false) ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = false, ignoreResults = false)
} else if (FeatureFlags.cdsV2Compat()) { } else if (FeatureFlags.cdsV2Compat()) {
ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = true) ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = true, ignoreResults = false)
} else if (FeatureFlags.cdsV2LoadTesting()) { } else if (FeatureFlags.cdsV2LoadTesting()) {
loadTestRefresh(context, recipients) loadTestRefresh(context, recipients)
} else { } else {
@ -128,9 +127,9 @@ object ContactDiscovery {
descriptor = "refresh-single", descriptor = "refresh-single",
refresh = { refresh = {
if (FeatureFlags.phoneNumberPrivacy()) { if (FeatureFlags.phoneNumberPrivacy()) {
ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = false) ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = false, ignoreResults = false)
} else if (FeatureFlags.cdsV2Compat()) { } else if (FeatureFlags.cdsV2Compat()) {
ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = true) ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = true, ignoreResults = false)
} else if (FeatureFlags.cdsV2LoadTesting()) { } else if (FeatureFlags.cdsV2LoadTesting()) {
loadTestRefresh(context, listOf(recipient)) loadTestRefresh(context, listOf(recipient))
} else { } else {
@ -404,7 +403,7 @@ object ContactDiscovery {
try { try {
v2Future.get() v2Future.get()
} catch (e: Exception) { } catch (e: Throwable) {
Log.w(TAG, "Failed to complete the V2 fetch!", e) Log.w(TAG, "Failed to complete the V2 fetch!", e)
} }

View file

@ -18,6 +18,7 @@ import org.thoughtcrime.securesms.phonenumbers.PhoneNumberFormatter
import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.recipients.RecipientId import org.thoughtcrime.securesms.recipients.RecipientId
import org.whispersystems.signalservice.api.push.ACI import org.whispersystems.signalservice.api.push.ACI
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException
import org.whispersystems.signalservice.api.services.CdsiV2Service import org.whispersystems.signalservice.api.services.CdsiV2Service
import java.io.IOException import java.io.IOException
import java.util.Optional import java.util.Optional
@ -44,7 +45,7 @@ object ContactDiscoveryRefreshV2 {
@WorkerThread @WorkerThread
@Synchronized @Synchronized
@JvmStatic @JvmStatic
fun refreshAll(context: Context, useCompat: Boolean, ignoreResults: Boolean = false): ContactDiscovery.RefreshResult { fun refreshAll(context: Context, useCompat: Boolean, ignoreResults: Boolean): ContactDiscovery.RefreshResult {
val recipientE164s: Set<String> = SignalDatabase.recipients.getAllE164s().sanitize() val recipientE164s: Set<String> = SignalDatabase.recipients.getAllE164s().sanitize()
val systemE164s: Set<String> = SystemContactsRepository.getAllDisplayNumbers(context).toE164s(context).sanitize() val systemE164s: Set<String> = SystemContactsRepository.getAllDisplayNumbers(context).toE164s(context).sanitize()
@ -52,7 +53,7 @@ object ContactDiscoveryRefreshV2 {
recipientE164s = recipientE164s, recipientE164s = recipientE164s,
systemE164s = systemE164s, systemE164s = systemE164s,
inputPreviousE164s = SignalDatabase.cds.getAllE164s(), inputPreviousE164s = SignalDatabase.cds.getAllE164s(),
saveToken = true, isPartialRefresh = false,
useCompat = useCompat, useCompat = useCompat,
ignoreResults = ignoreResults ignoreResults = ignoreResults
) )
@ -62,14 +63,14 @@ object ContactDiscoveryRefreshV2 {
@WorkerThread @WorkerThread
@Synchronized @Synchronized
@JvmStatic @JvmStatic
fun refresh(context: Context, inputRecipients: List<Recipient>, useCompat: Boolean, ignoreResults: Boolean = false): ContactDiscovery.RefreshResult { fun refresh(context: Context, inputRecipients: List<Recipient>, useCompat: Boolean, ignoreResults: Boolean): ContactDiscovery.RefreshResult {
val recipients: List<Recipient> = inputRecipients.map { it.resolve() } val recipients: List<Recipient> = inputRecipients.map { it.resolve() }
val inputE164s: Set<String> = recipients.mapNotNull { it.e164.orElse(null) }.toSet() val inputE164s: Set<String> = recipients.mapNotNull { it.e164.orElse(null) }.toSet()
return if (inputE164s.size > MAXIMUM_ONE_OFF_REQUEST_SIZE) { return if (inputE164s.size > MAXIMUM_ONE_OFF_REQUEST_SIZE) {
Log.i(TAG, "List of specific recipients to refresh is too large! (Size: ${recipients.size}). Doing a full refresh instead.") Log.i(TAG, "List of specific recipients to refresh is too large! (Size: ${recipients.size}). Doing a full refresh instead.")
val fullResult: ContactDiscovery.RefreshResult = refreshAll(context, ignoreResults) val fullResult: ContactDiscovery.RefreshResult = refreshAll(context, useCompat = useCompat, ignoreResults = ignoreResults)
val inputIds: Set<RecipientId> = recipients.map { it.id }.toSet() val inputIds: Set<RecipientId> = recipients.map { it.id }.toSet()
ContactDiscovery.RefreshResult( ContactDiscovery.RefreshResult(
@ -81,7 +82,7 @@ object ContactDiscoveryRefreshV2 {
recipientE164s = inputE164s, recipientE164s = inputE164s,
systemE164s = inputE164s, systemE164s = inputE164s,
inputPreviousE164s = emptySet(), inputPreviousE164s = emptySet(),
saveToken = false, isPartialRefresh = true,
useCompat = useCompat, useCompat = useCompat,
ignoreResults = ignoreResults ignoreResults = ignoreResults
) )
@ -93,13 +94,14 @@ object ContactDiscoveryRefreshV2 {
recipientE164s: Set<String>, recipientE164s: Set<String>,
systemE164s: Set<String>, systemE164s: Set<String>,
inputPreviousE164s: Set<String>, inputPreviousE164s: Set<String>,
saveToken: Boolean, isPartialRefresh: Boolean,
useCompat: Boolean, useCompat: Boolean,
ignoreResults: Boolean ignoreResults: Boolean
): ContactDiscovery.RefreshResult { ): ContactDiscovery.RefreshResult {
val stopwatch = Stopwatch("refreshInternal-${if (useCompat) "compat" else "v2"}") val tag = "refreshInternal-${if (useCompat) "compat" else "v2"}"
val stopwatch = Stopwatch(tag)
val previousE164s: Set<String> = if (SignalStore.misc().cdsToken != null) inputPreviousE164s else emptySet() val previousE164s: Set<String> = if (SignalStore.misc().cdsToken != null && !isPartialRefresh) inputPreviousE164s else emptySet()
val allE164s: Set<String> = recipientE164s + systemE164s val allE164s: Set<String> = recipientE164s + systemE164s
val newRawE164s: Set<String> = allE164s - previousE164s val newRawE164s: Set<String> = allE164s - previousE164s
@ -107,40 +109,50 @@ object ContactDiscoveryRefreshV2 {
val newE164s: Set<String> = fuzzyInput.numbers val newE164s: Set<String> = fuzzyInput.numbers
if (newE164s.isEmpty() && previousE164s.isEmpty()) { if (newE164s.isEmpty() && previousE164s.isEmpty()) {
Log.w(TAG, "[refreshInternal] No data to send! Ignoring.") Log.w(TAG, "[$tag] No data to send! Ignoring.")
return ContactDiscovery.RefreshResult(emptySet(), emptyMap()) return ContactDiscovery.RefreshResult(emptySet(), emptyMap())
} }
val token: ByteArray? = if (previousE164s.isNotEmpty()) SignalStore.misc().cdsToken else null val token: ByteArray? = if (previousE164s.isNotEmpty() && !isPartialRefresh) SignalStore.misc().cdsToken else null
stopwatch.split("preamble") stopwatch.split("preamble")
val response: CdsiV2Service.Response = ApplicationDependencies.getSignalServiceAccountManager().getRegisteredUsersWithCdsi( val response: CdsiV2Service.Response = try {
previousE164s, ApplicationDependencies.getSignalServiceAccountManager().getRegisteredUsersWithCdsi(
newE164s, previousE164s,
SignalDatabase.recipients.getAllServiceIdProfileKeyPairs(), newE164s,
useCompat, SignalDatabase.recipients.getAllServiceIdProfileKeyPairs(),
Optional.ofNullable(token), useCompat,
BuildConfig.CDSI_MRENCLAVE Optional.ofNullable(token),
) { tokenToSave -> BuildConfig.CDSI_MRENCLAVE
if (saveToken) { ) { tokenToSave ->
SignalStore.misc().cdsToken = tokenToSave stopwatch.split("network-pre-token")
Log.d(TAG, "Token saved!") if (!isPartialRefresh) {
} else { SignalStore.misc().cdsToken = tokenToSave
Log.d(TAG, "Ignoring token.") SignalDatabase.cds.updateAfterFullCdsQuery(previousE164s + newE164s, allE164s + newE164s)
Log.d(TAG, "Token saved!")
} else {
SignalDatabase.cds.updateAfterPartialCdsQuery(newE164s)
Log.d(TAG, "Ignoring token.")
}
stopwatch.split("cds-db")
} }
} catch (e: NonSuccessfulResponseCodeException) {
if (e.code == 4101) {
Log.w(TAG, "Our token was invalid! Only thing we can do now is clear our local state :(")
SignalStore.misc().cdsToken = null
SignalDatabase.cds.clearAll()
}
throw e
} }
Log.d(TAG, "[refreshInternal] Used ${response.quotaUsedDebugOnly} quota.") Log.d(TAG, "[$tag] Used ${response.quotaUsedDebugOnly} quota.")
stopwatch.split("network") stopwatch.split("network-post-token")
SignalDatabase.cds.updateAfterCdsQuery(newE164s, allE164s + newE164s)
stopwatch.split("cds-db")
val registeredIds: MutableSet<RecipientId> = mutableSetOf() val registeredIds: MutableSet<RecipientId> = mutableSetOf()
val rewrites: MutableMap<String, String> = mutableMapOf() val rewrites: MutableMap<String, String> = mutableMapOf()
if (ignoreResults) { if (ignoreResults) {
Log.w(TAG, "[refreshInternal] Ignoring CDSv2 results.") Log.w(TAG, "[$tag] Ignoring CDSv2 results.")
} else { } else {
if (useCompat) { if (useCompat) {
val transformed: Map<String, ACI?> = response.results.mapValues { entry -> entry.value.aci.orElse(null) } val transformed: Map<String, ACI?> = response.results.mapValues { entry -> entry.value.aci.orElse(null) }

View file

@ -8,7 +8,7 @@ import org.signal.core.util.delete
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.signal.core.util.requireNonNullString import org.signal.core.util.requireNonNullString
import org.signal.core.util.select import org.signal.core.util.select
import org.signal.core.util.update import org.signal.core.util.withinTransaction
/** /**
* Keeps track of the numbers we've previously queried CDS for. * Keeps track of the numbers we've previously queried CDS for.
@ -53,32 +53,58 @@ class CdsDatabase(context: Context, databaseHelper: SignalDatabase) : Database(c
} }
/** /**
* @param newE164s The newly-added E164s that we hadn't previously queried for. * Saves the set of e164s used after a full refresh.
* @param seenE164s The E164s that were seen in either the system contacts or recipients table. * @param fullE164s All of the e164s used in the last CDS query (previous and new).
* This should be a superset of [newE164s] * @param seenE164s The E164s that were seen in either the system contacts or recipients table. This is different from [fullE164s] in that [fullE164s]
* * includes every number we've ever seen, even if it's not in our contacts anymore.
*/ */
fun updateAfterCdsQuery(newE164s: Set<String>, seenE164s: Set<String>) { fun updateAfterFullCdsQuery(fullE164s: Set<String>, seenE164s: Set<String>) {
val lastSeen = System.currentTimeMillis() val lastSeen = System.currentTimeMillis()
writableDatabase.beginTransaction() writableDatabase.withinTransaction { db ->
try { val existingE164s: Set<String> = getAllE164s()
val insertValues: List<ContentValues> = newE164s.map { contentValuesOf(E164 to it) } val removedE164s: Set<String> = existingE164s - fullE164s
val addedE164s: Set<String> = fullE164s - existingE164s
SqlUtil.buildBulkInsert(TABLE_NAME, arrayOf(E164), insertValues) if (removedE164s.isNotEmpty()) {
.forEach { writableDatabase.execSQL(it.where, it.whereArgs) } SqlUtil.buildCollectionQuery(E164, removedE164s)
.forEach { db.delete(TABLE_NAME, it.where, it.whereArgs) }
}
val contentValues = contentValuesOf(LAST_SEEN_AT to lastSeen) if (addedE164s.isNotEmpty()) {
val insertValues: List<ContentValues> = addedE164s.map { contentValuesOf(E164 to it) }
SqlUtil.buildCollectionQuery(E164, seenE164s) SqlUtil.buildBulkInsert(TABLE_NAME, arrayOf(E164), insertValues)
.forEach { query -> writableDatabase.update(TABLE_NAME, contentValues, query.where, query.whereArgs) } .forEach { db.execSQL(it.where, it.whereArgs) }
}
writableDatabase.setTransactionSuccessful() if (seenE164s.isNotEmpty()) {
} finally { val contentValues = contentValuesOf(LAST_SEEN_AT to lastSeen)
writableDatabase.endTransaction()
SqlUtil.buildCollectionQuery(E164, seenE164s)
.forEach { query -> db.update(TABLE_NAME, contentValues, query.where, query.whereArgs) }
}
} }
} }
/**
* Updates after a partial CDS query. Will not insert new entries. Instead, this will simply update the lastSeen timestamp of any entry we already have.
* @param seenE164s The newly-added E164s that we hadn't previously queried for.
*/
fun updateAfterPartialCdsQuery(seenE164s: Set<String>) {
val lastSeen = System.currentTimeMillis()
writableDatabase.withinTransaction { db ->
val contentValues = contentValuesOf(LAST_SEEN_AT to lastSeen)
SqlUtil.buildCollectionQuery(E164, seenE164s)
.forEach { query -> db.update(TABLE_NAME, contentValues, query.where, query.whereArgs) }
}
}
/**
* Wipes the entire table.
*/
fun clearAll() { fun clearAll() {
writableDatabase writableDatabase
.delete(TABLE_NAME) .delete(TABLE_NAME)

View file

@ -549,7 +549,11 @@ public class SignalServiceAccountManager {
if (serviceResponse.getResult().isPresent()) { if (serviceResponse.getResult().isPresent()) {
return serviceResponse.getResult().get(); return serviceResponse.getResult().get();
} else if (serviceResponse.getApplicationError().isPresent()) { } else if (serviceResponse.getApplicationError().isPresent()) {
throw new IOException(serviceResponse.getApplicationError().get()); if (serviceResponse.getApplicationError().get() instanceof IOException) {
throw (IOException) serviceResponse.getApplicationError().get();
} else {
throw new IOException(serviceResponse.getApplicationError().get());
}
} else if (serviceResponse.getExecutionError().isPresent()) { } else if (serviceResponse.getExecutionError().isPresent()) {
throw new IOException(serviceResponse.getExecutionError().get()); throw new IOException(serviceResponse.getExecutionError().get());
} else { } else {