Make PreKeyWhisperMessage decrypt more reliably atomic.

This commit is contained in:
Moxie Marlinspike 2014-07-24 11:59:54 -07:00
parent 1eb3884b7a
commit c330eef7b9
2 changed files with 30 additions and 41 deletions

View file

@ -88,35 +88,32 @@ public class SessionBuilder {
* @throws org.whispersystems.libaxolotl.InvalidKeyException when the message is formatted incorrectly. * @throws org.whispersystems.libaxolotl.InvalidKeyException when the message is formatted incorrectly.
* @throws org.whispersystems.libaxolotl.UntrustedIdentityException when the {@link IdentityKey} of the sender is untrusted. * @throws org.whispersystems.libaxolotl.UntrustedIdentityException when the {@link IdentityKey} of the sender is untrusted.
*/ */
/*package*/ boolean process(PreKeyWhisperMessage message) /*package*/ void process(SessionRecord sessionRecord, PreKeyWhisperMessage message)
throws InvalidKeyIdException, InvalidKeyException, UntrustedIdentityException throws InvalidKeyIdException, InvalidKeyException, UntrustedIdentityException
{ {
int messageVersion = message.getMessageVersion(); int messageVersion = message.getMessageVersion();
IdentityKey theirIdentityKey = message.getIdentityKey(); IdentityKey theirIdentityKey = message.getIdentityKey();
boolean createdSession;
if (!identityKeyStore.isTrustedIdentity(recipientId, theirIdentityKey)) { if (!identityKeyStore.isTrustedIdentity(recipientId, theirIdentityKey)) {
throw new UntrustedIdentityException(); throw new UntrustedIdentityException();
} }
if (messageVersion == 2) createdSession = processV2(message); switch (messageVersion) {
else if (messageVersion == 3) createdSession = processV3(message); case 2: processV2(sessionRecord, message); break;
else throw new AssertionError("Unknown version: " + messageVersion); case 3: processV3(sessionRecord, message); break;
default: throw new AssertionError("Unknown version: " + messageVersion);
identityKeyStore.saveIdentity(recipientId, theirIdentityKey);
return createdSession;
} }
private boolean processV3(PreKeyWhisperMessage message) identityKeyStore.saveIdentity(recipientId, theirIdentityKey);
}
private void processV3(SessionRecord sessionRecord, PreKeyWhisperMessage message)
throws UntrustedIdentityException, InvalidKeyIdException, InvalidKeyException throws UntrustedIdentityException, InvalidKeyIdException, InvalidKeyException
{ {
SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId);
if (sessionRecord.hasSessionState(message.getMessageVersion(), message.getBaseKey().serialize())) { if (sessionRecord.hasSessionState(message.getMessageVersion(), message.getBaseKey().serialize())) {
Log.w(TAG, "We've already setup a session for this V3 message, letting bundled message fall through..."); Log.w(TAG, "We've already setup a session for this V3 message, letting bundled message fall through...");
return false; return;
} }
boolean simultaneousInitiate = sessionRecord.getSessionState().hasUnacknowledgedPreKeyMessage(); boolean simultaneousInitiate = sessionRecord.getSessionState().hasUnacknowledgedPreKeyMessage();
@ -147,16 +144,12 @@ public class SessionBuilder {
if (simultaneousInitiate) sessionRecord.getSessionState().setNeedsRefresh(true); if (simultaneousInitiate) sessionRecord.getSessionState().setNeedsRefresh(true);
sessionStore.storeSession(recipientId, deviceId, sessionRecord);
if (message.getPreKeyId() >= 0 && message.getPreKeyId() != Medium.MAX_VALUE) { if (message.getPreKeyId() >= 0 && message.getPreKeyId() != Medium.MAX_VALUE) {
preKeyStore.removePreKey(message.getPreKeyId()); preKeyStore.removePreKey(message.getPreKeyId());
} }
return true;
} }
private boolean processV2(PreKeyWhisperMessage message) private void processV2(SessionRecord sessionRecord, PreKeyWhisperMessage message)
throws UntrustedIdentityException, InvalidKeyIdException, InvalidKeyException throws UntrustedIdentityException, InvalidKeyIdException, InvalidKeyException
{ {
@ -164,10 +157,9 @@ public class SessionBuilder {
sessionStore.containsSession(recipientId, deviceId)) sessionStore.containsSession(recipientId, deviceId))
{ {
Log.w(TAG, "We've already processed the prekey part of this V2 session, letting bundled message fall through..."); Log.w(TAG, "We've already processed the prekey part of this V2 session, letting bundled message fall through...");
return false; return;
} }
SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId);
ECKeyPair ourPreKey = preKeyStore.loadPreKey(message.getPreKeyId()).getKeyPair(); ECKeyPair ourPreKey = preKeyStore.loadPreKey(message.getPreKeyId()).getKeyPair();
boolean simultaneousInitiate = sessionRecord.getSessionState().hasUnacknowledgedPreKeyMessage(); boolean simultaneousInitiate = sessionRecord.getSessionState().hasUnacknowledgedPreKeyMessage();
@ -193,10 +185,6 @@ public class SessionBuilder {
if (message.getPreKeyId() != Medium.MAX_VALUE) { if (message.getPreKeyId() != Medium.MAX_VALUE) {
preKeyStore.removePreKey(message.getPreKeyId()); preKeyStore.removePreKey(message.getPreKeyId());
} }
sessionStore.storeSession(recipientId, deviceId, sessionRecord);
return true;
} }
/** /**

View file

@ -145,17 +145,13 @@ public class SessionCipher {
InvalidKeyIdException, InvalidKeyException, UntrustedIdentityException, NoSessionException InvalidKeyIdException, InvalidKeyException, UntrustedIdentityException, NoSessionException
{ {
synchronized (SESSION_LOCK) { synchronized (SESSION_LOCK) {
boolean sessionCreated = sessionBuilder.process(ciphertext); SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId);
try { sessionBuilder.process(sessionRecord, ciphertext);
return decrypt(ciphertext.getWhisperMessage()); byte[] plaintext = decrypt(sessionRecord, ciphertext.getWhisperMessage());
} catch (InvalidMessageException | DuplicateMessageException | LegacyMessageException e) {
if (sessionCreated) {
sessionStore.deleteSession(recipientId, deviceId);
}
throw e; sessionStore.storeSession(recipientId, deviceId, sessionRecord);
} return plaintext;
} }
} }
@ -183,25 +179,31 @@ public class SessionCipher {
} }
SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId); SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId);
byte[] plaintext = decrypt(sessionRecord, ciphertext);
sessionStore.storeSession(recipientId, deviceId, sessionRecord);
return plaintext;
}
}
private byte[] decrypt(SessionRecord sessionRecord, WhisperMessage ciphertext)
throws DuplicateMessageException, LegacyMessageException, InvalidMessageException
{
synchronized (SESSION_LOCK) {
SessionState sessionState = sessionRecord.getSessionState(); SessionState sessionState = sessionRecord.getSessionState();
List<SessionState> previousStates = sessionRecord.getPreviousSessionStates(); List<SessionState> previousStates = sessionRecord.getPreviousSessionStates();
List<Exception> exceptions = new LinkedList<>(); List<Exception> exceptions = new LinkedList<>();
try { try {
byte[] plaintext = decrypt(sessionState, ciphertext); return decrypt(sessionState, ciphertext);
sessionStore.storeSession(recipientId, deviceId, sessionRecord);
return plaintext;
} catch (InvalidMessageException e) { } catch (InvalidMessageException e) {
exceptions.add(e); exceptions.add(e);
} }
for (SessionState previousState : previousStates) { for (SessionState previousState : previousStates) {
try { try {
byte[] plaintext = decrypt(previousState, ciphertext); return decrypt(previousState, ciphertext);
sessionStore.storeSession(recipientId, deviceId, sessionRecord);
return plaintext;
} catch (InvalidMessageException e) { } catch (InvalidMessageException e) {
exceptions.add(e); exceptions.add(e);
} }
@ -240,7 +242,6 @@ public class SessionCipher {
sessionState.clearUnacknowledgedPreKeyMessage(); sessionState.clearUnacknowledgedPreKeyMessage();
return plaintext; return plaintext;
} }
public int getRemoteRegistrationId() { public int getRemoteRegistrationId() {