Refactor CDSH to allow for code reuse.

This commit is contained in:
Greyson Parrelli 2022-02-25 15:16:15 -05:00 committed by Cody Henthorne
parent 88a34936cd
commit 7e063e8ad8
3 changed files with 147 additions and 96 deletions

View file

@ -43,7 +43,7 @@ import org.whispersystems.signalservice.api.push.exceptions.NoContentException;
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException;
import org.whispersystems.signalservice.api.push.exceptions.NotFoundException;
import org.whispersystems.signalservice.api.push.exceptions.PushNetworkException;
import org.whispersystems.signalservice.api.services.CdshService;
import org.whispersystems.signalservice.api.services.CdshV1Service;
import org.whispersystems.signalservice.api.storage.SignalStorageCipher;
import org.whispersystems.signalservice.api.storage.SignalStorageManifest;
import org.whispersystems.signalservice.api.storage.SignalStorageModels;
@ -509,7 +509,7 @@ public class SignalServiceAccountManager {
throws IOException
{
CdshAuthResponse auth = pushServiceSocket.getCdshAuth();
CdshService service = new CdshService(configuration, hexPublicKey, hexCodeHash);
CdshV1Service service = new CdshV1Service(configuration, hexPublicKey, hexCodeHash);
Single<ServiceResponse<Map<String, ACI>>> result = service.getRegisteredUsers(auth.getUsername(), auth.getPassword(), e164numbers);
ServiceResponse<Map<String, ACI>> response;

View file

@ -1,12 +1,8 @@
package org.whispersystems.signalservice.api.services;
import com.google.protobuf.ByteString;
import org.signal.cds.ClientRequest;
import org.signal.cds.ClientResponse;
import org.signal.libsignal.hsmenclave.HsmEnclaveClient;
import org.whispersystems.libsignal.logging.Log;
import org.whispersystems.libsignal.util.ByteUtil;
import org.whispersystems.libsignal.util.Pair;
import org.whispersystems.signalservice.api.push.ACI;
import org.whispersystems.signalservice.api.push.TrustStore;
@ -19,21 +15,14 @@ import org.whispersystems.signalservice.internal.util.Hex;
import org.whispersystems.signalservice.internal.util.Util;
import org.whispersystems.util.Base64;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
@ -43,7 +32,7 @@ import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.core.Observable;
import okhttp3.ConnectionSpec;
import okhttp3.OkHttpClient;
import okhttp3.Request;
@ -51,17 +40,9 @@ import okhttp3.Response;
import okhttp3.WebSocket;
import okhttp3.WebSocketListener;
/**
* Handles network interactions with CDSH, the HSM-backed CDS service.
*/
public final class CdshService {
final class CdshSocket {
private static final String TAG = CdshService.class.getSimpleName();
private static final int VERSION = 1;
private static final int MAX_E164S_PER_REQUEST = 5000;
private static final UUID EMPTY_ACI = new UUID(0, 0);
private static final int RESPONSE_ITEM_SIZE = 8 + 16 + 16; // 1 uint64 + 2 UUIDs
private static final String TAG = CdshSocket.class.getSimpleName();
private final OkHttpClient client;
private final HsmEnclaveClient enclave;
@ -69,7 +50,7 @@ public final class CdshService {
private final String hexPublicKey;
private final String hexCodeHash;
public CdshService(SignalServiceConfiguration configuration, String hexPublicKey, String hexCodeHash) {
CdshSocket(SignalServiceConfiguration configuration, String hexPublicKey, String hexCodeHash) {
this.baseUrl = configuration.getSignalCdshUrls()[0].getUrl();
this.hexPublicKey = hexPublicKey;
this.hexCodeHash = hexCodeHash;
@ -92,17 +73,15 @@ public final class CdshService {
}
}
public Single<ServiceResponse<Map<String, ACI>>> getRegisteredUsers(String username, String password, Set<String> e164Numbers) {
return Single.create(emitter -> {
Observable<ClientResponse> connect(String username, String password, List<byte[]> requests) {
return Observable.create(emitter -> {
AtomicReference<Stage> stage = new AtomicReference<>(Stage.WAITING_TO_INITIALIZE);
List<String> addressBook = e164Numbers.stream().map(e -> e.substring(1)).collect(Collectors.toList());
final Map<String, ACI> out = new HashMap<>();
String url = String.format("%s/discovery/%s/%s", baseUrl, hexPublicKey, hexCodeHash);
Request request = new Request.Builder()
.url(url)
.addHeader("Authorization", basicAuth(username, password))
.build();
String url = String.format("%s/discovery/%s/%s", baseUrl, hexPublicKey, hexCodeHash);
Request request = new Request.Builder()
.url(url)
.addHeader("Authorization", basicAuth(username, password))
.build();
WebSocket webSocket = client.newWebSocket(request, new WebSocketListener() {
@Override
@ -112,7 +91,7 @@ public final class CdshService {
enclave.completeHandshake(bytes.toByteArray());
stage.set(Stage.WAITING_FOR_RESPONSE);
for (byte[] request : buildPlaintextRequests(addressBook)) {
for (byte[] request : requests) {
webSocket.send(okio.ByteString.of(enclave.establishedSend(request)));
}
@ -122,9 +101,9 @@ public final class CdshService {
try {
ClientResponse clientResponse = ClientResponse.parseFrom(rawResponse);
addClientResponseToOutput(clientResponse, out);
emitter.onNext(clientResponse);
} catch (IOException e) {
emitter.onSuccess(ServiceResponse.forUnknownError(e));
emitter.onError(e);
}
break;
@ -138,18 +117,18 @@ public final class CdshService {
@Override
public void onClosing(WebSocket webSocket, int code, String reason) {
if (code == 1000) {
emitter.onSuccess(ServiceResponse.forResult(out, 200, null));
emitter.onComplete();
} else {
Log.w(TAG, "Remote side is closing with non-normal code " + code);
webSocket.close(1000, "Remote closed with code " + code);
stage.set(Stage.FAILURE);
emitter.onSuccess(ServiceResponse.forApplicationError(new NonSuccessfulResponseCodeException(code), code, null));
emitter.onError(new NonSuccessfulResponseCodeException(code));
}
}
@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
emitter.onSuccess(ServiceResponse.forApplicationError(t, response != null ? response.code() : 0, null));
emitter.onError(t);
stage.set(Stage.FAILURE);
webSocket.close(1000, "OK");
}
@ -160,65 +139,10 @@ public final class CdshService {
});
}
private static void addClientResponseToOutput(ClientResponse responsePB, Map<String, ACI> out) {
ByteBuffer parser = responsePB.getE164PniAciTriples().asReadOnlyByteBuffer();
while (parser.remaining() >= RESPONSE_ITEM_SIZE) {
String e164 = "+" + parser.getLong();
UUID unusedPni = new UUID(parser.getLong(), parser.getLong());
UUID aci = new UUID(parser.getLong(), parser.getLong());
if (!aci.equals(EMPTY_ACI)) {
out.put(e164, ACI.from(aci));
}
}
}
private String basicAuth(String username, String password) {
private static String basicAuth(String username, String password) {
return "Basic " + Base64.encodeBytes((username + ":" + password).getBytes(StandardCharsets.UTF_8));
}
private static byte[] e164sToRequest(ByteString e164s, boolean more) {
try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) {
outputStream.write(VERSION);
ClientRequest.newBuilder()
.setNewE164S(e164s)
.setHasMore(more)
.build()
.writeTo(outputStream);
return outputStream.toByteArray();
} catch (IOException e) {
throw new AssertionError("Failed to write protobuf to the output stream?");
}
}
private static List<byte[]> buildPlaintextRequests(List<String> addressBook) {
List<byte[]> out = new ArrayList<>((addressBook.size() / MAX_E164S_PER_REQUEST) + 1);
ByteString.Output e164Page = ByteString.newOutput();
int pageSize = 0;
for (String address : addressBook) {
if (pageSize >= MAX_E164S_PER_REQUEST) {
pageSize = 0;
out.add(e164sToRequest(e164Page.toByteString(), true));
e164Page = ByteString.newOutput();
}
try {
e164Page.write(ByteUtil.longToByteArray(Long.parseLong(address)));
} catch (IOException e) {
throw new AssertionError("Failed to write long to ByteString", e);
}
pageSize++;
}
if (pageSize > 0) {
out.add(e164sToRequest(e164Page.toByteString(), false));
}
return out;
}
private static Pair<SSLSocketFactory, X509TrustManager> createTlsSocketFactory(TrustStore trustStore) {
try {
SSLContext context = SSLContext.getInstance("TLS");

View file

@ -0,0 +1,127 @@
package org.whispersystems.signalservice.api.services;
import com.google.protobuf.ByteString;
import org.signal.cds.ClientRequest;
import org.signal.cds.ClientResponse;
import org.whispersystems.libsignal.util.ByteUtil;
import org.whispersystems.signalservice.api.push.ACI;
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException;
import org.whispersystems.signalservice.internal.ServiceResponse;
import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import io.reactivex.rxjava3.core.Single;
/**
* Handles network interactions with CDSHv1, the HSM-backed CDS service.
*/
public final class CdshV1Service {
private static final String TAG = CdshV1Service.class.getSimpleName();
private static final int VERSION = 1;
private static final int MAX_E164S_PER_REQUEST = 5000;
private static final UUID EMPTY_ACI = new UUID(0, 0);
private static final int RESPONSE_ITEM_SIZE = 8 + 16 + 16; // 1 uint64 + 2 UUIDs
private final CdshSocket cdshSocket;
public CdshV1Service(SignalServiceConfiguration configuration, String hexPublicKey, String hexCodeHash) {
this.cdshSocket = new CdshSocket(configuration, hexPublicKey, hexCodeHash);
}
public Single<ServiceResponse<Map<String, ACI>>> getRegisteredUsers(String username, String password, Set<String> e164Numbers) {
List<String> addressBook = e164Numbers.stream().map(e -> e.substring(1)).collect(Collectors.toList());
return cdshSocket
.connect(username, password, buildPlaintextRequests(addressBook))
.map(CdshV1Service::parseEntries)
.collect(Collectors.toList())
.flatMap(pages -> {
Map<String, ACI> all = new HashMap<>();
for (Map<String, ACI> page : pages) {
all.putAll(page);
}
return Single.just(all);
})
.map(result -> ServiceResponse.forResult(result, 200, null))
.onErrorReturn(error -> {
if (error instanceof NonSuccessfulResponseCodeException) {
int status = ((NonSuccessfulResponseCodeException) error).getCode();
return ServiceResponse.forApplicationError(error, status, null);
} else {
return ServiceResponse.forUnknownError(error);
}
});
}
private static Map<String, ACI> parseEntries(ClientResponse clientResponse) {
Map<String, ACI> out = new HashMap<>();
ByteBuffer parser = clientResponse.getE164PniAciTriples().asReadOnlyByteBuffer();
while (parser.remaining() >= RESPONSE_ITEM_SIZE) {
String e164 = "+" + parser.getLong();
UUID unusedPni = new UUID(parser.getLong(), parser.getLong());
UUID aci = new UUID(parser.getLong(), parser.getLong());
if (!aci.equals(EMPTY_ACI)) {
out.put(e164, ACI.from(aci));
}
}
return out;
}
private static List<byte[]> buildPlaintextRequests(List<String> addressBook) {
List<byte[]> out = new ArrayList<>((addressBook.size() / MAX_E164S_PER_REQUEST) + 1);
ByteString.Output e164Page = ByteString.newOutput();
int pageSize = 0;
for (String address : addressBook) {
if (pageSize >= MAX_E164S_PER_REQUEST) {
pageSize = 0;
out.add(e164sToRequest(e164Page.toByteString(), true));
e164Page = ByteString.newOutput();
}
try {
e164Page.write(ByteUtil.longToByteArray(Long.parseLong(address)));
} catch (IOException e) {
throw new AssertionError("Failed to write long to ByteString", e);
}
pageSize++;
}
if (pageSize > 0) {
out.add(e164sToRequest(e164Page.toByteString(), false));
}
return out;
}
private static byte[] e164sToRequest(ByteString e164s, boolean more) {
try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) {
outputStream.write(VERSION);
ClientRequest.newBuilder()
.setNewE164S(e164s)
.setHasMore(more)
.build()
.writeTo(outputStream);
return outputStream.toByteArray();
} catch (IOException e) {
throw new AssertionError("Failed to write protobuf to the output stream?");
}
}
}