Reimplement ProfileCipherInputStream using libsignal-client.

libsignal-client provides an AES-GCM streaming interface that can
replace the implementation in AES-GCM-Provider. Using it from
ProfileCipherInputStream requires some knowledge about the tag size of
AES-GCM, but frees it from the JCE interface.

Note that it remains a serious error to not read the *entire* stream,
since the authentication tag is at the end!
This commit is contained in:
Jordan Rose 2021-06-23 12:59:17 -07:00 committed by Greyson Parrelli
parent 35e9e31a7b
commit 68a2d5ed20
2 changed files with 118 additions and 54 deletions

View file

@ -1,42 +1,33 @@
package org.whispersystems.signalservice.api.crypto; package org.whispersystems.signalservice.api.crypto;
import org.signal.libsignal.crypto.Aes256GcmDecryption;
import org.signal.zkgroup.profiles.ProfileKey; import org.signal.zkgroup.profiles.ProfileKey;
import org.whispersystems.libsignal.InvalidKeyException;
import org.whispersystems.signalservice.internal.util.Util; import org.whispersystems.signalservice.internal.util.Util;
import java.io.FilterInputStream; import java.io.FilterInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import javax.crypto.BadPaddingException; import static org.signal.libsignal.crypto.Aes256GcmDecryption.TAG_SIZE_IN_BYTES;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.ShortBufferException;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.SecretKeySpec;
public class ProfileCipherInputStream extends FilterInputStream { public class ProfileCipherInputStream extends FilterInputStream {
private final Cipher cipher; private Aes256GcmDecryption aes;
private boolean finished = false; // The buffer size must match the length of the authentication tag.
private byte[] buffer = new byte[TAG_SIZE_IN_BYTES];
private byte[] swapBuffer = new byte[TAG_SIZE_IN_BYTES];
public ProfileCipherInputStream(InputStream in, ProfileKey key) throws IOException { public ProfileCipherInputStream(InputStream in, ProfileKey key) throws IOException {
super(in); super(in);
try { try {
this.cipher = Cipher.getInstance("AES/GCM/NoPadding");
byte[] nonce = new byte[12]; byte[] nonce = new byte[12];
Util.readFully(in, nonce); Util.readFully(in, nonce);
Util.readFully(in, buffer);
this.cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(key.serialize(), "AES"), new GCMParameterSpec(128, nonce)); this.aes = new Aes256GcmDecryption(key.serialize(), nonce, new byte[] {});
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidAlgorithmParameterException e) {
throw new AssertionError(e);
} catch (InvalidKeyException e) { } catch (InvalidKeyException e) {
throw new IOException(e); throw new IOException(e);
} }
@ -54,31 +45,47 @@ public class ProfileCipherInputStream extends FilterInputStream {
@Override @Override
public int read(byte[] output, int outputOffset, int outputLength) throws IOException { public int read(byte[] output, int outputOffset, int outputLength) throws IOException {
if (finished) return -1; if (aes == null) return -1;
try { int read = in.read(output, outputOffset, outputLength);
byte[] ciphertext = new byte[outputLength / 2];
int read = in.read(ciphertext, 0, ciphertext.length);
if (read == -1) { if (read == -1) {
if (cipher.getOutputSize(0) > outputLength) { // We're done. The buffer has the final tag for authentication.
throw new AssertionError("Need: " + cipher.getOutputSize(0) + " but only have: " + outputLength); Aes256GcmDecryption aes = this.aes;
} this.aes = null;
if (!aes.verifyTag(this.buffer)) {
finished = true; throw new IOException("authentication of decrypted data failed");
return cipher.doFinal(output, outputOffset);
} else {
if (cipher.getOutputSize(read) > outputLength) {
throw new AssertionError("Need: " + cipher.getOutputSize(read) + " but only have: " + outputLength);
}
return cipher.update(ciphertext, 0, read, output, outputOffset);
} }
} catch (IllegalBlockSizeException | ShortBufferException e) { return -1;
throw new AssertionError(e);
} catch (BadPaddingException e) {
throw new IOException(e);
} }
if (read < TAG_SIZE_IN_BYTES) {
// swapBuffer = buffer[read..] + output[offset..][..read]
// output[offset..][..read] = buffer[..read]
System.arraycopy(this.buffer, read, this.swapBuffer, 0, TAG_SIZE_IN_BYTES - read);
System.arraycopy(output, outputOffset, this.swapBuffer, TAG_SIZE_IN_BYTES - read, read);
System.arraycopy(this.buffer, 0, output, outputOffset, read);
} else if (read == TAG_SIZE_IN_BYTES) {
// swapBuffer = output[offset..][..read]
// output[offset..][..read] = buffer
System.arraycopy(output, outputOffset, this.swapBuffer, 0, read);
System.arraycopy(this.buffer, 0, output, outputOffset, read);
} else {
// swapBuffer = output[offset..][(read - TAG_SIZE)..read]
// output[offset..][TAG_SIZE..read] = output[offset..][..(read - TAG_SIZE)]
// output[offset..][..TAG_SIZE] = buffer
System.arraycopy(output, outputOffset + read - TAG_SIZE_IN_BYTES, this.swapBuffer, 0, TAG_SIZE_IN_BYTES);
System.arraycopy(output, outputOffset, output, outputOffset + TAG_SIZE_IN_BYTES, read - TAG_SIZE_IN_BYTES);
System.arraycopy(this.buffer, 0, output, outputOffset, TAG_SIZE_IN_BYTES);
}
// Now swapBuffer has the buffer for next time.
byte[] temp = this.buffer;
this.buffer = this.swapBuffer;
this.swapBuffer = temp;
aes.decrypt(output, outputOffset, read);
return read;
} }
} }

View file

@ -1,9 +1,8 @@
package org.whispersystems.signalservice.api.crypto; package org.whispersystems.signalservice.api.crypto;
import junit.framework.TestCase;
import org.conscrypt.Conscrypt; import org.conscrypt.Conscrypt;
import org.junit.Test;
import org.signal.zkgroup.InvalidInputException; import org.signal.zkgroup.InvalidInputException;
import org.signal.zkgroup.profiles.ProfileKey; import org.signal.zkgroup.profiles.ProfileKey;
import org.whispersystems.signalservice.internal.util.Util; import org.whispersystems.signalservice.internal.util.Util;
@ -11,14 +10,30 @@ import org.whispersystems.util.Base64;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.security.Security; import java.security.Security;
public class ProfileCipherTest extends TestCase { import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.whispersystems.signalservice.testutil.LibSignalLibraryUtil.assumeLibSignalSupportedOnOS;
public class ProfileCipherTest {
private class TestByteArrayInputStream extends ByteArrayInputStream {
TestByteArrayInputStream(byte[] buffer) {
super(buffer);
}
int getPos() {
return this.pos;
}
}
static { static {
Security.insertProviderAt(Conscrypt.newProvider(), 1); Security.insertProviderAt(Conscrypt.newProvider(), 1);
} }
@Test
public void testEncryptDecrypt() throws InvalidCiphertextException, InvalidInputException { public void testEncryptDecrypt() throws InvalidCiphertextException, InvalidInputException {
ProfileKey key = new ProfileKey(Util.getSecretBytes(32)); ProfileKey key = new ProfileKey(Util.getSecretBytes(32));
ProfileCipher cipher = new ProfileCipher(key); ProfileCipher cipher = new ProfileCipher(key);
@ -27,6 +42,7 @@ public class ProfileCipherTest extends TestCase {
assertEquals(plaintext, "Clement\0Duval"); assertEquals(plaintext, "Clement\0Duval");
} }
@Test
public void testEmpty() throws Exception { public void testEmpty() throws Exception {
ProfileKey key = new ProfileKey(Util.getSecretBytes(32)); ProfileKey key = new ProfileKey(Util.getSecretBytes(32));
ProfileCipher cipher = new ProfileCipher(key); ProfileCipher cipher = new ProfileCipher(key);
@ -36,7 +52,32 @@ public class ProfileCipherTest extends TestCase {
assertEquals(plaintext.length(), 0); assertEquals(plaintext.length(), 0);
} }
private byte[] readStream(byte[] input, ProfileKey key, int bufferSize) throws Exception {
TestByteArrayInputStream bais = new TestByteArrayInputStream(input);
assertEquals(0, bais.getPos());
ProfileCipherInputStream in = new ProfileCipherInputStream(bais, key);
assertEquals(12 + 16, bais.getPos()); // initial read of nonce + tag-sized buffer
ByteArrayOutputStream result = new ByteArrayOutputStream();
byte[] buffer = new byte[bufferSize];
int pos = bais.getPos();
int read;
while ((read = in.read(buffer)) != -1) {
assertEquals(pos + read, bais.getPos());
pos += read;
result.write(buffer, 0, read);
}
assertEquals(pos, input.length);
return result.toByteArray();
}
@Test
public void testStreams() throws Exception { public void testStreams() throws Exception {
assumeLibSignalSupportedOnOS();
ProfileKey key = new ProfileKey(Util.getSecretBytes(32)); ProfileKey key = new ProfileKey(Util.getSecretBytes(32));
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();
ProfileCipherOutputStream out = new ProfileCipherOutputStream(baos, key); ProfileCipherOutputStream out = new ProfileCipherOutputStream(baos, key);
@ -45,21 +86,35 @@ public class ProfileCipherTest extends TestCase {
out.flush(); out.flush();
out.close(); out.close();
ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); byte[] encrypted = baos.toByteArray();
ProfileCipherInputStream in = new ProfileCipherInputStream(bais, key);
ByteArrayOutputStream result = new ByteArrayOutputStream(); assertEquals(new String(readStream(encrypted, key, 2048)), "This is an avatar");
byte[] buffer = new byte[2048]; assertEquals(new String(readStream(encrypted, key, 16 /* == block size */)), "This is an avatar");
assertEquals(new String(readStream(encrypted, key, 5)), "This is an avatar");
int read;
while ((read = in.read(buffer)) != -1) {
result.write(buffer, 0, read);
}
assertEquals(new String(result.toByteArray()), "This is an avatar");
} }
@Test
public void testStreamBadAuthentication() throws Exception {
assumeLibSignalSupportedOnOS();
ProfileKey key = new ProfileKey(Util.getSecretBytes(32));
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ProfileCipherOutputStream out = new ProfileCipherOutputStream(baos, key);
out.write("This is an avatar".getBytes());
out.flush();
out.close();
byte[] encrypted = baos.toByteArray();
encrypted[encrypted.length - 1] ^= 1;
try {
readStream(encrypted, key, 2048);
fail("failed to verify authenticate tag");
} catch (IOException e) {
}
}
@Test
public void testEncryptLengthBucket1() throws InvalidInputException { public void testEncryptLengthBucket1() throws InvalidInputException {
ProfileKey key = new ProfileKey(Util.getSecretBytes(32)); ProfileKey key = new ProfileKey(Util.getSecretBytes(32));
ProfileCipher cipher = new ProfileCipher(key); ProfileCipher cipher = new ProfileCipher(key);
@ -70,6 +125,7 @@ public class ProfileCipherTest extends TestCase {
assertEquals(108, encoded.length()); assertEquals(108, encoded.length());
} }
@Test
public void testEncryptLengthBucket2() throws InvalidInputException { public void testEncryptLengthBucket2() throws InvalidInputException {
ProfileKey key = new ProfileKey(Util.getSecretBytes(32)); ProfileKey key = new ProfileKey(Util.getSecretBytes(32));
ProfileCipher cipher = new ProfileCipher(key); ProfileCipher cipher = new ProfileCipher(key);
@ -80,6 +136,7 @@ public class ProfileCipherTest extends TestCase {
assertEquals(380, encoded.length()); assertEquals(380, encoded.length());
} }
@Test
public void testTargetNameLength() { public void testTargetNameLength() {
assertEquals(53, ProfileCipher.getTargetNameLength("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")); assertEquals(53, ProfileCipher.getTargetNameLength("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"));
assertEquals(53, ProfileCipher.getTargetNameLength("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1")); assertEquals(53, ProfileCipher.getTargetNameLength("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1"));