From 68a2d5ed20b86b54bcede31e273f281383156e4b Mon Sep 17 00:00:00 2001 From: Jordan Rose Date: Wed, 23 Jun 2021 12:59:17 -0700 Subject: [PATCH] Reimplement ProfileCipherInputStream using libsignal-client. libsignal-client provides an AES-GCM streaming interface that can replace the implementation in AES-GCM-Provider. Using it from ProfileCipherInputStream requires some knowledge about the tag size of AES-GCM, but frees it from the JCE interface. Note that it remains a serious error to not read the *entire* stream, since the authentication tag is at the end! --- .../api/crypto/ProfileCipherInputStream.java | 85 +++++++++--------- .../api/crypto/ProfileCipherTest.java | 87 +++++++++++++++---- 2 files changed, 118 insertions(+), 54 deletions(-) diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/crypto/ProfileCipherInputStream.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/crypto/ProfileCipherInputStream.java index 8798297f29..af409c6b9c 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/crypto/ProfileCipherInputStream.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/crypto/ProfileCipherInputStream.java @@ -1,42 +1,33 @@ package org.whispersystems.signalservice.api.crypto; - +import org.signal.libsignal.crypto.Aes256GcmDecryption; import org.signal.zkgroup.profiles.ProfileKey; +import org.whispersystems.libsignal.InvalidKeyException; import org.whispersystems.signalservice.internal.util.Util; import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; -import java.security.InvalidAlgorithmParameterException; -import java.security.InvalidKeyException; -import java.security.NoSuchAlgorithmException; -import javax.crypto.BadPaddingException; -import javax.crypto.Cipher; -import javax.crypto.IllegalBlockSizeException; -import javax.crypto.NoSuchPaddingException; -import javax.crypto.ShortBufferException; -import javax.crypto.spec.GCMParameterSpec; -import javax.crypto.spec.SecretKeySpec; +import static org.signal.libsignal.crypto.Aes256GcmDecryption.TAG_SIZE_IN_BYTES; public class ProfileCipherInputStream extends FilterInputStream { - private final Cipher cipher; + private Aes256GcmDecryption aes; - private boolean finished = false; + // The buffer size must match the length of the authentication tag. + private byte[] buffer = new byte[TAG_SIZE_IN_BYTES]; + private byte[] swapBuffer = new byte[TAG_SIZE_IN_BYTES]; public ProfileCipherInputStream(InputStream in, ProfileKey key) throws IOException { super(in); try { - this.cipher = Cipher.getInstance("AES/GCM/NoPadding"); - byte[] nonce = new byte[12]; Util.readFully(in, nonce); + Util.readFully(in, buffer); - this.cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(key.serialize(), "AES"), new GCMParameterSpec(128, nonce)); - } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidAlgorithmParameterException e) { - throw new AssertionError(e); + this.aes = new Aes256GcmDecryption(key.serialize(), nonce, new byte[] {}); } catch (InvalidKeyException e) { throw new IOException(e); } @@ -54,31 +45,47 @@ public class ProfileCipherInputStream extends FilterInputStream { @Override public int read(byte[] output, int outputOffset, int outputLength) throws IOException { - if (finished) return -1; + if (aes == null) return -1; - try { - byte[] ciphertext = new byte[outputLength / 2]; - int read = in.read(ciphertext, 0, ciphertext.length); + int read = in.read(output, outputOffset, outputLength); - if (read == -1) { - if (cipher.getOutputSize(0) > outputLength) { - throw new AssertionError("Need: " + cipher.getOutputSize(0) + " but only have: " + outputLength); - } - - finished = true; - return cipher.doFinal(output, outputOffset); - } else { - if (cipher.getOutputSize(read) > outputLength) { - throw new AssertionError("Need: " + cipher.getOutputSize(read) + " but only have: " + outputLength); - } - - return cipher.update(ciphertext, 0, read, output, outputOffset); + if (read == -1) { + // We're done. The buffer has the final tag for authentication. + Aes256GcmDecryption aes = this.aes; + this.aes = null; + if (!aes.verifyTag(this.buffer)) { + throw new IOException("authentication of decrypted data failed"); } - } catch (IllegalBlockSizeException | ShortBufferException e) { - throw new AssertionError(e); - } catch (BadPaddingException e) { - throw new IOException(e); + return -1; } + + if (read < TAG_SIZE_IN_BYTES) { + // swapBuffer = buffer[read..] + output[offset..][..read] + // output[offset..][..read] = buffer[..read] + System.arraycopy(this.buffer, read, this.swapBuffer, 0, TAG_SIZE_IN_BYTES - read); + System.arraycopy(output, outputOffset, this.swapBuffer, TAG_SIZE_IN_BYTES - read, read); + System.arraycopy(this.buffer, 0, output, outputOffset, read); + } else if (read == TAG_SIZE_IN_BYTES) { + // swapBuffer = output[offset..][..read] + // output[offset..][..read] = buffer + System.arraycopy(output, outputOffset, this.swapBuffer, 0, read); + System.arraycopy(this.buffer, 0, output, outputOffset, read); + } else { + // swapBuffer = output[offset..][(read - TAG_SIZE)..read] + // output[offset..][TAG_SIZE..read] = output[offset..][..(read - TAG_SIZE)] + // output[offset..][..TAG_SIZE] = buffer + System.arraycopy(output, outputOffset + read - TAG_SIZE_IN_BYTES, this.swapBuffer, 0, TAG_SIZE_IN_BYTES); + System.arraycopy(output, outputOffset, output, outputOffset + TAG_SIZE_IN_BYTES, read - TAG_SIZE_IN_BYTES); + System.arraycopy(this.buffer, 0, output, outputOffset, TAG_SIZE_IN_BYTES); + } + + // Now swapBuffer has the buffer for next time. + byte[] temp = this.buffer; + this.buffer = this.swapBuffer; + this.swapBuffer = temp; + + aes.decrypt(output, outputOffset, read); + return read; } } diff --git a/libsignal/service/src/test/java/org/whispersystems/signalservice/api/crypto/ProfileCipherTest.java b/libsignal/service/src/test/java/org/whispersystems/signalservice/api/crypto/ProfileCipherTest.java index 6ab55cac9d..05d8b17d84 100644 --- a/libsignal/service/src/test/java/org/whispersystems/signalservice/api/crypto/ProfileCipherTest.java +++ b/libsignal/service/src/test/java/org/whispersystems/signalservice/api/crypto/ProfileCipherTest.java @@ -1,9 +1,8 @@ package org.whispersystems.signalservice.api.crypto; -import junit.framework.TestCase; - import org.conscrypt.Conscrypt; +import org.junit.Test; import org.signal.zkgroup.InvalidInputException; import org.signal.zkgroup.profiles.ProfileKey; import org.whispersystems.signalservice.internal.util.Util; @@ -11,14 +10,30 @@ import org.whispersystems.util.Base64; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.security.Security; -public class ProfileCipherTest extends TestCase { +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import static org.whispersystems.signalservice.testutil.LibSignalLibraryUtil.assumeLibSignalSupportedOnOS; + +public class ProfileCipherTest { + + private class TestByteArrayInputStream extends ByteArrayInputStream { + TestByteArrayInputStream(byte[] buffer) { + super(buffer); + } + + int getPos() { + return this.pos; + } + } static { Security.insertProviderAt(Conscrypt.newProvider(), 1); } + @Test public void testEncryptDecrypt() throws InvalidCiphertextException, InvalidInputException { ProfileKey key = new ProfileKey(Util.getSecretBytes(32)); ProfileCipher cipher = new ProfileCipher(key); @@ -27,6 +42,7 @@ public class ProfileCipherTest extends TestCase { assertEquals(plaintext, "Clement\0Duval"); } + @Test public void testEmpty() throws Exception { ProfileKey key = new ProfileKey(Util.getSecretBytes(32)); ProfileCipher cipher = new ProfileCipher(key); @@ -36,7 +52,32 @@ public class ProfileCipherTest extends TestCase { assertEquals(plaintext.length(), 0); } + private byte[] readStream(byte[] input, ProfileKey key, int bufferSize) throws Exception { + TestByteArrayInputStream bais = new TestByteArrayInputStream(input); + assertEquals(0, bais.getPos()); + + ProfileCipherInputStream in = new ProfileCipherInputStream(bais, key); + assertEquals(12 + 16, bais.getPos()); // initial read of nonce + tag-sized buffer + + ByteArrayOutputStream result = new ByteArrayOutputStream(); + byte[] buffer = new byte[bufferSize]; + + int pos = bais.getPos(); + int read; + while ((read = in.read(buffer)) != -1) { + assertEquals(pos + read, bais.getPos()); + pos += read; + result.write(buffer, 0, read); + } + + assertEquals(pos, input.length); + return result.toByteArray(); + } + + @Test public void testStreams() throws Exception { + assumeLibSignalSupportedOnOS(); + ProfileKey key = new ProfileKey(Util.getSecretBytes(32)); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ProfileCipherOutputStream out = new ProfileCipherOutputStream(baos, key); @@ -45,21 +86,35 @@ public class ProfileCipherTest extends TestCase { out.flush(); out.close(); - ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); - ProfileCipherInputStream in = new ProfileCipherInputStream(bais, key); + byte[] encrypted = baos.toByteArray(); - ByteArrayOutputStream result = new ByteArrayOutputStream(); - byte[] buffer = new byte[2048]; - - int read; - - while ((read = in.read(buffer)) != -1) { - result.write(buffer, 0, read); - } - - assertEquals(new String(result.toByteArray()), "This is an avatar"); + assertEquals(new String(readStream(encrypted, key, 2048)), "This is an avatar"); + assertEquals(new String(readStream(encrypted, key, 16 /* == block size */)), "This is an avatar"); + assertEquals(new String(readStream(encrypted, key, 5)), "This is an avatar"); } + @Test + public void testStreamBadAuthentication() throws Exception { + assumeLibSignalSupportedOnOS(); + + ProfileKey key = new ProfileKey(Util.getSecretBytes(32)); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ProfileCipherOutputStream out = new ProfileCipherOutputStream(baos, key); + + out.write("This is an avatar".getBytes()); + out.flush(); + out.close(); + + byte[] encrypted = baos.toByteArray(); + encrypted[encrypted.length - 1] ^= 1; + try { + readStream(encrypted, key, 2048); + fail("failed to verify authenticate tag"); + } catch (IOException e) { + } + } + + @Test public void testEncryptLengthBucket1() throws InvalidInputException { ProfileKey key = new ProfileKey(Util.getSecretBytes(32)); ProfileCipher cipher = new ProfileCipher(key); @@ -70,6 +125,7 @@ public class ProfileCipherTest extends TestCase { assertEquals(108, encoded.length()); } + @Test public void testEncryptLengthBucket2() throws InvalidInputException { ProfileKey key = new ProfileKey(Util.getSecretBytes(32)); ProfileCipher cipher = new ProfileCipher(key); @@ -80,6 +136,7 @@ public class ProfileCipherTest extends TestCase { assertEquals(380, encoded.length()); } + @Test public void testTargetNameLength() { assertEquals(53, ProfileCipher.getTargetNameLength("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")); assertEquals(53, ProfileCipher.getTargetNameLength("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1"));