Rotate quick restore QR code and web socket.

This commit is contained in:
Cody Henthorne 2024-12-10 16:32:00 -05:00 committed by Greyson Parrelli
parent 57502fb4ad
commit 2eabf03421
3 changed files with 131 additions and 33 deletions

View file

@ -169,8 +169,12 @@ private fun RestoreViaQrScreen(
) { ) {
AnimatedContent( AnimatedContent(
targetState = state.qrState, targetState = state.qrState,
contentKey = { it::class },
contentAlignment = Alignment.Center, contentAlignment = Alignment.Center,
label = "qr-code-progress" label = "qr-code-progress",
modifier = Modifier
.fillMaxWidth()
.fillMaxHeight()
) { qrState -> ) { qrState ->
when (qrState) { when (qrState) {
is RestoreViaQrViewModel.QrState.Loaded -> { is RestoreViaQrViewModel.QrState.Loaded -> {
@ -184,8 +188,10 @@ private fun RestoreViaQrScreen(
} }
RestoreViaQrViewModel.QrState.Loading -> { RestoreViaQrViewModel.QrState.Loading -> {
Box(contentAlignment = Alignment.Center) {
CircularProgressIndicator(modifier = Modifier.size(48.dp)) CircularProgressIndicator(modifier = Modifier.size(48.dp))
} }
}
is RestoreViaQrViewModel.QrState.Scanned, is RestoreViaQrViewModel.QrState.Scanned,
RestoreViaQrViewModel.QrState.Failed -> { RestoreViaQrViewModel.QrState.Failed -> {

View file

@ -6,10 +6,15 @@
package org.thoughtcrime.securesms.registrationv3.ui.restore package org.thoughtcrime.securesms.registrationv3.ui.restore
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import kotlinx.coroutines.CoroutineExceptionHandler import androidx.lifecycle.viewModelScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.update import kotlinx.coroutines.flow.update
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.signal.registration.proto.RegistrationProvisionMessage import org.signal.registration.proto.RegistrationProvisionMessage
import org.thoughtcrime.securesms.backup.v2.MessageBackupTier import org.thoughtcrime.securesms.backup.v2.MessageBackupTier
@ -31,15 +36,30 @@ class RestoreViaQrViewModel : ViewModel() {
val state: StateFlow<RestoreViaQrState> = store val state: StateFlow<RestoreViaQrState> = store
private var socketHandle: Closeable private var socketHandles: MutableList<Closeable> = mutableListOf()
private var startNewSocketJob: Job? = null
init { init {
socketHandle = start() restart()
} }
fun restart() { fun restart() {
socketHandle.close() SignalStore.registration.restoreMethodToken = null
socketHandle = start() shutdown()
startNewSocket()
startNewSocketJob = viewModelScope.launch(Dispatchers.IO) {
var count = 0
while (count < 5 && isActive) {
delay(ProvisioningSocket.LIFESPAN / 2)
if (isActive) {
startNewSocket()
count++
Log.d(TAG, "Started next websocket count: $count")
}
}
}
} }
fun handleRegistrationFailure() { fun handleRegistrationFailure() {
@ -61,20 +81,66 @@ class RestoreViaQrViewModel : ViewModel() {
} }
override fun onCleared() { override fun onCleared() {
socketHandle.close() shutdown()
}
private fun startNewSocket() {
synchronized(socketHandles) {
socketHandles += start()
if (socketHandles.size > 2) {
socketHandles.removeAt(0).close()
}
}
}
private fun shutdown() {
startNewSocketJob?.cancel()
synchronized(socketHandles) {
socketHandles.forEach { it.close() }
socketHandles.clear()
}
} }
private fun start(): Closeable { private fun start(): Closeable {
SignalStore.registration.restoreMethodToken = null store.update {
store.update { it.copy(qrState = QrState.Loading) } if (it.qrState !is QrState.Loaded) {
it.copy(qrState = QrState.Loading)
} else {
it
}
}
return ProvisioningSocket.start( return ProvisioningSocket.start(
identityKeyPair = IdentityKeyUtil.generateIdentityKeyPair(), identityKeyPair = IdentityKeyUtil.generateIdentityKeyPair(),
configuration = AppDependencies.signalServiceNetworkAccess.getConfiguration(), configuration = AppDependencies.signalServiceNetworkAccess.getConfiguration(),
handler = CoroutineExceptionHandler { _, _ -> store.update { it.copy(qrState = QrState.Failed) } } handler = { id, t ->
store.update {
if (it.currentSocketId == null || it.currentSocketId == id) {
Log.w(TAG, "Current socket [$id] has failed, stopping automatic connects", t)
shutdown()
it.copy(currentSocketId = null, qrState = QrState.Failed)
} else {
Log.i(TAG, "Old socket [$id] failed, ignoring")
it
}
}
}
) { socket -> ) { socket ->
val url = socket.getProvisioningUrl() val url = socket.getProvisioningUrl()
store.update { it.copy(qrState = QrState.Loaded(qrData = QrCodeData.forData(data = url, supportIconOverlay = false))) } store.update {
Log.d(TAG, "Updating QR code with data from [${socket.id}]")
it.copy(
currentSocketId = socket.id,
qrState = QrState.Loaded(
qrData = QrCodeData.forData(
data = url,
supportIconOverlay = false
)
)
)
}
val result = socket.getRegistrationProvisioningMessage() val result = socket.getRegistrationProvisioningMessage()
@ -94,8 +160,15 @@ class RestoreViaQrViewModel : ViewModel() {
SignalStore.backup.usedBackupMediaSpace = result.message.backupSizeBytes SignalStore.backup.usedBackupMediaSpace = result.message.backupSizeBytes
} }
store.update { it.copy(isRegistering = true, provisioningMessage = result.message, qrState = QrState.Scanned) } store.update { it.copy(isRegistering = true, provisioningMessage = result.message, qrState = QrState.Scanned) }
shutdown()
} else { } else {
store.update { it.copy(showProvisioningError = true, qrState = QrState.Scanned) } store.update {
if (it.currentSocketId == socket.id) {
it.copy(showProvisioningError = true, qrState = QrState.Scanned)
} else {
it
}
}
} }
} }
} }
@ -105,7 +178,8 @@ class RestoreViaQrViewModel : ViewModel() {
val qrState: QrState = QrState.Loading, val qrState: QrState = QrState.Loading,
val provisioningMessage: RegistrationProvisionMessage? = null, val provisioningMessage: RegistrationProvisionMessage? = null,
val showProvisioningError: Boolean = false, val showProvisioningError: Boolean = false,
val showRegistrationError: Boolean = false val showRegistrationError: Boolean = false,
val currentSocketId: Int? = null
) )
sealed interface QrState { sealed interface QrState {

View file

@ -44,6 +44,7 @@ import kotlin.time.Duration.Companion.seconds
* A provisional web socket for communicating with a primary device during registration. * A provisional web socket for communicating with a primary device during registration.
*/ */
class ProvisioningSocket private constructor( class ProvisioningSocket private constructor(
val id: Int,
identityKeyPair: IdentityKeyPair, identityKeyPair: IdentityKeyPair,
configuration: SignalServiceConfiguration, configuration: SignalServiceConfiguration,
private val scope: CoroutineScope private val scope: CoroutineScope
@ -51,31 +52,36 @@ class ProvisioningSocket private constructor(
companion object { companion object {
private val TAG = Log.tag(ProvisioningSocket::class) private val TAG = Log.tag(ProvisioningSocket::class)
@Volatile private var nextSocketId = 1000
val LIFESPAN = 90.seconds
fun start( fun start(
identityKeyPair: IdentityKeyPair, identityKeyPair: IdentityKeyPair,
configuration: SignalServiceConfiguration, configuration: SignalServiceConfiguration,
handler: CoroutineExceptionHandler, handler: ProvisioningSocketExceptionHandler,
block: suspend CoroutineScope.(ProvisioningSocket) -> Unit block: suspend CoroutineScope.(ProvisioningSocket) -> Unit
): Closeable { ): Closeable {
val scope = CoroutineScope(Dispatchers.IO) + SupervisorJob() + handler val socketId = nextSocketId++
val scope = CoroutineScope(Dispatchers.IO) + SupervisorJob() + CoroutineExceptionHandler { _, t -> handler.handleException(socketId, t) }
scope.launch { scope.launch {
var socket: ProvisioningSocket? = null var socket: ProvisioningSocket? = null
try { try {
socket = ProvisioningSocket(identityKeyPair, configuration, scope) socket = ProvisioningSocket(socketId, identityKeyPair, configuration, scope)
socket.connect() socket.connect()
block(socket) block(socket)
} catch (e: CancellationException) { } catch (e: CancellationException) {
val rootCause = e.getRootCause() val rootCause = e.getRootCause()
if (rootCause == null) { if (rootCause == null) {
Log.i(TAG, "Scope canceled expectedly, fail silently, ${e.toMinimalString()}") Log.i(TAG, "[$socketId] Scope canceled expectedly, fail silently, ${e.toMinimalString()}")
throw e throw e
} else { } else {
Log.w(TAG, "Unable to maintain web socket, ${rootCause.toMinimalString()}", rootCause) Log.w(TAG, "[$socketId] Unable to maintain web socket, ${rootCause.toMinimalString()}", rootCause)
throw rootCause throw rootCause
} }
} finally { } finally {
Log.d(TAG, "Closing web socket") Log.d(TAG, "[$socketId] Closing web socket")
socket?.close() socket?.close()
} }
} }
@ -144,7 +150,7 @@ class ProvisioningSocket private constructor(
private var lastKeepAliveId: Long = 0 private var lastKeepAliveId: Long = 0
override fun onOpen(webSocket: WebSocket, response: Response) { override fun onOpen(webSocket: WebSocket, response: Response) {
Log.d(TAG, "[onOpen]") Log.d(TAG, "[$id] [onOpen]")
keepAliveJob = scope.launch { keepAlive(webSocket) } keepAliveJob = scope.launch { keepAlive(webSocket) }
val timeoutJob = scope.launch { val timeoutJob = scope.launch {
@ -152,9 +158,17 @@ class ProvisioningSocket private constructor(
scope.cancel("Did not receive device id within 10 seconds", SocketTimeoutException("No device id received")) scope.cancel("Did not receive device id within 10 seconds", SocketTimeoutException("No device id received"))
} }
val webSocketExpireJob = scope.launch {
delay(LIFESPAN)
scope.cancel("Did not complete a registration within ${LIFESPAN.inWholeSeconds} seconds", SocketTimeoutException("No provisioning message received"))
}
scope.launch { scope.launch {
provisioningUrlDeferral.await() provisioningUrlDeferral.await()
timeoutJob.cancel() timeoutJob.cancel()
provisioningMessageDeferral.await()
webSocketExpireJob.cancel()
} }
} }
@ -162,24 +176,24 @@ class ProvisioningSocket private constructor(
val message: WebSocketMessage = WebSocketMessage.ADAPTER.decode(bytes) val message: WebSocketMessage = WebSocketMessage.ADAPTER.decode(bytes)
if (message.response != null && message.response.id == lastKeepAliveId) { if (message.response != null && message.response.id == lastKeepAliveId) {
Log.d(TAG, "[onMessage] Keep alive received") Log.d(TAG, "[$id] [onMessage] Keep alive received")
return return
} }
if (message.request == null) { if (message.request == null) {
Log.w(TAG, "[onMessage] Received null request") Log.w(TAG, "[$id] [onMessage] Received null request")
return return
} }
val success = webSocket.send(message.request.toResponse().encode().toByteString()) val success = webSocket.send(message.request.toResponse().encode().toByteString())
if (!success) { if (!success) {
Log.w(TAG, "[onMessage] Failed to send response") Log.w(TAG, "[$id] [onMessage] Failed to send response")
webSocket.close(1000, "OK") webSocket.close(1000, "OK")
return return
} }
Log.d(TAG, "[onMessage] Processing request") Log.d(TAG, "[$id] [onMessage] Processing request")
if (message.request.verb == "PUT" && message.request.body != null) { if (message.request.verb == "PUT" && message.request.body != null) {
when (message.request.path) { when (message.request.path) {
@ -197,19 +211,19 @@ class ProvisioningSocket private constructor(
provisioningMessageDeferral.complete(result) provisioningMessageDeferral.complete(result)
} }
else -> Log.w(TAG, "Unknown path requested") else -> Log.w(TAG, "[$id] Unknown path requested")
} }
} else { } else {
Log.w(TAG, "Invalid data") Log.w(TAG, "[$id] Invalid data")
} }
} }
override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {
scope.launch { scope.launch {
Log.i(TAG, "[onClosing] code: $code reason: $reason") Log.i(TAG, "[$id] [onClosing] code: $code reason: $reason")
if (code != 1000) { if (code != 1000) {
Log.w(TAG, "Remote side is closing with non-normal code $code") Log.w(TAG, "[$id] Remote side is closing with non-normal code $code")
webSocket.close(1000, "Remote closed with code $code") webSocket.close(1000, "Remote closed with code $code")
} }
@ -219,7 +233,7 @@ class ProvisioningSocket private constructor(
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
scope.launch { scope.launch {
Log.w(TAG, "[onFailure] Failed", t) Log.w(TAG, "[$id] [onFailure] Failed", t)
webSocket.close(1000, "Failed ${t.message}") webSocket.close(1000, "Failed ${t.message}")
scope.cancel(CancellationException("WebSocket Failure", t)) scope.cancel(CancellationException("WebSocket Failure", t))
@ -233,10 +247,10 @@ class ProvisioningSocket private constructor(
} }
private suspend fun keepAlive(webSocket: WebSocket) { private suspend fun keepAlive(webSocket: WebSocket) {
Log.i(TAG, "[keepAlive] Starting") Log.i(TAG, "[$id] [keepAlive] Starting")
while (true) { while (true) {
delay(30.seconds) delay(30.seconds)
Log.i(TAG, "[keepAlive] Sending...") Log.i(TAG, "[$id] [keepAlive] Sending...")
val id = System.currentTimeMillis() val id = System.currentTimeMillis()
val message = WebSocketMessage( val message = WebSocketMessage(
@ -249,7 +263,7 @@ class ProvisioningSocket private constructor(
) )
if (!webSocket.send(message.encodeByteString())) { if (!webSocket.send(message.encodeByteString())) {
Log.w(TAG, "[keepAlive] Send failed") Log.w(TAG, "[${this@ProvisioningSocket.id}] [keepAlive] Send failed")
} else { } else {
lastKeepAliveId = id lastKeepAliveId = id
} }
@ -267,4 +281,8 @@ class ProvisioningSocket private constructor(
) )
} }
} }
fun interface ProvisioningSocketExceptionHandler {
fun handleException(id: Int, exception: Throwable)
}
} }