Unify locks in protocol stores.

This commit is contained in:
Greyson Parrelli 2023-03-23 15:37:35 -04:00
parent 2763cfe6f4
commit d58c4ef439
4 changed files with 33 additions and 32 deletions

View file

@ -11,6 +11,7 @@ import org.signal.core.util.logging.Log;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.signal.libsignal.protocol.state.IdentityKeyStore;
import org.thoughtcrime.securesms.crypto.ReentrantSessionLock;
import org.thoughtcrime.securesms.crypto.storage.SignalIdentityKeyStore.SaveResult;
import org.thoughtcrime.securesms.database.IdentityTable;
import org.thoughtcrime.securesms.database.IdentityTable.VerifiedStatus;
@ -24,6 +25,7 @@ import org.thoughtcrime.securesms.recipients.Recipient;
import org.thoughtcrime.securesms.recipients.RecipientId;
import org.thoughtcrime.securesms.util.IdentityUtil;
import org.thoughtcrime.securesms.util.LRUCache;
import org.whispersystems.signalservice.api.SignalSessionLock;
import org.whispersystems.signalservice.api.push.ServiceId;
import java.util.ArrayList;
@ -43,7 +45,6 @@ public class SignalBaseIdentityKeyStore {
private static final String TAG = Log.tag(SignalBaseIdentityKeyStore.class);
private static final Object LOCK = new Object();
private static final int TIMESTAMP_THRESHOLD_SECONDS = 5;
private final Context context;
@ -67,7 +68,7 @@ public class SignalBaseIdentityKeyStore {
}
public @NonNull SaveResult saveIdentity(SignalProtocolAddress address, IdentityKey identityKey, boolean nonBlockingApproval) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
IdentityStoreRecord identityRecord = cache.get(address.getName());
RecipientId recipientId = RecipientId.fromSidOrE164(address.getName());

View file

@ -6,9 +6,11 @@ import androidx.annotation.NonNull;
import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord;
import org.thoughtcrime.securesms.crypto.ReentrantSessionLock;
import org.thoughtcrime.securesms.database.SenderKeyTable;
import org.thoughtcrime.securesms.database.SignalDatabase;
import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore;
import org.whispersystems.signalservice.api.SignalSessionLock;
import org.whispersystems.signalservice.api.push.DistributionId;
import java.util.Collection;
@ -23,8 +25,6 @@ import javax.annotation.Nullable;
*/
public final class SignalSenderKeyStore implements SignalServiceSenderKeyStore {
private static final Object LOCK = new Object();
private final Context context;
public SignalSenderKeyStore(@NonNull Context context) {
@ -33,35 +33,35 @@ public final class SignalSenderKeyStore implements SignalServiceSenderKeyStore {
@Override
public void storeSenderKey(@NonNull SignalProtocolAddress sender, @NonNull UUID distributionId, @NonNull SenderKeyRecord record) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
SignalDatabase.senderKeys().store(sender, DistributionId.from(distributionId), record);
}
}
@Override
public @Nullable SenderKeyRecord loadSenderKey(@NonNull SignalProtocolAddress sender, @NonNull UUID distributionId) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
return SignalDatabase.senderKeys().load(sender, DistributionId.from(distributionId));
}
}
@Override
public Set<SignalProtocolAddress> getSenderKeySharedWith(DistributionId distributionId) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
return SignalDatabase.senderKeyShared().getSharedWith(distributionId);
}
}
@Override
public void markSenderKeySharedWith(DistributionId distributionId, Collection<SignalProtocolAddress> addresses) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
SignalDatabase.senderKeyShared().markAsShared(distributionId, addresses);
}
}
@Override
public void clearSenderKeySharedWith(Collection<SignalProtocolAddress> addresses) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
SignalDatabase.senderKeyShared().deleteAllFor(addresses);
}
}
@ -70,7 +70,7 @@ public final class SignalSenderKeyStore implements SignalServiceSenderKeyStore {
* Removes all sender key session state for all devices for the provided recipient-distributionId pair.
*/
public void deleteAllFor(@NonNull String addressName, @NonNull DistributionId distributionId) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
SignalDatabase.senderKeys().deleteAllFor(addressName, distributionId);
}
}
@ -79,7 +79,7 @@ public final class SignalSenderKeyStore implements SignalServiceSenderKeyStore {
* Deletes all sender key session state.
*/
public void deleteAll() {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
SignalDatabase.senderKeys().deleteAll();
}
}

View file

@ -8,7 +8,9 @@ import org.signal.libsignal.protocol.state.PreKeyRecord;
import org.signal.libsignal.protocol.state.PreKeyStore;
import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
import org.signal.libsignal.protocol.state.SignedPreKeyStore;
import org.thoughtcrime.securesms.crypto.ReentrantSessionLock;
import org.thoughtcrime.securesms.database.SignalDatabase;
import org.whispersystems.signalservice.api.SignalSessionLock;
import org.whispersystems.signalservice.api.push.ServiceId;
import java.util.List;
@ -18,8 +20,6 @@ public class TextSecurePreKeyStore implements PreKeyStore, SignedPreKeyStore {
@SuppressWarnings("unused")
private static final String TAG = Log.tag(TextSecurePreKeyStore.class);
private static final Object LOCK = new Object();
@NonNull
private final ServiceId accountId;
@ -29,7 +29,7 @@ public class TextSecurePreKeyStore implements PreKeyStore, SignedPreKeyStore {
@Override
public PreKeyRecord loadPreKey(int preKeyId) throws InvalidKeyIdException {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
PreKeyRecord preKeyRecord = SignalDatabase.oneTimePreKeys().get(accountId, preKeyId);
if (preKeyRecord == null) throw new InvalidKeyIdException("No such key: " + preKeyId);
@ -39,7 +39,7 @@ public class TextSecurePreKeyStore implements PreKeyStore, SignedPreKeyStore {
@Override
public SignedPreKeyRecord loadSignedPreKey(int signedPreKeyId) throws InvalidKeyIdException {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
SignedPreKeyRecord signedPreKeyRecord = SignalDatabase.signedPreKeys().get(accountId, signedPreKeyId);
if (signedPreKeyRecord == null) throw new InvalidKeyIdException("No such signed prekey: " + signedPreKeyId);
@ -49,21 +49,21 @@ public class TextSecurePreKeyStore implements PreKeyStore, SignedPreKeyStore {
@Override
public List<SignedPreKeyRecord> loadSignedPreKeys() {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
return SignalDatabase.signedPreKeys().getAll(accountId);
}
}
@Override
public void storePreKey(int preKeyId, PreKeyRecord record) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
SignalDatabase.oneTimePreKeys().insert(accountId, preKeyId, record);
}
}
@Override
public void storeSignedPreKey(int signedPreKeyId, SignedPreKeyRecord record) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
SignalDatabase.signedPreKeys().insert(accountId, signedPreKeyId, record);
}
}

View file

@ -8,11 +8,13 @@ import org.signal.libsignal.protocol.NoSessionException;
import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.signal.libsignal.protocol.message.CiphertextMessage;
import org.signal.libsignal.protocol.state.SessionRecord;
import org.thoughtcrime.securesms.crypto.ReentrantSessionLock;
import org.thoughtcrime.securesms.database.SessionTable;
import org.thoughtcrime.securesms.database.SignalDatabase;
import org.thoughtcrime.securesms.recipients.Recipient;
import org.thoughtcrime.securesms.recipients.RecipientId;
import org.whispersystems.signalservice.api.SignalServiceSessionStore;
import org.whispersystems.signalservice.api.SignalSessionLock;
import org.whispersystems.signalservice.api.push.ServiceId;
import java.util.List;
@ -24,8 +26,6 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
private static final String TAG = Log.tag(TextSecureSessionStore.class);
private static final Object LOCK = new Object();
private final ServiceId accountId;
public TextSecureSessionStore(@NonNull ServiceId accountId) {
@ -34,7 +34,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public SessionRecord loadSession(@NonNull SignalProtocolAddress address) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
SessionRecord sessionRecord = SignalDatabase.sessions().load(accountId, address);
if (sessionRecord == null) {
@ -48,7 +48,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public List<SessionRecord> loadExistingSessions(List<SignalProtocolAddress> addresses) throws NoSessionException {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
List<SessionRecord> sessionRecords = SignalDatabase.sessions().load(accountId, addresses);
if (sessionRecords.size() != addresses.size()) {
@ -67,14 +67,14 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public void storeSession(@NonNull SignalProtocolAddress address, @NonNull SessionRecord record) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
SignalDatabase.sessions().store(accountId, address, record);
}
}
@Override
public boolean containsSession(SignalProtocolAddress address) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
SessionRecord sessionRecord = SignalDatabase.sessions().load(accountId, address);
return sessionRecord != null &&
@ -85,7 +85,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public void deleteSession(SignalProtocolAddress address) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
Log.w(TAG, "Deleting session for " + address);
SignalDatabase.sessions().delete(accountId, address);
}
@ -93,7 +93,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public void deleteAllSessions(String name) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
Log.w(TAG, "Deleting all sessions for " + name);
SignalDatabase.sessions().deleteAllFor(accountId, name);
}
@ -101,14 +101,14 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public List<Integer> getSubDeviceSessions(String name) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
return SignalDatabase.sessions().getSubDevices(accountId, name);
}
}
@Override
public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(List<String> addressNames) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
return SignalDatabase.sessions()
.getAllFor(accountId, addressNames)
.stream()
@ -120,7 +120,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
@Override
public void archiveSession(SignalProtocolAddress address) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
SessionRecord session = SignalDatabase.sessions().load(accountId, address);
if (session != null) {
session.archiveCurrentState();
@ -130,7 +130,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
}
public void archiveSession(@NonNull RecipientId recipientId, int deviceId) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
Recipient recipient = Recipient.resolved(recipientId);
if (recipient.hasServiceId()) {
@ -144,7 +144,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
}
public void archiveSiblingSessions(@NonNull SignalProtocolAddress address) {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
List<SessionTable.SessionRow> sessions = SignalDatabase.sessions().getAllFor(accountId, address.getName());
for (SessionTable.SessionRow row : sessions) {
@ -157,7 +157,7 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
}
public void archiveAllSessions() {
synchronized (LOCK) {
try (SignalSessionLock.Lock unused = ReentrantSessionLock.INSTANCE.acquire()) {
List<SessionTable.SessionRow> sessions = SignalDatabase.sessions().getAll(accountId);
for (SessionTable.SessionRow row : sessions) {