diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/database/RecipientTableTest_getAndPossiblyMerge.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/database/RecipientTableTest_getAndPossiblyMerge.kt index f921387642..9bf07f21e5 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/database/RecipientTableTest_getAndPossiblyMerge.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/database/RecipientTableTest_getAndPossiblyMerge.kt @@ -893,8 +893,8 @@ class RecipientTableTest_getAndPossiblyMerge { // Thread validation assertEquals(threadIdAci, retrievedThreadId) - Assert.assertNull(SignalDatabase.threads.getThreadIdFor(recipientIdE164)) - Assert.assertNull(SignalDatabase.threads.getThreadRecord(threadIdE164)) + assertNull(SignalDatabase.threads.getThreadIdFor(recipientIdE164)) + assertNull(SignalDatabase.threads.getThreadRecord(threadIdE164)) // SMS validation val sms1: MessageRecord = SignalDatabase.messages.getMessageRecord(smsId1)!! @@ -938,10 +938,10 @@ class RecipientTableTest_getAndPossiblyMerge { // Identity validation assertEquals(identityKeyAci, SignalDatabase.identities.getIdentityStoreRecord(ACI_A.toString())!!.identityKey) - Assert.assertNull(SignalDatabase.identities.getIdentityStoreRecord(E164_A)) + assertNull(SignalDatabase.identities.getIdentityStoreRecord(E164_A)) // Session validation - Assert.assertNotNull(SignalDatabase.sessions.load(ACI_SELF, SignalProtocolAddress(ACI_A.toString(), 1))) + assertNotNull(SignalDatabase.sessions.load(ACI_SELF, SignalProtocolAddress(ACI_A.toString(), 1))) // Reaction validation val reactionsSms: List = SignalDatabase.reactions.getReactions(MessageId(smsId1)) diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/ThreadTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/ThreadTable.kt index 3c32157f76..95c03d96d8 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/ThreadTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/ThreadTable.kt @@ -18,6 +18,7 @@ import org.signal.core.util.exists import org.signal.core.util.logging.Log import org.signal.core.util.or import org.signal.core.util.readToList +import org.signal.core.util.readToSingleLong import org.signal.core.util.requireBoolean import org.signal.core.util.requireInt import org.signal.core.util.requireLong @@ -1585,64 +1586,69 @@ class ThreadTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTa check(databaseHelper.signalWritableDatabase.inTransaction()) { "Must be in a transaction!" } Log.w(TAG, "Merging threads. Primary: $primaryRecipientId, Secondary: $secondaryRecipientId", true) - val primary: ThreadRecord? = getThreadRecord(getThreadIdFor(primaryRecipientId)) - val secondary: ThreadRecord? = getThreadRecord(getThreadIdFor(secondaryRecipientId)) + val primaryThreadId: Long? = getThreadIdFor(primaryRecipientId) + val secondaryThreadId: Long? = getThreadIdFor(secondaryRecipientId) - return if (primary != null && secondary == null) { + return if (primaryThreadId != null && secondaryThreadId == null) { Log.w(TAG, "[merge] Only had a thread for primary. Returning that.", true) - MergeResult(threadId = primary.threadId, previousThreadId = -1, neededMerge = false) - } else if (primary == null && secondary != null) { + MergeResult(threadId = primaryThreadId, previousThreadId = -1, neededMerge = false) + } else if (primaryThreadId == null && secondaryThreadId != null) { Log.w(TAG, "[merge] Only had a thread for secondary. Updating it to have the recipientId of the primary.", true) writableDatabase .update(TABLE_NAME) .values(RECIPIENT_ID to primaryRecipientId.serialize()) - .where("$ID = ?", secondary.threadId) + .where("$ID = ?", secondaryThreadId) .run() synchronized(threadIdCache) { threadIdCache.remove(secondaryRecipientId) } - MergeResult(threadId = secondary.threadId, previousThreadId = -1, neededMerge = false) - } else if (primary == null && secondary == null) { + MergeResult(threadId = secondaryThreadId, previousThreadId = -1, neededMerge = false) + } else if (primaryThreadId == null && secondaryThreadId == null) { Log.w(TAG, "[merge] No thread for either.") MergeResult(threadId = -1, previousThreadId = -1, neededMerge = false) } else { Log.w(TAG, "[merge] Had a thread for both. Deleting the secondary and merging the attributes together.", true) - check(primary != null) - check(secondary != null) + check(primaryThreadId != null) + check(secondaryThreadId != null) for (table in threadIdDatabaseTables) { - table.remapThread(secondary.threadId, primary.threadId) + table.remapThread(secondaryThreadId, primaryThreadId) } writableDatabase .delete(TABLE_NAME) - .where("$ID = ?", secondary.threadId) + .where("$ID = ?", secondaryThreadId) .run() synchronized(threadIdCache) { threadIdCache.remove(secondaryRecipientId) } - if (primary.expiresIn != secondary.expiresIn) { - val values = ContentValues() - if (primary.expiresIn == 0L) { - values.put(EXPIRES_IN, secondary.expiresIn) - } else if (secondary.expiresIn == 0L) { - values.put(EXPIRES_IN, primary.expiresIn) - } else { - values.put(EXPIRES_IN, min(primary.expiresIn, secondary.expiresIn)) - } + val primaryExpiresIn = getExpiresIn(primaryThreadId) + val secondaryExpiresIn = getExpiresIn(secondaryThreadId) - writableDatabase - .update(TABLE_NAME) - .values(values) - .where("$ID = ?", primary.threadId) - .run() + val values = ContentValues() + values.put(ACTIVE, true) + + if (primaryExpiresIn != secondaryExpiresIn) { + if (primaryExpiresIn == 0L) { + values.put(EXPIRES_IN, secondaryExpiresIn) + } else if (secondaryExpiresIn == 0L) { + values.put(EXPIRES_IN, primaryExpiresIn) + } else { + values.put(EXPIRES_IN, min(primaryExpiresIn, secondaryExpiresIn)) + } } - RemappedRecords.getInstance().addThread(secondary.threadId, primary.threadId) + writableDatabase + .update(TABLE_NAME) + .values(values) + .where("$ID = ?", primaryThreadId) + .run() - MergeResult(threadId = primary.threadId, previousThreadId = secondary.threadId, neededMerge = true) + RemappedRecords.getInstance().addThread(secondaryThreadId, primaryThreadId) + + MergeResult(threadId = primaryThreadId, previousThreadId = secondaryThreadId, neededMerge = true) } } @@ -1662,6 +1668,15 @@ class ThreadTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTa } } + private fun getExpiresIn(threadId: Long): Long { + return readableDatabase + .select(EXPIRES_IN) + .from(TABLE_NAME) + .where("$ID = $threadId") + .run() + .readToSingleLong() + } + private fun SQLiteDatabase.deactivateThreads() { deactivateThread(query = null) }