From 7e063e8ad8d4fd2cc5cab0aa7b6ee58d0f54edd6 Mon Sep 17 00:00:00 2001 From: Greyson Parrelli Date: Fri, 25 Feb 2022 15:16:15 -0500 Subject: [PATCH] Refactor CDSH to allow for code reuse. --- .../api/SignalServiceAccountManager.java | 4 +- .../{CdshService.java => CdshSocket.java} | 112 +++------------ .../api/services/CdshV1Service.java | 127 ++++++++++++++++++ 3 files changed, 147 insertions(+), 96 deletions(-) rename libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/{CdshService.java => CdshSocket.java} (57%) create mode 100644 libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshV1Service.java diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java index fb15e571d2..0d76484196 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java @@ -43,7 +43,7 @@ import org.whispersystems.signalservice.api.push.exceptions.NoContentException; import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException; import org.whispersystems.signalservice.api.push.exceptions.NotFoundException; import org.whispersystems.signalservice.api.push.exceptions.PushNetworkException; -import org.whispersystems.signalservice.api.services.CdshService; +import org.whispersystems.signalservice.api.services.CdshV1Service; import org.whispersystems.signalservice.api.storage.SignalStorageCipher; import org.whispersystems.signalservice.api.storage.SignalStorageManifest; import org.whispersystems.signalservice.api.storage.SignalStorageModels; @@ -509,7 +509,7 @@ public class SignalServiceAccountManager { throws IOException { CdshAuthResponse auth = pushServiceSocket.getCdshAuth(); - CdshService service = new CdshService(configuration, hexPublicKey, hexCodeHash); + CdshV1Service service = new CdshV1Service(configuration, hexPublicKey, hexCodeHash); Single>> result = service.getRegisteredUsers(auth.getUsername(), auth.getPassword(), e164numbers); ServiceResponse> response; diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshService.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshSocket.java similarity index 57% rename from libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshService.java rename to libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshSocket.java index e8a6c6f7a2..536beb8d75 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshService.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshSocket.java @@ -1,12 +1,8 @@ package org.whispersystems.signalservice.api.services; -import com.google.protobuf.ByteString; - -import org.signal.cds.ClientRequest; import org.signal.cds.ClientResponse; import org.signal.libsignal.hsmenclave.HsmEnclaveClient; import org.whispersystems.libsignal.logging.Log; -import org.whispersystems.libsignal.util.ByteUtil; import org.whispersystems.libsignal.util.Pair; import org.whispersystems.signalservice.api.push.ACI; import org.whispersystems.signalservice.api.push.TrustStore; @@ -19,21 +15,14 @@ import org.whispersystems.signalservice.internal.util.Hex; import org.whispersystems.signalservice.internal.util.Util; import org.whispersystems.util.Base64; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.DataInputStream; import java.io.IOException; -import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.KeyManagementException; import java.security.NoSuchAlgorithmException; -import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; -import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -43,7 +32,7 @@ import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; -import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.core.Observable; import okhttp3.ConnectionSpec; import okhttp3.OkHttpClient; import okhttp3.Request; @@ -51,17 +40,9 @@ import okhttp3.Response; import okhttp3.WebSocket; import okhttp3.WebSocketListener; -/** - * Handles network interactions with CDSH, the HSM-backed CDS service. - */ -public final class CdshService { +final class CdshSocket { - private static final String TAG = CdshService.class.getSimpleName(); - - private static final int VERSION = 1; - private static final int MAX_E164S_PER_REQUEST = 5000; - private static final UUID EMPTY_ACI = new UUID(0, 0); - private static final int RESPONSE_ITEM_SIZE = 8 + 16 + 16; // 1 uint64 + 2 UUIDs + private static final String TAG = CdshSocket.class.getSimpleName(); private final OkHttpClient client; private final HsmEnclaveClient enclave; @@ -69,7 +50,7 @@ public final class CdshService { private final String hexPublicKey; private final String hexCodeHash; - public CdshService(SignalServiceConfiguration configuration, String hexPublicKey, String hexCodeHash) { + CdshSocket(SignalServiceConfiguration configuration, String hexPublicKey, String hexCodeHash) { this.baseUrl = configuration.getSignalCdshUrls()[0].getUrl(); this.hexPublicKey = hexPublicKey; this.hexCodeHash = hexCodeHash; @@ -92,17 +73,15 @@ public final class CdshService { } } - public Single>> getRegisteredUsers(String username, String password, Set e164Numbers) { - return Single.create(emitter -> { + Observable connect(String username, String password, List requests) { + return Observable.create(emitter -> { AtomicReference stage = new AtomicReference<>(Stage.WAITING_TO_INITIALIZE); - List addressBook = e164Numbers.stream().map(e -> e.substring(1)).collect(Collectors.toList()); - final Map out = new HashMap<>(); - String url = String.format("%s/discovery/%s/%s", baseUrl, hexPublicKey, hexCodeHash); - Request request = new Request.Builder() - .url(url) - .addHeader("Authorization", basicAuth(username, password)) - .build(); + String url = String.format("%s/discovery/%s/%s", baseUrl, hexPublicKey, hexCodeHash); + Request request = new Request.Builder() + .url(url) + .addHeader("Authorization", basicAuth(username, password)) + .build(); WebSocket webSocket = client.newWebSocket(request, new WebSocketListener() { @Override @@ -112,7 +91,7 @@ public final class CdshService { enclave.completeHandshake(bytes.toByteArray()); stage.set(Stage.WAITING_FOR_RESPONSE); - for (byte[] request : buildPlaintextRequests(addressBook)) { + for (byte[] request : requests) { webSocket.send(okio.ByteString.of(enclave.establishedSend(request))); } @@ -122,9 +101,9 @@ public final class CdshService { try { ClientResponse clientResponse = ClientResponse.parseFrom(rawResponse); - addClientResponseToOutput(clientResponse, out); + emitter.onNext(clientResponse); } catch (IOException e) { - emitter.onSuccess(ServiceResponse.forUnknownError(e)); + emitter.onError(e); } break; @@ -138,18 +117,18 @@ public final class CdshService { @Override public void onClosing(WebSocket webSocket, int code, String reason) { if (code == 1000) { - emitter.onSuccess(ServiceResponse.forResult(out, 200, null)); + emitter.onComplete(); } else { Log.w(TAG, "Remote side is closing with non-normal code " + code); webSocket.close(1000, "Remote closed with code " + code); stage.set(Stage.FAILURE); - emitter.onSuccess(ServiceResponse.forApplicationError(new NonSuccessfulResponseCodeException(code), code, null)); + emitter.onError(new NonSuccessfulResponseCodeException(code)); } } @Override public void onFailure(WebSocket webSocket, Throwable t, Response response) { - emitter.onSuccess(ServiceResponse.forApplicationError(t, response != null ? response.code() : 0, null)); + emitter.onError(t); stage.set(Stage.FAILURE); webSocket.close(1000, "OK"); } @@ -160,65 +139,10 @@ public final class CdshService { }); } - private static void addClientResponseToOutput(ClientResponse responsePB, Map out) { - ByteBuffer parser = responsePB.getE164PniAciTriples().asReadOnlyByteBuffer(); - while (parser.remaining() >= RESPONSE_ITEM_SIZE) { - String e164 = "+" + parser.getLong(); - UUID unusedPni = new UUID(parser.getLong(), parser.getLong()); - UUID aci = new UUID(parser.getLong(), parser.getLong()); - - if (!aci.equals(EMPTY_ACI)) { - out.put(e164, ACI.from(aci)); - } - } - } - - private String basicAuth(String username, String password) { + private static String basicAuth(String username, String password) { return "Basic " + Base64.encodeBytes((username + ":" + password).getBytes(StandardCharsets.UTF_8)); } - private static byte[] e164sToRequest(ByteString e164s, boolean more) { - try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) { - outputStream.write(VERSION); - ClientRequest.newBuilder() - .setNewE164S(e164s) - .setHasMore(more) - .build() - .writeTo(outputStream); - return outputStream.toByteArray(); - } catch (IOException e) { - throw new AssertionError("Failed to write protobuf to the output stream?"); - } - } - - private static List buildPlaintextRequests(List addressBook) { - List out = new ArrayList<>((addressBook.size() / MAX_E164S_PER_REQUEST) + 1); - ByteString.Output e164Page = ByteString.newOutput(); - int pageSize = 0; - - for (String address : addressBook) { - if (pageSize >= MAX_E164S_PER_REQUEST) { - pageSize = 0; - out.add(e164sToRequest(e164Page.toByteString(), true)); - e164Page = ByteString.newOutput(); - } - - try { - e164Page.write(ByteUtil.longToByteArray(Long.parseLong(address))); - } catch (IOException e) { - throw new AssertionError("Failed to write long to ByteString", e); - } - - pageSize++; - } - - if (pageSize > 0) { - out.add(e164sToRequest(e164Page.toByteString(), false)); - } - - return out; - } - private static Pair createTlsSocketFactory(TrustStore trustStore) { try { SSLContext context = SSLContext.getInstance("TLS"); diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshV1Service.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshV1Service.java new file mode 100644 index 0000000000..3f27b5b07e --- /dev/null +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/services/CdshV1Service.java @@ -0,0 +1,127 @@ +package org.whispersystems.signalservice.api.services; + +import com.google.protobuf.ByteString; + +import org.signal.cds.ClientRequest; +import org.signal.cds.ClientResponse; +import org.whispersystems.libsignal.util.ByteUtil; +import org.whispersystems.signalservice.api.push.ACI; +import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException; +import org.whispersystems.signalservice.internal.ServiceResponse; +import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; + +import io.reactivex.rxjava3.core.Single; + +/** + * Handles network interactions with CDSHv1, the HSM-backed CDS service. + */ +public final class CdshV1Service { + + private static final String TAG = CdshV1Service.class.getSimpleName(); + + private static final int VERSION = 1; + private static final int MAX_E164S_PER_REQUEST = 5000; + private static final UUID EMPTY_ACI = new UUID(0, 0); + private static final int RESPONSE_ITEM_SIZE = 8 + 16 + 16; // 1 uint64 + 2 UUIDs + + private final CdshSocket cdshSocket; + + public CdshV1Service(SignalServiceConfiguration configuration, String hexPublicKey, String hexCodeHash) { + this.cdshSocket = new CdshSocket(configuration, hexPublicKey, hexCodeHash); + } + + public Single>> getRegisteredUsers(String username, String password, Set e164Numbers) { + List addressBook = e164Numbers.stream().map(e -> e.substring(1)).collect(Collectors.toList()); + + return cdshSocket + .connect(username, password, buildPlaintextRequests(addressBook)) + .map(CdshV1Service::parseEntries) + .collect(Collectors.toList()) + .flatMap(pages -> { + Map all = new HashMap<>(); + for (Map page : pages) { + all.putAll(page); + } + return Single.just(all); + }) + .map(result -> ServiceResponse.forResult(result, 200, null)) + .onErrorReturn(error -> { + if (error instanceof NonSuccessfulResponseCodeException) { + int status = ((NonSuccessfulResponseCodeException) error).getCode(); + return ServiceResponse.forApplicationError(error, status, null); + } else { + return ServiceResponse.forUnknownError(error); + } + }); + } + + private static Map parseEntries(ClientResponse clientResponse) { + Map out = new HashMap<>(); + ByteBuffer parser = clientResponse.getE164PniAciTriples().asReadOnlyByteBuffer(); + + while (parser.remaining() >= RESPONSE_ITEM_SIZE) { + String e164 = "+" + parser.getLong(); + UUID unusedPni = new UUID(parser.getLong(), parser.getLong()); + UUID aci = new UUID(parser.getLong(), parser.getLong()); + + if (!aci.equals(EMPTY_ACI)) { + out.put(e164, ACI.from(aci)); + } + } + + return out; + } + + private static List buildPlaintextRequests(List addressBook) { + List out = new ArrayList<>((addressBook.size() / MAX_E164S_PER_REQUEST) + 1); + ByteString.Output e164Page = ByteString.newOutput(); + int pageSize = 0; + + for (String address : addressBook) { + if (pageSize >= MAX_E164S_PER_REQUEST) { + pageSize = 0; + out.add(e164sToRequest(e164Page.toByteString(), true)); + e164Page = ByteString.newOutput(); + } + + try { + e164Page.write(ByteUtil.longToByteArray(Long.parseLong(address))); + } catch (IOException e) { + throw new AssertionError("Failed to write long to ByteString", e); + } + + pageSize++; + } + + if (pageSize > 0) { + out.add(e164sToRequest(e164Page.toByteString(), false)); + } + + return out; + } + + private static byte[] e164sToRequest(ByteString e164s, boolean more) { + try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) { + outputStream.write(VERSION); + ClientRequest.newBuilder() + .setNewE164S(e164s) + .setHasMore(more) + .build() + .writeTo(outputStream); + return outputStream.toByteArray(); + } catch (IOException e) { + throw new AssertionError("Failed to write protobuf to the output stream?"); + } + } +}