Add incremental digests to attachment sending.

This commit is contained in:
Nicholas 2023-06-28 16:24:46 -04:00
parent 025411c9fb
commit 34d252a4bd
34 changed files with 397 additions and 240 deletions

View file

@ -142,6 +142,7 @@ class ConversationItemPreviewer {
1024, 1024,
1024, 1024,
Optional.empty(), Optional.empty(),
Optional.empty(),
Optional.of("/not-there.jpg"), Optional.of("/not-there.jpg"),
false, false,
false, false,

View file

@ -32,6 +32,7 @@ import org.thoughtcrime.securesms.registration.VerifyResponse
import org.thoughtcrime.securesms.util.Util import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.profiles.SignalServiceProfile import org.whispersystems.signalservice.api.profiles.SignalServiceProfile
import org.whispersystems.signalservice.api.push.ACI import org.whispersystems.signalservice.api.push.ACI
import org.whispersystems.signalservice.api.push.ServiceIdType
import org.whispersystems.signalservice.api.push.SignalServiceAddress import org.whispersystems.signalservice.api.push.SignalServiceAddress
import org.whispersystems.signalservice.internal.ServiceResponse import org.whispersystems.signalservice.internal.ServiceResponse
import org.whispersystems.signalservice.internal.ServiceResponseProcessor import org.whispersystems.signalservice.internal.ServiceResponseProcessor
@ -87,7 +88,8 @@ class SignalActivityRule(private val othersCount: Int = 4) : ExternalResource()
password = Util.getSecret(18), password = Util.getSecret(18),
registrationId = registrationRepository.registrationId, registrationId = registrationRepository.registrationId,
profileKey = registrationRepository.getProfileKey("+15555550101"), profileKey = registrationRepository.getProfileKey("+15555550101"),
preKeyCollections = RegistrationRepository.generatePreKeys()!!, aciPreKeyCollection = RegistrationRepository.generatePreKeysForType(ServiceIdType.ACI),
pniPreKeyCollection = RegistrationRepository.generatePreKeysForType(ServiceIdType.PNI),
fcmToken = null, fcmToken = null,
pniRegistrationId = registrationRepository.pniRegistrationId, pniRegistrationId = registrationRepository.pniRegistrationId,
recoveryPassword = "asdfasdfasdfasdf" recoveryPassword = "asdfasdfasdfasdf"

View file

@ -35,6 +35,9 @@ public abstract class Attachment {
@Nullable @Nullable
private final byte[] digest; private final byte[] digest;
@Nullable
private final byte[] incrementalDigest;
@Nullable @Nullable
private final String fastPreflightId; private final String fastPreflightId;
@ -70,6 +73,7 @@ public abstract class Attachment {
@Nullable String key, @Nullable String key,
@Nullable String relay, @Nullable String relay,
@Nullable byte[] digest, @Nullable byte[] digest,
@Nullable byte[] incrementalDigest,
@Nullable String fastPreflightId, @Nullable String fastPreflightId,
boolean voiceNote, boolean voiceNote,
boolean borderless, boolean borderless,
@ -93,6 +97,7 @@ public abstract class Attachment {
this.key = key; this.key = key;
this.relay = relay; this.relay = relay;
this.digest = digest; this.digest = digest;
this.incrementalDigest = incrementalDigest;
this.fastPreflightId = fastPreflightId; this.fastPreflightId = fastPreflightId;
this.voiceNote = voiceNote; this.voiceNote = voiceNote;
this.borderless = borderless; this.borderless = borderless;
@ -165,6 +170,11 @@ public abstract class Attachment {
return digest; return digest;
} }
@Nullable
public byte[] getIncrementalDigest() {
return incrementalDigest;
}
@Nullable @Nullable
public String getFastPreflightId() { public String getFastPreflightId() {
return fastPreflightId; return fastPreflightId;

View file

@ -33,6 +33,7 @@ public class DatabaseAttachment extends Attachment {
String key, String key,
String relay, String relay,
byte[] digest, byte[] digest,
byte[] incrementalDigest,
String fastPreflightId, String fastPreflightId,
boolean voiceNote, boolean voiceNote,
boolean borderless, boolean borderless,
@ -48,7 +49,7 @@ public class DatabaseAttachment extends Attachment {
int displayOrder, int displayOrder,
long uploadTimestamp) long uploadTimestamp)
{ {
super(contentType, transferProgress, size, fileName, cdnNumber, location, key, relay, digest, fastPreflightId, voiceNote, borderless, videoGif, width, height, quote, uploadTimestamp, caption, stickerLocator, blurHash, audioHash, transformProperties); super(contentType, transferProgress, size, fileName, cdnNumber, location, key, relay, digest, incrementalDigest, fastPreflightId, voiceNote, borderless, videoGif, width, height, quote, uploadTimestamp, caption, stickerLocator, blurHash, audioHash, transformProperties);
this.attachmentId = attachmentId; this.attachmentId = attachmentId;
this.hasData = hasData; this.hasData = hasData;
this.hasThumbnail = hasThumbnail; this.hasThumbnail = hasThumbnail;

View file

@ -11,7 +11,7 @@ import org.thoughtcrime.securesms.database.MessageTable;
public class MmsNotificationAttachment extends Attachment { public class MmsNotificationAttachment extends Attachment {
public MmsNotificationAttachment(int status, long size) { public MmsNotificationAttachment(int status, long size) {
super("application/mms", getTransferStateFromStatus(status), size, null, 0, null, null, null, null, null, false, false, false, 0, 0, false, 0, null, null, null, null, null); super("application/mms", getTransferStateFromStatus(status), size, null, 0, null, null, null, null, null, null, false, false, false, 0, 0, false, 0, null, null, null, null, null);
} }
@Nullable @Nullable

View file

@ -30,6 +30,7 @@ public class PointerAttachment extends Attachment {
@Nullable String key, @Nullable String key,
@Nullable String relay, @Nullable String relay,
@Nullable byte[] digest, @Nullable byte[] digest,
@Nullable byte[] incrementalDigest,
@Nullable String fastPreflightId, @Nullable String fastPreflightId,
boolean voiceNote, boolean voiceNote,
boolean borderless, boolean borderless,
@ -41,7 +42,7 @@ public class PointerAttachment extends Attachment {
@Nullable StickerLocator stickerLocator, @Nullable StickerLocator stickerLocator,
@Nullable BlurHash blurHash) @Nullable BlurHash blurHash)
{ {
super(contentType, transferState, size, fileName, cdnNumber, location, key, relay, digest, fastPreflightId, voiceNote, borderless, videoGif, width, height, false, uploadTimestamp, caption, stickerLocator, blurHash, null, null); super(contentType, transferState, size, fileName, cdnNumber, location, key, relay, digest, incrementalDigest, fastPreflightId, voiceNote, borderless, videoGif, width, height, false, uploadTimestamp, caption, stickerLocator, blurHash, null, null);
} }
@Nullable @Nullable
@ -112,6 +113,7 @@ public class PointerAttachment extends Attachment {
pointer.get().asPointer().getRemoteId().toString(), pointer.get().asPointer().getRemoteId().toString(),
encodedKey, null, encodedKey, null,
pointer.get().asPointer().getDigest().orElse(null), pointer.get().asPointer().getDigest().orElse(null),
pointer.get().asPointer().getincrementalDigest().orElse(null),
fastPreflightId, fastPreflightId,
pointer.get().asPointer().getVoiceNote(), pointer.get().asPointer().getVoiceNote(),
pointer.get().asPointer().isBorderless(), pointer.get().asPointer().isBorderless(),
@ -137,6 +139,7 @@ public class PointerAttachment extends Attachment {
thumbnail != null && thumbnail.asPointer().getKey() != null ? Base64.encodeBytes(thumbnail.asPointer().getKey()) : null, thumbnail != null && thumbnail.asPointer().getKey() != null ? Base64.encodeBytes(thumbnail.asPointer().getKey()) : null,
null, null,
thumbnail != null ? thumbnail.asPointer().getDigest().orElse(null) : null, thumbnail != null ? thumbnail.asPointer().getDigest().orElse(null) : null,
thumbnail != null ? thumbnail.asPointer().getincrementalDigest().orElse(null) : null,
null, null,
false, false,
false, false,
@ -166,6 +169,7 @@ public class PointerAttachment extends Attachment {
thumbnail != null && thumbnail.asPointer().getKey() != null ? Base64.encodeBytes(thumbnail.asPointer().getKey()) : null, thumbnail != null && thumbnail.asPointer().getKey() != null ? Base64.encodeBytes(thumbnail.asPointer().getKey()) : null,
null, null,
thumbnail != null ? thumbnail.asPointer().getDigest().orElse(null) : null, thumbnail != null ? thumbnail.asPointer().getDigest().orElse(null) : null,
thumbnail != null ? thumbnail.asPointer().getincrementalDigest().orElse(null) : null,
null, null,
false, false,
false, false,

View file

@ -16,7 +16,7 @@ import org.thoughtcrime.securesms.database.AttachmentTable;
public class TombstoneAttachment extends Attachment { public class TombstoneAttachment extends Attachment {
public TombstoneAttachment(@NonNull String contentType, boolean quote) { public TombstoneAttachment(@NonNull String contentType, boolean quote) {
super(contentType, AttachmentTable.TRANSFER_PROGRESS_DONE, 0, null, 0, null, null, null, null, null, false, false, false, 0, 0, quote, 0, null, null, null, null, null); super(contentType, AttachmentTable.TRANSFER_PROGRESS_DONE, 0, null, 0, null, null, null, null, null, null, false, false, false, 0, 0, quote, 0, null, null, null, null, null);
} }
@Override @Override

View file

@ -52,7 +52,7 @@ public class UriAttachment extends Attachment {
@Nullable AudioHash audioHash, @Nullable AudioHash audioHash,
@Nullable TransformProperties transformProperties) @Nullable TransformProperties transformProperties)
{ {
super(contentType, transferState, size, fileName, 0, null, null, null, null, fastPreflightId, voiceNote, borderless, videoGif, width, height, quote, 0, caption, stickerLocator, blurHash, audioHash, transformProperties); super(contentType, transferState, size, fileName, 0, null, null, null, null, null, fastPreflightId, voiceNote, borderless, videoGif, width, height, quote, 0, caption, stickerLocator, blurHash, audioHash, transformProperties);
this.dataUri = Objects.requireNonNull(dataUri); this.dataUri = Objects.requireNonNull(dataUri);
} }

View file

@ -126,6 +126,7 @@ public class AttachmentTable extends DatabaseTable {
static final String DISPLAY_ORDER = "display_order"; static final String DISPLAY_ORDER = "display_order";
static final String UPLOAD_TIMESTAMP = "upload_timestamp"; static final String UPLOAD_TIMESTAMP = "upload_timestamp";
static final String CDN_NUMBER = "cdn_number"; static final String CDN_NUMBER = "cdn_number";
static final String MAC_DIGEST = "incremental_mac_digest";
private static final String DIRECTORY = "parts"; private static final String DIRECTORY = "parts";
@ -143,7 +144,7 @@ public class AttachmentTable extends DatabaseTable {
private static final String[] PROJECTION = new String[] {ROW_ID, private static final String[] PROJECTION = new String[] {ROW_ID,
MMS_ID, CONTENT_TYPE, NAME, CONTENT_DISPOSITION, MMS_ID, CONTENT_TYPE, NAME, CONTENT_DISPOSITION,
CDN_NUMBER, CONTENT_LOCATION, DATA, CDN_NUMBER, CONTENT_LOCATION, DATA,
TRANSFER_STATE, SIZE, FILE_NAME, UNIQUE_ID, DIGEST, TRANSFER_STATE, SIZE, FILE_NAME, UNIQUE_ID, DIGEST, MAC_DIGEST,
FAST_PREFLIGHT_ID, VOICE_NOTE, BORDERLESS, VIDEO_GIF, QUOTE, DATA_RANDOM, FAST_PREFLIGHT_ID, VOICE_NOTE, BORDERLESS, VIDEO_GIF, QUOTE, DATA_RANDOM,
WIDTH, HEIGHT, CAPTION, STICKER_PACK_ID, WIDTH, HEIGHT, CAPTION, STICKER_PACK_ID,
STICKER_PACK_KEY, STICKER_ID, STICKER_EMOJI, DATA_HASH, VISUAL_HASH, STICKER_PACK_KEY, STICKER_ID, STICKER_EMOJI, DATA_HASH, VISUAL_HASH,
@ -188,7 +189,8 @@ public class AttachmentTable extends DatabaseTable {
TRANSFER_FILE + " TEXT DEFAULT NULL, " + TRANSFER_FILE + " TEXT DEFAULT NULL, " +
DISPLAY_ORDER + " INTEGER DEFAULT 0, " + DISPLAY_ORDER + " INTEGER DEFAULT 0, " +
UPLOAD_TIMESTAMP + " INTEGER DEFAULT 0, " + UPLOAD_TIMESTAMP + " INTEGER DEFAULT 0, " +
CDN_NUMBER + " INTEGER DEFAULT 0);"; CDN_NUMBER + " INTEGER DEFAULT 0, " +
MAC_DIGEST + " BLOB);";
public static final String[] CREATE_INDEXS = { public static final String[] CREATE_INDEXS = {
"CREATE INDEX IF NOT EXISTS part_mms_id_index ON " + TABLE_NAME + " (" + MMS_ID + ");", "CREATE INDEX IF NOT EXISTS part_mms_id_index ON " + TABLE_NAME + " (" + MMS_ID + ");",
@ -698,6 +700,7 @@ public class AttachmentTable extends DatabaseTable {
contentValues.put(CDN_NUMBER, sourceAttachment.getCdnNumber()); contentValues.put(CDN_NUMBER, sourceAttachment.getCdnNumber());
contentValues.put(CONTENT_LOCATION, sourceAttachment.getLocation()); contentValues.put(CONTENT_LOCATION, sourceAttachment.getLocation());
contentValues.put(DIGEST, sourceAttachment.getDigest()); contentValues.put(DIGEST, sourceAttachment.getDigest());
contentValues.put(MAC_DIGEST, sourceAttachment.getIncrementalDigest());
contentValues.put(CONTENT_DISPOSITION, sourceAttachment.getKey()); contentValues.put(CONTENT_DISPOSITION, sourceAttachment.getKey());
contentValues.put(NAME, sourceAttachment.getRelay()); contentValues.put(NAME, sourceAttachment.getRelay());
contentValues.put(SIZE, sourceAttachment.getSize()); contentValues.put(SIZE, sourceAttachment.getSize());
@ -746,6 +749,7 @@ public class AttachmentTable extends DatabaseTable {
values.put(CDN_NUMBER, attachment.getCdnNumber()); values.put(CDN_NUMBER, attachment.getCdnNumber());
values.put(CONTENT_LOCATION, attachment.getLocation()); values.put(CONTENT_LOCATION, attachment.getLocation());
values.put(DIGEST, attachment.getDigest()); values.put(DIGEST, attachment.getDigest());
values.put(MAC_DIGEST, attachment.getIncrementalDigest());
values.put(CONTENT_DISPOSITION, attachment.getKey()); values.put(CONTENT_DISPOSITION, attachment.getKey());
values.put(NAME, attachment.getRelay()); values.put(NAME, attachment.getRelay());
values.put(SIZE, attachment.getSize()); values.put(SIZE, attachment.getSize());
@ -1272,6 +1276,7 @@ public class AttachmentTable extends DatabaseTable {
object.getString(CONTENT_DISPOSITION), object.getString(CONTENT_DISPOSITION),
object.getString(NAME), object.getString(NAME),
null, null,
null,
object.getString(FAST_PREFLIGHT_ID), object.getString(FAST_PREFLIGHT_ID),
object.getInt(VOICE_NOTE) == 1, object.getInt(VOICE_NOTE) == 1,
object.getInt(BORDERLESS) == 1, object.getInt(BORDERLESS) == 1,
@ -1319,6 +1324,7 @@ public class AttachmentTable extends DatabaseTable {
cursor.getString(cursor.getColumnIndexOrThrow(CONTENT_DISPOSITION)), cursor.getString(cursor.getColumnIndexOrThrow(CONTENT_DISPOSITION)),
cursor.getString(cursor.getColumnIndexOrThrow(NAME)), cursor.getString(cursor.getColumnIndexOrThrow(NAME)),
cursor.getBlob(cursor.getColumnIndexOrThrow(DIGEST)), cursor.getBlob(cursor.getColumnIndexOrThrow(DIGEST)),
cursor.getBlob(cursor.getColumnIndexOrThrow(MAC_DIGEST)),
cursor.getString(cursor.getColumnIndexOrThrow(FAST_PREFLIGHT_ID)), cursor.getString(cursor.getColumnIndexOrThrow(FAST_PREFLIGHT_ID)),
cursor.getInt(cursor.getColumnIndexOrThrow(VOICE_NOTE)) == 1, cursor.getInt(cursor.getColumnIndexOrThrow(VOICE_NOTE)) == 1,
cursor.getInt(cursor.getColumnIndexOrThrow(BORDERLESS)) == 1, cursor.getInt(cursor.getColumnIndexOrThrow(BORDERLESS)) == 1,
@ -1385,6 +1391,7 @@ public class AttachmentTable extends DatabaseTable {
contentValues.put(CDN_NUMBER, useTemplateUpload ? template.getCdnNumber() : attachment.getCdnNumber()); contentValues.put(CDN_NUMBER, useTemplateUpload ? template.getCdnNumber() : attachment.getCdnNumber());
contentValues.put(CONTENT_LOCATION, useTemplateUpload ? template.getLocation() : attachment.getLocation()); contentValues.put(CONTENT_LOCATION, useTemplateUpload ? template.getLocation() : attachment.getLocation());
contentValues.put(DIGEST, useTemplateUpload ? template.getDigest() : attachment.getDigest()); contentValues.put(DIGEST, useTemplateUpload ? template.getDigest() : attachment.getDigest());
contentValues.put(MAC_DIGEST, useTemplateUpload ? template.getIncrementalDigest() : attachment.getIncrementalDigest());
contentValues.put(CONTENT_DISPOSITION, useTemplateUpload ? template.getKey() : attachment.getKey()); contentValues.put(CONTENT_DISPOSITION, useTemplateUpload ? template.getKey() : attachment.getKey());
contentValues.put(NAME, useTemplateUpload ? template.getRelay() : attachment.getRelay()); contentValues.put(NAME, useTemplateUpload ? template.getRelay() : attachment.getRelay());
contentValues.put(FILE_NAME, StorageUtil.getCleanFileName(attachment.getFileName())); contentValues.put(FILE_NAME, StorageUtil.getCleanFileName(attachment.getFileName()));

View file

@ -48,6 +48,7 @@ class MediaTable internal constructor(context: Context?, databaseHelper: SignalD
${AttachmentTable.TABLE_NAME}.${AttachmentTable.CAPTION}, ${AttachmentTable.TABLE_NAME}.${AttachmentTable.CAPTION},
${AttachmentTable.TABLE_NAME}.${AttachmentTable.NAME}, ${AttachmentTable.TABLE_NAME}.${AttachmentTable.NAME},
${AttachmentTable.TABLE_NAME}.${AttachmentTable.UPLOAD_TIMESTAMP}, ${AttachmentTable.TABLE_NAME}.${AttachmentTable.UPLOAD_TIMESTAMP},
${AttachmentTable.TABLE_NAME}.${AttachmentTable.MAC_DIGEST},
${MessageTable.TABLE_NAME}.${MessageTable.TYPE}, ${MessageTable.TABLE_NAME}.${MessageTable.TYPE},
${MessageTable.TABLE_NAME}.${MessageTable.DATE_SENT}, ${MessageTable.TABLE_NAME}.${MessageTable.DATE_SENT},
${MessageTable.TABLE_NAME}.${MessageTable.DATE_RECEIVED}, ${MessageTable.TABLE_NAME}.${MessageTable.DATE_RECEIVED},
@ -55,7 +56,7 @@ class MediaTable internal constructor(context: Context?, databaseHelper: SignalD
${MessageTable.TABLE_NAME}.${MessageTable.THREAD_ID}, ${MessageTable.TABLE_NAME}.${MessageTable.THREAD_ID},
${MessageTable.TABLE_NAME}.${MessageTable.FROM_RECIPIENT_ID}, ${MessageTable.TABLE_NAME}.${MessageTable.FROM_RECIPIENT_ID},
${ThreadTable.TABLE_NAME}.${ThreadTable.RECIPIENT_ID} as $THREAD_RECIPIENT_ID ${ThreadTable.TABLE_NAME}.${ThreadTable.RECIPIENT_ID} as $THREAD_RECIPIENT_ID
FROM FROM
${AttachmentTable.TABLE_NAME} ${AttachmentTable.TABLE_NAME}
LEFT JOIN ${MessageTable.TABLE_NAME} ON ${AttachmentTable.TABLE_NAME}.${AttachmentTable.MMS_ID} = ${MessageTable.TABLE_NAME}.${MessageTable.ID} LEFT JOIN ${MessageTable.TABLE_NAME} ON ${AttachmentTable.TABLE_NAME}.${AttachmentTable.MMS_ID} = ${MessageTable.TABLE_NAME}.${MessageTable.ID}
LEFT JOIN ${ThreadTable.TABLE_NAME} ON ${ThreadTable.TABLE_NAME}.${ThreadTable.ID} = ${MessageTable.TABLE_NAME}.${MessageTable.THREAD_ID} LEFT JOIN ${ThreadTable.TABLE_NAME} ON ${ThreadTable.TABLE_NAME}.${ThreadTable.ID} = ${MessageTable.TABLE_NAME}.${MessageTable.THREAD_ID}

View file

@ -53,6 +53,7 @@ import org.thoughtcrime.securesms.database.helpers.migration.V194_KyberPreKeyMig
import org.thoughtcrime.securesms.database.helpers.migration.V195_GroupMemberForeignKeyMigration import org.thoughtcrime.securesms.database.helpers.migration.V195_GroupMemberForeignKeyMigration
import org.thoughtcrime.securesms.database.helpers.migration.V196_BackCallLinksWithRecipientV2 import org.thoughtcrime.securesms.database.helpers.migration.V196_BackCallLinksWithRecipientV2
import org.thoughtcrime.securesms.database.helpers.migration.V197_DropAvatarColorFromCallLinks import org.thoughtcrime.securesms.database.helpers.migration.V197_DropAvatarColorFromCallLinks
import org.thoughtcrime.securesms.database.helpers.migration.V198_AddMacDigestColumn
/** /**
* Contains all of the database migrations for [SignalDatabase]. Broken into a separate file for cleanliness. * Contains all of the database migrations for [SignalDatabase]. Broken into a separate file for cleanliness.
@ -61,7 +62,7 @@ object SignalDatabaseMigrations {
val TAG: String = Log.tag(SignalDatabaseMigrations.javaClass) val TAG: String = Log.tag(SignalDatabaseMigrations.javaClass)
const val DATABASE_VERSION = 197 const val DATABASE_VERSION = 198
@JvmStatic @JvmStatic
fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) { fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
@ -260,6 +261,10 @@ object SignalDatabaseMigrations {
if (oldVersion < 197) { if (oldVersion < 197) {
V197_DropAvatarColorFromCallLinks.migrate(context, db, oldVersion, newVersion) V197_DropAvatarColorFromCallLinks.migrate(context, db, oldVersion, newVersion)
} }
if (oldVersion < 198) {
V198_AddMacDigestColumn.migrate(context, db, oldVersion, newVersion)
}
} }
@JvmStatic @JvmStatic

View file

@ -0,0 +1,19 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.database.helpers.migration
import android.app.Application
import net.zetetic.database.sqlcipher.SQLiteDatabase
/**
* New field migration.
*/
@Suppress("ClassName")
object V198_AddMacDigestColumn : SignalDatabaseMigration {
override fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
db.execSQL("ALTER TABLE part ADD COLUMN incremental_mac_digest BLOB")
}
}

View file

@ -247,6 +247,7 @@ public final class AttachmentDownloadJob extends BaseJob {
Optional.empty(), Optional.empty(),
0, 0, 0, 0,
Optional.ofNullable(attachment.getDigest()), Optional.ofNullable(attachment.getDigest()),
Optional.ofNullable(attachment.getIncrementalDigest()),
Optional.ofNullable(attachment.getFileName()), Optional.ofNullable(attachment.getFileName()),
attachment.isVoiceNote(), attachment.isVoiceNote(),
attachment.isBorderless(), attachment.isBorderless(),

View file

@ -86,7 +86,7 @@ public final class AvatarGroupsV1DownloadJob extends BaseJob {
attachment.deleteOnExit(); attachment.deleteOnExit();
SignalServiceMessageReceiver receiver = ApplicationDependencies.getSignalServiceMessageReceiver(); SignalServiceMessageReceiver receiver = ApplicationDependencies.getSignalServiceMessageReceiver();
SignalServiceAttachmentPointer pointer = new SignalServiceAttachmentPointer(0, new SignalServiceAttachmentRemoteId(avatarId), contentType, key, Optional.of(0), Optional.empty(), 0, 0, digest, fileName, false, false, false, Optional.empty(), Optional.empty(), System.currentTimeMillis()); SignalServiceAttachmentPointer pointer = new SignalServiceAttachmentPointer(0, new SignalServiceAttachmentRemoteId(avatarId), contentType, key, Optional.of(0), Optional.empty(), 0, 0, digest, Optional.empty(), fileName, false, false, false, Optional.empty(), Optional.empty(), System.currentTimeMillis());
InputStream inputStream = receiver.retrieveAttachment(pointer, attachment, AvatarHelper.AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE); InputStream inputStream = receiver.retrieveAttachment(pointer, attachment, AvatarHelper.AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE);
AvatarHelper.setAvatar(context, record.get().getRecipientId(), inputStream); AvatarHelper.setAvatar(context, record.get().getRecipientId(), inputStream);

View file

@ -283,6 +283,7 @@ public abstract class PushSendJob extends SendJob {
width, width,
height, height,
Optional.ofNullable(attachment.getDigest()), Optional.ofNullable(attachment.getDigest()),
Optional.ofNullable(attachment.getIncrementalDigest()),
Optional.ofNullable(attachment.getFileName()), Optional.ofNullable(attachment.getFileName()),
attachment.isVoiceNote(), attachment.isVoiceNote(),
attachment.isBorderless(), attachment.isBorderless(),

View file

@ -22,22 +22,24 @@ class AttachmentStreamLocalUriFetcher implements DataFetcher<InputStream> {
private final File attachment; private final File attachment;
private final byte[] key; private final byte[] key;
private final Optional<byte[]> digest; private final Optional<byte[]> digest;
private final Optional<byte[]> incrementalDigest;
private final long plaintextLength; private final long plaintextLength;
private InputStream is; private InputStream is;
AttachmentStreamLocalUriFetcher(File attachment, long plaintextLength, byte[] key, Optional<byte[]> digest) { AttachmentStreamLocalUriFetcher(File attachment, long plaintextLength, byte[] key, Optional<byte[]> digest, Optional<byte[]> incrementalDigest) {
this.attachment = attachment; this.attachment = attachment;
this.plaintextLength = plaintextLength; this.plaintextLength = plaintextLength;
this.digest = digest; this.digest = digest;
this.key = key; this.incrementalDigest = incrementalDigest;
this.key = key;
} }
@Override @Override
public void loadData(@NonNull Priority priority, @NonNull DataCallback<? super InputStream> callback) { public void loadData(@NonNull Priority priority, @NonNull DataCallback<? super InputStream> callback) {
try { try {
if (!digest.isPresent()) throw new InvalidMessageException("No attachment digest!"); if (!digest.isPresent()) throw new InvalidMessageException("No attachment digest!");
is = AttachmentCipherInputStream.createForAttachment(attachment, plaintextLength, key, digest.get()); is = AttachmentCipherInputStream.createForAttachment(attachment, plaintextLength, key, digest.get(), incrementalDigest.get());
callback.onDataReady(is); callback.onDataReady(is);
} catch (IOException | InvalidMessageException e) { } catch (IOException | InvalidMessageException e) {
callback.onLoadFailed(e); callback.onLoadFailed(e);

View file

@ -20,7 +20,7 @@ public class AttachmentStreamUriLoader implements ModelLoader<AttachmentModel, I
@Override @Override
public @Nullable LoadData<InputStream> buildLoadData(@NonNull AttachmentModel attachmentModel, int width, int height, @NonNull Options options) { public @Nullable LoadData<InputStream> buildLoadData(@NonNull AttachmentModel attachmentModel, int width, int height, @NonNull Options options) {
return new LoadData<>(attachmentModel, new AttachmentStreamLocalUriFetcher(attachmentModel.attachment, attachmentModel.plaintextLength, attachmentModel.key, attachmentModel.digest)); return new LoadData<>(attachmentModel, new AttachmentStreamLocalUriFetcher(attachmentModel.attachment, attachmentModel.plaintextLength, attachmentModel.key, attachmentModel.digest, attachmentModel.incrementalDigest));
} }
@Override @Override
@ -45,15 +45,20 @@ public class AttachmentStreamUriLoader implements ModelLoader<AttachmentModel, I
public @NonNull File attachment; public @NonNull File attachment;
public @NonNull byte[] key; public @NonNull byte[] key;
public @NonNull Optional<byte[]> digest; public @NonNull Optional<byte[]> digest;
public @NonNull Optional<byte[]> incrementalDigest;
public long plaintextLength; public long plaintextLength;
public AttachmentModel(@NonNull File attachment, @NonNull byte[] key, public AttachmentModel(@NonNull File attachment,
long plaintextLength, @NonNull Optional<byte[]> digest) @NonNull byte[] key,
long plaintextLength,
@NonNull Optional<byte[]> digest,
@NonNull Optional<byte[]> incrementalDigest)
{ {
this.attachment = attachment; this.attachment = attachment;
this.key = key; this.key = key;
this.digest = digest; this.digest = digest;
this.plaintextLength = plaintextLength; this.incrementalDigest = incrementalDigest;
this.plaintextLength = plaintextLength;
} }
@Override @Override

View file

@ -44,6 +44,7 @@ object ReleaseChannel {
mediaWidth, mediaWidth,
mediaHeight, mediaHeight,
Optional.empty(), Optional.empty(),
Optional.empty(),
Optional.of(media), Optional.of(media),
false, false,
false, false,

View file

@ -40,6 +40,7 @@ object FakeMessageRecords {
key: String = "", key: String = "",
relay: String = "", relay: String = "",
digest: ByteArray = byteArrayOf(), digest: ByteArray = byteArrayOf(),
incrementalDigest: ByteArray = byteArrayOf(),
fastPreflightId: String = "", fastPreflightId: String = "",
voiceNote: Boolean = false, voiceNote: Boolean = false,
borderless: Boolean = false, borderless: Boolean = false,
@ -69,6 +70,7 @@ object FakeMessageRecords {
key, key,
relay, relay,
digest, digest,
incrementalDigest,
fastPreflightId, fastPreflightId,
voiceNote, voiceNote,
borderless, borderless,

View file

@ -242,6 +242,7 @@ class UploadDependencyGraphTest {
attachment.key, attachment.key,
attachment.relay, attachment.relay,
attachment.digest, attachment.digest,
attachment.incrementalDigest,
attachment.fastPreflightId, attachment.fastPreflightId,
attachment.isVoiceNote, attachment.isVoiceNote,
attachment.isBorderless, attachment.isBorderless,

View file

@ -154,7 +154,7 @@ public class SignalServiceMessageReceiver {
if (!pointer.getDigest().isPresent()) throw new InvalidMessageException("No attachment digest!"); if (!pointer.getDigest().isPresent()) throw new InvalidMessageException("No attachment digest!");
socket.retrieveAttachment(pointer.getCdnNumber(), pointer.getRemoteId(), destination, maxSizeBytes, listener); socket.retrieveAttachment(pointer.getCdnNumber(), pointer.getRemoteId(), destination, maxSizeBytes, listener);
return AttachmentCipherInputStream.createForAttachment(destination, pointer.getSize().orElse(0), pointer.getKey(), pointer.getDigest().get()); return AttachmentCipherInputStream.createForAttachment(destination, pointer.getSize().orElse(0), pointer.getKey(), pointer.getDigest().get(), pointer.getincrementalDigest().orElse(new byte[0]));
} }
public InputStream retrieveSticker(byte[] packId, byte[] packKey, int stickerId) public InputStream retrieveSticker(byte[] packId, byte[] packKey, int stickerId)

View file

@ -87,6 +87,7 @@ import org.whispersystems.signalservice.api.util.Uint64Util;
import org.whispersystems.signalservice.api.util.UuidUtil; import org.whispersystems.signalservice.api.util.UuidUtil;
import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException; import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException;
import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration; import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration;
import org.whispersystems.signalservice.internal.crypto.AttachmentDigest;
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream; import org.whispersystems.signalservice.internal.crypto.PaddingInputStream;
import org.whispersystems.signalservice.internal.push.AttachmentV2UploadAttributes; import org.whispersystems.signalservice.internal.push.AttachmentV2UploadAttributes;
import org.whispersystems.signalservice.internal.push.AttachmentV3UploadAttributes; import org.whispersystems.signalservice.internal.push.AttachmentV3UploadAttributes;
@ -762,7 +763,7 @@ public class SignalServiceMessageSender {
v2UploadAttributes = socket.getAttachmentV2UploadAttributes(); v2UploadAttributes = socket.getAttachmentV2UploadAttributes();
} }
Pair<Long, byte[]> attachmentIdAndDigest = socket.uploadAttachment(attachmentData, v2UploadAttributes); Pair<Long, AttachmentDigest> attachmentIdAndDigest = socket.uploadAttachment(attachmentData, v2UploadAttributes);
return new SignalServiceAttachmentPointer(0, return new SignalServiceAttachmentPointer(0,
new SignalServiceAttachmentRemoteId(attachmentIdAndDigest.first()), new SignalServiceAttachmentRemoteId(attachmentIdAndDigest.first()),
@ -771,7 +772,8 @@ public class SignalServiceMessageSender {
Optional.of(Util.toIntExact(attachment.getLength())), Optional.of(Util.toIntExact(attachment.getLength())),
attachment.getPreview(), attachment.getPreview(),
attachment.getWidth(), attachment.getHeight(), attachment.getWidth(), attachment.getHeight(),
Optional.of(attachmentIdAndDigest.second()), Optional.of(attachmentIdAndDigest.second().getDigest()),
Optional.of(attachmentIdAndDigest.second().getIncrementalDigest()),
attachment.getFileName(), attachment.getFileName(),
attachment.getVoiceNote(), attachment.getVoiceNote(),
attachment.isBorderless(), attachment.isBorderless(),
@ -811,7 +813,7 @@ public class SignalServiceMessageSender {
} }
private SignalServiceAttachmentPointer uploadAttachmentV3(SignalServiceAttachmentStream attachment, byte[] attachmentKey, PushAttachmentData attachmentData) throws IOException { private SignalServiceAttachmentPointer uploadAttachmentV3(SignalServiceAttachmentStream attachment, byte[] attachmentKey, PushAttachmentData attachmentData) throws IOException {
byte[] digest = socket.uploadAttachment(attachmentData); AttachmentDigest digest = socket.uploadAttachment(attachmentData);
return new SignalServiceAttachmentPointer(attachmentData.getResumableUploadSpec().getCdnNumber(), return new SignalServiceAttachmentPointer(attachmentData.getResumableUploadSpec().getCdnNumber(),
new SignalServiceAttachmentRemoteId(attachmentData.getResumableUploadSpec().getCdnKey()), new SignalServiceAttachmentRemoteId(attachmentData.getResumableUploadSpec().getCdnKey()),
attachment.getContentType(), attachment.getContentType(),
@ -820,7 +822,8 @@ public class SignalServiceMessageSender {
attachment.getPreview(), attachment.getPreview(),
attachment.getWidth(), attachment.getWidth(),
attachment.getHeight(), attachment.getHeight(),
Optional.of(digest), Optional.of(digest.getDigest()),
Optional.ofNullable(digest.getIncrementalDigest()),
attachment.getFileName(), attachment.getFileName(),
attachment.getVoiceNote(), attachment.getVoiceNote(),
attachment.isBorderless(), attachment.isBorderless(),

View file

@ -8,6 +8,8 @@ package org.whispersystems.signalservice.api.crypto;
import org.signal.libsignal.protocol.InvalidMacException; import org.signal.libsignal.protocol.InvalidMacException;
import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice;
import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream;
import org.signal.libsignal.protocol.kdf.HKDFv3; import org.signal.libsignal.protocol.kdf.HKDFv3;
import org.whispersystems.signalservice.internal.util.ContentLengthInputStream; import org.whispersystems.signalservice.internal.util.ContentLengthInputStream;
import org.whispersystems.signalservice.internal.util.Util; import org.whispersystems.signalservice.internal.util.Util;
@ -51,7 +53,7 @@ public class AttachmentCipherInputStream extends FilterInputStream {
private long totalRead; private long totalRead;
private byte[] overflowBuffer; private byte[] overflowBuffer;
public static InputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest) public static InputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest)
throws InvalidMessageException, IOException throws InvalidMessageException, IOException
{ {
try { try {
@ -71,7 +73,18 @@ public class AttachmentCipherInputStream extends FilterInputStream {
verifyMac(fin, file.length(), mac, digest); verifyMac(fin, file.length(), mac, digest);
} }
InputStream inputStream = new AttachmentCipherInputStream(new FileInputStream(file), parts[0], file.length() - BLOCK_SIZE - mac.getMacLength()); final FileInputStream innerStream = new FileInputStream(file);
boolean hasIncrementalMac = incrementalDigest != null && incrementalDigest.length > 0;
InputStream wrap = !hasIncrementalMac ? innerStream
: new IncrementalMacInputStream(
innerStream,
parts[1],
ChunkSizeChoice.inferChunkSize(Math.max(Math.toIntExact(file.length()), 1)),
incrementalDigest);
InputStream inputStream = new AttachmentCipherInputStream(wrap, parts[0], file.length() - BLOCK_SIZE - mac.getMacLength());
if (plaintextLength != 0) { if (plaintextLength != 0) {
inputStream = new ContentLengthInputStream(inputStream, plaintextLength); inputStream = new ContentLengthInputStream(inputStream, plaintextLength);

View file

@ -25,6 +25,7 @@ public class SignalServiceAttachmentPointer extends SignalServiceAttachment {
private final Optional<Integer> size; private final Optional<Integer> size;
private final Optional<byte[]> preview; private final Optional<byte[]> preview;
private final Optional<byte[]> digest; private final Optional<byte[]> digest;
private final Optional<byte[]> incrementalDigest;
private final Optional<String> fileName; private final Optional<String> fileName;
private final boolean voiceNote; private final boolean voiceNote;
private final boolean borderless; private final boolean borderless;
@ -44,6 +45,7 @@ public class SignalServiceAttachmentPointer extends SignalServiceAttachment {
int width, int width,
int height, int height,
Optional<byte[]> digest, Optional<byte[]> digest,
Optional<byte[]> incrementalDigest,
Optional<String> fileName, Optional<String> fileName,
boolean voiceNote, boolean voiceNote,
boolean borderless, boolean borderless,
@ -53,21 +55,22 @@ public class SignalServiceAttachmentPointer extends SignalServiceAttachment {
long uploadTimestamp) long uploadTimestamp)
{ {
super(contentType); super(contentType);
this.cdnNumber = cdnNumber; this.cdnNumber = cdnNumber;
this.remoteId = remoteId; this.remoteId = remoteId;
this.key = key; this.key = key;
this.size = size; this.size = size;
this.preview = preview; this.preview = preview;
this.width = width; this.width = width;
this.height = height; this.height = height;
this.digest = digest; this.digest = digest;
this.fileName = fileName; this.incrementalDigest = incrementalDigest;
this.voiceNote = voiceNote; this.fileName = fileName;
this.borderless = borderless; this.voiceNote = voiceNote;
this.caption = caption; this.borderless = borderless;
this.blurHash = blurHash; this.caption = caption;
this.uploadTimestamp = uploadTimestamp; this.blurHash = blurHash;
this.gif = gif; this.uploadTimestamp = uploadTimestamp;
this.gif = gif;
} }
public int getCdnNumber() { public int getCdnNumber() {
@ -108,6 +111,10 @@ public class SignalServiceAttachmentPointer extends SignalServiceAttachment {
return digest; return digest;
} }
public Optional<byte[]> getincrementalDigest() {
return incrementalDigest;
}
public boolean getVoiceNote() { public boolean getVoiceNote() {
return voiceNote; return voiceNote;
} }

View file

@ -25,6 +25,7 @@ public final class AttachmentPointerUtil {
pointer.hasThumbnail() ? Optional.of(pointer.getThumbnail().toByteArray()): Optional.empty(), pointer.hasThumbnail() ? Optional.of(pointer.getThumbnail().toByteArray()): Optional.empty(),
pointer.getWidth(), pointer.getHeight(), pointer.getWidth(), pointer.getHeight(),
pointer.hasDigest() ? Optional.of(pointer.getDigest().toByteArray()) : Optional.empty(), pointer.hasDigest() ? Optional.of(pointer.getDigest().toByteArray()) : Optional.empty(),
pointer.hasIncrementalDigest() ? Optional.of(pointer.getIncrementalDigest().toByteArray()) : Optional.empty(),
pointer.hasFileName() ? Optional.of(pointer.getFileName()) : Optional.empty(), pointer.hasFileName() ? Optional.of(pointer.getFileName()) : Optional.empty(),
(pointer.getFlags() & FlagUtil.toBinaryFlag(SignalServiceProtos.AttachmentPointer.Flags.VOICE_MESSAGE_VALUE)) != 0, (pointer.getFlags() & FlagUtil.toBinaryFlag(SignalServiceProtos.AttachmentPointer.Flags.VOICE_MESSAGE_VALUE)) != 0,
(pointer.getFlags() & FlagUtil.toBinaryFlag(SignalServiceProtos.AttachmentPointer.Flags.BORDERLESS_VALUE)) != 0, (pointer.getFlags() & FlagUtil.toBinaryFlag(SignalServiceProtos.AttachmentPointer.Flags.BORDERLESS_VALUE)) != 0,

View file

@ -0,0 +1,8 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.internal.crypto
data class AttachmentDigest(val digest: ByteArray, val incrementalDigest: ByteArray?)

View file

@ -112,6 +112,7 @@ import org.whispersystems.signalservice.internal.configuration.SignalUrl;
import org.whispersystems.signalservice.internal.contacts.entities.KeyBackupRequest; import org.whispersystems.signalservice.internal.contacts.entities.KeyBackupRequest;
import org.whispersystems.signalservice.internal.contacts.entities.KeyBackupResponse; import org.whispersystems.signalservice.internal.contacts.entities.KeyBackupResponse;
import org.whispersystems.signalservice.internal.contacts.entities.TokenResponse; import org.whispersystems.signalservice.internal.contacts.entities.TokenResponse;
import org.whispersystems.signalservice.internal.crypto.AttachmentDigest;
import org.whispersystems.signalservice.internal.push.exceptions.ForbiddenException; import org.whispersystems.signalservice.internal.push.exceptions.ForbiddenException;
import org.whispersystems.signalservice.internal.push.exceptions.GroupExistsException; import org.whispersystems.signalservice.internal.push.exceptions.GroupExistsException;
import org.whispersystems.signalservice.internal.push.exceptions.GroupMismatchedDevicesException; import org.whispersystems.signalservice.internal.push.exceptions.GroupMismatchedDevicesException;
@ -1345,7 +1346,7 @@ public class PushServiceSocket {
} }
} }
public byte[] uploadGroupV2Avatar(byte[] avatarCipherText, AvatarUploadAttributes uploadAttributes) public AttachmentDigest uploadGroupV2Avatar(byte[] avatarCipherText, AvatarUploadAttributes uploadAttributes)
throws IOException throws IOException
{ {
return uploadToCdn0(AVATAR_UPLOAD_PATH, uploadAttributes.getAcl(), uploadAttributes.getKey(), return uploadToCdn0(AVATAR_UPLOAD_PATH, uploadAttributes.getAcl(), uploadAttributes.getKey(),
@ -1358,17 +1359,17 @@ public class PushServiceSocket {
null, null); null, null);
} }
public Pair<Long, byte[]> uploadAttachment(PushAttachmentData attachment, AttachmentV2UploadAttributes uploadAttributes) public Pair<Long, AttachmentDigest> uploadAttachment(PushAttachmentData attachment, AttachmentV2UploadAttributes uploadAttributes)
throws PushNetworkException, NonSuccessfulResponseCodeException throws PushNetworkException, NonSuccessfulResponseCodeException
{ {
long id = Long.parseLong(uploadAttributes.getAttachmentId()); long id = Long.parseLong(uploadAttributes.getAttachmentId());
byte[] digest = uploadToCdn0(ATTACHMENT_UPLOAD_PATH, uploadAttributes.getAcl(), uploadAttributes.getKey(), AttachmentDigest digest = uploadToCdn0(ATTACHMENT_UPLOAD_PATH, uploadAttributes.getAcl(), uploadAttributes.getKey(),
uploadAttributes.getPolicy(), uploadAttributes.getAlgorithm(), uploadAttributes.getPolicy(), uploadAttributes.getAlgorithm(),
uploadAttributes.getCredential(), uploadAttributes.getDate(), uploadAttributes.getCredential(), uploadAttributes.getDate(),
uploadAttributes.getSignature(), attachment.getData(), uploadAttributes.getSignature(), attachment.getData(),
"application/octet-stream", attachment.getDataSize(), "application/octet-stream", attachment.getDataSize(),
attachment.getOutputStreamFactory(), attachment.getListener(), attachment.getOutputStreamFactory(), attachment.getListener(),
attachment.getCancelationSignal()); attachment.getCancelationSignal());
return new Pair<>(id, digest); return new Pair<>(id, digest);
} }
@ -1382,7 +1383,7 @@ public class PushServiceSocket {
System.currentTimeMillis() + CDN2_RESUMABLE_LINK_LIFETIME_MILLIS); System.currentTimeMillis() + CDN2_RESUMABLE_LINK_LIFETIME_MILLIS);
} }
public byte[] uploadAttachment(PushAttachmentData attachment) throws IOException { public AttachmentDigest uploadAttachment(PushAttachmentData attachment) throws IOException {
if (attachment.getResumableUploadSpec() == null || attachment.getResumableUploadSpec().getExpirationTimestamp() < System.currentTimeMillis()) { if (attachment.getResumableUploadSpec() == null || attachment.getResumableUploadSpec().getExpirationTimestamp() < System.currentTimeMillis()) {
throw new ResumeLocationInvalidException(); throw new ResumeLocationInvalidException();
@ -1472,11 +1473,11 @@ public class PushServiceSocket {
} }
} }
private byte[] uploadToCdn0(String path, String acl, String key, String policy, String algorithm, private AttachmentDigest uploadToCdn0(String path, String acl, String key, String policy, String algorithm,
String credential, String date, String signature, String credential, String date, String signature,
InputStream data, String contentType, long length, InputStream data, String contentType, long length,
OutputStreamFactory outputStreamFactory, ProgressListener progressListener, OutputStreamFactory outputStreamFactory, ProgressListener progressListener,
CancelationSignal cancelationSignal) CancelationSignal cancelationSignal)
throws PushNetworkException, NonSuccessfulResponseCodeException throws PushNetworkException, NonSuccessfulResponseCodeException
{ {
ConnectionHolder connectionHolder = getRandom(cdnClientsMap.get(0), random); ConnectionHolder connectionHolder = getRandom(cdnClientsMap.get(0), random);
@ -1516,7 +1517,7 @@ public class PushServiceSocket {
} }
try (Response response = call.execute()) { try (Response response = call.execute()) {
if (response.isSuccessful()) return file.getTransmittedDigest(); if (response.isSuccessful()) return file.getAttachmentDigest();
else throw new NonSuccessfulResponseCodeException(response.code(), "Response: " + response); else throw new NonSuccessfulResponseCodeException(response.code(), "Response: " + response);
} catch (PushNetworkException | NonSuccessfulResponseCodeException e) { } catch (PushNetworkException | NonSuccessfulResponseCodeException e) {
throw e; throw e;
@ -1577,7 +1578,7 @@ public class PushServiceSocket {
} }
} }
private byte[] uploadToCdn2(String resumableUrl, InputStream data, String contentType, long length, OutputStreamFactory outputStreamFactory, ProgressListener progressListener, CancelationSignal cancelationSignal) throws IOException { private AttachmentDigest uploadToCdn2(String resumableUrl, InputStream data, String contentType, long length, OutputStreamFactory outputStreamFactory, ProgressListener progressListener, CancelationSignal cancelationSignal) throws IOException {
ConnectionHolder connectionHolder = getRandom(cdnClientsMap.get(2), random); ConnectionHolder connectionHolder = getRandom(cdnClientsMap.get(2), random);
OkHttpClient okHttpClient = connectionHolder.getClient() OkHttpClient okHttpClient = connectionHolder.getClient()
.newBuilder() .newBuilder()
@ -1593,7 +1594,7 @@ public class PushServiceSocket {
try (NowhereBufferedSink buffer = new NowhereBufferedSink()) { try (NowhereBufferedSink buffer = new NowhereBufferedSink()) {
file.writeTo(buffer); file.writeTo(buffer);
} }
return file.getTransmittedDigest(); return file.getAttachmentDigest();
} }
Request.Builder request = new Request.Builder().url(buildConfiguredUrl(connectionHolder, resumableUrl)) Request.Builder request = new Request.Builder().url(buildConfiguredUrl(connectionHolder, resumableUrl))
@ -1611,7 +1612,7 @@ public class PushServiceSocket {
} }
try (Response response = call.execute()) { try (Response response = call.execute()) {
if (response.isSuccessful()) return file.getTransmittedDigest(); if (response.isSuccessful()) return file.getAttachmentDigest();
else throw new NonSuccessfulResponseCodeException(response.code(), "Response: " + response); else throw new NonSuccessfulResponseCodeException(response.code(), "Response: " + response);
} catch (PushNetworkException | NonSuccessfulResponseCodeException e) { } catch (PushNetworkException | NonSuccessfulResponseCodeException e) {
throw e; throw e;

View file

@ -1,25 +0,0 @@
package org.whispersystems.signalservice.internal.push.http;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherOutputStream;
import org.whispersystems.signalservice.api.crypto.DigestingOutputStream;
import java.io.IOException;
import java.io.OutputStream;
public class AttachmentCipherOutputStreamFactory implements OutputStreamFactory {
private final byte[] key;
private final byte[] iv;
public AttachmentCipherOutputStreamFactory(byte[] key, byte[] iv) {
this.key = key;
this.iv = iv;
}
@Override
public DigestingOutputStream createFor(OutputStream wrap) throws IOException {
return new AttachmentCipherOutputStream(key, iv, wrap);
}
}

View file

@ -0,0 +1,40 @@
package org.whispersystems.signalservice.internal.push.http
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice
import org.signal.libsignal.protocol.incrementalmac.IncrementalMacOutputStream
import org.whispersystems.signalservice.api.crypto.AttachmentCipherOutputStream
import org.whispersystems.signalservice.api.crypto.DigestingOutputStream
import java.io.IOException
import java.io.OutputStream
/**
* Creates [AttachmentCipherOutputStream] using the provided [key] and [iv].
*
* [createFor] is straightforward, and is the legacy behavior.
* [createIncrementalFor] first wraps the stream in an [IncrementalMacOutputStream] to calculate MAC digests on chunks as the stream is written to.
*
* @property key
* @property iv
*/
class AttachmentCipherOutputStreamFactory(private val key: ByteArray, private val iv: ByteArray) : OutputStreamFactory {
companion object {
private const val AES_KEY_LENGTH = 32
}
@Throws(IOException::class)
override fun createFor(wrap: OutputStream): DigestingOutputStream {
return AttachmentCipherOutputStream(key, iv, wrap)
}
@Throws(IOException::class)
fun createIncrementalFor(wrap: OutputStream?, length: Long, incrementalDigestOut: OutputStream?): DigestingOutputStream {
if (length > Int.MAX_VALUE) {
throw IllegalArgumentException("Attachment length overflows int!")
}
val privateKey = key.sliceArray(AES_KEY_LENGTH until key.size)
val chunkSizeChoice = ChunkSizeChoice.inferChunkSize(length.toInt().coerceAtLeast(1))
val incrementalStream = IncrementalMacOutputStream(wrap, privateKey, chunkSizeChoice, incrementalDigestOut)
return createFor(incrementalStream)
}
}

View file

@ -1,87 +0,0 @@
package org.whispersystems.signalservice.internal.push.http;
import org.whispersystems.signalservice.api.crypto.DigestingOutputStream;
import org.whispersystems.signalservice.api.crypto.SkippingOutputStream;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener;
import org.whispersystems.signalservice.api.util.Preconditions;
import java.io.IOException;
import java.io.InputStream;
import okhttp3.MediaType;
import okhttp3.RequestBody;
import okio.BufferedSink;
public class DigestingRequestBody extends RequestBody {
private final InputStream inputStream;
private final OutputStreamFactory outputStreamFactory;
private final String contentType;
private final long contentLength;
private final ProgressListener progressListener;
private final CancelationSignal cancelationSignal;
private final long contentStart;
private byte[] digest;
public DigestingRequestBody(InputStream inputStream,
OutputStreamFactory outputStreamFactory,
String contentType, long contentLength,
ProgressListener progressListener,
CancelationSignal cancelationSignal,
long contentStart)
{
Preconditions.checkArgument(contentLength >= contentStart);
Preconditions.checkArgument(contentStart >= 0);
this.inputStream = inputStream;
this.outputStreamFactory = outputStreamFactory;
this.contentType = contentType;
this.contentLength = contentLength;
this.progressListener = progressListener;
this.cancelationSignal = cancelationSignal;
this.contentStart = contentStart;
}
@Override
public MediaType contentType() {
return MediaType.parse(contentType);
}
@Override
public void writeTo(BufferedSink sink) throws IOException {
DigestingOutputStream outputStream = outputStreamFactory.createFor(new SkippingOutputStream(contentStart, sink.outputStream()));
byte[] buffer = new byte[8192];
int read;
long total = 0;
while ((read = inputStream.read(buffer, 0, buffer.length)) != -1) {
if (cancelationSignal != null && cancelationSignal.isCanceled()) {
throw new IOException("Canceled!");
}
outputStream.write(buffer, 0, read);
total += read;
if (progressListener != null) {
progressListener.onAttachmentProgress(contentLength, total);
}
}
outputStream.flush();
digest = outputStream.getTransmittedDigest();
}
@Override
public long contentLength() {
if (contentLength > 0) return contentLength - contentStart;
else return -1;
}
public byte[] getTransmittedDigest() {
return digest;
}
}

View file

@ -0,0 +1,80 @@
package org.whispersystems.signalservice.internal.push.http
import okhttp3.MediaType
import okhttp3.RequestBody
import okio.BufferedSink
import org.whispersystems.signalservice.api.crypto.DigestingOutputStream
import org.whispersystems.signalservice.api.crypto.SkippingOutputStream
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment
import org.whispersystems.signalservice.internal.crypto.AttachmentDigest
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream
/**
* This [RequestBody] encrypts the data written to it before it is sent.
*/
class DigestingRequestBody(
private val inputStream: InputStream,
private val outputStreamFactory: OutputStreamFactory,
private val contentType: String,
private val contentLength: Long,
private val progressListener: SignalServiceAttachment.ProgressListener?,
private val cancelationSignal: CancelationSignal?,
private val contentStart: Long
) : RequestBody() {
lateinit var transmittedDigest: ByteArray
private set
var incrementalDigest: ByteArray? = null
private set
init {
require(contentLength >= contentStart)
require(contentStart >= 0)
}
override fun contentType(): MediaType? {
return MediaType.parse(contentType)
}
@Throws(IOException::class)
override fun writeTo(sink: BufferedSink) {
val digestStream = ByteArrayOutputStream()
val inner = SkippingOutputStream(contentStart, sink.outputStream())
val outputStream: DigestingOutputStream = if (outputStreamFactory is AttachmentCipherOutputStreamFactory) {
outputStreamFactory.createIncrementalFor(inner, contentLength, digestStream)
} else {
outputStreamFactory.createFor(inner)
}
val buffer = ByteArray(8192)
var read: Int
var total: Long = 0
while (inputStream.read(buffer, 0, buffer.size).also { read = it } != -1) {
if (cancelationSignal?.isCanceled == true) {
throw IOException("Canceled!")
}
outputStream.write(buffer, 0, read)
total += read.toLong()
progressListener?.onAttachmentProgress(contentLength, total)
}
outputStream.flush()
outputStream.close()
digestStream.close()
incrementalDigest = digestStream.toByteArray()
transmittedDigest = outputStream.transmittedDigest
}
override fun contentLength(): Long {
return if (contentLength > 0) contentLength - contentStart else -1
}
fun getAttachmentDigest() = AttachmentDigest(transmittedDigest, incrementalDigest)
companion object {
const val TAG = "DigestingRequestBody"
}
}

View file

@ -667,20 +667,21 @@ message AttachmentPointer {
fixed64 cdnId = 1; fixed64 cdnId = 1;
string cdnKey = 15; string cdnKey = 15;
} }
optional string contentType = 2; optional string contentType = 2;
optional bytes key = 3; optional bytes key = 3;
optional uint32 size = 4; optional uint32 size = 4;
optional bytes thumbnail = 5; optional bytes thumbnail = 5;
optional bytes digest = 6; optional bytes digest = 6;
optional string fileName = 7; optional bytes incrementalDigest = 16;
optional uint32 flags = 8; optional string fileName = 7;
optional uint32 width = 9; optional uint32 flags = 8;
optional uint32 height = 10; optional uint32 width = 9;
optional string caption = 11; optional uint32 height = 10;
optional string blurHash = 12; optional string caption = 11;
optional uint64 uploadTimestamp = 13; optional string blurHash = 12;
optional uint32 cdnNumber = 14; optional uint64 uploadTimestamp = 13;
// Next ID: 16 optional uint32 cdnNumber = 14;
// Next ID: 17
} }
message GroupContext { message GroupContext {

View file

@ -3,6 +3,7 @@ package org.whispersystems.signalservice.api.crypto;
import org.conscrypt.Conscrypt; import org.conscrypt.Conscrypt;
import org.junit.Test; import org.junit.Test;
import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.incrementalmac.InvalidMacException;
import org.signal.libsignal.protocol.kdf.HKDFv3; import org.signal.libsignal.protocol.kdf.HKDFv3;
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream; import org.whispersystems.signalservice.internal.crypto.PaddingInputStream;
import org.whispersystems.signalservice.internal.push.http.AttachmentCipherOutputStreamFactory; import org.whispersystems.signalservice.internal.push.http.AttachmentCipherOutputStreamFactory;
@ -17,9 +18,11 @@ import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.security.Security; import java.security.Security;
import java.util.Arrays; import java.util.Arrays;
import java.util.Random;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.whispersystems.signalservice.testutil.LibSignalLibraryUtil.assumeLibSignalSupportedOnOS; import static org.whispersystems.signalservice.testutil.LibSignalLibraryUtil.assumeLibSignalSupportedOnOS;
public final class AttachmentCipherTest { public final class AttachmentCipherTest {
@ -32,9 +35,9 @@ public final class AttachmentCipherTest {
public void attachment_encryptDecrypt() throws IOException, InvalidMessageException { public void attachment_encryptDecrypt() throws IOException, InvalidMessageException {
byte[] key = Util.getSecretBytes(64); byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = "Peter Parker".getBytes(); byte[] plaintextInput = "Peter Parker".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key); EncryptResult encryptResult = encryptData(plaintextInput, key, true);
File cipherFile = writeToFile(encryptResult.ciphertext); File cipherFile = writeToFile(encryptResult.ciphertext);
InputStream inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, encryptResult.digest); InputStream inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, encryptResult.digest, encryptResult.incrementalDigest);
byte[] plaintextOutput = readInputStreamFully(inputStream); byte[] plaintextOutput = readInputStreamFully(inputStream);
assertArrayEquals(plaintextInput, plaintextOutput); assertArrayEquals(plaintextInput, plaintextOutput);
@ -46,9 +49,9 @@ public final class AttachmentCipherTest {
public void attachment_encryptDecryptEmpty() throws IOException, InvalidMessageException { public void attachment_encryptDecryptEmpty() throws IOException, InvalidMessageException {
byte[] key = Util.getSecretBytes(64); byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = "".getBytes(); byte[] plaintextInput = "".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key); EncryptResult encryptResult = encryptData(plaintextInput, key, true);
File cipherFile = writeToFile(encryptResult.ciphertext); File cipherFile = writeToFile(encryptResult.ciphertext);
InputStream inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, encryptResult.digest); InputStream inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, encryptResult.digest, encryptResult.incrementalDigest);
byte[] plaintextOutput = readInputStreamFully(inputStream); byte[] plaintextOutput = readInputStreamFully(inputStream);
assertArrayEquals(plaintextInput, plaintextOutput); assertArrayEquals(plaintextInput, plaintextOutput);
@ -57,19 +60,19 @@ public final class AttachmentCipherTest {
} }
@Test @Test
public void attachment_decryptFailOnBadKey() throws IOException{ public void attachment_decryptFailOnBadKey() throws IOException {
File cipherFile = null; File cipherFile = null;
boolean hitCorrectException = false; boolean hitCorrectException = false;
try { try {
byte[] key = Util.getSecretBytes(64); byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = "Gwen Stacy".getBytes(); byte[] plaintextInput = "Gwen Stacy".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key); EncryptResult encryptResult = encryptData(plaintextInput, key, true);
byte[] badKey = new byte[64]; byte[] badKey = new byte[64];
cipherFile = writeToFile(encryptResult.ciphertext); cipherFile = writeToFile(encryptResult.ciphertext);
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, badKey, encryptResult.digest); AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, badKey, encryptResult.digest, encryptResult.incrementalDigest);
} catch (InvalidMessageException e) { } catch (InvalidMessageException e) {
hitCorrectException = true; hitCorrectException = true;
} finally { } finally {
@ -82,19 +85,19 @@ public final class AttachmentCipherTest {
} }
@Test @Test
public void attachment_decryptFailOnBadDigest() throws IOException{ public void attachment_decryptFailOnBadDigest() throws IOException {
File cipherFile = null; File cipherFile = null;
boolean hitCorrectException = false; boolean hitCorrectException = false;
try { try {
byte[] key = Util.getSecretBytes(64); byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = "Mary Jane Watson".getBytes(); byte[] plaintextInput = "Mary Jane Watson".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key); EncryptResult encryptResult = encryptData(plaintextInput, key, true);
byte[] badDigest = new byte[32]; byte[] badDigest = new byte[32];
cipherFile = writeToFile(encryptResult.ciphertext); cipherFile = writeToFile(encryptResult.ciphertext);
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, badDigest); AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, badDigest, encryptResult.incrementalDigest);
} catch (InvalidMessageException e) { } catch (InvalidMessageException e) {
hitCorrectException = true; hitCorrectException = true;
} finally { } finally {
@ -106,9 +109,42 @@ public final class AttachmentCipherTest {
assertTrue(hitCorrectException); assertTrue(hitCorrectException);
} }
@Test
public void attachment_decryptFailOnBadIncrementalDigest() throws IOException {
File cipherFile = null;
boolean hitCorrectException = false;
try {
byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = new byte[1000000];
new Random().nextBytes(plaintextInput);
EncryptResult encryptResult = encryptData(plaintextInput, key, true);
byte[] badDigest = Util.getSecretBytes(encryptResult.incrementalDigest.length);
cipherFile = writeToFile(encryptResult.ciphertext);
InputStream decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, encryptResult.digest, badDigest);
byte[] plaintextOutput = readInputStreamFully(decryptedStream);
fail();
} catch (InvalidMacException e) {
hitCorrectException = true;
} catch (InvalidMessageException e) {
hitCorrectException = false;
} finally {
if (cipherFile != null) {
cipherFile.delete();
}
}
assertTrue(hitCorrectException);
}
@Test @Test
public void attachment_encryptDecryptPaddedContent() throws IOException, InvalidMessageException { public void attachment_encryptDecryptPaddedContent() throws IOException, InvalidMessageException {
int[] lengths = { 531, 600, 724, 1019, 1024 }; int[] lengths = { 531, 600, 724, 1019, 1024 };
for (int length : lengths) { for (int length : lengths) {
byte[] plaintextInput = new byte[length]; byte[] plaintextInput = new byte[length];
@ -117,24 +153,26 @@ public final class AttachmentCipherTest {
plaintextInput[i] = (byte) 0x97; plaintextInput[i] = (byte) 0x97;
} }
byte[] key = Util.getSecretBytes(64); byte[] key = Util.getSecretBytes(64);
ByteArrayInputStream inputStream = new ByteArrayInputStream(plaintextInput); byte[] iv = Util.getSecretBytes(16);
InputStream dataStream = new PaddingInputStream(inputStream, length); ByteArrayInputStream inputStream = new ByteArrayInputStream(plaintextInput);
ByteArrayOutputStream encryptedStream = new ByteArrayOutputStream(); InputStream paddedInputStream = new PaddingInputStream(inputStream, length);
DigestingOutputStream digestStream = new AttachmentCipherOutputStreamFactory(key, null).createFor(encryptedStream); ByteArrayOutputStream destinationOutputStream = new ByteArrayOutputStream();
ByteArrayOutputStream incrementalDigestOutputStream = new ByteArrayOutputStream();
DigestingOutputStream encryptingOutputStream = new AttachmentCipherOutputStreamFactory(key, iv).createIncrementalFor(destinationOutputStream, length, incrementalDigestOutputStream);
Util.copy(dataStream, digestStream); Util.copy(paddedInputStream, encryptingOutputStream);
digestStream.flush();
byte[] digest = digestStream.getTransmittedDigest(); encryptingOutputStream.flush();
byte[] encryptedData = encryptedStream.toByteArray(); encryptingOutputStream.close();
encryptedStream.close(); byte[] encryptedData = destinationOutputStream.toByteArray();
inputStream.close(); byte[] digest = encryptingOutputStream.getTransmittedDigest();
byte[] incrementalDigest = incrementalDigestOutputStream.toByteArray();
File cipherFile = writeToFile(encryptedData); File cipherFile = writeToFile(encryptedData);
InputStream decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, length, key, digest); InputStream decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, length, key, digest, incrementalDigest);
byte[] plaintextOutput = readInputStreamFully(decryptedStream); byte[] plaintextOutput = readInputStreamFully(decryptedStream);
assertArrayEquals(plaintextInput, plaintextOutput); assertArrayEquals(plaintextInput, plaintextOutput);
@ -149,13 +187,13 @@ public final class AttachmentCipherTest {
boolean hitCorrectException = false; boolean hitCorrectException = false;
try { try {
byte[] key = Util.getSecretBytes(64); byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = "Aunt May".getBytes(); byte[] plaintextInput = "Aunt May".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key); EncryptResult encryptResult = encryptData(plaintextInput, key, true);
cipherFile = writeToFile(encryptResult.ciphertext); cipherFile = writeToFile(encryptResult.ciphertext);
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, null); AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, null, encryptResult.incrementalDigest);
} catch (InvalidMessageException e) { } catch (InvalidMessageException e) {
hitCorrectException = true; hitCorrectException = true;
} finally { } finally {
@ -175,14 +213,14 @@ public final class AttachmentCipherTest {
try { try {
byte[] key = Util.getSecretBytes(64); byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = "Uncle Ben".getBytes(); byte[] plaintextInput = "Uncle Ben".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key); EncryptResult encryptResult = encryptData(plaintextInput, key, true);
byte[] badMacCiphertext = Arrays.copyOf(encryptResult.ciphertext, encryptResult.ciphertext.length); byte[] badMacCiphertext = Arrays.copyOf(encryptResult.ciphertext, encryptResult.ciphertext.length);
badMacCiphertext[badMacCiphertext.length - 1] += 1; badMacCiphertext[badMacCiphertext.length - 1] += 1;
cipherFile = writeToFile(badMacCiphertext); cipherFile = writeToFile(badMacCiphertext);
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, encryptResult.digest); AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.length, key, encryptResult.digest, encryptResult.incrementalDigest);
} catch (InvalidMessageException e) { } catch (InvalidMessageException e) {
hitCorrectException = true; hitCorrectException = true;
} finally { } finally {
@ -200,7 +238,7 @@ public final class AttachmentCipherTest {
byte[] packKey = Util.getSecretBytes(32); byte[] packKey = Util.getSecretBytes(32);
byte[] plaintextInput = "Peter Parker".getBytes(); byte[] plaintextInput = "Peter Parker".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, expandPackKey(packKey)); EncryptResult encryptResult = encryptData(plaintextInput, expandPackKey(packKey), true);
InputStream inputStream = AttachmentCipherInputStream.createForStickerData(encryptResult.ciphertext, packKey); InputStream inputStream = AttachmentCipherInputStream.createForStickerData(encryptResult.ciphertext, packKey);
byte[] plaintextOutput = readInputStreamFully(inputStream); byte[] plaintextOutput = readInputStreamFully(inputStream);
@ -213,7 +251,7 @@ public final class AttachmentCipherTest {
byte[] packKey = Util.getSecretBytes(32); byte[] packKey = Util.getSecretBytes(32);
byte[] plaintextInput = "".getBytes(); byte[] plaintextInput = "".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, expandPackKey(packKey)); EncryptResult encryptResult = encryptData(plaintextInput, expandPackKey(packKey), true);
InputStream inputStream = AttachmentCipherInputStream.createForStickerData(encryptResult.ciphertext, packKey); InputStream inputStream = AttachmentCipherInputStream.createForStickerData(encryptResult.ciphertext, packKey);
byte[] plaintextOutput = readInputStreamFully(inputStream); byte[] plaintextOutput = readInputStreamFully(inputStream);
@ -227,10 +265,10 @@ public final class AttachmentCipherTest {
boolean hitCorrectException = false; boolean hitCorrectException = false;
try { try {
byte[] packKey = Util.getSecretBytes(32); byte[] packKey = Util.getSecretBytes(32);
byte[] plaintextInput = "Gwen Stacy".getBytes(); byte[] plaintextInput = "Gwen Stacy".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, expandPackKey(packKey)); EncryptResult encryptResult = encryptData(plaintextInput, expandPackKey(packKey), true);
byte[] badPackKey = new byte[32]; byte[] badPackKey = new byte[32];
AttachmentCipherInputStream.createForStickerData(encryptResult.ciphertext, badPackKey); AttachmentCipherInputStream.createForStickerData(encryptResult.ciphertext, badPackKey);
} catch (InvalidMessageException e) { } catch (InvalidMessageException e) {
@ -249,7 +287,7 @@ public final class AttachmentCipherTest {
try { try {
byte[] packKey = Util.getSecretBytes(32); byte[] packKey = Util.getSecretBytes(32);
byte[] plaintextInput = "Uncle Ben".getBytes(); byte[] plaintextInput = "Uncle Ben".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, expandPackKey(packKey)); EncryptResult encryptResult = encryptData(plaintextInput, expandPackKey(packKey), true);
byte[] badMacCiphertext = Arrays.copyOf(encryptResult.ciphertext, encryptResult.ciphertext.length); byte[] badMacCiphertext = Arrays.copyOf(encryptResult.ciphertext, encryptResult.ciphertext.length);
badMacCiphertext[badMacCiphertext.length - 1] += 1; badMacCiphertext[badMacCiphertext.length - 1] += 1;
@ -262,15 +300,26 @@ public final class AttachmentCipherTest {
assertTrue(hitCorrectException); assertTrue(hitCorrectException);
} }
private static EncryptResult encryptData(byte[] data, byte[] keyMaterial) throws IOException { private static EncryptResult encryptData(byte[] data, byte[] keyMaterial, boolean withIncremental) throws IOException {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
AttachmentCipherOutputStream encryptStream = new AttachmentCipherOutputStream(keyMaterial, null, outputStream); ByteArrayOutputStream incrementalDigestOut = new ByteArrayOutputStream();
byte[] iv = Util.getSecretBytes(16);
AttachmentCipherOutputStreamFactory factory = new AttachmentCipherOutputStreamFactory(keyMaterial, iv);
DigestingOutputStream encryptStream;
if (withIncremental) {
encryptStream = factory.createIncrementalFor(outputStream, data.length, incrementalDigestOut);
} else {
encryptStream = factory.createFor(outputStream);
}
encryptStream.write(data); encryptStream.write(data);
encryptStream.flush(); encryptStream.flush();
encryptStream.close(); encryptStream.close();
incrementalDigestOut.close();
return new EncryptResult(outputStream.toByteArray(), encryptStream.getTransmittedDigest()); return new EncryptResult(outputStream.toByteArray(), encryptStream.getTransmittedDigest(), incrementalDigestOut.toByteArray());
} }
private static File writeToFile(byte[] data) throws IOException { private static File writeToFile(byte[] data) throws IOException {
@ -296,10 +345,12 @@ public final class AttachmentCipherTest {
private static class EncryptResult { private static class EncryptResult {
final byte[] ciphertext; final byte[] ciphertext;
final byte[] digest; final byte[] digest;
final byte[] incrementalDigest;
private EncryptResult(byte[] ciphertext, byte[] digest) { private EncryptResult(byte[] ciphertext, byte[] digest, byte[] incrementalDigest) {
this.ciphertext = ciphertext; this.ciphertext = ciphertext;
this.digest = digest; this.digest = digest;
this.incrementalDigest = incrementalDigest;
} }
} }
} }

View file

@ -23,7 +23,7 @@ public class DigestingRequestBodyTest {
private final OutputStreamFactory outputStreamFactory = new AttachmentCipherOutputStreamFactory(attachmentKey, attachmentIV); private final OutputStreamFactory outputStreamFactory = new AttachmentCipherOutputStreamFactory(attachmentKey, attachmentIV);
@Test @Test
public void givenSameKeyAndIV_whenIWriteToBuffer_thenIExpectSameTransmittedDigest() throws Exception { public void givenSameKeyAndIV_whenIWriteToBuffer_thenIExpectSameDigests() throws Exception {
DigestingRequestBody fromStart = getBody(0); DigestingRequestBody fromStart = getBody(0);
DigestingRequestBody fromMiddle = getBody(CONTENT_LENGTH / 2); DigestingRequestBody fromMiddle = getBody(CONTENT_LENGTH / 2);
@ -36,6 +36,7 @@ public class DigestingRequestBodyTest {
} }
assertArrayEquals(fromStart.getTransmittedDigest(), fromMiddle.getTransmittedDigest()); assertArrayEquals(fromStart.getTransmittedDigest(), fromMiddle.getTransmittedDigest());
assertArrayEquals(fromStart.getIncrementalDigest(), fromMiddle.getIncrementalDigest());
} }
@Test @Test