Make LibSignalChatConnection Only Use Each ChatService Once
This commit is contained in:
parent
040d05a0a6
commit
1401256ffd
3 changed files with 165 additions and 84 deletions
|
@ -434,7 +434,9 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
|
|||
Network network = libSignalNetworkSupplier.get();
|
||||
return new LibSignalChatConnection(
|
||||
"libsignal-unauth",
|
||||
LibSignalNetworkExtensions.createChatService(network, null, Stories.isFeatureEnabled()),
|
||||
network,
|
||||
null,
|
||||
Stories.isFeatureEnabled(),
|
||||
healthMonitor);
|
||||
} else {
|
||||
return new OkHttpWebSocketConnection("unidentified",
|
||||
|
|
|
@ -13,12 +13,17 @@ import io.reactivex.rxjava3.subjects.SingleSubject
|
|||
import org.signal.core.util.logging.Log
|
||||
import org.signal.libsignal.net.AuthenticatedChatService
|
||||
import org.signal.libsignal.net.ChatService
|
||||
import org.signal.libsignal.net.Network
|
||||
import org.signal.libsignal.net.UnauthenticatedChatService
|
||||
import org.whispersystems.signalservice.api.util.CredentialsProvider
|
||||
import org.whispersystems.signalservice.api.websocket.HealthMonitor
|
||||
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
|
||||
import org.whispersystems.signalservice.internal.util.whenComplete
|
||||
import java.io.IOException
|
||||
import java.time.Instant
|
||||
import java.util.Optional
|
||||
import java.util.concurrent.locks.ReentrantLock
|
||||
import kotlin.concurrent.withLock
|
||||
import kotlin.time.Duration.Companion.seconds
|
||||
import org.signal.libsignal.net.ChatService.Request as LibSignalRequest
|
||||
import org.signal.libsignal.net.ChatService.Response as LibSignalResponse
|
||||
|
@ -38,10 +43,15 @@ import org.signal.libsignal.net.ChatService.Response as LibSignalResponse
|
|||
*/
|
||||
class LibSignalChatConnection(
|
||||
name: String,
|
||||
private val chatService: ChatService,
|
||||
private val network: Network,
|
||||
private val credentialsProvider: CredentialsProvider?,
|
||||
private val receiveStories: Boolean,
|
||||
private val healthMonitor: HealthMonitor
|
||||
) : WebSocketConnection {
|
||||
|
||||
private val CHAT_SERVICE_LOCK = ReentrantLock()
|
||||
private var chatService: ChatService? = null
|
||||
|
||||
companion object {
|
||||
private val TAG = Log.tag(LibSignalChatConnection::class.java)
|
||||
private val SEND_TIMEOUT: Long = 10.seconds.inWholeMilliseconds
|
||||
|
@ -85,95 +95,118 @@ class LibSignalChatConnection(
|
|||
val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED)
|
||||
|
||||
override fun connect(): Observable<WebSocketConnectionState> {
|
||||
Log.i(TAG, "$name Connecting...")
|
||||
state.onNext(WebSocketConnectionState.CONNECTING)
|
||||
chatService.connect()
|
||||
.whenComplete(
|
||||
onSuccess = { debugInfo ->
|
||||
Log.i(TAG, "$name Connected")
|
||||
Log.d(TAG, "$name $debugInfo")
|
||||
state.onNext(WebSocketConnectionState.CONNECTED)
|
||||
},
|
||||
onFailure = { throwable ->
|
||||
// TODO: [libsignal-net] Report WebSocketConnectionState.AUTHENTICATION_FAILED for 401 and 403 errors
|
||||
Log.d(TAG, "$name Connect failed", throwable)
|
||||
state.onNext(WebSocketConnectionState.FAILED)
|
||||
}
|
||||
)
|
||||
return state
|
||||
CHAT_SERVICE_LOCK.withLock {
|
||||
if (chatService != null) {
|
||||
return state
|
||||
}
|
||||
|
||||
Log.i(TAG, "$name Connecting...")
|
||||
chatService = network.createChatService(credentialsProvider, receiveStories).apply {
|
||||
state.onNext(WebSocketConnectionState.CONNECTING)
|
||||
connect().whenComplete(
|
||||
onSuccess = { debugInfo ->
|
||||
Log.i(TAG, "$name Connected")
|
||||
Log.d(TAG, "$name $debugInfo")
|
||||
state.onNext(WebSocketConnectionState.CONNECTED)
|
||||
},
|
||||
onFailure = { throwable ->
|
||||
// TODO[libsignal-net]: Report AUTHENTICATION_FAILED for 401 and 403 errors
|
||||
Log.w(TAG, "$name Connect failed", throwable)
|
||||
state.onNext(WebSocketConnectionState.FAILED)
|
||||
}
|
||||
)
|
||||
}
|
||||
return state
|
||||
}
|
||||
}
|
||||
|
||||
override fun isDead(): Boolean = false
|
||||
|
||||
override fun disconnect() {
|
||||
Log.i(TAG, "$name Disconnecting...")
|
||||
state.onNext(WebSocketConnectionState.DISCONNECTING)
|
||||
chatService.disconnect()
|
||||
.whenComplete(
|
||||
onSuccess = {
|
||||
Log.i(TAG, "$name Disconnected")
|
||||
state.onNext(WebSocketConnectionState.DISCONNECTED)
|
||||
},
|
||||
onFailure = { throwable ->
|
||||
Log.d(TAG, "$name Disconnect failed", throwable)
|
||||
state.onNext(WebSocketConnectionState.DISCONNECTED)
|
||||
}
|
||||
)
|
||||
CHAT_SERVICE_LOCK.withLock {
|
||||
if (chatService == null) {
|
||||
return
|
||||
}
|
||||
|
||||
Log.i(TAG, "$name Disconnecting...")
|
||||
state.onNext(WebSocketConnectionState.DISCONNECTING)
|
||||
chatService!!.disconnect()
|
||||
.whenComplete(
|
||||
onSuccess = {
|
||||
Log.i(TAG, "$name Disconnected")
|
||||
state.onNext(WebSocketConnectionState.DISCONNECTED)
|
||||
},
|
||||
onFailure = { throwable ->
|
||||
Log.w(TAG, "$name Disconnect failed", throwable)
|
||||
state.onNext(WebSocketConnectionState.DISCONNECTED)
|
||||
}
|
||||
)
|
||||
chatService = null
|
||||
}
|
||||
}
|
||||
|
||||
override fun sendRequest(request: WebSocketRequestMessage): Single<WebsocketResponse> {
|
||||
val single = SingleSubject.create<WebsocketResponse>()
|
||||
val internalRequest = request.toLibSignalRequest()
|
||||
chatService.send(internalRequest)
|
||||
.whenComplete(
|
||||
onSuccess = { response ->
|
||||
when (response!!.status) {
|
||||
in 400..599 -> {
|
||||
healthMonitor.onMessageError(response.status, false)
|
||||
CHAT_SERVICE_LOCK.withLock {
|
||||
if (chatService == null) {
|
||||
return Single.error(IOException("[$name] is closed!"))
|
||||
}
|
||||
val single = SingleSubject.create<WebsocketResponse>()
|
||||
val internalRequest = request.toLibSignalRequest()
|
||||
chatService!!.send(internalRequest)
|
||||
.whenComplete(
|
||||
onSuccess = { response ->
|
||||
when (response!!.status) {
|
||||
in 400..599 -> {
|
||||
healthMonitor.onMessageError(response.status, false)
|
||||
}
|
||||
}
|
||||
// Here success means "we received the response" even if it is reporting an error.
|
||||
// This is consistent with the behavior of the OkHttpWebSocketConnection.
|
||||
single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatService is UnauthenticatedChatService)))
|
||||
},
|
||||
onFailure = { throwable ->
|
||||
Log.w(TAG, "$name sendRequest failed", throwable)
|
||||
single.onError(throwable)
|
||||
}
|
||||
// Here success means "we received the response" even if it is reporting an error.
|
||||
// This is consistent with the behavior of the OkHttpWebSocketConnection.
|
||||
single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatService is UnauthenticatedChatService)))
|
||||
},
|
||||
onFailure = { throwable ->
|
||||
Log.i(TAG, "$name sendRequest failed", throwable)
|
||||
single.onError(throwable)
|
||||
}
|
||||
)
|
||||
return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io())
|
||||
)
|
||||
return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io())
|
||||
}
|
||||
}
|
||||
|
||||
override fun sendKeepAlive() {
|
||||
Log.i(TAG, "$name Sending keep alive...")
|
||||
chatService.sendAndDebug(KEEP_ALIVE_REQUEST)
|
||||
.whenComplete(
|
||||
onSuccess = { debugResponse ->
|
||||
Log.i(TAG, "$name Keep alive - success")
|
||||
Log.d(TAG, "$name $debugResponse")
|
||||
when (debugResponse!!.response.status) {
|
||||
in 200..299 -> {
|
||||
healthMonitor.onKeepAliveResponse(
|
||||
Instant.now().toEpochMilli(), // ignored. can be any value
|
||||
false
|
||||
)
|
||||
}
|
||||
CHAT_SERVICE_LOCK.withLock {
|
||||
if (chatService == null) {
|
||||
return
|
||||
}
|
||||
|
||||
in 400..599 -> {
|
||||
healthMonitor.onMessageError(debugResponse.response.status, (chatService is AuthenticatedChatService))
|
||||
}
|
||||
Log.i(TAG, "$name Sending keep alive...")
|
||||
chatService!!.sendAndDebug(KEEP_ALIVE_REQUEST)
|
||||
.whenComplete(
|
||||
onSuccess = { debugResponse ->
|
||||
Log.d(TAG, "$name Keep alive - success")
|
||||
when (debugResponse!!.response.status) {
|
||||
in 200..299 -> {
|
||||
healthMonitor.onKeepAliveResponse(
|
||||
Instant.now().toEpochMilli(), // ignored. can be any value
|
||||
false
|
||||
)
|
||||
}
|
||||
|
||||
else -> {
|
||||
Log.w(TAG, "$name Unsupported keep alive response status: ${debugResponse.response.status}")
|
||||
in 400..599 -> {
|
||||
healthMonitor.onMessageError(debugResponse.response.status, (chatService is AuthenticatedChatService))
|
||||
}
|
||||
|
||||
else -> {
|
||||
Log.w(TAG, "$name Unsupported keep alive response status: ${debugResponse.response.status}")
|
||||
}
|
||||
}
|
||||
},
|
||||
onFailure = { throwable ->
|
||||
Log.w(TAG, "$name Keep alive - failed", throwable)
|
||||
state.onNext(WebSocketConnectionState.DISCONNECTED)
|
||||
}
|
||||
},
|
||||
onFailure = { throwable ->
|
||||
Log.i(TAG, "$name Keep alive - failed")
|
||||
Log.d(TAG, "$name $throwable")
|
||||
state.onNext(WebSocketConnectionState.DISCONNECTED)
|
||||
}
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
override fun readRequestIfAvailable(): Optional<WebSocketRequestMessage> {
|
||||
|
|
|
@ -3,6 +3,7 @@ package org.whispersystems.signalservice.internal.websocket
|
|||
import io.mockk.clearAllMocks
|
||||
import io.mockk.every
|
||||
import io.mockk.mockk
|
||||
import io.mockk.mockkStatic
|
||||
import io.mockk.verify
|
||||
import io.reactivex.rxjava3.observers.TestObserver
|
||||
import org.junit.Before
|
||||
|
@ -11,6 +12,7 @@ import org.signal.libsignal.internal.CompletableFuture
|
|||
import org.signal.libsignal.net.ChatService
|
||||
import org.signal.libsignal.net.ChatService.DebugInfo
|
||||
import org.signal.libsignal.net.IpType
|
||||
import org.signal.libsignal.net.Network
|
||||
import org.whispersystems.signalservice.api.websocket.HealthMonitor
|
||||
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
|
||||
import java.util.concurrent.CountDownLatch
|
||||
|
@ -25,13 +27,16 @@ class LibSignalChatConnectionTest {
|
|||
private val executor: ExecutorService = Executors.newSingleThreadExecutor()
|
||||
private val healthMonitor = mockk<HealthMonitor>()
|
||||
private val chatService = mockk<ChatService>()
|
||||
private val connection = LibSignalChatConnection("test", chatService, healthMonitor)
|
||||
private val network = mockk<Network>()
|
||||
private val connection = LibSignalChatConnection("test", network, null, false, healthMonitor)
|
||||
|
||||
@Before
|
||||
fun before() {
|
||||
clearAllMocks()
|
||||
mockkStatic(Network::createChatService)
|
||||
every { healthMonitor.onMessageError(any(), any()) }
|
||||
every { healthMonitor.onKeepAliveResponse(any(), any()) }
|
||||
every { network.createChatService(any(), any()) } answers { chatService }
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -127,25 +132,37 @@ class LibSignalChatConnectionTest {
|
|||
fun orderOfStatesOnDisconnectFailure() {
|
||||
val disconnectException = RuntimeException("disconnect failed")
|
||||
|
||||
val latch = CountDownLatch(1)
|
||||
val connectLatch = CountDownLatch(1)
|
||||
val disconnectLatch = CountDownLatch(1)
|
||||
|
||||
every { chatService.disconnect() } answers {
|
||||
delay {
|
||||
it.completeExceptionally(disconnectException)
|
||||
disconnectLatch.countDown()
|
||||
}
|
||||
}
|
||||
|
||||
val observer = TestObserver<WebSocketConnectionState>()
|
||||
every { chatService.connect() } answers {
|
||||
delay {
|
||||
it.complete(DEBUG_INFO)
|
||||
connectLatch.countDown()
|
||||
}
|
||||
}
|
||||
|
||||
connection.connect()
|
||||
|
||||
connectLatch.await(100, TimeUnit.MILLISECONDS)
|
||||
|
||||
val observer = TestObserver<WebSocketConnectionState>()
|
||||
connection.state.subscribe(observer)
|
||||
|
||||
connection.disconnect()
|
||||
|
||||
latch.await(100, TimeUnit.MILLISECONDS)
|
||||
disconnectLatch.await(100, TimeUnit.MILLISECONDS)
|
||||
|
||||
observer.assertNotComplete()
|
||||
observer.assertValues(
|
||||
WebSocketConnectionState.DISCONNECTED,
|
||||
WebSocketConnectionState.CONNECTED,
|
||||
WebSocketConnectionState.DISCONNECTING,
|
||||
WebSocketConnectionState.DISCONNECTED
|
||||
)
|
||||
|
@ -162,6 +179,14 @@ class LibSignalChatConnectionTest {
|
|||
}
|
||||
}
|
||||
|
||||
every { chatService.connect() } answers {
|
||||
delay {
|
||||
it.complete(DEBUG_INFO)
|
||||
}
|
||||
}
|
||||
|
||||
connection.connect()
|
||||
|
||||
connection.sendKeepAlive()
|
||||
|
||||
latch.await(100, TimeUnit.MILLISECONDS)
|
||||
|
@ -185,6 +210,14 @@ class LibSignalChatConnectionTest {
|
|||
}
|
||||
}
|
||||
|
||||
every { chatService.connect() } answers {
|
||||
delay {
|
||||
it.complete(DEBUG_INFO)
|
||||
}
|
||||
}
|
||||
|
||||
connection.connect()
|
||||
|
||||
connection.sendKeepAlive()
|
||||
latch.await(100, TimeUnit.MILLISECONDS)
|
||||
|
||||
|
@ -200,28 +233,41 @@ class LibSignalChatConnectionTest {
|
|||
@Test
|
||||
fun keepAliveConnectionFailure() {
|
||||
val connectionFailure = RuntimeException("Sending keep-alive failed")
|
||||
val latch = CountDownLatch(1)
|
||||
|
||||
val connectLatch = CountDownLatch(1)
|
||||
val keepAliveFailureLatch = CountDownLatch(1)
|
||||
|
||||
every {
|
||||
chatService.sendAndDebug(any())
|
||||
} answers {
|
||||
delay {
|
||||
it.completeExceptionally(connectionFailure)
|
||||
keepAliveFailureLatch.countDown()
|
||||
}
|
||||
}
|
||||
|
||||
every { chatService.connect() } answers {
|
||||
delay {
|
||||
it.complete(DEBUG_INFO)
|
||||
connectLatch.countDown()
|
||||
}
|
||||
}
|
||||
|
||||
connection.connect()
|
||||
connectLatch.await(100, TimeUnit.MILLISECONDS)
|
||||
|
||||
val observer = TestObserver<WebSocketConnectionState>()
|
||||
connection.state.subscribe(observer)
|
||||
|
||||
connection.sendKeepAlive()
|
||||
|
||||
latch.await(100, TimeUnit.MILLISECONDS)
|
||||
keepAliveFailureLatch.await(100, TimeUnit.MILLISECONDS)
|
||||
|
||||
observer.assertNotComplete()
|
||||
observer.assertValues(
|
||||
// This is the starting state
|
||||
WebSocketConnectionState.DISCONNECTED,
|
||||
// This one is the result of a keep-alive failure
|
||||
// We start in the connected state
|
||||
WebSocketConnectionState.CONNECTED,
|
||||
// Disconnects as a result of keep-alive failure
|
||||
WebSocketConnectionState.DISCONNECTED
|
||||
)
|
||||
verify(exactly = 0) {
|
||||
|
|
Loading…
Add table
Reference in a new issue