Implement CdshV2Service.

This commit is contained in:
Greyson Parrelli 2022-02-28 11:22:58 -05:00 committed by Cody Henthorne
parent 7e063e8ad8
commit e552b5160f
6 changed files with 300 additions and 37 deletions

View file

@ -18,15 +18,14 @@ import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
/**
* Uses CDS to map E164's to UUIDs.
* Uses CDSHv1 to map E164's to UUIDs.
*/
class ContactDiscoveryV3 {
class ContactDiscoveryHsmV1 {
private static final String TAG = Log.tag(ContactDiscoveryV3.class);
private static final String TAG = Log.tag(ContactDiscoveryHsmV1.class);
private static final int MAX_NUMBERS = 20_500;
@ -47,7 +46,7 @@ class ContactDiscoveryV3 {
SignalServiceAccountManager accountManager = ApplicationDependencies.getSignalServiceAccountManager();
try {
Map<String, ACI> results = accountManager.getRegisteredUsersWithCdsh(sanitizedNumbers, BuildConfig.CDSH_PUBLIC_KEY, BuildConfig.CDSH_CODE_HASH);
Map<String, ACI> results = accountManager.getRegisteredUsersWithCdshV1(sanitizedNumbers, BuildConfig.CDSH_PUBLIC_KEY, BuildConfig.CDSH_CODE_HASH);
FuzzyPhoneNumberHelper.OutputResult outputResult = FuzzyPhoneNumberHelper.generateOutput(results, inputResult);
return new DirectoryResult(outputResult.getNumbers(), outputResult.getRewrites(), ignoredNumbers);

View file

@ -233,7 +233,7 @@ public class DirectoryHelper {
DirectoryResult result;
if (FeatureFlags.cdsh()) {
result = ContactDiscoveryV3.getDirectoryResult(databaseNumbers, systemNumbers);
result = ContactDiscoveryHsmV1.getDirectoryResult(databaseNumbers, systemNumbers);
} else {
result = ContactDiscoveryV2.getDirectoryResult(context, databaseNumbers, systemNumbers);
}

View file

@ -44,6 +44,7 @@ import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulRespons
import org.whispersystems.signalservice.api.push.exceptions.NotFoundException;
import org.whispersystems.signalservice.api.push.exceptions.PushNetworkException;
import org.whispersystems.signalservice.api.services.CdshV1Service;
import org.whispersystems.signalservice.api.services.CdshV2Service;
import org.whispersystems.signalservice.api.storage.SignalStorageCipher;
import org.whispersystems.signalservice.api.storage.SignalStorageManifest;
import org.whispersystems.signalservice.api.storage.SignalStorageModels;
@ -505,7 +506,7 @@ public class SignalServiceAccountManager {
}
}
public Map<String, ACI> getRegisteredUsersWithCdsh(Set<String> e164numbers, String hexPublicKey, String hexCodeHash)
public Map<String, ACI> getRegisteredUsersWithCdshV1(Set<String> e164numbers, String hexPublicKey, String hexCodeHash)
throws IOException
{
CdshAuthResponse auth = pushServiceSocket.getCdshAuth();
@ -530,6 +531,32 @@ public class SignalServiceAccountManager {
}
}
public CdshV2Service.Response getRegisteredUsersWithCdshV2(Set<String> previousE164s, Set<String> newE164s, Map<ServiceId, ProfileKey> serviceIds, Optional<byte[]> token, String hexPublicKey, String hexCodeHash)
throws IOException
{
CdshAuthResponse auth = pushServiceSocket.getCdshAuth();
CdshV2Service service = new CdshV2Service(configuration, hexPublicKey, hexCodeHash);
CdshV2Service.Request request = new CdshV2Service.Request(previousE164s, newE164s, serviceIds, token);
Single<ServiceResponse<CdshV2Service.Response>> single = service.getRegisteredUsers(auth.getUsername(), auth.getPassword(), request);
ServiceResponse<CdshV2Service.Response> serviceResponse;
try {
serviceResponse = single.blockingGet();
} catch (Exception e) {
throw new RuntimeException("Unexpected exception when retrieving registered users!", e);
}
if (serviceResponse.getResult().isPresent()) {
return serviceResponse.getResult().get();
} else if (serviceResponse.getApplicationError().isPresent()) {
throw new IOException(serviceResponse.getApplicationError().get());
} else if (serviceResponse.getExecutionError().isPresent()) {
throw new IOException(serviceResponse.getExecutionError().get());
} else {
throw new IOException("Missing result!");
}
}
public Optional<SignalStorageManifest> getStorageManifest(StorageKey storageKey) throws IOException {
try {

View file

@ -1,31 +1,28 @@
package org.whispersystems.signalservice.api.services;
import org.signal.cds.ClientRequest;
import org.signal.cds.ClientResponse;
import org.signal.libsignal.hsmenclave.HsmEnclaveClient;
import org.whispersystems.libsignal.logging.Log;
import org.whispersystems.libsignal.util.Pair;
import org.whispersystems.signalservice.api.push.ACI;
import org.whispersystems.signalservice.api.push.TrustStore;
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException;
import org.whispersystems.signalservice.api.util.Tls12SocketFactory;
import org.whispersystems.signalservice.internal.ServiceResponse;
import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration;
import org.whispersystems.signalservice.internal.util.BlacklistingTrustManager;
import org.whispersystems.signalservice.internal.util.Hex;
import org.whispersystems.signalservice.internal.util.Util;
import org.whispersystems.util.Base64;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
@ -40,6 +37,9 @@ import okhttp3.Response;
import okhttp3.WebSocket;
import okhttp3.WebSocketListener;
/**
* Handles the websocket and general lifecycle of a CDSH request.
*/
final class CdshSocket {
private static final String TAG = CdshSocket.class.getSimpleName();
@ -49,11 +49,13 @@ final class CdshSocket {
private final String baseUrl;
private final String hexPublicKey;
private final String hexCodeHash;
private final Version version;
CdshSocket(SignalServiceConfiguration configuration, String hexPublicKey, String hexCodeHash) {
CdshSocket(SignalServiceConfiguration configuration, String hexPublicKey, String hexCodeHash, Version version) {
this.baseUrl = configuration.getSignalCdshUrls()[0].getUrl();
this.hexPublicKey = hexPublicKey;
this.hexCodeHash = hexCodeHash;
this.version = version;
Pair<SSLSocketFactory, X509TrustManager> socketFactory = createTlsSocketFactory(configuration.getSignalCdshUrls()[0].getTrustStore());
@ -73,11 +75,11 @@ final class CdshSocket {
}
}
Observable<ClientResponse> connect(String username, String password, List<byte[]> requests) {
Observable<ClientResponse> connect(String username, String password, List<ClientRequest> requests) {
return Observable.create(emitter -> {
AtomicReference<Stage> stage = new AtomicReference<>(Stage.WAITING_TO_INITIALIZE);
AtomicReference<Stage> stage = new AtomicReference<>(Stage.WAITING_TO_INITIALIZE);
String url = String.format("%s/discovery/%s/%s", baseUrl, hexPublicKey, hexCodeHash);
String url = String.format("%s/discovery/%s/%s", baseUrl, hexPublicKey, hexCodeHash);
Request request = new Request.Builder()
.url(url)
.addHeader("Authorization", basicAuth(username, password))
@ -91,8 +93,10 @@ final class CdshSocket {
enclave.completeHandshake(bytes.toByteArray());
stage.set(Stage.WAITING_FOR_RESPONSE);
for (byte[] request : requests) {
webSocket.send(okio.ByteString.of(enclave.establishedSend(request)));
for (ClientRequest request : requests) {
byte[] plaintextBytes = requestToBytes(request, version);
byte[] ciphertextBytes = enclave.establishedSend(plaintextBytes);
webSocket.send(okio.ByteString.of(ciphertextBytes));
}
break;
@ -139,6 +143,17 @@ final class CdshSocket {
});
}
private static byte[] requestToBytes(ClientRequest request, Version version) {
ByteArrayOutputStream requestStream = new ByteArrayOutputStream();
try {
requestStream.write(version.getValue());
requestStream.write(request.toByteArray());
} catch (IOException e) {
throw new AssertionError("Failed to write bytes!");
}
return requestStream.toByteArray();
}
private static String basicAuth(String username, String password) {
return "Basic " + Base64.encodeBytes((username + ":" + password).getBytes(StandardCharsets.UTF_8));
}
@ -158,4 +173,18 @@ final class CdshSocket {
private enum Stage {
WAITING_TO_INITIALIZE, WAITING_FOR_RESPONSE, FAILURE
}
enum Version {
V1(1), V2(2);
private final int value;
Version(int value) {
this.value = value;
}
public int getValue() {
return value;
}
}
}

View file

@ -30,7 +30,6 @@ public final class CdshV1Service {
private static final String TAG = CdshV1Service.class.getSimpleName();
private static final int VERSION = 1;
private static final int MAX_E164S_PER_REQUEST = 5000;
private static final UUID EMPTY_ACI = new UUID(0, 0);
private static final int RESPONSE_ITEM_SIZE = 8 + 16 + 16; // 1 uint64 + 2 UUIDs
@ -38,14 +37,14 @@ public final class CdshV1Service {
private final CdshSocket cdshSocket;
public CdshV1Service(SignalServiceConfiguration configuration, String hexPublicKey, String hexCodeHash) {
this.cdshSocket = new CdshSocket(configuration, hexPublicKey, hexCodeHash);
this.cdshSocket = new CdshSocket(configuration, hexPublicKey, hexCodeHash, CdshSocket.Version.V1);
}
public Single<ServiceResponse<Map<String, ACI>>> getRegisteredUsers(String username, String password, Set<String> e164Numbers) {
List<String> addressBook = e164Numbers.stream().map(e -> e.substring(1)).collect(Collectors.toList());
return cdshSocket
.connect(username, password, buildPlaintextRequests(addressBook))
.connect(username, password, buildClientRequests(addressBook))
.map(CdshV1Service::parseEntries)
.collect(Collectors.toList())
.flatMap(pages -> {
@ -83,10 +82,10 @@ public final class CdshV1Service {
return out;
}
private static List<byte[]> buildPlaintextRequests(List<String> addressBook) {
List<byte[]> out = new ArrayList<>((addressBook.size() / MAX_E164S_PER_REQUEST) + 1);
ByteString.Output e164Page = ByteString.newOutput();
int pageSize = 0;
private static List<ClientRequest> buildClientRequests(List<String> addressBook) {
List<ClientRequest> out = new ArrayList<>((addressBook.size() / MAX_E164S_PER_REQUEST) + 1);
ByteString.Output e164Page = ByteString.newOutput();
int pageSize = 0;
for (String address : addressBook) {
if (pageSize >= MAX_E164S_PER_REQUEST) {
@ -111,17 +110,10 @@ public final class CdshV1Service {
return out;
}
private static byte[] e164sToRequest(ByteString e164s, boolean more) {
try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) {
outputStream.write(VERSION);
ClientRequest.newBuilder()
.setNewE164S(e164s)
.setHasMore(more)
.build()
.writeTo(outputStream);
return outputStream.toByteArray();
} catch (IOException e) {
throw new AssertionError("Failed to write protobuf to the output stream?");
}
private static ClientRequest e164sToRequest(ByteString e164s, boolean more) {
return ClientRequest.newBuilder()
.setNewE164S(e164s)
.setHasMore(more)
.build();
}
}

View file

@ -0,0 +1,216 @@
package org.whispersystems.signalservice.api.services;
import com.google.protobuf.ByteString;
import org.signal.cds.ClientRequest;
import org.signal.cds.ClientResponse;
import org.signal.zkgroup.profiles.ProfileKey;
import org.whispersystems.libsignal.util.ByteUtil;
import org.whispersystems.libsignal.util.guava.Optional;
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess;
import org.whispersystems.signalservice.api.push.ACI;
import org.whispersystems.signalservice.api.push.PNI;
import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException;
import org.whispersystems.signalservice.api.util.UuidUtil;
import org.whispersystems.signalservice.internal.ServiceResponse;
import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import io.reactivex.rxjava3.core.Single;
/**
* Handles network interactions with CDSHv2, the HSM-backed CDS service.
*/
public final class CdshV2Service {
private static final String TAG = CdshV2Service.class.getSimpleName();
private static final UUID EMPTY_UUID = new UUID(0, 0);
private static final int RESPONSE_ITEM_SIZE = 8 + 16 + 16; // 1 uint64 + 2 UUIDs
private final CdshSocket cdshSocket;
public CdshV2Service(SignalServiceConfiguration configuration, String hexPublicKey, String hexCodeHash) {
this.cdshSocket = new CdshSocket(configuration, hexPublicKey, hexCodeHash, CdshSocket.Version.V2);
}
public Single<ServiceResponse<Response>> getRegisteredUsers(String username, String password, Request request) {
return cdshSocket
.connect(username, password, buildClientRequests(request))
.map(CdshV2Service::parseEntries)
.collect(Collectors.toList())
.flatMap(pages -> {
byte[] token = null;
Map<String, ResponseItem> all = new HashMap<>();
for (Response page : pages) {
all.putAll(page.getResults());
token = token == null ? page.getToken() : token;
}
if (token == null) {
throw new IOException("No token found in response!");
}
return Single.just(new Response(all, token));
})
.map(result -> ServiceResponse.forResult(result, 200, null))
.onErrorReturn(error -> {
if (error instanceof NonSuccessfulResponseCodeException) {
int status = ((NonSuccessfulResponseCodeException) error).getCode();
return ServiceResponse.forApplicationError(error, status, null);
} else {
return ServiceResponse.forUnknownError(error);
}
});
}
private static Response parseEntries(ClientResponse clientResponse) {
byte[] token = !clientResponse.getToken().isEmpty() ? clientResponse.getToken().toByteArray() : null;
Map<String, ResponseItem> results = new HashMap<>();
ByteBuffer parser = clientResponse.getE164PniAciTriples().asReadOnlyByteBuffer();
while (parser.remaining() >= RESPONSE_ITEM_SIZE) {
String e164 = "+" + parser.getLong();
UUID pniUuid = new UUID(parser.getLong(), parser.getLong());
UUID aciUuid = new UUID(parser.getLong(), parser.getLong());
if (!pniUuid.equals(EMPTY_UUID)) {
PNI pni = PNI.from(pniUuid);
ACI aci = aciUuid.equals(EMPTY_UUID) ? null : ACI.from(aciUuid);
results.put(e164, new ResponseItem(pni, Optional.fromNullable(aci)));
}
}
return new Response(results, token);
}
private static List<ClientRequest> buildClientRequests(Request request) {
List<Long> previousE164s = parseAndSortE164Strings(request.previousE164s);
List<Long> newE164s = parseAndSortE164Strings(request.newE164s);
List<Long> removedE164s = parseAndSortE164Strings(request.removedE164s);
return Collections.singletonList(ClientRequest.newBuilder()
.setPrevE164S(toByteString(previousE164s))
.setNewE164S(toByteString(newE164s))
.setDiscardE164S(toByteString(removedE164s))
.setAciUakPairs(toByteString(request.serviceIds))
.setToken(ByteString.copyFrom(request.token))
.setHasMore(false)
.build());
}
private static ByteString toByteString(List<Long> numbers) {
ByteString.Output os = ByteString.newOutput();
for (long number : numbers) {
try {
os.write(ByteUtil.longToByteArray(number));
} catch (IOException e) {
throw new AssertionError("Failed to write long to ByteString", e);
}
}
return os.toByteString();
}
private static ByteString toByteString(Map<ServiceId, ProfileKey> serviceIds) {
ByteString.Output os = ByteString.newOutput();
for (Map.Entry<ServiceId, ProfileKey> entry : serviceIds.entrySet()) {
try {
os.write(UuidUtil.toByteArray(entry.getKey().uuid()));
os.write(UnidentifiedAccess.deriveAccessKeyFrom(entry.getValue()));
} catch (IOException e) {
throw new AssertionError("Failed to write long to ByteString", e);
}
}
return os.toByteString();
}
private static List<Long> parseAndSortE164Strings(Collection<String> e164s) {
return e164s.stream()
.map(Long::parseLong)
.sorted()
.collect(Collectors.toList());
}
public static final class Request {
private final Set<String> previousE164s;
private final Set<String> newE164s;
private final Set<String> removedE164s;
private final Map<ServiceId, ProfileKey> serviceIds;
private final byte[] token;
public Request(Set<String> previousE164s, Set<String> newE164s, Map<ServiceId, ProfileKey> serviceIds, Optional<byte[]> token) {
this.previousE164s = previousE164s;
this.newE164s = newE164s;
this.removedE164s = Collections.emptySet();
this.serviceIds = serviceIds;
this.token = token.isPresent() ? token.get() : new byte[32];
}
public int totalE164s() {
return previousE164s.size() + newE164s.size() - removedE164s.size();
}
public int serviceIdSize() {
return previousE164s.size() + newE164s.size() + removedE164s.size() + serviceIds.size();
}
}
public static final class Response {
private final Map<String, ResponseItem> results;
private final byte[] token;
public Response(Map<String, ResponseItem> results, byte[] token) {
this.results = results;
this.token = token;
}
public Map<String, ResponseItem> getResults() {
return results;
}
public byte[] getToken() {
return token;
}
}
public static final class ResponseItem {
private final PNI pni;
private final Optional<ACI> aci;
public ResponseItem(PNI pni, Optional<ACI> aci) {
this.pni = pni;
this.aci = aci;
}
public PNI getPni() {
return pni;
}
public Optional<ACI> getAci() {
return aci;
}
public boolean hasAci() {
return aci.isPresent();
}
}
}