diff --git a/app/src/main/java/org/thoughtcrime/securesms/gcm/FcmFetchManager.kt b/app/src/main/java/org/thoughtcrime/securesms/gcm/FcmFetchManager.kt index 99addf20b8..08a07e23c3 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/gcm/FcmFetchManager.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/gcm/FcmFetchManager.kt @@ -14,7 +14,7 @@ import org.thoughtcrime.securesms.MainActivity import org.thoughtcrime.securesms.R import org.thoughtcrime.securesms.dependencies.ApplicationDependencies import org.thoughtcrime.securesms.jobs.PushNotificationReceiveJob -import org.thoughtcrime.securesms.messages.WebSocketStrategy +import org.thoughtcrime.securesms.messages.WebSocketDrainer import org.thoughtcrime.securesms.notifications.NotificationChannels import org.thoughtcrime.securesms.notifications.NotificationIds import org.thoughtcrime.securesms.util.FeatureFlags @@ -140,7 +140,7 @@ object FcmFetchManager { @JvmStatic fun retrieveMessages(context: Context): Boolean { - val success = WebSocketStrategy.execute(WEBSOCKET_DRAIN_TIMEOUT) + val success = WebSocketDrainer.blockUntilDrainedAndProcessed(WEBSOCKET_DRAIN_TIMEOUT) if (success) { Log.i(TAG, "Successfully retrieved messages.") diff --git a/app/src/main/java/org/thoughtcrime/securesms/gcm/FcmJobService.java b/app/src/main/java/org/thoughtcrime/securesms/gcm/FcmJobService.java index 4599d7b047..53fc6fd256 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/gcm/FcmJobService.java +++ b/app/src/main/java/org/thoughtcrime/securesms/gcm/FcmJobService.java @@ -12,10 +12,8 @@ import androidx.annotation.RequiresApi; import org.signal.core.util.concurrent.SignalExecutors; import org.signal.core.util.logging.Log; import org.thoughtcrime.securesms.dependencies.ApplicationDependencies; -import org.thoughtcrime.securesms.keyvalue.SignalStore; -import org.thoughtcrime.securesms.messages.WebSocketStrategy; +import org.thoughtcrime.securesms.messages.WebSocketDrainer; import org.thoughtcrime.securesms.util.ServiceUtil; -import org.thoughtcrime.securesms.util.TextSecurePreferences; /** * Pulls down messages. Used when we fail to pull down messages in {@link FcmReceiveService}. @@ -47,7 +45,7 @@ public class FcmJobService extends JobService { } SignalExecutors.UNBOUNDED.execute(() -> { - boolean success = WebSocketStrategy.execute(); + boolean success = WebSocketDrainer.blockUntilDrainedAndProcessed(); if (success) { Log.i(TAG, "Successfully retrieved messages."); diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/PushNotificationReceiveJob.java b/app/src/main/java/org/thoughtcrime/securesms/jobs/PushNotificationReceiveJob.java index 39735ba051..b7efcac3d6 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/PushNotificationReceiveJob.java +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/PushNotificationReceiveJob.java @@ -7,7 +7,7 @@ import org.signal.core.util.logging.Log; import org.thoughtcrime.securesms.R; import org.thoughtcrime.securesms.jobmanager.Job; import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint; -import org.thoughtcrime.securesms.messages.WebSocketStrategy; +import org.thoughtcrime.securesms.messages.WebSocketDrainer; import org.thoughtcrime.securesms.service.DelayedNotificationController; import org.thoughtcrime.securesms.service.GenericForegroundService; import org.whispersystems.signalservice.api.push.exceptions.PushNetworkException; @@ -48,7 +48,7 @@ public final class PushNotificationReceiveJob extends BaseJob { boolean success; try (DelayedNotificationController unused = GenericForegroundService.startForegroundTaskDelayed(context, context.getString(R.string.BackgroundMessageRetriever_checking_for_messages), 300, R.drawable.ic_signal_refresh)) { - success = WebSocketStrategy.execute(); + success = WebSocketDrainer.blockUntilDrainedAndProcessed(); } if (success) { diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/WebSocketDrainer.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/WebSocketDrainer.kt new file mode 100644 index 0000000000..231eda8c9d --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/WebSocketDrainer.kt @@ -0,0 +1,168 @@ +package org.thoughtcrime.securesms.messages + +import android.os.PowerManager +import androidx.annotation.AnyThread +import androidx.annotation.WorkerThread +import org.signal.core.util.Stopwatch +import org.signal.core.util.logging.Log +import org.thoughtcrime.securesms.dependencies.ApplicationDependencies +import org.thoughtcrime.securesms.jobmanager.Job +import org.thoughtcrime.securesms.jobmanager.JobTracker +import org.thoughtcrime.securesms.jobmanager.JobTracker.JobListener +import org.thoughtcrime.securesms.jobs.MarkerJob +import org.thoughtcrime.securesms.jobs.PushProcessMessageJob +import org.thoughtcrime.securesms.util.NetworkUtil +import org.thoughtcrime.securesms.util.PowerManagerCompat +import org.thoughtcrime.securesms.util.ServiceUtil +import org.thoughtcrime.securesms.util.WakeLockUtil +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import kotlin.time.Duration.Companion.minutes +import kotlin.time.Duration.Companion.seconds + +/** + * Forces the websocket to stay open until all messages have been drained and processed or until a user-specified timeout has been hit. + */ +object WebSocketDrainer { + private val TAG = Log.tag(WebSocketDrainer::class.java) + + private const val KEEP_ALIVE_TOKEN = "WebsocketStrategy" + private const val WAKELOCK_PREFIX = "websocket-strategy-" + + private val QUEUE_TIMEOUT = 30.seconds.inWholeMilliseconds + + @JvmStatic + @WorkerThread + fun blockUntilDrainedAndProcessed(): Boolean { + return blockUntilDrainedAndProcessed(1.minutes.inWholeMilliseconds) + } + + /** + * Blocks until the websocket is drained and all resulting processing jobs have finished, or until the [websocketDrainTimeoutMs] has been hit. + * Note: the timeout specified here is only for draining the websocket. There is currently a non-configurable timeout for waiting for the job queues. + */ + @WorkerThread + fun blockUntilDrainedAndProcessed(websocketDrainTimeoutMs: Long): Boolean { + val context = ApplicationDependencies.getApplication() + val incomingMessageObserver = ApplicationDependencies.getIncomingMessageObserver() + val powerManager = ServiceUtil.getPowerManager(context) + + val doze = PowerManagerCompat.isDeviceIdleMode(powerManager) + val network = NetworkUtil.isConnected(context) + + if (doze || !network) { + Log.w(TAG, "We may be operating in a constrained environment. Doze: $doze Network: $network") + } + + incomingMessageObserver.registerKeepAliveToken(KEEP_ALIVE_TOKEN) + + val wakeLockTag = WAKELOCK_PREFIX + System.currentTimeMillis() + val wakeLock = WakeLockUtil.acquire(ApplicationDependencies.getApplication(), PowerManager.PARTIAL_WAKE_LOCK, websocketDrainTimeoutMs + QUEUE_TIMEOUT, wakeLockTag) + + return try { + drainAndProcess(websocketDrainTimeoutMs, incomingMessageObserver) + } finally { + WakeLockUtil.release(wakeLock, wakeLockTag) + incomingMessageObserver.removeKeepAliveToken(KEEP_ALIVE_TOKEN) + } + } + + /** + * This drains the socket and listens for any processing jobs that were enqueued during that time. + * + * For every job queue that got a processing job, we'll add a [MarkerJob] and wait for it to finish + * so that we know the queue has been drained. + */ + @WorkerThread + private fun drainAndProcess(timeout: Long, incomingMessageObserver: IncomingMessageObserver): Boolean { + val stopwatch = Stopwatch("websocket-strategy") + + val jobManager = ApplicationDependencies.getJobManager() + val queueListener = QueueFindingJobListener() + + jobManager.addListener( + { job: Job -> job.parameters.queue?.startsWith(PushProcessMessageJob.QUEUE_PREFIX) ?: false }, + queueListener + ) + + val successfullyDrained = blockUntilWebsocketDrained(incomingMessageObserver, timeout) + if (!successfullyDrained) { + return false + } + + stopwatch.split("decryptions-drained") + + val processQueues: Set = queueListener.getQueues() + Log.d(TAG, "Discovered " + processQueues.size + " queue(s): " + processQueues) + + for (queue in processQueues) { + val queueDrained = blockUntilJobQueueDrained(queue, QUEUE_TIMEOUT) + if (!queueDrained) { + return false + } + } + + stopwatch.split("process-drained") + stopwatch.stop(TAG) + return true + } + + private fun blockUntilWebsocketDrained(incomingMessageObserver: IncomingMessageObserver, timeoutMs: Long): Boolean { + val latch = CountDownLatch(1) + incomingMessageObserver.addDecryptionDrainedListener(object : Runnable { + override fun run() { + latch.countDown() + incomingMessageObserver.removeDecryptionDrainedListener(this) + } + }) + + return try { + if (latch.await(timeoutMs, TimeUnit.MILLISECONDS)) { + true + } else { + Log.w(TAG, "Hit timeout while waiting for decryptions to drain!") + false + } + } catch (e: InterruptedException) { + Log.w(TAG, "Interrupted!", e) + false + } + } + + private fun blockUntilJobQueueDrained(queue: String, timeoutMs: Long): Boolean { + val startTime = System.currentTimeMillis() + val jobManager = ApplicationDependencies.getJobManager() + val markerJob = MarkerJob(queue) + val jobState = jobManager.runSynchronously(markerJob, timeoutMs) + + if (!jobState.isPresent) { + Log.w(TAG, "Timed out waiting for $queue job(s) to finish!") + return false + } + + val endTime = System.currentTimeMillis() + val duration = endTime - startTime + + Log.d(TAG, "Waited $duration ms for the $queue job(s) to finish.") + return true + } + + private class QueueFindingJobListener : JobListener { + private val queues: MutableSet = HashSet() + + @AnyThread + override fun onStateChanged(job: Job, jobState: JobTracker.JobState) { + synchronized(queues) { + job.parameters.queue?.let { queue -> + queues += queue + } + } + } + + fun getQueues(): Set { + synchronized(queues) { + return HashSet(queues) + } + } + } +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/WebSocketStrategy.java b/app/src/main/java/org/thoughtcrime/securesms/messages/WebSocketStrategy.java deleted file mode 100644 index 4b6ebd1668..0000000000 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/WebSocketStrategy.java +++ /dev/null @@ -1,171 +0,0 @@ -package org.thoughtcrime.securesms.messages; - -import android.app.Application; -import android.content.Context; -import android.os.PowerManager; -import android.os.PowerManager.WakeLock; - -import androidx.annotation.AnyThread; -import androidx.annotation.NonNull; -import androidx.annotation.WorkerThread; - -import org.signal.core.util.Stopwatch; -import org.signal.core.util.ThreadUtil; -import org.signal.core.util.logging.Log; -import org.thoughtcrime.securesms.ApplicationContext; -import org.thoughtcrime.securesms.dependencies.ApplicationDependencies; -import org.thoughtcrime.securesms.jobmanager.Job; -import org.thoughtcrime.securesms.jobmanager.JobManager; -import org.thoughtcrime.securesms.jobmanager.JobTracker; -import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint; -import org.thoughtcrime.securesms.jobs.MarkerJob; -import org.thoughtcrime.securesms.jobs.PushProcessMessageJob; -import org.thoughtcrime.securesms.keyvalue.SignalStore; -import org.thoughtcrime.securesms.util.NetworkUtil; -import org.thoughtcrime.securesms.util.PowerManagerCompat; -import org.thoughtcrime.securesms.util.ServiceUtil; -import org.thoughtcrime.securesms.util.WakeLockUtil; - -import java.util.HashSet; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; - -/** - * Retrieves messages over the websocket. - */ -public class WebSocketStrategy { - - private static final String TAG = Log.tag(WebSocketStrategy.class); - - private static final String KEEP_ALIVE_TOKEN = "WebsocketStrategy"; - private static final long QUEUE_TIMEOUT = TimeUnit.SECONDS.toMillis(30); - private static final String WAKELOCK_PREFIX = "websocket-strategy-"; - - @WorkerThread - public static boolean execute() { - return execute(TimeUnit.MINUTES.toMillis(1)); - } - - @WorkerThread - public static boolean execute(long websocketDrainTimeoutMs) { - Application context = ApplicationDependencies.getApplication(); - IncomingMessageObserver observer = ApplicationDependencies.getIncomingMessageObserver(); - - PowerManager powerManager = ServiceUtil.getPowerManager(context); - boolean doze = PowerManagerCompat.isDeviceIdleMode(powerManager); - boolean network = NetworkUtil.isConnected(context); - - if (doze || !network) { - Log.w(TAG, "We may be operating in a constrained environment. Doze: " + doze + " Network: " + network); - } - - observer.registerKeepAliveToken(KEEP_ALIVE_TOKEN); - - String wakeLockTag = WAKELOCK_PREFIX + System.currentTimeMillis(); - WakeLock wakeLock = WakeLockUtil.acquire(ApplicationDependencies.getApplication(), PowerManager.PARTIAL_WAKE_LOCK, websocketDrainTimeoutMs + QUEUE_TIMEOUT, wakeLockTag); - - try { - return drainAndProcess(websocketDrainTimeoutMs); - } finally { - WakeLockUtil.release(wakeLock, wakeLockTag); - } - } - - @WorkerThread - private static boolean drainAndProcess(long timeout) { - Stopwatch stopwatch = new Stopwatch("websocket-strategy"); - IncomingMessageObserver observer = ApplicationDependencies.getIncomingMessageObserver(); - - observer.registerKeepAliveToken(KEEP_ALIVE_TOKEN); - try { - JobManager jobManager = ApplicationDependencies.getJobManager(); - QueueFindingJobListener queueListener = new QueueFindingJobListener(); - - jobManager.addListener(job -> job.getParameters().getQueue() != null && job.getParameters().getQueue().startsWith(PushProcessMessageJob.QUEUE_PREFIX), queueListener); - - if (!blockUntilWebsocketDrained(observer, timeout)) { - return false; - } - - stopwatch.split("decryptions-drained"); - - Set processQueues = queueListener.getQueues(); - Log.d(TAG, "Discovered " + processQueues.size() + " queue(s): " + processQueues); - - for (String queue : processQueues) { - if (!blockUntilJobQueueDrained(queue, QUEUE_TIMEOUT)) { - return false; - } - } - - stopwatch.split("process-drained"); - stopwatch.stop(TAG); - - return true; - } finally { - ApplicationDependencies.getIncomingMessageObserver().removeKeepAliveToken(KEEP_ALIVE_TOKEN); - } - } - - private static boolean blockUntilWebsocketDrained(IncomingMessageObserver observer, long timeoutMs) { - CountDownLatch latch = new CountDownLatch(1); - - observer.addDecryptionDrainedListener(new Runnable() { - @Override public void run() { - latch.countDown(); - observer.removeDecryptionDrainedListener(this); - } - }); - - try { - if (latch.await(timeoutMs, TimeUnit.MILLISECONDS)) { - return true; - } else { - Log.w(TAG, "Hit timeout while waiting for decryptions to drain!"); - return false; - } - } catch (InterruptedException e) { - Log.w(TAG, "Interrupted!", e); - return false; - } - } - - private static boolean blockUntilJobQueueDrained(@NonNull String queue, long timeoutMs) { - long startTime = System.currentTimeMillis(); - final JobManager jobManager = ApplicationDependencies.getJobManager(); - final MarkerJob markerJob = new MarkerJob(queue); - - Optional jobState = jobManager.runSynchronously(markerJob, timeoutMs); - - if (!jobState.isPresent()) { - Log.w(TAG, "Timed out waiting for " + queue + " job(s) to finish!"); - return false; - } - - long endTime = System.currentTimeMillis(); - long duration = endTime - startTime; - - Log.d(TAG, "Waited " + duration + " ms for the " + queue + " job(s) to finish."); - return true; - } - - protected static class QueueFindingJobListener implements JobTracker.JobListener { - private final Set queues = new HashSet<>(); - - @Override - @AnyThread - public void onStateChanged(@NonNull Job job, @NonNull JobTracker.JobState jobState) { - synchronized (queues) { - queues.add(job.getParameters().getQueue()); - } - } - - @NonNull Set getQueues() { - synchronized (queues) { - return new HashSet<>(queues); - } - } - } -}