Perform message decryptions in batches.

This commit is contained in:
Greyson Parrelli 2023-03-09 17:05:00 -05:00
parent 04baa7925f
commit 894095414a
17 changed files with 772 additions and 69 deletions

View file

@ -6,6 +6,7 @@ import org.signal.libsignal.zkgroup.profiles.ProfileKey
import org.thoughtcrime.securesms.crypto.ProfileKeyUtil
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.messages.protocol.BufferedProtocolStore
import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.testing.FakeClientHelpers.toEnvelope
import org.whispersystems.signalservice.api.push.ServiceId
@ -35,7 +36,9 @@ class AliceClient(val serviceId: ServiceId, val e164: String, val trustRoot: ECK
fun process(envelope: Envelope, serverDeliveredTimestamp: Long) {
val start = System.currentTimeMillis()
ApplicationDependencies.getIncomingMessageObserver().processEnvelope(envelope, serverDeliveredTimestamp)
val bufferedStore = BufferedProtocolStore.create()
ApplicationDependencies.getIncomingMessageObserver().processEnvelope(bufferedStore, envelope, serverDeliveredTimestamp)
bufferedStore.flushToDisk()
val end = System.currentTimeMillis()
Log.d(TAG, "${end - start}")
}

View file

@ -14,6 +14,7 @@ import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.messages.MessageContentProcessor.ExceptionMetadata
import org.thoughtcrime.securesms.messages.MessageContentProcessor.MessageState
import org.thoughtcrime.securesms.messages.MessageDecryptor
import org.thoughtcrime.securesms.messages.protocol.BufferedProtocolStore
import org.thoughtcrime.securesms.notifications.NotificationChannels
import org.thoughtcrime.securesms.notifications.NotificationIds
import org.thoughtcrime.securesms.transport.RetryLaterException
@ -77,7 +78,9 @@ class PushDecryptMessageJob private constructor(
throw RetryLaterException()
}
val result = MessageDecryptor.decrypt(context, envelope.proto, envelope.serverDeliveredTimestamp)
val bufferedProtocolStore = BufferedProtocolStore.create()
val result = MessageDecryptor.decrypt(context, bufferedProtocolStore, envelope.proto, envelope.serverDeliveredTimestamp)
bufferedProtocolStore.flushToDisk()
when (result) {
is MessageDecryptor.Result.Success -> {

View file

@ -11,6 +11,7 @@ import org.thoughtcrime.securesms.database.model.MessageId;
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies;
import org.thoughtcrime.securesms.jobmanager.Data;
import org.thoughtcrime.securesms.jobmanager.Job;
import org.thoughtcrime.securesms.jobmanager.impl.DecryptionsDrainedConstraint;
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint;
import org.thoughtcrime.securesms.net.NotPushRegisteredException;
import org.thoughtcrime.securesms.recipients.Recipient;
@ -51,6 +52,7 @@ public class SendDeliveryReceiptJob extends BaseJob {
public SendDeliveryReceiptJob(@NonNull RecipientId recipientId, long messageSentTimestamp, @NonNull MessageId messageId) {
this(new Job.Parameters.Builder()
.addConstraint(NetworkConstraint.KEY)
.addConstraint(DecryptionsDrainedConstraint.KEY)
.setLifespan(TimeUnit.DAYS.toMillis(1))
.setMaxAttempts(Parameters.UNLIMITED)
.setQueue(recipientId.toQueueKey())

View file

@ -16,6 +16,7 @@ import org.thoughtcrime.securesms.dependencies.ApplicationDependencies;
import org.thoughtcrime.securesms.jobmanager.Data;
import org.thoughtcrime.securesms.jobmanager.Job;
import org.thoughtcrime.securesms.jobmanager.JobManager;
import org.thoughtcrime.securesms.jobmanager.impl.DecryptionsDrainedConstraint;
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint;
import org.thoughtcrime.securesms.net.NotPushRegisteredException;
import org.thoughtcrime.securesms.recipients.Recipient;
@ -65,6 +66,7 @@ public class SendReadReceiptJob extends BaseJob {
public SendReadReceiptJob(long threadId, @NonNull RecipientId recipientId, List<Long> messageSentTimestamps, List<MessageId> messageIds) {
this(new Job.Parameters.Builder()
.addConstraint(NetworkConstraint.KEY)
.addConstraint(DecryptionsDrainedConstraint.KEY)
.setLifespan(TimeUnit.DAYS.toMillis(1))
.setMaxAttempts(Parameters.UNLIMITED)
.setQueue(recipientId.toQueueKey())

View file

@ -72,7 +72,7 @@ class PersistentLogger(
}
private fun write(level: String, tag: String?, message: String?, t: Throwable?, keepLonger: Boolean) {
logEntries.add(LogRequest(level, tag ?: "null", message, Date(), getThreadString(), t, keepLonger))
logEntries.add(LogRequest(level, tag ?: "null", message, System.currentTimeMillis(), getThreadString(), t, keepLonger))
}
private fun getThreadString(): String {
@ -95,7 +95,7 @@ class PersistentLogger(
val level: String,
val tag: String,
val message: String?,
val date: Date,
val createTime: Long,
val threadString: String,
val throwable: Throwable?,
val keepLonger: Boolean
@ -121,11 +121,13 @@ class PersistentLogger(
fun requestToEntries(request: LogRequest): List<LogEntry> {
val out = mutableListOf<LogEntry>()
val createDate = Date(request.createTime)
out.add(
LogEntry(
createdAt = request.date.time,
createdAt = request.createTime,
keepLonger = request.keepLonger,
body = formatBody(request.threadString, request.date, request.level, request.tag, request.message)
body = formatBody(request.threadString, createDate, request.level, request.tag, request.message)
)
)
@ -138,9 +140,9 @@ class PersistentLogger(
val entries = lines.map { line ->
LogEntry(
createdAt = request.date.time,
createdAt = request.createTime,
keepLonger = request.keepLonger,
body = formatBody(request.threadString, request.date, request.level, request.tag, line)
body = formatBody(request.threadString, createDate, request.level, request.tag, line)
)
}

View file

@ -14,7 +14,9 @@ import androidx.core.app.NotificationCompat
import org.signal.core.util.ThreadUtil
import org.signal.core.util.concurrent.SignalExecutors
import org.signal.core.util.logging.Log
import org.signal.core.util.withinTransaction
import org.thoughtcrime.securesms.R
import org.thoughtcrime.securesms.crypto.ReentrantSessionLock
import org.thoughtcrime.securesms.database.MessageTable
import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
@ -28,6 +30,7 @@ import org.thoughtcrime.securesms.jobs.PushDecryptMessageJob
import org.thoughtcrime.securesms.jobs.PushProcessMessageJob
import org.thoughtcrime.securesms.jobs.UnableToStartException
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.messages.protocol.BufferedProtocolStore
import org.thoughtcrime.securesms.notifications.NotificationChannels
import org.thoughtcrime.securesms.recipients.RecipientId
import org.thoughtcrime.securesms.util.AppForegroundObserver
@ -200,7 +203,7 @@ class IncomingMessageObserver(private val context: Application) {
val needsConnectionString = if (conclusion) "Needs Connection" else "Does Not Need Connection"
Log.d(TAG, "[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisible, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, Stay open requests: [${keepAliveTokens.entries}], Registered: $registered, Proxy: $hasProxy, Force websocket: $forceWebsocket, Decrypt Queue Empty: $decryptQueueEmpty")
Log.d(TAG, "[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisible, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, Stay open requests: ${keepAliveTokens.entries}, Registered: $registered, Proxy: $hasProxy, Force websocket: $forceWebsocket, Decrypt Queue Empty: $decryptQueueEmpty")
return conclusion
}
}
@ -249,19 +252,29 @@ class IncomingMessageObserver(private val context: Application) {
}
@VisibleForTesting
fun processEnvelope(envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long) {
when (envelope.type.number) {
SignalServiceProtos.Envelope.Type.RECEIPT_VALUE -> processReceipt(envelope)
fun processEnvelope(bufferedProtocolStore: BufferedProtocolStore, envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long): List<Runnable>? {
return when (envelope.type.number) {
SignalServiceProtos.Envelope.Type.RECEIPT_VALUE -> {
processReceipt(envelope)
null
}
SignalServiceProtos.Envelope.Type.PREKEY_BUNDLE_VALUE,
SignalServiceProtos.Envelope.Type.CIPHERTEXT_VALUE,
SignalServiceProtos.Envelope.Type.UNIDENTIFIED_SENDER_VALUE,
SignalServiceProtos.Envelope.Type.PLAINTEXT_CONTENT_VALUE -> processMessage(envelope, serverDeliveredTimestamp)
else -> Log.w(TAG, "Received envelope of unknown type: " + envelope.type)
SignalServiceProtos.Envelope.Type.PLAINTEXT_CONTENT_VALUE -> {
processMessage(bufferedProtocolStore, envelope, serverDeliveredTimestamp)
}
else -> {
Log.w(TAG, "Received envelope of unknown type: " + envelope.type)
null
}
}
}
private fun processMessage(envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long) {
val result = MessageDecryptor.decrypt(context, envelope, serverDeliveredTimestamp)
private fun processMessage(bufferedProtocolStore: BufferedProtocolStore, envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long): List<Runnable> {
val result = MessageDecryptor.decrypt(context, bufferedProtocolStore, envelope, serverDeliveredTimestamp)
when (result) {
is MessageDecryptor.Result.Success -> {
@ -297,7 +310,7 @@ class IncomingMessageObserver(private val context: Application) {
}
}
result.followUpOperations.forEach { it.run() }
return result.followUpOperations
}
private fun processReceipt(envelope: SignalServiceProtos.Envelope) {
@ -386,13 +399,31 @@ class IncomingMessageObserver(private val context: Application) {
signalWebSocket.connect()
try {
val bufferedStore = BufferedProtocolStore.create()
while (isConnectionNecessary()) {
try {
Log.d(TAG, "Reading message...")
val hasMore = signalWebSocket.readMessage(WEBSOCKET_READ_TIMEOUT) { envelope, serverDeliveredTimestamp ->
Log.i(TAG, "Retrieved envelope! " + envelope.timestamp)
processEnvelope(envelope, serverDeliveredTimestamp)
val hasMore = signalWebSocket.readMessageBatch(WEBSOCKET_READ_TIMEOUT, 30) { batch ->
Log.i(TAG, "Retrieved ${batch.size} envelopes!")
val startTime = System.currentTimeMillis()
ReentrantSessionLock.INSTANCE.acquire().use {
SignalDatabase.rawDatabase.withinTransaction {
val followUpOperations = batch
.mapNotNull { processEnvelope(bufferedStore, it.envelope, it.serverDeliveredTimestamp) }
.flatten()
bufferedStore.flushToDisk()
followUpOperations.forEach { it.run() }
}
}
val duration = System.currentTimeMillis() - startTime
Log.d(TAG, "Decrypted ${batch.size} envelopes in $duration ms (~${duration / batch.size} ms per message)")
true
}

View file

@ -38,13 +38,13 @@ import org.thoughtcrime.securesms.jobs.PreKeysSyncJob
import org.thoughtcrime.securesms.jobs.SendRetryReceiptJob
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.logsubmit.SubmitDebugLogActivity
import org.thoughtcrime.securesms.messages.protocol.BufferedProtocolStore
import org.thoughtcrime.securesms.notifications.NotificationChannels
import org.thoughtcrime.securesms.notifications.NotificationIds
import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.recipients.RecipientId
import org.thoughtcrime.securesms.util.FeatureFlags
import org.whispersystems.signalservice.api.InvalidMessageStructureException
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.crypto.ContentHint
import org.whispersystems.signalservice.api.crypto.EnvelopeMetadata
import org.whispersystems.signalservice.api.crypto.SignalServiceCipher
@ -74,7 +74,12 @@ object MessageDecryptor {
* To keep that property, there may be [Result.followUpOperations] you have to perform after your transaction is committed.
* These can vary from enqueueing jobs to inserting items into the [org.thoughtcrime.securesms.database.PendingRetryReceiptCache].
*/
fun decrypt(context: Context, envelope: Envelope, serverDeliveredTimestamp: Long): Result {
fun decrypt(
context: Context,
bufferedProtocolStore: BufferedProtocolStore,
envelope: Envelope,
serverDeliveredTimestamp: Long
): Result {
val selfAci: ServiceId = SignalStore.account().requireAci()
val selfPni: ServiceId = SignalStore.account().requirePni()
@ -106,9 +111,9 @@ object MessageDecryptor {
}
}
val protocolStore: SignalServiceAccountDataStore = ApplicationDependencies.getProtocolStore().get(destination)
val bufferedStore = bufferedProtocolStore.get(destination)
val localAddress = SignalServiceAddress(selfAci, SignalStore.account().e164)
val cipher = SignalServiceCipher(localAddress, SignalStore.account().deviceId, protocolStore, ReentrantSessionLock.INSTANCE, UnidentifiedAccessUtil.getCertificateValidator())
val cipher = SignalServiceCipher(localAddress, SignalStore.account().deviceId, bufferedStore, ReentrantSessionLock.INSTANCE, UnidentifiedAccessUtil.getCertificateValidator())
return try {
val cipherResult: SignalServiceCipherResult? = cipher.decrypt(envelope, serverDeliveredTimestamp)

View file

@ -0,0 +1,80 @@
package org.thoughtcrime.securesms.messages.protocol
import org.signal.libsignal.protocol.IdentityKey
import org.signal.libsignal.protocol.IdentityKeyPair
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.state.IdentityKeyStore
import org.thoughtcrime.securesms.database.SignalDatabase
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.push.ServiceId
/**
* An in-memory identity key store that is intended to be used temporarily while decrypting messages.
*/
class BufferedIdentityKeyStore(
private val selfServiceId: ServiceId,
private val selfIdentityKeyPair: IdentityKeyPair,
private val selfRegistrationId: Int
) : IdentityKeyStore {
private val store: MutableMap<SignalProtocolAddress, IdentityKey> = HashMap()
/** All of the keys that have been created or updated during operation. */
private val updatedKeys: MutableMap<SignalProtocolAddress, IdentityKey> = mutableMapOf()
override fun getIdentityKeyPair(): IdentityKeyPair {
return selfIdentityKeyPair
}
override fun getLocalRegistrationId(): Int {
return selfRegistrationId
}
override fun saveIdentity(address: SignalProtocolAddress, identityKey: IdentityKey): Boolean {
val existing: IdentityKey? = getIdentity(address)
store[address] = identityKey
return if (identityKey != existing) {
updatedKeys[address] = identityKey
true
} else {
false
}
}
override fun isTrustedIdentity(address: SignalProtocolAddress, identityKey: IdentityKey, direction: IdentityKeyStore.Direction): Boolean {
if (address.name == selfServiceId.toString()) {
return identityKey == selfIdentityKeyPair.publicKey
}
return when (direction) {
IdentityKeyStore.Direction.RECEIVING -> true
IdentityKeyStore.Direction.SENDING -> error("Should not happen during the intended usage pattern of this class")
else -> error("Unknown direction: $direction")
}
}
override fun getIdentity(address: SignalProtocolAddress): IdentityKey? {
val cached = store[address]
return if (cached != null) {
cached
} else {
val fromDatabase = SignalDatabase.identities.getIdentityStoreRecord(address.name)
if (fromDatabase != null) {
store[address] = fromDatabase.identityKey
}
fromDatabase?.identityKey
}
}
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
for ((address, identityKey) in updatedKeys) {
persistentStore.saveIdentity(address, identityKey)
}
updatedKeys.clear()
}
}

View file

@ -0,0 +1,49 @@
package org.thoughtcrime.securesms.messages.protocol
import org.signal.libsignal.protocol.InvalidKeyIdException
import org.signal.libsignal.protocol.state.PreKeyRecord
import org.signal.libsignal.protocol.state.PreKeyStore
import org.thoughtcrime.securesms.database.SignalDatabase
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.push.ServiceId
/**
* An in-memory one-time prekey store that is intended to be used temporarily while decrypting messages.
*/
class BufferedOneTimePreKeyStore(private val selfServiceId: ServiceId) : PreKeyStore {
/** Our in-memory cache of one-time prekeys. */
private val store: MutableMap<Int, PreKeyRecord> = HashMap()
/** The one-time prekeys that have been marked as removed */
private val removed: MutableList<Int> = mutableListOf()
@kotlin.jvm.Throws(InvalidKeyIdException::class)
override fun loadPreKey(id: Int): PreKeyRecord {
return store.computeIfAbsent(id) {
SignalDatabase.oneTimePreKeys.get(selfServiceId, id) ?: throw InvalidKeyIdException("Missing one-time prekey with ID: $id")
}
}
override fun storePreKey(id: Int, record: PreKeyRecord) {
error("Should not happen during the intended usage pattern of this class")
}
override fun containsPreKey(id: Int): Boolean {
loadPreKey(id)
return store.containsKey(id)
}
override fun removePreKey(id: Int) {
store.remove(id)
removed += id
}
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
for (id in removed) {
persistentStore.removePreKey(id)
}
removed.clear()
}
}

View file

@ -0,0 +1,46 @@
package org.thoughtcrime.securesms.messages.protocol
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.whispersystems.signalservice.api.push.ServiceId
/**
* The entry point for creating and retrieving buffered protocol stores.
* These stores will read from disk, but never write, instead buffering the results in memory.
* You can then call [flushToDisk] in order to write the buffered results to disk.
*
* This allows you to efficiently do batches of work and avoid unnecessary intermediate writes.
*/
class BufferedProtocolStore private constructor(
private val aciStore: Pair<ServiceId, BufferedSignalServiceAccountDataStore>,
private val pniStore: Pair<ServiceId, BufferedSignalServiceAccountDataStore>
) {
fun get(serviceId: ServiceId): BufferedSignalServiceAccountDataStore {
return when (serviceId) {
aciStore.first -> aciStore.second
pniStore.first -> pniStore.second
else -> error("No store matching serviceId $serviceId")
}
}
/**
* Writes any buffered data to disk. You can continue to use the same buffered store afterwards.
*/
fun flushToDisk() {
aciStore.second.flushToDisk(ApplicationDependencies.getProtocolStore().aci())
pniStore.second.flushToDisk(ApplicationDependencies.getProtocolStore().pni())
}
companion object {
fun create(): BufferedProtocolStore {
val aci = SignalStore.account().requireAci()
val pni = SignalStore.account().requirePni()
return BufferedProtocolStore(
aciStore = aci to BufferedSignalServiceAccountDataStore(aci),
pniStore = pni to BufferedSignalServiceAccountDataStore(pni)
)
}
}
}

View file

@ -0,0 +1,75 @@
package org.thoughtcrime.securesms.messages.protocol
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord
import org.thoughtcrime.securesms.database.SignalDatabase
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore
import org.whispersystems.signalservice.api.push.DistributionId
import java.util.UUID
/**
* An in-memory sender key store that is intended to be used temporarily while decrypting messages.
*/
class BufferedSenderKeyStore : SignalServiceSenderKeyStore {
private val store: MutableMap<StoreKey, SenderKeyRecord> = HashMap()
/** All of the keys that have been created or updated during operation. */
private val updatedKeys: MutableMap<StoreKey, SenderKeyRecord> = mutableMapOf()
/** All of the distributionId's whose sharing has been cleared during operation. */
private val clearSharedWith: MutableSet<SignalProtocolAddress> = mutableSetOf()
override fun storeSenderKey(sender: SignalProtocolAddress, distributionId: UUID, record: SenderKeyRecord) {
val key = StoreKey(sender, distributionId)
store[key] = record
updatedKeys[key] = record
}
override fun loadSenderKey(sender: SignalProtocolAddress, distributionId: UUID): SenderKeyRecord? {
val cached: SenderKeyRecord? = store[StoreKey(sender, distributionId)]
return if (cached != null) {
cached
} else {
val fromDatabase: SenderKeyRecord? = SignalDatabase.senderKeys.load(sender, distributionId.toDistributionId())
if (fromDatabase != null) {
store[StoreKey(sender, distributionId)] = fromDatabase
}
return fromDatabase
}
}
override fun clearSenderKeySharedWith(addresses: MutableCollection<SignalProtocolAddress>) {
clearSharedWith.addAll(addresses)
}
override fun getSenderKeySharedWith(distributionId: DistributionId?): MutableSet<SignalProtocolAddress> {
error("Should not happen during the intended usage pattern of this class")
}
override fun markSenderKeySharedWith(distributionId: DistributionId?, addresses: MutableCollection<SignalProtocolAddress>?) {
error("Should not happen during the intended usage pattern of this class")
}
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
for ((key, record) in updatedKeys) {
persistentStore.storeSenderKey(key.address, key.distributionId, record)
}
persistentStore.clearSenderKeySharedWith(clearSharedWith)
updatedKeys.clear()
clearSharedWith.clear()
}
private fun UUID.toDistributionId() = DistributionId.from(this)
data class StoreKey(
val address: SignalProtocolAddress,
val distributionId: UUID
)
}

View file

@ -0,0 +1,115 @@
package org.thoughtcrime.securesms.messages.protocol
import org.signal.libsignal.protocol.NoSessionException
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.message.CiphertextMessage
import org.signal.libsignal.protocol.state.SessionRecord
import org.thoughtcrime.securesms.database.SignalDatabase
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.SignalServiceSessionStore
import org.whispersystems.signalservice.api.push.ServiceId
import kotlin.jvm.Throws
/**
* An in-memory session store that is intended to be used temporarily while decrypting messages.
*/
class BufferedSessionStore(private val selfServiceId: ServiceId) : SignalServiceSessionStore {
private val store: MutableMap<SignalProtocolAddress, SessionRecord> = HashMap()
/** All of the sessions that have been created or updated during operation. */
private val updatedSessions: MutableMap<SignalProtocolAddress, SessionRecord> = mutableMapOf()
/** All of the sessions that have deleted during operation. */
private val deletedSessions: MutableSet<SignalProtocolAddress> = mutableSetOf()
override fun loadSession(address: SignalProtocolAddress): SessionRecord {
val session: SessionRecord = store[address]
?: SignalDatabase.sessions.load(selfServiceId, address)
?: SessionRecord()
store[address] = session
return session
}
@Throws(NoSessionException::class)
override fun loadExistingSessions(addresses: MutableList<SignalProtocolAddress>): List<SessionRecord> {
val found: MutableList<SessionRecord> = mutableListOf()
val needsDatabaseLookup: MutableList<SignalProtocolAddress> = mutableListOf()
for (address in addresses) {
val cached: SessionRecord? = store[address]
if (cached != null) {
found += cached
} else {
needsDatabaseLookup += address
}
}
if (needsDatabaseLookup.isNotEmpty()) {
found += SignalDatabase.sessions.load(selfServiceId, needsDatabaseLookup).filterNotNull()
}
if (found.size != addresses.size) {
throw NoSessionException("Failed to find one or more sessions.")
}
return found
}
override fun storeSession(address: SignalProtocolAddress, record: SessionRecord) {
store[address] = record
updatedSessions[address] = record
}
override fun containsSession(address: SignalProtocolAddress): Boolean {
return if (store.containsKey(address)) {
true
} else {
val fromDatabase: SessionRecord? = SignalDatabase.sessions.load(selfServiceId, address)
if (fromDatabase != null) {
store[address] = fromDatabase
return fromDatabase.hasSenderChain() && fromDatabase.sessionVersion == CiphertextMessage.CURRENT_VERSION
} else {
false
}
}
}
override fun deleteSession(address: SignalProtocolAddress) {
store.remove(address)
deletedSessions += address
}
override fun getSubDeviceSessions(name: String): MutableList<Int> {
error("Should not happen during the intended usage pattern of this class")
}
override fun deleteAllSessions(name: String) {
error("Should not happen during the intended usage pattern of this class")
}
override fun archiveSession(address: SignalProtocolAddress?) {
error("Should not happen during the intended usage pattern of this class")
}
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Set<SignalProtocolAddress> {
error("Should not happen during the intended usage pattern of this class")
}
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
for ((address, record) in updatedSessions) {
persistentStore.storeSession(address, record)
}
for (address in deletedSessions) {
persistentStore.deleteSession(address)
}
updatedSessions.clear()
deletedSessions.clear()
}
}

View file

@ -0,0 +1,157 @@
package org.thoughtcrime.securesms.messages.protocol
import org.signal.libsignal.protocol.IdentityKey
import org.signal.libsignal.protocol.IdentityKeyPair
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord
import org.signal.libsignal.protocol.state.IdentityKeyStore
import org.signal.libsignal.protocol.state.PreKeyRecord
import org.signal.libsignal.protocol.state.SessionRecord
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.push.DistributionId
import org.whispersystems.signalservice.api.push.ServiceId
import java.util.UUID
/**
* The wrapper around all of the Buffered protocol stores. Designed to perform operations in memory,
* then [flushToDisk] at set intervals.
*/
class BufferedSignalServiceAccountDataStore(selfServiceId: ServiceId) : SignalServiceAccountDataStore {
private val identityStore: BufferedIdentityKeyStore = if (selfServiceId == SignalStore.account().pni) {
BufferedIdentityKeyStore(selfServiceId, SignalStore.account().pniIdentityKey, SignalStore.account().pniRegistrationId)
} else {
BufferedIdentityKeyStore(selfServiceId, SignalStore.account().aciIdentityKey, SignalStore.account().registrationId)
}
private val oneTimePreKeyStore: BufferedOneTimePreKeyStore = BufferedOneTimePreKeyStore(selfServiceId)
private val signedPreKeyStore: BufferedSignedPreKeyStore = BufferedSignedPreKeyStore(selfServiceId)
private val sessionStore: BufferedSessionStore = BufferedSessionStore(selfServiceId)
private val senderKeyStore: BufferedSenderKeyStore = BufferedSenderKeyStore()
override fun getIdentityKeyPair(): IdentityKeyPair {
return identityStore.identityKeyPair
}
override fun getLocalRegistrationId(): Int {
return identityStore.localRegistrationId
}
override fun saveIdentity(address: SignalProtocolAddress, identityKey: IdentityKey): Boolean {
return identityStore.saveIdentity(address, identityKey)
}
override fun isTrustedIdentity(address: SignalProtocolAddress, identityKey: IdentityKey, direction: IdentityKeyStore.Direction): Boolean {
return identityStore.isTrustedIdentity(address, identityKey, direction)
}
override fun getIdentity(address: SignalProtocolAddress): IdentityKey? {
return identityStore.getIdentity(address)
}
override fun loadPreKey(preKeyId: Int): PreKeyRecord {
return oneTimePreKeyStore.loadPreKey(preKeyId)
}
override fun storePreKey(preKeyId: Int, record: PreKeyRecord) {
return oneTimePreKeyStore.storePreKey(preKeyId, record)
}
override fun containsPreKey(preKeyId: Int): Boolean {
return oneTimePreKeyStore.containsPreKey(preKeyId)
}
override fun removePreKey(preKeyId: Int) {
oneTimePreKeyStore.removePreKey(preKeyId)
}
override fun loadSession(address: SignalProtocolAddress): SessionRecord {
return sessionStore.loadSession(address)
}
override fun loadExistingSessions(addresses: MutableList<SignalProtocolAddress>): List<SessionRecord> {
return sessionStore.loadExistingSessions(addresses)
}
override fun getSubDeviceSessions(name: String): MutableList<Int> {
return sessionStore.getSubDeviceSessions(name)
}
override fun storeSession(address: SignalProtocolAddress, record: SessionRecord) {
sessionStore.storeSession(address, record)
}
override fun containsSession(address: SignalProtocolAddress): Boolean {
return sessionStore.containsSession(address)
}
override fun deleteSession(address: SignalProtocolAddress) {
return sessionStore.deleteSession(address)
}
override fun deleteAllSessions(name: String) {
sessionStore.deleteAllSessions(name)
}
override fun loadSignedPreKey(signedPreKeyId: Int): SignedPreKeyRecord {
return signedPreKeyStore.loadSignedPreKey(signedPreKeyId)
}
override fun loadSignedPreKeys(): List<SignedPreKeyRecord> {
return signedPreKeyStore.loadSignedPreKeys()
}
override fun storeSignedPreKey(signedPreKeyId: Int, record: SignedPreKeyRecord) {
signedPreKeyStore.storeSignedPreKey(signedPreKeyId, record)
}
override fun containsSignedPreKey(signedPreKeyId: Int): Boolean {
return signedPreKeyStore.containsSignedPreKey(signedPreKeyId)
}
override fun removeSignedPreKey(signedPreKeyId: Int) {
signedPreKeyStore.removeSignedPreKey(signedPreKeyId)
}
override fun storeSenderKey(sender: SignalProtocolAddress, distributionId: UUID, record: SenderKeyRecord) {
senderKeyStore.storeSenderKey(sender, distributionId, record)
}
override fun loadSenderKey(sender: SignalProtocolAddress, distributionId: UUID): SenderKeyRecord? {
return senderKeyStore.loadSenderKey(sender, distributionId)
}
override fun archiveSession(address: SignalProtocolAddress?) {
sessionStore.archiveSession(address)
}
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Set<SignalProtocolAddress> {
return sessionStore.getAllAddressesWithActiveSessions(addressNames)
}
override fun getSenderKeySharedWith(distributionId: DistributionId?): MutableSet<SignalProtocolAddress> {
return senderKeyStore.getSenderKeySharedWith(distributionId)
}
override fun markSenderKeySharedWith(distributionId: DistributionId, addresses: MutableCollection<SignalProtocolAddress>) {
senderKeyStore.markSenderKeySharedWith(distributionId, addresses)
}
override fun clearSenderKeySharedWith(addresses: MutableCollection<SignalProtocolAddress>) {
senderKeyStore.clearSenderKeySharedWith(addresses)
}
override fun isMultiDevice(): Boolean {
error("Should not happen during the intended usage pattern of this class")
}
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
identityStore.flushToDisk(persistentStore)
oneTimePreKeyStore.flushToDisk(persistentStore)
signedPreKeyStore.flushToDisk(persistentStore)
sessionStore.flushToDisk(persistentStore)
senderKeyStore.flushToDisk(persistentStore)
}
}

View file

@ -0,0 +1,64 @@
package org.thoughtcrime.securesms.messages.protocol
import org.signal.libsignal.protocol.InvalidKeyIdException
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
import org.signal.libsignal.protocol.state.SignedPreKeyStore
import org.thoughtcrime.securesms.database.SignalDatabase
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.push.ServiceId
/**
* An in-memory signed prekey store that is intended to be used temporarily while decrypting messages.
*/
class BufferedSignedPreKeyStore(private val selfServiceId: ServiceId) : SignedPreKeyStore {
/** Our in-memory cache of signed prekeys. */
private val store: MutableMap<Int, SignedPreKeyRecord> = HashMap()
/** The signed prekeys that have been marked as removed */
private val removed: MutableList<Int> = mutableListOf()
/** Whether or not we've done a loadAll operation. Let's us avoid doing it twice. */
private var hasLoadedAll: Boolean = false
@kotlin.jvm.Throws(InvalidKeyIdException::class)
override fun loadSignedPreKey(id: Int): SignedPreKeyRecord {
return store.computeIfAbsent(id) {
SignalDatabase.signedPreKeys.get(selfServiceId, id) ?: throw InvalidKeyIdException("Missing one-time prekey with ID: $id")
}
}
override fun loadSignedPreKeys(): List<SignedPreKeyRecord> {
return if (hasLoadedAll) {
store.values.toList()
} else {
val records = SignalDatabase.signedPreKeys.getAll(selfServiceId)
records.forEach { store[it.id] = it }
hasLoadedAll = true
records
}
}
override fun storeSignedPreKey(id: Int, record: SignedPreKeyRecord) {
error("Should not happen during the intended usage pattern of this class")
}
override fun containsSignedPreKey(id: Int): Boolean {
loadSignedPreKey(id)
return store.containsKey(id)
}
override fun removeSignedPreKey(id: Int) {
store.remove(id)
removed += id
}
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
for (id in removed) {
persistentStore.removeSignedPreKey(id)
}
removed.clear()
}
}

View file

@ -1,7 +1,10 @@
package org.whispersystems.signalservice.api;
import com.google.protobuf.InvalidProtocolBufferException;
import org.signal.libsignal.protocol.logging.Log;
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess;
import org.whispersystems.signalservice.api.messages.EnvelopeResponse;
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState;
import org.whispersystems.signalservice.api.websocket.WebSocketFactory;
import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException;
@ -13,9 +16,10 @@ import org.whispersystems.signalservice.internal.websocket.WebsocketResponse;
import org.whispersystems.util.Base64;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import io.reactivex.rxjava3.core.Observable;
import io.reactivex.rxjava3.core.Single;
@ -218,61 +222,105 @@ public final class SignalWebSocket {
}
/**
* <p>
* A blocking call that reads a message off the pipe. When this call returns, if the callback indicates the
* message was successfully processed, then the message will be ack'ed on the serve and will not be retransmitted.
* <p>
* This will return true if there are more messages to be read from the websocket, or false if the websocket is empty.
* <p>
* You can specify a {@link MessageReceivedCallback} that will be called before the received message is acknowledged.
* This allows you to write the received message to durable storage before acknowledging receipt of it to the
* server.
* <p>
* Important: This will only return `false` once for each connection. That means if you get false call readMessage()
* again on the same instance, you will not get an immediate `false` return value, and instead will block until
* you get an actual message. This will, however, reset if connection breaks (if, for instance, you lose and regain network).
* The reads a batch of messages off of the websocket.
*
* @param timeout The timeout to wait for.
* @param callback A callback that will be called before the message receipt is acknowledged to the server.
* @return The message read (same as the message sent through the callback).
* Rather than just provide you the batch as a return value, it will invoke the provided callback with the
* batch as an argument. If you are able to successfully process them, this method will then ack all of the
* messages so that they won't be re-delivered in the future.
*
* The return value of this method is a boolean indicating whether or not there are more messages in the
* queue to be read (true if there's still more, or false if you've drained everything).
*
* However, this return value is only really useful the first time you read from the websocket. That's because
* the websocket will only ever let you know if it's drained *once* for any given connection. So if this method
* returns false, a subsequent call while using the same websocket connection will simply block until we either
* get a new message or hit the timeout.
*
* Concerning the requested batch size, it's worth noting that this is simply an upper bound. This method will
* not wait extra time until the batch has "filled up". Instead, it will wait for a single message, and then
* take any extra messages that are also available up until you've hit your batch size.
*/
@SuppressWarnings("DuplicateThrows")
public boolean readMessage(long timeout, MessageReceivedCallback callback)
public boolean readMessageBatch(long timeout, int batchSize, MessageReceivedCallback callback)
throws TimeoutException, WebSocketUnavailableException, IOException
{
while (true) {
WebSocketRequestMessage request = getWebSocket().readRequest(timeout);
WebSocketResponseMessage response = createWebSocketResponse(request);
List<EnvelopeResponse> responses = new ArrayList<>();
boolean hitEndOfQueue = false;
AtomicBoolean successfullyProcessed = new AtomicBoolean(false);
Optional<EnvelopeResponse> firstEnvelope = waitForSingleMessage(timeout);
try {
if (isSignalServiceEnvelope(request)) {
Optional<String> timestampHeader = findHeader(request);
long timestamp = 0;
if (firstEnvelope.isPresent()) {
responses.add(firstEnvelope.get());
} else {
hitEndOfQueue = true;
}
if (timestampHeader.isPresent()) {
try {
timestamp = Long.parseLong(timestampHeader.get());
} catch (NumberFormatException e) {
Log.w(TAG, "Failed to parse " + SERVER_DELIVERED_TIMESTAMP_HEADER);
}
if (!hitEndOfQueue) {
for (int i = 1; i < batchSize; i++) {
Optional<WebSocketRequestMessage> request = getWebSocket().readRequestIfAvailable();
if (request.isPresent()) {
if (isSignalServiceEnvelope(request.get())) {
responses.add(requestToEnvelopeResponse(request.get()));
} else if (isSocketEmptyRequest(request.get())) {
hitEndOfQueue = true;
break;
}
SignalServiceProtos.Envelope envelope = SignalServiceProtos.Envelope.parseFrom(request.getBody().toByteArray());
successfullyProcessed.set(callback.onMessage(envelope, timestamp));
return true;
} else if (isSocketEmptyRequest(request)) {
return false;
}
} finally {
if (successfullyProcessed.get()) {
getWebSocket().sendResponse(response);
} else {
break;
}
}
}
if (responses.size() > 0) {
boolean successfullyProcessed = false;
try {
successfullyProcessed = callback.onMessageBatch(responses);
} finally {
if (successfullyProcessed) {
for (EnvelopeResponse response : responses) {
getWebSocket().sendResponse(createWebSocketResponse(response.getWebsocketRequest()));
}
}
}
}
return !hitEndOfQueue;
}
@SuppressWarnings("DuplicateThrows")
private Optional<EnvelopeResponse> waitForSingleMessage(long timeout)
throws TimeoutException, WebSocketUnavailableException, IOException
{
while (true) {
WebSocketRequestMessage request = getWebSocket().readRequest(timeout);
if (isSignalServiceEnvelope(request)) {
return Optional.of(requestToEnvelopeResponse(request));
} else if (isSocketEmptyRequest(request)) {
return Optional.empty();
}
}
}
private static EnvelopeResponse requestToEnvelopeResponse(WebSocketRequestMessage request)
throws InvalidProtocolBufferException
{
Optional<String> timestampHeader = findHeader(request);
long timestamp = 0;
if (timestampHeader.isPresent()) {
try {
timestamp = Long.parseLong(timestampHeader.get());
} catch (NumberFormatException e) {
Log.w(TAG, "Failed to parse " + SERVER_DELIVERED_TIMESTAMP_HEADER);
}
}
SignalServiceProtos.Envelope envelope = SignalServiceProtos.Envelope.parseFrom(request.getBody().toByteArray());
return new EnvelopeResponse(envelope, timestamp, request);
}
private static boolean isSignalServiceEnvelope(WebSocketRequestMessage message) {
@ -323,6 +371,6 @@ public final class SignalWebSocket {
public interface MessageReceivedCallback {
/** True if you successfully processed the message, otherwise false. **/
boolean onMessage(SignalServiceProtos.Envelope envelope, long serverDeliveredTimestamp);
boolean onMessageBatch(List<EnvelopeResponse> envelopeResponses);
}
}

View file

@ -0,0 +1,13 @@
package org.whispersystems.signalservice.api.messages
import org.whispersystems.signalservice.internal.push.SignalServiceProtos.Envelope
import org.whispersystems.signalservice.internal.websocket.WebSocketProtos.WebSocketRequestMessage
/**
* Represents an envelope off the wire, paired with the metadata needed to process it.
*/
class EnvelopeResponse(
val envelope: Envelope,
val serverDeliveredTimestamp: Long,
val websocketRequest: WebSocketRequestMessage
)

View file

@ -193,6 +193,14 @@ public class WebSocketConnection extends WebSocketListener {
notifyAll();
}
public synchronized Optional<WebSocketRequestMessage> readRequestIfAvailable() {
if (incomingRequests.size() > 0) {
return Optional.of(incomingRequests.removeFirst());
} else {
return Optional.empty();
}
}
public synchronized WebSocketRequestMessage readRequest(long timeoutMillis)
throws TimeoutException, IOException
{