Rotate quick restore QR code and web socket.
This commit is contained in:
parent
57502fb4ad
commit
2eabf03421
3 changed files with 131 additions and 33 deletions
|
@ -169,8 +169,12 @@ private fun RestoreViaQrScreen(
|
||||||
) {
|
) {
|
||||||
AnimatedContent(
|
AnimatedContent(
|
||||||
targetState = state.qrState,
|
targetState = state.qrState,
|
||||||
|
contentKey = { it::class },
|
||||||
contentAlignment = Alignment.Center,
|
contentAlignment = Alignment.Center,
|
||||||
label = "qr-code-progress"
|
label = "qr-code-progress",
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.fillMaxHeight()
|
||||||
) { qrState ->
|
) { qrState ->
|
||||||
when (qrState) {
|
when (qrState) {
|
||||||
is RestoreViaQrViewModel.QrState.Loaded -> {
|
is RestoreViaQrViewModel.QrState.Loaded -> {
|
||||||
|
@ -184,8 +188,10 @@ private fun RestoreViaQrScreen(
|
||||||
}
|
}
|
||||||
|
|
||||||
RestoreViaQrViewModel.QrState.Loading -> {
|
RestoreViaQrViewModel.QrState.Loading -> {
|
||||||
|
Box(contentAlignment = Alignment.Center) {
|
||||||
CircularProgressIndicator(modifier = Modifier.size(48.dp))
|
CircularProgressIndicator(modifier = Modifier.size(48.dp))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
is RestoreViaQrViewModel.QrState.Scanned,
|
is RestoreViaQrViewModel.QrState.Scanned,
|
||||||
RestoreViaQrViewModel.QrState.Failed -> {
|
RestoreViaQrViewModel.QrState.Failed -> {
|
||||||
|
|
|
@ -6,10 +6,15 @@
|
||||||
package org.thoughtcrime.securesms.registrationv3.ui.restore
|
package org.thoughtcrime.securesms.registrationv3.ui.restore
|
||||||
|
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.ViewModel
|
||||||
import kotlinx.coroutines.CoroutineExceptionHandler
|
import androidx.lifecycle.viewModelScope
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.Job
|
||||||
|
import kotlinx.coroutines.delay
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
import kotlinx.coroutines.flow.update
|
import kotlinx.coroutines.flow.update
|
||||||
|
import kotlinx.coroutines.isActive
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
import org.signal.core.util.logging.Log
|
import org.signal.core.util.logging.Log
|
||||||
import org.signal.registration.proto.RegistrationProvisionMessage
|
import org.signal.registration.proto.RegistrationProvisionMessage
|
||||||
import org.thoughtcrime.securesms.backup.v2.MessageBackupTier
|
import org.thoughtcrime.securesms.backup.v2.MessageBackupTier
|
||||||
|
@ -31,15 +36,30 @@ class RestoreViaQrViewModel : ViewModel() {
|
||||||
|
|
||||||
val state: StateFlow<RestoreViaQrState> = store
|
val state: StateFlow<RestoreViaQrState> = store
|
||||||
|
|
||||||
private var socketHandle: Closeable
|
private var socketHandles: MutableList<Closeable> = mutableListOf()
|
||||||
|
private var startNewSocketJob: Job? = null
|
||||||
|
|
||||||
init {
|
init {
|
||||||
socketHandle = start()
|
restart()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun restart() {
|
fun restart() {
|
||||||
socketHandle.close()
|
SignalStore.registration.restoreMethodToken = null
|
||||||
socketHandle = start()
|
shutdown()
|
||||||
|
|
||||||
|
startNewSocket()
|
||||||
|
|
||||||
|
startNewSocketJob = viewModelScope.launch(Dispatchers.IO) {
|
||||||
|
var count = 0
|
||||||
|
while (count < 5 && isActive) {
|
||||||
|
delay(ProvisioningSocket.LIFESPAN / 2)
|
||||||
|
if (isActive) {
|
||||||
|
startNewSocket()
|
||||||
|
count++
|
||||||
|
Log.d(TAG, "Started next websocket count: $count")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun handleRegistrationFailure() {
|
fun handleRegistrationFailure() {
|
||||||
|
@ -61,20 +81,66 @@ class RestoreViaQrViewModel : ViewModel() {
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun onCleared() {
|
override fun onCleared() {
|
||||||
socketHandle.close()
|
shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun startNewSocket() {
|
||||||
|
synchronized(socketHandles) {
|
||||||
|
socketHandles += start()
|
||||||
|
|
||||||
|
if (socketHandles.size > 2) {
|
||||||
|
socketHandles.removeAt(0).close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun shutdown() {
|
||||||
|
startNewSocketJob?.cancel()
|
||||||
|
synchronized(socketHandles) {
|
||||||
|
socketHandles.forEach { it.close() }
|
||||||
|
socketHandles.clear()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun start(): Closeable {
|
private fun start(): Closeable {
|
||||||
SignalStore.registration.restoreMethodToken = null
|
store.update {
|
||||||
store.update { it.copy(qrState = QrState.Loading) }
|
if (it.qrState !is QrState.Loaded) {
|
||||||
|
it.copy(qrState = QrState.Loading)
|
||||||
|
} else {
|
||||||
|
it
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return ProvisioningSocket.start(
|
return ProvisioningSocket.start(
|
||||||
identityKeyPair = IdentityKeyUtil.generateIdentityKeyPair(),
|
identityKeyPair = IdentityKeyUtil.generateIdentityKeyPair(),
|
||||||
configuration = AppDependencies.signalServiceNetworkAccess.getConfiguration(),
|
configuration = AppDependencies.signalServiceNetworkAccess.getConfiguration(),
|
||||||
handler = CoroutineExceptionHandler { _, _ -> store.update { it.copy(qrState = QrState.Failed) } }
|
handler = { id, t ->
|
||||||
|
store.update {
|
||||||
|
if (it.currentSocketId == null || it.currentSocketId == id) {
|
||||||
|
Log.w(TAG, "Current socket [$id] has failed, stopping automatic connects", t)
|
||||||
|
shutdown()
|
||||||
|
it.copy(currentSocketId = null, qrState = QrState.Failed)
|
||||||
|
} else {
|
||||||
|
Log.i(TAG, "Old socket [$id] failed, ignoring")
|
||||||
|
it
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
) { socket ->
|
) { socket ->
|
||||||
val url = socket.getProvisioningUrl()
|
val url = socket.getProvisioningUrl()
|
||||||
store.update { it.copy(qrState = QrState.Loaded(qrData = QrCodeData.forData(data = url, supportIconOverlay = false))) }
|
store.update {
|
||||||
|
Log.d(TAG, "Updating QR code with data from [${socket.id}]")
|
||||||
|
|
||||||
|
it.copy(
|
||||||
|
currentSocketId = socket.id,
|
||||||
|
qrState = QrState.Loaded(
|
||||||
|
qrData = QrCodeData.forData(
|
||||||
|
data = url,
|
||||||
|
supportIconOverlay = false
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
val result = socket.getRegistrationProvisioningMessage()
|
val result = socket.getRegistrationProvisioningMessage()
|
||||||
|
|
||||||
|
@ -94,8 +160,15 @@ class RestoreViaQrViewModel : ViewModel() {
|
||||||
SignalStore.backup.usedBackupMediaSpace = result.message.backupSizeBytes
|
SignalStore.backup.usedBackupMediaSpace = result.message.backupSizeBytes
|
||||||
}
|
}
|
||||||
store.update { it.copy(isRegistering = true, provisioningMessage = result.message, qrState = QrState.Scanned) }
|
store.update { it.copy(isRegistering = true, provisioningMessage = result.message, qrState = QrState.Scanned) }
|
||||||
|
shutdown()
|
||||||
} else {
|
} else {
|
||||||
store.update { it.copy(showProvisioningError = true, qrState = QrState.Scanned) }
|
store.update {
|
||||||
|
if (it.currentSocketId == socket.id) {
|
||||||
|
it.copy(showProvisioningError = true, qrState = QrState.Scanned)
|
||||||
|
} else {
|
||||||
|
it
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -105,7 +178,8 @@ class RestoreViaQrViewModel : ViewModel() {
|
||||||
val qrState: QrState = QrState.Loading,
|
val qrState: QrState = QrState.Loading,
|
||||||
val provisioningMessage: RegistrationProvisionMessage? = null,
|
val provisioningMessage: RegistrationProvisionMessage? = null,
|
||||||
val showProvisioningError: Boolean = false,
|
val showProvisioningError: Boolean = false,
|
||||||
val showRegistrationError: Boolean = false
|
val showRegistrationError: Boolean = false,
|
||||||
|
val currentSocketId: Int? = null
|
||||||
)
|
)
|
||||||
|
|
||||||
sealed interface QrState {
|
sealed interface QrState {
|
||||||
|
|
|
@ -44,6 +44,7 @@ import kotlin.time.Duration.Companion.seconds
|
||||||
* A provisional web socket for communicating with a primary device during registration.
|
* A provisional web socket for communicating with a primary device during registration.
|
||||||
*/
|
*/
|
||||||
class ProvisioningSocket private constructor(
|
class ProvisioningSocket private constructor(
|
||||||
|
val id: Int,
|
||||||
identityKeyPair: IdentityKeyPair,
|
identityKeyPair: IdentityKeyPair,
|
||||||
configuration: SignalServiceConfiguration,
|
configuration: SignalServiceConfiguration,
|
||||||
private val scope: CoroutineScope
|
private val scope: CoroutineScope
|
||||||
|
@ -51,31 +52,36 @@ class ProvisioningSocket private constructor(
|
||||||
companion object {
|
companion object {
|
||||||
private val TAG = Log.tag(ProvisioningSocket::class)
|
private val TAG = Log.tag(ProvisioningSocket::class)
|
||||||
|
|
||||||
|
@Volatile private var nextSocketId = 1000
|
||||||
|
|
||||||
|
val LIFESPAN = 90.seconds
|
||||||
|
|
||||||
fun start(
|
fun start(
|
||||||
identityKeyPair: IdentityKeyPair,
|
identityKeyPair: IdentityKeyPair,
|
||||||
configuration: SignalServiceConfiguration,
|
configuration: SignalServiceConfiguration,
|
||||||
handler: CoroutineExceptionHandler,
|
handler: ProvisioningSocketExceptionHandler,
|
||||||
block: suspend CoroutineScope.(ProvisioningSocket) -> Unit
|
block: suspend CoroutineScope.(ProvisioningSocket) -> Unit
|
||||||
): Closeable {
|
): Closeable {
|
||||||
val scope = CoroutineScope(Dispatchers.IO) + SupervisorJob() + handler
|
val socketId = nextSocketId++
|
||||||
|
val scope = CoroutineScope(Dispatchers.IO) + SupervisorJob() + CoroutineExceptionHandler { _, t -> handler.handleException(socketId, t) }
|
||||||
|
|
||||||
scope.launch {
|
scope.launch {
|
||||||
var socket: ProvisioningSocket? = null
|
var socket: ProvisioningSocket? = null
|
||||||
try {
|
try {
|
||||||
socket = ProvisioningSocket(identityKeyPair, configuration, scope)
|
socket = ProvisioningSocket(socketId, identityKeyPair, configuration, scope)
|
||||||
socket.connect()
|
socket.connect()
|
||||||
block(socket)
|
block(socket)
|
||||||
} catch (e: CancellationException) {
|
} catch (e: CancellationException) {
|
||||||
val rootCause = e.getRootCause()
|
val rootCause = e.getRootCause()
|
||||||
if (rootCause == null) {
|
if (rootCause == null) {
|
||||||
Log.i(TAG, "Scope canceled expectedly, fail silently, ${e.toMinimalString()}")
|
Log.i(TAG, "[$socketId] Scope canceled expectedly, fail silently, ${e.toMinimalString()}")
|
||||||
throw e
|
throw e
|
||||||
} else {
|
} else {
|
||||||
Log.w(TAG, "Unable to maintain web socket, ${rootCause.toMinimalString()}", rootCause)
|
Log.w(TAG, "[$socketId] Unable to maintain web socket, ${rootCause.toMinimalString()}", rootCause)
|
||||||
throw rootCause
|
throw rootCause
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
Log.d(TAG, "Closing web socket")
|
Log.d(TAG, "[$socketId] Closing web socket")
|
||||||
socket?.close()
|
socket?.close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -144,7 +150,7 @@ class ProvisioningSocket private constructor(
|
||||||
private var lastKeepAliveId: Long = 0
|
private var lastKeepAliveId: Long = 0
|
||||||
|
|
||||||
override fun onOpen(webSocket: WebSocket, response: Response) {
|
override fun onOpen(webSocket: WebSocket, response: Response) {
|
||||||
Log.d(TAG, "[onOpen]")
|
Log.d(TAG, "[$id] [onOpen]")
|
||||||
keepAliveJob = scope.launch { keepAlive(webSocket) }
|
keepAliveJob = scope.launch { keepAlive(webSocket) }
|
||||||
|
|
||||||
val timeoutJob = scope.launch {
|
val timeoutJob = scope.launch {
|
||||||
|
@ -152,9 +158,17 @@ class ProvisioningSocket private constructor(
|
||||||
scope.cancel("Did not receive device id within 10 seconds", SocketTimeoutException("No device id received"))
|
scope.cancel("Did not receive device id within 10 seconds", SocketTimeoutException("No device id received"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val webSocketExpireJob = scope.launch {
|
||||||
|
delay(LIFESPAN)
|
||||||
|
scope.cancel("Did not complete a registration within ${LIFESPAN.inWholeSeconds} seconds", SocketTimeoutException("No provisioning message received"))
|
||||||
|
}
|
||||||
|
|
||||||
scope.launch {
|
scope.launch {
|
||||||
provisioningUrlDeferral.await()
|
provisioningUrlDeferral.await()
|
||||||
timeoutJob.cancel()
|
timeoutJob.cancel()
|
||||||
|
|
||||||
|
provisioningMessageDeferral.await()
|
||||||
|
webSocketExpireJob.cancel()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,24 +176,24 @@ class ProvisioningSocket private constructor(
|
||||||
val message: WebSocketMessage = WebSocketMessage.ADAPTER.decode(bytes)
|
val message: WebSocketMessage = WebSocketMessage.ADAPTER.decode(bytes)
|
||||||
|
|
||||||
if (message.response != null && message.response.id == lastKeepAliveId) {
|
if (message.response != null && message.response.id == lastKeepAliveId) {
|
||||||
Log.d(TAG, "[onMessage] Keep alive received")
|
Log.d(TAG, "[$id] [onMessage] Keep alive received")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (message.request == null) {
|
if (message.request == null) {
|
||||||
Log.w(TAG, "[onMessage] Received null request")
|
Log.w(TAG, "[$id] [onMessage] Received null request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
val success = webSocket.send(message.request.toResponse().encode().toByteString())
|
val success = webSocket.send(message.request.toResponse().encode().toByteString())
|
||||||
|
|
||||||
if (!success) {
|
if (!success) {
|
||||||
Log.w(TAG, "[onMessage] Failed to send response")
|
Log.w(TAG, "[$id] [onMessage] Failed to send response")
|
||||||
webSocket.close(1000, "OK")
|
webSocket.close(1000, "OK")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
Log.d(TAG, "[onMessage] Processing request")
|
Log.d(TAG, "[$id] [onMessage] Processing request")
|
||||||
|
|
||||||
if (message.request.verb == "PUT" && message.request.body != null) {
|
if (message.request.verb == "PUT" && message.request.body != null) {
|
||||||
when (message.request.path) {
|
when (message.request.path) {
|
||||||
|
@ -197,19 +211,19 @@ class ProvisioningSocket private constructor(
|
||||||
provisioningMessageDeferral.complete(result)
|
provisioningMessageDeferral.complete(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
else -> Log.w(TAG, "Unknown path requested")
|
else -> Log.w(TAG, "[$id] Unknown path requested")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Log.w(TAG, "Invalid data")
|
Log.w(TAG, "[$id] Invalid data")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {
|
override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {
|
||||||
scope.launch {
|
scope.launch {
|
||||||
Log.i(TAG, "[onClosing] code: $code reason: $reason")
|
Log.i(TAG, "[$id] [onClosing] code: $code reason: $reason")
|
||||||
|
|
||||||
if (code != 1000) {
|
if (code != 1000) {
|
||||||
Log.w(TAG, "Remote side is closing with non-normal code $code")
|
Log.w(TAG, "[$id] Remote side is closing with non-normal code $code")
|
||||||
webSocket.close(1000, "Remote closed with code $code")
|
webSocket.close(1000, "Remote closed with code $code")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,7 +233,7 @@ class ProvisioningSocket private constructor(
|
||||||
|
|
||||||
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
|
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
|
||||||
scope.launch {
|
scope.launch {
|
||||||
Log.w(TAG, "[onFailure] Failed", t)
|
Log.w(TAG, "[$id] [onFailure] Failed", t)
|
||||||
webSocket.close(1000, "Failed ${t.message}")
|
webSocket.close(1000, "Failed ${t.message}")
|
||||||
|
|
||||||
scope.cancel(CancellationException("WebSocket Failure", t))
|
scope.cancel(CancellationException("WebSocket Failure", t))
|
||||||
|
@ -233,10 +247,10 @@ class ProvisioningSocket private constructor(
|
||||||
}
|
}
|
||||||
|
|
||||||
private suspend fun keepAlive(webSocket: WebSocket) {
|
private suspend fun keepAlive(webSocket: WebSocket) {
|
||||||
Log.i(TAG, "[keepAlive] Starting")
|
Log.i(TAG, "[$id] [keepAlive] Starting")
|
||||||
while (true) {
|
while (true) {
|
||||||
delay(30.seconds)
|
delay(30.seconds)
|
||||||
Log.i(TAG, "[keepAlive] Sending...")
|
Log.i(TAG, "[$id] [keepAlive] Sending...")
|
||||||
|
|
||||||
val id = System.currentTimeMillis()
|
val id = System.currentTimeMillis()
|
||||||
val message = WebSocketMessage(
|
val message = WebSocketMessage(
|
||||||
|
@ -249,7 +263,7 @@ class ProvisioningSocket private constructor(
|
||||||
)
|
)
|
||||||
|
|
||||||
if (!webSocket.send(message.encodeByteString())) {
|
if (!webSocket.send(message.encodeByteString())) {
|
||||||
Log.w(TAG, "[keepAlive] Send failed")
|
Log.w(TAG, "[${this@ProvisioningSocket.id}] [keepAlive] Send failed")
|
||||||
} else {
|
} else {
|
||||||
lastKeepAliveId = id
|
lastKeepAliveId = id
|
||||||
}
|
}
|
||||||
|
@ -267,4 +281,8 @@ class ProvisioningSocket private constructor(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun interface ProvisioningSocketExceptionHandler {
|
||||||
|
fun handleException(id: Int, exception: Throwable)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue