Updated MessageProcessingPerformanceTest to use websocket injection.

This commit is contained in:
Greyson Parrelli 2023-03-13 13:50:32 -04:00
parent a7bca89889
commit ee685936c5
4 changed files with 143 additions and 51 deletions

View file

@ -2,14 +2,18 @@ package org.thoughtcrime.securesms.dependencies
import android.app.Application import android.app.Application
import okhttp3.ConnectionSpec import okhttp3.ConnectionSpec
import okhttp3.Response
import okhttp3.WebSocket
import okhttp3.WebSocketListener import okhttp3.WebSocketListener
import okhttp3.mockwebserver.Dispatcher import okhttp3.mockwebserver.Dispatcher
import okhttp3.mockwebserver.MockResponse import okhttp3.mockwebserver.MockResponse
import okhttp3.mockwebserver.MockWebServer import okhttp3.mockwebserver.MockWebServer
import okhttp3.mockwebserver.RecordedRequest import okhttp3.mockwebserver.RecordedRequest
import okio.ByteString
import org.mockito.kotlin.any import org.mockito.kotlin.any
import org.mockito.kotlin.doReturn import org.mockito.kotlin.doReturn
import org.mockito.kotlin.mock import org.mockito.kotlin.mock
import org.signal.core.util.logging.Log
import org.thoughtcrime.securesms.BuildConfig import org.thoughtcrime.securesms.BuildConfig
import org.thoughtcrime.securesms.KbsEnclave import org.thoughtcrime.securesms.KbsEnclave
import org.thoughtcrime.securesms.push.SignalServiceNetworkAccess import org.thoughtcrime.securesms.push.SignalServiceNetworkAccess
@ -52,7 +56,10 @@ class InstrumentationApplicationDependencyProvider(application: Application, def
baseUrl = webServer.url("").toString() baseUrl = webServer.url("").toString()
addMockWebRequestHandlers( addMockWebRequestHandlers(
Get("/v1/websocket/") { Get("/v1/websocket/?login=") {
MockResponse().success().withWebSocketUpgrade(mockIdentifiedWebSocket)
},
Get("/v1/websocket", { !it.path.contains("login") }) {
MockResponse().success().withWebSocketUpgrade(object : WebSocketListener() {}) MockResponse().success().withWebSocketUpgrade(object : WebSocketListener() {})
} }
) )
@ -60,9 +67,7 @@ class InstrumentationApplicationDependencyProvider(application: Application, def
webServer.setDispatcher(object : Dispatcher() { webServer.setDispatcher(object : Dispatcher() {
override fun dispatch(request: RecordedRequest): MockResponse { override fun dispatch(request: RecordedRequest): MockResponse {
val handler = handlers.firstOrNull { val handler = handlers.firstOrNull { it.requestPredicate(request) }
request.method == it.verb && request.path.startsWith("/${it.path}")
}
return handler?.responseFactory?.invoke(request) ?: MockResponse().setResponseCode(500) return handler?.responseFactory?.invoke(request) ?: MockResponse().setResponseCode(500)
} }
}) })
@ -106,18 +111,51 @@ class InstrumentationApplicationDependencyProvider(application: Application, def
return recipientCache return recipientCache
} }
class MockWebSocket : WebSocketListener() {
private val TAG = "MockWebSocket"
var webSocket: WebSocket? = null
private set
override fun onOpen(webSocket: WebSocket, response: Response) {
Log.i(TAG, "onOpen(${webSocket.hashCode()})")
this.webSocket = webSocket
}
override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {
Log.i(TAG, "onClosing(${webSocket.hashCode()}): $code, $reason")
this.webSocket = null
}
override fun onClosed(webSocket: WebSocket, code: Int, reason: String) {
Log.i(TAG, "onClosed(${webSocket.hashCode()}): $code, $reason")
this.webSocket = null
}
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
Log.w(TAG, "onFailure(${webSocket.hashCode()})", t)
this.webSocket = null
}
}
companion object { companion object {
lateinit var webServer: MockWebServer lateinit var webServer: MockWebServer
private set private set
lateinit var baseUrl: String lateinit var baseUrl: String
private set private set
val mockIdentifiedWebSocket = MockWebSocket()
private val handlers: MutableList<Verb> = mutableListOf() private val handlers: MutableList<Verb> = mutableListOf()
fun addMockWebRequestHandlers(vararg verbs: Verb) { fun addMockWebRequestHandlers(vararg verbs: Verb) {
handlers.addAll(verbs) handlers.addAll(verbs)
} }
fun injectWebSocketMessage(value: ByteString) {
mockIdentifiedWebSocket.webSocket!!.send(value)
}
fun clearHandlers() { fun clearHandlers() {
handlers.clear() handlers.clear()
} }

View file

@ -4,6 +4,8 @@ import androidx.test.ext.junit.runners.AndroidJUnit4
import io.mockk.every import io.mockk.every
import io.mockk.mockkStatic import io.mockk.mockkStatic
import io.mockk.unmockkStatic import io.mockk.unmockkStatic
import okio.ByteString
import okio.ByteString.Companion.toByteString
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Ignore import org.junit.Ignore
@ -15,6 +17,7 @@ import org.signal.libsignal.protocol.ecc.Curve
import org.signal.libsignal.protocol.ecc.ECKeyPair import org.signal.libsignal.protocol.ecc.ECKeyPair
import org.signal.libsignal.zkgroup.profiles.ProfileKey import org.signal.libsignal.zkgroup.profiles.ProfileKey
import org.thoughtcrime.securesms.crypto.UnidentifiedAccessUtil import org.thoughtcrime.securesms.crypto.UnidentifiedAccessUtil
import org.thoughtcrime.securesms.dependencies.InstrumentationApplicationDependencyProvider
import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.testing.AliceClient import org.thoughtcrime.securesms.testing.AliceClient
import org.thoughtcrime.securesms.testing.BobClient import org.thoughtcrime.securesms.testing.BobClient
@ -23,6 +26,10 @@ import org.thoughtcrime.securesms.testing.FakeClientHelpers
import org.thoughtcrime.securesms.testing.SignalActivityRule import org.thoughtcrime.securesms.testing.SignalActivityRule
import org.thoughtcrime.securesms.testing.awaitFor import org.thoughtcrime.securesms.testing.awaitFor
import org.whispersystems.signalservice.internal.push.SignalServiceProtos.Envelope import org.whispersystems.signalservice.internal.push.SignalServiceProtos.Envelope
import org.whispersystems.signalservice.internal.websocket.WebSocketProtos.WebSocketMessage
import org.whispersystems.signalservice.internal.websocket.WebSocketProtos.WebSocketRequestMessage
import java.util.regex.Pattern
import kotlin.random.Random
import kotlin.time.Duration.Companion.minutes import kotlin.time.Duration.Companion.minutes
import kotlin.time.Duration.Companion.seconds import kotlin.time.Duration.Companion.seconds
import android.util.Log as AndroidLog import android.util.Log as AndroidLog
@ -37,6 +44,8 @@ class MessageProcessingPerformanceTest {
companion object { companion object {
private val TAG = Log.tag(MessageProcessingPerformanceTest::class.java) private val TAG = Log.tag(MessageProcessingPerformanceTest::class.java)
private val TIMING_TAG = "TIMING_$TAG".substring(0..23) private val TIMING_TAG = "TIMING_$TAG".substring(0..23)
private val DECRYPTION_TIME_PATTERN = Pattern.compile("^Decrypted (?<count>\\d+) envelopes in (?<duration>\\d+) ms.*$")
} }
@get:Rule @get:Rule
@ -76,64 +85,43 @@ class MessageProcessingPerformanceTest {
profileKey = ProfileKey(bob.profileKey) profileKey = ProfileKey(bob.profileKey)
) )
// Send message from Bob to Alice (self) // Send the initial messages to get past the prekey phase
establishSession(aliceClient, bobClient, bob)
val firstPreKeyMessageTimestamp = System.currentTimeMillis()
val encryptedEnvelope = bobClient.encrypt(firstPreKeyMessageTimestamp)
val aliceProcessFirstMessageLatch = harness
.inMemoryLogger
.getLockForUntil(TimingMessageContentProcessor.endTagPredicate(firstPreKeyMessageTimestamp))
Thread { aliceClient.process(encryptedEnvelope, System.currentTimeMillis()) }.start()
aliceProcessFirstMessageLatch.awaitFor(15.seconds)
// Send message from Alice to Bob
val aliceNow = System.currentTimeMillis()
bobClient.decrypt(aliceClient.encrypt(aliceNow, bob), aliceNow)
// Build N messages from Bob to Alice
// Have Bob generate N messages that will be received by Alice
val messageCount = 100 val messageCount = 100
val envelopes = ArrayList<Envelope>(messageCount) val envelopes = generateInboundEnvelopes(bobClient, messageCount)
var now = System.currentTimeMillis()
for (i in 0..messageCount) {
envelopes += bobClient.encrypt(now)
now += 3
}
val firstTimestamp = envelopes.first().timestamp val firstTimestamp = envelopes.first().timestamp
val lastTimestamp = envelopes.last().timestamp val lastTimestamp = envelopes.last().timestamp
// Alice processes N messages // Inject the envelopes into the websocket
val aliceProcessLastMessageLatch = harness
.inMemoryLogger
.getLockForUntil(TimingMessageContentProcessor.endTagPredicate(lastTimestamp))
Thread { Thread {
for (envelope in envelopes) { for (envelope in envelopes) {
Log.i(TIMING_TAG, "Retrieved envelope! ${envelope.timestamp}") Log.i(TIMING_TAG, "Retrieved envelope! ${envelope.timestamp}")
aliceClient.process(envelope, envelope.timestamp) InstrumentationApplicationDependencyProvider.injectWebSocketMessage(envelope.toWebSocketPayload())
} }
InstrumentationApplicationDependencyProvider.injectWebSocketMessage(webSocketTombstone())
}.start() }.start()
// Wait for Alice to finish processing messages // Wait until they've all been fully decrypted + processed
aliceProcessLastMessageLatch.awaitFor(1.minutes) harness
.inMemoryLogger
.getLockForUntil(TimingMessageContentProcessor.endTagPredicate(lastTimestamp))
.awaitFor(1.minutes)
harness.inMemoryLogger.flush() harness.inMemoryLogger.flush()
// Process logs for timing data // Process logs for timing data
val entries = harness.inMemoryLogger.entries() val entries = harness.inMemoryLogger.entries()
// Calculate decryption average // Calculate decryption average
val totalDecryptDuration: Long = entries
.mapNotNull { entry -> entry.message?.let { DECRYPTION_TIME_PATTERN.matcher(it) } }
.filter { it.matches() }
.drop(1) // Ignore the first message, which represents the prekey exchange
.sumOf { it.group("duration")!!.toLong() }
val decrypts = entries AndroidLog.w(TAG, "Decryption: Average runtime: ${totalDecryptDuration.toFloat() / messageCount.toFloat()}ms")
.filter { it.tag == AliceClient.TAG }
.drop(1)
val totalDecryptDuration = decrypts.sumOf { it.message!!.toLong() }
AndroidLog.w(TAG, "Decryption: Average runtime: ${totalDecryptDuration.toFloat() / decrypts.size.toFloat()}ms")
// Calculate MessageContentProcessor // Calculate MessageContentProcessor
@ -160,4 +148,62 @@ class MessageProcessingPerformanceTest {
AndroidLog.w(TAG, "Processing $messageCount messages took ${duration}s or ${messagePerSecond}m/s") AndroidLog.w(TAG, "Processing $messageCount messages took ${duration}s or ${messagePerSecond}m/s")
} }
private fun establishSession(aliceClient: AliceClient, bobClient: BobClient, bob: Recipient) {
// Send message from Bob to Alice (self)
val firstPreKeyMessageTimestamp = System.currentTimeMillis()
val encryptedEnvelope = bobClient.encrypt(firstPreKeyMessageTimestamp)
val aliceProcessFirstMessageLatch = harness
.inMemoryLogger
.getLockForUntil(TimingMessageContentProcessor.endTagPredicate(firstPreKeyMessageTimestamp))
Thread { aliceClient.process(encryptedEnvelope, System.currentTimeMillis()) }.start()
aliceProcessFirstMessageLatch.awaitFor(15.seconds)
// Send message from Alice to Bob
val aliceNow = System.currentTimeMillis()
bobClient.decrypt(aliceClient.encrypt(aliceNow, bob), aliceNow)
}
private fun generateInboundEnvelopes(bobClient: BobClient, count: Int): List<Envelope> {
val envelopes = ArrayList<Envelope>(count)
var now = System.currentTimeMillis()
for (i in 0..count) {
envelopes += bobClient.encrypt(now)
now += 3
}
return envelopes
}
private fun webSocketTombstone(): ByteString {
return WebSocketMessage
.newBuilder()
.setRequest(
WebSocketRequestMessage.newBuilder()
.setVerb("PUT")
.setPath("/api/v1/queue/empty")
)
.build()
.toByteArray()
.toByteString()
}
private fun Envelope.toWebSocketPayload(): ByteString {
return WebSocketMessage
.newBuilder()
.setType(WebSocketMessage.Type.REQUEST)
.setRequest(
WebSocketRequestMessage.newBuilder()
.setVerb("PUT")
.setPath("/api/v1/message")
.setId(Random(System.currentTimeMillis()).nextLong())
.addHeaders("X-Signal-Timestamp: ${this.timestamp}")
.setBody(this.toByteString())
)
.build()
.toByteArray()
.toByteString()
}
} }

View file

@ -7,17 +7,20 @@ import org.thoughtcrime.securesms.util.JsonUtils
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
typealias ResponseFactory = (request: RecordedRequest) -> MockResponse typealias ResponseFactory = (request: RecordedRequest) -> MockResponse
typealias RequestPredicate = (request: RecordedRequest) -> Boolean
/** /**
* Represent an HTTP verb for mocking web requests. * Represent an HTTP verb for mocking web requests.
*/ */
sealed class Verb(val verb: String, val path: String, val responseFactory: ResponseFactory) sealed class Verb(val requestPredicate: RequestPredicate, val responseFactory: ResponseFactory)
class Get(path: String, responseFactory: ResponseFactory) : Verb("GET", path, responseFactory) class Get(path: String, predicate: RequestPredicate, responseFactory: ResponseFactory) : Verb(defaultRequestPredicate("GET", path, predicate), responseFactory) {
constructor(path: String, responseFactory: ResponseFactory) : this(path, { true }, responseFactory)
}
class Put(path: String, responseFactory: ResponseFactory) : Verb("PUT", path, responseFactory) class Put(path: String, responseFactory: ResponseFactory) : Verb(defaultRequestPredicate("PUT", path), responseFactory)
class Post(path: String, responseFactory: ResponseFactory) : Verb("POST", path, responseFactory) class Post(path: String, responseFactory: ResponseFactory) : Verb(defaultRequestPredicate("POST", path), responseFactory)
fun MockResponse.success(response: Any? = null): MockResponse { fun MockResponse.success(response: Any? = null): MockResponse {
return setResponseCode(200).apply { return setResponseCode(200).apply {
@ -48,3 +51,7 @@ inline fun <reified T> RecordedRequest.parsedRequestBody(): T {
val bodyString = String(body.readByteArray()) val bodyString = String(body.readByteArray())
return JsonUtils.fromJson(bodyString, T::class.java) return JsonUtils.fromJson(bodyString, T::class.java)
} }
private fun defaultRequestPredicate(verb: String, path: String, predicate: RequestPredicate = { true }): RequestPredicate = { request ->
request.method == verb && request.path.startsWith("/$path") && predicate(request)
}

View file

@ -221,11 +221,12 @@ class IncomingMessageObserver(private val context: Application) {
} }
fun terminateAsync() { fun terminateAsync() {
Log.w(TAG, "Termination Enqueued! ${this.hashCode()}", Throwable())
INSTANCE_COUNT.decrementAndGet() INSTANCE_COUNT.decrementAndGet()
context.unregisterReceiver(connectionReceiver) context.unregisterReceiver(connectionReceiver)
SignalExecutors.BOUNDED.execute { SignalExecutors.BOUNDED.execute {
Log.w(TAG, "Beginning termination.") Log.w(TAG, "Beginning termination. ${this.hashCode()}")
terminated = true terminated = true
disconnect() disconnect()
} }
@ -371,7 +372,7 @@ class IncomingMessageObserver(private val context: Application) {
private inner class MessageRetrievalThread : Thread("MessageRetrievalService"), Thread.UncaughtExceptionHandler { private inner class MessageRetrievalThread : Thread("MessageRetrievalService"), Thread.UncaughtExceptionHandler {
init { init {
Log.i(TAG, "Initializing! (" + this.hashCode() + ")") Log.i(TAG, "Initializing! (${this.hashCode()})")
uncaughtExceptionHandler = this uncaughtExceptionHandler = this
} }
@ -461,7 +462,7 @@ class IncomingMessageObserver(private val context: Application) {
} }
Log.i(TAG, "Looping...") Log.i(TAG, "Looping...")
} }
Log.w(TAG, "Terminated! (" + this.hashCode() + ")") Log.w(TAG, "Terminated! (${this.hashCode()})")
} }
override fun uncaughtException(t: Thread, e: Throwable) { override fun uncaughtException(t: Thread, e: Throwable) {