Fix token mismatch issues when using CDSv2.
This commit is contained in:
parent
f1bcc756d3
commit
658741be52
4 changed files with 96 additions and 55 deletions
|
@ -37,7 +37,6 @@ import org.thoughtcrime.securesms.util.Util
|
|||
import org.whispersystems.signalservice.api.push.SignalServiceAddress
|
||||
import org.whispersystems.signalservice.api.util.UuidUtil
|
||||
import java.io.IOException
|
||||
import java.lang.Exception
|
||||
import java.util.Calendar
|
||||
import java.util.concurrent.Callable
|
||||
import java.util.concurrent.ExecutionException
|
||||
|
@ -80,9 +79,9 @@ object ContactDiscovery {
|
|||
descriptor = "refresh-all",
|
||||
refresh = {
|
||||
if (FeatureFlags.phoneNumberPrivacy()) {
|
||||
ContactDiscoveryRefreshV2.refreshAll(context, useCompat = false)
|
||||
ContactDiscoveryRefreshV2.refreshAll(context, useCompat = false, ignoreResults = false)
|
||||
} else if (FeatureFlags.cdsV2Compat()) {
|
||||
ContactDiscoveryRefreshV2.refreshAll(context, useCompat = true)
|
||||
ContactDiscoveryRefreshV2.refreshAll(context, useCompat = true, ignoreResults = false)
|
||||
} else if (FeatureFlags.cdsV2LoadTesting()) {
|
||||
loadTestRefreshAll(context)
|
||||
} else {
|
||||
|
@ -105,9 +104,9 @@ object ContactDiscovery {
|
|||
descriptor = "refresh-multiple",
|
||||
refresh = {
|
||||
if (FeatureFlags.phoneNumberPrivacy()) {
|
||||
ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = false)
|
||||
ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = false, ignoreResults = false)
|
||||
} else if (FeatureFlags.cdsV2Compat()) {
|
||||
ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = true)
|
||||
ContactDiscoveryRefreshV2.refresh(context, recipients, useCompat = true, ignoreResults = false)
|
||||
} else if (FeatureFlags.cdsV2LoadTesting()) {
|
||||
loadTestRefresh(context, recipients)
|
||||
} else {
|
||||
|
@ -128,9 +127,9 @@ object ContactDiscovery {
|
|||
descriptor = "refresh-single",
|
||||
refresh = {
|
||||
if (FeatureFlags.phoneNumberPrivacy()) {
|
||||
ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = false)
|
||||
ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = false, ignoreResults = false)
|
||||
} else if (FeatureFlags.cdsV2Compat()) {
|
||||
ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = true)
|
||||
ContactDiscoveryRefreshV2.refresh(context, listOf(recipient), useCompat = true, ignoreResults = false)
|
||||
} else if (FeatureFlags.cdsV2LoadTesting()) {
|
||||
loadTestRefresh(context, listOf(recipient))
|
||||
} else {
|
||||
|
@ -404,7 +403,7 @@ object ContactDiscovery {
|
|||
|
||||
try {
|
||||
v2Future.get()
|
||||
} catch (e: Exception) {
|
||||
} catch (e: Throwable) {
|
||||
Log.w(TAG, "Failed to complete the V2 fetch!", e)
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.thoughtcrime.securesms.phonenumbers.PhoneNumberFormatter
|
|||
import org.thoughtcrime.securesms.recipients.Recipient
|
||||
import org.thoughtcrime.securesms.recipients.RecipientId
|
||||
import org.whispersystems.signalservice.api.push.ACI
|
||||
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException
|
||||
import org.whispersystems.signalservice.api.services.CdsiV2Service
|
||||
import java.io.IOException
|
||||
import java.util.Optional
|
||||
|
@ -44,7 +45,7 @@ object ContactDiscoveryRefreshV2 {
|
|||
@WorkerThread
|
||||
@Synchronized
|
||||
@JvmStatic
|
||||
fun refreshAll(context: Context, useCompat: Boolean, ignoreResults: Boolean = false): ContactDiscovery.RefreshResult {
|
||||
fun refreshAll(context: Context, useCompat: Boolean, ignoreResults: Boolean): ContactDiscovery.RefreshResult {
|
||||
val recipientE164s: Set<String> = SignalDatabase.recipients.getAllE164s().sanitize()
|
||||
val systemE164s: Set<String> = SystemContactsRepository.getAllDisplayNumbers(context).toE164s(context).sanitize()
|
||||
|
||||
|
@ -52,7 +53,7 @@ object ContactDiscoveryRefreshV2 {
|
|||
recipientE164s = recipientE164s,
|
||||
systemE164s = systemE164s,
|
||||
inputPreviousE164s = SignalDatabase.cds.getAllE164s(),
|
||||
saveToken = true,
|
||||
isPartialRefresh = false,
|
||||
useCompat = useCompat,
|
||||
ignoreResults = ignoreResults
|
||||
)
|
||||
|
@ -62,14 +63,14 @@ object ContactDiscoveryRefreshV2 {
|
|||
@WorkerThread
|
||||
@Synchronized
|
||||
@JvmStatic
|
||||
fun refresh(context: Context, inputRecipients: List<Recipient>, useCompat: Boolean, ignoreResults: Boolean = false): ContactDiscovery.RefreshResult {
|
||||
fun refresh(context: Context, inputRecipients: List<Recipient>, useCompat: Boolean, ignoreResults: Boolean): ContactDiscovery.RefreshResult {
|
||||
val recipients: List<Recipient> = inputRecipients.map { it.resolve() }
|
||||
val inputE164s: Set<String> = recipients.mapNotNull { it.e164.orElse(null) }.toSet()
|
||||
|
||||
return if (inputE164s.size > MAXIMUM_ONE_OFF_REQUEST_SIZE) {
|
||||
Log.i(TAG, "List of specific recipients to refresh is too large! (Size: ${recipients.size}). Doing a full refresh instead.")
|
||||
|
||||
val fullResult: ContactDiscovery.RefreshResult = refreshAll(context, ignoreResults)
|
||||
val fullResult: ContactDiscovery.RefreshResult = refreshAll(context, useCompat = useCompat, ignoreResults = ignoreResults)
|
||||
val inputIds: Set<RecipientId> = recipients.map { it.id }.toSet()
|
||||
|
||||
ContactDiscovery.RefreshResult(
|
||||
|
@ -81,7 +82,7 @@ object ContactDiscoveryRefreshV2 {
|
|||
recipientE164s = inputE164s,
|
||||
systemE164s = inputE164s,
|
||||
inputPreviousE164s = emptySet(),
|
||||
saveToken = false,
|
||||
isPartialRefresh = true,
|
||||
useCompat = useCompat,
|
||||
ignoreResults = ignoreResults
|
||||
)
|
||||
|
@ -93,13 +94,14 @@ object ContactDiscoveryRefreshV2 {
|
|||
recipientE164s: Set<String>,
|
||||
systemE164s: Set<String>,
|
||||
inputPreviousE164s: Set<String>,
|
||||
saveToken: Boolean,
|
||||
isPartialRefresh: Boolean,
|
||||
useCompat: Boolean,
|
||||
ignoreResults: Boolean
|
||||
): ContactDiscovery.RefreshResult {
|
||||
val stopwatch = Stopwatch("refreshInternal-${if (useCompat) "compat" else "v2"}")
|
||||
val tag = "refreshInternal-${if (useCompat) "compat" else "v2"}"
|
||||
val stopwatch = Stopwatch(tag)
|
||||
|
||||
val previousE164s: Set<String> = if (SignalStore.misc().cdsToken != null) inputPreviousE164s else emptySet()
|
||||
val previousE164s: Set<String> = if (SignalStore.misc().cdsToken != null && !isPartialRefresh) inputPreviousE164s else emptySet()
|
||||
|
||||
val allE164s: Set<String> = recipientE164s + systemE164s
|
||||
val newRawE164s: Set<String> = allE164s - previousE164s
|
||||
|
@ -107,40 +109,50 @@ object ContactDiscoveryRefreshV2 {
|
|||
val newE164s: Set<String> = fuzzyInput.numbers
|
||||
|
||||
if (newE164s.isEmpty() && previousE164s.isEmpty()) {
|
||||
Log.w(TAG, "[refreshInternal] No data to send! Ignoring.")
|
||||
Log.w(TAG, "[$tag] No data to send! Ignoring.")
|
||||
return ContactDiscovery.RefreshResult(emptySet(), emptyMap())
|
||||
}
|
||||
|
||||
val token: ByteArray? = if (previousE164s.isNotEmpty()) SignalStore.misc().cdsToken else null
|
||||
val token: ByteArray? = if (previousE164s.isNotEmpty() && !isPartialRefresh) SignalStore.misc().cdsToken else null
|
||||
|
||||
stopwatch.split("preamble")
|
||||
|
||||
val response: CdsiV2Service.Response = ApplicationDependencies.getSignalServiceAccountManager().getRegisteredUsersWithCdsi(
|
||||
previousE164s,
|
||||
newE164s,
|
||||
SignalDatabase.recipients.getAllServiceIdProfileKeyPairs(),
|
||||
useCompat,
|
||||
Optional.ofNullable(token),
|
||||
BuildConfig.CDSI_MRENCLAVE
|
||||
) { tokenToSave ->
|
||||
if (saveToken) {
|
||||
SignalStore.misc().cdsToken = tokenToSave
|
||||
Log.d(TAG, "Token saved!")
|
||||
} else {
|
||||
Log.d(TAG, "Ignoring token.")
|
||||
val response: CdsiV2Service.Response = try {
|
||||
ApplicationDependencies.getSignalServiceAccountManager().getRegisteredUsersWithCdsi(
|
||||
previousE164s,
|
||||
newE164s,
|
||||
SignalDatabase.recipients.getAllServiceIdProfileKeyPairs(),
|
||||
useCompat,
|
||||
Optional.ofNullable(token),
|
||||
BuildConfig.CDSI_MRENCLAVE
|
||||
) { tokenToSave ->
|
||||
stopwatch.split("network-pre-token")
|
||||
if (!isPartialRefresh) {
|
||||
SignalStore.misc().cdsToken = tokenToSave
|
||||
SignalDatabase.cds.updateAfterFullCdsQuery(previousE164s + newE164s, allE164s + newE164s)
|
||||
Log.d(TAG, "Token saved!")
|
||||
} else {
|
||||
SignalDatabase.cds.updateAfterPartialCdsQuery(newE164s)
|
||||
Log.d(TAG, "Ignoring token.")
|
||||
}
|
||||
stopwatch.split("cds-db")
|
||||
}
|
||||
} catch (e: NonSuccessfulResponseCodeException) {
|
||||
if (e.code == 4101) {
|
||||
Log.w(TAG, "Our token was invalid! Only thing we can do now is clear our local state :(")
|
||||
SignalStore.misc().cdsToken = null
|
||||
SignalDatabase.cds.clearAll()
|
||||
}
|
||||
throw e
|
||||
}
|
||||
Log.d(TAG, "[refreshInternal] Used ${response.quotaUsedDebugOnly} quota.")
|
||||
stopwatch.split("network")
|
||||
|
||||
SignalDatabase.cds.updateAfterCdsQuery(newE164s, allE164s + newE164s)
|
||||
stopwatch.split("cds-db")
|
||||
Log.d(TAG, "[$tag] Used ${response.quotaUsedDebugOnly} quota.")
|
||||
stopwatch.split("network-post-token")
|
||||
|
||||
val registeredIds: MutableSet<RecipientId> = mutableSetOf()
|
||||
val rewrites: MutableMap<String, String> = mutableMapOf()
|
||||
|
||||
if (ignoreResults) {
|
||||
Log.w(TAG, "[refreshInternal] Ignoring CDSv2 results.")
|
||||
Log.w(TAG, "[$tag] Ignoring CDSv2 results.")
|
||||
} else {
|
||||
if (useCompat) {
|
||||
val transformed: Map<String, ACI?> = response.results.mapValues { entry -> entry.value.aci.orElse(null) }
|
||||
|
|
|
@ -8,7 +8,7 @@ import org.signal.core.util.delete
|
|||
import org.signal.core.util.logging.Log
|
||||
import org.signal.core.util.requireNonNullString
|
||||
import org.signal.core.util.select
|
||||
import org.signal.core.util.update
|
||||
import org.signal.core.util.withinTransaction
|
||||
|
||||
/**
|
||||
* Keeps track of the numbers we've previously queried CDS for.
|
||||
|
@ -53,32 +53,58 @@ class CdsDatabase(context: Context, databaseHelper: SignalDatabase) : Database(c
|
|||
}
|
||||
|
||||
/**
|
||||
* @param newE164s The newly-added E164s that we hadn't previously queried for.
|
||||
* @param seenE164s The E164s that were seen in either the system contacts or recipients table.
|
||||
* This should be a superset of [newE164s]
|
||||
*
|
||||
* Saves the set of e164s used after a full refresh.
|
||||
* @param fullE164s All of the e164s used in the last CDS query (previous and new).
|
||||
* @param seenE164s The E164s that were seen in either the system contacts or recipients table. This is different from [fullE164s] in that [fullE164s]
|
||||
* includes every number we've ever seen, even if it's not in our contacts anymore.
|
||||
*/
|
||||
fun updateAfterCdsQuery(newE164s: Set<String>, seenE164s: Set<String>) {
|
||||
fun updateAfterFullCdsQuery(fullE164s: Set<String>, seenE164s: Set<String>) {
|
||||
val lastSeen = System.currentTimeMillis()
|
||||
|
||||
writableDatabase.beginTransaction()
|
||||
try {
|
||||
val insertValues: List<ContentValues> = newE164s.map { contentValuesOf(E164 to it) }
|
||||
writableDatabase.withinTransaction { db ->
|
||||
val existingE164s: Set<String> = getAllE164s()
|
||||
val removedE164s: Set<String> = existingE164s - fullE164s
|
||||
val addedE164s: Set<String> = fullE164s - existingE164s
|
||||
|
||||
SqlUtil.buildBulkInsert(TABLE_NAME, arrayOf(E164), insertValues)
|
||||
.forEach { writableDatabase.execSQL(it.where, it.whereArgs) }
|
||||
if (removedE164s.isNotEmpty()) {
|
||||
SqlUtil.buildCollectionQuery(E164, removedE164s)
|
||||
.forEach { db.delete(TABLE_NAME, it.where, it.whereArgs) }
|
||||
}
|
||||
|
||||
val contentValues = contentValuesOf(LAST_SEEN_AT to lastSeen)
|
||||
if (addedE164s.isNotEmpty()) {
|
||||
val insertValues: List<ContentValues> = addedE164s.map { contentValuesOf(E164 to it) }
|
||||
|
||||
SqlUtil.buildCollectionQuery(E164, seenE164s)
|
||||
.forEach { query -> writableDatabase.update(TABLE_NAME, contentValues, query.where, query.whereArgs) }
|
||||
SqlUtil.buildBulkInsert(TABLE_NAME, arrayOf(E164), insertValues)
|
||||
.forEach { db.execSQL(it.where, it.whereArgs) }
|
||||
}
|
||||
|
||||
writableDatabase.setTransactionSuccessful()
|
||||
} finally {
|
||||
writableDatabase.endTransaction()
|
||||
if (seenE164s.isNotEmpty()) {
|
||||
val contentValues = contentValuesOf(LAST_SEEN_AT to lastSeen)
|
||||
|
||||
SqlUtil.buildCollectionQuery(E164, seenE164s)
|
||||
.forEach { query -> db.update(TABLE_NAME, contentValues, query.where, query.whereArgs) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates after a partial CDS query. Will not insert new entries. Instead, this will simply update the lastSeen timestamp of any entry we already have.
|
||||
* @param seenE164s The newly-added E164s that we hadn't previously queried for.
|
||||
*/
|
||||
fun updateAfterPartialCdsQuery(seenE164s: Set<String>) {
|
||||
val lastSeen = System.currentTimeMillis()
|
||||
|
||||
writableDatabase.withinTransaction { db ->
|
||||
val contentValues = contentValuesOf(LAST_SEEN_AT to lastSeen)
|
||||
|
||||
SqlUtil.buildCollectionQuery(E164, seenE164s)
|
||||
.forEach { query -> db.update(TABLE_NAME, contentValues, query.where, query.whereArgs) }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Wipes the entire table.
|
||||
*/
|
||||
fun clearAll() {
|
||||
writableDatabase
|
||||
.delete(TABLE_NAME)
|
||||
|
|
|
@ -549,7 +549,11 @@ public class SignalServiceAccountManager {
|
|||
if (serviceResponse.getResult().isPresent()) {
|
||||
return serviceResponse.getResult().get();
|
||||
} else if (serviceResponse.getApplicationError().isPresent()) {
|
||||
throw new IOException(serviceResponse.getApplicationError().get());
|
||||
if (serviceResponse.getApplicationError().get() instanceof IOException) {
|
||||
throw (IOException) serviceResponse.getApplicationError().get();
|
||||
} else {
|
||||
throw new IOException(serviceResponse.getApplicationError().get());
|
||||
}
|
||||
} else if (serviceResponse.getExecutionError().isPresent()) {
|
||||
throw new IOException(serviceResponse.getExecutionError().get());
|
||||
} else {
|
||||
|
|
Loading…
Add table
Reference in a new issue