Make LibSignalChatConnection Only Use Each ChatService Once

This commit is contained in:
andrew-signal 2024-11-20 12:05:11 -05:00 committed by Greyson Parrelli
parent 040d05a0a6
commit 1401256ffd
3 changed files with 165 additions and 84 deletions

View file

@ -434,7 +434,9 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
Network network = libSignalNetworkSupplier.get();
return new LibSignalChatConnection(
"libsignal-unauth",
LibSignalNetworkExtensions.createChatService(network, null, Stories.isFeatureEnabled()),
network,
null,
Stories.isFeatureEnabled(),
healthMonitor);
} else {
return new OkHttpWebSocketConnection("unidentified",

View file

@ -13,12 +13,17 @@ import io.reactivex.rxjava3.subjects.SingleSubject
import org.signal.core.util.logging.Log
import org.signal.libsignal.net.AuthenticatedChatService
import org.signal.libsignal.net.ChatService
import org.signal.libsignal.net.Network
import org.signal.libsignal.net.UnauthenticatedChatService
import org.whispersystems.signalservice.api.util.CredentialsProvider
import org.whispersystems.signalservice.api.websocket.HealthMonitor
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
import org.whispersystems.signalservice.internal.util.whenComplete
import java.io.IOException
import java.time.Instant
import java.util.Optional
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
import kotlin.time.Duration.Companion.seconds
import org.signal.libsignal.net.ChatService.Request as LibSignalRequest
import org.signal.libsignal.net.ChatService.Response as LibSignalResponse
@ -38,10 +43,15 @@ import org.signal.libsignal.net.ChatService.Response as LibSignalResponse
*/
class LibSignalChatConnection(
name: String,
private val chatService: ChatService,
private val network: Network,
private val credentialsProvider: CredentialsProvider?,
private val receiveStories: Boolean,
private val healthMonitor: HealthMonitor
) : WebSocketConnection {
private val CHAT_SERVICE_LOCK = ReentrantLock()
private var chatService: ChatService? = null
companion object {
private val TAG = Log.tag(LibSignalChatConnection::class.java)
private val SEND_TIMEOUT: Long = 10.seconds.inWholeMilliseconds
@ -85,95 +95,118 @@ class LibSignalChatConnection(
val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED)
override fun connect(): Observable<WebSocketConnectionState> {
Log.i(TAG, "$name Connecting...")
state.onNext(WebSocketConnectionState.CONNECTING)
chatService.connect()
.whenComplete(
onSuccess = { debugInfo ->
Log.i(TAG, "$name Connected")
Log.d(TAG, "$name $debugInfo")
state.onNext(WebSocketConnectionState.CONNECTED)
},
onFailure = { throwable ->
// TODO: [libsignal-net] Report WebSocketConnectionState.AUTHENTICATION_FAILED for 401 and 403 errors
Log.d(TAG, "$name Connect failed", throwable)
state.onNext(WebSocketConnectionState.FAILED)
}
)
return state
CHAT_SERVICE_LOCK.withLock {
if (chatService != null) {
return state
}
Log.i(TAG, "$name Connecting...")
chatService = network.createChatService(credentialsProvider, receiveStories).apply {
state.onNext(WebSocketConnectionState.CONNECTING)
connect().whenComplete(
onSuccess = { debugInfo ->
Log.i(TAG, "$name Connected")
Log.d(TAG, "$name $debugInfo")
state.onNext(WebSocketConnectionState.CONNECTED)
},
onFailure = { throwable ->
// TODO[libsignal-net]: Report AUTHENTICATION_FAILED for 401 and 403 errors
Log.w(TAG, "$name Connect failed", throwable)
state.onNext(WebSocketConnectionState.FAILED)
}
)
}
return state
}
}
override fun isDead(): Boolean = false
override fun disconnect() {
Log.i(TAG, "$name Disconnecting...")
state.onNext(WebSocketConnectionState.DISCONNECTING)
chatService.disconnect()
.whenComplete(
onSuccess = {
Log.i(TAG, "$name Disconnected")
state.onNext(WebSocketConnectionState.DISCONNECTED)
},
onFailure = { throwable ->
Log.d(TAG, "$name Disconnect failed", throwable)
state.onNext(WebSocketConnectionState.DISCONNECTED)
}
)
CHAT_SERVICE_LOCK.withLock {
if (chatService == null) {
return
}
Log.i(TAG, "$name Disconnecting...")
state.onNext(WebSocketConnectionState.DISCONNECTING)
chatService!!.disconnect()
.whenComplete(
onSuccess = {
Log.i(TAG, "$name Disconnected")
state.onNext(WebSocketConnectionState.DISCONNECTED)
},
onFailure = { throwable ->
Log.w(TAG, "$name Disconnect failed", throwable)
state.onNext(WebSocketConnectionState.DISCONNECTED)
}
)
chatService = null
}
}
override fun sendRequest(request: WebSocketRequestMessage): Single<WebsocketResponse> {
val single = SingleSubject.create<WebsocketResponse>()
val internalRequest = request.toLibSignalRequest()
chatService.send(internalRequest)
.whenComplete(
onSuccess = { response ->
when (response!!.status) {
in 400..599 -> {
healthMonitor.onMessageError(response.status, false)
CHAT_SERVICE_LOCK.withLock {
if (chatService == null) {
return Single.error(IOException("[$name] is closed!"))
}
val single = SingleSubject.create<WebsocketResponse>()
val internalRequest = request.toLibSignalRequest()
chatService!!.send(internalRequest)
.whenComplete(
onSuccess = { response ->
when (response!!.status) {
in 400..599 -> {
healthMonitor.onMessageError(response.status, false)
}
}
// Here success means "we received the response" even if it is reporting an error.
// This is consistent with the behavior of the OkHttpWebSocketConnection.
single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatService is UnauthenticatedChatService)))
},
onFailure = { throwable ->
Log.w(TAG, "$name sendRequest failed", throwable)
single.onError(throwable)
}
// Here success means "we received the response" even if it is reporting an error.
// This is consistent with the behavior of the OkHttpWebSocketConnection.
single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatService is UnauthenticatedChatService)))
},
onFailure = { throwable ->
Log.i(TAG, "$name sendRequest failed", throwable)
single.onError(throwable)
}
)
return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io())
)
return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io())
}
}
override fun sendKeepAlive() {
Log.i(TAG, "$name Sending keep alive...")
chatService.sendAndDebug(KEEP_ALIVE_REQUEST)
.whenComplete(
onSuccess = { debugResponse ->
Log.i(TAG, "$name Keep alive - success")
Log.d(TAG, "$name $debugResponse")
when (debugResponse!!.response.status) {
in 200..299 -> {
healthMonitor.onKeepAliveResponse(
Instant.now().toEpochMilli(), // ignored. can be any value
false
)
}
CHAT_SERVICE_LOCK.withLock {
if (chatService == null) {
return
}
in 400..599 -> {
healthMonitor.onMessageError(debugResponse.response.status, (chatService is AuthenticatedChatService))
}
Log.i(TAG, "$name Sending keep alive...")
chatService!!.sendAndDebug(KEEP_ALIVE_REQUEST)
.whenComplete(
onSuccess = { debugResponse ->
Log.d(TAG, "$name Keep alive - success")
when (debugResponse!!.response.status) {
in 200..299 -> {
healthMonitor.onKeepAliveResponse(
Instant.now().toEpochMilli(), // ignored. can be any value
false
)
}
else -> {
Log.w(TAG, "$name Unsupported keep alive response status: ${debugResponse.response.status}")
in 400..599 -> {
healthMonitor.onMessageError(debugResponse.response.status, (chatService is AuthenticatedChatService))
}
else -> {
Log.w(TAG, "$name Unsupported keep alive response status: ${debugResponse.response.status}")
}
}
},
onFailure = { throwable ->
Log.w(TAG, "$name Keep alive - failed", throwable)
state.onNext(WebSocketConnectionState.DISCONNECTED)
}
},
onFailure = { throwable ->
Log.i(TAG, "$name Keep alive - failed")
Log.d(TAG, "$name $throwable")
state.onNext(WebSocketConnectionState.DISCONNECTED)
}
)
)
}
}
override fun readRequestIfAvailable(): Optional<WebSocketRequestMessage> {

View file

@ -3,6 +3,7 @@ package org.whispersystems.signalservice.internal.websocket
import io.mockk.clearAllMocks
import io.mockk.every
import io.mockk.mockk
import io.mockk.mockkStatic
import io.mockk.verify
import io.reactivex.rxjava3.observers.TestObserver
import org.junit.Before
@ -11,6 +12,7 @@ import org.signal.libsignal.internal.CompletableFuture
import org.signal.libsignal.net.ChatService
import org.signal.libsignal.net.ChatService.DebugInfo
import org.signal.libsignal.net.IpType
import org.signal.libsignal.net.Network
import org.whispersystems.signalservice.api.websocket.HealthMonitor
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
import java.util.concurrent.CountDownLatch
@ -25,13 +27,16 @@ class LibSignalChatConnectionTest {
private val executor: ExecutorService = Executors.newSingleThreadExecutor()
private val healthMonitor = mockk<HealthMonitor>()
private val chatService = mockk<ChatService>()
private val connection = LibSignalChatConnection("test", chatService, healthMonitor)
private val network = mockk<Network>()
private val connection = LibSignalChatConnection("test", network, null, false, healthMonitor)
@Before
fun before() {
clearAllMocks()
mockkStatic(Network::createChatService)
every { healthMonitor.onMessageError(any(), any()) }
every { healthMonitor.onKeepAliveResponse(any(), any()) }
every { network.createChatService(any(), any()) } answers { chatService }
}
@Test
@ -127,25 +132,37 @@ class LibSignalChatConnectionTest {
fun orderOfStatesOnDisconnectFailure() {
val disconnectException = RuntimeException("disconnect failed")
val latch = CountDownLatch(1)
val connectLatch = CountDownLatch(1)
val disconnectLatch = CountDownLatch(1)
every { chatService.disconnect() } answers {
delay {
it.completeExceptionally(disconnectException)
disconnectLatch.countDown()
}
}
val observer = TestObserver<WebSocketConnectionState>()
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
connectLatch.countDown()
}
}
connection.connect()
connectLatch.await(100, TimeUnit.MILLISECONDS)
val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer)
connection.disconnect()
latch.await(100, TimeUnit.MILLISECONDS)
disconnectLatch.await(100, TimeUnit.MILLISECONDS)
observer.assertNotComplete()
observer.assertValues(
WebSocketConnectionState.DISCONNECTED,
WebSocketConnectionState.CONNECTED,
WebSocketConnectionState.DISCONNECTING,
WebSocketConnectionState.DISCONNECTED
)
@ -162,6 +179,14 @@ class LibSignalChatConnectionTest {
}
}
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
}
}
connection.connect()
connection.sendKeepAlive()
latch.await(100, TimeUnit.MILLISECONDS)
@ -185,6 +210,14 @@ class LibSignalChatConnectionTest {
}
}
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
}
}
connection.connect()
connection.sendKeepAlive()
latch.await(100, TimeUnit.MILLISECONDS)
@ -200,28 +233,41 @@ class LibSignalChatConnectionTest {
@Test
fun keepAliveConnectionFailure() {
val connectionFailure = RuntimeException("Sending keep-alive failed")
val latch = CountDownLatch(1)
val connectLatch = CountDownLatch(1)
val keepAliveFailureLatch = CountDownLatch(1)
every {
chatService.sendAndDebug(any())
} answers {
delay {
it.completeExceptionally(connectionFailure)
keepAliveFailureLatch.countDown()
}
}
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
connectLatch.countDown()
}
}
connection.connect()
connectLatch.await(100, TimeUnit.MILLISECONDS)
val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer)
connection.sendKeepAlive()
latch.await(100, TimeUnit.MILLISECONDS)
keepAliveFailureLatch.await(100, TimeUnit.MILLISECONDS)
observer.assertNotComplete()
observer.assertValues(
// This is the starting state
WebSocketConnectionState.DISCONNECTED,
// This one is the result of a keep-alive failure
// We start in the connected state
WebSocketConnectionState.CONNECTED,
// Disconnects as a result of keep-alive failure
WebSocketConnectionState.DISCONNECTED
)
verify(exactly = 0) {