Fix thread merges where one thread is inactive.

This commit is contained in:
Greyson Parrelli 2024-02-14 11:15:49 -05:00
parent 0cc7178cdc
commit ba41df19bb
2 changed files with 47 additions and 32 deletions

View file

@ -893,8 +893,8 @@ class RecipientTableTest_getAndPossiblyMerge {
// Thread validation // Thread validation
assertEquals(threadIdAci, retrievedThreadId) assertEquals(threadIdAci, retrievedThreadId)
Assert.assertNull(SignalDatabase.threads.getThreadIdFor(recipientIdE164)) assertNull(SignalDatabase.threads.getThreadIdFor(recipientIdE164))
Assert.assertNull(SignalDatabase.threads.getThreadRecord(threadIdE164)) assertNull(SignalDatabase.threads.getThreadRecord(threadIdE164))
// SMS validation // SMS validation
val sms1: MessageRecord = SignalDatabase.messages.getMessageRecord(smsId1)!! val sms1: MessageRecord = SignalDatabase.messages.getMessageRecord(smsId1)!!
@ -938,10 +938,10 @@ class RecipientTableTest_getAndPossiblyMerge {
// Identity validation // Identity validation
assertEquals(identityKeyAci, SignalDatabase.identities.getIdentityStoreRecord(ACI_A.toString())!!.identityKey) assertEquals(identityKeyAci, SignalDatabase.identities.getIdentityStoreRecord(ACI_A.toString())!!.identityKey)
Assert.assertNull(SignalDatabase.identities.getIdentityStoreRecord(E164_A)) assertNull(SignalDatabase.identities.getIdentityStoreRecord(E164_A))
// Session validation // Session validation
Assert.assertNotNull(SignalDatabase.sessions.load(ACI_SELF, SignalProtocolAddress(ACI_A.toString(), 1))) assertNotNull(SignalDatabase.sessions.load(ACI_SELF, SignalProtocolAddress(ACI_A.toString(), 1)))
// Reaction validation // Reaction validation
val reactionsSms: List<ReactionRecord> = SignalDatabase.reactions.getReactions(MessageId(smsId1)) val reactionsSms: List<ReactionRecord> = SignalDatabase.reactions.getReactions(MessageId(smsId1))

View file

@ -18,6 +18,7 @@ import org.signal.core.util.exists
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.signal.core.util.or import org.signal.core.util.or
import org.signal.core.util.readToList import org.signal.core.util.readToList
import org.signal.core.util.readToSingleLong
import org.signal.core.util.requireBoolean import org.signal.core.util.requireBoolean
import org.signal.core.util.requireInt import org.signal.core.util.requireInt
import org.signal.core.util.requireLong import org.signal.core.util.requireLong
@ -1585,64 +1586,69 @@ class ThreadTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTa
check(databaseHelper.signalWritableDatabase.inTransaction()) { "Must be in a transaction!" } check(databaseHelper.signalWritableDatabase.inTransaction()) { "Must be in a transaction!" }
Log.w(TAG, "Merging threads. Primary: $primaryRecipientId, Secondary: $secondaryRecipientId", true) Log.w(TAG, "Merging threads. Primary: $primaryRecipientId, Secondary: $secondaryRecipientId", true)
val primary: ThreadRecord? = getThreadRecord(getThreadIdFor(primaryRecipientId)) val primaryThreadId: Long? = getThreadIdFor(primaryRecipientId)
val secondary: ThreadRecord? = getThreadRecord(getThreadIdFor(secondaryRecipientId)) val secondaryThreadId: Long? = getThreadIdFor(secondaryRecipientId)
return if (primary != null && secondary == null) { return if (primaryThreadId != null && secondaryThreadId == null) {
Log.w(TAG, "[merge] Only had a thread for primary. Returning that.", true) Log.w(TAG, "[merge] Only had a thread for primary. Returning that.", true)
MergeResult(threadId = primary.threadId, previousThreadId = -1, neededMerge = false) MergeResult(threadId = primaryThreadId, previousThreadId = -1, neededMerge = false)
} else if (primary == null && secondary != null) { } else if (primaryThreadId == null && secondaryThreadId != null) {
Log.w(TAG, "[merge] Only had a thread for secondary. Updating it to have the recipientId of the primary.", true) Log.w(TAG, "[merge] Only had a thread for secondary. Updating it to have the recipientId of the primary.", true)
writableDatabase writableDatabase
.update(TABLE_NAME) .update(TABLE_NAME)
.values(RECIPIENT_ID to primaryRecipientId.serialize()) .values(RECIPIENT_ID to primaryRecipientId.serialize())
.where("$ID = ?", secondary.threadId) .where("$ID = ?", secondaryThreadId)
.run() .run()
synchronized(threadIdCache) { synchronized(threadIdCache) {
threadIdCache.remove(secondaryRecipientId) threadIdCache.remove(secondaryRecipientId)
} }
MergeResult(threadId = secondary.threadId, previousThreadId = -1, neededMerge = false) MergeResult(threadId = secondaryThreadId, previousThreadId = -1, neededMerge = false)
} else if (primary == null && secondary == null) { } else if (primaryThreadId == null && secondaryThreadId == null) {
Log.w(TAG, "[merge] No thread for either.") Log.w(TAG, "[merge] No thread for either.")
MergeResult(threadId = -1, previousThreadId = -1, neededMerge = false) MergeResult(threadId = -1, previousThreadId = -1, neededMerge = false)
} else { } else {
Log.w(TAG, "[merge] Had a thread for both. Deleting the secondary and merging the attributes together.", true) Log.w(TAG, "[merge] Had a thread for both. Deleting the secondary and merging the attributes together.", true)
check(primary != null) check(primaryThreadId != null)
check(secondary != null) check(secondaryThreadId != null)
for (table in threadIdDatabaseTables) { for (table in threadIdDatabaseTables) {
table.remapThread(secondary.threadId, primary.threadId) table.remapThread(secondaryThreadId, primaryThreadId)
} }
writableDatabase writableDatabase
.delete(TABLE_NAME) .delete(TABLE_NAME)
.where("$ID = ?", secondary.threadId) .where("$ID = ?", secondaryThreadId)
.run() .run()
synchronized(threadIdCache) { synchronized(threadIdCache) {
threadIdCache.remove(secondaryRecipientId) threadIdCache.remove(secondaryRecipientId)
} }
if (primary.expiresIn != secondary.expiresIn) { val primaryExpiresIn = getExpiresIn(primaryThreadId)
val secondaryExpiresIn = getExpiresIn(secondaryThreadId)
val values = ContentValues() val values = ContentValues()
if (primary.expiresIn == 0L) { values.put(ACTIVE, true)
values.put(EXPIRES_IN, secondary.expiresIn)
} else if (secondary.expiresIn == 0L) { if (primaryExpiresIn != secondaryExpiresIn) {
values.put(EXPIRES_IN, primary.expiresIn) if (primaryExpiresIn == 0L) {
values.put(EXPIRES_IN, secondaryExpiresIn)
} else if (secondaryExpiresIn == 0L) {
values.put(EXPIRES_IN, primaryExpiresIn)
} else { } else {
values.put(EXPIRES_IN, min(primary.expiresIn, secondary.expiresIn)) values.put(EXPIRES_IN, min(primaryExpiresIn, secondaryExpiresIn))
}
} }
writableDatabase writableDatabase
.update(TABLE_NAME) .update(TABLE_NAME)
.values(values) .values(values)
.where("$ID = ?", primary.threadId) .where("$ID = ?", primaryThreadId)
.run() .run()
}
RemappedRecords.getInstance().addThread(secondary.threadId, primary.threadId) RemappedRecords.getInstance().addThread(secondaryThreadId, primaryThreadId)
MergeResult(threadId = primary.threadId, previousThreadId = secondary.threadId, neededMerge = true) MergeResult(threadId = primaryThreadId, previousThreadId = secondaryThreadId, neededMerge = true)
} }
} }
@ -1662,6 +1668,15 @@ class ThreadTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTa
} }
} }
private fun getExpiresIn(threadId: Long): Long {
return readableDatabase
.select(EXPIRES_IN)
.from(TABLE_NAME)
.where("$ID = $threadId")
.run()
.readToSingleLong()
}
private fun SQLiteDatabase.deactivateThreads() { private fun SQLiteDatabase.deactivateThreads() {
deactivateThread(query = null) deactivateThread(query = null)
} }