diff --git a/app/src/main/java/org/thoughtcrime/securesms/storage/GroupV2ConflictMerger.java b/app/src/main/java/org/thoughtcrime/securesms/storage/GroupV2ConflictMerger.java new file mode 100644 index 0000000000..85c86c547e --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/storage/GroupV2ConflictMerger.java @@ -0,0 +1,48 @@ +package org.thoughtcrime.securesms.storage; + +import androidx.annotation.NonNull; + +import com.annimon.stream.Collectors; +import com.annimon.stream.Stream; + +import org.signal.zkgroup.groups.GroupMasterKey; +import org.whispersystems.libsignal.util.guava.Optional; +import org.whispersystems.signalservice.api.storage.SignalGroupV2Record; + +import java.util.Collection; +import java.util.Map; + +class GroupV2ConflictMerger implements StorageSyncHelper.ConflictMerger { + + private final Map localByGroupId; + + GroupV2ConflictMerger(@NonNull Collection localOnly) { + localByGroupId = Stream.of(localOnly).collect(Collectors.toMap(SignalGroupV2Record::getMasterKey, g -> g)); + } + + @Override + public @NonNull Optional getMatching(@NonNull SignalGroupV2Record record) { + return Optional.fromNullable(localByGroupId.get(record.getMasterKey())); + } + + @Override + public @NonNull SignalGroupV2Record merge(@NonNull SignalGroupV2Record remote, @NonNull SignalGroupV2Record local, @NonNull StorageSyncHelper.KeyGenerator keyGenerator) { + boolean blocked = remote.isBlocked(); + boolean profileSharing = remote.isProfileSharingEnabled() || local.isProfileSharingEnabled(); + boolean archived = remote.isArchived(); + + boolean matchesRemote = blocked == remote.isBlocked() && profileSharing == remote.isProfileSharingEnabled() && archived == remote.isArchived(); + boolean matchesLocal = blocked == local.isBlocked() && profileSharing == local.isProfileSharingEnabled() && archived == local.isArchived(); + + if (matchesRemote) { + return remote; + } else if (matchesLocal) { + return local; + } else { + return new SignalGroupV2Record.Builder(keyGenerator.generate(), remote.getMasterKey()) + .setBlocked(blocked) + .setProfileSharingEnabled(blocked) + .build(); + } + } +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/storage/StorageSyncHelper.java b/app/src/main/java/org/thoughtcrime/securesms/storage/StorageSyncHelper.java index 8c8b0a4a4a..dc7d1e5b93 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/storage/StorageSyncHelper.java +++ b/app/src/main/java/org/thoughtcrime/securesms/storage/StorageSyncHelper.java @@ -139,8 +139,9 @@ public final class StorageSyncHelper { List remoteOnlyGroupV1 = Stream.of(remoteOnlyRecords).filter(r -> r.getGroupV1().isPresent()).map(r -> r.getGroupV1().get()).toList(); List localOnlyGroupV1 = Stream.of(localOnlyRecords).filter(r -> r.getGroupV1().isPresent()).map(r -> r.getGroupV1().get()).toList(); - List remoteOnlyUnknowns = Stream.of(remoteOnlyRecords).filter(SignalStorageRecord::isUnknown).toList(); - List localOnlyUnknowns = Stream.of(localOnlyRecords).filter(SignalStorageRecord::isUnknown).toList(); + // TODO [storage] Handle groupV2 when appropriate + List remoteOnlyUnknowns = Stream.of(remoteOnlyRecords).filter(r -> r.isUnknown() || r.getGroupV2().isPresent()).toList(); + List localOnlyUnknowns = Stream.of(localOnlyRecords).filter(r -> r.isUnknown() || r.getGroupV2().isPresent()).toList(); RecordMergeResult> contactMergeResult = resolveRecordConflict(remoteOnlyContacts, localOnlyContacts, new ContactConflictMerger(localOnlyContacts)); RecordMergeResult> groupV1MergeResult = resolveRecordConflict(remoteOnlyGroupV1, localOnlyGroupV1, new GroupV1ConflictMerger(localOnlyGroupV1)); diff --git a/app/src/main/java/org/thoughtcrime/securesms/storage/StorageSyncModels.java b/app/src/main/java/org/thoughtcrime/securesms/storage/StorageSyncModels.java index 406d6105b7..711ddc43ad 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/storage/StorageSyncModels.java +++ b/app/src/main/java/org/thoughtcrime/securesms/storage/StorageSyncModels.java @@ -3,7 +3,6 @@ package org.thoughtcrime.securesms.storage; import androidx.annotation.NonNull; import org.thoughtcrime.securesms.database.IdentityDatabase; -import org.thoughtcrime.securesms.database.RecipientDatabase; import org.thoughtcrime.securesms.database.RecipientDatabase.RecipientSettings; import org.thoughtcrime.securesms.util.GroupUtil; import org.whispersystems.signalservice.api.push.SignalServiceAddress; diff --git a/app/src/test/java/org/thoughtcrime/securesms/storage/GroupV2ConflictMergerTest.java b/app/src/test/java/org/thoughtcrime/securesms/storage/GroupV2ConflictMergerTest.java new file mode 100644 index 0000000000..19ea87b226 --- /dev/null +++ b/app/src/test/java/org/thoughtcrime/securesms/storage/GroupV2ConflictMergerTest.java @@ -0,0 +1,90 @@ +package org.thoughtcrime.securesms.storage; + +import org.junit.Test; +import org.signal.zkgroup.InvalidInputException; +import org.signal.zkgroup.groups.GroupMasterKey; +import org.thoughtcrime.securesms.storage.StorageSyncHelper.KeyGenerator; +import org.whispersystems.signalservice.api.storage.SignalGroupV2Record; + +import java.util.Collections; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.powermock.api.mockito.PowerMockito.mock; +import static org.powermock.api.mockito.PowerMockito.when; +import static org.thoughtcrime.securesms.testutil.TestHelpers.byteArray; + +public class GroupV2ConflictMergerTest { + + private static byte[] GENERATED_KEY = byteArray(8675309); + private static KeyGenerator KEY_GENERATOR = mock(KeyGenerator.class); + static { + when(KEY_GENERATOR.generate()).thenReturn(GENERATED_KEY); + } + + @Test + public void merge_alwaysPreferRemote_exceptProfileSharingIsEitherOr() { + SignalGroupV2Record remote = new SignalGroupV2Record.Builder(byteArray(1), groupKey(100)) + .setBlocked(false) + .setProfileSharingEnabled(false) + .setArchived(false) + .build(); + SignalGroupV2Record local = new SignalGroupV2Record.Builder(byteArray(2), groupKey(100)) + .setBlocked(true) + .setProfileSharingEnabled(true) + .setArchived(true) + .build(); + + SignalGroupV2Record merged = new GroupV2ConflictMerger(Collections.singletonList(local)).merge(remote, local, KEY_GENERATOR); + + assertArrayEquals(GENERATED_KEY, merged.getId().getRaw()); + assertEquals(groupKey(100), merged.getMasterKey()); + assertFalse(merged.isBlocked()); + assertFalse(merged.isArchived()); + } + + @Test + public void merge_returnRemoteIfEndResultMatchesRemote() { + SignalGroupV2Record remote = new SignalGroupV2Record.Builder(byteArray(1), groupKey(100)) + .setBlocked(false) + .setProfileSharingEnabled(true) + .setArchived(true) + .build(); + SignalGroupV2Record local = new SignalGroupV2Record.Builder(byteArray(2), groupKey(100)) + .setBlocked(true) + .setProfileSharingEnabled(false) + .setArchived(false) + .build(); + + SignalGroupV2Record merged = new GroupV2ConflictMerger(Collections.singletonList(local)).merge(remote, local, mock(KeyGenerator.class)); + + assertEquals(remote, merged); + } + + @Test + public void merge_returnLocalIfEndResultMatchesLocal() { + SignalGroupV2Record remote = new SignalGroupV2Record.Builder(byteArray(1), groupKey(100)) + .setBlocked(false) + .setProfileSharingEnabled(false) + .setArchived(false) + .build(); + SignalGroupV2Record local = new SignalGroupV2Record.Builder(byteArray(2), groupKey(100)) + .setBlocked(false) + .setProfileSharingEnabled(true) + .setArchived(false) + .build(); + + SignalGroupV2Record merged = new GroupV2ConflictMerger(Collections.singletonList(local)).merge(remote, local, mock(KeyGenerator.class)); + + assertEquals(local, merged); + } + + private static GroupMasterKey groupKey(int value) { + try { + return new GroupMasterKey(byteArray(value, 32)); + } catch (InvalidInputException e) { + throw new AssertionError(e); + } + } +} diff --git a/app/src/test/java/org/thoughtcrime/securesms/storage/StorageSyncHelperTest.java b/app/src/test/java/org/thoughtcrime/securesms/storage/StorageSyncHelperTest.java index dd484d8110..7ce1ed31c3 100644 --- a/app/src/test/java/org/thoughtcrime/securesms/storage/StorageSyncHelperTest.java +++ b/app/src/test/java/org/thoughtcrime/securesms/storage/StorageSyncHelperTest.java @@ -6,12 +6,14 @@ import com.annimon.stream.Stream; import org.junit.Before; import org.junit.Test; -import org.thoughtcrime.securesms.storage.StorageSyncHelper; +import org.signal.zkgroup.InvalidInputException; +import org.signal.zkgroup.groups.GroupMasterKey; import org.thoughtcrime.securesms.storage.StorageSyncHelper.KeyDifferenceResult; import org.thoughtcrime.securesms.storage.StorageSyncHelper.MergeResult; import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.storage.SignalContactRecord; import org.whispersystems.signalservice.api.storage.SignalGroupV1Record; +import org.whispersystems.signalservice.api.storage.SignalGroupV2Record; import org.whispersystems.signalservice.api.storage.SignalRecord; import org.whispersystems.signalservice.api.storage.SignalStorageRecord; import org.whispersystems.signalservice.api.storage.StorageId; @@ -151,15 +153,17 @@ public final class StorageSyncHelperTest { public void resolveConflict_unknowns() { SignalStorageRecord remote1 = unknown(3); SignalStorageRecord remote2 = unknown(4); + SignalStorageRecord remote3 = SignalStorageRecord.forGroupV2(groupV2(100, 200, true, true)); SignalStorageRecord local1 = unknown(1); SignalStorageRecord local2 = unknown(2); + SignalStorageRecord local3 = SignalStorageRecord.forGroupV2(groupV2(101, 201, true, true)); - MergeResult result = StorageSyncHelper.resolveConflict(setOf(remote1, remote2), setOf(local1, local2)); + MergeResult result = StorageSyncHelper.resolveConflict(setOf(remote1, remote2, remote3), setOf(local1, local2, local3)); assertTrue(result.getLocalContactInserts().isEmpty()); assertTrue(result.getLocalContactUpdates().isEmpty()); - assertEquals(setOf(remote1, remote2), result.getLocalUnknownInserts()); - assertEquals(setOf(local1, local2), result.getLocalUnknownDeletes()); + assertEquals(setOf(remote1, remote2, remote3), result.getLocalUnknownInserts()); + assertEquals(setOf(local1, local2, local3), result.getLocalUnknownDeletes()); } @Test @@ -269,6 +273,8 @@ public final class StorageSyncHelperTest { storageRecords.add(SignalStorageRecord.forContact(record.getId(), (SignalContactRecord) record)); } else if (record instanceof SignalGroupV1Record) { storageRecords.add(SignalStorageRecord.forGroupV1(record.getId(), (SignalGroupV1Record) record)); + } else if (record instanceof SignalGroupV2Record) { + storageRecords.add(SignalStorageRecord.forGroupV2(record.getId(), (SignalGroupV2Record) record)); } else { storageRecords.add(SignalStorageRecord.forUnknown(record.getId())); } @@ -312,6 +318,18 @@ public final class StorageSyncHelperTest { return new SignalGroupV1Record.Builder(byteArray(key), byteArray(groupId)).setBlocked(blocked).setProfileSharingEnabled(profileSharing).build(); } + private static SignalGroupV2Record groupV2(int key, + int groupId, + boolean blocked, + boolean profileSharing) + { + try { + return new SignalGroupV2Record.Builder(byteArray(key), new GroupMasterKey(byteArray(groupId, 32))).setBlocked(blocked).setProfileSharingEnabled(profileSharing).build(); + } catch (InvalidInputException e) { + throw new AssertionError(e); + } + } + private static StorageSyncHelper.RecordUpdate contactUpdate(SignalContactRecord oldContact, SignalContactRecord newContact) { return new StorageSyncHelper.RecordUpdate<>(oldContact, newContact); } diff --git a/app/src/test/java/org/thoughtcrime/securesms/testutil/TestHelpers.java b/app/src/test/java/org/thoughtcrime/securesms/testutil/TestHelpers.java index c5e599823f..5f318bbdf6 100644 --- a/app/src/test/java/org/thoughtcrime/securesms/testutil/TestHelpers.java +++ b/app/src/test/java/org/thoughtcrime/securesms/testutil/TestHelpers.java @@ -4,6 +4,7 @@ import com.annimon.stream.Stream; import com.google.common.collect.Sets; import org.thoughtcrime.securesms.util.Conversions; +import org.whispersystems.libsignal.util.ByteUtil; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -23,6 +24,12 @@ public final class TestHelpers { return Conversions.intToByteArray(a); } + public static byte[] byteArray(int a, int totalLength) { + byte[] out = new byte[totalLength - 4]; + byte[] val = Conversions.intToByteArray(a); + return ByteUtil.combine(out, val); + } + public static List byteListOf(int... vals) { List list = new ArrayList<>(vals.length); diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalGroupV2Record.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalGroupV2Record.java new file mode 100644 index 0000000000..2436e151d0 --- /dev/null +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalGroupV2Record.java @@ -0,0 +1,97 @@ +package org.whispersystems.signalservice.api.storage; + +import com.google.protobuf.ByteString; + +import org.signal.zkgroup.InvalidInputException; +import org.signal.zkgroup.groups.GroupMasterKey; +import org.whispersystems.signalservice.internal.storage.protos.GroupV1Record; +import org.whispersystems.signalservice.internal.storage.protos.GroupV2Record; + +import java.util.Objects; + +public final class SignalGroupV2Record implements SignalRecord { + + private final StorageId id; + private final GroupV2Record proto; + private final GroupMasterKey masterKey; + + private SignalGroupV2Record(StorageId id, GroupV2Record proto) { + this.id = id; + this.proto = proto; + try { + this.masterKey = new GroupMasterKey(proto.getMasterKey().toByteArray()); + } catch (InvalidInputException e) { + throw new AssertionError(e); + } + } + + @Override + public StorageId getId() { + return id; + } + + public GroupMasterKey getMasterKey() { + return masterKey; + } + + public boolean isBlocked() { + return proto.getBlocked(); + } + + public boolean isProfileSharingEnabled() { + return proto.getWhitelisted(); + } + + public boolean isArchived() { + return proto.getArchived(); + } + + GroupV2Record toProto() { + return proto; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SignalGroupV2Record that = (SignalGroupV2Record) o; + return id.equals(that.id) && + proto.equals(that.proto); + } + + @Override + public int hashCode() { + return Objects.hash(id, proto); + } + + public static final class Builder { + private final StorageId id; + private final GroupV2Record.Builder builder; + + public Builder(byte[] rawId, GroupMasterKey masterKey) { + this.id = StorageId.forGroupV1(rawId); + this.builder = GroupV2Record.newBuilder(); + + builder.setMasterKey(ByteString.copyFrom(masterKey.serialize())); + } + + public Builder setBlocked(boolean blocked) { + builder.setBlocked(blocked); + return this; + } + + public Builder setProfileSharingEnabled(boolean profileSharingEnabled) { + builder.setWhitelisted(profileSharingEnabled); + return this; + } + + public Builder setArchived(boolean archived) { + builder.setArchived(archived); + return this; + } + + public SignalGroupV2Record build() { + return new SignalGroupV2Record(id, builder.build()); + } + } +} diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalStorageModels.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalStorageModels.java index 116059dfce..971c062451 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalStorageModels.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalStorageModels.java @@ -2,11 +2,14 @@ package org.whispersystems.signalservice.api.storage; import com.google.protobuf.ByteString; +import org.signal.zkgroup.InvalidInputException; +import org.signal.zkgroup.groups.GroupMasterKey; import org.whispersystems.libsignal.InvalidKeyException; import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.util.UuidUtil; import org.whispersystems.signalservice.internal.storage.protos.ContactRecord; import org.whispersystems.signalservice.internal.storage.protos.GroupV1Record; +import org.whispersystems.signalservice.internal.storage.protos.GroupV2Record; import org.whispersystems.signalservice.internal.storage.protos.ManifestRecord; import org.whispersystems.signalservice.internal.storage.protos.StorageItem; import org.whispersystems.signalservice.internal.storage.protos.StorageManifest; @@ -39,6 +42,8 @@ public final class SignalStorageModels { return SignalStorageRecord.forContact(StorageId.forContact(key), remoteToLocalContactRecord(key, record.getContact())); } else if (record.hasGroupV1() && type == ManifestRecord.Identifier.Type.GROUPV1_VALUE) { return SignalStorageRecord.forGroupV1(StorageId.forGroupV1(key), remoteToLocalGroupV1Record(key, record.getGroupV1())); + } else if (record.hasGroupV2() && type == ManifestRecord.Identifier.Type.GROUPV2_VALUE && record.getGroupV2().getMasterKey().size() == GroupMasterKey.SIZE) { + return SignalStorageRecord.forGroupV2(StorageId.forGroupV2(key), remoteToLocalGroupV2Record(key, record.getGroupV2())); } else { return SignalStorageRecord.forUnknown(StorageId.forType(key, type)); } @@ -51,6 +56,8 @@ public final class SignalStorageModels { builder.setContact(record.getContact().get().toProto()); } else if (record.getGroupV1().isPresent()) { builder.setGroupV1(record.getGroupV1().get().toProto()); + } else if (record.getGroupV2().isPresent()) { + builder.setGroupV2(record.getGroupV2().get().toProto()); } else { throw new InvalidStorageWriteError(); } @@ -77,6 +84,7 @@ public final class SignalStorageModels { .setUsername(contact.getUsername()) .setIdentityKey(contact.getIdentityKey().toByteArray()) .setIdentityState(contact.getIdentityState()) + .setArchived(contact.getArchived()) .build(); } @@ -84,9 +92,22 @@ public final class SignalStorageModels { return new SignalGroupV1Record.Builder(key, groupV1.getId().toByteArray()) .setBlocked(groupV1.getBlocked()) .setProfileSharingEnabled(groupV1.getWhitelisted()) + .setArchived(groupV1.getArchived()) .build(); } + private static SignalGroupV2Record remoteToLocalGroupV2Record(byte[] key, GroupV2Record groupV2) { + try { + return new SignalGroupV2Record.Builder(key, new GroupMasterKey(groupV2.getMasterKey().toByteArray())) + .setBlocked(groupV2.getBlocked()) + .setProfileSharingEnabled(groupV2.getWhitelisted()) + .setArchived(groupV2.getArchived()) + .build(); + } catch (InvalidInputException e) { + throw new AssertionError(); + } + } + private static class InvalidStorageWriteError extends Error { } } diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalStorageRecord.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalStorageRecord.java index 0d4b8ec8c7..564f4357ad 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalStorageRecord.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalStorageRecord.java @@ -9,13 +9,14 @@ public class SignalStorageRecord implements SignalRecord { private final StorageId id; private final Optional contact; private final Optional groupV1; + private final Optional groupV2; public static SignalStorageRecord forContact(SignalContactRecord contact) { return forContact(contact.getId(), contact); } public static SignalStorageRecord forContact(StorageId key, SignalContactRecord contact) { - return new SignalStorageRecord(key, Optional.of(contact), Optional.absent()); + return new SignalStorageRecord(key, Optional.of(contact), Optional.absent(), Optional.absent()); } public static SignalStorageRecord forGroupV1(SignalGroupV1Record groupV1) { @@ -23,20 +24,30 @@ public class SignalStorageRecord implements SignalRecord { } public static SignalStorageRecord forGroupV1(StorageId key, SignalGroupV1Record groupV1) { - return new SignalStorageRecord(key, Optional.absent(), Optional.of(groupV1)); + return new SignalStorageRecord(key, Optional.absent(), Optional.of(groupV1), Optional.absent()); + } + + public static SignalStorageRecord forGroupV2(SignalGroupV2Record groupV2) { + return forGroupV2(groupV2.getId(), groupV2); + } + + public static SignalStorageRecord forGroupV2(StorageId key, SignalGroupV2Record groupV2) { + return new SignalStorageRecord(key, Optional.absent(), Optional.absent(), Optional.of(groupV2)); } public static SignalStorageRecord forUnknown(StorageId key) { - return new SignalStorageRecord(key,Optional.absent(), Optional.absent()); + return new SignalStorageRecord(key,Optional.absent(), Optional.absent(), Optional.absent()); } private SignalStorageRecord(StorageId id, Optional contact, - Optional groupV1) + Optional groupV1, + Optional groupV2) { this.id = id; this.contact = contact; this.groupV1 = groupV1; + this.groupV2 = groupV2; } @Override @@ -56,6 +67,10 @@ public class SignalStorageRecord implements SignalRecord { return groupV1; } + public Optional getGroupV2() { + return groupV2; + } + public boolean isUnknown() { return !contact.isPresent() && !groupV1.isPresent(); } diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/StorageId.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/StorageId.java index 3a42655040..c300612a16 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/StorageId.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/StorageId.java @@ -17,6 +17,10 @@ public class StorageId { return new StorageId(ManifestRecord.Identifier.Type.GROUPV1_VALUE, raw); } + public static StorageId forGroupV2(byte[] raw) { + return new StorageId(ManifestRecord.Identifier.Type.GROUPV2_VALUE, raw); + } + public static StorageId forType(byte[] raw, int type) { return new StorageId(type, raw); }