Improve safety of update and delete database methods.

This commit is contained in:
Greyson Parrelli 2024-01-12 15:20:05 -05:00
parent e361795184
commit 750fd4efe1
22 changed files with 124 additions and 101 deletions

View file

@ -7,7 +7,7 @@ import org.junit.Assert.assertTrue
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.signal.core.util.readToList
import org.signal.core.util.requireLong
import org.signal.core.util.withinTransaction
@ -33,8 +33,8 @@ class GroupTableTest {
fun setUp() {
groupTable = SignalDatabase.groups
groupTable.writableDatabase.delete(GroupTable.TABLE_NAME).run()
groupTable.writableDatabase.delete(GroupTable.MembershipTable.TABLE_NAME).run()
groupTable.writableDatabase.deleteAll(GroupTable.TABLE_NAME)
groupTable.writableDatabase.deleteAll(GroupTable.MembershipTable.TABLE_NAME)
}
@Test

View file

@ -10,7 +10,7 @@ import org.signal.core.util.forEach
import org.signal.core.util.requireLong
import org.signal.core.util.requireNonNullString
import org.signal.core.util.select
import org.signal.core.util.update
import org.signal.core.util.updateAll
import org.thoughtcrime.securesms.crash.CrashConfig
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.testing.assertIs
@ -220,7 +220,7 @@ class LogDatabaseTest {
)
db.writableDatabase
.update(LogDatabase.CrashTable.TABLE_NAME)
.updateAll(LogDatabase.CrashTable.TABLE_NAME)
.values(LogDatabase.CrashTable.LAST_PROMPTED_AT to currentTime)
.run()

View file

@ -2,7 +2,9 @@ package org.thoughtcrime.securesms.testing
import org.junit.rules.TestWatcher
import org.junit.runner.Description
import org.signal.core.util.deleteAll
import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.database.ThreadTable
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.whispersystems.signalservice.api.push.ServiceId.ACI
import org.whispersystems.signalservice.api.push.ServiceId.PNI
@ -34,7 +36,8 @@ class SignalDatabaseRule(
private fun deleteAllThreads() {
if (deleteAllThreadsOnEachRun) {
SignalDatabase.threads.clearForTests()
SignalDatabase.threads.deleteAllConversations()
SignalDatabase.rawDatabase.deleteAll(ThreadTable.TABLE_NAME)
}
}
}

View file

@ -5,9 +5,9 @@
package org.thoughtcrime.securesms.backup.v2.database
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.thoughtcrime.securesms.database.AttachmentTable
fun AttachmentTable.clearAllDataForBackupRestore() {
writableDatabase.delete(AttachmentTable.TABLE_NAME).run()
writableDatabase.deleteAll(AttachmentTable.TABLE_NAME)
}

View file

@ -7,7 +7,7 @@ package org.thoughtcrime.securesms.backup.v2.database
import okio.ByteString.Companion.toByteString
import org.signal.core.util.CursorUtil
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.signal.core.util.logging.Log
import org.signal.core.util.readToList
import org.signal.core.util.requireLong
@ -88,12 +88,10 @@ fun DistributionListTables.restoreFromBackup(dlist: BackupDistributionList, back
fun DistributionListTables.clearAllDataForBackupRestore() {
writableDatabase
.delete(DistributionListTables.ListTable.TABLE_NAME)
.run()
.deleteAll(DistributionListTables.ListTable.TABLE_NAME)
writableDatabase
.delete(DistributionListTables.MembershipTable.TABLE_NAME)
.run()
.deleteAll(DistributionListTables.MembershipTable.TABLE_NAME)
}
private fun DistributionListPrivacyMode.toBackupPrivacyMode(): BackupDistributionList.PrivacyMode {

View file

@ -10,7 +10,7 @@ import android.database.Cursor
import okio.ByteString.Companion.toByteString
import org.signal.core.util.Base64
import org.signal.core.util.SqlUtil
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.signal.core.util.logging.Log
import org.signal.core.util.nullIfBlank
import org.signal.core.util.requireBoolean
@ -155,7 +155,7 @@ fun RecipientTable.restoreSelfFromBackup(accountData: AccountData, selfId: Recip
}
fun RecipientTable.clearAllDataForBackupRestore() {
writableDatabase.delete(RecipientTable.TABLE_NAME).run()
writableDatabase.deleteAll(RecipientTable.TABLE_NAME)
SqlUtil.resetAutoIncrementValue(writableDatabase, RecipientTable.TABLE_NAME)
RecipientId.clearCache()

View file

@ -216,7 +216,7 @@ class ChangeNumberRepository(
@WorkerThread
fun changeLocalNumber(e164: String, pni: PNI): Single<Unit> {
val oldStorageId: ByteArray? = Recipient.self().storageServiceId
SignalDatabase.recipients.updateSelfPhone(e164, pni)
SignalDatabase.recipients.updateSelfE164(e164, pni)
val newStorageId: ByteArray? = Recipient.self().storageServiceId
if (e164 != SignalStore.account().requireE164() && MessageDigest.isEqual(oldStorageId, newStorageId)) {

View file

@ -39,6 +39,7 @@ import org.signal.core.util.SqlUtil.buildSingleCollectionQuery
import org.signal.core.util.StreamUtil
import org.signal.core.util.ThreadUtil
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.signal.core.util.exists
import org.signal.core.util.forEach
import org.signal.core.util.groupBy
@ -475,9 +476,7 @@ class AttachmentTable(
fun deleteAllAttachments() {
Log.d(TAG, "[deleteAllAttachments]")
writableDatabase
.delete(TABLE_NAME)
.run()
writableDatabase.deleteAll(TABLE_NAME)
FileUtils.deleteDirectoryContents(context.getDir(DIRECTORY, Context.MODE_PRIVATE))

View file

@ -8,6 +8,7 @@ import org.signal.core.util.IntSerializer
import org.signal.core.util.Serializer
import org.signal.core.util.SqlUtil
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.signal.core.util.flatten
import org.signal.core.util.insertInto
import org.signal.core.util.logging.Log
@ -1014,9 +1015,7 @@ class CallTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTabl
@Discouraged("Using this method is generally considered an error. Utilize other deletion methods instead of this.")
fun deleteAllCalls() {
Log.w(TAG, "Deleting all calls from the local database.")
writableDatabase
.delete(TABLE_NAME)
.run()
writableDatabase.deleteAll(TABLE_NAME)
}
private fun getCallSelectionQuery(callId: Long, recipientId: RecipientId): SqlUtil.Query {

View file

@ -5,6 +5,7 @@ import android.content.Context
import androidx.core.content.contentValuesOf
import org.signal.core.util.SqlUtil
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.signal.core.util.logging.Log
import org.signal.core.util.requireNonNullString
import org.signal.core.util.select
@ -106,8 +107,6 @@ class CdsTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTable
* Wipes the entire table.
*/
fun clearAll() {
writableDatabase
.delete(TABLE_NAME)
.run()
writableDatabase.deleteAll(TABLE_NAME)
}
}

View file

@ -5,6 +5,7 @@ import android.net.Uri
import androidx.core.content.contentValuesOf
import org.signal.core.util.SqlUtil
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.signal.core.util.logging.Log
import org.signal.core.util.readToList
import org.signal.core.util.requireNonNullString
@ -71,9 +72,7 @@ class DraftTable(context: Context?, databaseHelper: SignalDatabase?) : DatabaseT
}
fun clearAllDrafts() {
writableDatabase
.delete(TABLE_NAME)
.run()
writableDatabase.deleteAll(TABLE_NAME)
}
fun getDrafts(threadId: Long): Drafts {

View file

@ -6,6 +6,7 @@ import android.database.Cursor
import androidx.core.content.contentValuesOf
import org.signal.core.util.SqlUtil
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.signal.core.util.forEach
import org.signal.core.util.readToList
import org.signal.core.util.requireBoolean
@ -172,7 +173,7 @@ class GroupReceiptTable(context: Context?, databaseHelper: SignalDatabase?) : Da
}
fun deleteAllRows() {
writableDatabase.delete(TABLE_NAME).run()
writableDatabase.deleteAll(TABLE_NAME)
}
override fun remapRecipient(fromId: RecipientId, toId: RecipientId) {

View file

@ -21,6 +21,7 @@ import org.signal.core.util.requireNonNullString
import org.signal.core.util.requireString
import org.signal.core.util.select
import org.signal.core.util.update
import org.signal.core.util.updateAll
import org.signal.core.util.withinTransaction
import org.thoughtcrime.securesms.crypto.DatabaseSecret
import org.thoughtcrime.securesms.crypto.DatabaseSecretProvider
@ -235,7 +236,7 @@ class JobDatabase(
@Synchronized
fun updateAllJobsToBePending() {
writableDatabase
.update(Jobs.TABLE_NAME)
.updateAll(Jobs.TABLE_NAME)
.values(Jobs.IS_RUNNING to 0)
.run()
}

View file

@ -9,6 +9,7 @@ import org.signal.core.util.CursorUtil
import org.signal.core.util.SqlUtil
import org.signal.core.util.Stopwatch
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.signal.core.util.exists
import org.signal.core.util.getTableRowCount
import org.signal.core.util.insertInto
@ -268,9 +269,7 @@ class LogDatabase private constructor(
}
fun clearAll() {
writableDatabase
.delete(TABLE_NAME)
.run()
writableDatabase.deleteAll(TABLE_NAME)
}
private fun getSize(query: String?, args: Array<String>?): Long {
@ -403,9 +402,7 @@ class LogDatabase private constructor(
}
fun clear() {
writableDatabase
.delete(TABLE_NAME)
.run()
writableDatabase.deleteAll(TABLE_NAME)
}
private fun CrashConfig.CrashPattern.asLikeQuery(): Pair<String, Array<String>> {
@ -494,9 +491,7 @@ class LogDatabase private constructor(
}
fun clear() {
writableDatabase
.delete(TABLE_NAME)
.run()
writableDatabase.deleteAll(TABLE_NAME)
}
data class AnrRecord(

View file

@ -40,6 +40,7 @@ import org.signal.core.util.SqlUtil.getNextAutoIncrementId
import org.signal.core.util.Stopwatch
import org.signal.core.util.count
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.signal.core.util.exists
import org.signal.core.util.forEach
import org.signal.core.util.insertInto
@ -3390,7 +3391,7 @@ open class MessageTable(context: Context?, databaseHelper: SignalDatabase) : Dat
attachments.deleteAllAttachments()
groupReceipts.deleteAllRows()
mentions.deleteAllMentions()
writableDatabase.delete(TABLE_NAME).run()
writableDatabase.deleteAll(TABLE_NAME)
calls.updateCallEventDeletionTimestamps()
OptimizeMessageSearchIndexJob.enqueue()

View file

@ -3,6 +3,7 @@ package org.thoughtcrime.securesms.database
import android.content.Context
import androidx.core.content.contentValuesOf
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.signal.core.util.exists
import org.signal.core.util.logging.Log
import org.signal.core.util.update
@ -117,7 +118,7 @@ class PendingPniSignatureMessageTable(context: Context, databaseHelper: SignalDa
* Deletes all record of pending PNI verification messages. Should only be called after the user changes their number.
*/
fun deleteAll() {
writableDatabase.delete(TABLE_NAME).run()
writableDatabase.deleteAll(TABLE_NAME)
}
override fun remapRecipient(oldId: RecipientId, newId: RecipientId) {

View file

@ -33,6 +33,7 @@ import org.signal.core.util.requireNonNullString
import org.signal.core.util.requireString
import org.signal.core.util.select
import org.signal.core.util.update
import org.signal.core.util.updateAll
import org.signal.core.util.withinTransaction
import org.signal.libsignal.protocol.IdentityKey
import org.signal.libsignal.protocol.InvalidKeyException
@ -2011,7 +2012,7 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da
/**
* Does *not* handle clearing the recipient cache. It is assumed the caller handles this.
*/
fun updateSelfPhone(e164: String, pni: PNI) {
fun updateSelfE164(e164: String, pni: PNI) {
val db = writableDatabase
db.beginTransaction()
@ -2022,11 +2023,10 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da
if (id == newId) {
Log.i(TAG, "[updateSelfPhone] Phone updated for self")
} else {
throw AssertionError("[updateSelfPhone] Self recipient id changed when updating phone. old: $id new: $newId")
throw AssertionError("[updateSelfPhone] Self recipient id changed when updating e164. old: $id new: $newId")
}
db
.update(TABLE_NAME)
db.updateAll(TABLE_NAME)
.values(NEEDS_PNI_SIGNATURE to 0)
.run()

View file

@ -8,7 +8,7 @@ import androidx.core.content.contentValuesOf
import androidx.core.net.toUri
import org.json.JSONException
import org.json.JSONObject
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.signal.core.util.logging.Log
import org.signal.core.util.readToList
import org.signal.core.util.requireInt
@ -194,9 +194,7 @@ class RemoteMegaphoneTable(context: Context, databaseHelper: SignalDatabase) : D
/** Only call from internal settings */
fun debugRemoveAll() {
writableDatabase
.delete(TABLE_NAME)
.run()
writableDatabase.deleteAll(TABLE_NAME)
}
private fun RemoteMegaphoneRecord.toContentValues(): ContentValues {

View file

@ -4,6 +4,7 @@ import android.content.Context
import android.database.Cursor
import androidx.core.content.contentValuesOf
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.signal.core.util.logging.Log
import org.signal.core.util.readToSet
import org.signal.core.util.requireInt
@ -137,9 +138,7 @@ class SenderKeySharedTable internal constructor(context: Context?, databaseHelpe
* Clears all database content.
*/
fun deleteAll() {
writableDatabase
.delete(TABLE_NAME)
.run()
writableDatabase.deleteAll(TABLE_NAME)
}
/**

View file

@ -5,6 +5,7 @@ import android.database.Cursor
import androidx.core.content.contentValuesOf
import org.signal.core.util.CursorUtil
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.signal.core.util.firstOrNull
import org.signal.core.util.logging.Log
import org.signal.core.util.requireLong
@ -123,8 +124,6 @@ class SenderKeyTable internal constructor(context: Context?, databaseHelper: Sig
* Deletes all database state.
*/
fun deleteAll() {
writableDatabase
.delete(TABLE_NAME)
.run()
writableDatabase.deleteAll(TABLE_NAME)
}
}

View file

@ -6,7 +6,6 @@ import android.content.Context
import android.database.Cursor
import android.database.MergeCursor
import android.net.Uri
import androidx.annotation.VisibleForTesting
import androidx.core.content.contentValuesOf
import com.fasterxml.jackson.annotation.JsonProperty
import org.json.JSONObject
@ -26,6 +25,7 @@ import org.signal.core.util.requireString
import org.signal.core.util.select
import org.signal.core.util.toInt
import org.signal.core.util.update
import org.signal.core.util.updateAll
import org.signal.core.util.withinTransaction
import org.signal.libsignal.zkgroup.InvalidInputException
import org.signal.libsignal.zkgroup.groups.GroupMasterKey
@ -403,7 +403,7 @@ class ThreadTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTa
fun setAllThreadsRead(): List<MarkedMessageInfo> {
writableDatabase
.update(TABLE_NAME)
.updateAll(TABLE_NAME)
.values(
READ to ReadStatus.READ.serialize(),
UNREAD_COUNT to 0,
@ -1107,14 +1107,6 @@ class ThreadTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTa
ConversationUtil.clearShortcuts(context, recipientIds)
}
@VisibleForTesting
fun clearForTests() {
writableDatabase.withinTransaction {
deleteAllConversations()
it.delete(TABLE_NAME).run()
}
}
@SuppressLint("DiscouragedApi")
fun deleteAllConversations() {
writableDatabase.withinTransaction { db ->
@ -1294,7 +1286,7 @@ class ThreadTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTa
writableDatabase.withinTransaction { db ->
applyStorageSyncUpdate(recipientId, record.isNoteToSelfArchived, record.isNoteToSelfForcedUnread)
db.update(TABLE_NAME)
db.updateAll(TABLE_NAME)
.values(PINNED to 0)
.run()
@ -1664,36 +1656,42 @@ class ThreadTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTa
}
private fun SQLiteDatabase.deactivateThread(query: SqlUtil.Query?) {
val update = update(TABLE_NAME)
.values(
DATE to 0,
MEANINGFUL_MESSAGES to 0,
READ to ReadStatus.READ.serialize(),
TYPE to 0,
ERROR to 0,
SNIPPET to null,
SNIPPET_TYPE to 0,
SNIPPET_URI to null,
SNIPPET_CONTENT_TYPE to null,
SNIPPET_EXTRAS to null,
UNREAD_COUNT to 0,
ARCHIVED to 0,
STATUS to 0,
HAS_DELIVERY_RECEIPT to 0,
HAS_READ_RECEIPT to 0,
EXPIRES_IN to 0,
LAST_SEEN to 0,
HAS_SENT to 0,
LAST_SCROLLED to 0,
PINNED to 0,
UNREAD_SELF_MENTION_COUNT to 0,
ACTIVE to 0
)
val contentValues = contentValuesOf(
DATE to 0,
MEANINGFUL_MESSAGES to 0,
READ to ReadStatus.READ.serialize(),
TYPE to 0,
ERROR to 0,
SNIPPET to null,
SNIPPET_TYPE to 0,
SNIPPET_URI to null,
SNIPPET_CONTENT_TYPE to null,
SNIPPET_EXTRAS to null,
UNREAD_COUNT to 0,
ARCHIVED to 0,
STATUS to 0,
HAS_DELIVERY_RECEIPT to 0,
HAS_READ_RECEIPT to 0,
EXPIRES_IN to 0,
LAST_SEEN to 0,
HAS_SENT to 0,
LAST_SCROLLED to 0,
PINNED to 0,
UNREAD_SELF_MENTION_COUNT to 0,
ACTIVE to 0
)
if (query != null) {
update.where(query.where, query.whereArgs).run()
writableDatabase
.update(TABLE_NAME)
.values(contentValues)
.where(query.where, query.whereArgs)
.run()
} else {
update.run()
writableDatabase
.updateAll(TABLE_NAME)
.values(contentValues)
.run()
}
}

View file

@ -92,18 +92,31 @@ fun SupportSQLiteDatabase.count(): SelectBuilderPart1 {
/**
* Begins an UPDATE statement with a helpful builder pattern.
* Requires a WHERE clause as a way of mitigating mistakes. If you'd like to update all items in the table, use [updateAll].
*/
fun SupportSQLiteDatabase.update(tableName: String): UpdateBuilderPart1 {
return UpdateBuilderPart1(this, tableName)
}
fun SupportSQLiteDatabase.updateAll(tableName: String): UpdateAllBuilderPart1 {
return UpdateAllBuilderPart1(this, tableName)
}
/**
* Begins a DELETE statement with a helpful builder pattern.
* Requires a WHERE clause as a way of mitigating mistakes. If you'd like to delete all items in the table, use [deleteAll].
*/
fun SupportSQLiteDatabase.delete(tableName: String): DeleteBuilderPart1 {
return DeleteBuilderPart1(this, tableName)
}
/**
* Deletes all data in the table.
*/
fun SupportSQLiteDatabase.deleteAll(tableName: String): Int {
return this.delete(tableName, null, null)
}
fun SupportSQLiteDatabase.insertInto(tableName: String): InsertBuilderPart1 {
return InsertBuilderPart1(this, tableName)
}
@ -271,16 +284,14 @@ class UpdateBuilderPart2(
private val values: ContentValues
) {
fun where(@Language("sql") where: String, vararg whereArgs: Any): UpdateBuilderPart3 {
require(where.isNotBlank())
return UpdateBuilderPart3(db, tableName, values, where, SqlUtil.buildArgs(*whereArgs))
}
fun where(@Language("sql") where: String, whereArgs: Array<String>): UpdateBuilderPart3 {
require(where.isNotBlank())
return UpdateBuilderPart3(db, tableName, values, where, whereArgs)
}
fun run(conflictStrategy: Int = SQLiteDatabase.CONFLICT_NONE): Int {
return db.update(tableName, conflictStrategy, values, null, arrayOf<String>())
}
}
class UpdateBuilderPart3(
@ -296,21 +307,43 @@ class UpdateBuilderPart3(
}
}
class UpdateAllBuilderPart1(
private val db: SupportSQLiteDatabase,
private val tableName: String
) {
fun values(values: ContentValues): UpdateAllBuilderPart2 {
return UpdateAllBuilderPart2(db, tableName, values)
}
fun values(vararg values: Pair<String, Any?>): UpdateAllBuilderPart2 {
return UpdateAllBuilderPart2(db, tableName, contentValuesOf(*values))
}
}
class UpdateAllBuilderPart2(
private val db: SupportSQLiteDatabase,
private val tableName: String,
private val values: ContentValues
) {
@JvmOverloads
fun run(conflictStrategy: Int = SQLiteDatabase.CONFLICT_NONE): Int {
return db.update(tableName, conflictStrategy, values, null, emptyArray<String>())
}
}
class DeleteBuilderPart1(
private val db: SupportSQLiteDatabase,
private val tableName: String
) {
fun where(@Language("sql") where: String, vararg whereArgs: Any): DeleteBuilderPart2 {
require(where.isNotBlank())
return DeleteBuilderPart2(db, tableName, where, SqlUtil.buildArgs(*whereArgs))
}
fun where(@Language("sql") where: String, whereArgs: Array<String>): DeleteBuilderPart2 {
require(where.isNotBlank())
return DeleteBuilderPart2(db, tableName, where, whereArgs)
}
fun run(): Int {
return db.delete(tableName, null, emptyArray<String>())
}
}
class DeleteBuilderPart2(