Improve handling of unknown IDs in storage service.

This commit is contained in:
Greyson Parrelli 2022-04-04 09:54:50 -04:00 committed by Cody Henthorne
parent e2c54eef77
commit b34ca8ca2f
7 changed files with 139 additions and 77 deletions

View file

@ -33,6 +33,7 @@ import org.thoughtcrime.securesms.jobs.StorageForcePushJob
import org.thoughtcrime.securesms.jobs.SubscriptionReceiptRequestResponseJob
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.payments.DataExportUtil
import org.thoughtcrime.securesms.storage.StorageSyncHelper
import org.thoughtcrime.securesms.util.ConversationUtil
import org.thoughtcrime.securesms.util.FeatureFlags
import java.util.Optional
@ -136,11 +137,19 @@ class InternalSettingsFragment : DSLSettingsFragment(R.string.preferences__inter
}
)
clickPref(
title = DSLSettingsText.from(R.string.preferences__internal_sync_now),
summary = DSLSettingsText.from(R.string.preferences__internal_sync_now_description),
onClick = {
enqueueStorageServiceSync()
}
)
clickPref(
title = DSLSettingsText.from(R.string.preferences__internal_force_storage_service_sync),
summary = DSLSettingsText.from(R.string.preferences__internal_force_storage_service_sync_description),
onClick = {
forceStorageServiceSync()
enqueueStorageServiceForcePush()
}
)
@ -475,7 +484,12 @@ class InternalSettingsFragment : DSLSettingsFragment(R.string.preferences__inter
}
}
private fun forceStorageServiceSync() {
private fun enqueueStorageServiceSync() {
StorageSyncHelper.scheduleSyncForDataChange()
Toast.makeText(context, "Scheduled routine storage sync", Toast.LENGTH_SHORT).show()
}
private fun enqueueStorageServiceForcePush() {
ApplicationDependencies.getJobManager().add(StorageForcePushJob())
Toast.makeText(context, "Scheduled storage force push", Toast.LENGTH_SHORT).show()
}

View file

@ -9,6 +9,7 @@ import androidx.annotation.Nullable;
import com.annimon.stream.Stream;
import org.signal.core.util.CursorUtil;
import org.thoughtcrime.securesms.util.Base64;
import org.signal.core.util.SqlUtil;
import org.whispersystems.signalservice.api.storage.SignalStorageRecord;
@ -46,13 +47,10 @@ public class UnknownStorageIdDatabase extends Database {
public List<StorageId> getAllUnknownIds() {
List<StorageId> keys = new ArrayList<>();
String query = TYPE + " > ?";
String[] args = SqlUtil.buildArgs(StorageId.largestKnownType());
try (Cursor cursor = databaseHelper.getSignalReadableDatabase().query(TABLE_NAME, null, query, args, null, null, null)) {
try (Cursor cursor = databaseHelper.getSignalReadableDatabase().query(TABLE_NAME, null, null, null, null, null, null)) {
while (cursor != null && cursor.moveToNext()) {
String keyEncoded = cursor.getString(cursor.getColumnIndexOrThrow(STORAGE_ID));
int type = cursor.getInt(cursor.getColumnIndexOrThrow(TYPE));
String keyEncoded = CursorUtil.requireString(cursor, STORAGE_ID);
int type = CursorUtil.requireInt(cursor, TYPE);
try {
keys.add(StorageId.forType(Base64.decode(keyEncoded), type));
} catch (IOException e) {
@ -64,13 +62,35 @@ public class UnknownStorageIdDatabase extends Database {
return keys;
}
/**
* Gets all StorageIds of items with the specified types.
*/
public List<StorageId> getAllWithTypes(List<Integer> types) {
List<StorageId> ids = new ArrayList<>();
SqlUtil.Query query = SqlUtil.buildCollectionQuery(TYPE, types);
try (Cursor cursor = databaseHelper.getSignalReadableDatabase().query(TABLE_NAME, null, query.getWhere(), query.getWhereArgs(), null, null, null)) {
while (cursor != null && cursor.moveToNext()) {
String keyEncoded = CursorUtil.requireString(cursor, STORAGE_ID);
int type = CursorUtil.requireInt(cursor, TYPE);
try {
ids.add(StorageId.forType(Base64.decode(keyEncoded), type));
} catch (IOException e) {
throw new AssertionError(e);
}
}
}
return ids;
}
public @Nullable SignalStorageRecord getById(@NonNull byte[] rawId) {
String query = STORAGE_ID + " = ?";
String[] args = new String[] { Base64.encodeBytes(rawId) };
try (Cursor cursor = databaseHelper.getSignalReadableDatabase().query(TABLE_NAME, null, query, args, null, null, null)) {
if (cursor != null && cursor.moveToFirst()) {
int type = cursor.getInt(cursor.getColumnIndexOrThrow(TYPE));
int type = CursorUtil.requireInt(cursor, TYPE);
return SignalStorageRecord.forUnknown(StorageId.forType(rawId, type));
} else {
return null;
@ -78,22 +98,6 @@ public class UnknownStorageIdDatabase extends Database {
}
}
public void applyStorageSyncUpdates(@NonNull Collection<SignalStorageRecord> inserts,
@NonNull Collection<SignalStorageRecord> deletes)
{
SQLiteDatabase db = databaseHelper.getSignalWritableDatabase();
db.beginTransaction();
try {
insert(inserts);
delete(Stream.of(deletes).map(SignalStorageRecord::getId).toList());
db.setTransactionSuccessful();
} finally {
db.endTransaction();
}
}
public void insert(@NonNull Collection<SignalStorageRecord> inserts) {
SQLiteDatabase db = databaseHelper.getSignalWritableDatabase();
@ -120,13 +124,6 @@ public class UnknownStorageIdDatabase extends Database {
}
}
public void deleteByType(int type) {
String query = TYPE + " = ?";
String[] args = new String[]{String.valueOf(type)};
databaseHelper.getSignalWritableDatabase().delete(TABLE_NAME, query, args);
}
public void deleteAll() {
databaseHelper.getSignalWritableDatabase().delete(TABLE_NAME, null, null);
}

View file

@ -197,8 +197,9 @@ object SignalDatabaseMigrations {
private const val STORY_SENDS = 136
private const val STORY_TYPE_AND_DISTRIBUTION = 137
private const val CLEAN_DELETED_DISTRIBUTION_LISTS = 138
private const val REMOVE_KNOWN_UNKNOWNS = 139
const val DATABASE_VERSION = 138
const val DATABASE_VERSION = 139
@JvmStatic
fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
@ -2535,6 +2536,11 @@ object SignalDatabaseMigrations {
""".trimIndent()
)
}
if (oldVersion < REMOVE_KNOWN_UNKNOWNS) {
val count: Int = db.delete("storage_key", "type <= ?", SqlUtil.buildArgs(4))
Log.i(TAG, "Cleaned up $count invalid unknown records.")
}
}
@JvmStatic

View file

@ -54,7 +54,6 @@ import org.whispersystems.signalservice.api.storage.StorageId;
import org.whispersystems.signalservice.api.storage.StorageKey;
import org.whispersystems.signalservice.internal.push.SignalServiceProtos;
import org.whispersystems.signalservice.internal.storage.protos.ManifestRecord;
import org.whispersystems.signalservice.internal.storage.protos.StoryDistributionListRecord;
import java.io.IOException;
import java.util.ArrayList;
@ -65,6 +64,7 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
/**
* Does a full sync of our local storage state with the remote storage state. Will write any pending
@ -107,9 +107,10 @@ import java.util.concurrent.TimeUnit;
* the diff in IDs.
* - Then, we fetch the actual records that correspond to the remote-only IDs.
* - Afterwards, we take those records and merge them into our local data store.
* - Finally, we assume that our local state represents the most up-to-date information, and so we
* - Next, we assume that our local state represents the most up-to-date information, and so we
* calculate and write a change set that represents the diff between our state and the remote
* state.
* - Finally, handle any possible records in our "unknown ID store" that might have become known to us.
*
* Of course, you'll notice that there's a lot of code to support that goal. That's mostly because
* converting local data into a format that can be compared with, merged, and eventually written
@ -249,7 +250,7 @@ public class StorageSyncJob extends BaseJob {
if (remoteManifest.getVersion() > localManifest.getVersion()) {
Log.i(TAG, "[Remote Sync] Newer manifest version found!");
List<StorageId> localStorageIdsBeforeMerge = getAllLocalStorageIds(context, self);
List<StorageId> localStorageIdsBeforeMerge = getAllLocalStorageIds(self);
IdDifferenceResult idDifference = StorageSyncHelper.findIdDifference(remoteManifest.getStorageIds(), localStorageIdsBeforeMerge);
if (idDifference.hasTypeMismatches() && SignalStore.account().isPrimaryDevice()) {
@ -264,53 +265,23 @@ public class StorageSyncJob extends BaseJob {
if (!idDifference.isEmpty()) {
Log.i(TAG, "[Remote Sync] Retrieving records for key difference.");
List<SignalStorageRecord> remoteOnly = accountManager.readStorageRecords(storageServiceKey, idDifference.getRemoteOnlyIds());
List<SignalStorageRecord> remoteOnlyRecords = accountManager.readStorageRecords(storageServiceKey, idDifference.getRemoteOnlyIds());
stopwatch.split("remote-records");
if (remoteOnly.size() != idDifference.getRemoteOnlyIds().size()) {
Log.w(TAG, "[Remote Sync] Could not find all remote-only records! Requested: " + idDifference.getRemoteOnlyIds().size() + ", Found: " + remoteOnly.size() + ". These stragglers should naturally get deleted during the sync.");
if (remoteOnlyRecords.size() != idDifference.getRemoteOnlyIds().size()) {
Log.w(TAG, "[Remote Sync] Could not find all remote-only records! Requested: " + idDifference.getRemoteOnlyIds().size() + ", Found: " + remoteOnlyRecords.size() + ". These stragglers should naturally get deleted during the sync.");
}
List<SignalContactRecord> remoteContacts = new LinkedList<>();
List<SignalGroupV1Record> remoteGv1 = new LinkedList<>();
List<SignalGroupV2Record> remoteGv2 = new LinkedList<>();
List<SignalAccountRecord> remoteAccount = new LinkedList<>();
List<SignalStorageRecord> remoteUnknown = new LinkedList<>();
List<SignalStoryDistributionListRecord> remoteStoryDistributionLists = new LinkedList<>();
for (SignalStorageRecord remote : remoteOnly) {
if (remote.getContact().isPresent()) {
remoteContacts.add(remote.getContact().get());
} else if (remote.getGroupV1().isPresent()) {
remoteGv1.add(remote.getGroupV1().get());
} else if (remote.getGroupV2().isPresent()) {
remoteGv2.add(remote.getGroupV2().get());
} else if (remote.getAccount().isPresent()) {
remoteAccount.add(remote.getAccount().get());
} else if (remote.getStoryDistributionList().isPresent()) {
remoteStoryDistributionLists.add(remote.getStoryDistributionList().get());
} else if (remote.getId().isUnknown()) {
remoteUnknown.add(remote);
} else {
Log.w(TAG, "Bad record! Type is a known value (" + remote.getId().getType() + "), but doesn't have a matching inner record. Dropping it.");
}
}
StorageRecordCollection remoteOnly = new StorageRecordCollection(remoteOnlyRecords);
db.beginTransaction();
try {
self = freshSelf();
Log.i(TAG, "[Remote Sync] Remote-Only :: Contacts: " + remoteOnly.contacts.size() + ", GV1: " + remoteOnly.gv1.size() + ", GV2: " + remoteOnly.gv2.size() + ", Account: " + remoteOnly.account.size() + ", DLists: " + remoteOnly.storyDistributionLists.size());
Log.i(TAG, "[Remote Sync] Remote-Only :: Contacts: " + remoteContacts.size() + ", GV1: " + remoteGv1.size() + ", GV2: " + remoteGv2.size() + ", Account: " + remoteAccount.size());
processKnownRecords(context, remoteOnly);
new ContactRecordProcessor(context, self).process(remoteContacts, StorageSyncHelper.KEY_GENERATOR);
new GroupV1RecordProcessor(context).process(remoteGv1, StorageSyncHelper.KEY_GENERATOR);
new GroupV2RecordProcessor(context).process(remoteGv2, StorageSyncHelper.KEY_GENERATOR);
self = freshSelf();
new AccountRecordProcessor(context, self).process(remoteAccount, StorageSyncHelper.KEY_GENERATOR);
new StoryDistributionListRecordProcessor().process(remoteStoryDistributionLists, StorageSyncHelper.KEY_GENERATOR);
List<SignalStorageRecord> unknownInserts = remoteUnknown;
List<SignalStorageRecord> unknownInserts = remoteOnly.unknown;
List<StorageId> unknownDeletes = Stream.of(idDifference.getLocalOnlyIds()).filter(StorageId::isUnknown).toList();
Log.i(TAG, "[Remote Sync] Unknowns :: " + unknownInserts.size() + " inserts, " + unknownDeletes.size() + " deletes");
@ -344,7 +315,7 @@ public class StorageSyncJob extends BaseJob {
try {
self = freshSelf();
List<StorageId> localStorageIds = getAllLocalStorageIds(context, self);
List<StorageId> localStorageIds = getAllLocalStorageIds(self);
IdDifferenceResult idDifference = StorageSyncHelper.findIdDifference(remoteManifest.getStorageIds(), localStorageIds);
List<SignalStorageRecord> remoteInserts = buildLocalStorageRecords(context, self, idDifference.getLocalOnlyIds());
List<byte[]> remoteDeletes = Stream.of(idDifference.getRemoteOnlyIds()).map(StorageId::getRaw).toList();
@ -384,6 +355,32 @@ public class StorageSyncJob extends BaseJob {
Log.i(TAG, "No remote writes needed. Still at version: " + remoteManifest.getVersion());
}
List<Integer> knownTypes = getKnownTypes();
List<StorageId> knownUnknownIds = SignalDatabase.unknownStorageIds().getAllWithTypes(knownTypes);
if (knownUnknownIds.size() > 0) {
Log.i(TAG, "We have " + knownUnknownIds.size() + " unknown records that we can now process.");
List<SignalStorageRecord> remote = accountManager.readStorageRecords(storageServiceKey, knownUnknownIds);
StorageRecordCollection records = new StorageRecordCollection(remote);
Log.i(TAG, "Found " + remote.size() + " of the known-unknowns remotely.");
db.beginTransaction();
try {
processKnownRecords(context, records);
SignalDatabase.unknownStorageIds().getAllWithTypes(knownTypes);
db.setTransactionSuccessful();
} finally {
db.endTransaction();
}
Log.i(TAG, "Enqueueing a storage sync job to handle any possible merges after applying unknown records.");
ApplicationDependencies.getJobManager().add(new StorageSyncJob());
}
stopwatch.split("known-unknowns");
if (needsForcePush && SignalStore.account().isPrimaryDevice()) {
Log.w(TAG, "Scheduling a force push.");
ApplicationDependencies.getJobManager().add(new StorageForcePushJob());
@ -393,7 +390,17 @@ public class StorageSyncJob extends BaseJob {
return needsMultiDeviceSync;
}
private static @NonNull List<StorageId> getAllLocalStorageIds(@NonNull Context context, @NonNull Recipient self) {
private static void processKnownRecords(@NonNull Context context, @NonNull StorageRecordCollection records) throws IOException {
Recipient self = freshSelf();
new ContactRecordProcessor(context, self).process(records.contacts, StorageSyncHelper.KEY_GENERATOR);
new GroupV1RecordProcessor(context).process(records.gv1, StorageSyncHelper.KEY_GENERATOR);
new GroupV2RecordProcessor(context).process(records.gv2, StorageSyncHelper.KEY_GENERATOR);
self = freshSelf();
new AccountRecordProcessor(context, self).process(records.account, StorageSyncHelper.KEY_GENERATOR);
new StoryDistributionListRecordProcessor().process(records.storyDistributionLists, StorageSyncHelper.KEY_GENERATOR);
}
private static @NonNull List<StorageId> getAllLocalStorageIds(@NonNull Recipient self) {
return Util.concatenatedList(SignalDatabase.recipients().getContactStorageSyncIds(),
Collections.singletonList(StorageId.forAccount(self.getStorageServiceId())),
SignalDatabase.unknownStorageIds().getAllUnknownIds());
@ -460,6 +467,42 @@ public class StorageSyncJob extends BaseJob {
return Recipient.self();
}
private static List<Integer> getKnownTypes() {
return Arrays.stream(ManifestRecord.Identifier.Type.values())
.filter(it -> !it.equals(ManifestRecord.Identifier.Type.UNKNOWN) && !it.equals(ManifestRecord.Identifier.Type.UNRECOGNIZED))
.map(it -> it.getNumber())
.collect(Collectors.toList());
}
private static final class StorageRecordCollection {
final List<SignalContactRecord> contacts = new LinkedList<>();
final List<SignalGroupV1Record> gv1 = new LinkedList<>();
final List<SignalGroupV2Record> gv2 = new LinkedList<>();
final List<SignalAccountRecord> account = new LinkedList<>();
final List<SignalStorageRecord> unknown = new LinkedList<>();
final List<SignalStoryDistributionListRecord> storyDistributionLists = new LinkedList<>();
StorageRecordCollection(Collection<SignalStorageRecord> records) {
for (SignalStorageRecord record : records) {
if (record.getContact().isPresent()) {
contacts.add(record.getContact().get());
} else if (record.getGroupV1().isPresent()) {
gv1.add(record.getGroupV1().get());
} else if (record.getGroupV2().isPresent()) {
gv2.add(record.getGroupV2().get());
} else if (record.getAccount().isPresent()) {
account.add(record.getAccount().get());
} else if (record.getStoryDistributionList().isPresent()) {
storyDistributionLists.add(record.getStoryDistributionList().get());
} else if (record.getId().isUnknown()) {
unknown.add(record);
} else {
Log.w(TAG, "Bad record! Type is a known value (" + record.getId().getType() + "), but doesn't have a matching inner record. Dropping it.");
}
}
}
}
private static final class MissingGv2MasterKeyError extends Error {}
private static final class MissingRecipientModelError extends Error {

View file

@ -2665,6 +2665,8 @@
<string name="preferences__internal_disable_storage_service" translatable="false">Disable syncing</string>
<string name="preferences__internal_disable_storage_service_description" translatable="false">Prevent syncing any data to/from storage service.</string>
<string name="preferences__internal_force_storage_service_sync" translatable="false">Overwrite remote data</string>
<string name="preferences__internal_sync_now" translatable="false">Sync now</string>
<string name="preferences__internal_sync_now_description" translatable="false">Enqueue a normal storage service sync.</string>
<string name="preferences__internal_force_storage_service_sync_description" translatable="false">Forces remote storage to match the local device state.</string>
<string name="preferences__internal_network" translatable="false">Network</string>
<string name="preferences__internal_allow_censorship_toggle" translatable="false">Allow censorship circumvention toggle</string>

View file

@ -690,7 +690,7 @@ public class SignalServiceAccountManager {
for (StorageId id : manifest.getStorageIds()) {
ManifestRecord.Identifier idProto = ManifestRecord.Identifier.newBuilder()
.setRaw(ByteString.copyFrom(id.getRaw()))
.setType(ManifestRecord.Identifier.Type.forNumber(id.getType())).build();
.setTypeValue(id.getType()).build();
manifestRecordBuilder.addIdentifiers(idProto);
}

View file

@ -24,7 +24,7 @@ public final class SignalStorageModels {
List<StorageId> ids = new ArrayList<>(manifestRecord.getIdentifiersCount());
for (ManifestRecord.Identifier id : manifestRecord.getIdentifiersList()) {
ids.add(StorageId.forType(id.getRaw().toByteArray(), id.getType().getNumber()));
ids.add(StorageId.forType(id.getRaw().toByteArray(), id.getTypeValue()));
}
return new SignalStorageManifest(manifestRecord.getVersion(), ids);