Use ChatListener to get connection interrupted event from libsignal; clear connection state when received

This commit is contained in:
andrew-signal 2024-11-23 18:15:14 -05:00 committed by GitHub
parent 0356b01866
commit 9833101cd1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 67 additions and 6 deletions

View file

@ -427,7 +427,7 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
BuildConfig.SIGNAL_AGENT,
healthMonitor,
Stories.isFeatureEnabled(),
LibSignalNetworkExtensions.createChatService(libSignalNetworkSupplier.get(), null, Stories.isFeatureEnabled()),
LibSignalNetworkExtensions.createChatService(libSignalNetworkSupplier.get(), null, Stories.isFeatureEnabled(), null),
shadowPercentage,
bridge
);

View file

@ -12,7 +12,9 @@ import io.reactivex.rxjava3.subjects.BehaviorSubject
import io.reactivex.rxjava3.subjects.SingleSubject
import org.signal.core.util.logging.Log
import org.signal.libsignal.net.AuthenticatedChatService
import org.signal.libsignal.net.ChatListener
import org.signal.libsignal.net.ChatService
import org.signal.libsignal.net.ChatServiceException
import org.signal.libsignal.net.Network
import org.signal.libsignal.net.UnauthenticatedChatService
import org.whispersystems.signalservice.api.util.CredentialsProvider
@ -101,7 +103,7 @@ class LibSignalChatConnection(
}
Log.i(TAG, "$name Connecting...")
chatService = network.createChatService(credentialsProvider, receiveStories).apply {
chatService = network.createChatService(credentialsProvider, receiveStories, listener).apply {
state.onNext(WebSocketConnectionState.CONNECTING)
connect().whenComplete(
onSuccess = { debugInfo ->
@ -220,4 +222,18 @@ class LibSignalChatConnection(
override fun sendResponse(response: WebSocketResponseMessage?) {
throw NotImplementedError()
}
private val listener = object : ChatListener {
override fun onIncomingMessage(chat: ChatService?, envelope: ByteArray?, serverDeliveryTimestamp: Long, sendAck: ChatListener.ServerMessageAck?) {
throw NotImplementedError()
}
override fun onConnectionInterrupted(chat: ChatService?, disconnectReason: ChatServiceException?) {
CHAT_SERVICE_LOCK.withLock {
Log.i(TAG, "connection interrupted", disconnectReason)
state.onNext(WebSocketConnectionState.DISCONNECTED)
chatService = null
}
}
}
}

View file

@ -7,6 +7,7 @@
package org.whispersystems.signalservice.internal.websocket
import org.signal.core.util.orNull
import org.signal.libsignal.net.ChatListener
import org.signal.libsignal.net.ChatService
import org.signal.libsignal.net.Network
import org.whispersystems.signalservice.api.util.CredentialsProvider
@ -17,14 +18,15 @@ import org.whispersystems.signalservice.internal.configuration.SignalServiceConf
*/
fun Network.createChatService(
credentialsProvider: CredentialsProvider? = null,
receiveStories: Boolean
receiveStories: Boolean,
listener: ChatListener? = null
): ChatService {
val username = credentialsProvider?.username ?: ""
val password = credentialsProvider?.password ?: ""
return if (username.isEmpty() && password.isEmpty()) {
this.createUnauthChatService(null)
this.createUnauthChatService(listener)
} else {
this.createAuthChatService(username, password, receiveStories, null)
this.createAuthChatService(username, password, receiveStories, listener)
}
}

View file

@ -9,8 +9,10 @@ import io.reactivex.rxjava3.observers.TestObserver
import org.junit.Before
import org.junit.Test
import org.signal.libsignal.internal.CompletableFuture
import org.signal.libsignal.net.ChatListener
import org.signal.libsignal.net.ChatService
import org.signal.libsignal.net.ChatService.DebugInfo
import org.signal.libsignal.net.ChatServiceException
import org.signal.libsignal.net.IpType
import org.signal.libsignal.net.Network
import org.whispersystems.signalservice.api.websocket.HealthMonitor
@ -29,6 +31,7 @@ class LibSignalChatConnectionTest {
private val chatService = mockk<ChatService>()
private val network = mockk<Network>()
private val connection = LibSignalChatConnection("test", network, null, false, healthMonitor)
private var chatListener: ChatListener? = null
@Before
fun before() {
@ -36,7 +39,14 @@ class LibSignalChatConnectionTest {
mockkStatic(Network::createChatService)
every { healthMonitor.onMessageError(any(), any()) }
every { healthMonitor.onKeepAliveResponse(any(), any()) }
every { network.createChatService(any(), any()) } answers { chatService }
every { network.createChatService(any(), any(), any()) } answers {
// When mocking static methods in mockk, the mock target is included as the first
// argument in the answers block. This results in the thirdArgument<T>() convenience method
// being off-by-one. Since we are interested in the last argument to createChatService, we need
// to manually fetch it from the args array and cast it ourselves.
chatListener = args[3] as ChatListener?
chatService
}
}
@Test
@ -276,6 +286,39 @@ class LibSignalChatConnectionTest {
}
}
@Test
fun connectionInterrupted() {
val disconnectReason = ChatServiceException("simulated interrupt")
val connectLatch = CountDownLatch(1)
every { chatService.connect() } answers {
delay {
it.complete(DEBUG_INFO)
connectLatch.countDown()
}
}
connection.connect()
connectLatch.await(100, TimeUnit.MILLISECONDS)
val observer = TestObserver<WebSocketConnectionState>()
connection.state.subscribe(observer)
chatListener!!.onConnectionInterrupted(chatService, disconnectReason)
observer.assertNotComplete()
observer.assertValues(
// We start in the connected state
WebSocketConnectionState.CONNECTED,
// Disconnects as a result of the connection interrupted event
WebSocketConnectionState.DISCONNECTED
)
verify(exactly = 0) {
healthMonitor.onKeepAliveResponse(any(), any())
healthMonitor.onMessageError(any(), any())
}
}
private fun <T> delay(action: ((CompletableFuture<T>) -> Unit)): CompletableFuture<T> {
val future = CompletableFuture<T>()
executor.submit {