diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/messages/EnvelopeContentValidator.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/messages/EnvelopeContentValidator.kt index 8a44bc6949..db240092af 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/messages/EnvelopeContentValidator.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/messages/EnvelopeContentValidator.kt @@ -1,6 +1,7 @@ package org.whispersystems.signalservice.api.messages import org.signal.libsignal.protocol.message.DecryptionErrorMessage +import org.signal.libsignal.protocol.message.SenderKeyDistributionMessage import org.signal.libsignal.zkgroup.InvalidInputException import org.signal.libsignal.zkgroup.groups.GroupMasterKey import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialPresentation @@ -38,6 +39,10 @@ object EnvelopeContentValidator { return Result.Invalid("Envelope had an invalid sourceServiceId!") } + if (content.senderKeyDistributionMessage != null) { + validateSenderKeyDistributionMessage(content.senderKeyDistributionMessage.toByteArray())?.let { return it } + } + // Reminder: envelope.destinationServiceId was already validated since we need that for decryption return when { @@ -50,9 +55,9 @@ object EnvelopeContentValidator { content.typingMessage != null -> validateTypingMessage(envelope, content.typingMessage) content.decryptionErrorMessage != null -> validateDecryptionErrorMessage(content.decryptionErrorMessage.toByteArray()) content.storyMessage != null -> validateStoryMessage(content.storyMessage) + content.editMessage != null -> validateEditMessage(content.editMessage) content.pniSignatureMessage != null -> Result.Valid content.senderKeyDistributionMessage != null -> Result.Valid - content.editMessage != null -> validateEditMessage(content.editMessage) else -> Result.Invalid("Content is empty!") } } @@ -241,6 +246,15 @@ object EnvelopeContentValidator { } } + private fun validateSenderKeyDistributionMessage(serializedSenderKeyDistributionMessage: ByteArray): Result.Invalid? { + return try { + SenderKeyDistributionMessage(serializedSenderKeyDistributionMessage) + null + } catch (e: Exception) { + Result.Invalid("[SenderKeyDistributionMessage] Bad sender key distribution message!", e) + } + } + private fun validateStoryMessage(storyMessage: StoryMessage): Result { if (storyMessage.group != null) { validateGroupContextV2(storyMessage.group, "[StoryMessage]")?.let { return it }