diff --git a/app/src/main/java/org/thoughtcrime/securesms/registrationv3/ui/restore/RestoreViaQrFragment.kt b/app/src/main/java/org/thoughtcrime/securesms/registrationv3/ui/restore/RestoreViaQrFragment.kt index 58a689b8dd..cc221d7199 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/registrationv3/ui/restore/RestoreViaQrFragment.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/registrationv3/ui/restore/RestoreViaQrFragment.kt @@ -169,8 +169,12 @@ private fun RestoreViaQrScreen( ) { AnimatedContent( targetState = state.qrState, + contentKey = { it::class }, contentAlignment = Alignment.Center, - label = "qr-code-progress" + label = "qr-code-progress", + modifier = Modifier + .fillMaxWidth() + .fillMaxHeight() ) { qrState -> when (qrState) { is RestoreViaQrViewModel.QrState.Loaded -> { @@ -184,7 +188,9 @@ private fun RestoreViaQrScreen( } RestoreViaQrViewModel.QrState.Loading -> { - CircularProgressIndicator(modifier = Modifier.size(48.dp)) + Box(contentAlignment = Alignment.Center) { + CircularProgressIndicator(modifier = Modifier.size(48.dp)) + } } is RestoreViaQrViewModel.QrState.Scanned, diff --git a/app/src/main/java/org/thoughtcrime/securesms/registrationv3/ui/restore/RestoreViaQrViewModel.kt b/app/src/main/java/org/thoughtcrime/securesms/registrationv3/ui/restore/RestoreViaQrViewModel.kt index d28bd4f4a9..e96f7067b7 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/registrationv3/ui/restore/RestoreViaQrViewModel.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/registrationv3/ui/restore/RestoreViaQrViewModel.kt @@ -6,10 +6,15 @@ package org.thoughtcrime.securesms.registrationv3.ui.restore import androidx.lifecycle.ViewModel -import kotlinx.coroutines.CoroutineExceptionHandler +import androidx.lifecycle.viewModelScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.delay import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.update +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch import org.signal.core.util.logging.Log import org.signal.registration.proto.RegistrationProvisionMessage import org.thoughtcrime.securesms.backup.v2.MessageBackupTier @@ -31,15 +36,30 @@ class RestoreViaQrViewModel : ViewModel() { val state: StateFlow = store - private var socketHandle: Closeable + private var socketHandles: MutableList = mutableListOf() + private var startNewSocketJob: Job? = null init { - socketHandle = start() + restart() } fun restart() { - socketHandle.close() - socketHandle = start() + SignalStore.registration.restoreMethodToken = null + shutdown() + + startNewSocket() + + startNewSocketJob = viewModelScope.launch(Dispatchers.IO) { + var count = 0 + while (count < 5 && isActive) { + delay(ProvisioningSocket.LIFESPAN / 2) + if (isActive) { + startNewSocket() + count++ + Log.d(TAG, "Started next websocket count: $count") + } + } + } } fun handleRegistrationFailure() { @@ -61,20 +81,66 @@ class RestoreViaQrViewModel : ViewModel() { } override fun onCleared() { - socketHandle.close() + shutdown() + } + + private fun startNewSocket() { + synchronized(socketHandles) { + socketHandles += start() + + if (socketHandles.size > 2) { + socketHandles.removeAt(0).close() + } + } + } + + private fun shutdown() { + startNewSocketJob?.cancel() + synchronized(socketHandles) { + socketHandles.forEach { it.close() } + socketHandles.clear() + } } private fun start(): Closeable { - SignalStore.registration.restoreMethodToken = null - store.update { it.copy(qrState = QrState.Loading) } + store.update { + if (it.qrState !is QrState.Loaded) { + it.copy(qrState = QrState.Loading) + } else { + it + } + } return ProvisioningSocket.start( identityKeyPair = IdentityKeyUtil.generateIdentityKeyPair(), configuration = AppDependencies.signalServiceNetworkAccess.getConfiguration(), - handler = CoroutineExceptionHandler { _, _ -> store.update { it.copy(qrState = QrState.Failed) } } + handler = { id, t -> + store.update { + if (it.currentSocketId == null || it.currentSocketId == id) { + Log.w(TAG, "Current socket [$id] has failed, stopping automatic connects", t) + shutdown() + it.copy(currentSocketId = null, qrState = QrState.Failed) + } else { + Log.i(TAG, "Old socket [$id] failed, ignoring") + it + } + } + } ) { socket -> val url = socket.getProvisioningUrl() - store.update { it.copy(qrState = QrState.Loaded(qrData = QrCodeData.forData(data = url, supportIconOverlay = false))) } + store.update { + Log.d(TAG, "Updating QR code with data from [${socket.id}]") + + it.copy( + currentSocketId = socket.id, + qrState = QrState.Loaded( + qrData = QrCodeData.forData( + data = url, + supportIconOverlay = false + ) + ) + ) + } val result = socket.getRegistrationProvisioningMessage() @@ -94,8 +160,15 @@ class RestoreViaQrViewModel : ViewModel() { SignalStore.backup.usedBackupMediaSpace = result.message.backupSizeBytes } store.update { it.copy(isRegistering = true, provisioningMessage = result.message, qrState = QrState.Scanned) } + shutdown() } else { - store.update { it.copy(showProvisioningError = true, qrState = QrState.Scanned) } + store.update { + if (it.currentSocketId == socket.id) { + it.copy(showProvisioningError = true, qrState = QrState.Scanned) + } else { + it + } + } } } } @@ -105,7 +178,8 @@ class RestoreViaQrViewModel : ViewModel() { val qrState: QrState = QrState.Loading, val provisioningMessage: RegistrationProvisionMessage? = null, val showProvisioningError: Boolean = false, - val showRegistrationError: Boolean = false + val showRegistrationError: Boolean = false, + val currentSocketId: Int? = null ) sealed interface QrState { diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/registration/ProvisioningSocket.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/registration/ProvisioningSocket.kt index 3256a401c3..86dc268aeb 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/registration/ProvisioningSocket.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/registration/ProvisioningSocket.kt @@ -44,6 +44,7 @@ import kotlin.time.Duration.Companion.seconds * A provisional web socket for communicating with a primary device during registration. */ class ProvisioningSocket private constructor( + val id: Int, identityKeyPair: IdentityKeyPair, configuration: SignalServiceConfiguration, private val scope: CoroutineScope @@ -51,31 +52,36 @@ class ProvisioningSocket private constructor( companion object { private val TAG = Log.tag(ProvisioningSocket::class) + @Volatile private var nextSocketId = 1000 + + val LIFESPAN = 90.seconds + fun start( identityKeyPair: IdentityKeyPair, configuration: SignalServiceConfiguration, - handler: CoroutineExceptionHandler, + handler: ProvisioningSocketExceptionHandler, block: suspend CoroutineScope.(ProvisioningSocket) -> Unit ): Closeable { - val scope = CoroutineScope(Dispatchers.IO) + SupervisorJob() + handler + val socketId = nextSocketId++ + val scope = CoroutineScope(Dispatchers.IO) + SupervisorJob() + CoroutineExceptionHandler { _, t -> handler.handleException(socketId, t) } scope.launch { var socket: ProvisioningSocket? = null try { - socket = ProvisioningSocket(identityKeyPair, configuration, scope) + socket = ProvisioningSocket(socketId, identityKeyPair, configuration, scope) socket.connect() block(socket) } catch (e: CancellationException) { val rootCause = e.getRootCause() if (rootCause == null) { - Log.i(TAG, "Scope canceled expectedly, fail silently, ${e.toMinimalString()}") + Log.i(TAG, "[$socketId] Scope canceled expectedly, fail silently, ${e.toMinimalString()}") throw e } else { - Log.w(TAG, "Unable to maintain web socket, ${rootCause.toMinimalString()}", rootCause) + Log.w(TAG, "[$socketId] Unable to maintain web socket, ${rootCause.toMinimalString()}", rootCause) throw rootCause } } finally { - Log.d(TAG, "Closing web socket") + Log.d(TAG, "[$socketId] Closing web socket") socket?.close() } } @@ -144,7 +150,7 @@ class ProvisioningSocket private constructor( private var lastKeepAliveId: Long = 0 override fun onOpen(webSocket: WebSocket, response: Response) { - Log.d(TAG, "[onOpen]") + Log.d(TAG, "[$id] [onOpen]") keepAliveJob = scope.launch { keepAlive(webSocket) } val timeoutJob = scope.launch { @@ -152,9 +158,17 @@ class ProvisioningSocket private constructor( scope.cancel("Did not receive device id within 10 seconds", SocketTimeoutException("No device id received")) } + val webSocketExpireJob = scope.launch { + delay(LIFESPAN) + scope.cancel("Did not complete a registration within ${LIFESPAN.inWholeSeconds} seconds", SocketTimeoutException("No provisioning message received")) + } + scope.launch { provisioningUrlDeferral.await() timeoutJob.cancel() + + provisioningMessageDeferral.await() + webSocketExpireJob.cancel() } } @@ -162,24 +176,24 @@ class ProvisioningSocket private constructor( val message: WebSocketMessage = WebSocketMessage.ADAPTER.decode(bytes) if (message.response != null && message.response.id == lastKeepAliveId) { - Log.d(TAG, "[onMessage] Keep alive received") + Log.d(TAG, "[$id] [onMessage] Keep alive received") return } if (message.request == null) { - Log.w(TAG, "[onMessage] Received null request") + Log.w(TAG, "[$id] [onMessage] Received null request") return } val success = webSocket.send(message.request.toResponse().encode().toByteString()) if (!success) { - Log.w(TAG, "[onMessage] Failed to send response") + Log.w(TAG, "[$id] [onMessage] Failed to send response") webSocket.close(1000, "OK") return } - Log.d(TAG, "[onMessage] Processing request") + Log.d(TAG, "[$id] [onMessage] Processing request") if (message.request.verb == "PUT" && message.request.body != null) { when (message.request.path) { @@ -197,19 +211,19 @@ class ProvisioningSocket private constructor( provisioningMessageDeferral.complete(result) } - else -> Log.w(TAG, "Unknown path requested") + else -> Log.w(TAG, "[$id] Unknown path requested") } } else { - Log.w(TAG, "Invalid data") + Log.w(TAG, "[$id] Invalid data") } } override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { scope.launch { - Log.i(TAG, "[onClosing] code: $code reason: $reason") + Log.i(TAG, "[$id] [onClosing] code: $code reason: $reason") if (code != 1000) { - Log.w(TAG, "Remote side is closing with non-normal code $code") + Log.w(TAG, "[$id] Remote side is closing with non-normal code $code") webSocket.close(1000, "Remote closed with code $code") } @@ -219,7 +233,7 @@ class ProvisioningSocket private constructor( override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { scope.launch { - Log.w(TAG, "[onFailure] Failed", t) + Log.w(TAG, "[$id] [onFailure] Failed", t) webSocket.close(1000, "Failed ${t.message}") scope.cancel(CancellationException("WebSocket Failure", t)) @@ -233,10 +247,10 @@ class ProvisioningSocket private constructor( } private suspend fun keepAlive(webSocket: WebSocket) { - Log.i(TAG, "[keepAlive] Starting") + Log.i(TAG, "[$id] [keepAlive] Starting") while (true) { delay(30.seconds) - Log.i(TAG, "[keepAlive] Sending...") + Log.i(TAG, "[$id] [keepAlive] Sending...") val id = System.currentTimeMillis() val message = WebSocketMessage( @@ -249,7 +263,7 @@ class ProvisioningSocket private constructor( ) if (!webSocket.send(message.encodeByteString())) { - Log.w(TAG, "[keepAlive] Send failed") + Log.w(TAG, "[${this@ProvisioningSocket.id}] [keepAlive] Send failed") } else { lastKeepAliveId = id } @@ -267,4 +281,8 @@ class ProvisioningSocket private constructor( ) } } + + fun interface ProvisioningSocketExceptionHandler { + fun handleException(id: Int, exception: Throwable) + } }