Re-use session objects during multi-recipient encryption.

This commit is contained in:
Clark 2023-11-27 15:20:18 -05:00 committed by Cody Henthorne
parent 59401e18ed
commit 61810cc977
11 changed files with 80 additions and 23 deletions

View file

@ -165,7 +165,7 @@ class BobClient(val serviceId: ServiceId, val e164: String, val identityKeyPair:
override fun storeSenderKey(sender: SignalProtocolAddress?, distributionId: UUID?, record: SenderKeyRecord?) = throw UnsupportedOperationException()
override fun loadSenderKey(sender: SignalProtocolAddress?, distributionId: UUID?): SenderKeyRecord = throw UnsupportedOperationException()
override fun archiveSession(address: SignalProtocolAddress?) = throw UnsupportedOperationException()
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>?): MutableSet<SignalProtocolAddress> = throw UnsupportedOperationException()
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>?): MutableMap<SignalProtocolAddress, SessionRecord> = throw UnsupportedOperationException()
override fun getSenderKeySharedWith(distributionId: DistributionId?): MutableSet<SignalProtocolAddress> = throw UnsupportedOperationException()
override fun markSenderKeySharedWith(distributionId: DistributionId?, addresses: MutableCollection<SignalProtocolAddress>?) = throw UnsupportedOperationException()
override fun clearSenderKeySharedWith(addresses: MutableCollection<SignalProtocolAddress>?) = throw UnsupportedOperationException()

View file

@ -21,6 +21,7 @@ import org.whispersystems.signalservice.api.push.DistributionId;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
@ -126,7 +127,7 @@ public class SignalServiceAccountDataStoreImpl implements SignalServiceAccountDa
}
@Override
public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(List<String> addressNames) {
public Map<SignalProtocolAddress, SessionRecord> getAllAddressesWithActiveSessions(List<String> addressNames) {
return sessionStore.getAllAddressesWithActiveSessions(addressNames);
}

View file

@ -17,6 +17,7 @@ import org.whispersystems.signalservice.api.SignalSessionLock;
import org.whispersystems.signalservice.api.push.ServiceId;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@ -104,14 +105,13 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
}
@Override
public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(List<String> addressNames) {
public Map<SignalProtocolAddress, SessionRecord> getAllAddressesWithActiveSessions(List<String> addressNames) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
return SignalDatabase.sessions()
.getAllFor(accountId, addressNames)
.stream()
.filter(row -> isActive(row.getRecord()))
.map(row -> new SignalProtocolAddress(row.getAddress(), row.getDeviceId()))
.collect(Collectors.toSet());
.collect(Collectors.toMap(row -> new SignalProtocolAddress(row.getAddress(), row.getDeviceId()), SessionTable.SessionRow::getRecord));
}
}

View file

@ -100,7 +100,7 @@ class BufferedSessionStore(private val selfServiceId: ServiceId) : SignalService
error("Should not happen during the intended usage pattern of this class")
}
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Set<SignalProtocolAddress> {
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Map<SignalProtocolAddress, SessionRecord> {
error("Should not happen during the intended usage pattern of this class")
}

View file

@ -177,7 +177,7 @@ class BufferedSignalServiceAccountDataStore(selfServiceId: ServiceId) : SignalSe
sessionStore.archiveSession(address)
}
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Set<SignalProtocolAddress> {
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Map<SignalProtocolAddress, SessionRecord> {
return sessionStore.getAllAddressesWithActiveSessions(addressNames)
}

View file

@ -18,6 +18,7 @@ import org.signal.libsignal.protocol.message.DecryptionErrorMessage;
import org.signal.libsignal.protocol.message.PlaintextContent;
import org.signal.libsignal.protocol.message.SenderKeyDistributionMessage;
import org.signal.libsignal.protocol.state.PreKeyBundle;
import org.signal.libsignal.protocol.state.SessionRecord;
import org.signal.libsignal.protocol.util.Pair;
import org.signal.libsignal.zkgroup.profiles.ClientZkProfileOperations;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil;
@ -136,6 +137,7 @@ import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
@ -2197,7 +2199,7 @@ public class SignalServiceMessageSender {
byte[] ciphertext;
try {
ciphertext = cipher.encryptForGroup(distributionId, targetInfo.destinations, senderCertificate, content.encode(), contentHint, groupId);
ciphertext = cipher.encryptForGroup(distributionId, targetInfo.destinations, targetInfo.sessions, senderCertificate, content.encode(), contentHint, groupId);
} catch (org.signal.libsignal.protocol.UntrustedIdentityException e) {
throw new UntrustedIdentityException("Untrusted during group encrypt", e.getName(), e.getUntrustedIdentity());
}
@ -2245,9 +2247,11 @@ public class SignalServiceMessageSender {
}
private GroupTargetInfo buildGroupTargetInfo(List<SignalServiceAddress> recipients) {
List<String> addressNames = recipients.stream().map(SignalServiceAddress::getIdentifier).collect(Collectors.toList());
Set<SignalProtocolAddress> destinations = aciStore.getAllAddressesWithActiveSessions(addressNames);
Map<String, List<Integer>> devicesByAddressName = new HashMap<>();
List<String> addressNames = recipients.stream().map(SignalServiceAddress::getIdentifier).collect(Collectors.toList());
Map<SignalProtocolAddress, SessionRecord> sessionMap = aciStore.getAllAddressesWithActiveSessions(addressNames);
Map<String, List<Integer>> devicesByAddressName = new HashMap<>();
Set<SignalProtocolAddress> destinations = new HashSet<>(sessionMap.keySet());
destinations.addAll(recipients.stream()
.map(a -> new SignalProtocolAddress(a.getIdentifier(), SignalServiceAddress.DEFAULT_DEVICE_ID))
@ -2267,17 +2271,22 @@ public class SignalServiceMessageSender {
}
}
return new GroupTargetInfo(new ArrayList<>(destinations), recipientDevices);
return new GroupTargetInfo(new ArrayList<>(destinations), recipientDevices, sessionMap);
}
private static final class GroupTargetInfo {
private final List<SignalProtocolAddress> destinations;
private final Map<SignalServiceAddress, List<Integer>> devices;
private final List<SignalProtocolAddress> destinations;
private final Map<SignalServiceAddress, List<Integer>> devices;
private final Map<SignalProtocolAddress, SessionRecord> sessions;
private GroupTargetInfo(List<SignalProtocolAddress> destinations, Map<SignalServiceAddress, List<Integer>> devices) {
private GroupTargetInfo(
List<SignalProtocolAddress> destinations,
Map<SignalServiceAddress, List<Integer>> devices,
Map<SignalProtocolAddress, SessionRecord> sessions) {
this.destinations = destinations;
this.devices = devices;
this.sessions = sessions;
}
}

View file

@ -1,9 +1,11 @@
package org.whispersystems.signalservice.api;
import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.signal.libsignal.protocol.state.SessionRecord;
import org.signal.libsignal.protocol.state.SessionStore;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
@ -12,5 +14,5 @@ import java.util.Set;
*/
public interface SignalServiceSessionStore extends SessionStore {
void archiveSession(SignalProtocolAddress address);
Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(List<String> addressNames);
Map<SignalProtocolAddress, SessionRecord> getAllAddressesWithActiveSessions(List<String> addressNames);
}

View file

@ -1,5 +1,7 @@
package org.whispersystems.signalservice.api.crypto;
import org.signal.libsignal.internal.Native;
import org.signal.libsignal.internal.NativeHandleGuard;
import org.signal.libsignal.metadata.InvalidMetadataMessageException;
import org.signal.libsignal.metadata.InvalidMetadataVersionException;
import org.signal.libsignal.metadata.ProtocolDuplicateMessageException;
@ -19,9 +21,14 @@ import org.signal.libsignal.protocol.InvalidRegistrationIdException;
import org.signal.libsignal.protocol.NoSessionException;
import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.signal.libsignal.protocol.UntrustedIdentityException;
import org.signal.libsignal.protocol.state.SessionRecord;
import org.signal.libsignal.protocol.state.SignalProtocolStore;
import org.whispersystems.signalservice.api.SignalSessionLock;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* A thread-safe wrapper around {@link SealedSessionCipher}.
@ -44,11 +51,47 @@ public class SignalSealedSessionCipher {
}
}
public byte[] multiRecipientEncrypt(List<SignalProtocolAddress> recipients, UnidentifiedSenderMessageContent content)
// TODO: Revert the change here to use the libsignal SealedSessionCipher when the API changes
public byte[] multiRecipientEncrypt(SignalProtocolStore signalProtocolStore, List<SignalProtocolAddress> recipients, Map<SignalProtocolAddress, SessionRecord> sessionMap, UnidentifiedSenderMessageContent content)
throws InvalidKeyException, UntrustedIdentityException, NoSessionException, InvalidRegistrationIdException
{
try (SignalSessionLock.Lock unused = lock.acquire()) {
return cipher.multiRecipientEncrypt(recipients, content);
if (sessionMap == null) {
return cipher.multiRecipientEncrypt(recipients, content);
}
List<SessionRecord> recipientSessions = recipients.stream().map(sessionMap::get).collect(Collectors.toList());
if (recipientSessions.stream().anyMatch(Objects::isNull)) {
throw new NoSessionException("Failed to find one or more sessions.");
}
// Unsafely access the native handles for the recipients and sessions,
// because try-with-resources syntax doesn't support a List of resources.
long[] recipientHandles = new long[recipients.size()];
int i = 0;
for (SignalProtocolAddress nextRecipient : recipients) {
recipientHandles[i] = nextRecipient.unsafeNativeHandleWithoutGuard();
i++;
}
long[] recipientSessionHandles = new long[recipientSessions.size()];
i = 0;
for (SessionRecord nextSession : recipientSessions) {
recipientSessionHandles[i] = nextSession.unsafeNativeHandleWithoutGuard();
i++;
}
try (NativeHandleGuard contentGuard = new NativeHandleGuard(content)) {
byte[] result =
Native.SealedSessionCipher_MultiRecipientEncrypt(
recipientHandles,
recipientSessionHandles,
contentGuard.nativeHandle(),
signalProtocolStore);
// Manually keep the lists of recipients and sessions from being garbage collected
// while we're using their native handles.
Native.keepAlive(recipients);
Native.keepAlive(recipientSessions);
return result;
}
}
}

View file

@ -40,6 +40,7 @@ import org.signal.libsignal.protocol.message.CiphertextMessage;
import org.signal.libsignal.protocol.message.PlaintextContent;
import org.signal.libsignal.protocol.message.PreKeySignalMessage;
import org.signal.libsignal.protocol.message.SignalMessage;
import org.signal.libsignal.protocol.state.SessionRecord;
import org.whispersystems.signalservice.api.InvalidMessageStructureException;
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore;
import org.whispersystems.signalservice.api.SignalSessionLock;
@ -56,6 +57,7 @@ import org.whispersystems.signalservice.internal.push.PushTransportDetails;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
/**
@ -87,6 +89,7 @@ public class SignalServiceCipher {
public byte[] encryptForGroup(DistributionId distributionId,
List<SignalProtocolAddress> destinations,
Map<SignalProtocolAddress, SessionRecord> sessionMap,
SenderCertificate senderCertificate,
byte[] unpaddedMessage,
ContentHint contentHint,
@ -103,7 +106,7 @@ public class SignalServiceCipher {
contentHint.getType(),
groupId);
return sessionCipher.multiRecipientEncrypt(destinations, messageContent);
return sessionCipher.multiRecipientEncrypt(signalProtocolStore, destinations, sessionMap, messageContent);
}
public OutgoingPushMessage encrypt(SignalProtocolAddress destination,

View file

@ -174,12 +174,11 @@ class InMemorySignalServiceAccountDataStore : SignalServiceAccountDataStore {
sessions[address]!!.archiveCurrentState()
}
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Set<SignalProtocolAddress> {
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): MutableMap<SignalProtocolAddress, SessionRecord> {
return sessions
.filter { it.key.name in addressNames }
.filter { it.value.isValid() }
.map { it.key }
.toSet()
.toMutableMap()
}
override fun getSenderKeySharedWith(distributionId: DistributionId): Set<SignalProtocolAddress> {

View file

@ -156,7 +156,7 @@ class SignalClient {
SignalProtocolAddress(bob.aci.toString(), 1)
}
return cipher.encryptForGroup(distributionId, destinations, senderCertificate, content.encode(), ContentHint.DEFAULT, groupId)
return cipher.encryptForGroup(distributionId, destinations, null, senderCertificate, content.encode(), ContentHint.DEFAULT, groupId)
}
fun decryptMessage(envelope: Envelope) {