From d6adfea9b1fb427ff3c087def99cf97f32197734 Mon Sep 17 00:00:00 2001 From: Greyson Parrelli Date: Fri, 11 Aug 2023 12:38:03 -0400 Subject: [PATCH] Clean up old one-time prekeys. --- .../database/KyberPreKeyTableTest.kt | 176 ++++++++++++++++++ .../database/OneTimePreKeyTableTest.kt | 137 ++++++++++++++ .../securesms/testing/BobClient.kt | 8 +- .../securesms/crypto/PreKeyUtil.java | 14 +- .../crypto/storage/SignalKyberPreKeyStore.kt | 12 ++ .../SignalServiceAccountDataStoreImpl.java | 21 ++- .../crypto/storage/TextSecurePreKeyStore.java | 13 +- .../securesms/database/KyberPreKeyTable.kt | 49 +++++ .../securesms/database/OneTimePreKeyTable.kt | 44 +++++ .../helpers/SignalDatabaseMigrations.kt | 7 +- .../migration/V203_PreKeyStaleTimestamp.kt | 40 ++++ .../securesms/jobs/PreKeysSyncJob.kt | 1 + .../protocol/BufferedKyberPreKeyStore.kt | 8 + .../BufferedSignalServiceAccountDataStore.kt | 16 ++ .../PniAccountInitializationMigrationJob.java | 7 +- .../api/SignalServiceAccountDataStore.java | 1 + .../api/SignalServiceKyberPreKeyStore.kt | 10 + .../api/SignalServicePreKeyStore.kt | 19 ++ 18 files changed, 572 insertions(+), 11 deletions(-) create mode 100644 app/src/androidTest/java/org/thoughtcrime/securesms/database/KyberPreKeyTableTest.kt create mode 100644 app/src/androidTest/java/org/thoughtcrime/securesms/database/OneTimePreKeyTableTest.kt create mode 100644 app/src/main/java/org/thoughtcrime/securesms/database/helpers/migration/V203_PreKeyStaleTimestamp.kt create mode 100644 libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServicePreKeyStore.kt diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/database/KyberPreKeyTableTest.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/database/KyberPreKeyTableTest.kt new file mode 100644 index 0000000000..f5582006db --- /dev/null +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/database/KyberPreKeyTableTest.kt @@ -0,0 +1,176 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.thoughtcrime.securesms.database + +import junit.framework.TestCase.assertEquals +import junit.framework.TestCase.assertNotNull +import junit.framework.TestCase.assertNull +import org.junit.Test +import org.signal.core.util.readToSingleObject +import org.signal.core.util.requireLongOrNull +import org.signal.core.util.select +import org.signal.core.util.update +import org.signal.libsignal.protocol.ecc.Curve +import org.signal.libsignal.protocol.kem.KEMKeyPair +import org.signal.libsignal.protocol.kem.KEMKeyType +import org.signal.libsignal.protocol.state.KyberPreKeyRecord +import org.whispersystems.signalservice.api.push.ServiceId +import org.whispersystems.signalservice.api.push.ServiceId.ACI +import org.whispersystems.signalservice.api.push.ServiceId.PNI +import java.util.UUID + +class KyberPreKeyTableTest { + + private val aci: ACI = ACI.from(UUID.randomUUID()) + private val pni: PNI = PNI.from(UUID.randomUUID()) + + @Test + fun markAllStaleIfNecessary_onlyUpdatesMatchingAccountAndZeroValues() { + insertTestRecord(aci, id = 1) + insertTestRecord(aci, id = 2) + insertTestRecord(aci, id = 3, staleTime = 42) + insertTestRecord(pni, id = 4) + + val now = System.currentTimeMillis() + SignalDatabase.kyberPreKeys.markAllStaleIfNecessary(aci, now) + + assertEquals(now, getStaleTime(aci, 1)) + assertEquals(now, getStaleTime(aci, 2)) + assertEquals(42L, getStaleTime(aci, 3)) + assertEquals(0L, getStaleTime(pni, 4)) + } + + @Test + fun deleteAllStaleBefore_deleteOldBeforeThreshold() { + insertTestRecord(aci, id = 1, staleTime = 10) + insertTestRecord(aci, id = 2, staleTime = 10) + insertTestRecord(aci, id = 3, staleTime = 10) + insertTestRecord(aci, id = 4, staleTime = 15) + insertTestRecord(aci, id = 5, staleTime = 0) + + SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0) + + assertNull(getStaleTime(aci, 1)) + assertNull(getStaleTime(aci, 2)) + assertNull(getStaleTime(aci, 3)) + assertNotNull(getStaleTime(aci, 4)) + assertNotNull(getStaleTime(aci, 5)) + } + + @Test + fun deleteAllStaleBefore_neverDeleteStaleOfZero() { + insertTestRecord(aci, id = 1, staleTime = 0) + insertTestRecord(aci, id = 2, staleTime = 0) + insertTestRecord(aci, id = 3, staleTime = 0) + insertTestRecord(aci, id = 4, staleTime = 0) + insertTestRecord(aci, id = 5, staleTime = 0) + + SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 10, minCount = 1) + + assertNotNull(getStaleTime(aci, 1)) + assertNotNull(getStaleTime(aci, 2)) + assertNotNull(getStaleTime(aci, 3)) + assertNotNull(getStaleTime(aci, 4)) + assertNotNull(getStaleTime(aci, 5)) + } + + @Test + fun deleteAllStaleBefore_respectMinCount() { + insertTestRecord(aci, id = 1, staleTime = 10) + insertTestRecord(aci, id = 2, staleTime = 10) + insertTestRecord(aci, id = 3, staleTime = 10) + insertTestRecord(aci, id = 4, staleTime = 10) + insertTestRecord(aci, id = 5, staleTime = 10) + + SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3) + + assertNull(getStaleTime(aci, 1)) + assertNull(getStaleTime(aci, 2)) + assertNotNull(getStaleTime(aci, 3)) + assertNotNull(getStaleTime(aci, 4)) + assertNotNull(getStaleTime(aci, 5)) + } + + @Test + fun deleteAllStaleBefore_respectAccount() { + insertTestRecord(aci, id = 1, staleTime = 10) + insertTestRecord(aci, id = 2, staleTime = 10) + insertTestRecord(aci, id = 3, staleTime = 10) + + insertTestRecord(pni, id = 4, staleTime = 10) + insertTestRecord(pni, id = 5, staleTime = 10) + + SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 2) + + assertNull(getStaleTime(aci, 1)) + assertNotNull(getStaleTime(aci, 2)) + assertNotNull(getStaleTime(aci, 3)) + assertNotNull(getStaleTime(pni, 4)) + assertNotNull(getStaleTime(pni, 5)) + } + + @Test + fun deleteAllStaleBefore_ignoreLastResortForMinCount() { + insertTestRecord(aci, id = 1, staleTime = 10) + insertTestRecord(aci, id = 2, staleTime = 10) + insertTestRecord(aci, id = 3, staleTime = 10) + insertTestRecord(aci, id = 4, staleTime = 10) + insertTestRecord(aci, id = 5, staleTime = 10, lastResort = true) + + SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3) + + assertNull(getStaleTime(aci, 1)) + assertNotNull(getStaleTime(aci, 2)) + assertNotNull(getStaleTime(aci, 3)) + assertNotNull(getStaleTime(aci, 4)) + assertNotNull(getStaleTime(aci, 5)) + } + + @Test + fun deleteAllStaleBefore_neverDeleteLastResort() { + insertTestRecord(aci, id = 1, staleTime = 10, lastResort = true) + insertTestRecord(aci, id = 2, staleTime = 10, lastResort = true) + insertTestRecord(aci, id = 3, staleTime = 10, lastResort = true) + + SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0) + + assertNotNull(getStaleTime(aci, 1)) + assertNotNull(getStaleTime(aci, 2)) + assertNotNull(getStaleTime(aci, 3)) + } + + private fun insertTestRecord(account: ServiceId, id: Int, staleTime: Long = 0, lastResort: Boolean = false) { + val kemKeyPair = KEMKeyPair.generate(KEMKeyType.KYBER_1024) + SignalDatabase.kyberPreKeys.insert( + serviceId = account, + keyId = id, + record = KyberPreKeyRecord( + id, + System.currentTimeMillis(), + kemKeyPair, + Curve.generateKeyPair().privateKey.calculateSignature(kemKeyPair.publicKey.serialize()) + ), + lastResort = lastResort + ) + + val count = SignalDatabase.rawDatabase + .update(KyberPreKeyTable.TABLE_NAME) + .values(KyberPreKeyTable.STALE_TIMESTAMP to staleTime) + .where("${KyberPreKeyTable.ACCOUNT_ID} = ? AND ${KyberPreKeyTable.KEY_ID} = $id", account) + .run() + + assertEquals(1, count) + } + + private fun getStaleTime(account: ServiceId, id: Int): Long? { + return SignalDatabase.rawDatabase + .select(KyberPreKeyTable.STALE_TIMESTAMP) + .from(KyberPreKeyTable.TABLE_NAME) + .where("${KyberPreKeyTable.ACCOUNT_ID} = ? AND ${KyberPreKeyTable.KEY_ID} = $id", account) + .run() + .readToSingleObject { it.requireLongOrNull(KyberPreKeyTable.STALE_TIMESTAMP) } + } +} diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/database/OneTimePreKeyTableTest.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/database/OneTimePreKeyTableTest.kt new file mode 100644 index 0000000000..beaacf0d34 --- /dev/null +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/database/OneTimePreKeyTableTest.kt @@ -0,0 +1,137 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.thoughtcrime.securesms.database + +import junit.framework.TestCase.assertEquals +import junit.framework.TestCase.assertNotNull +import junit.framework.TestCase.assertNull +import org.junit.Test +import org.signal.core.util.readToSingleObject +import org.signal.core.util.requireLongOrNull +import org.signal.core.util.select +import org.signal.core.util.update +import org.signal.libsignal.protocol.ecc.Curve +import org.signal.libsignal.protocol.state.PreKeyRecord +import org.whispersystems.signalservice.api.push.ServiceId +import org.whispersystems.signalservice.api.push.ServiceId.ACI +import org.whispersystems.signalservice.api.push.ServiceId.PNI +import java.util.UUID + +class OneTimePreKeyTableTest { + + private val aci: ACI = ACI.from(UUID.randomUUID()) + private val pni: PNI = PNI.from(UUID.randomUUID()) + + @Test + fun markAllStaleIfNecessary_onlyUpdatesMatchingAccountAndZeroValues() { + insertTestRecord(aci, id = 1) + insertTestRecord(aci, id = 2) + insertTestRecord(aci, id = 3, staleTime = 42) + insertTestRecord(pni, id = 4) + + val now = System.currentTimeMillis() + SignalDatabase.oneTimePreKeys.markAllStaleIfNecessary(aci, now) + + assertEquals(now, getStaleTime(aci, 1)) + assertEquals(now, getStaleTime(aci, 2)) + assertEquals(42L, getStaleTime(aci, 3)) + assertEquals(0L, getStaleTime(pni, 4)) + } + + @Test + fun deleteAllStaleBefore_deleteOldBeforeThreshold() { + insertTestRecord(aci, id = 1, staleTime = 10) + insertTestRecord(aci, id = 2, staleTime = 10) + insertTestRecord(aci, id = 3, staleTime = 10) + insertTestRecord(aci, id = 4, staleTime = 15) + insertTestRecord(aci, id = 5, staleTime = 0) + + SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0) + + assertNull(getStaleTime(aci, 1)) + assertNull(getStaleTime(aci, 2)) + assertNull(getStaleTime(aci, 3)) + assertNotNull(getStaleTime(aci, 4)) + assertNotNull(getStaleTime(aci, 5)) + } + + @Test + fun deleteAllStaleBefore_neverDeleteStaleOfZero() { + insertTestRecord(aci, id = 1, staleTime = 0) + insertTestRecord(aci, id = 2, staleTime = 0) + insertTestRecord(aci, id = 3, staleTime = 0) + insertTestRecord(aci, id = 4, staleTime = 0) + insertTestRecord(aci, id = 5, staleTime = 0) + + SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 10, minCount = 0) + + assertNotNull(getStaleTime(aci, 1)) + assertNotNull(getStaleTime(aci, 2)) + assertNotNull(getStaleTime(aci, 3)) + assertNotNull(getStaleTime(aci, 4)) + assertNotNull(getStaleTime(aci, 5)) + } + + @Test + fun deleteAllStaleBefore_respectMinCount() { + insertTestRecord(aci, id = 1, staleTime = 10) + insertTestRecord(aci, id = 2, staleTime = 10) + insertTestRecord(aci, id = 3, staleTime = 10) + insertTestRecord(aci, id = 4, staleTime = 10) + insertTestRecord(aci, id = 5, staleTime = 10) + + SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3) + + assertNull(getStaleTime(aci, 1)) + assertNull(getStaleTime(aci, 2)) + assertNotNull(getStaleTime(aci, 3)) + assertNotNull(getStaleTime(aci, 4)) + assertNotNull(getStaleTime(aci, 5)) + } + + @Test + fun deleteAllStaleBefore_respectAccount() { + insertTestRecord(aci, id = 1, staleTime = 10) + insertTestRecord(aci, id = 2, staleTime = 10) + insertTestRecord(aci, id = 3, staleTime = 10) + + insertTestRecord(pni, id = 4, staleTime = 10) + insertTestRecord(pni, id = 5, staleTime = 10) + + SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 2) + + assertNull(getStaleTime(aci, 1)) + assertNotNull(getStaleTime(aci, 2)) + assertNotNull(getStaleTime(aci, 3)) + assertNotNull(getStaleTime(pni, 4)) + assertNotNull(getStaleTime(pni, 5)) + } + + private fun insertTestRecord(account: ServiceId, id: Int, staleTime: Long = 0) { + SignalDatabase.oneTimePreKeys.insert( + serviceId = account, + keyId = id, + record = PreKeyRecord(id, Curve.generateKeyPair()) + ) + + val count = SignalDatabase.rawDatabase + .update(OneTimePreKeyTable.TABLE_NAME) + .values(OneTimePreKeyTable.STALE_TIMESTAMP to staleTime) + .where("${OneTimePreKeyTable.ACCOUNT_ID} = ? AND ${OneTimePreKeyTable.KEY_ID} = $id", account) + .run() + + assertEquals(1, count) + } + + private fun getStaleTime(account: ServiceId, id: Int): Long? { + return SignalDatabase.rawDatabase + .select(OneTimePreKeyTable.STALE_TIMESTAMP) + .from(OneTimePreKeyTable.TABLE_NAME) + .where("${OneTimePreKeyTable.ACCOUNT_ID} = ? AND ${OneTimePreKeyTable.KEY_ID} = $id", account) + .run() + .readToSingleObject { it.requireLongOrNull(OneTimePreKeyTable.STALE_TIMESTAMP) } + } +} diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/testing/BobClient.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/testing/BobClient.kt index 954a248302..b5e5ab133a 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/testing/BobClient.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/testing/BobClient.kt @@ -32,10 +32,10 @@ import org.whispersystems.signalservice.api.push.DistributionId import org.whispersystems.signalservice.api.push.ServiceId import org.whispersystems.signalservice.api.push.SignalServiceAddress import org.whispersystems.signalservice.internal.push.SignalServiceProtos -import java.lang.UnsupportedOperationException import java.util.Optional import java.util.UUID import java.util.concurrent.locks.ReentrantLock +import kotlin.UnsupportedOperationException /** * Welcome to Bob's Client. @@ -144,7 +144,6 @@ class BobClient(val serviceId: ServiceId, val e164: String, val identityKeyPair: override fun getSubDeviceSessions(name: String?): List = emptyList() override fun containsSession(address: SignalProtocolAddress?): Boolean = aliceSessionRecord != null override fun getIdentity(address: SignalProtocolAddress?): IdentityKey = SignalStore.account().aciIdentityKey.publicKey - override fun loadPreKey(preKeyId: Int): PreKeyRecord = throw UnsupportedOperationException() override fun storePreKey(preKeyId: Int, record: PreKeyRecord?) = throw UnsupportedOperationException() override fun containsPreKey(preKeyId: Int): Boolean = throw UnsupportedOperationException() @@ -162,6 +161,8 @@ class BobClient(val serviceId: ServiceId, val e164: String, val identityKeyPair: override fun storeKyberPreKey(kyberPreKeyId: Int, record: KyberPreKeyRecord?) = throw UnsupportedOperationException() override fun containsKyberPreKey(kyberPreKeyId: Int): Boolean = throw UnsupportedOperationException() override fun markKyberPreKeyUsed(kyberPreKeyId: Int) = throw UnsupportedOperationException() + override fun deleteAllStaleOneTimeEcPreKeys(threshold: Long, minCount: Int) = throw UnsupportedOperationException() + override fun markAllOneTimeEcPreKeysStaleIfNecessary(staleTime: Long) = throw UnsupportedOperationException() override fun storeSenderKey(sender: SignalProtocolAddress?, distributionId: UUID?, record: SenderKeyRecord?) = throw UnsupportedOperationException() override fun loadSenderKey(sender: SignalProtocolAddress?, distributionId: UUID?): SenderKeyRecord = throw UnsupportedOperationException() override fun archiveSession(address: SignalProtocolAddress?) = throw UnsupportedOperationException() @@ -171,8 +172,9 @@ class BobClient(val serviceId: ServiceId, val e164: String, val identityKeyPair: override fun clearSenderKeySharedWith(addresses: MutableCollection?) = throw UnsupportedOperationException() override fun storeLastResortKyberPreKey(kyberPreKeyId: Int, kyberPreKeyRecord: KyberPreKeyRecord) = throw UnsupportedOperationException() override fun removeKyberPreKey(kyberPreKeyId: Int) = throw UnsupportedOperationException() + override fun markAllOneTimeKyberPreKeysStaleIfNecessary(staleTime: Long) = throw UnsupportedOperationException() + override fun deleteAllStaleOneTimeKyberPreKeys(threshold: Long, minCount: Int) = throw UnsupportedOperationException() override fun loadLastResortKyberPreKeys(): List = throw UnsupportedOperationException() - override fun isMultiDevice(): Boolean = throw UnsupportedOperationException() } } diff --git a/app/src/main/java/org/thoughtcrime/securesms/crypto/PreKeyUtil.java b/app/src/main/java/org/thoughtcrime/securesms/crypto/PreKeyUtil.java index 989e529de6..490bbdf105 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/crypto/PreKeyUtil.java +++ b/app/src/main/java/org/thoughtcrime/securesms/crypto/PreKeyUtil.java @@ -49,10 +49,11 @@ public class PreKeyUtil { private static final int BATCH_SIZE = 100; private static final long ARCHIVE_AGE = TimeUnit.DAYS.toMillis(30); - public synchronized static @NonNull List generateAndStoreOneTimeEcPreKeys(@NonNull SignalProtocolStore protocolStore, @NonNull PreKeyMetadataStore metadataStore) { + public synchronized static @NonNull List generateAndStoreOneTimeEcPreKeys(@NonNull SignalServiceAccountDataStore protocolStore, @NonNull PreKeyMetadataStore metadataStore) { int startingId = metadataStore.getNextEcOneTimePreKeyId(); final List records = generateOneTimeEcPreKeys(startingId); + protocolStore.markAllOneTimeEcPreKeysStaleIfNecessary(System.currentTimeMillis()); storeOneTimeEcPreKeys(protocolStore, metadataStore, records); return records; @@ -92,10 +93,11 @@ public class PreKeyUtil { } - public synchronized static @NonNull List generateAndStoreOneTimeKyberPreKeys(@NonNull SignalProtocolStore protocolStore, @NonNull PreKeyMetadataStore metadataStore) { + public synchronized static @NonNull List generateAndStoreOneTimeKyberPreKeys(@NonNull SignalServiceAccountDataStore protocolStore, @NonNull PreKeyMetadataStore metadataStore) { int startingId = metadataStore.getNextKyberPreKeyId(); List records = generateOneTimeKyberPreKeyRecords(startingId, protocolStore.getIdentityKeyPair().getPrivateKey()); + protocolStore.markAllOneTimeKyberPreKeysStaleIfNecessary(System.currentTimeMillis()); storeOneTimeKyberPreKeys(protocolStore, metadataStore, records); return records; @@ -264,4 +266,12 @@ public class PreKeyUtil { Log.w(TAG, e); } } + + public synchronized static void cleanOneTimePreKeys(@NonNull SignalServiceAccountDataStore protocolStore) { + long threshold = System.currentTimeMillis() - TimeUnit.DAYS.toMillis(90); + int minCount = 200; + + protocolStore.deleteAllStaleOneTimeEcPreKeys(threshold, minCount); + protocolStore.deleteAllStaleOneTimeKyberPreKeys(threshold, minCount); + } } diff --git a/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/SignalKyberPreKeyStore.kt b/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/SignalKyberPreKeyStore.kt index ba1de8e17b..c00f6b5092 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/SignalKyberPreKeyStore.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/SignalKyberPreKeyStore.kt @@ -67,4 +67,16 @@ class SignalKyberPreKeyStore(private val selfServiceId: ServiceId) : SignalServi SignalDatabase.kyberPreKeys.delete(selfServiceId, kyberPreKeyId) } } + + override fun markAllOneTimeKyberPreKeysStaleIfNecessary(staleTime: Long) { + ReentrantSessionLock.INSTANCE.acquire().use { + SignalDatabase.kyberPreKeys.markAllStaleIfNecessary(selfServiceId, staleTime) + } + } + + override fun deleteAllStaleOneTimeKyberPreKeys(threshold: Long, minCount: Int) { + ReentrantSessionLock.INSTANCE.acquire().use { + SignalDatabase.kyberPreKeys.deleteAllStaleBefore(selfServiceId, threshold, minCount) + } + } } diff --git a/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/SignalServiceAccountDataStoreImpl.java b/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/SignalServiceAccountDataStoreImpl.java index 97ba3202f6..d00f4c4149 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/SignalServiceAccountDataStoreImpl.java +++ b/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/SignalServiceAccountDataStoreImpl.java @@ -100,6 +100,16 @@ public class SignalServiceAccountDataStoreImpl implements SignalServiceAccountDa preKeyStore.removePreKey(preKeyId); } + @Override + public void markAllOneTimeEcPreKeysStaleIfNecessary(long staleTime) { + preKeyStore.markAllOneTimeEcPreKeysStaleIfNecessary(staleTime); + } + + @Override + public void deleteAllStaleOneTimeEcPreKeys(long threshold, int minCount) { + preKeyStore.deleteAllStaleOneTimeEcPreKeys(threshold, minCount); + } + @Override public SessionRecord loadSession(SignalProtocolAddress axolotlAddress) { return sessionStore.loadSession(axolotlAddress); @@ -211,6 +221,16 @@ public class SignalServiceAccountDataStoreImpl implements SignalServiceAccountDa kyberPreKeyStore.removeKyberPreKey(kyberPreKeyId); } + @Override + public void markAllOneTimeKyberPreKeysStaleIfNecessary(long staleTime) { + kyberPreKeyStore.markAllOneTimeKyberPreKeysStaleIfNecessary(staleTime); + } + + @Override + public void deleteAllStaleOneTimeKyberPreKeys(long threshold, int minCount) { + kyberPreKeyStore.deleteAllStaleOneTimeKyberPreKeys(threshold, minCount); + } + @Override public void storeSenderKey(SignalProtocolAddress sender, UUID distributionId, SenderKeyRecord record) { senderKeyStore.storeSenderKey(sender, distributionId, record); @@ -251,5 +271,4 @@ public class SignalServiceAccountDataStoreImpl implements SignalServiceAccountDa public @NonNull SignalSenderKeyStore senderKeys() { return senderKeyStore; } - } diff --git a/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/TextSecurePreKeyStore.java b/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/TextSecurePreKeyStore.java index 95c209f76f..05c1f54f3a 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/TextSecurePreKeyStore.java +++ b/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/TextSecurePreKeyStore.java @@ -10,12 +10,13 @@ import org.signal.libsignal.protocol.state.SignedPreKeyRecord; import org.signal.libsignal.protocol.state.SignedPreKeyStore; import org.thoughtcrime.securesms.crypto.ReentrantSessionLock; import org.thoughtcrime.securesms.database.SignalDatabase; +import org.whispersystems.signalservice.api.SignalServicePreKeyStore; import org.whispersystems.signalservice.api.SignalSessionLock; import org.whispersystems.signalservice.api.push.ServiceId; import java.util.List; -public class TextSecurePreKeyStore implements PreKeyStore, SignedPreKeyStore { +public class TextSecurePreKeyStore implements SignalServicePreKeyStore, SignedPreKeyStore { @SuppressWarnings("unused") private static final String TAG = Log.tag(TextSecurePreKeyStore.class); @@ -87,4 +88,14 @@ public class TextSecurePreKeyStore implements PreKeyStore, SignedPreKeyStore { public void removeSignedPreKey(int signedPreKeyId) { SignalDatabase.signedPreKeys().delete(accountId, signedPreKeyId); } + + @Override + public void markAllOneTimeEcPreKeysStaleIfNecessary(long staleTime) { + SignalDatabase.oneTimePreKeys().markAllStaleIfNecessary(accountId, staleTime); + } + + @Override + public void deleteAllStaleOneTimeEcPreKeys(long threshold, int minCount) { + SignalDatabase.oneTimePreKeys().deleteAllStaleBefore(accountId, threshold, minCount); + } } diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/KyberPreKeyTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/KyberPreKeyTable.kt index dad16ae710..22da29809d 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/KyberPreKeyTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/KyberPreKeyTable.kt @@ -4,12 +4,14 @@ import android.content.Context import org.signal.core.util.delete import org.signal.core.util.exists import org.signal.core.util.insertInto +import org.signal.core.util.logging.Log import org.signal.core.util.readToList import org.signal.core.util.readToSingleObject import org.signal.core.util.requireBoolean import org.signal.core.util.requireNonNullBlob import org.signal.core.util.select import org.signal.core.util.toInt +import org.signal.core.util.update import org.signal.libsignal.protocol.state.KyberPreKeyRecord import org.whispersystems.signalservice.api.push.ServiceId @@ -18,6 +20,8 @@ import org.whispersystems.signalservice.api.push.ServiceId */ class KyberPreKeyTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTable(context, databaseHelper) { companion object { + private val TAG = Log.tag(KyberPreKeyTable::class.java) + const val TABLE_NAME = "kyber_prekey" const val ID = "_id" const val ACCOUNT_ID = "account_id" @@ -25,6 +29,8 @@ class KyberPreKeyTable(context: Context, databaseHelper: SignalDatabase) : Datab const val TIMESTAMP = "timestamp" const val LAST_RESORT = "last_resort" const val SERIALIZED = "serialized" + const val STALE_TIMESTAMP = "stale_timestamp" + const val CREATE_TABLE = """ CREATE TABLE $TABLE_NAME ( $ID INTEGER PRIMARY KEY, @@ -33,6 +39,7 @@ class KyberPreKeyTable(context: Context, databaseHelper: SignalDatabase) : Datab $TIMESTAMP INTEGER NOT NULL, $LAST_RESORT INTEGER NOT NULL, $SERIALIZED BLOB NOT NULL, + $STALE_TIMESTAMP INTEGER NOT NULL DEFAULT 0, UNIQUE($ACCOUNT_ID, $KEY_ID) ) """ @@ -120,6 +127,48 @@ class KyberPreKeyTable(context: Context, databaseHelper: SignalDatabase) : Datab .run() } + fun markAllStaleIfNecessary(serviceId: ServiceId, staleTime: Long) { + writableDatabase + .update(TABLE_NAME) + .values(STALE_TIMESTAMP to staleTime) + .where("$ACCOUNT_ID = ? AND $STALE_TIMESTAMP = 0 AND $LAST_RESORT = 0", serviceId) + .run() + } + + /** + * Deletes all keys that have been stale since before the specified threshold. + * We will always keep at least [minCount] items, preferring more recent ones. + */ + fun deleteAllStaleBefore(serviceId: ServiceId, threshold: Long, minCount: Int) { + val count = writableDatabase + .delete(TABLE_NAME) + .where( + """ + $ACCOUNT_ID = ? + AND $LAST_RESORT = 0 + AND $STALE_TIMESTAMP > 0 + AND $STALE_TIMESTAMP < $threshold + AND $ID NOT IN ( + SELECT $ID + FROM $TABLE_NAME + WHERE + $ACCOUNT_ID = ? + AND $LAST_RESORT = 0 + ORDER BY + CASE $STALE_TIMESTAMP WHEN 0 THEN 1 ELSE 0 END DESC, + $STALE_TIMESTAMP DESC, + $ID DESC + LIMIT $minCount + ) + """, + serviceId, + serviceId + ) + .run() + + Log.i(TAG, "Deleted $count stale one-time EC prekeys.") + } + data class KyberPreKey( val record: KyberPreKeyRecord, val lastResort: Boolean diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/OneTimePreKeyTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/OneTimePreKeyTable.kt index 5989f1cf43..3c934d5c4a 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/OneTimePreKeyTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/OneTimePreKeyTable.kt @@ -3,8 +3,10 @@ package org.thoughtcrime.securesms.database import android.content.Context import androidx.core.content.contentValuesOf import org.signal.core.util.SqlUtil +import org.signal.core.util.delete import org.signal.core.util.logging.Log import org.signal.core.util.requireNonNullString +import org.signal.core.util.update import org.signal.libsignal.protocol.InvalidKeyException import org.signal.libsignal.protocol.ecc.Curve import org.signal.libsignal.protocol.ecc.ECKeyPair @@ -23,6 +25,8 @@ class OneTimePreKeyTable(context: Context, databaseHelper: SignalDatabase) : Dat const val KEY_ID = "key_id" const val PUBLIC_KEY = "public_key" const val PRIVATE_KEY = "private_key" + const val STALE_TIMESTAMP = "stale_timestamp" + const val CREATE_TABLE = """ CREATE TABLE $TABLE_NAME ( $ID INTEGER PRIMARY KEY, @@ -30,6 +34,7 @@ class OneTimePreKeyTable(context: Context, databaseHelper: SignalDatabase) : Dat $KEY_ID INTEGER UNIQUE, $PUBLIC_KEY TEXT NOT NULL, $PRIVATE_KEY TEXT NOT NULL, + $STALE_TIMESTAMP INTEGER NOT NULL DEFAULT 0, UNIQUE($ACCOUNT_ID, $KEY_ID) ) """ @@ -68,4 +73,43 @@ class OneTimePreKeyTable(context: Context, databaseHelper: SignalDatabase) : Dat val database = databaseHelper.signalWritableDatabase database.delete(TABLE_NAME, "$ACCOUNT_ID = ? AND $KEY_ID = ?", SqlUtil.buildArgs(serviceId, keyId)) } + + fun markAllStaleIfNecessary(serviceId: ServiceId, staleTime: Long) { + writableDatabase + .update(TABLE_NAME) + .values(STALE_TIMESTAMP to staleTime) + .where("$ACCOUNT_ID = ? AND $STALE_TIMESTAMP = 0", serviceId) + .run() + } + + /** + * Deletes all keys that have been stale since before the specified threshold. + * We will always keep at least [minCount] items, preferring more recent ones. + */ + fun deleteAllStaleBefore(serviceId: ServiceId, threshold: Long, minCount: Int) { + val count = writableDatabase + .delete(TABLE_NAME) + .where( + """ + $ACCOUNT_ID = ? + AND $STALE_TIMESTAMP > 0 + AND $STALE_TIMESTAMP < $threshold + AND $ID NOT IN ( + SELECT $ID + FROM $TABLE_NAME + WHERE $ACCOUNT_ID = ? + ORDER BY + CASE $STALE_TIMESTAMP WHEN 0 THEN 1 ELSE 0 END DESC, + $STALE_TIMESTAMP DESC, + $ID DESC + LIMIT $minCount + ) + """, + serviceId, + serviceId + ) + .run() + + Log.i(TAG, "Deleted $count stale one-time EC prekeys.") + } } diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/helpers/SignalDatabaseMigrations.kt b/app/src/main/java/org/thoughtcrime/securesms/database/helpers/SignalDatabaseMigrations.kt index 3f2b6abcf4..ad72f536d3 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/helpers/SignalDatabaseMigrations.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/helpers/SignalDatabaseMigrations.kt @@ -58,6 +58,7 @@ import org.thoughtcrime.securesms.database.helpers.migration.V199_AddThreadActiv import org.thoughtcrime.securesms.database.helpers.migration.V200_ResetPniColumn import org.thoughtcrime.securesms.database.helpers.migration.V201_RecipientTableValidations import org.thoughtcrime.securesms.database.helpers.migration.V202_DropMessageTableThreadDateIndex +import org.thoughtcrime.securesms.database.helpers.migration.V203_PreKeyStaleTimestamp /** * Contains all of the database migrations for [SignalDatabase]. Broken into a separate file for cleanliness. @@ -66,7 +67,7 @@ object SignalDatabaseMigrations { val TAG: String = Log.tag(SignalDatabaseMigrations.javaClass) - const val DATABASE_VERSION = 202 + const val DATABASE_VERSION = 203 @JvmStatic fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) { @@ -285,6 +286,10 @@ object SignalDatabaseMigrations { if (oldVersion < 202) { V202_DropMessageTableThreadDateIndex.migrate(context, db, oldVersion, newVersion) } + + if (oldVersion < 203) { + V203_PreKeyStaleTimestamp.migrate(context, db, oldVersion, newVersion) + } } @JvmStatic diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/helpers/migration/V203_PreKeyStaleTimestamp.kt b/app/src/main/java/org/thoughtcrime/securesms/database/helpers/migration/V203_PreKeyStaleTimestamp.kt new file mode 100644 index 0000000000..f5d240a2e8 --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/database/helpers/migration/V203_PreKeyStaleTimestamp.kt @@ -0,0 +1,40 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.thoughtcrime.securesms.database.helpers.migration + +import android.app.Application +import androidx.sqlite.db.SupportSQLiteDatabase +import net.zetetic.database.sqlcipher.SQLiteDatabase + +/** + * Keep track of a "stale timestamp" for one-time prekeys so that we can know when it's safe to delete them. + */ +@Suppress("ClassName") +object V203_PreKeyStaleTimestamp : SignalDatabaseMigration { + override fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) { + // Note: Because of a sequencing issue between beta/nightly, we had two V202 migrations (of which this used to be one of them), + // so we have to do some conditional migrating based on the user's current state. + db.execSQL("DROP INDEX IF EXISTS message_thread_date_index") + + if (!columnExists(db, "one_time_prekeys", "stale_timestamp")) { + db.execSQL("ALTER TABLE one_time_prekeys ADD COLUMN stale_timestamp INTEGER NOT NULL DEFAULT 0") + db.execSQL("ALTER TABLE kyber_prekey ADD COLUMN stale_timestamp INTEGER NOT NULL DEFAULT 0") + } + } + + private fun columnExists(db: SupportSQLiteDatabase, table: String, column: String): Boolean { + db.query("PRAGMA table_info($table)", null).use { cursor -> + val nameColumnIndex = cursor.getColumnIndexOrThrow("name") + while (cursor.moveToNext()) { + val name = cursor.getString(nameColumnIndex) + if (name == column) { + return true + } + } + } + return false + } +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJob.kt index 777681c85b..a80402eca3 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJob.kt @@ -181,6 +181,7 @@ class PreKeysSyncJob private constructor(parameters: Parameters) : BaseJob(param log(serviceIdType, "Cleaning prekeys...") PreKeyUtil.cleanSignedPreKeys(protocolStore, metadataStore) PreKeyUtil.cleanLastResortKyberPreKeys(protocolStore, metadataStore) + PreKeyUtil.cleanOneTimePreKeys(protocolStore) } private fun signedPreKeyUploadIfNeeded(serviceIdType: ServiceIdType, protocolStore: SignalProtocolStore, metadataStore: PreKeyMetadataStore): SignedPreKeyRecord? { diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedKyberPreKeyStore.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedKyberPreKeyStore.kt index 9082680219..ea6617cd68 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedKyberPreKeyStore.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedKyberPreKeyStore.kt @@ -79,6 +79,14 @@ class BufferedKyberPreKeyStore(private val selfServiceId: ServiceId) : SignalSer error("Not expected in this flow") } + override fun markAllOneTimeKyberPreKeysStaleIfNecessary(staleTime: Long) { + error("Not expected in this flow") + } + + override fun deleteAllStaleOneTimeKyberPreKeys(threshold: Long, minCount: Int) { + error("Not expected in this flow") + } + fun flushToDisk(persistentStore: SignalServiceAccountDataStore) { for (id in removedIfNotLastResort) { persistentStore.markKyberPreKeyUsed(id) diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSignalServiceAccountDataStore.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSignalServiceAccountDataStore.kt index b875ff88ff..b00bb91527 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSignalServiceAccountDataStore.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSignalServiceAccountDataStore.kt @@ -141,10 +141,26 @@ class BufferedSignalServiceAccountDataStore(selfServiceId: ServiceId) : SignalSe return kyberPreKeyStore.markKyberPreKeyUsed(kyberPreKeyId) } + override fun deleteAllStaleOneTimeEcPreKeys(threshold: Long, minCount: Int) { + error("Should not happen during the intended usage pattern of this class") + } + + override fun markAllOneTimeEcPreKeysStaleIfNecessary(staleTime: Long) { + error("Should not happen during the intended usage pattern of this class") + } + override fun removeKyberPreKey(kyberPreKeyId: Int) { kyberPreKeyStore.removeKyberPreKey(kyberPreKeyId) } + override fun markAllOneTimeKyberPreKeysStaleIfNecessary(staleTime: Long) { + kyberPreKeyStore.markAllOneTimeKyberPreKeysStaleIfNecessary(staleTime) + } + + override fun deleteAllStaleOneTimeKyberPreKeys(threshold: Long, minCount: Int) { + kyberPreKeyStore.deleteAllStaleOneTimeKyberPreKeys(threshold, minCount) + } + override fun loadLastResortKyberPreKeys(): List { return kyberPreKeyStore.loadLastResortKyberPreKeys() } diff --git a/app/src/main/java/org/thoughtcrime/securesms/migrations/PniAccountInitializationMigrationJob.java b/app/src/main/java/org/thoughtcrime/securesms/migrations/PniAccountInitializationMigrationJob.java index 2b790eca84..fa933b61c0 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/migrations/PniAccountInitializationMigrationJob.java +++ b/app/src/main/java/org/thoughtcrime/securesms/migrations/PniAccountInitializationMigrationJob.java @@ -14,6 +14,7 @@ import org.thoughtcrime.securesms.jobmanager.Job; import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint; import org.thoughtcrime.securesms.keyvalue.SignalStore; import org.thoughtcrime.securesms.recipients.Recipient; +import org.whispersystems.signalservice.api.SignalServiceAccountDataStore; import org.whispersystems.signalservice.api.SignalServiceAccountManager; import org.whispersystems.signalservice.api.account.PreKeyUpload; import org.whispersystems.signalservice.api.push.ServiceId.PNI; @@ -70,9 +71,9 @@ public class PniAccountInitializationMigrationJob extends MigrationJob { Log.w(TAG, "Already generated the PNI identity. Skipping this step."); } - SignalServiceAccountManager accountManager = ApplicationDependencies.getSignalServiceAccountManager(); - SignalProtocolStore protocolStore = ApplicationDependencies.getProtocolStore().pni(); - PreKeyMetadataStore metadataStore = SignalStore.account().pniPreKeys(); + SignalServiceAccountManager accountManager = ApplicationDependencies.getSignalServiceAccountManager(); + SignalServiceAccountDataStore protocolStore = ApplicationDependencies.getProtocolStore().pni(); + PreKeyMetadataStore metadataStore = SignalStore.account().pniPreKeys(); if (!metadataStore.isSignedPreKeyRegistered()) { Log.i(TAG, "Uploading signed prekey for PNI."); diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountDataStore.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountDataStore.java index af5ce04932..7ec77b399b 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountDataStore.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountDataStore.java @@ -7,6 +7,7 @@ import org.signal.libsignal.protocol.state.SignalProtocolStore; * in the service layer, but not the protocol layer. */ public interface SignalServiceAccountDataStore extends SignalProtocolStore, + SignalServicePreKeyStore, SignalServiceSessionStore, SignalServiceSenderKeyStore, SignalServiceKyberPreKeyStore { diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceKyberPreKeyStore.kt b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceKyberPreKeyStore.kt index 17aba70e5b..a720db1a90 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceKyberPreKeyStore.kt +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceKyberPreKeyStore.kt @@ -23,4 +23,14 @@ interface SignalServiceKyberPreKeyStore : KyberPreKeyStore { * Unconditionally remove the specified key from the store. */ fun removeKyberPreKey(kyberPreKeyId: Int) + + /** + * Marks all prekeys stale if they haven't been marked already. "Stale" means the time that the keys have been replaced. + */ + fun markAllOneTimeKyberPreKeysStaleIfNecessary(staleTime: Long) + + /** + * Deletes all prekeys that have been stale since before the threshold. "Stale" means the time that the keys have been replaced. + */ + fun deleteAllStaleOneTimeKyberPreKeys(threshold: Long, minCount: Int) } diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServicePreKeyStore.kt b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServicePreKeyStore.kt new file mode 100644 index 0000000000..55f57dfff1 --- /dev/null +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServicePreKeyStore.kt @@ -0,0 +1,19 @@ +package org.whispersystems.signalservice.api + +import org.signal.libsignal.protocol.state.PreKeyStore + +/** + * And extension of the normal protocol prekey store interface that has additional methods that are + * needed in the service layer, but not the protocol layer. + */ +interface SignalServicePreKeyStore : PreKeyStore { + /** + * Marks all prekeys stale if they haven't been marked already. "Stale" means the time that the keys have been replaced. + */ + fun markAllOneTimeEcPreKeysStaleIfNecessary(staleTime: Long) + + /** + * Deletes all prekeys that have been stale since before the threshold. "Stale" means the time that the keys have been replaced. + */ + fun deleteAllStaleOneTimeEcPreKeys(threshold: Long, minCount: Int) +}