Use db as SSOT for unread counter.

This commit is contained in:
Alex Hart 2022-07-27 10:35:13 -03:00 committed by Cody Henthorne
parent a7a5f2e8c6
commit e3e9f90094
9 changed files with 52 additions and 215 deletions

View file

@ -986,7 +986,7 @@ public class ConversationFragment extends LoggingFragment implements Multiselect
list.addItemDecoration(lastSeenDecoration, 0);
if (lastSeen > 0) {
lastSeenDisposable.add(conversationViewModel.getThreadUnreadCount()
lastSeenDisposable.add(conversationViewModel.getThreadUnreadCount(lastSeen)
.distinctUntilChanged()
.observeOn(AndroidSchedulers.mainThread())
.subscribe(unreadCount -> {

View file

@ -9,14 +9,12 @@ import androidx.annotation.WorkerThread;
import org.signal.core.util.concurrent.SignalExecutors;
import org.signal.core.util.logging.Log;
import org.thoughtcrime.securesms.contacts.sync.ContactDiscovery;
import org.thoughtcrime.securesms.database.Database;
import org.thoughtcrime.securesms.database.DatabaseObserver;
import org.thoughtcrime.securesms.database.GroupDatabase;
import org.thoughtcrime.securesms.database.MessageDatabase;
import org.thoughtcrime.securesms.database.RecipientDatabase;
import org.thoughtcrime.securesms.database.SignalDatabase;
import org.thoughtcrime.securesms.database.ThreadDatabase;
import org.thoughtcrime.securesms.database.model.ThreadRecord;
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies;
import org.thoughtcrime.securesms.jobs.MultiDeviceViewedUpdateJob;
import org.thoughtcrime.securesms.keyvalue.SignalStore;
@ -34,7 +32,6 @@ import java.util.Optional;
import java.util.stream.Collectors;
import io.reactivex.rxjava3.core.Observable;
import io.reactivex.rxjava3.core.Scheduler;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.schedulers.Schedulers;
@ -180,16 +177,14 @@ class ConversationRepository {
}).subscribeOn(Schedulers.io());
}
Observable<Optional<ThreadRecord>> getThreadRecord(long threadId) {
if (threadId == -1L) {
return Observable.just(Optional.empty());
Observable<Integer> getUnreadCount(long threadId, long afterTime) {
if (threadId <= -1L || afterTime <= 0L) {
return Observable.just(0);
}
return Observable.<Optional<ThreadRecord>> create(emitter -> {
return Observable.<Integer> create(emitter -> {
DatabaseObserver.Observer listener = () -> {
emitter.onNext(Optional.ofNullable(SignalDatabase.threads().getThreadRecord(threadId)));
};
DatabaseObserver.Observer listener = () -> emitter.onNext(SignalDatabase.mmsSms().getIncomingMeaningfulMessageCountSince(threadId, afterTime));
ApplicationDependencies.getDatabaseObserver().registerConversationObserver(threadId, listener);
emitter.setCancellable(() -> ApplicationDependencies.getDatabaseObserver().unregisterObserver(listener));

View file

@ -95,9 +95,8 @@ public class ConversationViewModel extends ViewModel {
private final GroupAuthorNameColorHelper groupAuthorNameColorHelper;
private final RxStore<ConversationState> conversationStateStore;
private final CompositeDisposable disposables;
private final BehaviorSubject<Unit> conversationStateTick;
private final RxStore<ThreadCountAggregator> threadCountStore;
private final PublishProcessor<Long> markReadRequestPublisher;
private final BehaviorSubject<Unit> conversationStateTick;
private final PublishProcessor<Long> markReadRequestPublisher;
private ConversationIntents.Args args;
private int jumpToPosition;
@ -126,7 +125,6 @@ public class ConversationViewModel extends ViewModel {
this.conversationStateStore = new RxStore<>(ConversationState.create(), Schedulers.io());
this.disposables = new CompositeDisposable();
this.conversationStateTick = BehaviorSubject.createDefault(Unit.INSTANCE);
this.threadCountStore = new RxStore<>(ThreadCountAggregator.Init.INSTANCE, Schedulers.computation());
this.markReadRequestPublisher = PublishProcessor.create();
BehaviorSubject<Recipient> recipientCache = BehaviorSubject.create();
@ -137,11 +135,6 @@ public class ConversationViewModel extends ViewModel {
.map(Recipient::resolved)
.subscribe(recipientCache);
disposables.add(threadCountStore.update(
threadId.switchMap(conversationRepository::getThreadRecord).toFlowable(BackpressureStrategy.BUFFER),
(record, count) -> record.map(count::updateWith).orElse(count)
));
conversationStateStore.update(Observable.combineLatest(recipientId, conversationStateTick, (id, tick) -> id)
.distinctUntilChanged()
.switchMap(conversationRepository::getSecurityInfo)
@ -307,13 +300,11 @@ public class ConversationViewModel extends ViewModel {
}
@NonNull Flowable<Long> getMarkReadRequests() {
Flowable<ThreadCountAggregator> nonInitialThreadCount = threadCountStore.getStateFlowable().filter(count -> !(count instanceof ThreadCountAggregator.Init)).take(1);
return Flowable.combineLatest(markReadRequestPublisher.onBackpressureBuffer(), nonInitialThreadCount, (time, count) -> time);
return markReadRequestPublisher.onBackpressureBuffer();
}
@NonNull Flowable<Integer> getThreadUnreadCount() {
return threadCountStore.getStateFlowable().map(ThreadCountAggregator::getCount);
@NonNull Observable<Integer> getThreadUnreadCount(long afterTime) {
return threadId.switchMap(id -> conversationRepository.getUnreadCount(id, afterTime));
}
@NonNull Flowable<ConversationState> getConversationState() {

View file

@ -1,70 +0,0 @@
package org.thoughtcrime.securesms.conversation
import org.thoughtcrime.securesms.database.model.ThreadRecord
/**
* Describes and aggregates the thread count for a particular thread, for use in the "Last Seen" header.
*/
sealed class ThreadCountAggregator {
abstract val count: Int
abstract fun updateWith(record: ThreadRecord): ThreadCountAggregator
/**
* The Init object, used as an initial state and returned whenever the given record is an outgoing message.
* The conversation fragment already properly cleans up the header when an outgoing message is emitted, so
* there's no need to worry about seeing a "zero."
*/
object Init : ThreadCountAggregator() {
override val count: Int = 0
override fun updateWith(record: ThreadRecord): ThreadCountAggregator {
return when {
record.isOutgoing -> Outgoing
else -> Count(record.threadId, record.unreadCount, record.date)
}
}
}
/**
* The Outgoing object, returned whenever the given record is an outgoing message.
* The conversation fragment already properly cleans up the header when an outgoing message is emitted, so
* there's no need to worry about seeing a "zero."
*/
object Outgoing : ThreadCountAggregator() {
override val count: Int = 0
override fun updateWith(record: ThreadRecord): ThreadCountAggregator {
return when {
record.isOutgoing -> Outgoing
else -> Count(record.threadId, record.unreadCount, record.date)
}
}
}
/**
* Represents an actual count. We keep record of the id and date to use in comparisons with future
* ThreadRecord objects.
*/
class Count(val threadId: Long, val unreadCount: Int, val threadDate: Long) : ThreadCountAggregator() {
override val count: Int = unreadCount
/**
* "Ratchets" the count to the new state.
* * Outgoing records will always result in Empty
* * Mismatched threadIds will always create a new Count, initialized with the new thread
* * Matching dates will be ignored, as this means that there was no actual change.
* * Otherwise, we'll proceed with the new date and aggregate the count.
*/
override fun updateWith(record: ThreadRecord): ThreadCountAggregator {
return when {
record.isOutgoing -> Outgoing
threadId != record.threadId -> Init.updateWith(record)
threadDate >= record.date -> this
record.unreadCount > 1 -> Init.updateWith(record)
else -> Count(threadId, unreadCount + 1, record.date)
}
}
}
}

View file

@ -78,6 +78,7 @@ public abstract class MessageDatabase extends Database implements MmsSmsColumns
public abstract int getMessageCountForThread(long threadId);
public abstract int getMessageCountForThread(long threadId, long beforeTime);
public abstract boolean hasMeaningfulMessage(long threadId);
public abstract int getIncomingMeaningfulMessageCountSince(long threadId, long afterTime);
public abstract Optional<MmsNotificationInfo> getNotification(long messageId);
public abstract Cursor getExpirationStartedMessages();

View file

@ -1068,6 +1068,22 @@ public class MmsDatabase extends MessageDatabase {
}
}
@Override
public int getIncomingMeaningfulMessageCountSince(long threadId, long afterTime) {
SQLiteDatabase db = databaseHelper.getSignalReadableDatabase();
String[] projection = SqlUtil.COUNT;
String where = THREAD_ID + " = ? AND " + STORY_TYPE + " = ? AND " + PARENT_STORY_ID + " <= ? AND " + DATE_RECEIVED + " >= ?";
String[] whereArgs = SqlUtil.buildArgs(threadId, 0, 0, afterTime);
try (Cursor cursor = db.query(TABLE_NAME, projection, where, whereArgs, null, null, null, "1")) {
if (cursor != null && cursor.moveToFirst()) {
return cursor.getInt(0);
} else {
return 0;
}
}
}
@Override
public void addFailures(long messageId, List<NetworkFailure> failure) {
try {

View file

@ -398,6 +398,13 @@ public class MmsSmsDatabase extends Database {
return count;
}
public int getIncomingMeaningfulMessageCountSince(long threadId, long afterTime) {
int count = SignalDatabase.sms().getIncomingMeaningfulMessageCountSince(threadId, afterTime);
count += SignalDatabase.mms().getIncomingMeaningfulMessageCountSince(threadId, afterTime);
return count;
}
public int getMessageCountBeforeDate(long date) {
String selection = MmsSmsColumns.NORMALIZED_DATE_RECEIVED + " < " + date;

View file

@ -276,6 +276,23 @@ public class SmsDatabase extends MessageDatabase {
}
}
@Override
public int getIncomingMeaningfulMessageCountSince(long threadId, long afterTime) {
SQLiteDatabase db = databaseHelper.getSignalReadableDatabase();
String[] projection = SqlUtil.COUNT;
SqlUtil.Query meaningfulMessagesQuery = buildMeaningfulMessagesQuery(threadId);
String where = meaningfulMessagesQuery.getWhere() + " AND " + DATE_RECEIVED + " >= ?";
String[] whereArgs = SqlUtil.appendArg(meaningfulMessagesQuery.getWhereArgs(), String.valueOf(afterTime));
try (Cursor cursor = db.query(TABLE_NAME, projection, where, whereArgs, null, null, null, "1")) {
if (cursor != null && cursor.moveToFirst()) {
return cursor.getInt(0);
} else {
return 0;
}
}
}
private @NonNull SqlUtil.Query buildMeaningfulMessagesQuery(long threadId) {
String query = THREAD_ID + " = ? AND (NOT " + TYPE + " & ? AND " + TYPE + " != ? AND " + TYPE + " != ? AND " + TYPE + " != ? AND " + TYPE + " & " + GROUP_V2_LEAVE_BITS + " != " + GROUP_V2_LEAVE_BITS + ")";
return SqlUtil.buildQuery(query, threadId, IGNORABLE_TYPESMASK_WHEN_COUNTING, Types.PROFILE_CHANGE_TYPE, Types.CHANGE_NUMBER_TYPE, Types.BOOST_REQUEST_TYPE);

View file

@ -1,120 +0,0 @@
package org.thoughtcrime.securesms.conversation
import org.junit.Assert.assertEquals
import org.junit.Test
import org.thoughtcrime.securesms.database.MmsSmsColumns
import org.thoughtcrime.securesms.database.model.ThreadRecord
class ThreadCountTest {
@Test
fun `Given an Init, when I getCount, then I expect 0`() {
// GIVEN
val threadCount = ThreadCountAggregator.Init
// WHEN
val result = threadCount.count
// THEN
assertEquals(0, result)
}
@Test
fun `Given an Empty, when I updateWith an outgoing record, then I expect Outgoing`() {
// GIVEN
val threadRecord = createThreadRecord(isOutgoing = true)
// WHEN
val result = ThreadCountAggregator.Init.updateWith(threadRecord)
// THEN
assertEquals(result, ThreadCountAggregator.Outgoing)
}
@Test
fun `Given an Empty, when I updateWith an incoming record, then I expect 5`() {
// GIVEN
val threadRecord = createThreadRecord(unreadCount = 5)
// WHEN
val result = ThreadCountAggregator.Init.updateWith(threadRecord)
// THEN
assertEquals(5, result.count)
}
@Test
fun `Given a Count, when I updateWith an incoming record with the same date, then I expect 5`() {
// GIVEN
val threadRecord = createThreadRecord(unreadCount = 5)
val newThreadRecord = createThreadRecord(unreadCount = 1)
// WHEN
val result = ThreadCountAggregator.Init.updateWith(threadRecord).updateWith(newThreadRecord)
// THEN
assertEquals(5, result.count)
}
@Test
fun `Given a Count, when I updateWith an incoming record with an earlier date, then I expect 5`() {
// GIVEN
val threadRecord = createThreadRecord(unreadCount = 5)
val newThreadRecord = createThreadRecord(unreadCount = 1, date = 0L)
// WHEN
val result = ThreadCountAggregator.Init.updateWith(threadRecord).updateWith(newThreadRecord)
// THEN
assertEquals(5, result.count)
}
@Test
fun `Given a Count, when I updateWith an incoming record with a later date, then I expect 6`() {
// GIVEN
val threadRecord = createThreadRecord(unreadCount = 5)
val newThreadRecord = createThreadRecord(unreadCount = 1, date = 2L)
// WHEN
val result = ThreadCountAggregator.Init.updateWith(threadRecord).updateWith(newThreadRecord)
// THEN
assertEquals(6, result.count)
}
@Test
fun `Given a Count, when I updateWith an incoming record with a later date and unread count gt 1, then I expect new unread count`() {
// GIVEN
val threadRecord = createThreadRecord(unreadCount = 5)
val newThreadRecord = createThreadRecord(unreadCount = 3, date = 2L)
// WHEN
val result = ThreadCountAggregator.Init.updateWith(threadRecord).updateWith(newThreadRecord)
// THEN
assertEquals(3, result.count)
}
@Test
fun `Given a Count, when I updateWith an incoming record with a different id, then I expect 3`() {
// GIVEN
val threadRecord = createThreadRecord(threadId = 1L, unreadCount = 5)
val newThreadRecord = createThreadRecord(threadId = 2L, unreadCount = 3)
// WHEN
val result = ThreadCountAggregator.Init.updateWith(threadRecord).updateWith(newThreadRecord)
// THEN
assertEquals(3, result.count)
}
private fun createThreadRecord(threadId: Long = 1L, unreadCount: Int = 0, date: Long = 1L, isOutgoing: Boolean = false): ThreadRecord {
val outgoingMessageType = MmsSmsColumns.Types.getOutgoingEncryptedMessageType()
return ThreadRecord.Builder(threadId)
.setUnreadCount(unreadCount)
.setDate(date)
.setType(if (isOutgoing) outgoingMessageType else (outgoingMessageType.inv()))
.build()
}
}