Fix thread merges where one thread is inactive.

This commit is contained in:
Greyson Parrelli 2024-02-14 11:15:49 -05:00
parent 0cc7178cdc
commit ba41df19bb
2 changed files with 47 additions and 32 deletions

View file

@ -893,8 +893,8 @@ class RecipientTableTest_getAndPossiblyMerge {
// Thread validation
assertEquals(threadIdAci, retrievedThreadId)
Assert.assertNull(SignalDatabase.threads.getThreadIdFor(recipientIdE164))
Assert.assertNull(SignalDatabase.threads.getThreadRecord(threadIdE164))
assertNull(SignalDatabase.threads.getThreadIdFor(recipientIdE164))
assertNull(SignalDatabase.threads.getThreadRecord(threadIdE164))
// SMS validation
val sms1: MessageRecord = SignalDatabase.messages.getMessageRecord(smsId1)!!
@ -938,10 +938,10 @@ class RecipientTableTest_getAndPossiblyMerge {
// Identity validation
assertEquals(identityKeyAci, SignalDatabase.identities.getIdentityStoreRecord(ACI_A.toString())!!.identityKey)
Assert.assertNull(SignalDatabase.identities.getIdentityStoreRecord(E164_A))
assertNull(SignalDatabase.identities.getIdentityStoreRecord(E164_A))
// Session validation
Assert.assertNotNull(SignalDatabase.sessions.load(ACI_SELF, SignalProtocolAddress(ACI_A.toString(), 1)))
assertNotNull(SignalDatabase.sessions.load(ACI_SELF, SignalProtocolAddress(ACI_A.toString(), 1)))
// Reaction validation
val reactionsSms: List<ReactionRecord> = SignalDatabase.reactions.getReactions(MessageId(smsId1))

View file

@ -18,6 +18,7 @@ import org.signal.core.util.exists
import org.signal.core.util.logging.Log
import org.signal.core.util.or
import org.signal.core.util.readToList
import org.signal.core.util.readToSingleLong
import org.signal.core.util.requireBoolean
import org.signal.core.util.requireInt
import org.signal.core.util.requireLong
@ -1585,64 +1586,69 @@ class ThreadTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTa
check(databaseHelper.signalWritableDatabase.inTransaction()) { "Must be in a transaction!" }
Log.w(TAG, "Merging threads. Primary: $primaryRecipientId, Secondary: $secondaryRecipientId", true)
val primary: ThreadRecord? = getThreadRecord(getThreadIdFor(primaryRecipientId))
val secondary: ThreadRecord? = getThreadRecord(getThreadIdFor(secondaryRecipientId))
val primaryThreadId: Long? = getThreadIdFor(primaryRecipientId)
val secondaryThreadId: Long? = getThreadIdFor(secondaryRecipientId)
return if (primary != null && secondary == null) {
return if (primaryThreadId != null && secondaryThreadId == null) {
Log.w(TAG, "[merge] Only had a thread for primary. Returning that.", true)
MergeResult(threadId = primary.threadId, previousThreadId = -1, neededMerge = false)
} else if (primary == null && secondary != null) {
MergeResult(threadId = primaryThreadId, previousThreadId = -1, neededMerge = false)
} else if (primaryThreadId == null && secondaryThreadId != null) {
Log.w(TAG, "[merge] Only had a thread for secondary. Updating it to have the recipientId of the primary.", true)
writableDatabase
.update(TABLE_NAME)
.values(RECIPIENT_ID to primaryRecipientId.serialize())
.where("$ID = ?", secondary.threadId)
.where("$ID = ?", secondaryThreadId)
.run()
synchronized(threadIdCache) {
threadIdCache.remove(secondaryRecipientId)
}
MergeResult(threadId = secondary.threadId, previousThreadId = -1, neededMerge = false)
} else if (primary == null && secondary == null) {
MergeResult(threadId = secondaryThreadId, previousThreadId = -1, neededMerge = false)
} else if (primaryThreadId == null && secondaryThreadId == null) {
Log.w(TAG, "[merge] No thread for either.")
MergeResult(threadId = -1, previousThreadId = -1, neededMerge = false)
} else {
Log.w(TAG, "[merge] Had a thread for both. Deleting the secondary and merging the attributes together.", true)
check(primary != null)
check(secondary != null)
check(primaryThreadId != null)
check(secondaryThreadId != null)
for (table in threadIdDatabaseTables) {
table.remapThread(secondary.threadId, primary.threadId)
table.remapThread(secondaryThreadId, primaryThreadId)
}
writableDatabase
.delete(TABLE_NAME)
.where("$ID = ?", secondary.threadId)
.where("$ID = ?", secondaryThreadId)
.run()
synchronized(threadIdCache) {
threadIdCache.remove(secondaryRecipientId)
}
if (primary.expiresIn != secondary.expiresIn) {
val primaryExpiresIn = getExpiresIn(primaryThreadId)
val secondaryExpiresIn = getExpiresIn(secondaryThreadId)
val values = ContentValues()
if (primary.expiresIn == 0L) {
values.put(EXPIRES_IN, secondary.expiresIn)
} else if (secondary.expiresIn == 0L) {
values.put(EXPIRES_IN, primary.expiresIn)
values.put(ACTIVE, true)
if (primaryExpiresIn != secondaryExpiresIn) {
if (primaryExpiresIn == 0L) {
values.put(EXPIRES_IN, secondaryExpiresIn)
} else if (secondaryExpiresIn == 0L) {
values.put(EXPIRES_IN, primaryExpiresIn)
} else {
values.put(EXPIRES_IN, min(primary.expiresIn, secondary.expiresIn))
values.put(EXPIRES_IN, min(primaryExpiresIn, secondaryExpiresIn))
}
}
writableDatabase
.update(TABLE_NAME)
.values(values)
.where("$ID = ?", primary.threadId)
.where("$ID = ?", primaryThreadId)
.run()
}
RemappedRecords.getInstance().addThread(secondary.threadId, primary.threadId)
RemappedRecords.getInstance().addThread(secondaryThreadId, primaryThreadId)
MergeResult(threadId = primary.threadId, previousThreadId = secondary.threadId, neededMerge = true)
MergeResult(threadId = primaryThreadId, previousThreadId = secondaryThreadId, neededMerge = true)
}
}
@ -1662,6 +1668,15 @@ class ThreadTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTa
}
}
private fun getExpiresIn(threadId: Long): Long {
return readableDatabase
.select(EXPIRES_IN)
.from(TABLE_NAME)
.where("$ID = $threadId")
.run()
.readToSingleLong()
}
private fun SQLiteDatabase.deactivateThreads() {
deactivateThread(query = null)
}