Convert SignalStorageManifest to kotlin.

This commit is contained in:
Greyson Parrelli 2024-11-13 11:51:31 -05:00
parent 7dd1fc09c0
commit b6bb3928e7
6 changed files with 112 additions and 140 deletions

View file

@ -36,7 +36,7 @@ public final class StorageSyncValidations {
validateManifestAndInserts(result.manifest, result.inserts, self);
if (result.deletes.size() > 0) {
Set<String> allSetEncoded = Stream.of(result.manifest.getStorageIds()).map(StorageId::getRaw).map(Base64::encodeWithPadding).collect(Collectors.toSet());
Set<String> allSetEncoded = Stream.of(result.manifest.storageIds).map(StorageId::getRaw).map(Base64::encodeWithPadding).collect(Collectors.toSet());
for (byte[] delete : result.deletes) {
String encoded = Base64.encodeWithPadding(delete);
@ -46,12 +46,12 @@ public final class StorageSyncValidations {
}
}
if (previousManifest.getVersion() == 0) {
if (previousManifest.version == 0) {
Log.i(TAG, "Previous manifest is empty, not bothering with additional validations around the diffs between the two manifests.");
return;
}
if (result.manifest.getVersion() != previousManifest.getVersion() + 1) {
if (result.manifest.version != previousManifest.version + 1) {
throw new IncorrectManifestVersionError();
}
@ -60,8 +60,8 @@ public final class StorageSyncValidations {
return;
}
Set<ByteBuffer> previousIds = Stream.of(previousManifest.getStorageIds()).map(id -> ByteBuffer.wrap(id.getRaw())).collect(Collectors.toSet());
Set<ByteBuffer> newIds = Stream.of(result.manifest.getStorageIds()).map(id -> ByteBuffer.wrap(id.getRaw())).collect(Collectors.toSet());
Set<ByteBuffer> previousIds = Stream.of(previousManifest.storageIds).map(id -> ByteBuffer.wrap(id.getRaw())).collect(Collectors.toSet());
Set<ByteBuffer> newIds = Stream.of(result.manifest.storageIds).map(id -> ByteBuffer.wrap(id.getRaw())).collect(Collectors.toSet());
Set<ByteBuffer> manifestInserts = SetUtil.difference(newIds, previousIds);
Set<ByteBuffer> manifestDeletes = SetUtil.difference(previousIds, newIds);
@ -105,7 +105,7 @@ public final class StorageSyncValidations {
private static void validateManifestAndInserts(@NonNull SignalStorageManifest manifest, @NonNull List<SignalStorageRecord> inserts, @NonNull Recipient self) {
int accountCount = 0;
for (StorageId id : manifest.getStorageIds()) {
for (StorageId id : manifest.storageIds) {
accountCount += id.getType() == ManifestRecord.Identifier.Type.ACCOUNT.getValue() ? 1 : 0;
}
@ -117,11 +117,11 @@ public final class StorageSyncValidations {
throw new MissingAccountError();
}
Set<StorageId> allSet = new HashSet<>(manifest.getStorageIds());
Set<StorageId> allSet = new HashSet<>(manifest.storageIds);
Set<StorageId> insertSet = new HashSet<>(Stream.of(inserts).map(SignalStorageRecord::getId).toList());
Set<ByteBuffer> rawIdSet = Stream.of(allSet).map(id -> ByteBuffer.wrap(id.getRaw())).collect(Collectors.toSet());
if (allSet.size() != manifest.getStorageIds().size()) {
if (allSet.size() != manifest.storageIds.size()) {
throw new DuplicateStorageIdError();
}

View file

@ -428,13 +428,13 @@ public class SignalServiceAccountManager {
throws IOException, InvalidKeyException
{
ManifestRecord.Builder manifestRecordBuilder = new ManifestRecord.Builder()
.sourceDevice(manifest.getSourceDeviceId())
.version(manifest.getVersion());
.sourceDevice(manifest.sourceDeviceId)
.version(manifest.version);
manifestRecordBuilder.identifiers(
manifest.getStorageIds().stream()
.map(id -> {
manifest.storageIds.stream()
.map(id -> {
ManifestRecord.Identifier.Builder builder = new ManifestRecord.Identifier.Builder()
.raw(ByteString.of(id.getRaw()));
if (!id.isUnknown()) {
@ -445,14 +445,14 @@ public class SignalServiceAccountManager {
}
return builder.build();
})
.collect(Collectors.toList())
.collect(Collectors.toList())
);
String authToken = this.pushServiceSocket.getStorageAuth();
StorageManifestKey manifestKey = storageKey.deriveManifestKey(manifest.getVersion());
StorageManifestKey manifestKey = storageKey.deriveManifestKey(manifest.version);
byte[] encryptedRecord = SignalStorageCipher.encrypt(manifestKey, manifestRecordBuilder.build().encode());
StorageManifest storageManifest = new StorageManifest.Builder()
.version(manifest.getVersion())
.version(manifest.version)
.value_(ByteString.of(encryptedRecord))
.build();

View file

@ -0,0 +1,44 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.api.storage
import com.squareup.wire.FieldEncoding
import okio.ByteString.Companion.toByteString
import org.signal.core.util.getUnknownEnumValue
import org.whispersystems.signalservice.internal.storage.protos.ManifestRecord
/**
* Wire makes it harder to write specific values to proto enums, since they use actual enums under the hood.
* This method handles creating an identifier from a possibly-unknown enum type, writing an unknown field if
* necessary to preserve the specific value.
*/
fun ManifestRecord.Identifier.Companion.fromPossiblyUnknownType(typeInt: Int, rawId: ByteArray): ManifestRecord.Identifier {
val builder = ManifestRecord.Identifier.Builder()
builder.raw = rawId.toByteString()
val type = ManifestRecord.Identifier.Type.fromValue(typeInt)
if (type != null) {
builder.type = type
} else {
builder.type = ManifestRecord.Identifier.Type.UNKNOWN
builder.addUnknownField(StorageRecordProtoUtil.STORAGE_ID_TYPE_TAG, FieldEncoding.VARINT, typeInt)
}
return builder.build()
}
/**
* Wire makes it harder to read the underlying int value of an unknown enum.
* This value represents the _true_ int value of the enum, even if it is not one of the known values.
*/
val ManifestRecord.Identifier.typeValue: Int
get() {
return if (this.type != ManifestRecord.Identifier.Type.UNKNOWN) {
this.type.value
} else {
this.getUnknownEnumValue(StorageRecordProtoUtil.STORAGE_ID_TYPE_TAG)
}
}

View file

@ -1,114 +0,0 @@
package org.whispersystems.signalservice.api.storage;
import org.signal.core.util.ProtoUtil;
import org.whispersystems.signalservice.internal.storage.protos.ManifestRecord;
import org.whispersystems.signalservice.internal.storage.protos.StorageManifest;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import okio.ByteString;
public class SignalStorageManifest {
public static final SignalStorageManifest EMPTY = new SignalStorageManifest(0, 1, Collections.emptyList());
private final long version;
private final int sourceDeviceId;
private final List<StorageId> storageIds;
private final Map<Integer, List<StorageId>> storageIdsByType;
public SignalStorageManifest(long version, int sourceDeviceId, List<StorageId> storageIds) {
this.version = version;
this.sourceDeviceId = sourceDeviceId;
this.storageIds = storageIds;
this.storageIdsByType = new HashMap<>();
for (StorageId id : storageIds) {
List<StorageId> list = storageIdsByType.get(id.getType());
if (list == null) {
list = new ArrayList<>();
}
list.add(id);
storageIdsByType.put(id.getType(), list);
}
}
public static SignalStorageManifest deserialize(byte[] serialized) {
try {
StorageManifest manifest = StorageManifest.ADAPTER.decode(serialized);
ManifestRecord manifestRecord = ManifestRecord.ADAPTER.decode(manifest.value_);
List<StorageId> ids = new ArrayList<>(manifestRecord.identifiers.size());
for (ManifestRecord.Identifier id : manifestRecord.identifiers) {
ids.add(StorageId.forType(id.raw.toByteArray(), id.type.getValue()));
}
return new SignalStorageManifest(manifest.version, manifestRecord.sourceDevice, ids);
} catch (IOException e) {
throw new AssertionError(e);
}
}
public long getVersion() {
return version;
}
public int getSourceDeviceId() {
return sourceDeviceId;
}
public String getVersionString() {
return version + "." + sourceDeviceId;
}
public List<StorageId> getStorageIds() {
return storageIds;
}
public Optional<StorageId> getAccountStorageId() {
List<StorageId> list = storageIdsByType.get(ManifestRecord.Identifier.Type.ACCOUNT.getValue());
if (list != null && list.size() > 0) {
return Optional.of(list.get(0));
} else {
return Optional.empty();
}
}
public Map<Integer, List<StorageId>> getStorageIdsByType() {
return storageIdsByType;
}
public byte[] serialize() {
List<ManifestRecord.Identifier> ids = new ArrayList<>(storageIds.size());
for (StorageId id : storageIds) {
ManifestRecord.Identifier.Type type = ManifestRecord.Identifier.Type.Companion.fromValue(id.getType());
if (type != null) {
ids.add(new ManifestRecord.Identifier.Builder()
.type(type)
.raw(ByteString.of(id.getRaw()))
.build());
} else {
ByteString unknownEnum = ProtoUtil.writeUnknownEnumValue(StorageRecordProtoUtil.STORAGE_ID_TYPE_TAG, id.getType());
ids.add(new ManifestRecord.Identifier(ByteString.of(id.getRaw()), ManifestRecord.Identifier.Type.UNKNOWN, unknownEnum));
}
}
ManifestRecord manifestRecord = new ManifestRecord.Builder()
.identifiers(ids)
.sourceDevice(sourceDeviceId)
.build();
return new StorageManifest.Builder()
.version(version)
.value_(manifestRecord.encodeByteString())
.build()
.encode();
}
}

View file

@ -0,0 +1,51 @@
package org.whispersystems.signalservice.api.storage
import org.signal.core.util.toOptional
import org.whispersystems.signalservice.internal.storage.protos.ManifestRecord
import org.whispersystems.signalservice.internal.storage.protos.StorageManifest
import java.util.Optional
class SignalStorageManifest(
@JvmField val version: Long,
@JvmField val sourceDeviceId: Int,
@JvmField val storageIds: List<StorageId>
) {
companion object {
val EMPTY: SignalStorageManifest = SignalStorageManifest(0, 1, emptyList())
fun deserialize(serialized: ByteArray): SignalStorageManifest {
val manifest = StorageManifest.ADAPTER.decode(serialized)
val manifestRecord = ManifestRecord.ADAPTER.decode(manifest.value_)
val ids: List<StorageId> = manifestRecord.identifiers.map { id ->
StorageId.forType(id.raw.toByteArray(), id.typeValue)
}
return SignalStorageManifest(manifest.version, manifestRecord.sourceDevice, ids)
}
}
val storageIdsByType: Map<Int, List<StorageId>> = storageIds.groupBy { it.type }
val versionString: String
get() = "$version.$sourceDeviceId"
val accountStorageId: Optional<StorageId>
get() = storageIdsByType[ManifestRecord.Identifier.Type.ACCOUNT.value]?.takeIf { it.isNotEmpty() }?.get(0).toOptional()
fun serialize(): ByteArray {
val ids: List<ManifestRecord.Identifier> = storageIds.map { id ->
ManifestRecord.Identifier.fromPossiblyUnknownType(id.type, id.raw)
}
val manifestRecord = ManifestRecord(
identifiers = ids,
sourceDevice = sourceDeviceId
)
return StorageManifest(
version = version,
value_ = manifestRecord.encodeByteString()
).encode()
}
}

View file

@ -1,7 +1,6 @@
package org.whispersystems.signalservice.api.storage
import okio.ByteString.Companion.toByteString
import org.signal.core.util.getUnknownEnumValue
import org.signal.libsignal.protocol.InvalidKeyException
import org.signal.libsignal.protocol.logging.Log
import org.signal.libsignal.zkgroup.groups.GroupMasterKey
@ -19,16 +18,8 @@ object SignalStorageModels {
fun remoteToLocalStorageManifest(manifest: StorageManifest, storageKey: StorageKey): SignalStorageManifest {
val rawRecord = SignalStorageCipher.decrypt(storageKey.deriveManifestKey(manifest.version), manifest.value_.toByteArray())
val manifestRecord = ManifestRecord.ADAPTER.decode(rawRecord)
val ids: MutableList<StorageId> = ArrayList(manifestRecord.identifiers.size)
for (id in manifestRecord.identifiers) {
val typeValue = if ((id.type != ManifestRecord.Identifier.Type.UNKNOWN)) {
id.type.value
} else {
id.getUnknownEnumValue(StorageRecordProtoUtil.STORAGE_ID_TYPE_TAG)
}
ids.add(StorageId.forType(id.raw.toByteArray(), typeValue))
val ids: List<StorageId> = manifestRecord.identifiers.map { id ->
StorageId.forType(id.raw.toByteArray(), id.typeValue)
}
return SignalStorageManifest(manifestRecord.version, manifestRecord.sourceDevice, ids)