Re-use session objects during multi-recipient encryption.
This commit is contained in:
parent
59401e18ed
commit
61810cc977
11 changed files with 80 additions and 23 deletions
|
@ -165,7 +165,7 @@ class BobClient(val serviceId: ServiceId, val e164: String, val identityKeyPair:
|
|||
override fun storeSenderKey(sender: SignalProtocolAddress?, distributionId: UUID?, record: SenderKeyRecord?) = throw UnsupportedOperationException()
|
||||
override fun loadSenderKey(sender: SignalProtocolAddress?, distributionId: UUID?): SenderKeyRecord = throw UnsupportedOperationException()
|
||||
override fun archiveSession(address: SignalProtocolAddress?) = throw UnsupportedOperationException()
|
||||
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>?): MutableSet<SignalProtocolAddress> = throw UnsupportedOperationException()
|
||||
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>?): MutableMap<SignalProtocolAddress, SessionRecord> = throw UnsupportedOperationException()
|
||||
override fun getSenderKeySharedWith(distributionId: DistributionId?): MutableSet<SignalProtocolAddress> = throw UnsupportedOperationException()
|
||||
override fun markSenderKeySharedWith(distributionId: DistributionId?, addresses: MutableCollection<SignalProtocolAddress>?) = throw UnsupportedOperationException()
|
||||
override fun clearSenderKeySharedWith(addresses: MutableCollection<SignalProtocolAddress>?) = throw UnsupportedOperationException()
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.whispersystems.signalservice.api.push.DistributionId;
|
|||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
|
||||
|
@ -126,7 +127,7 @@ public class SignalServiceAccountDataStoreImpl implements SignalServiceAccountDa
|
|||
}
|
||||
|
||||
@Override
|
||||
public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(List<String> addressNames) {
|
||||
public Map<SignalProtocolAddress, SessionRecord> getAllAddressesWithActiveSessions(List<String> addressNames) {
|
||||
return sessionStore.getAllAddressesWithActiveSessions(addressNames);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ import org.whispersystems.signalservice.api.SignalSessionLock;
|
|||
import org.whispersystems.signalservice.api.push.ServiceId;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
@ -104,14 +105,13 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(List<String> addressNames) {
|
||||
public Map<SignalProtocolAddress, SessionRecord> getAllAddressesWithActiveSessions(List<String> addressNames) {
|
||||
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
|
||||
return SignalDatabase.sessions()
|
||||
.getAllFor(accountId, addressNames)
|
||||
.stream()
|
||||
.filter(row -> isActive(row.getRecord()))
|
||||
.map(row -> new SignalProtocolAddress(row.getAddress(), row.getDeviceId()))
|
||||
.collect(Collectors.toSet());
|
||||
.collect(Collectors.toMap(row -> new SignalProtocolAddress(row.getAddress(), row.getDeviceId()), SessionTable.SessionRow::getRecord));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ class BufferedSessionStore(private val selfServiceId: ServiceId) : SignalService
|
|||
error("Should not happen during the intended usage pattern of this class")
|
||||
}
|
||||
|
||||
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Set<SignalProtocolAddress> {
|
||||
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Map<SignalProtocolAddress, SessionRecord> {
|
||||
error("Should not happen during the intended usage pattern of this class")
|
||||
}
|
||||
|
||||
|
|
|
@ -177,7 +177,7 @@ class BufferedSignalServiceAccountDataStore(selfServiceId: ServiceId) : SignalSe
|
|||
sessionStore.archiveSession(address)
|
||||
}
|
||||
|
||||
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Set<SignalProtocolAddress> {
|
||||
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Map<SignalProtocolAddress, SessionRecord> {
|
||||
return sessionStore.getAllAddressesWithActiveSessions(addressNames)
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.signal.libsignal.protocol.message.DecryptionErrorMessage;
|
|||
import org.signal.libsignal.protocol.message.PlaintextContent;
|
||||
import org.signal.libsignal.protocol.message.SenderKeyDistributionMessage;
|
||||
import org.signal.libsignal.protocol.state.PreKeyBundle;
|
||||
import org.signal.libsignal.protocol.state.SessionRecord;
|
||||
import org.signal.libsignal.protocol.util.Pair;
|
||||
import org.signal.libsignal.zkgroup.profiles.ClientZkProfileOperations;
|
||||
import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil;
|
||||
|
@ -136,6 +137,7 @@ import java.security.SecureRandom;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
@ -2197,7 +2199,7 @@ public class SignalServiceMessageSender {
|
|||
|
||||
byte[] ciphertext;
|
||||
try {
|
||||
ciphertext = cipher.encryptForGroup(distributionId, targetInfo.destinations, senderCertificate, content.encode(), contentHint, groupId);
|
||||
ciphertext = cipher.encryptForGroup(distributionId, targetInfo.destinations, targetInfo.sessions, senderCertificate, content.encode(), contentHint, groupId);
|
||||
} catch (org.signal.libsignal.protocol.UntrustedIdentityException e) {
|
||||
throw new UntrustedIdentityException("Untrusted during group encrypt", e.getName(), e.getUntrustedIdentity());
|
||||
}
|
||||
|
@ -2245,9 +2247,11 @@ public class SignalServiceMessageSender {
|
|||
}
|
||||
|
||||
private GroupTargetInfo buildGroupTargetInfo(List<SignalServiceAddress> recipients) {
|
||||
List<String> addressNames = recipients.stream().map(SignalServiceAddress::getIdentifier).collect(Collectors.toList());
|
||||
Set<SignalProtocolAddress> destinations = aciStore.getAllAddressesWithActiveSessions(addressNames);
|
||||
Map<String, List<Integer>> devicesByAddressName = new HashMap<>();
|
||||
List<String> addressNames = recipients.stream().map(SignalServiceAddress::getIdentifier).collect(Collectors.toList());
|
||||
Map<SignalProtocolAddress, SessionRecord> sessionMap = aciStore.getAllAddressesWithActiveSessions(addressNames);
|
||||
Map<String, List<Integer>> devicesByAddressName = new HashMap<>();
|
||||
|
||||
Set<SignalProtocolAddress> destinations = new HashSet<>(sessionMap.keySet());
|
||||
|
||||
destinations.addAll(recipients.stream()
|
||||
.map(a -> new SignalProtocolAddress(a.getIdentifier(), SignalServiceAddress.DEFAULT_DEVICE_ID))
|
||||
|
@ -2267,17 +2271,22 @@ public class SignalServiceMessageSender {
|
|||
}
|
||||
}
|
||||
|
||||
return new GroupTargetInfo(new ArrayList<>(destinations), recipientDevices);
|
||||
return new GroupTargetInfo(new ArrayList<>(destinations), recipientDevices, sessionMap);
|
||||
}
|
||||
|
||||
|
||||
private static final class GroupTargetInfo {
|
||||
private final List<SignalProtocolAddress> destinations;
|
||||
private final Map<SignalServiceAddress, List<Integer>> devices;
|
||||
private final List<SignalProtocolAddress> destinations;
|
||||
private final Map<SignalServiceAddress, List<Integer>> devices;
|
||||
private final Map<SignalProtocolAddress, SessionRecord> sessions;
|
||||
|
||||
private GroupTargetInfo(List<SignalProtocolAddress> destinations, Map<SignalServiceAddress, List<Integer>> devices) {
|
||||
private GroupTargetInfo(
|
||||
List<SignalProtocolAddress> destinations,
|
||||
Map<SignalServiceAddress, List<Integer>> devices,
|
||||
Map<SignalProtocolAddress, SessionRecord> sessions) {
|
||||
this.destinations = destinations;
|
||||
this.devices = devices;
|
||||
this.sessions = sessions;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
package org.whispersystems.signalservice.api;
|
||||
|
||||
import org.signal.libsignal.protocol.SignalProtocolAddress;
|
||||
import org.signal.libsignal.protocol.state.SessionRecord;
|
||||
import org.signal.libsignal.protocol.state.SessionStore;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
|
@ -12,5 +14,5 @@ import java.util.Set;
|
|||
*/
|
||||
public interface SignalServiceSessionStore extends SessionStore {
|
||||
void archiveSession(SignalProtocolAddress address);
|
||||
Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(List<String> addressNames);
|
||||
Map<SignalProtocolAddress, SessionRecord> getAllAddressesWithActiveSessions(List<String> addressNames);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package org.whispersystems.signalservice.api.crypto;
|
||||
|
||||
import org.signal.libsignal.internal.Native;
|
||||
import org.signal.libsignal.internal.NativeHandleGuard;
|
||||
import org.signal.libsignal.metadata.InvalidMetadataMessageException;
|
||||
import org.signal.libsignal.metadata.InvalidMetadataVersionException;
|
||||
import org.signal.libsignal.metadata.ProtocolDuplicateMessageException;
|
||||
|
@ -19,9 +21,14 @@ import org.signal.libsignal.protocol.InvalidRegistrationIdException;
|
|||
import org.signal.libsignal.protocol.NoSessionException;
|
||||
import org.signal.libsignal.protocol.SignalProtocolAddress;
|
||||
import org.signal.libsignal.protocol.UntrustedIdentityException;
|
||||
import org.signal.libsignal.protocol.state.SessionRecord;
|
||||
import org.signal.libsignal.protocol.state.SignalProtocolStore;
|
||||
import org.whispersystems.signalservice.api.SignalSessionLock;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* A thread-safe wrapper around {@link SealedSessionCipher}.
|
||||
|
@ -44,11 +51,47 @@ public class SignalSealedSessionCipher {
|
|||
}
|
||||
}
|
||||
|
||||
public byte[] multiRecipientEncrypt(List<SignalProtocolAddress> recipients, UnidentifiedSenderMessageContent content)
|
||||
// TODO: Revert the change here to use the libsignal SealedSessionCipher when the API changes
|
||||
public byte[] multiRecipientEncrypt(SignalProtocolStore signalProtocolStore, List<SignalProtocolAddress> recipients, Map<SignalProtocolAddress, SessionRecord> sessionMap, UnidentifiedSenderMessageContent content)
|
||||
throws InvalidKeyException, UntrustedIdentityException, NoSessionException, InvalidRegistrationIdException
|
||||
{
|
||||
try (SignalSessionLock.Lock unused = lock.acquire()) {
|
||||
return cipher.multiRecipientEncrypt(recipients, content);
|
||||
if (sessionMap == null) {
|
||||
return cipher.multiRecipientEncrypt(recipients, content);
|
||||
}
|
||||
List<SessionRecord> recipientSessions = recipients.stream().map(sessionMap::get).collect(Collectors.toList());
|
||||
if (recipientSessions.stream().anyMatch(Objects::isNull)) {
|
||||
throw new NoSessionException("Failed to find one or more sessions.");
|
||||
}
|
||||
// Unsafely access the native handles for the recipients and sessions,
|
||||
// because try-with-resources syntax doesn't support a List of resources.
|
||||
long[] recipientHandles = new long[recipients.size()];
|
||||
int i = 0;
|
||||
for (SignalProtocolAddress nextRecipient : recipients) {
|
||||
recipientHandles[i] = nextRecipient.unsafeNativeHandleWithoutGuard();
|
||||
i++;
|
||||
}
|
||||
|
||||
long[] recipientSessionHandles = new long[recipientSessions.size()];
|
||||
i = 0;
|
||||
for (SessionRecord nextSession : recipientSessions) {
|
||||
recipientSessionHandles[i] = nextSession.unsafeNativeHandleWithoutGuard();
|
||||
i++;
|
||||
}
|
||||
|
||||
try (NativeHandleGuard contentGuard = new NativeHandleGuard(content)) {
|
||||
byte[] result =
|
||||
Native.SealedSessionCipher_MultiRecipientEncrypt(
|
||||
recipientHandles,
|
||||
recipientSessionHandles,
|
||||
contentGuard.nativeHandle(),
|
||||
signalProtocolStore);
|
||||
// Manually keep the lists of recipients and sessions from being garbage collected
|
||||
// while we're using their native handles.
|
||||
Native.keepAlive(recipients);
|
||||
Native.keepAlive(recipientSessions);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -40,6 +40,7 @@ import org.signal.libsignal.protocol.message.CiphertextMessage;
|
|||
import org.signal.libsignal.protocol.message.PlaintextContent;
|
||||
import org.signal.libsignal.protocol.message.PreKeySignalMessage;
|
||||
import org.signal.libsignal.protocol.message.SignalMessage;
|
||||
import org.signal.libsignal.protocol.state.SessionRecord;
|
||||
import org.whispersystems.signalservice.api.InvalidMessageStructureException;
|
||||
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore;
|
||||
import org.whispersystems.signalservice.api.SignalSessionLock;
|
||||
|
@ -56,6 +57,7 @@ import org.whispersystems.signalservice.internal.push.PushTransportDetails;
|
|||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
|
@ -87,6 +89,7 @@ public class SignalServiceCipher {
|
|||
|
||||
public byte[] encryptForGroup(DistributionId distributionId,
|
||||
List<SignalProtocolAddress> destinations,
|
||||
Map<SignalProtocolAddress, SessionRecord> sessionMap,
|
||||
SenderCertificate senderCertificate,
|
||||
byte[] unpaddedMessage,
|
||||
ContentHint contentHint,
|
||||
|
@ -103,7 +106,7 @@ public class SignalServiceCipher {
|
|||
contentHint.getType(),
|
||||
groupId);
|
||||
|
||||
return sessionCipher.multiRecipientEncrypt(destinations, messageContent);
|
||||
return sessionCipher.multiRecipientEncrypt(signalProtocolStore, destinations, sessionMap, messageContent);
|
||||
}
|
||||
|
||||
public OutgoingPushMessage encrypt(SignalProtocolAddress destination,
|
||||
|
|
|
@ -174,12 +174,11 @@ class InMemorySignalServiceAccountDataStore : SignalServiceAccountDataStore {
|
|||
sessions[address]!!.archiveCurrentState()
|
||||
}
|
||||
|
||||
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Set<SignalProtocolAddress> {
|
||||
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): MutableMap<SignalProtocolAddress, SessionRecord> {
|
||||
return sessions
|
||||
.filter { it.key.name in addressNames }
|
||||
.filter { it.value.isValid() }
|
||||
.map { it.key }
|
||||
.toSet()
|
||||
.toMutableMap()
|
||||
}
|
||||
|
||||
override fun getSenderKeySharedWith(distributionId: DistributionId): Set<SignalProtocolAddress> {
|
||||
|
|
|
@ -156,7 +156,7 @@ class SignalClient {
|
|||
SignalProtocolAddress(bob.aci.toString(), 1)
|
||||
}
|
||||
|
||||
return cipher.encryptForGroup(distributionId, destinations, senderCertificate, content.encode(), ContentHint.DEFAULT, groupId)
|
||||
return cipher.encryptForGroup(distributionId, destinations, null, senderCertificate, content.encode(), ContentHint.DEFAULT, groupId)
|
||||
}
|
||||
|
||||
fun decryptMessage(envelope: Envelope) {
|
||||
|
|
Loading…
Add table
Reference in a new issue