Clean up old one-time prekeys.
This commit is contained in:
parent
389b439e9a
commit
d6adfea9b1
18 changed files with 572 additions and 11 deletions
|
@ -0,0 +1,176 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.thoughtcrime.securesms.database
|
||||
|
||||
import junit.framework.TestCase.assertEquals
|
||||
import junit.framework.TestCase.assertNotNull
|
||||
import junit.framework.TestCase.assertNull
|
||||
import org.junit.Test
|
||||
import org.signal.core.util.readToSingleObject
|
||||
import org.signal.core.util.requireLongOrNull
|
||||
import org.signal.core.util.select
|
||||
import org.signal.core.util.update
|
||||
import org.signal.libsignal.protocol.ecc.Curve
|
||||
import org.signal.libsignal.protocol.kem.KEMKeyPair
|
||||
import org.signal.libsignal.protocol.kem.KEMKeyType
|
||||
import org.signal.libsignal.protocol.state.KyberPreKeyRecord
|
||||
import org.whispersystems.signalservice.api.push.ServiceId
|
||||
import org.whispersystems.signalservice.api.push.ServiceId.ACI
|
||||
import org.whispersystems.signalservice.api.push.ServiceId.PNI
|
||||
import java.util.UUID
|
||||
|
||||
class KyberPreKeyTableTest {
|
||||
|
||||
private val aci: ACI = ACI.from(UUID.randomUUID())
|
||||
private val pni: PNI = PNI.from(UUID.randomUUID())
|
||||
|
||||
@Test
|
||||
fun markAllStaleIfNecessary_onlyUpdatesMatchingAccountAndZeroValues() {
|
||||
insertTestRecord(aci, id = 1)
|
||||
insertTestRecord(aci, id = 2)
|
||||
insertTestRecord(aci, id = 3, staleTime = 42)
|
||||
insertTestRecord(pni, id = 4)
|
||||
|
||||
val now = System.currentTimeMillis()
|
||||
SignalDatabase.kyberPreKeys.markAllStaleIfNecessary(aci, now)
|
||||
|
||||
assertEquals(now, getStaleTime(aci, 1))
|
||||
assertEquals(now, getStaleTime(aci, 2))
|
||||
assertEquals(42L, getStaleTime(aci, 3))
|
||||
assertEquals(0L, getStaleTime(pni, 4))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun deleteAllStaleBefore_deleteOldBeforeThreshold() {
|
||||
insertTestRecord(aci, id = 1, staleTime = 10)
|
||||
insertTestRecord(aci, id = 2, staleTime = 10)
|
||||
insertTestRecord(aci, id = 3, staleTime = 10)
|
||||
insertTestRecord(aci, id = 4, staleTime = 15)
|
||||
insertTestRecord(aci, id = 5, staleTime = 0)
|
||||
|
||||
SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0)
|
||||
|
||||
assertNull(getStaleTime(aci, 1))
|
||||
assertNull(getStaleTime(aci, 2))
|
||||
assertNull(getStaleTime(aci, 3))
|
||||
assertNotNull(getStaleTime(aci, 4))
|
||||
assertNotNull(getStaleTime(aci, 5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun deleteAllStaleBefore_neverDeleteStaleOfZero() {
|
||||
insertTestRecord(aci, id = 1, staleTime = 0)
|
||||
insertTestRecord(aci, id = 2, staleTime = 0)
|
||||
insertTestRecord(aci, id = 3, staleTime = 0)
|
||||
insertTestRecord(aci, id = 4, staleTime = 0)
|
||||
insertTestRecord(aci, id = 5, staleTime = 0)
|
||||
|
||||
SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 10, minCount = 1)
|
||||
|
||||
assertNotNull(getStaleTime(aci, 1))
|
||||
assertNotNull(getStaleTime(aci, 2))
|
||||
assertNotNull(getStaleTime(aci, 3))
|
||||
assertNotNull(getStaleTime(aci, 4))
|
||||
assertNotNull(getStaleTime(aci, 5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun deleteAllStaleBefore_respectMinCount() {
|
||||
insertTestRecord(aci, id = 1, staleTime = 10)
|
||||
insertTestRecord(aci, id = 2, staleTime = 10)
|
||||
insertTestRecord(aci, id = 3, staleTime = 10)
|
||||
insertTestRecord(aci, id = 4, staleTime = 10)
|
||||
insertTestRecord(aci, id = 5, staleTime = 10)
|
||||
|
||||
SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3)
|
||||
|
||||
assertNull(getStaleTime(aci, 1))
|
||||
assertNull(getStaleTime(aci, 2))
|
||||
assertNotNull(getStaleTime(aci, 3))
|
||||
assertNotNull(getStaleTime(aci, 4))
|
||||
assertNotNull(getStaleTime(aci, 5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun deleteAllStaleBefore_respectAccount() {
|
||||
insertTestRecord(aci, id = 1, staleTime = 10)
|
||||
insertTestRecord(aci, id = 2, staleTime = 10)
|
||||
insertTestRecord(aci, id = 3, staleTime = 10)
|
||||
|
||||
insertTestRecord(pni, id = 4, staleTime = 10)
|
||||
insertTestRecord(pni, id = 5, staleTime = 10)
|
||||
|
||||
SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 2)
|
||||
|
||||
assertNull(getStaleTime(aci, 1))
|
||||
assertNotNull(getStaleTime(aci, 2))
|
||||
assertNotNull(getStaleTime(aci, 3))
|
||||
assertNotNull(getStaleTime(pni, 4))
|
||||
assertNotNull(getStaleTime(pni, 5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun deleteAllStaleBefore_ignoreLastResortForMinCount() {
|
||||
insertTestRecord(aci, id = 1, staleTime = 10)
|
||||
insertTestRecord(aci, id = 2, staleTime = 10)
|
||||
insertTestRecord(aci, id = 3, staleTime = 10)
|
||||
insertTestRecord(aci, id = 4, staleTime = 10)
|
||||
insertTestRecord(aci, id = 5, staleTime = 10, lastResort = true)
|
||||
|
||||
SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3)
|
||||
|
||||
assertNull(getStaleTime(aci, 1))
|
||||
assertNotNull(getStaleTime(aci, 2))
|
||||
assertNotNull(getStaleTime(aci, 3))
|
||||
assertNotNull(getStaleTime(aci, 4))
|
||||
assertNotNull(getStaleTime(aci, 5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun deleteAllStaleBefore_neverDeleteLastResort() {
|
||||
insertTestRecord(aci, id = 1, staleTime = 10, lastResort = true)
|
||||
insertTestRecord(aci, id = 2, staleTime = 10, lastResort = true)
|
||||
insertTestRecord(aci, id = 3, staleTime = 10, lastResort = true)
|
||||
|
||||
SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0)
|
||||
|
||||
assertNotNull(getStaleTime(aci, 1))
|
||||
assertNotNull(getStaleTime(aci, 2))
|
||||
assertNotNull(getStaleTime(aci, 3))
|
||||
}
|
||||
|
||||
private fun insertTestRecord(account: ServiceId, id: Int, staleTime: Long = 0, lastResort: Boolean = false) {
|
||||
val kemKeyPair = KEMKeyPair.generate(KEMKeyType.KYBER_1024)
|
||||
SignalDatabase.kyberPreKeys.insert(
|
||||
serviceId = account,
|
||||
keyId = id,
|
||||
record = KyberPreKeyRecord(
|
||||
id,
|
||||
System.currentTimeMillis(),
|
||||
kemKeyPair,
|
||||
Curve.generateKeyPair().privateKey.calculateSignature(kemKeyPair.publicKey.serialize())
|
||||
),
|
||||
lastResort = lastResort
|
||||
)
|
||||
|
||||
val count = SignalDatabase.rawDatabase
|
||||
.update(KyberPreKeyTable.TABLE_NAME)
|
||||
.values(KyberPreKeyTable.STALE_TIMESTAMP to staleTime)
|
||||
.where("${KyberPreKeyTable.ACCOUNT_ID} = ? AND ${KyberPreKeyTable.KEY_ID} = $id", account)
|
||||
.run()
|
||||
|
||||
assertEquals(1, count)
|
||||
}
|
||||
|
||||
private fun getStaleTime(account: ServiceId, id: Int): Long? {
|
||||
return SignalDatabase.rawDatabase
|
||||
.select(KyberPreKeyTable.STALE_TIMESTAMP)
|
||||
.from(KyberPreKeyTable.TABLE_NAME)
|
||||
.where("${KyberPreKeyTable.ACCOUNT_ID} = ? AND ${KyberPreKeyTable.KEY_ID} = $id", account)
|
||||
.run()
|
||||
.readToSingleObject { it.requireLongOrNull(KyberPreKeyTable.STALE_TIMESTAMP) }
|
||||
}
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.thoughtcrime.securesms.database
|
||||
|
||||
import junit.framework.TestCase.assertEquals
|
||||
import junit.framework.TestCase.assertNotNull
|
||||
import junit.framework.TestCase.assertNull
|
||||
import org.junit.Test
|
||||
import org.signal.core.util.readToSingleObject
|
||||
import org.signal.core.util.requireLongOrNull
|
||||
import org.signal.core.util.select
|
||||
import org.signal.core.util.update
|
||||
import org.signal.libsignal.protocol.ecc.Curve
|
||||
import org.signal.libsignal.protocol.state.PreKeyRecord
|
||||
import org.whispersystems.signalservice.api.push.ServiceId
|
||||
import org.whispersystems.signalservice.api.push.ServiceId.ACI
|
||||
import org.whispersystems.signalservice.api.push.ServiceId.PNI
|
||||
import java.util.UUID
|
||||
|
||||
class OneTimePreKeyTableTest {
|
||||
|
||||
private val aci: ACI = ACI.from(UUID.randomUUID())
|
||||
private val pni: PNI = PNI.from(UUID.randomUUID())
|
||||
|
||||
@Test
|
||||
fun markAllStaleIfNecessary_onlyUpdatesMatchingAccountAndZeroValues() {
|
||||
insertTestRecord(aci, id = 1)
|
||||
insertTestRecord(aci, id = 2)
|
||||
insertTestRecord(aci, id = 3, staleTime = 42)
|
||||
insertTestRecord(pni, id = 4)
|
||||
|
||||
val now = System.currentTimeMillis()
|
||||
SignalDatabase.oneTimePreKeys.markAllStaleIfNecessary(aci, now)
|
||||
|
||||
assertEquals(now, getStaleTime(aci, 1))
|
||||
assertEquals(now, getStaleTime(aci, 2))
|
||||
assertEquals(42L, getStaleTime(aci, 3))
|
||||
assertEquals(0L, getStaleTime(pni, 4))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun deleteAllStaleBefore_deleteOldBeforeThreshold() {
|
||||
insertTestRecord(aci, id = 1, staleTime = 10)
|
||||
insertTestRecord(aci, id = 2, staleTime = 10)
|
||||
insertTestRecord(aci, id = 3, staleTime = 10)
|
||||
insertTestRecord(aci, id = 4, staleTime = 15)
|
||||
insertTestRecord(aci, id = 5, staleTime = 0)
|
||||
|
||||
SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0)
|
||||
|
||||
assertNull(getStaleTime(aci, 1))
|
||||
assertNull(getStaleTime(aci, 2))
|
||||
assertNull(getStaleTime(aci, 3))
|
||||
assertNotNull(getStaleTime(aci, 4))
|
||||
assertNotNull(getStaleTime(aci, 5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun deleteAllStaleBefore_neverDeleteStaleOfZero() {
|
||||
insertTestRecord(aci, id = 1, staleTime = 0)
|
||||
insertTestRecord(aci, id = 2, staleTime = 0)
|
||||
insertTestRecord(aci, id = 3, staleTime = 0)
|
||||
insertTestRecord(aci, id = 4, staleTime = 0)
|
||||
insertTestRecord(aci, id = 5, staleTime = 0)
|
||||
|
||||
SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 10, minCount = 0)
|
||||
|
||||
assertNotNull(getStaleTime(aci, 1))
|
||||
assertNotNull(getStaleTime(aci, 2))
|
||||
assertNotNull(getStaleTime(aci, 3))
|
||||
assertNotNull(getStaleTime(aci, 4))
|
||||
assertNotNull(getStaleTime(aci, 5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun deleteAllStaleBefore_respectMinCount() {
|
||||
insertTestRecord(aci, id = 1, staleTime = 10)
|
||||
insertTestRecord(aci, id = 2, staleTime = 10)
|
||||
insertTestRecord(aci, id = 3, staleTime = 10)
|
||||
insertTestRecord(aci, id = 4, staleTime = 10)
|
||||
insertTestRecord(aci, id = 5, staleTime = 10)
|
||||
|
||||
SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3)
|
||||
|
||||
assertNull(getStaleTime(aci, 1))
|
||||
assertNull(getStaleTime(aci, 2))
|
||||
assertNotNull(getStaleTime(aci, 3))
|
||||
assertNotNull(getStaleTime(aci, 4))
|
||||
assertNotNull(getStaleTime(aci, 5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun deleteAllStaleBefore_respectAccount() {
|
||||
insertTestRecord(aci, id = 1, staleTime = 10)
|
||||
insertTestRecord(aci, id = 2, staleTime = 10)
|
||||
insertTestRecord(aci, id = 3, staleTime = 10)
|
||||
|
||||
insertTestRecord(pni, id = 4, staleTime = 10)
|
||||
insertTestRecord(pni, id = 5, staleTime = 10)
|
||||
|
||||
SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 2)
|
||||
|
||||
assertNull(getStaleTime(aci, 1))
|
||||
assertNotNull(getStaleTime(aci, 2))
|
||||
assertNotNull(getStaleTime(aci, 3))
|
||||
assertNotNull(getStaleTime(pni, 4))
|
||||
assertNotNull(getStaleTime(pni, 5))
|
||||
}
|
||||
|
||||
private fun insertTestRecord(account: ServiceId, id: Int, staleTime: Long = 0) {
|
||||
SignalDatabase.oneTimePreKeys.insert(
|
||||
serviceId = account,
|
||||
keyId = id,
|
||||
record = PreKeyRecord(id, Curve.generateKeyPair())
|
||||
)
|
||||
|
||||
val count = SignalDatabase.rawDatabase
|
||||
.update(OneTimePreKeyTable.TABLE_NAME)
|
||||
.values(OneTimePreKeyTable.STALE_TIMESTAMP to staleTime)
|
||||
.where("${OneTimePreKeyTable.ACCOUNT_ID} = ? AND ${OneTimePreKeyTable.KEY_ID} = $id", account)
|
||||
.run()
|
||||
|
||||
assertEquals(1, count)
|
||||
}
|
||||
|
||||
private fun getStaleTime(account: ServiceId, id: Int): Long? {
|
||||
return SignalDatabase.rawDatabase
|
||||
.select(OneTimePreKeyTable.STALE_TIMESTAMP)
|
||||
.from(OneTimePreKeyTable.TABLE_NAME)
|
||||
.where("${OneTimePreKeyTable.ACCOUNT_ID} = ? AND ${OneTimePreKeyTable.KEY_ID} = $id", account)
|
||||
.run()
|
||||
.readToSingleObject { it.requireLongOrNull(OneTimePreKeyTable.STALE_TIMESTAMP) }
|
||||
}
|
||||
}
|
|
@ -32,10 +32,10 @@ import org.whispersystems.signalservice.api.push.DistributionId
|
|||
import org.whispersystems.signalservice.api.push.ServiceId
|
||||
import org.whispersystems.signalservice.api.push.SignalServiceAddress
|
||||
import org.whispersystems.signalservice.internal.push.SignalServiceProtos
|
||||
import java.lang.UnsupportedOperationException
|
||||
import java.util.Optional
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.locks.ReentrantLock
|
||||
import kotlin.UnsupportedOperationException
|
||||
|
||||
/**
|
||||
* Welcome to Bob's Client.
|
||||
|
@ -144,7 +144,6 @@ class BobClient(val serviceId: ServiceId, val e164: String, val identityKeyPair:
|
|||
override fun getSubDeviceSessions(name: String?): List<Int> = emptyList()
|
||||
override fun containsSession(address: SignalProtocolAddress?): Boolean = aliceSessionRecord != null
|
||||
override fun getIdentity(address: SignalProtocolAddress?): IdentityKey = SignalStore.account().aciIdentityKey.publicKey
|
||||
|
||||
override fun loadPreKey(preKeyId: Int): PreKeyRecord = throw UnsupportedOperationException()
|
||||
override fun storePreKey(preKeyId: Int, record: PreKeyRecord?) = throw UnsupportedOperationException()
|
||||
override fun containsPreKey(preKeyId: Int): Boolean = throw UnsupportedOperationException()
|
||||
|
@ -162,6 +161,8 @@ class BobClient(val serviceId: ServiceId, val e164: String, val identityKeyPair:
|
|||
override fun storeKyberPreKey(kyberPreKeyId: Int, record: KyberPreKeyRecord?) = throw UnsupportedOperationException()
|
||||
override fun containsKyberPreKey(kyberPreKeyId: Int): Boolean = throw UnsupportedOperationException()
|
||||
override fun markKyberPreKeyUsed(kyberPreKeyId: Int) = throw UnsupportedOperationException()
|
||||
override fun deleteAllStaleOneTimeEcPreKeys(threshold: Long, minCount: Int) = throw UnsupportedOperationException()
|
||||
override fun markAllOneTimeEcPreKeysStaleIfNecessary(staleTime: Long) = throw UnsupportedOperationException()
|
||||
override fun storeSenderKey(sender: SignalProtocolAddress?, distributionId: UUID?, record: SenderKeyRecord?) = throw UnsupportedOperationException()
|
||||
override fun loadSenderKey(sender: SignalProtocolAddress?, distributionId: UUID?): SenderKeyRecord = throw UnsupportedOperationException()
|
||||
override fun archiveSession(address: SignalProtocolAddress?) = throw UnsupportedOperationException()
|
||||
|
@ -171,8 +172,9 @@ class BobClient(val serviceId: ServiceId, val e164: String, val identityKeyPair:
|
|||
override fun clearSenderKeySharedWith(addresses: MutableCollection<SignalProtocolAddress>?) = throw UnsupportedOperationException()
|
||||
override fun storeLastResortKyberPreKey(kyberPreKeyId: Int, kyberPreKeyRecord: KyberPreKeyRecord) = throw UnsupportedOperationException()
|
||||
override fun removeKyberPreKey(kyberPreKeyId: Int) = throw UnsupportedOperationException()
|
||||
override fun markAllOneTimeKyberPreKeysStaleIfNecessary(staleTime: Long) = throw UnsupportedOperationException()
|
||||
override fun deleteAllStaleOneTimeKyberPreKeys(threshold: Long, minCount: Int) = throw UnsupportedOperationException()
|
||||
override fun loadLastResortKyberPreKeys(): List<KyberPreKeyRecord> = throw UnsupportedOperationException()
|
||||
|
||||
override fun isMultiDevice(): Boolean = throw UnsupportedOperationException()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,10 +49,11 @@ public class PreKeyUtil {
|
|||
private static final int BATCH_SIZE = 100;
|
||||
private static final long ARCHIVE_AGE = TimeUnit.DAYS.toMillis(30);
|
||||
|
||||
public synchronized static @NonNull List<PreKeyRecord> generateAndStoreOneTimeEcPreKeys(@NonNull SignalProtocolStore protocolStore, @NonNull PreKeyMetadataStore metadataStore) {
|
||||
public synchronized static @NonNull List<PreKeyRecord> generateAndStoreOneTimeEcPreKeys(@NonNull SignalServiceAccountDataStore protocolStore, @NonNull PreKeyMetadataStore metadataStore) {
|
||||
int startingId = metadataStore.getNextEcOneTimePreKeyId();
|
||||
final List<PreKeyRecord> records = generateOneTimeEcPreKeys(startingId);
|
||||
|
||||
protocolStore.markAllOneTimeEcPreKeysStaleIfNecessary(System.currentTimeMillis());
|
||||
storeOneTimeEcPreKeys(protocolStore, metadataStore, records);
|
||||
|
||||
return records;
|
||||
|
@ -92,10 +93,11 @@ public class PreKeyUtil {
|
|||
|
||||
}
|
||||
|
||||
public synchronized static @NonNull List<KyberPreKeyRecord> generateAndStoreOneTimeKyberPreKeys(@NonNull SignalProtocolStore protocolStore, @NonNull PreKeyMetadataStore metadataStore) {
|
||||
public synchronized static @NonNull List<KyberPreKeyRecord> generateAndStoreOneTimeKyberPreKeys(@NonNull SignalServiceAccountDataStore protocolStore, @NonNull PreKeyMetadataStore metadataStore) {
|
||||
int startingId = metadataStore.getNextKyberPreKeyId();
|
||||
List<KyberPreKeyRecord> records = generateOneTimeKyberPreKeyRecords(startingId, protocolStore.getIdentityKeyPair().getPrivateKey());
|
||||
|
||||
protocolStore.markAllOneTimeKyberPreKeysStaleIfNecessary(System.currentTimeMillis());
|
||||
storeOneTimeKyberPreKeys(protocolStore, metadataStore, records);
|
||||
|
||||
return records;
|
||||
|
@ -264,4 +266,12 @@ public class PreKeyUtil {
|
|||
Log.w(TAG, e);
|
||||
}
|
||||
}
|
||||
|
||||
public synchronized static void cleanOneTimePreKeys(@NonNull SignalServiceAccountDataStore protocolStore) {
|
||||
long threshold = System.currentTimeMillis() - TimeUnit.DAYS.toMillis(90);
|
||||
int minCount = 200;
|
||||
|
||||
protocolStore.deleteAllStaleOneTimeEcPreKeys(threshold, minCount);
|
||||
protocolStore.deleteAllStaleOneTimeKyberPreKeys(threshold, minCount);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -67,4 +67,16 @@ class SignalKyberPreKeyStore(private val selfServiceId: ServiceId) : SignalServi
|
|||
SignalDatabase.kyberPreKeys.delete(selfServiceId, kyberPreKeyId)
|
||||
}
|
||||
}
|
||||
|
||||
override fun markAllOneTimeKyberPreKeysStaleIfNecessary(staleTime: Long) {
|
||||
ReentrantSessionLock.INSTANCE.acquire().use {
|
||||
SignalDatabase.kyberPreKeys.markAllStaleIfNecessary(selfServiceId, staleTime)
|
||||
}
|
||||
}
|
||||
|
||||
override fun deleteAllStaleOneTimeKyberPreKeys(threshold: Long, minCount: Int) {
|
||||
ReentrantSessionLock.INSTANCE.acquire().use {
|
||||
SignalDatabase.kyberPreKeys.deleteAllStaleBefore(selfServiceId, threshold, minCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -100,6 +100,16 @@ public class SignalServiceAccountDataStoreImpl implements SignalServiceAccountDa
|
|||
preKeyStore.removePreKey(preKeyId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void markAllOneTimeEcPreKeysStaleIfNecessary(long staleTime) {
|
||||
preKeyStore.markAllOneTimeEcPreKeysStaleIfNecessary(staleTime);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteAllStaleOneTimeEcPreKeys(long threshold, int minCount) {
|
||||
preKeyStore.deleteAllStaleOneTimeEcPreKeys(threshold, minCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SessionRecord loadSession(SignalProtocolAddress axolotlAddress) {
|
||||
return sessionStore.loadSession(axolotlAddress);
|
||||
|
@ -211,6 +221,16 @@ public class SignalServiceAccountDataStoreImpl implements SignalServiceAccountDa
|
|||
kyberPreKeyStore.removeKyberPreKey(kyberPreKeyId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void markAllOneTimeKyberPreKeysStaleIfNecessary(long staleTime) {
|
||||
kyberPreKeyStore.markAllOneTimeKyberPreKeysStaleIfNecessary(staleTime);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteAllStaleOneTimeKyberPreKeys(long threshold, int minCount) {
|
||||
kyberPreKeyStore.deleteAllStaleOneTimeKyberPreKeys(threshold, minCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void storeSenderKey(SignalProtocolAddress sender, UUID distributionId, SenderKeyRecord record) {
|
||||
senderKeyStore.storeSenderKey(sender, distributionId, record);
|
||||
|
@ -251,5 +271,4 @@ public class SignalServiceAccountDataStoreImpl implements SignalServiceAccountDa
|
|||
public @NonNull SignalSenderKeyStore senderKeys() {
|
||||
return senderKeyStore;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -10,12 +10,13 @@ import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
|
|||
import org.signal.libsignal.protocol.state.SignedPreKeyStore;
|
||||
import org.thoughtcrime.securesms.crypto.ReentrantSessionLock;
|
||||
import org.thoughtcrime.securesms.database.SignalDatabase;
|
||||
import org.whispersystems.signalservice.api.SignalServicePreKeyStore;
|
||||
import org.whispersystems.signalservice.api.SignalSessionLock;
|
||||
import org.whispersystems.signalservice.api.push.ServiceId;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class TextSecurePreKeyStore implements PreKeyStore, SignedPreKeyStore {
|
||||
public class TextSecurePreKeyStore implements SignalServicePreKeyStore, SignedPreKeyStore {
|
||||
|
||||
@SuppressWarnings("unused")
|
||||
private static final String TAG = Log.tag(TextSecurePreKeyStore.class);
|
||||
|
@ -87,4 +88,14 @@ public class TextSecurePreKeyStore implements PreKeyStore, SignedPreKeyStore {
|
|||
public void removeSignedPreKey(int signedPreKeyId) {
|
||||
SignalDatabase.signedPreKeys().delete(accountId, signedPreKeyId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void markAllOneTimeEcPreKeysStaleIfNecessary(long staleTime) {
|
||||
SignalDatabase.oneTimePreKeys().markAllStaleIfNecessary(accountId, staleTime);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteAllStaleOneTimeEcPreKeys(long threshold, int minCount) {
|
||||
SignalDatabase.oneTimePreKeys().deleteAllStaleBefore(accountId, threshold, minCount);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,12 +4,14 @@ import android.content.Context
|
|||
import org.signal.core.util.delete
|
||||
import org.signal.core.util.exists
|
||||
import org.signal.core.util.insertInto
|
||||
import org.signal.core.util.logging.Log
|
||||
import org.signal.core.util.readToList
|
||||
import org.signal.core.util.readToSingleObject
|
||||
import org.signal.core.util.requireBoolean
|
||||
import org.signal.core.util.requireNonNullBlob
|
||||
import org.signal.core.util.select
|
||||
import org.signal.core.util.toInt
|
||||
import org.signal.core.util.update
|
||||
import org.signal.libsignal.protocol.state.KyberPreKeyRecord
|
||||
import org.whispersystems.signalservice.api.push.ServiceId
|
||||
|
||||
|
@ -18,6 +20,8 @@ import org.whispersystems.signalservice.api.push.ServiceId
|
|||
*/
|
||||
class KyberPreKeyTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTable(context, databaseHelper) {
|
||||
companion object {
|
||||
private val TAG = Log.tag(KyberPreKeyTable::class.java)
|
||||
|
||||
const val TABLE_NAME = "kyber_prekey"
|
||||
const val ID = "_id"
|
||||
const val ACCOUNT_ID = "account_id"
|
||||
|
@ -25,6 +29,8 @@ class KyberPreKeyTable(context: Context, databaseHelper: SignalDatabase) : Datab
|
|||
const val TIMESTAMP = "timestamp"
|
||||
const val LAST_RESORT = "last_resort"
|
||||
const val SERIALIZED = "serialized"
|
||||
const val STALE_TIMESTAMP = "stale_timestamp"
|
||||
|
||||
const val CREATE_TABLE = """
|
||||
CREATE TABLE $TABLE_NAME (
|
||||
$ID INTEGER PRIMARY KEY,
|
||||
|
@ -33,6 +39,7 @@ class KyberPreKeyTable(context: Context, databaseHelper: SignalDatabase) : Datab
|
|||
$TIMESTAMP INTEGER NOT NULL,
|
||||
$LAST_RESORT INTEGER NOT NULL,
|
||||
$SERIALIZED BLOB NOT NULL,
|
||||
$STALE_TIMESTAMP INTEGER NOT NULL DEFAULT 0,
|
||||
UNIQUE($ACCOUNT_ID, $KEY_ID)
|
||||
)
|
||||
"""
|
||||
|
@ -120,6 +127,48 @@ class KyberPreKeyTable(context: Context, databaseHelper: SignalDatabase) : Datab
|
|||
.run()
|
||||
}
|
||||
|
||||
fun markAllStaleIfNecessary(serviceId: ServiceId, staleTime: Long) {
|
||||
writableDatabase
|
||||
.update(TABLE_NAME)
|
||||
.values(STALE_TIMESTAMP to staleTime)
|
||||
.where("$ACCOUNT_ID = ? AND $STALE_TIMESTAMP = 0 AND $LAST_RESORT = 0", serviceId)
|
||||
.run()
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes all keys that have been stale since before the specified threshold.
|
||||
* We will always keep at least [minCount] items, preferring more recent ones.
|
||||
*/
|
||||
fun deleteAllStaleBefore(serviceId: ServiceId, threshold: Long, minCount: Int) {
|
||||
val count = writableDatabase
|
||||
.delete(TABLE_NAME)
|
||||
.where(
|
||||
"""
|
||||
$ACCOUNT_ID = ?
|
||||
AND $LAST_RESORT = 0
|
||||
AND $STALE_TIMESTAMP > 0
|
||||
AND $STALE_TIMESTAMP < $threshold
|
||||
AND $ID NOT IN (
|
||||
SELECT $ID
|
||||
FROM $TABLE_NAME
|
||||
WHERE
|
||||
$ACCOUNT_ID = ?
|
||||
AND $LAST_RESORT = 0
|
||||
ORDER BY
|
||||
CASE $STALE_TIMESTAMP WHEN 0 THEN 1 ELSE 0 END DESC,
|
||||
$STALE_TIMESTAMP DESC,
|
||||
$ID DESC
|
||||
LIMIT $minCount
|
||||
)
|
||||
""",
|
||||
serviceId,
|
||||
serviceId
|
||||
)
|
||||
.run()
|
||||
|
||||
Log.i(TAG, "Deleted $count stale one-time EC prekeys.")
|
||||
}
|
||||
|
||||
data class KyberPreKey(
|
||||
val record: KyberPreKeyRecord,
|
||||
val lastResort: Boolean
|
||||
|
|
|
@ -3,8 +3,10 @@ package org.thoughtcrime.securesms.database
|
|||
import android.content.Context
|
||||
import androidx.core.content.contentValuesOf
|
||||
import org.signal.core.util.SqlUtil
|
||||
import org.signal.core.util.delete
|
||||
import org.signal.core.util.logging.Log
|
||||
import org.signal.core.util.requireNonNullString
|
||||
import org.signal.core.util.update
|
||||
import org.signal.libsignal.protocol.InvalidKeyException
|
||||
import org.signal.libsignal.protocol.ecc.Curve
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair
|
||||
|
@ -23,6 +25,8 @@ class OneTimePreKeyTable(context: Context, databaseHelper: SignalDatabase) : Dat
|
|||
const val KEY_ID = "key_id"
|
||||
const val PUBLIC_KEY = "public_key"
|
||||
const val PRIVATE_KEY = "private_key"
|
||||
const val STALE_TIMESTAMP = "stale_timestamp"
|
||||
|
||||
const val CREATE_TABLE = """
|
||||
CREATE TABLE $TABLE_NAME (
|
||||
$ID INTEGER PRIMARY KEY,
|
||||
|
@ -30,6 +34,7 @@ class OneTimePreKeyTable(context: Context, databaseHelper: SignalDatabase) : Dat
|
|||
$KEY_ID INTEGER UNIQUE,
|
||||
$PUBLIC_KEY TEXT NOT NULL,
|
||||
$PRIVATE_KEY TEXT NOT NULL,
|
||||
$STALE_TIMESTAMP INTEGER NOT NULL DEFAULT 0,
|
||||
UNIQUE($ACCOUNT_ID, $KEY_ID)
|
||||
)
|
||||
"""
|
||||
|
@ -68,4 +73,43 @@ class OneTimePreKeyTable(context: Context, databaseHelper: SignalDatabase) : Dat
|
|||
val database = databaseHelper.signalWritableDatabase
|
||||
database.delete(TABLE_NAME, "$ACCOUNT_ID = ? AND $KEY_ID = ?", SqlUtil.buildArgs(serviceId, keyId))
|
||||
}
|
||||
|
||||
fun markAllStaleIfNecessary(serviceId: ServiceId, staleTime: Long) {
|
||||
writableDatabase
|
||||
.update(TABLE_NAME)
|
||||
.values(STALE_TIMESTAMP to staleTime)
|
||||
.where("$ACCOUNT_ID = ? AND $STALE_TIMESTAMP = 0", serviceId)
|
||||
.run()
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes all keys that have been stale since before the specified threshold.
|
||||
* We will always keep at least [minCount] items, preferring more recent ones.
|
||||
*/
|
||||
fun deleteAllStaleBefore(serviceId: ServiceId, threshold: Long, minCount: Int) {
|
||||
val count = writableDatabase
|
||||
.delete(TABLE_NAME)
|
||||
.where(
|
||||
"""
|
||||
$ACCOUNT_ID = ?
|
||||
AND $STALE_TIMESTAMP > 0
|
||||
AND $STALE_TIMESTAMP < $threshold
|
||||
AND $ID NOT IN (
|
||||
SELECT $ID
|
||||
FROM $TABLE_NAME
|
||||
WHERE $ACCOUNT_ID = ?
|
||||
ORDER BY
|
||||
CASE $STALE_TIMESTAMP WHEN 0 THEN 1 ELSE 0 END DESC,
|
||||
$STALE_TIMESTAMP DESC,
|
||||
$ID DESC
|
||||
LIMIT $minCount
|
||||
)
|
||||
""",
|
||||
serviceId,
|
||||
serviceId
|
||||
)
|
||||
.run()
|
||||
|
||||
Log.i(TAG, "Deleted $count stale one-time EC prekeys.")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,6 +58,7 @@ import org.thoughtcrime.securesms.database.helpers.migration.V199_AddThreadActiv
|
|||
import org.thoughtcrime.securesms.database.helpers.migration.V200_ResetPniColumn
|
||||
import org.thoughtcrime.securesms.database.helpers.migration.V201_RecipientTableValidations
|
||||
import org.thoughtcrime.securesms.database.helpers.migration.V202_DropMessageTableThreadDateIndex
|
||||
import org.thoughtcrime.securesms.database.helpers.migration.V203_PreKeyStaleTimestamp
|
||||
|
||||
/**
|
||||
* Contains all of the database migrations for [SignalDatabase]. Broken into a separate file for cleanliness.
|
||||
|
@ -66,7 +67,7 @@ object SignalDatabaseMigrations {
|
|||
|
||||
val TAG: String = Log.tag(SignalDatabaseMigrations.javaClass)
|
||||
|
||||
const val DATABASE_VERSION = 202
|
||||
const val DATABASE_VERSION = 203
|
||||
|
||||
@JvmStatic
|
||||
fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
|
||||
|
@ -285,6 +286,10 @@ object SignalDatabaseMigrations {
|
|||
if (oldVersion < 202) {
|
||||
V202_DropMessageTableThreadDateIndex.migrate(context, db, oldVersion, newVersion)
|
||||
}
|
||||
|
||||
if (oldVersion < 203) {
|
||||
V203_PreKeyStaleTimestamp.migrate(context, db, oldVersion, newVersion)
|
||||
}
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.thoughtcrime.securesms.database.helpers.migration
|
||||
|
||||
import android.app.Application
|
||||
import androidx.sqlite.db.SupportSQLiteDatabase
|
||||
import net.zetetic.database.sqlcipher.SQLiteDatabase
|
||||
|
||||
/**
|
||||
* Keep track of a "stale timestamp" for one-time prekeys so that we can know when it's safe to delete them.
|
||||
*/
|
||||
@Suppress("ClassName")
|
||||
object V203_PreKeyStaleTimestamp : SignalDatabaseMigration {
|
||||
override fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
|
||||
// Note: Because of a sequencing issue between beta/nightly, we had two V202 migrations (of which this used to be one of them),
|
||||
// so we have to do some conditional migrating based on the user's current state.
|
||||
db.execSQL("DROP INDEX IF EXISTS message_thread_date_index")
|
||||
|
||||
if (!columnExists(db, "one_time_prekeys", "stale_timestamp")) {
|
||||
db.execSQL("ALTER TABLE one_time_prekeys ADD COLUMN stale_timestamp INTEGER NOT NULL DEFAULT 0")
|
||||
db.execSQL("ALTER TABLE kyber_prekey ADD COLUMN stale_timestamp INTEGER NOT NULL DEFAULT 0")
|
||||
}
|
||||
}
|
||||
|
||||
private fun columnExists(db: SupportSQLiteDatabase, table: String, column: String): Boolean {
|
||||
db.query("PRAGMA table_info($table)", null).use { cursor ->
|
||||
val nameColumnIndex = cursor.getColumnIndexOrThrow("name")
|
||||
while (cursor.moveToNext()) {
|
||||
val name = cursor.getString(nameColumnIndex)
|
||||
if (name == column) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
|
@ -181,6 +181,7 @@ class PreKeysSyncJob private constructor(parameters: Parameters) : BaseJob(param
|
|||
log(serviceIdType, "Cleaning prekeys...")
|
||||
PreKeyUtil.cleanSignedPreKeys(protocolStore, metadataStore)
|
||||
PreKeyUtil.cleanLastResortKyberPreKeys(protocolStore, metadataStore)
|
||||
PreKeyUtil.cleanOneTimePreKeys(protocolStore)
|
||||
}
|
||||
|
||||
private fun signedPreKeyUploadIfNeeded(serviceIdType: ServiceIdType, protocolStore: SignalProtocolStore, metadataStore: PreKeyMetadataStore): SignedPreKeyRecord? {
|
||||
|
|
|
@ -79,6 +79,14 @@ class BufferedKyberPreKeyStore(private val selfServiceId: ServiceId) : SignalSer
|
|||
error("Not expected in this flow")
|
||||
}
|
||||
|
||||
override fun markAllOneTimeKyberPreKeysStaleIfNecessary(staleTime: Long) {
|
||||
error("Not expected in this flow")
|
||||
}
|
||||
|
||||
override fun deleteAllStaleOneTimeKyberPreKeys(threshold: Long, minCount: Int) {
|
||||
error("Not expected in this flow")
|
||||
}
|
||||
|
||||
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
|
||||
for (id in removedIfNotLastResort) {
|
||||
persistentStore.markKyberPreKeyUsed(id)
|
||||
|
|
|
@ -141,10 +141,26 @@ class BufferedSignalServiceAccountDataStore(selfServiceId: ServiceId) : SignalSe
|
|||
return kyberPreKeyStore.markKyberPreKeyUsed(kyberPreKeyId)
|
||||
}
|
||||
|
||||
override fun deleteAllStaleOneTimeEcPreKeys(threshold: Long, minCount: Int) {
|
||||
error("Should not happen during the intended usage pattern of this class")
|
||||
}
|
||||
|
||||
override fun markAllOneTimeEcPreKeysStaleIfNecessary(staleTime: Long) {
|
||||
error("Should not happen during the intended usage pattern of this class")
|
||||
}
|
||||
|
||||
override fun removeKyberPreKey(kyberPreKeyId: Int) {
|
||||
kyberPreKeyStore.removeKyberPreKey(kyberPreKeyId)
|
||||
}
|
||||
|
||||
override fun markAllOneTimeKyberPreKeysStaleIfNecessary(staleTime: Long) {
|
||||
kyberPreKeyStore.markAllOneTimeKyberPreKeysStaleIfNecessary(staleTime)
|
||||
}
|
||||
|
||||
override fun deleteAllStaleOneTimeKyberPreKeys(threshold: Long, minCount: Int) {
|
||||
kyberPreKeyStore.deleteAllStaleOneTimeKyberPreKeys(threshold, minCount)
|
||||
}
|
||||
|
||||
override fun loadLastResortKyberPreKeys(): List<KyberPreKeyRecord> {
|
||||
return kyberPreKeyStore.loadLastResortKyberPreKeys()
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.thoughtcrime.securesms.jobmanager.Job;
|
|||
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint;
|
||||
import org.thoughtcrime.securesms.keyvalue.SignalStore;
|
||||
import org.thoughtcrime.securesms.recipients.Recipient;
|
||||
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore;
|
||||
import org.whispersystems.signalservice.api.SignalServiceAccountManager;
|
||||
import org.whispersystems.signalservice.api.account.PreKeyUpload;
|
||||
import org.whispersystems.signalservice.api.push.ServiceId.PNI;
|
||||
|
@ -70,9 +71,9 @@ public class PniAccountInitializationMigrationJob extends MigrationJob {
|
|||
Log.w(TAG, "Already generated the PNI identity. Skipping this step.");
|
||||
}
|
||||
|
||||
SignalServiceAccountManager accountManager = ApplicationDependencies.getSignalServiceAccountManager();
|
||||
SignalProtocolStore protocolStore = ApplicationDependencies.getProtocolStore().pni();
|
||||
PreKeyMetadataStore metadataStore = SignalStore.account().pniPreKeys();
|
||||
SignalServiceAccountManager accountManager = ApplicationDependencies.getSignalServiceAccountManager();
|
||||
SignalServiceAccountDataStore protocolStore = ApplicationDependencies.getProtocolStore().pni();
|
||||
PreKeyMetadataStore metadataStore = SignalStore.account().pniPreKeys();
|
||||
|
||||
if (!metadataStore.isSignedPreKeyRegistered()) {
|
||||
Log.i(TAG, "Uploading signed prekey for PNI.");
|
||||
|
|
|
@ -7,6 +7,7 @@ import org.signal.libsignal.protocol.state.SignalProtocolStore;
|
|||
* in the service layer, but not the protocol layer.
|
||||
*/
|
||||
public interface SignalServiceAccountDataStore extends SignalProtocolStore,
|
||||
SignalServicePreKeyStore,
|
||||
SignalServiceSessionStore,
|
||||
SignalServiceSenderKeyStore,
|
||||
SignalServiceKyberPreKeyStore {
|
||||
|
|
|
@ -23,4 +23,14 @@ interface SignalServiceKyberPreKeyStore : KyberPreKeyStore {
|
|||
* Unconditionally remove the specified key from the store.
|
||||
*/
|
||||
fun removeKyberPreKey(kyberPreKeyId: Int)
|
||||
|
||||
/**
|
||||
* Marks all prekeys stale if they haven't been marked already. "Stale" means the time that the keys have been replaced.
|
||||
*/
|
||||
fun markAllOneTimeKyberPreKeysStaleIfNecessary(staleTime: Long)
|
||||
|
||||
/**
|
||||
* Deletes all prekeys that have been stale since before the threshold. "Stale" means the time that the keys have been replaced.
|
||||
*/
|
||||
fun deleteAllStaleOneTimeKyberPreKeys(threshold: Long, minCount: Int)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
package org.whispersystems.signalservice.api
|
||||
|
||||
import org.signal.libsignal.protocol.state.PreKeyStore
|
||||
|
||||
/**
|
||||
* And extension of the normal protocol prekey store interface that has additional methods that are
|
||||
* needed in the service layer, but not the protocol layer.
|
||||
*/
|
||||
interface SignalServicePreKeyStore : PreKeyStore {
|
||||
/**
|
||||
* Marks all prekeys stale if they haven't been marked already. "Stale" means the time that the keys have been replaced.
|
||||
*/
|
||||
fun markAllOneTimeEcPreKeysStaleIfNecessary(staleTime: Long)
|
||||
|
||||
/**
|
||||
* Deletes all prekeys that have been stale since before the threshold. "Stale" means the time that the keys have been replaced.
|
||||
*/
|
||||
fun deleteAllStaleOneTimeEcPreKeys(threshold: Long, minCount: Int)
|
||||
}
|
Loading…
Add table
Reference in a new issue