Updated MessageProcessingPerformanceTest to use websocket injection.
This commit is contained in:
parent
a7bca89889
commit
ee685936c5
4 changed files with 143 additions and 51 deletions
|
@ -2,14 +2,18 @@ package org.thoughtcrime.securesms.dependencies
|
||||||
|
|
||||||
import android.app.Application
|
import android.app.Application
|
||||||
import okhttp3.ConnectionSpec
|
import okhttp3.ConnectionSpec
|
||||||
|
import okhttp3.Response
|
||||||
|
import okhttp3.WebSocket
|
||||||
import okhttp3.WebSocketListener
|
import okhttp3.WebSocketListener
|
||||||
import okhttp3.mockwebserver.Dispatcher
|
import okhttp3.mockwebserver.Dispatcher
|
||||||
import okhttp3.mockwebserver.MockResponse
|
import okhttp3.mockwebserver.MockResponse
|
||||||
import okhttp3.mockwebserver.MockWebServer
|
import okhttp3.mockwebserver.MockWebServer
|
||||||
import okhttp3.mockwebserver.RecordedRequest
|
import okhttp3.mockwebserver.RecordedRequest
|
||||||
|
import okio.ByteString
|
||||||
import org.mockito.kotlin.any
|
import org.mockito.kotlin.any
|
||||||
import org.mockito.kotlin.doReturn
|
import org.mockito.kotlin.doReturn
|
||||||
import org.mockito.kotlin.mock
|
import org.mockito.kotlin.mock
|
||||||
|
import org.signal.core.util.logging.Log
|
||||||
import org.thoughtcrime.securesms.BuildConfig
|
import org.thoughtcrime.securesms.BuildConfig
|
||||||
import org.thoughtcrime.securesms.KbsEnclave
|
import org.thoughtcrime.securesms.KbsEnclave
|
||||||
import org.thoughtcrime.securesms.push.SignalServiceNetworkAccess
|
import org.thoughtcrime.securesms.push.SignalServiceNetworkAccess
|
||||||
|
@ -52,7 +56,10 @@ class InstrumentationApplicationDependencyProvider(application: Application, def
|
||||||
baseUrl = webServer.url("").toString()
|
baseUrl = webServer.url("").toString()
|
||||||
|
|
||||||
addMockWebRequestHandlers(
|
addMockWebRequestHandlers(
|
||||||
Get("/v1/websocket/") {
|
Get("/v1/websocket/?login=") {
|
||||||
|
MockResponse().success().withWebSocketUpgrade(mockIdentifiedWebSocket)
|
||||||
|
},
|
||||||
|
Get("/v1/websocket", { !it.path.contains("login") }) {
|
||||||
MockResponse().success().withWebSocketUpgrade(object : WebSocketListener() {})
|
MockResponse().success().withWebSocketUpgrade(object : WebSocketListener() {})
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -60,9 +67,7 @@ class InstrumentationApplicationDependencyProvider(application: Application, def
|
||||||
|
|
||||||
webServer.setDispatcher(object : Dispatcher() {
|
webServer.setDispatcher(object : Dispatcher() {
|
||||||
override fun dispatch(request: RecordedRequest): MockResponse {
|
override fun dispatch(request: RecordedRequest): MockResponse {
|
||||||
val handler = handlers.firstOrNull {
|
val handler = handlers.firstOrNull { it.requestPredicate(request) }
|
||||||
request.method == it.verb && request.path.startsWith("/${it.path}")
|
|
||||||
}
|
|
||||||
return handler?.responseFactory?.invoke(request) ?: MockResponse().setResponseCode(500)
|
return handler?.responseFactory?.invoke(request) ?: MockResponse().setResponseCode(500)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -106,18 +111,51 @@ class InstrumentationApplicationDependencyProvider(application: Application, def
|
||||||
return recipientCache
|
return recipientCache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class MockWebSocket : WebSocketListener() {
|
||||||
|
private val TAG = "MockWebSocket"
|
||||||
|
|
||||||
|
var webSocket: WebSocket? = null
|
||||||
|
private set
|
||||||
|
|
||||||
|
override fun onOpen(webSocket: WebSocket, response: Response) {
|
||||||
|
Log.i(TAG, "onOpen(${webSocket.hashCode()})")
|
||||||
|
this.webSocket = webSocket
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {
|
||||||
|
Log.i(TAG, "onClosing(${webSocket.hashCode()}): $code, $reason")
|
||||||
|
this.webSocket = null
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun onClosed(webSocket: WebSocket, code: Int, reason: String) {
|
||||||
|
Log.i(TAG, "onClosed(${webSocket.hashCode()}): $code, $reason")
|
||||||
|
this.webSocket = null
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
|
||||||
|
Log.w(TAG, "onFailure(${webSocket.hashCode()})", t)
|
||||||
|
this.webSocket = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
lateinit var webServer: MockWebServer
|
lateinit var webServer: MockWebServer
|
||||||
private set
|
private set
|
||||||
lateinit var baseUrl: String
|
lateinit var baseUrl: String
|
||||||
private set
|
private set
|
||||||
|
|
||||||
|
val mockIdentifiedWebSocket = MockWebSocket()
|
||||||
|
|
||||||
private val handlers: MutableList<Verb> = mutableListOf()
|
private val handlers: MutableList<Verb> = mutableListOf()
|
||||||
|
|
||||||
fun addMockWebRequestHandlers(vararg verbs: Verb) {
|
fun addMockWebRequestHandlers(vararg verbs: Verb) {
|
||||||
handlers.addAll(verbs)
|
handlers.addAll(verbs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun injectWebSocketMessage(value: ByteString) {
|
||||||
|
mockIdentifiedWebSocket.webSocket!!.send(value)
|
||||||
|
}
|
||||||
|
|
||||||
fun clearHandlers() {
|
fun clearHandlers() {
|
||||||
handlers.clear()
|
handlers.clear()
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,8 @@ import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||||
import io.mockk.every
|
import io.mockk.every
|
||||||
import io.mockk.mockkStatic
|
import io.mockk.mockkStatic
|
||||||
import io.mockk.unmockkStatic
|
import io.mockk.unmockkStatic
|
||||||
|
import okio.ByteString
|
||||||
|
import okio.ByteString.Companion.toByteString
|
||||||
import org.junit.After
|
import org.junit.After
|
||||||
import org.junit.Before
|
import org.junit.Before
|
||||||
import org.junit.Ignore
|
import org.junit.Ignore
|
||||||
|
@ -15,6 +17,7 @@ import org.signal.libsignal.protocol.ecc.Curve
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair
|
import org.signal.libsignal.protocol.ecc.ECKeyPair
|
||||||
import org.signal.libsignal.zkgroup.profiles.ProfileKey
|
import org.signal.libsignal.zkgroup.profiles.ProfileKey
|
||||||
import org.thoughtcrime.securesms.crypto.UnidentifiedAccessUtil
|
import org.thoughtcrime.securesms.crypto.UnidentifiedAccessUtil
|
||||||
|
import org.thoughtcrime.securesms.dependencies.InstrumentationApplicationDependencyProvider
|
||||||
import org.thoughtcrime.securesms.recipients.Recipient
|
import org.thoughtcrime.securesms.recipients.Recipient
|
||||||
import org.thoughtcrime.securesms.testing.AliceClient
|
import org.thoughtcrime.securesms.testing.AliceClient
|
||||||
import org.thoughtcrime.securesms.testing.BobClient
|
import org.thoughtcrime.securesms.testing.BobClient
|
||||||
|
@ -23,6 +26,10 @@ import org.thoughtcrime.securesms.testing.FakeClientHelpers
|
||||||
import org.thoughtcrime.securesms.testing.SignalActivityRule
|
import org.thoughtcrime.securesms.testing.SignalActivityRule
|
||||||
import org.thoughtcrime.securesms.testing.awaitFor
|
import org.thoughtcrime.securesms.testing.awaitFor
|
||||||
import org.whispersystems.signalservice.internal.push.SignalServiceProtos.Envelope
|
import org.whispersystems.signalservice.internal.push.SignalServiceProtos.Envelope
|
||||||
|
import org.whispersystems.signalservice.internal.websocket.WebSocketProtos.WebSocketMessage
|
||||||
|
import org.whispersystems.signalservice.internal.websocket.WebSocketProtos.WebSocketRequestMessage
|
||||||
|
import java.util.regex.Pattern
|
||||||
|
import kotlin.random.Random
|
||||||
import kotlin.time.Duration.Companion.minutes
|
import kotlin.time.Duration.Companion.minutes
|
||||||
import kotlin.time.Duration.Companion.seconds
|
import kotlin.time.Duration.Companion.seconds
|
||||||
import android.util.Log as AndroidLog
|
import android.util.Log as AndroidLog
|
||||||
|
@ -37,6 +44,8 @@ class MessageProcessingPerformanceTest {
|
||||||
companion object {
|
companion object {
|
||||||
private val TAG = Log.tag(MessageProcessingPerformanceTest::class.java)
|
private val TAG = Log.tag(MessageProcessingPerformanceTest::class.java)
|
||||||
private val TIMING_TAG = "TIMING_$TAG".substring(0..23)
|
private val TIMING_TAG = "TIMING_$TAG".substring(0..23)
|
||||||
|
|
||||||
|
private val DECRYPTION_TIME_PATTERN = Pattern.compile("^Decrypted (?<count>\\d+) envelopes in (?<duration>\\d+) ms.*$")
|
||||||
}
|
}
|
||||||
|
|
||||||
@get:Rule
|
@get:Rule
|
||||||
|
@ -76,64 +85,43 @@ class MessageProcessingPerformanceTest {
|
||||||
profileKey = ProfileKey(bob.profileKey)
|
profileKey = ProfileKey(bob.profileKey)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Send message from Bob to Alice (self)
|
// Send the initial messages to get past the prekey phase
|
||||||
|
establishSession(aliceClient, bobClient, bob)
|
||||||
val firstPreKeyMessageTimestamp = System.currentTimeMillis()
|
|
||||||
val encryptedEnvelope = bobClient.encrypt(firstPreKeyMessageTimestamp)
|
|
||||||
|
|
||||||
val aliceProcessFirstMessageLatch = harness
|
|
||||||
.inMemoryLogger
|
|
||||||
.getLockForUntil(TimingMessageContentProcessor.endTagPredicate(firstPreKeyMessageTimestamp))
|
|
||||||
|
|
||||||
Thread { aliceClient.process(encryptedEnvelope, System.currentTimeMillis()) }.start()
|
|
||||||
aliceProcessFirstMessageLatch.awaitFor(15.seconds)
|
|
||||||
|
|
||||||
// Send message from Alice to Bob
|
|
||||||
val aliceNow = System.currentTimeMillis()
|
|
||||||
bobClient.decrypt(aliceClient.encrypt(aliceNow, bob), aliceNow)
|
|
||||||
|
|
||||||
// Build N messages from Bob to Alice
|
|
||||||
|
|
||||||
|
// Have Bob generate N messages that will be received by Alice
|
||||||
val messageCount = 100
|
val messageCount = 100
|
||||||
val envelopes = ArrayList<Envelope>(messageCount)
|
val envelopes = generateInboundEnvelopes(bobClient, messageCount)
|
||||||
var now = System.currentTimeMillis()
|
|
||||||
for (i in 0..messageCount) {
|
|
||||||
envelopes += bobClient.encrypt(now)
|
|
||||||
now += 3
|
|
||||||
}
|
|
||||||
|
|
||||||
val firstTimestamp = envelopes.first().timestamp
|
val firstTimestamp = envelopes.first().timestamp
|
||||||
val lastTimestamp = envelopes.last().timestamp
|
val lastTimestamp = envelopes.last().timestamp
|
||||||
|
|
||||||
// Alice processes N messages
|
// Inject the envelopes into the websocket
|
||||||
|
|
||||||
val aliceProcessLastMessageLatch = harness
|
|
||||||
.inMemoryLogger
|
|
||||||
.getLockForUntil(TimingMessageContentProcessor.endTagPredicate(lastTimestamp))
|
|
||||||
|
|
||||||
Thread {
|
Thread {
|
||||||
for (envelope in envelopes) {
|
for (envelope in envelopes) {
|
||||||
Log.i(TIMING_TAG, "Retrieved envelope! ${envelope.timestamp}")
|
Log.i(TIMING_TAG, "Retrieved envelope! ${envelope.timestamp}")
|
||||||
aliceClient.process(envelope, envelope.timestamp)
|
InstrumentationApplicationDependencyProvider.injectWebSocketMessage(envelope.toWebSocketPayload())
|
||||||
}
|
}
|
||||||
|
InstrumentationApplicationDependencyProvider.injectWebSocketMessage(webSocketTombstone())
|
||||||
}.start()
|
}.start()
|
||||||
|
|
||||||
// Wait for Alice to finish processing messages
|
// Wait until they've all been fully decrypted + processed
|
||||||
aliceProcessLastMessageLatch.awaitFor(1.minutes)
|
harness
|
||||||
|
.inMemoryLogger
|
||||||
|
.getLockForUntil(TimingMessageContentProcessor.endTagPredicate(lastTimestamp))
|
||||||
|
.awaitFor(1.minutes)
|
||||||
|
|
||||||
harness.inMemoryLogger.flush()
|
harness.inMemoryLogger.flush()
|
||||||
|
|
||||||
// Process logs for timing data
|
// Process logs for timing data
|
||||||
val entries = harness.inMemoryLogger.entries()
|
val entries = harness.inMemoryLogger.entries()
|
||||||
|
|
||||||
// Calculate decryption average
|
// Calculate decryption average
|
||||||
|
val totalDecryptDuration: Long = entries
|
||||||
|
.mapNotNull { entry -> entry.message?.let { DECRYPTION_TIME_PATTERN.matcher(it) } }
|
||||||
|
.filter { it.matches() }
|
||||||
|
.drop(1) // Ignore the first message, which represents the prekey exchange
|
||||||
|
.sumOf { it.group("duration")!!.toLong() }
|
||||||
|
|
||||||
val decrypts = entries
|
AndroidLog.w(TAG, "Decryption: Average runtime: ${totalDecryptDuration.toFloat() / messageCount.toFloat()}ms")
|
||||||
.filter { it.tag == AliceClient.TAG }
|
|
||||||
.drop(1)
|
|
||||||
|
|
||||||
val totalDecryptDuration = decrypts.sumOf { it.message!!.toLong() }
|
|
||||||
|
|
||||||
AndroidLog.w(TAG, "Decryption: Average runtime: ${totalDecryptDuration.toFloat() / decrypts.size.toFloat()}ms")
|
|
||||||
|
|
||||||
// Calculate MessageContentProcessor
|
// Calculate MessageContentProcessor
|
||||||
|
|
||||||
|
@ -160,4 +148,62 @@ class MessageProcessingPerformanceTest {
|
||||||
|
|
||||||
AndroidLog.w(TAG, "Processing $messageCount messages took ${duration}s or ${messagePerSecond}m/s")
|
AndroidLog.w(TAG, "Processing $messageCount messages took ${duration}s or ${messagePerSecond}m/s")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun establishSession(aliceClient: AliceClient, bobClient: BobClient, bob: Recipient) {
|
||||||
|
// Send message from Bob to Alice (self)
|
||||||
|
val firstPreKeyMessageTimestamp = System.currentTimeMillis()
|
||||||
|
val encryptedEnvelope = bobClient.encrypt(firstPreKeyMessageTimestamp)
|
||||||
|
|
||||||
|
val aliceProcessFirstMessageLatch = harness
|
||||||
|
.inMemoryLogger
|
||||||
|
.getLockForUntil(TimingMessageContentProcessor.endTagPredicate(firstPreKeyMessageTimestamp))
|
||||||
|
|
||||||
|
Thread { aliceClient.process(encryptedEnvelope, System.currentTimeMillis()) }.start()
|
||||||
|
aliceProcessFirstMessageLatch.awaitFor(15.seconds)
|
||||||
|
|
||||||
|
// Send message from Alice to Bob
|
||||||
|
val aliceNow = System.currentTimeMillis()
|
||||||
|
bobClient.decrypt(aliceClient.encrypt(aliceNow, bob), aliceNow)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun generateInboundEnvelopes(bobClient: BobClient, count: Int): List<Envelope> {
|
||||||
|
val envelopes = ArrayList<Envelope>(count)
|
||||||
|
var now = System.currentTimeMillis()
|
||||||
|
for (i in 0..count) {
|
||||||
|
envelopes += bobClient.encrypt(now)
|
||||||
|
now += 3
|
||||||
|
}
|
||||||
|
|
||||||
|
return envelopes
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun webSocketTombstone(): ByteString {
|
||||||
|
return WebSocketMessage
|
||||||
|
.newBuilder()
|
||||||
|
.setRequest(
|
||||||
|
WebSocketRequestMessage.newBuilder()
|
||||||
|
.setVerb("PUT")
|
||||||
|
.setPath("/api/v1/queue/empty")
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
.toByteArray()
|
||||||
|
.toByteString()
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun Envelope.toWebSocketPayload(): ByteString {
|
||||||
|
return WebSocketMessage
|
||||||
|
.newBuilder()
|
||||||
|
.setType(WebSocketMessage.Type.REQUEST)
|
||||||
|
.setRequest(
|
||||||
|
WebSocketRequestMessage.newBuilder()
|
||||||
|
.setVerb("PUT")
|
||||||
|
.setPath("/api/v1/message")
|
||||||
|
.setId(Random(System.currentTimeMillis()).nextLong())
|
||||||
|
.addHeaders("X-Signal-Timestamp: ${this.timestamp}")
|
||||||
|
.setBody(this.toByteString())
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
.toByteArray()
|
||||||
|
.toByteString()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,17 +7,20 @@ import org.thoughtcrime.securesms.util.JsonUtils
|
||||||
import java.util.concurrent.TimeUnit
|
import java.util.concurrent.TimeUnit
|
||||||
|
|
||||||
typealias ResponseFactory = (request: RecordedRequest) -> MockResponse
|
typealias ResponseFactory = (request: RecordedRequest) -> MockResponse
|
||||||
|
typealias RequestPredicate = (request: RecordedRequest) -> Boolean
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represent an HTTP verb for mocking web requests.
|
* Represent an HTTP verb for mocking web requests.
|
||||||
*/
|
*/
|
||||||
sealed class Verb(val verb: String, val path: String, val responseFactory: ResponseFactory)
|
sealed class Verb(val requestPredicate: RequestPredicate, val responseFactory: ResponseFactory)
|
||||||
|
|
||||||
class Get(path: String, responseFactory: ResponseFactory) : Verb("GET", path, responseFactory)
|
class Get(path: String, predicate: RequestPredicate, responseFactory: ResponseFactory) : Verb(defaultRequestPredicate("GET", path, predicate), responseFactory) {
|
||||||
|
constructor(path: String, responseFactory: ResponseFactory) : this(path, { true }, responseFactory)
|
||||||
|
}
|
||||||
|
|
||||||
class Put(path: String, responseFactory: ResponseFactory) : Verb("PUT", path, responseFactory)
|
class Put(path: String, responseFactory: ResponseFactory) : Verb(defaultRequestPredicate("PUT", path), responseFactory)
|
||||||
|
|
||||||
class Post(path: String, responseFactory: ResponseFactory) : Verb("POST", path, responseFactory)
|
class Post(path: String, responseFactory: ResponseFactory) : Verb(defaultRequestPredicate("POST", path), responseFactory)
|
||||||
|
|
||||||
fun MockResponse.success(response: Any? = null): MockResponse {
|
fun MockResponse.success(response: Any? = null): MockResponse {
|
||||||
return setResponseCode(200).apply {
|
return setResponseCode(200).apply {
|
||||||
|
@ -48,3 +51,7 @@ inline fun <reified T> RecordedRequest.parsedRequestBody(): T {
|
||||||
val bodyString = String(body.readByteArray())
|
val bodyString = String(body.readByteArray())
|
||||||
return JsonUtils.fromJson(bodyString, T::class.java)
|
return JsonUtils.fromJson(bodyString, T::class.java)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun defaultRequestPredicate(verb: String, path: String, predicate: RequestPredicate = { true }): RequestPredicate = { request ->
|
||||||
|
request.method == verb && request.path.startsWith("/$path") && predicate(request)
|
||||||
|
}
|
||||||
|
|
|
@ -221,11 +221,12 @@ class IncomingMessageObserver(private val context: Application) {
|
||||||
}
|
}
|
||||||
|
|
||||||
fun terminateAsync() {
|
fun terminateAsync() {
|
||||||
|
Log.w(TAG, "Termination Enqueued! ${this.hashCode()}", Throwable())
|
||||||
INSTANCE_COUNT.decrementAndGet()
|
INSTANCE_COUNT.decrementAndGet()
|
||||||
context.unregisterReceiver(connectionReceiver)
|
context.unregisterReceiver(connectionReceiver)
|
||||||
|
|
||||||
SignalExecutors.BOUNDED.execute {
|
SignalExecutors.BOUNDED.execute {
|
||||||
Log.w(TAG, "Beginning termination.")
|
Log.w(TAG, "Beginning termination. ${this.hashCode()}")
|
||||||
terminated = true
|
terminated = true
|
||||||
disconnect()
|
disconnect()
|
||||||
}
|
}
|
||||||
|
@ -371,7 +372,7 @@ class IncomingMessageObserver(private val context: Application) {
|
||||||
private inner class MessageRetrievalThread : Thread("MessageRetrievalService"), Thread.UncaughtExceptionHandler {
|
private inner class MessageRetrievalThread : Thread("MessageRetrievalService"), Thread.UncaughtExceptionHandler {
|
||||||
|
|
||||||
init {
|
init {
|
||||||
Log.i(TAG, "Initializing! (" + this.hashCode() + ")")
|
Log.i(TAG, "Initializing! (${this.hashCode()})")
|
||||||
uncaughtExceptionHandler = this
|
uncaughtExceptionHandler = this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -461,7 +462,7 @@ class IncomingMessageObserver(private val context: Application) {
|
||||||
}
|
}
|
||||||
Log.i(TAG, "Looping...")
|
Log.i(TAG, "Looping...")
|
||||||
}
|
}
|
||||||
Log.w(TAG, "Terminated! (" + this.hashCode() + ")")
|
Log.w(TAG, "Terminated! (${this.hashCode()})")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun uncaughtException(t: Thread, e: Throwable) {
|
override fun uncaughtException(t: Thread, e: Throwable) {
|
||||||
|
|
Loading…
Add table
Reference in a new issue