Perform message decryptions in batches.
This commit is contained in:
parent
04baa7925f
commit
894095414a
17 changed files with 772 additions and 69 deletions
|
@ -6,6 +6,7 @@ import org.signal.libsignal.zkgroup.profiles.ProfileKey
|
|||
import org.thoughtcrime.securesms.crypto.ProfileKeyUtil
|
||||
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
|
||||
import org.thoughtcrime.securesms.keyvalue.SignalStore
|
||||
import org.thoughtcrime.securesms.messages.protocol.BufferedProtocolStore
|
||||
import org.thoughtcrime.securesms.recipients.Recipient
|
||||
import org.thoughtcrime.securesms.testing.FakeClientHelpers.toEnvelope
|
||||
import org.whispersystems.signalservice.api.push.ServiceId
|
||||
|
@ -35,7 +36,9 @@ class AliceClient(val serviceId: ServiceId, val e164: String, val trustRoot: ECK
|
|||
|
||||
fun process(envelope: Envelope, serverDeliveredTimestamp: Long) {
|
||||
val start = System.currentTimeMillis()
|
||||
ApplicationDependencies.getIncomingMessageObserver().processEnvelope(envelope, serverDeliveredTimestamp)
|
||||
val bufferedStore = BufferedProtocolStore.create()
|
||||
ApplicationDependencies.getIncomingMessageObserver().processEnvelope(bufferedStore, envelope, serverDeliveredTimestamp)
|
||||
bufferedStore.flushToDisk()
|
||||
val end = System.currentTimeMillis()
|
||||
Log.d(TAG, "${end - start}")
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.thoughtcrime.securesms.keyvalue.SignalStore
|
|||
import org.thoughtcrime.securesms.messages.MessageContentProcessor.ExceptionMetadata
|
||||
import org.thoughtcrime.securesms.messages.MessageContentProcessor.MessageState
|
||||
import org.thoughtcrime.securesms.messages.MessageDecryptor
|
||||
import org.thoughtcrime.securesms.messages.protocol.BufferedProtocolStore
|
||||
import org.thoughtcrime.securesms.notifications.NotificationChannels
|
||||
import org.thoughtcrime.securesms.notifications.NotificationIds
|
||||
import org.thoughtcrime.securesms.transport.RetryLaterException
|
||||
|
@ -77,7 +78,9 @@ class PushDecryptMessageJob private constructor(
|
|||
throw RetryLaterException()
|
||||
}
|
||||
|
||||
val result = MessageDecryptor.decrypt(context, envelope.proto, envelope.serverDeliveredTimestamp)
|
||||
val bufferedProtocolStore = BufferedProtocolStore.create()
|
||||
val result = MessageDecryptor.decrypt(context, bufferedProtocolStore, envelope.proto, envelope.serverDeliveredTimestamp)
|
||||
bufferedProtocolStore.flushToDisk()
|
||||
|
||||
when (result) {
|
||||
is MessageDecryptor.Result.Success -> {
|
||||
|
|
|
@ -11,6 +11,7 @@ import org.thoughtcrime.securesms.database.model.MessageId;
|
|||
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies;
|
||||
import org.thoughtcrime.securesms.jobmanager.Data;
|
||||
import org.thoughtcrime.securesms.jobmanager.Job;
|
||||
import org.thoughtcrime.securesms.jobmanager.impl.DecryptionsDrainedConstraint;
|
||||
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint;
|
||||
import org.thoughtcrime.securesms.net.NotPushRegisteredException;
|
||||
import org.thoughtcrime.securesms.recipients.Recipient;
|
||||
|
@ -51,6 +52,7 @@ public class SendDeliveryReceiptJob extends BaseJob {
|
|||
public SendDeliveryReceiptJob(@NonNull RecipientId recipientId, long messageSentTimestamp, @NonNull MessageId messageId) {
|
||||
this(new Job.Parameters.Builder()
|
||||
.addConstraint(NetworkConstraint.KEY)
|
||||
.addConstraint(DecryptionsDrainedConstraint.KEY)
|
||||
.setLifespan(TimeUnit.DAYS.toMillis(1))
|
||||
.setMaxAttempts(Parameters.UNLIMITED)
|
||||
.setQueue(recipientId.toQueueKey())
|
||||
|
|
|
@ -16,6 +16,7 @@ import org.thoughtcrime.securesms.dependencies.ApplicationDependencies;
|
|||
import org.thoughtcrime.securesms.jobmanager.Data;
|
||||
import org.thoughtcrime.securesms.jobmanager.Job;
|
||||
import org.thoughtcrime.securesms.jobmanager.JobManager;
|
||||
import org.thoughtcrime.securesms.jobmanager.impl.DecryptionsDrainedConstraint;
|
||||
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint;
|
||||
import org.thoughtcrime.securesms.net.NotPushRegisteredException;
|
||||
import org.thoughtcrime.securesms.recipients.Recipient;
|
||||
|
@ -65,6 +66,7 @@ public class SendReadReceiptJob extends BaseJob {
|
|||
public SendReadReceiptJob(long threadId, @NonNull RecipientId recipientId, List<Long> messageSentTimestamps, List<MessageId> messageIds) {
|
||||
this(new Job.Parameters.Builder()
|
||||
.addConstraint(NetworkConstraint.KEY)
|
||||
.addConstraint(DecryptionsDrainedConstraint.KEY)
|
||||
.setLifespan(TimeUnit.DAYS.toMillis(1))
|
||||
.setMaxAttempts(Parameters.UNLIMITED)
|
||||
.setQueue(recipientId.toQueueKey())
|
||||
|
|
|
@ -72,7 +72,7 @@ class PersistentLogger(
|
|||
}
|
||||
|
||||
private fun write(level: String, tag: String?, message: String?, t: Throwable?, keepLonger: Boolean) {
|
||||
logEntries.add(LogRequest(level, tag ?: "null", message, Date(), getThreadString(), t, keepLonger))
|
||||
logEntries.add(LogRequest(level, tag ?: "null", message, System.currentTimeMillis(), getThreadString(), t, keepLonger))
|
||||
}
|
||||
|
||||
private fun getThreadString(): String {
|
||||
|
@ -95,7 +95,7 @@ class PersistentLogger(
|
|||
val level: String,
|
||||
val tag: String,
|
||||
val message: String?,
|
||||
val date: Date,
|
||||
val createTime: Long,
|
||||
val threadString: String,
|
||||
val throwable: Throwable?,
|
||||
val keepLonger: Boolean
|
||||
|
@ -121,11 +121,13 @@ class PersistentLogger(
|
|||
fun requestToEntries(request: LogRequest): List<LogEntry> {
|
||||
val out = mutableListOf<LogEntry>()
|
||||
|
||||
val createDate = Date(request.createTime)
|
||||
|
||||
out.add(
|
||||
LogEntry(
|
||||
createdAt = request.date.time,
|
||||
createdAt = request.createTime,
|
||||
keepLonger = request.keepLonger,
|
||||
body = formatBody(request.threadString, request.date, request.level, request.tag, request.message)
|
||||
body = formatBody(request.threadString, createDate, request.level, request.tag, request.message)
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -138,9 +140,9 @@ class PersistentLogger(
|
|||
|
||||
val entries = lines.map { line ->
|
||||
LogEntry(
|
||||
createdAt = request.date.time,
|
||||
createdAt = request.createTime,
|
||||
keepLonger = request.keepLonger,
|
||||
body = formatBody(request.threadString, request.date, request.level, request.tag, line)
|
||||
body = formatBody(request.threadString, createDate, request.level, request.tag, line)
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -14,7 +14,9 @@ import androidx.core.app.NotificationCompat
|
|||
import org.signal.core.util.ThreadUtil
|
||||
import org.signal.core.util.concurrent.SignalExecutors
|
||||
import org.signal.core.util.logging.Log
|
||||
import org.signal.core.util.withinTransaction
|
||||
import org.thoughtcrime.securesms.R
|
||||
import org.thoughtcrime.securesms.crypto.ReentrantSessionLock
|
||||
import org.thoughtcrime.securesms.database.MessageTable
|
||||
import org.thoughtcrime.securesms.database.SignalDatabase
|
||||
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
|
||||
|
@ -28,6 +30,7 @@ import org.thoughtcrime.securesms.jobs.PushDecryptMessageJob
|
|||
import org.thoughtcrime.securesms.jobs.PushProcessMessageJob
|
||||
import org.thoughtcrime.securesms.jobs.UnableToStartException
|
||||
import org.thoughtcrime.securesms.keyvalue.SignalStore
|
||||
import org.thoughtcrime.securesms.messages.protocol.BufferedProtocolStore
|
||||
import org.thoughtcrime.securesms.notifications.NotificationChannels
|
||||
import org.thoughtcrime.securesms.recipients.RecipientId
|
||||
import org.thoughtcrime.securesms.util.AppForegroundObserver
|
||||
|
@ -200,7 +203,7 @@ class IncomingMessageObserver(private val context: Application) {
|
|||
|
||||
val needsConnectionString = if (conclusion) "Needs Connection" else "Does Not Need Connection"
|
||||
|
||||
Log.d(TAG, "[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisible, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, Stay open requests: [${keepAliveTokens.entries}], Registered: $registered, Proxy: $hasProxy, Force websocket: $forceWebsocket, Decrypt Queue Empty: $decryptQueueEmpty")
|
||||
Log.d(TAG, "[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisible, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, Stay open requests: ${keepAliveTokens.entries}, Registered: $registered, Proxy: $hasProxy, Force websocket: $forceWebsocket, Decrypt Queue Empty: $decryptQueueEmpty")
|
||||
return conclusion
|
||||
}
|
||||
}
|
||||
|
@ -249,19 +252,29 @@ class IncomingMessageObserver(private val context: Application) {
|
|||
}
|
||||
|
||||
@VisibleForTesting
|
||||
fun processEnvelope(envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long) {
|
||||
when (envelope.type.number) {
|
||||
SignalServiceProtos.Envelope.Type.RECEIPT_VALUE -> processReceipt(envelope)
|
||||
fun processEnvelope(bufferedProtocolStore: BufferedProtocolStore, envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long): List<Runnable>? {
|
||||
return when (envelope.type.number) {
|
||||
SignalServiceProtos.Envelope.Type.RECEIPT_VALUE -> {
|
||||
processReceipt(envelope)
|
||||
null
|
||||
}
|
||||
|
||||
SignalServiceProtos.Envelope.Type.PREKEY_BUNDLE_VALUE,
|
||||
SignalServiceProtos.Envelope.Type.CIPHERTEXT_VALUE,
|
||||
SignalServiceProtos.Envelope.Type.UNIDENTIFIED_SENDER_VALUE,
|
||||
SignalServiceProtos.Envelope.Type.PLAINTEXT_CONTENT_VALUE -> processMessage(envelope, serverDeliveredTimestamp)
|
||||
else -> Log.w(TAG, "Received envelope of unknown type: " + envelope.type)
|
||||
SignalServiceProtos.Envelope.Type.PLAINTEXT_CONTENT_VALUE -> {
|
||||
processMessage(bufferedProtocolStore, envelope, serverDeliveredTimestamp)
|
||||
}
|
||||
|
||||
else -> {
|
||||
Log.w(TAG, "Received envelope of unknown type: " + envelope.type)
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun processMessage(envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long) {
|
||||
val result = MessageDecryptor.decrypt(context, envelope, serverDeliveredTimestamp)
|
||||
private fun processMessage(bufferedProtocolStore: BufferedProtocolStore, envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long): List<Runnable> {
|
||||
val result = MessageDecryptor.decrypt(context, bufferedProtocolStore, envelope, serverDeliveredTimestamp)
|
||||
|
||||
when (result) {
|
||||
is MessageDecryptor.Result.Success -> {
|
||||
|
@ -297,7 +310,7 @@ class IncomingMessageObserver(private val context: Application) {
|
|||
}
|
||||
}
|
||||
|
||||
result.followUpOperations.forEach { it.run() }
|
||||
return result.followUpOperations
|
||||
}
|
||||
|
||||
private fun processReceipt(envelope: SignalServiceProtos.Envelope) {
|
||||
|
@ -386,13 +399,31 @@ class IncomingMessageObserver(private val context: Application) {
|
|||
|
||||
signalWebSocket.connect()
|
||||
try {
|
||||
val bufferedStore = BufferedProtocolStore.create()
|
||||
|
||||
while (isConnectionNecessary()) {
|
||||
try {
|
||||
Log.d(TAG, "Reading message...")
|
||||
|
||||
val hasMore = signalWebSocket.readMessage(WEBSOCKET_READ_TIMEOUT) { envelope, serverDeliveredTimestamp ->
|
||||
Log.i(TAG, "Retrieved envelope! " + envelope.timestamp)
|
||||
processEnvelope(envelope, serverDeliveredTimestamp)
|
||||
val hasMore = signalWebSocket.readMessageBatch(WEBSOCKET_READ_TIMEOUT, 30) { batch ->
|
||||
Log.i(TAG, "Retrieved ${batch.size} envelopes!")
|
||||
|
||||
val startTime = System.currentTimeMillis()
|
||||
ReentrantSessionLock.INSTANCE.acquire().use {
|
||||
SignalDatabase.rawDatabase.withinTransaction {
|
||||
val followUpOperations = batch
|
||||
.mapNotNull { processEnvelope(bufferedStore, it.envelope, it.serverDeliveredTimestamp) }
|
||||
.flatten()
|
||||
|
||||
bufferedStore.flushToDisk()
|
||||
|
||||
followUpOperations.forEach { it.run() }
|
||||
}
|
||||
}
|
||||
|
||||
val duration = System.currentTimeMillis() - startTime
|
||||
Log.d(TAG, "Decrypted ${batch.size} envelopes in $duration ms (~${duration / batch.size} ms per message)")
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
|
|
|
@ -38,13 +38,13 @@ import org.thoughtcrime.securesms.jobs.PreKeysSyncJob
|
|||
import org.thoughtcrime.securesms.jobs.SendRetryReceiptJob
|
||||
import org.thoughtcrime.securesms.keyvalue.SignalStore
|
||||
import org.thoughtcrime.securesms.logsubmit.SubmitDebugLogActivity
|
||||
import org.thoughtcrime.securesms.messages.protocol.BufferedProtocolStore
|
||||
import org.thoughtcrime.securesms.notifications.NotificationChannels
|
||||
import org.thoughtcrime.securesms.notifications.NotificationIds
|
||||
import org.thoughtcrime.securesms.recipients.Recipient
|
||||
import org.thoughtcrime.securesms.recipients.RecipientId
|
||||
import org.thoughtcrime.securesms.util.FeatureFlags
|
||||
import org.whispersystems.signalservice.api.InvalidMessageStructureException
|
||||
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
|
||||
import org.whispersystems.signalservice.api.crypto.ContentHint
|
||||
import org.whispersystems.signalservice.api.crypto.EnvelopeMetadata
|
||||
import org.whispersystems.signalservice.api.crypto.SignalServiceCipher
|
||||
|
@ -74,7 +74,12 @@ object MessageDecryptor {
|
|||
* To keep that property, there may be [Result.followUpOperations] you have to perform after your transaction is committed.
|
||||
* These can vary from enqueueing jobs to inserting items into the [org.thoughtcrime.securesms.database.PendingRetryReceiptCache].
|
||||
*/
|
||||
fun decrypt(context: Context, envelope: Envelope, serverDeliveredTimestamp: Long): Result {
|
||||
fun decrypt(
|
||||
context: Context,
|
||||
bufferedProtocolStore: BufferedProtocolStore,
|
||||
envelope: Envelope,
|
||||
serverDeliveredTimestamp: Long
|
||||
): Result {
|
||||
val selfAci: ServiceId = SignalStore.account().requireAci()
|
||||
val selfPni: ServiceId = SignalStore.account().requirePni()
|
||||
|
||||
|
@ -106,9 +111,9 @@ object MessageDecryptor {
|
|||
}
|
||||
}
|
||||
|
||||
val protocolStore: SignalServiceAccountDataStore = ApplicationDependencies.getProtocolStore().get(destination)
|
||||
val bufferedStore = bufferedProtocolStore.get(destination)
|
||||
val localAddress = SignalServiceAddress(selfAci, SignalStore.account().e164)
|
||||
val cipher = SignalServiceCipher(localAddress, SignalStore.account().deviceId, protocolStore, ReentrantSessionLock.INSTANCE, UnidentifiedAccessUtil.getCertificateValidator())
|
||||
val cipher = SignalServiceCipher(localAddress, SignalStore.account().deviceId, bufferedStore, ReentrantSessionLock.INSTANCE, UnidentifiedAccessUtil.getCertificateValidator())
|
||||
|
||||
return try {
|
||||
val cipherResult: SignalServiceCipherResult? = cipher.decrypt(envelope, serverDeliveredTimestamp)
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
package org.thoughtcrime.securesms.messages.protocol
|
||||
|
||||
import org.signal.libsignal.protocol.IdentityKey
|
||||
import org.signal.libsignal.protocol.IdentityKeyPair
|
||||
import org.signal.libsignal.protocol.SignalProtocolAddress
|
||||
import org.signal.libsignal.protocol.state.IdentityKeyStore
|
||||
import org.thoughtcrime.securesms.database.SignalDatabase
|
||||
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
|
||||
import org.whispersystems.signalservice.api.push.ServiceId
|
||||
|
||||
/**
|
||||
* An in-memory identity key store that is intended to be used temporarily while decrypting messages.
|
||||
*/
|
||||
class BufferedIdentityKeyStore(
|
||||
private val selfServiceId: ServiceId,
|
||||
private val selfIdentityKeyPair: IdentityKeyPair,
|
||||
private val selfRegistrationId: Int
|
||||
) : IdentityKeyStore {
|
||||
|
||||
private val store: MutableMap<SignalProtocolAddress, IdentityKey> = HashMap()
|
||||
|
||||
/** All of the keys that have been created or updated during operation. */
|
||||
private val updatedKeys: MutableMap<SignalProtocolAddress, IdentityKey> = mutableMapOf()
|
||||
|
||||
override fun getIdentityKeyPair(): IdentityKeyPair {
|
||||
return selfIdentityKeyPair
|
||||
}
|
||||
|
||||
override fun getLocalRegistrationId(): Int {
|
||||
return selfRegistrationId
|
||||
}
|
||||
|
||||
override fun saveIdentity(address: SignalProtocolAddress, identityKey: IdentityKey): Boolean {
|
||||
val existing: IdentityKey? = getIdentity(address)
|
||||
|
||||
store[address] = identityKey
|
||||
|
||||
return if (identityKey != existing) {
|
||||
updatedKeys[address] = identityKey
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
override fun isTrustedIdentity(address: SignalProtocolAddress, identityKey: IdentityKey, direction: IdentityKeyStore.Direction): Boolean {
|
||||
if (address.name == selfServiceId.toString()) {
|
||||
return identityKey == selfIdentityKeyPair.publicKey
|
||||
}
|
||||
|
||||
return when (direction) {
|
||||
IdentityKeyStore.Direction.RECEIVING -> true
|
||||
IdentityKeyStore.Direction.SENDING -> error("Should not happen during the intended usage pattern of this class")
|
||||
else -> error("Unknown direction: $direction")
|
||||
}
|
||||
}
|
||||
|
||||
override fun getIdentity(address: SignalProtocolAddress): IdentityKey? {
|
||||
val cached = store[address]
|
||||
|
||||
return if (cached != null) {
|
||||
cached
|
||||
} else {
|
||||
val fromDatabase = SignalDatabase.identities.getIdentityStoreRecord(address.name)
|
||||
if (fromDatabase != null) {
|
||||
store[address] = fromDatabase.identityKey
|
||||
}
|
||||
|
||||
fromDatabase?.identityKey
|
||||
}
|
||||
}
|
||||
|
||||
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
|
||||
for ((address, identityKey) in updatedKeys) {
|
||||
persistentStore.saveIdentity(address, identityKey)
|
||||
}
|
||||
|
||||
updatedKeys.clear()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
package org.thoughtcrime.securesms.messages.protocol
|
||||
|
||||
import org.signal.libsignal.protocol.InvalidKeyIdException
|
||||
import org.signal.libsignal.protocol.state.PreKeyRecord
|
||||
import org.signal.libsignal.protocol.state.PreKeyStore
|
||||
import org.thoughtcrime.securesms.database.SignalDatabase
|
||||
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
|
||||
import org.whispersystems.signalservice.api.push.ServiceId
|
||||
|
||||
/**
|
||||
* An in-memory one-time prekey store that is intended to be used temporarily while decrypting messages.
|
||||
*/
|
||||
class BufferedOneTimePreKeyStore(private val selfServiceId: ServiceId) : PreKeyStore {
|
||||
|
||||
/** Our in-memory cache of one-time prekeys. */
|
||||
private val store: MutableMap<Int, PreKeyRecord> = HashMap()
|
||||
|
||||
/** The one-time prekeys that have been marked as removed */
|
||||
private val removed: MutableList<Int> = mutableListOf()
|
||||
|
||||
@kotlin.jvm.Throws(InvalidKeyIdException::class)
|
||||
override fun loadPreKey(id: Int): PreKeyRecord {
|
||||
return store.computeIfAbsent(id) {
|
||||
SignalDatabase.oneTimePreKeys.get(selfServiceId, id) ?: throw InvalidKeyIdException("Missing one-time prekey with ID: $id")
|
||||
}
|
||||
}
|
||||
|
||||
override fun storePreKey(id: Int, record: PreKeyRecord) {
|
||||
error("Should not happen during the intended usage pattern of this class")
|
||||
}
|
||||
|
||||
override fun containsPreKey(id: Int): Boolean {
|
||||
loadPreKey(id)
|
||||
return store.containsKey(id)
|
||||
}
|
||||
|
||||
override fun removePreKey(id: Int) {
|
||||
store.remove(id)
|
||||
removed += id
|
||||
}
|
||||
|
||||
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
|
||||
for (id in removed) {
|
||||
persistentStore.removePreKey(id)
|
||||
}
|
||||
|
||||
removed.clear()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
package org.thoughtcrime.securesms.messages.protocol
|
||||
|
||||
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
|
||||
import org.thoughtcrime.securesms.keyvalue.SignalStore
|
||||
import org.whispersystems.signalservice.api.push.ServiceId
|
||||
|
||||
/**
|
||||
* The entry point for creating and retrieving buffered protocol stores.
|
||||
* These stores will read from disk, but never write, instead buffering the results in memory.
|
||||
* You can then call [flushToDisk] in order to write the buffered results to disk.
|
||||
*
|
||||
* This allows you to efficiently do batches of work and avoid unnecessary intermediate writes.
|
||||
*/
|
||||
class BufferedProtocolStore private constructor(
|
||||
private val aciStore: Pair<ServiceId, BufferedSignalServiceAccountDataStore>,
|
||||
private val pniStore: Pair<ServiceId, BufferedSignalServiceAccountDataStore>
|
||||
) {
|
||||
|
||||
fun get(serviceId: ServiceId): BufferedSignalServiceAccountDataStore {
|
||||
return when (serviceId) {
|
||||
aciStore.first -> aciStore.second
|
||||
pniStore.first -> pniStore.second
|
||||
else -> error("No store matching serviceId $serviceId")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes any buffered data to disk. You can continue to use the same buffered store afterwards.
|
||||
*/
|
||||
fun flushToDisk() {
|
||||
aciStore.second.flushToDisk(ApplicationDependencies.getProtocolStore().aci())
|
||||
pniStore.second.flushToDisk(ApplicationDependencies.getProtocolStore().pni())
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun create(): BufferedProtocolStore {
|
||||
val aci = SignalStore.account().requireAci()
|
||||
val pni = SignalStore.account().requirePni()
|
||||
|
||||
return BufferedProtocolStore(
|
||||
aciStore = aci to BufferedSignalServiceAccountDataStore(aci),
|
||||
pniStore = pni to BufferedSignalServiceAccountDataStore(pni)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
package org.thoughtcrime.securesms.messages.protocol
|
||||
|
||||
import org.signal.libsignal.protocol.SignalProtocolAddress
|
||||
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord
|
||||
import org.thoughtcrime.securesms.database.SignalDatabase
|
||||
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
|
||||
import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore
|
||||
import org.whispersystems.signalservice.api.push.DistributionId
|
||||
import java.util.UUID
|
||||
|
||||
/**
|
||||
* An in-memory sender key store that is intended to be used temporarily while decrypting messages.
|
||||
*/
|
||||
class BufferedSenderKeyStore : SignalServiceSenderKeyStore {
|
||||
|
||||
private val store: MutableMap<StoreKey, SenderKeyRecord> = HashMap()
|
||||
|
||||
/** All of the keys that have been created or updated during operation. */
|
||||
private val updatedKeys: MutableMap<StoreKey, SenderKeyRecord> = mutableMapOf()
|
||||
|
||||
/** All of the distributionId's whose sharing has been cleared during operation. */
|
||||
private val clearSharedWith: MutableSet<SignalProtocolAddress> = mutableSetOf()
|
||||
|
||||
override fun storeSenderKey(sender: SignalProtocolAddress, distributionId: UUID, record: SenderKeyRecord) {
|
||||
val key = StoreKey(sender, distributionId)
|
||||
store[key] = record
|
||||
updatedKeys[key] = record
|
||||
}
|
||||
|
||||
override fun loadSenderKey(sender: SignalProtocolAddress, distributionId: UUID): SenderKeyRecord? {
|
||||
val cached: SenderKeyRecord? = store[StoreKey(sender, distributionId)]
|
||||
|
||||
return if (cached != null) {
|
||||
cached
|
||||
} else {
|
||||
val fromDatabase: SenderKeyRecord? = SignalDatabase.senderKeys.load(sender, distributionId.toDistributionId())
|
||||
|
||||
if (fromDatabase != null) {
|
||||
store[StoreKey(sender, distributionId)] = fromDatabase
|
||||
}
|
||||
|
||||
return fromDatabase
|
||||
}
|
||||
}
|
||||
|
||||
override fun clearSenderKeySharedWith(addresses: MutableCollection<SignalProtocolAddress>) {
|
||||
clearSharedWith.addAll(addresses)
|
||||
}
|
||||
|
||||
override fun getSenderKeySharedWith(distributionId: DistributionId?): MutableSet<SignalProtocolAddress> {
|
||||
error("Should not happen during the intended usage pattern of this class")
|
||||
}
|
||||
|
||||
override fun markSenderKeySharedWith(distributionId: DistributionId?, addresses: MutableCollection<SignalProtocolAddress>?) {
|
||||
error("Should not happen during the intended usage pattern of this class")
|
||||
}
|
||||
|
||||
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
|
||||
for ((key, record) in updatedKeys) {
|
||||
persistentStore.storeSenderKey(key.address, key.distributionId, record)
|
||||
}
|
||||
|
||||
persistentStore.clearSenderKeySharedWith(clearSharedWith)
|
||||
|
||||
updatedKeys.clear()
|
||||
clearSharedWith.clear()
|
||||
}
|
||||
|
||||
private fun UUID.toDistributionId() = DistributionId.from(this)
|
||||
|
||||
data class StoreKey(
|
||||
val address: SignalProtocolAddress,
|
||||
val distributionId: UUID
|
||||
)
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
package org.thoughtcrime.securesms.messages.protocol
|
||||
|
||||
import org.signal.libsignal.protocol.NoSessionException
|
||||
import org.signal.libsignal.protocol.SignalProtocolAddress
|
||||
import org.signal.libsignal.protocol.message.CiphertextMessage
|
||||
import org.signal.libsignal.protocol.state.SessionRecord
|
||||
import org.thoughtcrime.securesms.database.SignalDatabase
|
||||
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
|
||||
import org.whispersystems.signalservice.api.SignalServiceSessionStore
|
||||
import org.whispersystems.signalservice.api.push.ServiceId
|
||||
import kotlin.jvm.Throws
|
||||
|
||||
/**
|
||||
* An in-memory session store that is intended to be used temporarily while decrypting messages.
|
||||
*/
|
||||
class BufferedSessionStore(private val selfServiceId: ServiceId) : SignalServiceSessionStore {
|
||||
|
||||
private val store: MutableMap<SignalProtocolAddress, SessionRecord> = HashMap()
|
||||
|
||||
/** All of the sessions that have been created or updated during operation. */
|
||||
private val updatedSessions: MutableMap<SignalProtocolAddress, SessionRecord> = mutableMapOf()
|
||||
|
||||
/** All of the sessions that have deleted during operation. */
|
||||
private val deletedSessions: MutableSet<SignalProtocolAddress> = mutableSetOf()
|
||||
|
||||
override fun loadSession(address: SignalProtocolAddress): SessionRecord {
|
||||
val session: SessionRecord = store[address]
|
||||
?: SignalDatabase.sessions.load(selfServiceId, address)
|
||||
?: SessionRecord()
|
||||
|
||||
store[address] = session
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
@Throws(NoSessionException::class)
|
||||
override fun loadExistingSessions(addresses: MutableList<SignalProtocolAddress>): List<SessionRecord> {
|
||||
val found: MutableList<SessionRecord> = mutableListOf()
|
||||
val needsDatabaseLookup: MutableList<SignalProtocolAddress> = mutableListOf()
|
||||
|
||||
for (address in addresses) {
|
||||
val cached: SessionRecord? = store[address]
|
||||
|
||||
if (cached != null) {
|
||||
found += cached
|
||||
} else {
|
||||
needsDatabaseLookup += address
|
||||
}
|
||||
}
|
||||
|
||||
if (needsDatabaseLookup.isNotEmpty()) {
|
||||
found += SignalDatabase.sessions.load(selfServiceId, needsDatabaseLookup).filterNotNull()
|
||||
}
|
||||
|
||||
if (found.size != addresses.size) {
|
||||
throw NoSessionException("Failed to find one or more sessions.")
|
||||
}
|
||||
|
||||
return found
|
||||
}
|
||||
|
||||
override fun storeSession(address: SignalProtocolAddress, record: SessionRecord) {
|
||||
store[address] = record
|
||||
updatedSessions[address] = record
|
||||
}
|
||||
|
||||
override fun containsSession(address: SignalProtocolAddress): Boolean {
|
||||
return if (store.containsKey(address)) {
|
||||
true
|
||||
} else {
|
||||
val fromDatabase: SessionRecord? = SignalDatabase.sessions.load(selfServiceId, address)
|
||||
|
||||
if (fromDatabase != null) {
|
||||
store[address] = fromDatabase
|
||||
return fromDatabase.hasSenderChain() && fromDatabase.sessionVersion == CiphertextMessage.CURRENT_VERSION
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun deleteSession(address: SignalProtocolAddress) {
|
||||
store.remove(address)
|
||||
deletedSessions += address
|
||||
}
|
||||
|
||||
override fun getSubDeviceSessions(name: String): MutableList<Int> {
|
||||
error("Should not happen during the intended usage pattern of this class")
|
||||
}
|
||||
|
||||
override fun deleteAllSessions(name: String) {
|
||||
error("Should not happen during the intended usage pattern of this class")
|
||||
}
|
||||
|
||||
override fun archiveSession(address: SignalProtocolAddress?) {
|
||||
error("Should not happen during the intended usage pattern of this class")
|
||||
}
|
||||
|
||||
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Set<SignalProtocolAddress> {
|
||||
error("Should not happen during the intended usage pattern of this class")
|
||||
}
|
||||
|
||||
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
|
||||
for ((address, record) in updatedSessions) {
|
||||
persistentStore.storeSession(address, record)
|
||||
}
|
||||
|
||||
for (address in deletedSessions) {
|
||||
persistentStore.deleteSession(address)
|
||||
}
|
||||
|
||||
updatedSessions.clear()
|
||||
deletedSessions.clear()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,157 @@
|
|||
package org.thoughtcrime.securesms.messages.protocol
|
||||
|
||||
import org.signal.libsignal.protocol.IdentityKey
|
||||
import org.signal.libsignal.protocol.IdentityKeyPair
|
||||
import org.signal.libsignal.protocol.SignalProtocolAddress
|
||||
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord
|
||||
import org.signal.libsignal.protocol.state.IdentityKeyStore
|
||||
import org.signal.libsignal.protocol.state.PreKeyRecord
|
||||
import org.signal.libsignal.protocol.state.SessionRecord
|
||||
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
|
||||
import org.thoughtcrime.securesms.keyvalue.SignalStore
|
||||
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
|
||||
import org.whispersystems.signalservice.api.push.DistributionId
|
||||
import org.whispersystems.signalservice.api.push.ServiceId
|
||||
import java.util.UUID
|
||||
|
||||
/**
|
||||
* The wrapper around all of the Buffered protocol stores. Designed to perform operations in memory,
|
||||
* then [flushToDisk] at set intervals.
|
||||
*/
|
||||
class BufferedSignalServiceAccountDataStore(selfServiceId: ServiceId) : SignalServiceAccountDataStore {
|
||||
|
||||
private val identityStore: BufferedIdentityKeyStore = if (selfServiceId == SignalStore.account().pni) {
|
||||
BufferedIdentityKeyStore(selfServiceId, SignalStore.account().pniIdentityKey, SignalStore.account().pniRegistrationId)
|
||||
} else {
|
||||
BufferedIdentityKeyStore(selfServiceId, SignalStore.account().aciIdentityKey, SignalStore.account().registrationId)
|
||||
}
|
||||
|
||||
private val oneTimePreKeyStore: BufferedOneTimePreKeyStore = BufferedOneTimePreKeyStore(selfServiceId)
|
||||
private val signedPreKeyStore: BufferedSignedPreKeyStore = BufferedSignedPreKeyStore(selfServiceId)
|
||||
private val sessionStore: BufferedSessionStore = BufferedSessionStore(selfServiceId)
|
||||
private val senderKeyStore: BufferedSenderKeyStore = BufferedSenderKeyStore()
|
||||
|
||||
override fun getIdentityKeyPair(): IdentityKeyPair {
|
||||
return identityStore.identityKeyPair
|
||||
}
|
||||
|
||||
override fun getLocalRegistrationId(): Int {
|
||||
return identityStore.localRegistrationId
|
||||
}
|
||||
|
||||
override fun saveIdentity(address: SignalProtocolAddress, identityKey: IdentityKey): Boolean {
|
||||
return identityStore.saveIdentity(address, identityKey)
|
||||
}
|
||||
|
||||
override fun isTrustedIdentity(address: SignalProtocolAddress, identityKey: IdentityKey, direction: IdentityKeyStore.Direction): Boolean {
|
||||
return identityStore.isTrustedIdentity(address, identityKey, direction)
|
||||
}
|
||||
|
||||
override fun getIdentity(address: SignalProtocolAddress): IdentityKey? {
|
||||
return identityStore.getIdentity(address)
|
||||
}
|
||||
|
||||
override fun loadPreKey(preKeyId: Int): PreKeyRecord {
|
||||
return oneTimePreKeyStore.loadPreKey(preKeyId)
|
||||
}
|
||||
|
||||
override fun storePreKey(preKeyId: Int, record: PreKeyRecord) {
|
||||
return oneTimePreKeyStore.storePreKey(preKeyId, record)
|
||||
}
|
||||
|
||||
override fun containsPreKey(preKeyId: Int): Boolean {
|
||||
return oneTimePreKeyStore.containsPreKey(preKeyId)
|
||||
}
|
||||
|
||||
override fun removePreKey(preKeyId: Int) {
|
||||
oneTimePreKeyStore.removePreKey(preKeyId)
|
||||
}
|
||||
|
||||
override fun loadSession(address: SignalProtocolAddress): SessionRecord {
|
||||
return sessionStore.loadSession(address)
|
||||
}
|
||||
|
||||
override fun loadExistingSessions(addresses: MutableList<SignalProtocolAddress>): List<SessionRecord> {
|
||||
return sessionStore.loadExistingSessions(addresses)
|
||||
}
|
||||
|
||||
override fun getSubDeviceSessions(name: String): MutableList<Int> {
|
||||
return sessionStore.getSubDeviceSessions(name)
|
||||
}
|
||||
|
||||
override fun storeSession(address: SignalProtocolAddress, record: SessionRecord) {
|
||||
sessionStore.storeSession(address, record)
|
||||
}
|
||||
|
||||
override fun containsSession(address: SignalProtocolAddress): Boolean {
|
||||
return sessionStore.containsSession(address)
|
||||
}
|
||||
|
||||
override fun deleteSession(address: SignalProtocolAddress) {
|
||||
return sessionStore.deleteSession(address)
|
||||
}
|
||||
|
||||
override fun deleteAllSessions(name: String) {
|
||||
sessionStore.deleteAllSessions(name)
|
||||
}
|
||||
|
||||
override fun loadSignedPreKey(signedPreKeyId: Int): SignedPreKeyRecord {
|
||||
return signedPreKeyStore.loadSignedPreKey(signedPreKeyId)
|
||||
}
|
||||
|
||||
override fun loadSignedPreKeys(): List<SignedPreKeyRecord> {
|
||||
return signedPreKeyStore.loadSignedPreKeys()
|
||||
}
|
||||
|
||||
override fun storeSignedPreKey(signedPreKeyId: Int, record: SignedPreKeyRecord) {
|
||||
signedPreKeyStore.storeSignedPreKey(signedPreKeyId, record)
|
||||
}
|
||||
|
||||
override fun containsSignedPreKey(signedPreKeyId: Int): Boolean {
|
||||
return signedPreKeyStore.containsSignedPreKey(signedPreKeyId)
|
||||
}
|
||||
|
||||
override fun removeSignedPreKey(signedPreKeyId: Int) {
|
||||
signedPreKeyStore.removeSignedPreKey(signedPreKeyId)
|
||||
}
|
||||
|
||||
override fun storeSenderKey(sender: SignalProtocolAddress, distributionId: UUID, record: SenderKeyRecord) {
|
||||
senderKeyStore.storeSenderKey(sender, distributionId, record)
|
||||
}
|
||||
|
||||
override fun loadSenderKey(sender: SignalProtocolAddress, distributionId: UUID): SenderKeyRecord? {
|
||||
return senderKeyStore.loadSenderKey(sender, distributionId)
|
||||
}
|
||||
|
||||
override fun archiveSession(address: SignalProtocolAddress?) {
|
||||
sessionStore.archiveSession(address)
|
||||
}
|
||||
|
||||
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Set<SignalProtocolAddress> {
|
||||
return sessionStore.getAllAddressesWithActiveSessions(addressNames)
|
||||
}
|
||||
|
||||
override fun getSenderKeySharedWith(distributionId: DistributionId?): MutableSet<SignalProtocolAddress> {
|
||||
return senderKeyStore.getSenderKeySharedWith(distributionId)
|
||||
}
|
||||
|
||||
override fun markSenderKeySharedWith(distributionId: DistributionId, addresses: MutableCollection<SignalProtocolAddress>) {
|
||||
senderKeyStore.markSenderKeySharedWith(distributionId, addresses)
|
||||
}
|
||||
|
||||
override fun clearSenderKeySharedWith(addresses: MutableCollection<SignalProtocolAddress>) {
|
||||
senderKeyStore.clearSenderKeySharedWith(addresses)
|
||||
}
|
||||
|
||||
override fun isMultiDevice(): Boolean {
|
||||
error("Should not happen during the intended usage pattern of this class")
|
||||
}
|
||||
|
||||
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
|
||||
identityStore.flushToDisk(persistentStore)
|
||||
oneTimePreKeyStore.flushToDisk(persistentStore)
|
||||
signedPreKeyStore.flushToDisk(persistentStore)
|
||||
sessionStore.flushToDisk(persistentStore)
|
||||
senderKeyStore.flushToDisk(persistentStore)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
package org.thoughtcrime.securesms.messages.protocol
|
||||
|
||||
import org.signal.libsignal.protocol.InvalidKeyIdException
|
||||
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
|
||||
import org.signal.libsignal.protocol.state.SignedPreKeyStore
|
||||
import org.thoughtcrime.securesms.database.SignalDatabase
|
||||
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
|
||||
import org.whispersystems.signalservice.api.push.ServiceId
|
||||
|
||||
/**
|
||||
* An in-memory signed prekey store that is intended to be used temporarily while decrypting messages.
|
||||
*/
|
||||
class BufferedSignedPreKeyStore(private val selfServiceId: ServiceId) : SignedPreKeyStore {
|
||||
|
||||
/** Our in-memory cache of signed prekeys. */
|
||||
private val store: MutableMap<Int, SignedPreKeyRecord> = HashMap()
|
||||
|
||||
/** The signed prekeys that have been marked as removed */
|
||||
private val removed: MutableList<Int> = mutableListOf()
|
||||
|
||||
/** Whether or not we've done a loadAll operation. Let's us avoid doing it twice. */
|
||||
private var hasLoadedAll: Boolean = false
|
||||
|
||||
@kotlin.jvm.Throws(InvalidKeyIdException::class)
|
||||
override fun loadSignedPreKey(id: Int): SignedPreKeyRecord {
|
||||
return store.computeIfAbsent(id) {
|
||||
SignalDatabase.signedPreKeys.get(selfServiceId, id) ?: throw InvalidKeyIdException("Missing one-time prekey with ID: $id")
|
||||
}
|
||||
}
|
||||
|
||||
override fun loadSignedPreKeys(): List<SignedPreKeyRecord> {
|
||||
return if (hasLoadedAll) {
|
||||
store.values.toList()
|
||||
} else {
|
||||
val records = SignalDatabase.signedPreKeys.getAll(selfServiceId)
|
||||
records.forEach { store[it.id] = it }
|
||||
hasLoadedAll = true
|
||||
|
||||
records
|
||||
}
|
||||
}
|
||||
|
||||
override fun storeSignedPreKey(id: Int, record: SignedPreKeyRecord) {
|
||||
error("Should not happen during the intended usage pattern of this class")
|
||||
}
|
||||
|
||||
override fun containsSignedPreKey(id: Int): Boolean {
|
||||
loadSignedPreKey(id)
|
||||
return store.containsKey(id)
|
||||
}
|
||||
|
||||
override fun removeSignedPreKey(id: Int) {
|
||||
store.remove(id)
|
||||
removed += id
|
||||
}
|
||||
|
||||
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
|
||||
for (id in removed) {
|
||||
persistentStore.removeSignedPreKey(id)
|
||||
}
|
||||
|
||||
removed.clear()
|
||||
}
|
||||
}
|
|
@ -1,7 +1,10 @@
|
|||
package org.whispersystems.signalservice.api;
|
||||
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
|
||||
import org.signal.libsignal.protocol.logging.Log;
|
||||
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess;
|
||||
import org.whispersystems.signalservice.api.messages.EnvelopeResponse;
|
||||
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState;
|
||||
import org.whispersystems.signalservice.api.websocket.WebSocketFactory;
|
||||
import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException;
|
||||
|
@ -13,9 +16,10 @@ import org.whispersystems.signalservice.internal.websocket.WebsocketResponse;
|
|||
import org.whispersystems.util.Base64;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
import io.reactivex.rxjava3.core.Observable;
|
||||
import io.reactivex.rxjava3.core.Single;
|
||||
|
@ -218,61 +222,105 @@ public final class SignalWebSocket {
|
|||
}
|
||||
|
||||
/**
|
||||
* <p>
|
||||
* A blocking call that reads a message off the pipe. When this call returns, if the callback indicates the
|
||||
* message was successfully processed, then the message will be ack'ed on the serve and will not be retransmitted.
|
||||
* <p>
|
||||
* This will return true if there are more messages to be read from the websocket, or false if the websocket is empty.
|
||||
* <p>
|
||||
* You can specify a {@link MessageReceivedCallback} that will be called before the received message is acknowledged.
|
||||
* This allows you to write the received message to durable storage before acknowledging receipt of it to the
|
||||
* server.
|
||||
* <p>
|
||||
* Important: This will only return `false` once for each connection. That means if you get false call readMessage()
|
||||
* again on the same instance, you will not get an immediate `false` return value, and instead will block until
|
||||
* you get an actual message. This will, however, reset if connection breaks (if, for instance, you lose and regain network).
|
||||
* The reads a batch of messages off of the websocket.
|
||||
*
|
||||
* @param timeout The timeout to wait for.
|
||||
* @param callback A callback that will be called before the message receipt is acknowledged to the server.
|
||||
* @return The message read (same as the message sent through the callback).
|
||||
* Rather than just provide you the batch as a return value, it will invoke the provided callback with the
|
||||
* batch as an argument. If you are able to successfully process them, this method will then ack all of the
|
||||
* messages so that they won't be re-delivered in the future.
|
||||
*
|
||||
* The return value of this method is a boolean indicating whether or not there are more messages in the
|
||||
* queue to be read (true if there's still more, or false if you've drained everything).
|
||||
*
|
||||
* However, this return value is only really useful the first time you read from the websocket. That's because
|
||||
* the websocket will only ever let you know if it's drained *once* for any given connection. So if this method
|
||||
* returns false, a subsequent call while using the same websocket connection will simply block until we either
|
||||
* get a new message or hit the timeout.
|
||||
*
|
||||
* Concerning the requested batch size, it's worth noting that this is simply an upper bound. This method will
|
||||
* not wait extra time until the batch has "filled up". Instead, it will wait for a single message, and then
|
||||
* take any extra messages that are also available up until you've hit your batch size.
|
||||
*/
|
||||
@SuppressWarnings("DuplicateThrows")
|
||||
public boolean readMessage(long timeout, MessageReceivedCallback callback)
|
||||
public boolean readMessageBatch(long timeout, int batchSize, MessageReceivedCallback callback)
|
||||
throws TimeoutException, WebSocketUnavailableException, IOException
|
||||
{
|
||||
while (true) {
|
||||
WebSocketRequestMessage request = getWebSocket().readRequest(timeout);
|
||||
WebSocketResponseMessage response = createWebSocketResponse(request);
|
||||
List<EnvelopeResponse> responses = new ArrayList<>();
|
||||
boolean hitEndOfQueue = false;
|
||||
|
||||
AtomicBoolean successfullyProcessed = new AtomicBoolean(false);
|
||||
Optional<EnvelopeResponse> firstEnvelope = waitForSingleMessage(timeout);
|
||||
|
||||
try {
|
||||
if (isSignalServiceEnvelope(request)) {
|
||||
Optional<String> timestampHeader = findHeader(request);
|
||||
long timestamp = 0;
|
||||
if (firstEnvelope.isPresent()) {
|
||||
responses.add(firstEnvelope.get());
|
||||
} else {
|
||||
hitEndOfQueue = true;
|
||||
}
|
||||
|
||||
if (timestampHeader.isPresent()) {
|
||||
try {
|
||||
timestamp = Long.parseLong(timestampHeader.get());
|
||||
} catch (NumberFormatException e) {
|
||||
Log.w(TAG, "Failed to parse " + SERVER_DELIVERED_TIMESTAMP_HEADER);
|
||||
}
|
||||
if (!hitEndOfQueue) {
|
||||
for (int i = 1; i < batchSize; i++) {
|
||||
Optional<WebSocketRequestMessage> request = getWebSocket().readRequestIfAvailable();
|
||||
|
||||
if (request.isPresent()) {
|
||||
if (isSignalServiceEnvelope(request.get())) {
|
||||
responses.add(requestToEnvelopeResponse(request.get()));
|
||||
} else if (isSocketEmptyRequest(request.get())) {
|
||||
hitEndOfQueue = true;
|
||||
break;
|
||||
}
|
||||
|
||||
SignalServiceProtos.Envelope envelope = SignalServiceProtos.Envelope.parseFrom(request.getBody().toByteArray());
|
||||
|
||||
successfullyProcessed.set(callback.onMessage(envelope, timestamp));
|
||||
|
||||
return true;
|
||||
} else if (isSocketEmptyRequest(request)) {
|
||||
return false;
|
||||
}
|
||||
} finally {
|
||||
if (successfullyProcessed.get()) {
|
||||
getWebSocket().sendResponse(response);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (responses.size() > 0) {
|
||||
boolean successfullyProcessed = false;
|
||||
|
||||
try {
|
||||
successfullyProcessed = callback.onMessageBatch(responses);
|
||||
} finally {
|
||||
if (successfullyProcessed) {
|
||||
for (EnvelopeResponse response : responses) {
|
||||
getWebSocket().sendResponse(createWebSocketResponse(response.getWebsocketRequest()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return !hitEndOfQueue;
|
||||
}
|
||||
|
||||
@SuppressWarnings("DuplicateThrows")
|
||||
private Optional<EnvelopeResponse> waitForSingleMessage(long timeout)
|
||||
throws TimeoutException, WebSocketUnavailableException, IOException
|
||||
{
|
||||
while (true) {
|
||||
WebSocketRequestMessage request = getWebSocket().readRequest(timeout);
|
||||
|
||||
if (isSignalServiceEnvelope(request)) {
|
||||
return Optional.of(requestToEnvelopeResponse(request));
|
||||
} else if (isSocketEmptyRequest(request)) {
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static EnvelopeResponse requestToEnvelopeResponse(WebSocketRequestMessage request)
|
||||
throws InvalidProtocolBufferException
|
||||
{
|
||||
Optional<String> timestampHeader = findHeader(request);
|
||||
long timestamp = 0;
|
||||
|
||||
if (timestampHeader.isPresent()) {
|
||||
try {
|
||||
timestamp = Long.parseLong(timestampHeader.get());
|
||||
} catch (NumberFormatException e) {
|
||||
Log.w(TAG, "Failed to parse " + SERVER_DELIVERED_TIMESTAMP_HEADER);
|
||||
}
|
||||
}
|
||||
|
||||
SignalServiceProtos.Envelope envelope = SignalServiceProtos.Envelope.parseFrom(request.getBody().toByteArray());
|
||||
|
||||
return new EnvelopeResponse(envelope, timestamp, request);
|
||||
}
|
||||
|
||||
private static boolean isSignalServiceEnvelope(WebSocketRequestMessage message) {
|
||||
|
@ -323,6 +371,6 @@ public final class SignalWebSocket {
|
|||
public interface MessageReceivedCallback {
|
||||
|
||||
/** True if you successfully processed the message, otherwise false. **/
|
||||
boolean onMessage(SignalServiceProtos.Envelope envelope, long serverDeliveredTimestamp);
|
||||
boolean onMessageBatch(List<EnvelopeResponse> envelopeResponses);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
package org.whispersystems.signalservice.api.messages
|
||||
|
||||
import org.whispersystems.signalservice.internal.push.SignalServiceProtos.Envelope
|
||||
import org.whispersystems.signalservice.internal.websocket.WebSocketProtos.WebSocketRequestMessage
|
||||
|
||||
/**
|
||||
* Represents an envelope off the wire, paired with the metadata needed to process it.
|
||||
*/
|
||||
class EnvelopeResponse(
|
||||
val envelope: Envelope,
|
||||
val serverDeliveredTimestamp: Long,
|
||||
val websocketRequest: WebSocketRequestMessage
|
||||
)
|
|
@ -193,6 +193,14 @@ public class WebSocketConnection extends WebSocketListener {
|
|||
notifyAll();
|
||||
}
|
||||
|
||||
public synchronized Optional<WebSocketRequestMessage> readRequestIfAvailable() {
|
||||
if (incomingRequests.size() > 0) {
|
||||
return Optional.of(incomingRequests.removeFirst());
|
||||
} else {
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
|
||||
public synchronized WebSocketRequestMessage readRequest(long timeoutMillis)
|
||||
throws TimeoutException, IOException
|
||||
{
|
||||
|
|
Loading…
Add table
Reference in a new issue