diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/testing/AliceClient.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/testing/AliceClient.kt index 5847e42995..4ce8c37580 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/testing/AliceClient.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/testing/AliceClient.kt @@ -6,6 +6,7 @@ import org.signal.libsignal.zkgroup.profiles.ProfileKey import org.thoughtcrime.securesms.crypto.ProfileKeyUtil import org.thoughtcrime.securesms.dependencies.ApplicationDependencies import org.thoughtcrime.securesms.keyvalue.SignalStore +import org.thoughtcrime.securesms.messages.protocol.BufferedProtocolStore import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.testing.FakeClientHelpers.toEnvelope import org.whispersystems.signalservice.api.push.ServiceId @@ -35,7 +36,9 @@ class AliceClient(val serviceId: ServiceId, val e164: String, val trustRoot: ECK fun process(envelope: Envelope, serverDeliveredTimestamp: Long) { val start = System.currentTimeMillis() - ApplicationDependencies.getIncomingMessageObserver().processEnvelope(envelope, serverDeliveredTimestamp) + val bufferedStore = BufferedProtocolStore.create() + ApplicationDependencies.getIncomingMessageObserver().processEnvelope(bufferedStore, envelope, serverDeliveredTimestamp) + bufferedStore.flushToDisk() val end = System.currentTimeMillis() Log.d(TAG, "${end - start}") } diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/PushDecryptMessageJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/PushDecryptMessageJob.kt index b9d2826bcf..ebc77e6dd9 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/PushDecryptMessageJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/PushDecryptMessageJob.kt @@ -14,6 +14,7 @@ import org.thoughtcrime.securesms.keyvalue.SignalStore import org.thoughtcrime.securesms.messages.MessageContentProcessor.ExceptionMetadata import org.thoughtcrime.securesms.messages.MessageContentProcessor.MessageState import org.thoughtcrime.securesms.messages.MessageDecryptor +import org.thoughtcrime.securesms.messages.protocol.BufferedProtocolStore import org.thoughtcrime.securesms.notifications.NotificationChannels import org.thoughtcrime.securesms.notifications.NotificationIds import org.thoughtcrime.securesms.transport.RetryLaterException @@ -77,7 +78,9 @@ class PushDecryptMessageJob private constructor( throw RetryLaterException() } - val result = MessageDecryptor.decrypt(context, envelope.proto, envelope.serverDeliveredTimestamp) + val bufferedProtocolStore = BufferedProtocolStore.create() + val result = MessageDecryptor.decrypt(context, bufferedProtocolStore, envelope.proto, envelope.serverDeliveredTimestamp) + bufferedProtocolStore.flushToDisk() when (result) { is MessageDecryptor.Result.Success -> { diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/SendDeliveryReceiptJob.java b/app/src/main/java/org/thoughtcrime/securesms/jobs/SendDeliveryReceiptJob.java index bdc9cf6b7a..868634a6c5 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/SendDeliveryReceiptJob.java +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/SendDeliveryReceiptJob.java @@ -11,6 +11,7 @@ import org.thoughtcrime.securesms.database.model.MessageId; import org.thoughtcrime.securesms.dependencies.ApplicationDependencies; import org.thoughtcrime.securesms.jobmanager.Data; import org.thoughtcrime.securesms.jobmanager.Job; +import org.thoughtcrime.securesms.jobmanager.impl.DecryptionsDrainedConstraint; import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint; import org.thoughtcrime.securesms.net.NotPushRegisteredException; import org.thoughtcrime.securesms.recipients.Recipient; @@ -51,6 +52,7 @@ public class SendDeliveryReceiptJob extends BaseJob { public SendDeliveryReceiptJob(@NonNull RecipientId recipientId, long messageSentTimestamp, @NonNull MessageId messageId) { this(new Job.Parameters.Builder() .addConstraint(NetworkConstraint.KEY) + .addConstraint(DecryptionsDrainedConstraint.KEY) .setLifespan(TimeUnit.DAYS.toMillis(1)) .setMaxAttempts(Parameters.UNLIMITED) .setQueue(recipientId.toQueueKey()) diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/SendReadReceiptJob.java b/app/src/main/java/org/thoughtcrime/securesms/jobs/SendReadReceiptJob.java index 1f28f16d6f..5561ffb4d4 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/SendReadReceiptJob.java +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/SendReadReceiptJob.java @@ -16,6 +16,7 @@ import org.thoughtcrime.securesms.dependencies.ApplicationDependencies; import org.thoughtcrime.securesms.jobmanager.Data; import org.thoughtcrime.securesms.jobmanager.Job; import org.thoughtcrime.securesms.jobmanager.JobManager; +import org.thoughtcrime.securesms.jobmanager.impl.DecryptionsDrainedConstraint; import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint; import org.thoughtcrime.securesms.net.NotPushRegisteredException; import org.thoughtcrime.securesms.recipients.Recipient; @@ -65,6 +66,7 @@ public class SendReadReceiptJob extends BaseJob { public SendReadReceiptJob(long threadId, @NonNull RecipientId recipientId, List messageSentTimestamps, List messageIds) { this(new Job.Parameters.Builder() .addConstraint(NetworkConstraint.KEY) + .addConstraint(DecryptionsDrainedConstraint.KEY) .setLifespan(TimeUnit.DAYS.toMillis(1)) .setMaxAttempts(Parameters.UNLIMITED) .setQueue(recipientId.toQueueKey()) diff --git a/app/src/main/java/org/thoughtcrime/securesms/logging/PersistentLogger.kt b/app/src/main/java/org/thoughtcrime/securesms/logging/PersistentLogger.kt index 62a675efdf..c884de0543 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/logging/PersistentLogger.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/logging/PersistentLogger.kt @@ -72,7 +72,7 @@ class PersistentLogger( } private fun write(level: String, tag: String?, message: String?, t: Throwable?, keepLonger: Boolean) { - logEntries.add(LogRequest(level, tag ?: "null", message, Date(), getThreadString(), t, keepLonger)) + logEntries.add(LogRequest(level, tag ?: "null", message, System.currentTimeMillis(), getThreadString(), t, keepLonger)) } private fun getThreadString(): String { @@ -95,7 +95,7 @@ class PersistentLogger( val level: String, val tag: String, val message: String?, - val date: Date, + val createTime: Long, val threadString: String, val throwable: Throwable?, val keepLonger: Boolean @@ -121,11 +121,13 @@ class PersistentLogger( fun requestToEntries(request: LogRequest): List { val out = mutableListOf() + val createDate = Date(request.createTime) + out.add( LogEntry( - createdAt = request.date.time, + createdAt = request.createTime, keepLonger = request.keepLonger, - body = formatBody(request.threadString, request.date, request.level, request.tag, request.message) + body = formatBody(request.threadString, createDate, request.level, request.tag, request.message) ) ) @@ -138,9 +140,9 @@ class PersistentLogger( val entries = lines.map { line -> LogEntry( - createdAt = request.date.time, + createdAt = request.createTime, keepLonger = request.keepLonger, - body = formatBody(request.threadString, request.date, request.level, request.tag, line) + body = formatBody(request.threadString, createDate, request.level, request.tag, line) ) } diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt index 0ec72b8181..9c3d100250 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt @@ -14,7 +14,9 @@ import androidx.core.app.NotificationCompat import org.signal.core.util.ThreadUtil import org.signal.core.util.concurrent.SignalExecutors import org.signal.core.util.logging.Log +import org.signal.core.util.withinTransaction import org.thoughtcrime.securesms.R +import org.thoughtcrime.securesms.crypto.ReentrantSessionLock import org.thoughtcrime.securesms.database.MessageTable import org.thoughtcrime.securesms.database.SignalDatabase import org.thoughtcrime.securesms.dependencies.ApplicationDependencies @@ -28,6 +30,7 @@ import org.thoughtcrime.securesms.jobs.PushDecryptMessageJob import org.thoughtcrime.securesms.jobs.PushProcessMessageJob import org.thoughtcrime.securesms.jobs.UnableToStartException import org.thoughtcrime.securesms.keyvalue.SignalStore +import org.thoughtcrime.securesms.messages.protocol.BufferedProtocolStore import org.thoughtcrime.securesms.notifications.NotificationChannels import org.thoughtcrime.securesms.recipients.RecipientId import org.thoughtcrime.securesms.util.AppForegroundObserver @@ -200,7 +203,7 @@ class IncomingMessageObserver(private val context: Application) { val needsConnectionString = if (conclusion) "Needs Connection" else "Does Not Need Connection" - Log.d(TAG, "[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisible, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, Stay open requests: [${keepAliveTokens.entries}], Registered: $registered, Proxy: $hasProxy, Force websocket: $forceWebsocket, Decrypt Queue Empty: $decryptQueueEmpty") + Log.d(TAG, "[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisible, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, Stay open requests: ${keepAliveTokens.entries}, Registered: $registered, Proxy: $hasProxy, Force websocket: $forceWebsocket, Decrypt Queue Empty: $decryptQueueEmpty") return conclusion } } @@ -249,19 +252,29 @@ class IncomingMessageObserver(private val context: Application) { } @VisibleForTesting - fun processEnvelope(envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long) { - when (envelope.type.number) { - SignalServiceProtos.Envelope.Type.RECEIPT_VALUE -> processReceipt(envelope) + fun processEnvelope(bufferedProtocolStore: BufferedProtocolStore, envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long): List? { + return when (envelope.type.number) { + SignalServiceProtos.Envelope.Type.RECEIPT_VALUE -> { + processReceipt(envelope) + null + } + SignalServiceProtos.Envelope.Type.PREKEY_BUNDLE_VALUE, SignalServiceProtos.Envelope.Type.CIPHERTEXT_VALUE, SignalServiceProtos.Envelope.Type.UNIDENTIFIED_SENDER_VALUE, - SignalServiceProtos.Envelope.Type.PLAINTEXT_CONTENT_VALUE -> processMessage(envelope, serverDeliveredTimestamp) - else -> Log.w(TAG, "Received envelope of unknown type: " + envelope.type) + SignalServiceProtos.Envelope.Type.PLAINTEXT_CONTENT_VALUE -> { + processMessage(bufferedProtocolStore, envelope, serverDeliveredTimestamp) + } + + else -> { + Log.w(TAG, "Received envelope of unknown type: " + envelope.type) + null + } } } - private fun processMessage(envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long) { - val result = MessageDecryptor.decrypt(context, envelope, serverDeliveredTimestamp) + private fun processMessage(bufferedProtocolStore: BufferedProtocolStore, envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long): List { + val result = MessageDecryptor.decrypt(context, bufferedProtocolStore, envelope, serverDeliveredTimestamp) when (result) { is MessageDecryptor.Result.Success -> { @@ -297,7 +310,7 @@ class IncomingMessageObserver(private val context: Application) { } } - result.followUpOperations.forEach { it.run() } + return result.followUpOperations } private fun processReceipt(envelope: SignalServiceProtos.Envelope) { @@ -386,13 +399,31 @@ class IncomingMessageObserver(private val context: Application) { signalWebSocket.connect() try { + val bufferedStore = BufferedProtocolStore.create() + while (isConnectionNecessary()) { try { Log.d(TAG, "Reading message...") - val hasMore = signalWebSocket.readMessage(WEBSOCKET_READ_TIMEOUT) { envelope, serverDeliveredTimestamp -> - Log.i(TAG, "Retrieved envelope! " + envelope.timestamp) - processEnvelope(envelope, serverDeliveredTimestamp) + val hasMore = signalWebSocket.readMessageBatch(WEBSOCKET_READ_TIMEOUT, 30) { batch -> + Log.i(TAG, "Retrieved ${batch.size} envelopes!") + + val startTime = System.currentTimeMillis() + ReentrantSessionLock.INSTANCE.acquire().use { + SignalDatabase.rawDatabase.withinTransaction { + val followUpOperations = batch + .mapNotNull { processEnvelope(bufferedStore, it.envelope, it.serverDeliveredTimestamp) } + .flatten() + + bufferedStore.flushToDisk() + + followUpOperations.forEach { it.run() } + } + } + + val duration = System.currentTimeMillis() - startTime + Log.d(TAG, "Decrypted ${batch.size} envelopes in $duration ms (~${duration / batch.size} ms per message)") + true } diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/MessageDecryptor.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/MessageDecryptor.kt index 79b1b3c9ef..f2425dcd71 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/MessageDecryptor.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/MessageDecryptor.kt @@ -38,13 +38,13 @@ import org.thoughtcrime.securesms.jobs.PreKeysSyncJob import org.thoughtcrime.securesms.jobs.SendRetryReceiptJob import org.thoughtcrime.securesms.keyvalue.SignalStore import org.thoughtcrime.securesms.logsubmit.SubmitDebugLogActivity +import org.thoughtcrime.securesms.messages.protocol.BufferedProtocolStore import org.thoughtcrime.securesms.notifications.NotificationChannels import org.thoughtcrime.securesms.notifications.NotificationIds import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.RecipientId import org.thoughtcrime.securesms.util.FeatureFlags import org.whispersystems.signalservice.api.InvalidMessageStructureException -import org.whispersystems.signalservice.api.SignalServiceAccountDataStore import org.whispersystems.signalservice.api.crypto.ContentHint import org.whispersystems.signalservice.api.crypto.EnvelopeMetadata import org.whispersystems.signalservice.api.crypto.SignalServiceCipher @@ -74,7 +74,12 @@ object MessageDecryptor { * To keep that property, there may be [Result.followUpOperations] you have to perform after your transaction is committed. * These can vary from enqueueing jobs to inserting items into the [org.thoughtcrime.securesms.database.PendingRetryReceiptCache]. */ - fun decrypt(context: Context, envelope: Envelope, serverDeliveredTimestamp: Long): Result { + fun decrypt( + context: Context, + bufferedProtocolStore: BufferedProtocolStore, + envelope: Envelope, + serverDeliveredTimestamp: Long + ): Result { val selfAci: ServiceId = SignalStore.account().requireAci() val selfPni: ServiceId = SignalStore.account().requirePni() @@ -106,9 +111,9 @@ object MessageDecryptor { } } - val protocolStore: SignalServiceAccountDataStore = ApplicationDependencies.getProtocolStore().get(destination) + val bufferedStore = bufferedProtocolStore.get(destination) val localAddress = SignalServiceAddress(selfAci, SignalStore.account().e164) - val cipher = SignalServiceCipher(localAddress, SignalStore.account().deviceId, protocolStore, ReentrantSessionLock.INSTANCE, UnidentifiedAccessUtil.getCertificateValidator()) + val cipher = SignalServiceCipher(localAddress, SignalStore.account().deviceId, bufferedStore, ReentrantSessionLock.INSTANCE, UnidentifiedAccessUtil.getCertificateValidator()) return try { val cipherResult: SignalServiceCipherResult? = cipher.decrypt(envelope, serverDeliveredTimestamp) diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedIdentityKeyStore.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedIdentityKeyStore.kt new file mode 100644 index 0000000000..79914ed550 --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedIdentityKeyStore.kt @@ -0,0 +1,80 @@ +package org.thoughtcrime.securesms.messages.protocol + +import org.signal.libsignal.protocol.IdentityKey +import org.signal.libsignal.protocol.IdentityKeyPair +import org.signal.libsignal.protocol.SignalProtocolAddress +import org.signal.libsignal.protocol.state.IdentityKeyStore +import org.thoughtcrime.securesms.database.SignalDatabase +import org.whispersystems.signalservice.api.SignalServiceAccountDataStore +import org.whispersystems.signalservice.api.push.ServiceId + +/** + * An in-memory identity key store that is intended to be used temporarily while decrypting messages. + */ +class BufferedIdentityKeyStore( + private val selfServiceId: ServiceId, + private val selfIdentityKeyPair: IdentityKeyPair, + private val selfRegistrationId: Int +) : IdentityKeyStore { + + private val store: MutableMap = HashMap() + + /** All of the keys that have been created or updated during operation. */ + private val updatedKeys: MutableMap = mutableMapOf() + + override fun getIdentityKeyPair(): IdentityKeyPair { + return selfIdentityKeyPair + } + + override fun getLocalRegistrationId(): Int { + return selfRegistrationId + } + + override fun saveIdentity(address: SignalProtocolAddress, identityKey: IdentityKey): Boolean { + val existing: IdentityKey? = getIdentity(address) + + store[address] = identityKey + + return if (identityKey != existing) { + updatedKeys[address] = identityKey + true + } else { + false + } + } + + override fun isTrustedIdentity(address: SignalProtocolAddress, identityKey: IdentityKey, direction: IdentityKeyStore.Direction): Boolean { + if (address.name == selfServiceId.toString()) { + return identityKey == selfIdentityKeyPair.publicKey + } + + return when (direction) { + IdentityKeyStore.Direction.RECEIVING -> true + IdentityKeyStore.Direction.SENDING -> error("Should not happen during the intended usage pattern of this class") + else -> error("Unknown direction: $direction") + } + } + + override fun getIdentity(address: SignalProtocolAddress): IdentityKey? { + val cached = store[address] + + return if (cached != null) { + cached + } else { + val fromDatabase = SignalDatabase.identities.getIdentityStoreRecord(address.name) + if (fromDatabase != null) { + store[address] = fromDatabase.identityKey + } + + fromDatabase?.identityKey + } + } + + fun flushToDisk(persistentStore: SignalServiceAccountDataStore) { + for ((address, identityKey) in updatedKeys) { + persistentStore.saveIdentity(address, identityKey) + } + + updatedKeys.clear() + } +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedOneTimePreKeyStore.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedOneTimePreKeyStore.kt new file mode 100644 index 0000000000..566969e675 --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedOneTimePreKeyStore.kt @@ -0,0 +1,49 @@ +package org.thoughtcrime.securesms.messages.protocol + +import org.signal.libsignal.protocol.InvalidKeyIdException +import org.signal.libsignal.protocol.state.PreKeyRecord +import org.signal.libsignal.protocol.state.PreKeyStore +import org.thoughtcrime.securesms.database.SignalDatabase +import org.whispersystems.signalservice.api.SignalServiceAccountDataStore +import org.whispersystems.signalservice.api.push.ServiceId + +/** + * An in-memory one-time prekey store that is intended to be used temporarily while decrypting messages. + */ +class BufferedOneTimePreKeyStore(private val selfServiceId: ServiceId) : PreKeyStore { + + /** Our in-memory cache of one-time prekeys. */ + private val store: MutableMap = HashMap() + + /** The one-time prekeys that have been marked as removed */ + private val removed: MutableList = mutableListOf() + + @kotlin.jvm.Throws(InvalidKeyIdException::class) + override fun loadPreKey(id: Int): PreKeyRecord { + return store.computeIfAbsent(id) { + SignalDatabase.oneTimePreKeys.get(selfServiceId, id) ?: throw InvalidKeyIdException("Missing one-time prekey with ID: $id") + } + } + + override fun storePreKey(id: Int, record: PreKeyRecord) { + error("Should not happen during the intended usage pattern of this class") + } + + override fun containsPreKey(id: Int): Boolean { + loadPreKey(id) + return store.containsKey(id) + } + + override fun removePreKey(id: Int) { + store.remove(id) + removed += id + } + + fun flushToDisk(persistentStore: SignalServiceAccountDataStore) { + for (id in removed) { + persistentStore.removePreKey(id) + } + + removed.clear() + } +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedProtocolStore.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedProtocolStore.kt new file mode 100644 index 0000000000..11faf42f73 --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedProtocolStore.kt @@ -0,0 +1,46 @@ +package org.thoughtcrime.securesms.messages.protocol + +import org.thoughtcrime.securesms.dependencies.ApplicationDependencies +import org.thoughtcrime.securesms.keyvalue.SignalStore +import org.whispersystems.signalservice.api.push.ServiceId + +/** + * The entry point for creating and retrieving buffered protocol stores. + * These stores will read from disk, but never write, instead buffering the results in memory. + * You can then call [flushToDisk] in order to write the buffered results to disk. + * + * This allows you to efficiently do batches of work and avoid unnecessary intermediate writes. + */ +class BufferedProtocolStore private constructor( + private val aciStore: Pair, + private val pniStore: Pair +) { + + fun get(serviceId: ServiceId): BufferedSignalServiceAccountDataStore { + return when (serviceId) { + aciStore.first -> aciStore.second + pniStore.first -> pniStore.second + else -> error("No store matching serviceId $serviceId") + } + } + + /** + * Writes any buffered data to disk. You can continue to use the same buffered store afterwards. + */ + fun flushToDisk() { + aciStore.second.flushToDisk(ApplicationDependencies.getProtocolStore().aci()) + pniStore.second.flushToDisk(ApplicationDependencies.getProtocolStore().pni()) + } + + companion object { + fun create(): BufferedProtocolStore { + val aci = SignalStore.account().requireAci() + val pni = SignalStore.account().requirePni() + + return BufferedProtocolStore( + aciStore = aci to BufferedSignalServiceAccountDataStore(aci), + pniStore = pni to BufferedSignalServiceAccountDataStore(pni) + ) + } + } +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSenderKeyStore.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSenderKeyStore.kt new file mode 100644 index 0000000000..c63bde627c --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSenderKeyStore.kt @@ -0,0 +1,75 @@ +package org.thoughtcrime.securesms.messages.protocol + +import org.signal.libsignal.protocol.SignalProtocolAddress +import org.signal.libsignal.protocol.groups.state.SenderKeyRecord +import org.thoughtcrime.securesms.database.SignalDatabase +import org.whispersystems.signalservice.api.SignalServiceAccountDataStore +import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore +import org.whispersystems.signalservice.api.push.DistributionId +import java.util.UUID + +/** + * An in-memory sender key store that is intended to be used temporarily while decrypting messages. + */ +class BufferedSenderKeyStore : SignalServiceSenderKeyStore { + + private val store: MutableMap = HashMap() + + /** All of the keys that have been created or updated during operation. */ + private val updatedKeys: MutableMap = mutableMapOf() + + /** All of the distributionId's whose sharing has been cleared during operation. */ + private val clearSharedWith: MutableSet = mutableSetOf() + + override fun storeSenderKey(sender: SignalProtocolAddress, distributionId: UUID, record: SenderKeyRecord) { + val key = StoreKey(sender, distributionId) + store[key] = record + updatedKeys[key] = record + } + + override fun loadSenderKey(sender: SignalProtocolAddress, distributionId: UUID): SenderKeyRecord? { + val cached: SenderKeyRecord? = store[StoreKey(sender, distributionId)] + + return if (cached != null) { + cached + } else { + val fromDatabase: SenderKeyRecord? = SignalDatabase.senderKeys.load(sender, distributionId.toDistributionId()) + + if (fromDatabase != null) { + store[StoreKey(sender, distributionId)] = fromDatabase + } + + return fromDatabase + } + } + + override fun clearSenderKeySharedWith(addresses: MutableCollection) { + clearSharedWith.addAll(addresses) + } + + override fun getSenderKeySharedWith(distributionId: DistributionId?): MutableSet { + error("Should not happen during the intended usage pattern of this class") + } + + override fun markSenderKeySharedWith(distributionId: DistributionId?, addresses: MutableCollection?) { + error("Should not happen during the intended usage pattern of this class") + } + + fun flushToDisk(persistentStore: SignalServiceAccountDataStore) { + for ((key, record) in updatedKeys) { + persistentStore.storeSenderKey(key.address, key.distributionId, record) + } + + persistentStore.clearSenderKeySharedWith(clearSharedWith) + + updatedKeys.clear() + clearSharedWith.clear() + } + + private fun UUID.toDistributionId() = DistributionId.from(this) + + data class StoreKey( + val address: SignalProtocolAddress, + val distributionId: UUID + ) +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSessionStore.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSessionStore.kt new file mode 100644 index 0000000000..7f15a9b845 --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSessionStore.kt @@ -0,0 +1,115 @@ +package org.thoughtcrime.securesms.messages.protocol + +import org.signal.libsignal.protocol.NoSessionException +import org.signal.libsignal.protocol.SignalProtocolAddress +import org.signal.libsignal.protocol.message.CiphertextMessage +import org.signal.libsignal.protocol.state.SessionRecord +import org.thoughtcrime.securesms.database.SignalDatabase +import org.whispersystems.signalservice.api.SignalServiceAccountDataStore +import org.whispersystems.signalservice.api.SignalServiceSessionStore +import org.whispersystems.signalservice.api.push.ServiceId +import kotlin.jvm.Throws + +/** + * An in-memory session store that is intended to be used temporarily while decrypting messages. + */ +class BufferedSessionStore(private val selfServiceId: ServiceId) : SignalServiceSessionStore { + + private val store: MutableMap = HashMap() + + /** All of the sessions that have been created or updated during operation. */ + private val updatedSessions: MutableMap = mutableMapOf() + + /** All of the sessions that have deleted during operation. */ + private val deletedSessions: MutableSet = mutableSetOf() + + override fun loadSession(address: SignalProtocolAddress): SessionRecord { + val session: SessionRecord = store[address] + ?: SignalDatabase.sessions.load(selfServiceId, address) + ?: SessionRecord() + + store[address] = session + + return session + } + + @Throws(NoSessionException::class) + override fun loadExistingSessions(addresses: MutableList): List { + val found: MutableList = mutableListOf() + val needsDatabaseLookup: MutableList = mutableListOf() + + for (address in addresses) { + val cached: SessionRecord? = store[address] + + if (cached != null) { + found += cached + } else { + needsDatabaseLookup += address + } + } + + if (needsDatabaseLookup.isNotEmpty()) { + found += SignalDatabase.sessions.load(selfServiceId, needsDatabaseLookup).filterNotNull() + } + + if (found.size != addresses.size) { + throw NoSessionException("Failed to find one or more sessions.") + } + + return found + } + + override fun storeSession(address: SignalProtocolAddress, record: SessionRecord) { + store[address] = record + updatedSessions[address] = record + } + + override fun containsSession(address: SignalProtocolAddress): Boolean { + return if (store.containsKey(address)) { + true + } else { + val fromDatabase: SessionRecord? = SignalDatabase.sessions.load(selfServiceId, address) + + if (fromDatabase != null) { + store[address] = fromDatabase + return fromDatabase.hasSenderChain() && fromDatabase.sessionVersion == CiphertextMessage.CURRENT_VERSION + } else { + false + } + } + } + + override fun deleteSession(address: SignalProtocolAddress) { + store.remove(address) + deletedSessions += address + } + + override fun getSubDeviceSessions(name: String): MutableList { + error("Should not happen during the intended usage pattern of this class") + } + + override fun deleteAllSessions(name: String) { + error("Should not happen during the intended usage pattern of this class") + } + + override fun archiveSession(address: SignalProtocolAddress?) { + error("Should not happen during the intended usage pattern of this class") + } + + override fun getAllAddressesWithActiveSessions(addressNames: MutableList): Set { + error("Should not happen during the intended usage pattern of this class") + } + + fun flushToDisk(persistentStore: SignalServiceAccountDataStore) { + for ((address, record) in updatedSessions) { + persistentStore.storeSession(address, record) + } + + for (address in deletedSessions) { + persistentStore.deleteSession(address) + } + + updatedSessions.clear() + deletedSessions.clear() + } +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSignalServiceAccountDataStore.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSignalServiceAccountDataStore.kt new file mode 100644 index 0000000000..581837e025 --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSignalServiceAccountDataStore.kt @@ -0,0 +1,157 @@ +package org.thoughtcrime.securesms.messages.protocol + +import org.signal.libsignal.protocol.IdentityKey +import org.signal.libsignal.protocol.IdentityKeyPair +import org.signal.libsignal.protocol.SignalProtocolAddress +import org.signal.libsignal.protocol.groups.state.SenderKeyRecord +import org.signal.libsignal.protocol.state.IdentityKeyStore +import org.signal.libsignal.protocol.state.PreKeyRecord +import org.signal.libsignal.protocol.state.SessionRecord +import org.signal.libsignal.protocol.state.SignedPreKeyRecord +import org.thoughtcrime.securesms.keyvalue.SignalStore +import org.whispersystems.signalservice.api.SignalServiceAccountDataStore +import org.whispersystems.signalservice.api.push.DistributionId +import org.whispersystems.signalservice.api.push.ServiceId +import java.util.UUID + +/** + * The wrapper around all of the Buffered protocol stores. Designed to perform operations in memory, + * then [flushToDisk] at set intervals. + */ +class BufferedSignalServiceAccountDataStore(selfServiceId: ServiceId) : SignalServiceAccountDataStore { + + private val identityStore: BufferedIdentityKeyStore = if (selfServiceId == SignalStore.account().pni) { + BufferedIdentityKeyStore(selfServiceId, SignalStore.account().pniIdentityKey, SignalStore.account().pniRegistrationId) + } else { + BufferedIdentityKeyStore(selfServiceId, SignalStore.account().aciIdentityKey, SignalStore.account().registrationId) + } + + private val oneTimePreKeyStore: BufferedOneTimePreKeyStore = BufferedOneTimePreKeyStore(selfServiceId) + private val signedPreKeyStore: BufferedSignedPreKeyStore = BufferedSignedPreKeyStore(selfServiceId) + private val sessionStore: BufferedSessionStore = BufferedSessionStore(selfServiceId) + private val senderKeyStore: BufferedSenderKeyStore = BufferedSenderKeyStore() + + override fun getIdentityKeyPair(): IdentityKeyPair { + return identityStore.identityKeyPair + } + + override fun getLocalRegistrationId(): Int { + return identityStore.localRegistrationId + } + + override fun saveIdentity(address: SignalProtocolAddress, identityKey: IdentityKey): Boolean { + return identityStore.saveIdentity(address, identityKey) + } + + override fun isTrustedIdentity(address: SignalProtocolAddress, identityKey: IdentityKey, direction: IdentityKeyStore.Direction): Boolean { + return identityStore.isTrustedIdentity(address, identityKey, direction) + } + + override fun getIdentity(address: SignalProtocolAddress): IdentityKey? { + return identityStore.getIdentity(address) + } + + override fun loadPreKey(preKeyId: Int): PreKeyRecord { + return oneTimePreKeyStore.loadPreKey(preKeyId) + } + + override fun storePreKey(preKeyId: Int, record: PreKeyRecord) { + return oneTimePreKeyStore.storePreKey(preKeyId, record) + } + + override fun containsPreKey(preKeyId: Int): Boolean { + return oneTimePreKeyStore.containsPreKey(preKeyId) + } + + override fun removePreKey(preKeyId: Int) { + oneTimePreKeyStore.removePreKey(preKeyId) + } + + override fun loadSession(address: SignalProtocolAddress): SessionRecord { + return sessionStore.loadSession(address) + } + + override fun loadExistingSessions(addresses: MutableList): List { + return sessionStore.loadExistingSessions(addresses) + } + + override fun getSubDeviceSessions(name: String): MutableList { + return sessionStore.getSubDeviceSessions(name) + } + + override fun storeSession(address: SignalProtocolAddress, record: SessionRecord) { + sessionStore.storeSession(address, record) + } + + override fun containsSession(address: SignalProtocolAddress): Boolean { + return sessionStore.containsSession(address) + } + + override fun deleteSession(address: SignalProtocolAddress) { + return sessionStore.deleteSession(address) + } + + override fun deleteAllSessions(name: String) { + sessionStore.deleteAllSessions(name) + } + + override fun loadSignedPreKey(signedPreKeyId: Int): SignedPreKeyRecord { + return signedPreKeyStore.loadSignedPreKey(signedPreKeyId) + } + + override fun loadSignedPreKeys(): List { + return signedPreKeyStore.loadSignedPreKeys() + } + + override fun storeSignedPreKey(signedPreKeyId: Int, record: SignedPreKeyRecord) { + signedPreKeyStore.storeSignedPreKey(signedPreKeyId, record) + } + + override fun containsSignedPreKey(signedPreKeyId: Int): Boolean { + return signedPreKeyStore.containsSignedPreKey(signedPreKeyId) + } + + override fun removeSignedPreKey(signedPreKeyId: Int) { + signedPreKeyStore.removeSignedPreKey(signedPreKeyId) + } + + override fun storeSenderKey(sender: SignalProtocolAddress, distributionId: UUID, record: SenderKeyRecord) { + senderKeyStore.storeSenderKey(sender, distributionId, record) + } + + override fun loadSenderKey(sender: SignalProtocolAddress, distributionId: UUID): SenderKeyRecord? { + return senderKeyStore.loadSenderKey(sender, distributionId) + } + + override fun archiveSession(address: SignalProtocolAddress?) { + sessionStore.archiveSession(address) + } + + override fun getAllAddressesWithActiveSessions(addressNames: MutableList): Set { + return sessionStore.getAllAddressesWithActiveSessions(addressNames) + } + + override fun getSenderKeySharedWith(distributionId: DistributionId?): MutableSet { + return senderKeyStore.getSenderKeySharedWith(distributionId) + } + + override fun markSenderKeySharedWith(distributionId: DistributionId, addresses: MutableCollection) { + senderKeyStore.markSenderKeySharedWith(distributionId, addresses) + } + + override fun clearSenderKeySharedWith(addresses: MutableCollection) { + senderKeyStore.clearSenderKeySharedWith(addresses) + } + + override fun isMultiDevice(): Boolean { + error("Should not happen during the intended usage pattern of this class") + } + + fun flushToDisk(persistentStore: SignalServiceAccountDataStore) { + identityStore.flushToDisk(persistentStore) + oneTimePreKeyStore.flushToDisk(persistentStore) + signedPreKeyStore.flushToDisk(persistentStore) + sessionStore.flushToDisk(persistentStore) + senderKeyStore.flushToDisk(persistentStore) + } +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSignedPreKeyStore.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSignedPreKeyStore.kt new file mode 100644 index 0000000000..b42a52075e --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedSignedPreKeyStore.kt @@ -0,0 +1,64 @@ +package org.thoughtcrime.securesms.messages.protocol + +import org.signal.libsignal.protocol.InvalidKeyIdException +import org.signal.libsignal.protocol.state.SignedPreKeyRecord +import org.signal.libsignal.protocol.state.SignedPreKeyStore +import org.thoughtcrime.securesms.database.SignalDatabase +import org.whispersystems.signalservice.api.SignalServiceAccountDataStore +import org.whispersystems.signalservice.api.push.ServiceId + +/** + * An in-memory signed prekey store that is intended to be used temporarily while decrypting messages. + */ +class BufferedSignedPreKeyStore(private val selfServiceId: ServiceId) : SignedPreKeyStore { + + /** Our in-memory cache of signed prekeys. */ + private val store: MutableMap = HashMap() + + /** The signed prekeys that have been marked as removed */ + private val removed: MutableList = mutableListOf() + + /** Whether or not we've done a loadAll operation. Let's us avoid doing it twice. */ + private var hasLoadedAll: Boolean = false + + @kotlin.jvm.Throws(InvalidKeyIdException::class) + override fun loadSignedPreKey(id: Int): SignedPreKeyRecord { + return store.computeIfAbsent(id) { + SignalDatabase.signedPreKeys.get(selfServiceId, id) ?: throw InvalidKeyIdException("Missing one-time prekey with ID: $id") + } + } + + override fun loadSignedPreKeys(): List { + return if (hasLoadedAll) { + store.values.toList() + } else { + val records = SignalDatabase.signedPreKeys.getAll(selfServiceId) + records.forEach { store[it.id] = it } + hasLoadedAll = true + + records + } + } + + override fun storeSignedPreKey(id: Int, record: SignedPreKeyRecord) { + error("Should not happen during the intended usage pattern of this class") + } + + override fun containsSignedPreKey(id: Int): Boolean { + loadSignedPreKey(id) + return store.containsKey(id) + } + + override fun removeSignedPreKey(id: Int) { + store.remove(id) + removed += id + } + + fun flushToDisk(persistentStore: SignalServiceAccountDataStore) { + for (id in removed) { + persistentStore.removeSignedPreKey(id) + } + + removed.clear() + } +} diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalWebSocket.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalWebSocket.java index 1e778c66c0..a3ba6a24d6 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalWebSocket.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalWebSocket.java @@ -1,7 +1,10 @@ package org.whispersystems.signalservice.api; +import com.google.protobuf.InvalidProtocolBufferException; + import org.signal.libsignal.protocol.logging.Log; import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess; +import org.whispersystems.signalservice.api.messages.EnvelopeResponse; import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState; import org.whispersystems.signalservice.api.websocket.WebSocketFactory; import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException; @@ -13,9 +16,10 @@ import org.whispersystems.signalservice.internal.websocket.WebsocketResponse; import org.whispersystems.util.Base64; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicBoolean; import io.reactivex.rxjava3.core.Observable; import io.reactivex.rxjava3.core.Single; @@ -218,61 +222,105 @@ public final class SignalWebSocket { } /** - *

- * A blocking call that reads a message off the pipe. When this call returns, if the callback indicates the - * message was successfully processed, then the message will be ack'ed on the serve and will not be retransmitted. - *

- * This will return true if there are more messages to be read from the websocket, or false if the websocket is empty. - *

- * You can specify a {@link MessageReceivedCallback} that will be called before the received message is acknowledged. - * This allows you to write the received message to durable storage before acknowledging receipt of it to the - * server. - *

- * Important: This will only return `false` once for each connection. That means if you get false call readMessage() - * again on the same instance, you will not get an immediate `false` return value, and instead will block until - * you get an actual message. This will, however, reset if connection breaks (if, for instance, you lose and regain network). + * The reads a batch of messages off of the websocket. * - * @param timeout The timeout to wait for. - * @param callback A callback that will be called before the message receipt is acknowledged to the server. - * @return The message read (same as the message sent through the callback). + * Rather than just provide you the batch as a return value, it will invoke the provided callback with the + * batch as an argument. If you are able to successfully process them, this method will then ack all of the + * messages so that they won't be re-delivered in the future. + * + * The return value of this method is a boolean indicating whether or not there are more messages in the + * queue to be read (true if there's still more, or false if you've drained everything). + * + * However, this return value is only really useful the first time you read from the websocket. That's because + * the websocket will only ever let you know if it's drained *once* for any given connection. So if this method + * returns false, a subsequent call while using the same websocket connection will simply block until we either + * get a new message or hit the timeout. + * + * Concerning the requested batch size, it's worth noting that this is simply an upper bound. This method will + * not wait extra time until the batch has "filled up". Instead, it will wait for a single message, and then + * take any extra messages that are also available up until you've hit your batch size. */ @SuppressWarnings("DuplicateThrows") - public boolean readMessage(long timeout, MessageReceivedCallback callback) + public boolean readMessageBatch(long timeout, int batchSize, MessageReceivedCallback callback) throws TimeoutException, WebSocketUnavailableException, IOException { - while (true) { - WebSocketRequestMessage request = getWebSocket().readRequest(timeout); - WebSocketResponseMessage response = createWebSocketResponse(request); + List responses = new ArrayList<>(); + boolean hitEndOfQueue = false; - AtomicBoolean successfullyProcessed = new AtomicBoolean(false); + Optional firstEnvelope = waitForSingleMessage(timeout); - try { - if (isSignalServiceEnvelope(request)) { - Optional timestampHeader = findHeader(request); - long timestamp = 0; + if (firstEnvelope.isPresent()) { + responses.add(firstEnvelope.get()); + } else { + hitEndOfQueue = true; + } - if (timestampHeader.isPresent()) { - try { - timestamp = Long.parseLong(timestampHeader.get()); - } catch (NumberFormatException e) { - Log.w(TAG, "Failed to parse " + SERVER_DELIVERED_TIMESTAMP_HEADER); - } + if (!hitEndOfQueue) { + for (int i = 1; i < batchSize; i++) { + Optional request = getWebSocket().readRequestIfAvailable(); + + if (request.isPresent()) { + if (isSignalServiceEnvelope(request.get())) { + responses.add(requestToEnvelopeResponse(request.get())); + } else if (isSocketEmptyRequest(request.get())) { + hitEndOfQueue = true; + break; } - - SignalServiceProtos.Envelope envelope = SignalServiceProtos.Envelope.parseFrom(request.getBody().toByteArray()); - - successfullyProcessed.set(callback.onMessage(envelope, timestamp)); - - return true; - } else if (isSocketEmptyRequest(request)) { - return false; - } - } finally { - if (successfullyProcessed.get()) { - getWebSocket().sendResponse(response); + } else { + break; } } } + + if (responses.size() > 0) { + boolean successfullyProcessed = false; + + try { + successfullyProcessed = callback.onMessageBatch(responses); + } finally { + if (successfullyProcessed) { + for (EnvelopeResponse response : responses) { + getWebSocket().sendResponse(createWebSocketResponse(response.getWebsocketRequest())); + } + } + } + } + + return !hitEndOfQueue; + } + + @SuppressWarnings("DuplicateThrows") + private Optional waitForSingleMessage(long timeout) + throws TimeoutException, WebSocketUnavailableException, IOException + { + while (true) { + WebSocketRequestMessage request = getWebSocket().readRequest(timeout); + + if (isSignalServiceEnvelope(request)) { + return Optional.of(requestToEnvelopeResponse(request)); + } else if (isSocketEmptyRequest(request)) { + return Optional.empty(); + } + } + } + + private static EnvelopeResponse requestToEnvelopeResponse(WebSocketRequestMessage request) + throws InvalidProtocolBufferException + { + Optional timestampHeader = findHeader(request); + long timestamp = 0; + + if (timestampHeader.isPresent()) { + try { + timestamp = Long.parseLong(timestampHeader.get()); + } catch (NumberFormatException e) { + Log.w(TAG, "Failed to parse " + SERVER_DELIVERED_TIMESTAMP_HEADER); + } + } + + SignalServiceProtos.Envelope envelope = SignalServiceProtos.Envelope.parseFrom(request.getBody().toByteArray()); + + return new EnvelopeResponse(envelope, timestamp, request); } private static boolean isSignalServiceEnvelope(WebSocketRequestMessage message) { @@ -323,6 +371,6 @@ public final class SignalWebSocket { public interface MessageReceivedCallback { /** True if you successfully processed the message, otherwise false. **/ - boolean onMessage(SignalServiceProtos.Envelope envelope, long serverDeliveredTimestamp); + boolean onMessageBatch(List envelopeResponses); } } diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/messages/EnvelopeResponse.kt b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/messages/EnvelopeResponse.kt new file mode 100644 index 0000000000..f086a3d477 --- /dev/null +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/messages/EnvelopeResponse.kt @@ -0,0 +1,13 @@ +package org.whispersystems.signalservice.api.messages + +import org.whispersystems.signalservice.internal.push.SignalServiceProtos.Envelope +import org.whispersystems.signalservice.internal.websocket.WebSocketProtos.WebSocketRequestMessage + +/** + * Represents an envelope off the wire, paired with the metadata needed to process it. + */ +class EnvelopeResponse( + val envelope: Envelope, + val serverDeliveredTimestamp: Long, + val websocketRequest: WebSocketRequestMessage +) diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.java index 4c19b56357..dc6362d91a 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.java @@ -193,6 +193,14 @@ public class WebSocketConnection extends WebSocketListener { notifyAll(); } + public synchronized Optional readRequestIfAvailable() { + if (incomingRequests.size() > 0) { + return Optional.of(incomingRequests.removeFirst()); + } else { + return Optional.empty(); + } + } + public synchronized WebSocketRequestMessage readRequest(long timeoutMillis) throws TimeoutException, IOException {