Improve SenderKeyDistributionMessage envelope validation.

This commit is contained in:
Greyson Parrelli 2024-01-25 15:33:00 -05:00
parent 6fc9055221
commit 9ce021afa2

View file

@ -1,6 +1,7 @@
package org.whispersystems.signalservice.api.messages package org.whispersystems.signalservice.api.messages
import org.signal.libsignal.protocol.message.DecryptionErrorMessage import org.signal.libsignal.protocol.message.DecryptionErrorMessage
import org.signal.libsignal.protocol.message.SenderKeyDistributionMessage
import org.signal.libsignal.zkgroup.InvalidInputException import org.signal.libsignal.zkgroup.InvalidInputException
import org.signal.libsignal.zkgroup.groups.GroupMasterKey import org.signal.libsignal.zkgroup.groups.GroupMasterKey
import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialPresentation import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialPresentation
@ -38,6 +39,10 @@ object EnvelopeContentValidator {
return Result.Invalid("Envelope had an invalid sourceServiceId!") return Result.Invalid("Envelope had an invalid sourceServiceId!")
} }
if (content.senderKeyDistributionMessage != null) {
validateSenderKeyDistributionMessage(content.senderKeyDistributionMessage.toByteArray())?.let { return it }
}
// Reminder: envelope.destinationServiceId was already validated since we need that for decryption // Reminder: envelope.destinationServiceId was already validated since we need that for decryption
return when { return when {
@ -50,9 +55,9 @@ object EnvelopeContentValidator {
content.typingMessage != null -> validateTypingMessage(envelope, content.typingMessage) content.typingMessage != null -> validateTypingMessage(envelope, content.typingMessage)
content.decryptionErrorMessage != null -> validateDecryptionErrorMessage(content.decryptionErrorMessage.toByteArray()) content.decryptionErrorMessage != null -> validateDecryptionErrorMessage(content.decryptionErrorMessage.toByteArray())
content.storyMessage != null -> validateStoryMessage(content.storyMessage) content.storyMessage != null -> validateStoryMessage(content.storyMessage)
content.editMessage != null -> validateEditMessage(content.editMessage)
content.pniSignatureMessage != null -> Result.Valid content.pniSignatureMessage != null -> Result.Valid
content.senderKeyDistributionMessage != null -> Result.Valid content.senderKeyDistributionMessage != null -> Result.Valid
content.editMessage != null -> validateEditMessage(content.editMessage)
else -> Result.Invalid("Content is empty!") else -> Result.Invalid("Content is empty!")
} }
} }
@ -241,6 +246,15 @@ object EnvelopeContentValidator {
} }
} }
private fun validateSenderKeyDistributionMessage(serializedSenderKeyDistributionMessage: ByteArray): Result.Invalid? {
return try {
SenderKeyDistributionMessage(serializedSenderKeyDistributionMessage)
null
} catch (e: Exception) {
Result.Invalid("[SenderKeyDistributionMessage] Bad sender key distribution message!", e)
}
}
private fun validateStoryMessage(storyMessage: StoryMessage): Result { private fun validateStoryMessage(storyMessage: StoryMessage): Result {
if (storyMessage.group != null) { if (storyMessage.group != null) {
validateGroupContextV2(storyMessage.group, "[StoryMessage]")?.let { return it } validateGroupContextV2(storyMessage.group, "[StoryMessage]")?.let { return it }