Rotate quick restore QR code and web socket.
This commit is contained in:
parent
57502fb4ad
commit
2eabf03421
3 changed files with 131 additions and 33 deletions
|
@ -169,8 +169,12 @@ private fun RestoreViaQrScreen(
|
|||
) {
|
||||
AnimatedContent(
|
||||
targetState = state.qrState,
|
||||
contentKey = { it::class },
|
||||
contentAlignment = Alignment.Center,
|
||||
label = "qr-code-progress"
|
||||
label = "qr-code-progress",
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.fillMaxHeight()
|
||||
) { qrState ->
|
||||
when (qrState) {
|
||||
is RestoreViaQrViewModel.QrState.Loaded -> {
|
||||
|
@ -184,7 +188,9 @@ private fun RestoreViaQrScreen(
|
|||
}
|
||||
|
||||
RestoreViaQrViewModel.QrState.Loading -> {
|
||||
CircularProgressIndicator(modifier = Modifier.size(48.dp))
|
||||
Box(contentAlignment = Alignment.Center) {
|
||||
CircularProgressIndicator(modifier = Modifier.size(48.dp))
|
||||
}
|
||||
}
|
||||
|
||||
is RestoreViaQrViewModel.QrState.Scanned,
|
||||
|
|
|
@ -6,10 +6,15 @@
|
|||
package org.thoughtcrime.securesms.registrationv3.ui.restore
|
||||
|
||||
import androidx.lifecycle.ViewModel
|
||||
import kotlinx.coroutines.CoroutineExceptionHandler
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.update
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.launch
|
||||
import org.signal.core.util.logging.Log
|
||||
import org.signal.registration.proto.RegistrationProvisionMessage
|
||||
import org.thoughtcrime.securesms.backup.v2.MessageBackupTier
|
||||
|
@ -31,15 +36,30 @@ class RestoreViaQrViewModel : ViewModel() {
|
|||
|
||||
val state: StateFlow<RestoreViaQrState> = store
|
||||
|
||||
private var socketHandle: Closeable
|
||||
private var socketHandles: MutableList<Closeable> = mutableListOf()
|
||||
private var startNewSocketJob: Job? = null
|
||||
|
||||
init {
|
||||
socketHandle = start()
|
||||
restart()
|
||||
}
|
||||
|
||||
fun restart() {
|
||||
socketHandle.close()
|
||||
socketHandle = start()
|
||||
SignalStore.registration.restoreMethodToken = null
|
||||
shutdown()
|
||||
|
||||
startNewSocket()
|
||||
|
||||
startNewSocketJob = viewModelScope.launch(Dispatchers.IO) {
|
||||
var count = 0
|
||||
while (count < 5 && isActive) {
|
||||
delay(ProvisioningSocket.LIFESPAN / 2)
|
||||
if (isActive) {
|
||||
startNewSocket()
|
||||
count++
|
||||
Log.d(TAG, "Started next websocket count: $count")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun handleRegistrationFailure() {
|
||||
|
@ -61,20 +81,66 @@ class RestoreViaQrViewModel : ViewModel() {
|
|||
}
|
||||
|
||||
override fun onCleared() {
|
||||
socketHandle.close()
|
||||
shutdown()
|
||||
}
|
||||
|
||||
private fun startNewSocket() {
|
||||
synchronized(socketHandles) {
|
||||
socketHandles += start()
|
||||
|
||||
if (socketHandles.size > 2) {
|
||||
socketHandles.removeAt(0).close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun shutdown() {
|
||||
startNewSocketJob?.cancel()
|
||||
synchronized(socketHandles) {
|
||||
socketHandles.forEach { it.close() }
|
||||
socketHandles.clear()
|
||||
}
|
||||
}
|
||||
|
||||
private fun start(): Closeable {
|
||||
SignalStore.registration.restoreMethodToken = null
|
||||
store.update { it.copy(qrState = QrState.Loading) }
|
||||
store.update {
|
||||
if (it.qrState !is QrState.Loaded) {
|
||||
it.copy(qrState = QrState.Loading)
|
||||
} else {
|
||||
it
|
||||
}
|
||||
}
|
||||
|
||||
return ProvisioningSocket.start(
|
||||
identityKeyPair = IdentityKeyUtil.generateIdentityKeyPair(),
|
||||
configuration = AppDependencies.signalServiceNetworkAccess.getConfiguration(),
|
||||
handler = CoroutineExceptionHandler { _, _ -> store.update { it.copy(qrState = QrState.Failed) } }
|
||||
handler = { id, t ->
|
||||
store.update {
|
||||
if (it.currentSocketId == null || it.currentSocketId == id) {
|
||||
Log.w(TAG, "Current socket [$id] has failed, stopping automatic connects", t)
|
||||
shutdown()
|
||||
it.copy(currentSocketId = null, qrState = QrState.Failed)
|
||||
} else {
|
||||
Log.i(TAG, "Old socket [$id] failed, ignoring")
|
||||
it
|
||||
}
|
||||
}
|
||||
}
|
||||
) { socket ->
|
||||
val url = socket.getProvisioningUrl()
|
||||
store.update { it.copy(qrState = QrState.Loaded(qrData = QrCodeData.forData(data = url, supportIconOverlay = false))) }
|
||||
store.update {
|
||||
Log.d(TAG, "Updating QR code with data from [${socket.id}]")
|
||||
|
||||
it.copy(
|
||||
currentSocketId = socket.id,
|
||||
qrState = QrState.Loaded(
|
||||
qrData = QrCodeData.forData(
|
||||
data = url,
|
||||
supportIconOverlay = false
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
val result = socket.getRegistrationProvisioningMessage()
|
||||
|
||||
|
@ -94,8 +160,15 @@ class RestoreViaQrViewModel : ViewModel() {
|
|||
SignalStore.backup.usedBackupMediaSpace = result.message.backupSizeBytes
|
||||
}
|
||||
store.update { it.copy(isRegistering = true, provisioningMessage = result.message, qrState = QrState.Scanned) }
|
||||
shutdown()
|
||||
} else {
|
||||
store.update { it.copy(showProvisioningError = true, qrState = QrState.Scanned) }
|
||||
store.update {
|
||||
if (it.currentSocketId == socket.id) {
|
||||
it.copy(showProvisioningError = true, qrState = QrState.Scanned)
|
||||
} else {
|
||||
it
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -105,7 +178,8 @@ class RestoreViaQrViewModel : ViewModel() {
|
|||
val qrState: QrState = QrState.Loading,
|
||||
val provisioningMessage: RegistrationProvisionMessage? = null,
|
||||
val showProvisioningError: Boolean = false,
|
||||
val showRegistrationError: Boolean = false
|
||||
val showRegistrationError: Boolean = false,
|
||||
val currentSocketId: Int? = null
|
||||
)
|
||||
|
||||
sealed interface QrState {
|
||||
|
|
|
@ -44,6 +44,7 @@ import kotlin.time.Duration.Companion.seconds
|
|||
* A provisional web socket for communicating with a primary device during registration.
|
||||
*/
|
||||
class ProvisioningSocket private constructor(
|
||||
val id: Int,
|
||||
identityKeyPair: IdentityKeyPair,
|
||||
configuration: SignalServiceConfiguration,
|
||||
private val scope: CoroutineScope
|
||||
|
@ -51,31 +52,36 @@ class ProvisioningSocket private constructor(
|
|||
companion object {
|
||||
private val TAG = Log.tag(ProvisioningSocket::class)
|
||||
|
||||
@Volatile private var nextSocketId = 1000
|
||||
|
||||
val LIFESPAN = 90.seconds
|
||||
|
||||
fun start(
|
||||
identityKeyPair: IdentityKeyPair,
|
||||
configuration: SignalServiceConfiguration,
|
||||
handler: CoroutineExceptionHandler,
|
||||
handler: ProvisioningSocketExceptionHandler,
|
||||
block: suspend CoroutineScope.(ProvisioningSocket) -> Unit
|
||||
): Closeable {
|
||||
val scope = CoroutineScope(Dispatchers.IO) + SupervisorJob() + handler
|
||||
val socketId = nextSocketId++
|
||||
val scope = CoroutineScope(Dispatchers.IO) + SupervisorJob() + CoroutineExceptionHandler { _, t -> handler.handleException(socketId, t) }
|
||||
|
||||
scope.launch {
|
||||
var socket: ProvisioningSocket? = null
|
||||
try {
|
||||
socket = ProvisioningSocket(identityKeyPair, configuration, scope)
|
||||
socket = ProvisioningSocket(socketId, identityKeyPair, configuration, scope)
|
||||
socket.connect()
|
||||
block(socket)
|
||||
} catch (e: CancellationException) {
|
||||
val rootCause = e.getRootCause()
|
||||
if (rootCause == null) {
|
||||
Log.i(TAG, "Scope canceled expectedly, fail silently, ${e.toMinimalString()}")
|
||||
Log.i(TAG, "[$socketId] Scope canceled expectedly, fail silently, ${e.toMinimalString()}")
|
||||
throw e
|
||||
} else {
|
||||
Log.w(TAG, "Unable to maintain web socket, ${rootCause.toMinimalString()}", rootCause)
|
||||
Log.w(TAG, "[$socketId] Unable to maintain web socket, ${rootCause.toMinimalString()}", rootCause)
|
||||
throw rootCause
|
||||
}
|
||||
} finally {
|
||||
Log.d(TAG, "Closing web socket")
|
||||
Log.d(TAG, "[$socketId] Closing web socket")
|
||||
socket?.close()
|
||||
}
|
||||
}
|
||||
|
@ -144,7 +150,7 @@ class ProvisioningSocket private constructor(
|
|||
private var lastKeepAliveId: Long = 0
|
||||
|
||||
override fun onOpen(webSocket: WebSocket, response: Response) {
|
||||
Log.d(TAG, "[onOpen]")
|
||||
Log.d(TAG, "[$id] [onOpen]")
|
||||
keepAliveJob = scope.launch { keepAlive(webSocket) }
|
||||
|
||||
val timeoutJob = scope.launch {
|
||||
|
@ -152,9 +158,17 @@ class ProvisioningSocket private constructor(
|
|||
scope.cancel("Did not receive device id within 10 seconds", SocketTimeoutException("No device id received"))
|
||||
}
|
||||
|
||||
val webSocketExpireJob = scope.launch {
|
||||
delay(LIFESPAN)
|
||||
scope.cancel("Did not complete a registration within ${LIFESPAN.inWholeSeconds} seconds", SocketTimeoutException("No provisioning message received"))
|
||||
}
|
||||
|
||||
scope.launch {
|
||||
provisioningUrlDeferral.await()
|
||||
timeoutJob.cancel()
|
||||
|
||||
provisioningMessageDeferral.await()
|
||||
webSocketExpireJob.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -162,24 +176,24 @@ class ProvisioningSocket private constructor(
|
|||
val message: WebSocketMessage = WebSocketMessage.ADAPTER.decode(bytes)
|
||||
|
||||
if (message.response != null && message.response.id == lastKeepAliveId) {
|
||||
Log.d(TAG, "[onMessage] Keep alive received")
|
||||
Log.d(TAG, "[$id] [onMessage] Keep alive received")
|
||||
return
|
||||
}
|
||||
|
||||
if (message.request == null) {
|
||||
Log.w(TAG, "[onMessage] Received null request")
|
||||
Log.w(TAG, "[$id] [onMessage] Received null request")
|
||||
return
|
||||
}
|
||||
|
||||
val success = webSocket.send(message.request.toResponse().encode().toByteString())
|
||||
|
||||
if (!success) {
|
||||
Log.w(TAG, "[onMessage] Failed to send response")
|
||||
Log.w(TAG, "[$id] [onMessage] Failed to send response")
|
||||
webSocket.close(1000, "OK")
|
||||
return
|
||||
}
|
||||
|
||||
Log.d(TAG, "[onMessage] Processing request")
|
||||
Log.d(TAG, "[$id] [onMessage] Processing request")
|
||||
|
||||
if (message.request.verb == "PUT" && message.request.body != null) {
|
||||
when (message.request.path) {
|
||||
|
@ -197,19 +211,19 @@ class ProvisioningSocket private constructor(
|
|||
provisioningMessageDeferral.complete(result)
|
||||
}
|
||||
|
||||
else -> Log.w(TAG, "Unknown path requested")
|
||||
else -> Log.w(TAG, "[$id] Unknown path requested")
|
||||
}
|
||||
} else {
|
||||
Log.w(TAG, "Invalid data")
|
||||
Log.w(TAG, "[$id] Invalid data")
|
||||
}
|
||||
}
|
||||
|
||||
override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {
|
||||
scope.launch {
|
||||
Log.i(TAG, "[onClosing] code: $code reason: $reason")
|
||||
Log.i(TAG, "[$id] [onClosing] code: $code reason: $reason")
|
||||
|
||||
if (code != 1000) {
|
||||
Log.w(TAG, "Remote side is closing with non-normal code $code")
|
||||
Log.w(TAG, "[$id] Remote side is closing with non-normal code $code")
|
||||
webSocket.close(1000, "Remote closed with code $code")
|
||||
}
|
||||
|
||||
|
@ -219,7 +233,7 @@ class ProvisioningSocket private constructor(
|
|||
|
||||
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
|
||||
scope.launch {
|
||||
Log.w(TAG, "[onFailure] Failed", t)
|
||||
Log.w(TAG, "[$id] [onFailure] Failed", t)
|
||||
webSocket.close(1000, "Failed ${t.message}")
|
||||
|
||||
scope.cancel(CancellationException("WebSocket Failure", t))
|
||||
|
@ -233,10 +247,10 @@ class ProvisioningSocket private constructor(
|
|||
}
|
||||
|
||||
private suspend fun keepAlive(webSocket: WebSocket) {
|
||||
Log.i(TAG, "[keepAlive] Starting")
|
||||
Log.i(TAG, "[$id] [keepAlive] Starting")
|
||||
while (true) {
|
||||
delay(30.seconds)
|
||||
Log.i(TAG, "[keepAlive] Sending...")
|
||||
Log.i(TAG, "[$id] [keepAlive] Sending...")
|
||||
|
||||
val id = System.currentTimeMillis()
|
||||
val message = WebSocketMessage(
|
||||
|
@ -249,7 +263,7 @@ class ProvisioningSocket private constructor(
|
|||
)
|
||||
|
||||
if (!webSocket.send(message.encodeByteString())) {
|
||||
Log.w(TAG, "[keepAlive] Send failed")
|
||||
Log.w(TAG, "[${this@ProvisioningSocket.id}] [keepAlive] Send failed")
|
||||
} else {
|
||||
lastKeepAliveId = id
|
||||
}
|
||||
|
@ -267,4 +281,8 @@ class ProvisioningSocket private constructor(
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
fun interface ProvisioningSocketExceptionHandler {
|
||||
fun handleException(id: Int, exception: Throwable)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue