From ee685936c59445a4d9dd38419a97e99fc202206e Mon Sep 17 00:00:00 2001 From: Greyson Parrelli Date: Mon, 13 Mar 2023 13:50:32 -0400 Subject: [PATCH] Updated MessageProcessingPerformanceTest to use websocket injection. --- ...umentationApplicationDependencyProvider.kt | 46 ++++++- .../MessageProcessingPerformanceTest.kt | 126 ++++++++++++------ .../securesms/testing/ResponseMocking.kt | 15 ++- .../messages/IncomingMessageObserver.kt | 7 +- 4 files changed, 143 insertions(+), 51 deletions(-) diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/dependencies/InstrumentationApplicationDependencyProvider.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/dependencies/InstrumentationApplicationDependencyProvider.kt index 374bb394a4..680e85add2 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/dependencies/InstrumentationApplicationDependencyProvider.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/dependencies/InstrumentationApplicationDependencyProvider.kt @@ -2,14 +2,18 @@ package org.thoughtcrime.securesms.dependencies import android.app.Application import okhttp3.ConnectionSpec +import okhttp3.Response +import okhttp3.WebSocket import okhttp3.WebSocketListener import okhttp3.mockwebserver.Dispatcher import okhttp3.mockwebserver.MockResponse import okhttp3.mockwebserver.MockWebServer import okhttp3.mockwebserver.RecordedRequest +import okio.ByteString import org.mockito.kotlin.any import org.mockito.kotlin.doReturn import org.mockito.kotlin.mock +import org.signal.core.util.logging.Log import org.thoughtcrime.securesms.BuildConfig import org.thoughtcrime.securesms.KbsEnclave import org.thoughtcrime.securesms.push.SignalServiceNetworkAccess @@ -52,7 +56,10 @@ class InstrumentationApplicationDependencyProvider(application: Application, def baseUrl = webServer.url("").toString() addMockWebRequestHandlers( - Get("/v1/websocket/") { + Get("/v1/websocket/?login=") { + MockResponse().success().withWebSocketUpgrade(mockIdentifiedWebSocket) + }, + Get("/v1/websocket", { !it.path.contains("login") }) { MockResponse().success().withWebSocketUpgrade(object : WebSocketListener() {}) } ) @@ -60,9 +67,7 @@ class InstrumentationApplicationDependencyProvider(application: Application, def webServer.setDispatcher(object : Dispatcher() { override fun dispatch(request: RecordedRequest): MockResponse { - val handler = handlers.firstOrNull { - request.method == it.verb && request.path.startsWith("/${it.path}") - } + val handler = handlers.firstOrNull { it.requestPredicate(request) } return handler?.responseFactory?.invoke(request) ?: MockResponse().setResponseCode(500) } }) @@ -106,18 +111,51 @@ class InstrumentationApplicationDependencyProvider(application: Application, def return recipientCache } + class MockWebSocket : WebSocketListener() { + private val TAG = "MockWebSocket" + + var webSocket: WebSocket? = null + private set + + override fun onOpen(webSocket: WebSocket, response: Response) { + Log.i(TAG, "onOpen(${webSocket.hashCode()})") + this.webSocket = webSocket + } + + override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { + Log.i(TAG, "onClosing(${webSocket.hashCode()}): $code, $reason") + this.webSocket = null + } + + override fun onClosed(webSocket: WebSocket, code: Int, reason: String) { + Log.i(TAG, "onClosed(${webSocket.hashCode()}): $code, $reason") + this.webSocket = null + } + + override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { + Log.w(TAG, "onFailure(${webSocket.hashCode()})", t) + this.webSocket = null + } + } + companion object { lateinit var webServer: MockWebServer private set lateinit var baseUrl: String private set + val mockIdentifiedWebSocket = MockWebSocket() + private val handlers: MutableList = mutableListOf() fun addMockWebRequestHandlers(vararg verbs: Verb) { handlers.addAll(verbs) } + fun injectWebSocketMessage(value: ByteString) { + mockIdentifiedWebSocket.webSocket!!.send(value) + } + fun clearHandlers() { handlers.clear() } diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/messages/MessageProcessingPerformanceTest.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/messages/MessageProcessingPerformanceTest.kt index c18e368d49..85b40c368c 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/messages/MessageProcessingPerformanceTest.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/messages/MessageProcessingPerformanceTest.kt @@ -4,6 +4,8 @@ import androidx.test.ext.junit.runners.AndroidJUnit4 import io.mockk.every import io.mockk.mockkStatic import io.mockk.unmockkStatic +import okio.ByteString +import okio.ByteString.Companion.toByteString import org.junit.After import org.junit.Before import org.junit.Ignore @@ -15,6 +17,7 @@ import org.signal.libsignal.protocol.ecc.Curve import org.signal.libsignal.protocol.ecc.ECKeyPair import org.signal.libsignal.zkgroup.profiles.ProfileKey import org.thoughtcrime.securesms.crypto.UnidentifiedAccessUtil +import org.thoughtcrime.securesms.dependencies.InstrumentationApplicationDependencyProvider import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.testing.AliceClient import org.thoughtcrime.securesms.testing.BobClient @@ -23,6 +26,10 @@ import org.thoughtcrime.securesms.testing.FakeClientHelpers import org.thoughtcrime.securesms.testing.SignalActivityRule import org.thoughtcrime.securesms.testing.awaitFor import org.whispersystems.signalservice.internal.push.SignalServiceProtos.Envelope +import org.whispersystems.signalservice.internal.websocket.WebSocketProtos.WebSocketMessage +import org.whispersystems.signalservice.internal.websocket.WebSocketProtos.WebSocketRequestMessage +import java.util.regex.Pattern +import kotlin.random.Random import kotlin.time.Duration.Companion.minutes import kotlin.time.Duration.Companion.seconds import android.util.Log as AndroidLog @@ -37,6 +44,8 @@ class MessageProcessingPerformanceTest { companion object { private val TAG = Log.tag(MessageProcessingPerformanceTest::class.java) private val TIMING_TAG = "TIMING_$TAG".substring(0..23) + + private val DECRYPTION_TIME_PATTERN = Pattern.compile("^Decrypted (?\\d+) envelopes in (?\\d+) ms.*$") } @get:Rule @@ -76,64 +85,43 @@ class MessageProcessingPerformanceTest { profileKey = ProfileKey(bob.profileKey) ) - // Send message from Bob to Alice (self) - - val firstPreKeyMessageTimestamp = System.currentTimeMillis() - val encryptedEnvelope = bobClient.encrypt(firstPreKeyMessageTimestamp) - - val aliceProcessFirstMessageLatch = harness - .inMemoryLogger - .getLockForUntil(TimingMessageContentProcessor.endTagPredicate(firstPreKeyMessageTimestamp)) - - Thread { aliceClient.process(encryptedEnvelope, System.currentTimeMillis()) }.start() - aliceProcessFirstMessageLatch.awaitFor(15.seconds) - - // Send message from Alice to Bob - val aliceNow = System.currentTimeMillis() - bobClient.decrypt(aliceClient.encrypt(aliceNow, bob), aliceNow) - - // Build N messages from Bob to Alice + // Send the initial messages to get past the prekey phase + establishSession(aliceClient, bobClient, bob) + // Have Bob generate N messages that will be received by Alice val messageCount = 100 - val envelopes = ArrayList(messageCount) - var now = System.currentTimeMillis() - for (i in 0..messageCount) { - envelopes += bobClient.encrypt(now) - now += 3 - } - + val envelopes = generateInboundEnvelopes(bobClient, messageCount) val firstTimestamp = envelopes.first().timestamp val lastTimestamp = envelopes.last().timestamp - // Alice processes N messages - - val aliceProcessLastMessageLatch = harness - .inMemoryLogger - .getLockForUntil(TimingMessageContentProcessor.endTagPredicate(lastTimestamp)) - + // Inject the envelopes into the websocket Thread { for (envelope in envelopes) { Log.i(TIMING_TAG, "Retrieved envelope! ${envelope.timestamp}") - aliceClient.process(envelope, envelope.timestamp) + InstrumentationApplicationDependencyProvider.injectWebSocketMessage(envelope.toWebSocketPayload()) } + InstrumentationApplicationDependencyProvider.injectWebSocketMessage(webSocketTombstone()) }.start() - // Wait for Alice to finish processing messages - aliceProcessLastMessageLatch.awaitFor(1.minutes) + // Wait until they've all been fully decrypted + processed + harness + .inMemoryLogger + .getLockForUntil(TimingMessageContentProcessor.endTagPredicate(lastTimestamp)) + .awaitFor(1.minutes) + harness.inMemoryLogger.flush() // Process logs for timing data val entries = harness.inMemoryLogger.entries() // Calculate decryption average + val totalDecryptDuration: Long = entries + .mapNotNull { entry -> entry.message?.let { DECRYPTION_TIME_PATTERN.matcher(it) } } + .filter { it.matches() } + .drop(1) // Ignore the first message, which represents the prekey exchange + .sumOf { it.group("duration")!!.toLong() } - val decrypts = entries - .filter { it.tag == AliceClient.TAG } - .drop(1) - - val totalDecryptDuration = decrypts.sumOf { it.message!!.toLong() } - - AndroidLog.w(TAG, "Decryption: Average runtime: ${totalDecryptDuration.toFloat() / decrypts.size.toFloat()}ms") + AndroidLog.w(TAG, "Decryption: Average runtime: ${totalDecryptDuration.toFloat() / messageCount.toFloat()}ms") // Calculate MessageContentProcessor @@ -160,4 +148,62 @@ class MessageProcessingPerformanceTest { AndroidLog.w(TAG, "Processing $messageCount messages took ${duration}s or ${messagePerSecond}m/s") } + + private fun establishSession(aliceClient: AliceClient, bobClient: BobClient, bob: Recipient) { + // Send message from Bob to Alice (self) + val firstPreKeyMessageTimestamp = System.currentTimeMillis() + val encryptedEnvelope = bobClient.encrypt(firstPreKeyMessageTimestamp) + + val aliceProcessFirstMessageLatch = harness + .inMemoryLogger + .getLockForUntil(TimingMessageContentProcessor.endTagPredicate(firstPreKeyMessageTimestamp)) + + Thread { aliceClient.process(encryptedEnvelope, System.currentTimeMillis()) }.start() + aliceProcessFirstMessageLatch.awaitFor(15.seconds) + + // Send message from Alice to Bob + val aliceNow = System.currentTimeMillis() + bobClient.decrypt(aliceClient.encrypt(aliceNow, bob), aliceNow) + } + + private fun generateInboundEnvelopes(bobClient: BobClient, count: Int): List { + val envelopes = ArrayList(count) + var now = System.currentTimeMillis() + for (i in 0..count) { + envelopes += bobClient.encrypt(now) + now += 3 + } + + return envelopes + } + + private fun webSocketTombstone(): ByteString { + return WebSocketMessage + .newBuilder() + .setRequest( + WebSocketRequestMessage.newBuilder() + .setVerb("PUT") + .setPath("/api/v1/queue/empty") + ) + .build() + .toByteArray() + .toByteString() + } + + private fun Envelope.toWebSocketPayload(): ByteString { + return WebSocketMessage + .newBuilder() + .setType(WebSocketMessage.Type.REQUEST) + .setRequest( + WebSocketRequestMessage.newBuilder() + .setVerb("PUT") + .setPath("/api/v1/message") + .setId(Random(System.currentTimeMillis()).nextLong()) + .addHeaders("X-Signal-Timestamp: ${this.timestamp}") + .setBody(this.toByteString()) + ) + .build() + .toByteArray() + .toByteString() + } } diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/testing/ResponseMocking.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/testing/ResponseMocking.kt index dfe45effec..fd93a3ee10 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/testing/ResponseMocking.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/testing/ResponseMocking.kt @@ -7,17 +7,20 @@ import org.thoughtcrime.securesms.util.JsonUtils import java.util.concurrent.TimeUnit typealias ResponseFactory = (request: RecordedRequest) -> MockResponse +typealias RequestPredicate = (request: RecordedRequest) -> Boolean /** * Represent an HTTP verb for mocking web requests. */ -sealed class Verb(val verb: String, val path: String, val responseFactory: ResponseFactory) +sealed class Verb(val requestPredicate: RequestPredicate, val responseFactory: ResponseFactory) -class Get(path: String, responseFactory: ResponseFactory) : Verb("GET", path, responseFactory) +class Get(path: String, predicate: RequestPredicate, responseFactory: ResponseFactory) : Verb(defaultRequestPredicate("GET", path, predicate), responseFactory) { + constructor(path: String, responseFactory: ResponseFactory) : this(path, { true }, responseFactory) +} -class Put(path: String, responseFactory: ResponseFactory) : Verb("PUT", path, responseFactory) +class Put(path: String, responseFactory: ResponseFactory) : Verb(defaultRequestPredicate("PUT", path), responseFactory) -class Post(path: String, responseFactory: ResponseFactory) : Verb("POST", path, responseFactory) +class Post(path: String, responseFactory: ResponseFactory) : Verb(defaultRequestPredicate("POST", path), responseFactory) fun MockResponse.success(response: Any? = null): MockResponse { return setResponseCode(200).apply { @@ -48,3 +51,7 @@ inline fun RecordedRequest.parsedRequestBody(): T { val bodyString = String(body.readByteArray()) return JsonUtils.fromJson(bodyString, T::class.java) } + +private fun defaultRequestPredicate(verb: String, path: String, predicate: RequestPredicate = { true }): RequestPredicate = { request -> + request.method == verb && request.path.startsWith("/$path") && predicate(request) +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt index 9c3d100250..aa4beccfe2 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt @@ -221,11 +221,12 @@ class IncomingMessageObserver(private val context: Application) { } fun terminateAsync() { + Log.w(TAG, "Termination Enqueued! ${this.hashCode()}", Throwable()) INSTANCE_COUNT.decrementAndGet() context.unregisterReceiver(connectionReceiver) SignalExecutors.BOUNDED.execute { - Log.w(TAG, "Beginning termination.") + Log.w(TAG, "Beginning termination. ${this.hashCode()}") terminated = true disconnect() } @@ -371,7 +372,7 @@ class IncomingMessageObserver(private val context: Application) { private inner class MessageRetrievalThread : Thread("MessageRetrievalService"), Thread.UncaughtExceptionHandler { init { - Log.i(TAG, "Initializing! (" + this.hashCode() + ")") + Log.i(TAG, "Initializing! (${this.hashCode()})") uncaughtExceptionHandler = this } @@ -461,7 +462,7 @@ class IncomingMessageObserver(private val context: Application) { } Log.i(TAG, "Looping...") } - Log.w(TAG, "Terminated! (" + this.hashCode() + ")") + Log.w(TAG, "Terminated! (${this.hashCode()})") } override fun uncaughtException(t: Thread, e: Throwable) {