Add additional validations to incremental attachment streams.

This commit is contained in:
Greyson Parrelli 2024-04-26 12:07:54 -04:00
parent 18e6c57e75
commit 97c08f0d52
3 changed files with 301 additions and 129 deletions

View file

@ -12,7 +12,6 @@ import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice;
import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream;
import org.signal.libsignal.protocol.kdf.HKDF;
import org.whispersystems.signalservice.api.backup.BackupKey;
import org.whispersystems.signalservice.api.backup.MediaId;
import org.whispersystems.signalservice.internal.util.ContentLengthInputStream;
import org.whispersystems.signalservice.internal.util.Util;
@ -88,7 +87,12 @@ public class AttachmentCipherInputStream extends FilterInputStream {
wrappedStream = new FileInputStream(file);
} else {
wrappedStream = new IncrementalMacInputStream(
new FileInputStream(file),
new IncrementalMacAdditionalValidationsInputStream(
new FileInputStream(file),
file.length(),
mac,
digest
),
parts[1],
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
incrementalDigest);

View file

@ -0,0 +1,118 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.api.crypto
import org.signal.libsignal.protocol.InvalidMessageException
import org.whispersystems.signalservice.internal.util.Util
import java.io.FilterInputStream
import java.io.InputStream
import java.security.MessageDigest
import javax.crypto.Mac
/**
* This is meant as a helper stream to go along with [org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream].
* That class does not validate the overall digest, nor the overall MAC. This class does that for us.
*
* To use, wrap the IncremtalMacInputStream around this class, and then this class should wrap the lowest-level data stream.
*/
class IncrementalMacAdditionalValidationsInputStream(
wrapped: InputStream,
fileLength: Long,
private val mac: Mac,
private val theirDigest: ByteArray
) : FilterInputStream(wrapped) {
private val digest: MessageDigest = MessageDigest.getInstance("SHA256")
private val macLength: Int = mac.macLength
private val macBuffer: ByteArray = ByteArray(macLength)
private var validated = false
private var bytesRemaining: Int = fileLength.toInt()
private var macBufferPosition: Int = 0
override fun read(): Int {
throw UnsupportedOperationException()
}
/**
* We need to be very careful to keep track of what data is part of the MAC and what isn't, based on how far we've read into the file.
* As a recap, the digest needs to ingest the entire file, while the MAC needs to ingest everything except the last [macLength] bytes.
* (Because the last [macLength] bytes represents the MAC we're going to verify against.)
*
* The wrapping stream may request the full length of the file, so we need to do some bookkeeping to remember the last [macLength] bytes
* for comparison purposes during [validate] while not ingesting them into the MAC that we're calculating.
*/
override fun read(buffer: ByteArray, offset: Int, length: Int): Int {
val bytesRead = super.read(buffer, offset, length)
if (bytesRead == -1) {
validate()
return bytesRead
}
bytesRemaining -= bytesRead
// This indicates we've read into the last [macLength] bytes of the file, so we need to start our bookkeeping
if (bytesRemaining < macLength) {
val bytesOfMacRead = macLength - bytesRemaining
val newBytesOfMacRead = bytesOfMacRead - macBufferPosition
// There's a possibility that the reader has only partially read the last [macLength] bytes, so we need to keep track of a position in our
// MAC buffer and copy over just the new parts we've read
if (newBytesOfMacRead > 0) {
System.arraycopy(buffer, offset + bytesRead - newBytesOfMacRead, macBuffer, macBufferPosition, newBytesOfMacRead)
macBufferPosition += newBytesOfMacRead
}
// Even though we're reading into the MAC, many of the bytes read in this method call could be non-MAC bytes, so we need to copy
// those over, while excluding the bytes that are part of the MAC.
mac.update(buffer, offset, bytesRead - bytesOfMacRead)
} else {
mac.update(buffer, offset, bytesRead)
}
digest.update(buffer, offset, bytesRead)
if (bytesRemaining == 0) {
validate()
}
return bytesRead
}
override fun close() {
// We only want to validate the digest if we've otherwise read the entire stream.
// It's valid to close the stream early, and in this case, we don't want to force reading the whole rest of the stream.
if (bytesRemaining > macLength) {
super.close()
return
}
if (bytesRemaining > 0) {
Util.readFullyAsBytes(this)
}
super.close()
}
private fun validate() {
if (validated) {
return
}
validated = true
val ourMac = mac.doFinal()
val theirMac = macBuffer
if (!MessageDigest.isEqual(ourMac, theirMac)) {
throw InvalidMessageException("MAC doesn't match!")
}
val ourDigest = digest.digest()
if (!MessageDigest.isEqual(ourDigest, theirDigest)) {
throw InvalidMessageException("Digest doesn't match!")
}
}
}

View file

@ -2,12 +2,13 @@ package org.whispersystems.signalservice.api.crypto;
import org.conscrypt.Conscrypt;
import org.junit.Test;
import org.signal.core.util.StreamUtil;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice;
import org.signal.libsignal.protocol.incrementalmac.InvalidMacException;
import org.signal.libsignal.protocol.kdf.HKDFv3;
import org.signal.libsignal.protocol.util.ByteUtil;
import org.whispersystems.signalservice.api.backup.BackupKey;
import org.whispersystems.signalservice.api.backup.MediaId;
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream;
import org.whispersystems.signalservice.internal.push.http.AttachmentCipherOutputStreamFactory;
import org.whispersystems.signalservice.internal.util.Util;
@ -15,6 +16,7 @@ import org.whispersystems.signalservice.internal.util.Util;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
@ -37,11 +39,30 @@ public final class AttachmentCipherTest {
}
}
private static int MEBIBYTE = 1024 * 1024;
@Test
public void attachment_encryptDecrypt() throws IOException, InvalidMessageException {
public void attachment_encryptDecrypt_nonIncremental() throws IOException, InvalidMessageException {
attachment_encryptDecrypt(false, MEBIBYTE);
}
@Test
public void attachment_encryptDecrypt_incremental() throws IOException, InvalidMessageException {
attachment_encryptDecrypt(true, MEBIBYTE);
}
@Test
public void attachment_encryptDecrypt_incremental_manyFileSizes() throws IOException, InvalidMessageException {
// Designed to stress the various boundary conditions of reading the final mac
for (int i = 0; i < 100; i++) {
attachment_encryptDecrypt(true, MEBIBYTE + new Random().nextInt(1, 64 * 1024));
}
}
private void attachment_encryptDecrypt(boolean incremental, int fileSize) throws IOException, InvalidMessageException {
byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = "Peter Parker".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key, false);
byte[] plaintextInput = Util.getSecretBytes(fileSize);
EncryptResult encryptResult = encryptData(plaintextInput, key, incremental);
File cipherFile = writeToFile(encryptResult.ciphertext);
InputStream inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice);
byte[] plaintextOutput = readInputStreamFully(inputStream);
@ -52,10 +73,19 @@ public final class AttachmentCipherTest {
}
@Test
public void attachment_encryptDecryptEmpty() throws IOException, InvalidMessageException {
public void attachment_encryptDecryptEmpty_nonIncremental() throws IOException, InvalidMessageException {
attachment_encryptDecryptEmpty(false);
}
@Test
public void attachment_encryptDecryptEmpty_incremental() throws IOException, InvalidMessageException {
attachment_encryptDecryptEmpty(true);
}
private void attachment_encryptDecryptEmpty(boolean incremental) throws IOException, InvalidMessageException {
byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = "".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key, true);
EncryptResult encryptResult = encryptData(plaintextInput, key, incremental);
File cipherFile = writeToFile(encryptResult.ciphertext);
InputStream inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice);
byte[] plaintextOutput = readInputStreamFully(inputStream);
@ -65,110 +95,130 @@ public final class AttachmentCipherTest {
cipherFile.delete();
}
@Test
public void attachment_decryptFailOnBadKey() throws IOException {
File cipherFile = null;
boolean hitCorrectException = false;
@Test(expected = InvalidMessageException.class)
public void attachment_decryptFailOnBadKey_nonIncremental() throws IOException, InvalidMessageException {
attachment_decryptFailOnBadKey(false);
}
@Test(expected = InvalidMessageException.class)
public void attachment_decryptFailOnBadKey_incremental() throws IOException, InvalidMessageException {
attachment_decryptFailOnBadKey(true);
}
private void attachment_decryptFailOnBadKey(boolean incremental) throws IOException, InvalidMessageException {
File cipherFile = null;
try {
byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = "Gwen Stacy".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key, true);
byte[] plaintextInput = Util.getSecretBytes(MEBIBYTE);
EncryptResult encryptResult = encryptData(plaintextInput, key, incremental);
byte[] badKey = new byte[64];
cipherFile = writeToFile(encryptResult.ciphertext);
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, badKey, encryptResult.digest, null, 0);
} catch (InvalidMessageException e) {
hitCorrectException = true;
} finally {
if (cipherFile != null) {
cipherFile.delete();
}
}
assertTrue(hitCorrectException);
}
@Test(expected = InvalidMessageException.class)
public void attachment_decryptFailOnBadMac_nonIncremental() throws IOException, InvalidMessageException {
attachment_decryptFailOnBadMac(false);
}
@Test
public void archive_encryptDecrypt() throws IOException, InvalidMessageException {
byte[] key = Util.getSecretBytes(64);
BackupKey.MediaKeyMaterial keyMaterial = BackupKey.MediaKeyMaterial.forMedia(Util.getSecretBytes(15), key, Util.getSecretBytes(16));
byte[] plaintextInput = "Peter Parker".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key, false);
File cipherFile = writeToFile(encryptResult.ciphertext);
InputStream inputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.length);
byte[] plaintextOutput = readInputStreamFully(inputStream);
assertArrayEquals(plaintextInput, plaintextOutput);
cipherFile.delete();
@Test(expected = InvalidMessageException.class)
public void attachment_decryptFailOnBadMac_incremental() throws IOException, InvalidMessageException {
attachment_decryptFailOnBadMac(true);
}
@Test
public void archive_encryptDecryptEmpty() throws IOException, InvalidMessageException {
byte[] key = Util.getSecretBytes(64);
BackupKey.MediaKeyMaterial keyMaterial = BackupKey.MediaKeyMaterial.forMedia(Util.getSecretBytes(15), key, Util.getSecretBytes(16));
byte[] plaintextInput = "".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key, false);
File cipherFile = writeToFile(encryptResult.ciphertext);
InputStream inputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.length);
byte[] plaintextOutput = readInputStreamFully(inputStream);
assertArrayEquals(plaintextInput, plaintextOutput);
cipherFile.delete();
}
@Test
public void archive_decryptFailOnBadKey() throws IOException {
File cipherFile = null;
boolean hitCorrectException = false;
private void attachment_decryptFailOnBadMac(boolean incremental) throws IOException, InvalidMessageException {
File cipherFile = null;
try {
byte[] key = Util.getSecretBytes(64);
byte[] badKey = Util.getSecretBytes(64);
BackupKey.MediaKeyMaterial keyMaterial = BackupKey.MediaKeyMaterial.forMedia(Util.getSecretBytes(15), badKey, Util.getSecretBytes(16));
byte[] plaintextInput = "Gwen Stacy".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key, false);
byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = Util.getSecretBytes(MEBIBYTE);
EncryptResult encryptResult = encryptData(plaintextInput, key, incremental);
byte[] badMacCiphertext = Arrays.copyOf(encryptResult.ciphertext, encryptResult.ciphertext.length);
cipherFile = writeToFile(encryptResult.ciphertext);
badMacCiphertext[badMacCiphertext.length - 1] += 1;
AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.length);
} catch (InvalidMessageException e) {
hitCorrectException = true;
cipherFile = writeToFile(badMacCiphertext);
InputStream stream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice);
// In incremental mode, we'll only check the digest after reading the whole thing
if (incremental) {
StreamUtil.readFully(stream);
}
} finally {
if (cipherFile != null) {
cipherFile.delete();
}
}
assertTrue(hitCorrectException);
}
@Test
public void attachment_decryptFailOnBadDigest() throws IOException {
File cipherFile = null;
boolean hitCorrectException = false;
@Test(expected = InvalidMessageException.class)
public void attachment_decryptFailOnNullDigest_nonIncremental() throws IOException, InvalidMessageException {
attachment_decryptFailOnNullDigest(false);
}
@Test(expected = InvalidMessageException.class)
public void attachment_decryptFailOnNullDigest_incremental() throws IOException, InvalidMessageException {
attachment_decryptFailOnNullDigest(true);
}
private void attachment_decryptFailOnNullDigest(boolean incremental) throws IOException, InvalidMessageException {
File cipherFile = null;
try {
byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = "Mary Jane Watson".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key, true);
byte[] plaintextInput = Util.getSecretBytes(MEBIBYTE);
EncryptResult encryptResult = encryptData(plaintextInput, key, incremental);
cipherFile = writeToFile(encryptResult.ciphertext);
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, null, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice);
} finally {
if (cipherFile != null) {
cipherFile.delete();
}
}
}
@Test(expected = InvalidMessageException.class)
public void attachment_decryptFailOnBadDigest_nonIncremental() throws IOException, InvalidMessageException {
attachment_decryptFailOnBadDigest(false);
}
@Test(expected = InvalidMessageException.class)
public void attachment_decryptFailOnBadDigest_incremental() throws IOException, InvalidMessageException {
attachment_decryptFailOnBadDigest(true);
}
private void attachment_decryptFailOnBadDigest(boolean incremental) throws IOException, InvalidMessageException {
File cipherFile = null;
try {
byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = Util.getSecretBytes(MEBIBYTE);
EncryptResult encryptResult = encryptData(plaintextInput, key, incremental);
byte[] badDigest = new byte[32];
cipherFile = writeToFile(encryptResult.ciphertext);
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, badDigest, null, 0);
} catch (InvalidMessageException e) {
hitCorrectException = true;
InputStream stream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, badDigest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice);
// In incremental mode, we'll only check the digest after reading the whole thing
if (incremental) {
StreamUtil.readFully(stream);
}
} finally {
if (cipherFile != null) {
cipherFile.delete();
}
}
assertTrue(hitCorrectException);
}
@Test
@ -178,9 +228,7 @@ public final class AttachmentCipherTest {
try {
byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = new byte[1000000];
new Random().nextBytes(plaintextInput);
byte[] plaintextInput = Util.getSecretBytes(MEBIBYTE);
EncryptResult encryptResult = encryptData(plaintextInput, key, true);
byte[] badDigest = Util.getSecretBytes(encryptResult.incrementalDigest.length);
@ -242,6 +290,62 @@ public final class AttachmentCipherTest {
}
}
@Test
public void archive_encryptDecrypt() throws IOException, InvalidMessageException {
byte[] key = Util.getSecretBytes(64);
BackupKey.MediaKeyMaterial keyMaterial = BackupKey.MediaKeyMaterial.forMedia(Util.getSecretBytes(15), key, Util.getSecretBytes(16));
byte[] plaintextInput = "Peter Parker".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key, false);
File cipherFile = writeToFile(encryptResult.ciphertext);
InputStream inputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.length);
byte[] plaintextOutput = readInputStreamFully(inputStream);
assertArrayEquals(plaintextInput, plaintextOutput);
cipherFile.delete();
}
@Test
public void archive_encryptDecryptEmpty() throws IOException, InvalidMessageException {
byte[] key = Util.getSecretBytes(64);
BackupKey.MediaKeyMaterial keyMaterial = BackupKey.MediaKeyMaterial.forMedia(Util.getSecretBytes(15), key, Util.getSecretBytes(16));
byte[] plaintextInput = "".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key, false);
File cipherFile = writeToFile(encryptResult.ciphertext);
InputStream inputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.length);
byte[] plaintextOutput = readInputStreamFully(inputStream);
assertArrayEquals(plaintextInput, plaintextOutput);
cipherFile.delete();
}
@Test
public void archive_decryptFailOnBadKey() throws IOException {
File cipherFile = null;
boolean hitCorrectException = false;
try {
byte[] key = Util.getSecretBytes(64);
byte[] badKey = Util.getSecretBytes(64);
BackupKey.MediaKeyMaterial keyMaterial = BackupKey.MediaKeyMaterial.forMedia(Util.getSecretBytes(15), badKey, Util.getSecretBytes(16));
byte[] plaintextInput = "Gwen Stacy".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key, false);
cipherFile = writeToFile(encryptResult.ciphertext);
AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.length);
} catch (InvalidMessageException e) {
hitCorrectException = true;
} finally {
if (cipherFile != null) {
cipherFile.delete();
}
}
assertTrue(hitCorrectException);
}
@Test
public void archive_encryptDecryptPaddedContent() throws IOException, InvalidMessageException {
int[] lengths = { 531, 600, 724, 1019, 1024 };
@ -280,59 +384,6 @@ public final class AttachmentCipherTest {
}
}
@Test
public void attachment_decryptFailOnNullDigest() throws IOException {
File cipherFile = null;
boolean hitCorrectException = false;
try {
byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = "Aunt May".getBytes();
ChunkSizeChoice sizeChoice = ChunkSizeChoice.inferChunkSize(plaintextInput.length);
EncryptResult encryptResult = encryptData(plaintextInput, key, true);
cipherFile = writeToFile(encryptResult.ciphertext);
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, null, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice);
} catch (InvalidMessageException e) {
hitCorrectException = true;
} finally {
if (cipherFile != null) {
cipherFile.delete();
}
}
assertTrue(hitCorrectException);
}
@Test
public void attachment_decryptFailOnBadMac() throws IOException {
File cipherFile = null;
boolean hitCorrectException = false;
try {
byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = "Uncle Ben".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key, true);
byte[] badMacCiphertext = Arrays.copyOf(encryptResult.ciphertext, encryptResult.ciphertext.length);
badMacCiphertext[badMacCiphertext.length - 1] += 1;
cipherFile = writeToFile(badMacCiphertext);
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, encryptResult.digest, null, encryptResult.chunkSizeChoice);
fail();
} catch (InvalidMessageException e) {
hitCorrectException = true;
} finally {
if (cipherFile != null) {
cipherFile.delete();
}
}
assertTrue(hitCorrectException);
}
@Test
public void archive_decryptFailOnBadMac() throws IOException {
File cipherFile = null;
@ -444,7 +495,6 @@ public final class AttachmentCipherTest {
encryptStream = factory.createFor(outputStream);
}
encryptStream.write(data);
encryptStream.flush();
encryptStream.close();