diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/testing/BobClient.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/testing/BobClient.kt index 49d7b78ba2..b79713ff46 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/testing/BobClient.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/testing/BobClient.kt @@ -165,7 +165,7 @@ class BobClient(val serviceId: ServiceId, val e164: String, val identityKeyPair: override fun storeSenderKey(sender: SignalProtocolAddress?, distributionId: UUID?, record: SenderKeyRecord?) = throw UnsupportedOperationException() override fun loadSenderKey(sender: SignalProtocolAddress?, distributionId: UUID?): SenderKeyRecord = throw UnsupportedOperationException() override fun archiveSession(address: SignalProtocolAddress?) = throw UnsupportedOperationException() - override fun getAllAddressesWithActiveSessions(addressNames: MutableList?): MutableSet = throw UnsupportedOperationException() + override fun getAllAddressesWithActiveSessions(addressNames: MutableList?): MutableMap = throw UnsupportedOperationException() override fun getSenderKeySharedWith(distributionId: DistributionId?): MutableSet = throw UnsupportedOperationException() override fun markSenderKeySharedWith(distributionId: DistributionId?, addresses: MutableCollection?) = throw UnsupportedOperationException() override fun clearSenderKeySharedWith(addresses: MutableCollection?) = throw UnsupportedOperationException() diff --git a/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/SignalServiceAccountDataStoreImpl.java b/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/SignalServiceAccountDataStoreImpl.java index d00f4c4149..dcb40d4c75 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/SignalServiceAccountDataStoreImpl.java +++ b/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/SignalServiceAccountDataStoreImpl.java @@ -21,6 +21,7 @@ import org.whispersystems.signalservice.api.push.DistributionId; import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.UUID; @@ -126,7 +127,7 @@ public class SignalServiceAccountDataStoreImpl implements SignalServiceAccountDa } @Override - public Set getAllAddressesWithActiveSessions(List addressNames) { + public Map getAllAddressesWithActiveSessions(List addressNames) { return sessionStore.getAllAddressesWithActiveSessions(addressNames); } diff --git a/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/TextSecureSessionStore.java b/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/TextSecureSessionStore.java index da84874432..6aa2b6d7c0 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/TextSecureSessionStore.java +++ b/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/TextSecureSessionStore.java @@ -17,6 +17,7 @@ import org.whispersystems.signalservice.api.SignalSessionLock; import org.whispersystems.signalservice.api.push.ServiceId; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; @@ -104,14 +105,13 @@ public class TextSecureSessionStore implements SignalServiceSessionStore { } @Override - public Set getAllAddressesWithActiveSessions(List addressNames) { + public Map getAllAddressesWithActiveSessions(List addressNames) { try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) { return SignalDatabase.sessions() .getAllFor(accountId, addressNames) .stream() .filter(row -> isActive(row.getRecord())) - .map(row -> new SignalProtocolAddress(row.getAddress(), row.getDeviceId())) - .collect(Collectors.toSet()); + .collect(Collectors.toMap(row -> new SignalProtocolAddress(row.getAddress(), row.getDeviceId()), SessionTable.SessionRow::getRecord)); } } diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSessionStore.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSessionStore.kt index 9a13ceaa96..377c63f973 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSessionStore.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSessionStore.kt @@ -100,7 +100,7 @@ class BufferedSessionStore(private val selfServiceId: ServiceId) : SignalService error("Should not happen during the intended usage pattern of this class") } - override fun getAllAddressesWithActiveSessions(addressNames: MutableList): Set { + override fun getAllAddressesWithActiveSessions(addressNames: MutableList): Map { error("Should not happen during the intended usage pattern of this class") } diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSignalServiceAccountDataStore.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSignalServiceAccountDataStore.kt index b00bb91527..5da7f56e9a 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSignalServiceAccountDataStore.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSignalServiceAccountDataStore.kt @@ -177,7 +177,7 @@ class BufferedSignalServiceAccountDataStore(selfServiceId: ServiceId) : SignalSe sessionStore.archiveSession(address) } - override fun getAllAddressesWithActiveSessions(addressNames: MutableList): Set { + override fun getAllAddressesWithActiveSessions(addressNames: MutableList): Map { return sessionStore.getAllAddressesWithActiveSessions(addressNames) } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java index 39b854cf9a..11197781a5 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java @@ -18,6 +18,7 @@ import org.signal.libsignal.protocol.message.DecryptionErrorMessage; import org.signal.libsignal.protocol.message.PlaintextContent; import org.signal.libsignal.protocol.message.SenderKeyDistributionMessage; import org.signal.libsignal.protocol.state.PreKeyBundle; +import org.signal.libsignal.protocol.state.SessionRecord; import org.signal.libsignal.protocol.util.Pair; import org.signal.libsignal.zkgroup.profiles.ClientZkProfileOperations; import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil; @@ -136,6 +137,7 @@ import java.security.SecureRandom; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.LinkedList; import java.util.List; @@ -2197,7 +2199,7 @@ public class SignalServiceMessageSender { byte[] ciphertext; try { - ciphertext = cipher.encryptForGroup(distributionId, targetInfo.destinations, senderCertificate, content.encode(), contentHint, groupId); + ciphertext = cipher.encryptForGroup(distributionId, targetInfo.destinations, targetInfo.sessions, senderCertificate, content.encode(), contentHint, groupId); } catch (org.signal.libsignal.protocol.UntrustedIdentityException e) { throw new UntrustedIdentityException("Untrusted during group encrypt", e.getName(), e.getUntrustedIdentity()); } @@ -2245,9 +2247,11 @@ public class SignalServiceMessageSender { } private GroupTargetInfo buildGroupTargetInfo(List recipients) { - List addressNames = recipients.stream().map(SignalServiceAddress::getIdentifier).collect(Collectors.toList()); - Set destinations = aciStore.getAllAddressesWithActiveSessions(addressNames); - Map> devicesByAddressName = new HashMap<>(); + List addressNames = recipients.stream().map(SignalServiceAddress::getIdentifier).collect(Collectors.toList()); + Map sessionMap = aciStore.getAllAddressesWithActiveSessions(addressNames); + Map> devicesByAddressName = new HashMap<>(); + + Set destinations = new HashSet<>(sessionMap.keySet()); destinations.addAll(recipients.stream() .map(a -> new SignalProtocolAddress(a.getIdentifier(), SignalServiceAddress.DEFAULT_DEVICE_ID)) @@ -2267,17 +2271,22 @@ public class SignalServiceMessageSender { } } - return new GroupTargetInfo(new ArrayList<>(destinations), recipientDevices); + return new GroupTargetInfo(new ArrayList<>(destinations), recipientDevices, sessionMap); } private static final class GroupTargetInfo { - private final List destinations; - private final Map> devices; + private final List destinations; + private final Map> devices; + private final Map sessions; - private GroupTargetInfo(List destinations, Map> devices) { + private GroupTargetInfo( + List destinations, + Map> devices, + Map sessions) { this.destinations = destinations; this.devices = devices; + this.sessions = sessions; } } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceSessionStore.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceSessionStore.java index 278180ab1f..bd11b5d2c8 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceSessionStore.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceSessionStore.java @@ -1,9 +1,11 @@ package org.whispersystems.signalservice.api; import org.signal.libsignal.protocol.SignalProtocolAddress; +import org.signal.libsignal.protocol.state.SessionRecord; import org.signal.libsignal.protocol.state.SessionStore; import java.util.List; +import java.util.Map; import java.util.Set; /** @@ -12,5 +14,5 @@ import java.util.Set; */ public interface SignalServiceSessionStore extends SessionStore { void archiveSession(SignalProtocolAddress address); - Set getAllAddressesWithActiveSessions(List addressNames); + Map getAllAddressesWithActiveSessions(List addressNames); } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/SignalSealedSessionCipher.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/SignalSealedSessionCipher.java index 12b67c4a93..7b573025f7 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/SignalSealedSessionCipher.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/SignalSealedSessionCipher.java @@ -1,5 +1,7 @@ package org.whispersystems.signalservice.api.crypto; +import org.signal.libsignal.internal.Native; +import org.signal.libsignal.internal.NativeHandleGuard; import org.signal.libsignal.metadata.InvalidMetadataMessageException; import org.signal.libsignal.metadata.InvalidMetadataVersionException; import org.signal.libsignal.metadata.ProtocolDuplicateMessageException; @@ -19,9 +21,14 @@ import org.signal.libsignal.protocol.InvalidRegistrationIdException; import org.signal.libsignal.protocol.NoSessionException; import org.signal.libsignal.protocol.SignalProtocolAddress; import org.signal.libsignal.protocol.UntrustedIdentityException; +import org.signal.libsignal.protocol.state.SessionRecord; +import org.signal.libsignal.protocol.state.SignalProtocolStore; import org.whispersystems.signalservice.api.SignalSessionLock; import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; /** * A thread-safe wrapper around {@link SealedSessionCipher}. @@ -44,11 +51,47 @@ public class SignalSealedSessionCipher { } } - public byte[] multiRecipientEncrypt(List recipients, UnidentifiedSenderMessageContent content) + // TODO: Revert the change here to use the libsignal SealedSessionCipher when the API changes + public byte[] multiRecipientEncrypt(SignalProtocolStore signalProtocolStore, List recipients, Map sessionMap, UnidentifiedSenderMessageContent content) throws InvalidKeyException, UntrustedIdentityException, NoSessionException, InvalidRegistrationIdException { try (SignalSessionLock.Lock unused = lock.acquire()) { - return cipher.multiRecipientEncrypt(recipients, content); + if (sessionMap == null) { + return cipher.multiRecipientEncrypt(recipients, content); + } + List recipientSessions = recipients.stream().map(sessionMap::get).collect(Collectors.toList()); + if (recipientSessions.stream().anyMatch(Objects::isNull)) { + throw new NoSessionException("Failed to find one or more sessions."); + } + // Unsafely access the native handles for the recipients and sessions, + // because try-with-resources syntax doesn't support a List of resources. + long[] recipientHandles = new long[recipients.size()]; + int i = 0; + for (SignalProtocolAddress nextRecipient : recipients) { + recipientHandles[i] = nextRecipient.unsafeNativeHandleWithoutGuard(); + i++; + } + + long[] recipientSessionHandles = new long[recipientSessions.size()]; + i = 0; + for (SessionRecord nextSession : recipientSessions) { + recipientSessionHandles[i] = nextSession.unsafeNativeHandleWithoutGuard(); + i++; + } + + try (NativeHandleGuard contentGuard = new NativeHandleGuard(content)) { + byte[] result = + Native.SealedSessionCipher_MultiRecipientEncrypt( + recipientHandles, + recipientSessionHandles, + contentGuard.nativeHandle(), + signalProtocolStore); + // Manually keep the lists of recipients and sessions from being garbage collected + // while we're using their native handles. + Native.keepAlive(recipients); + Native.keepAlive(recipientSessions); + return result; + } } } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/SignalServiceCipher.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/SignalServiceCipher.java index 97e85e96c2..4d22a1c622 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/SignalServiceCipher.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/SignalServiceCipher.java @@ -40,6 +40,7 @@ import org.signal.libsignal.protocol.message.CiphertextMessage; import org.signal.libsignal.protocol.message.PlaintextContent; import org.signal.libsignal.protocol.message.PreKeySignalMessage; import org.signal.libsignal.protocol.message.SignalMessage; +import org.signal.libsignal.protocol.state.SessionRecord; import org.whispersystems.signalservice.api.InvalidMessageStructureException; import org.whispersystems.signalservice.api.SignalServiceAccountDataStore; import org.whispersystems.signalservice.api.SignalSessionLock; @@ -56,6 +57,7 @@ import org.whispersystems.signalservice.internal.push.PushTransportDetails; import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; /** @@ -87,6 +89,7 @@ public class SignalServiceCipher { public byte[] encryptForGroup(DistributionId distributionId, List destinations, + Map sessionMap, SenderCertificate senderCertificate, byte[] unpaddedMessage, ContentHint contentHint, @@ -103,7 +106,7 @@ public class SignalServiceCipher { contentHint.getType(), groupId); - return sessionCipher.multiRecipientEncrypt(destinations, messageContent); + return sessionCipher.multiRecipientEncrypt(signalProtocolStore, destinations, sessionMap, messageContent); } public OutgoingPushMessage encrypt(SignalProtocolAddress destination, diff --git a/microbenchmark/src/androidTest/java/org/signal/util/InMemorySignalServiceAccountDataStore.kt b/microbenchmark/src/androidTest/java/org/signal/util/InMemorySignalServiceAccountDataStore.kt index 28d7037b82..bbf2c96cbc 100644 --- a/microbenchmark/src/androidTest/java/org/signal/util/InMemorySignalServiceAccountDataStore.kt +++ b/microbenchmark/src/androidTest/java/org/signal/util/InMemorySignalServiceAccountDataStore.kt @@ -174,12 +174,11 @@ class InMemorySignalServiceAccountDataStore : SignalServiceAccountDataStore { sessions[address]!!.archiveCurrentState() } - override fun getAllAddressesWithActiveSessions(addressNames: MutableList): Set { + override fun getAllAddressesWithActiveSessions(addressNames: MutableList): MutableMap { return sessions .filter { it.key.name in addressNames } .filter { it.value.isValid() } - .map { it.key } - .toSet() + .toMutableMap() } override fun getSenderKeySharedWith(distributionId: DistributionId): Set { diff --git a/microbenchmark/src/androidTest/java/org/signal/util/SignalClient.kt b/microbenchmark/src/androidTest/java/org/signal/util/SignalClient.kt index 96e3ec1a51..646117afef 100644 --- a/microbenchmark/src/androidTest/java/org/signal/util/SignalClient.kt +++ b/microbenchmark/src/androidTest/java/org/signal/util/SignalClient.kt @@ -156,7 +156,7 @@ class SignalClient { SignalProtocolAddress(bob.aci.toString(), 1) } - return cipher.encryptForGroup(distributionId, destinations, senderCertificate, content.encode(), ContentHint.DEFAULT, groupId) + return cipher.encryptForGroup(distributionId, destinations, null, senderCertificate, content.encode(), ContentHint.DEFAULT, groupId) } fun decryptMessage(envelope: Envelope) {