Skip to content

Commit

Permalink
Retire the legacy disconnection request system
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-signal committed Nov 12, 2024
1 parent d6f890c commit 3fefb24
Show file tree
Hide file tree
Showing 18 changed files with 20 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor);
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
pubsubClient, accountLockManager, keysManager, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager, webSocketConnectionEventManager,
secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, messagePollExecutor,
clock, config.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager);
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
Expand Down Expand Up @@ -651,7 +651,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
disconnectionRequestManager.addListener(webSocketConnectionEventManager);

final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager(
accountsManager, disconnectionRequestManager, webSocketConnectionEventManager, svr2CredentialsGenerator, svr3CredentialsGenerator,
accountsManager, disconnectionRequestManager, svr2CredentialsGenerator, svr3CredentialsGenerator,
registrationRecoveryPasswordsManager, pushNotificationManager, rateLimiters);

final ReportedMessageMetricsListener reportedMessageMetricsListener = new ReportedMessageMetricsListener(
Expand Down Expand Up @@ -982,7 +982,7 @@ protected void configureServer(final ServerBuilder<?> serverBuilder) {
environment.jersey().register(new AuthDynamicFeature(accountAuthFilter));
environment.jersey().register(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class));
environment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager,
disconnectionRequestManager, webSocketConnectionEventManager));
disconnectionRequestManager));
environment.jersey().register(new TimestampResponseFilter());

///
Expand All @@ -995,7 +995,7 @@ protected void configureServer(final ServerBuilder<?> serverBuilder) {
pushNotificationScheduler, webSocketConnectionEventManager, websocketScheduledExecutor,
messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor));
webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager, webSocketConnectionEventManager));
.register(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager));
webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters));
webSocketEnvironment.jersey().register(new RequestStatisticsFilter(TrafficSource.WEBSOCKET));
webSocketEnvironment.jersey().register(MultiRecipientMessageProvider.class);
Expand Down Expand Up @@ -1138,7 +1138,7 @@ protected void configureServer(final ServerBuilder<?> serverBuilder) {
WebSocketEnvironment<AuthenticatedDevice> provisioningEnvironment = new WebSocketEnvironment<>(environment,
webSocketEnvironment.getRequestLog(), Duration.ofMillis(60000));
provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager,
disconnectionRequestManager, webSocketConnectionEventManager));
disconnectionRequestManager));
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager));
provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager));
provisioningEnvironment.jersey().register(new KeepAliveController(webSocketConnectionEventManager));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ public enum Flow {

private final AccountsManager accounts;
private final DisconnectionRequestManager disconnectionRequestManager;
private final WebSocketConnectionEventManager webSocketConnectionEventManager;
private final ExternalServiceCredentialsGenerator svr2CredentialGenerator;
private final ExternalServiceCredentialsGenerator svr3CredentialGenerator;
private final RateLimiters rateLimiters;
Expand All @@ -65,15 +64,13 @@ public enum Flow {
public RegistrationLockVerificationManager(
final AccountsManager accounts,
final DisconnectionRequestManager disconnectionRequestManager,
final WebSocketConnectionEventManager webSocketConnectionEventManager,
final ExternalServiceCredentialsGenerator svr2CredentialGenerator,
final ExternalServiceCredentialsGenerator svr3CredentialGenerator,
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
final PushNotificationManager pushNotificationManager,
final RateLimiters rateLimiters) {
this.accounts = accounts;
this.disconnectionRequestManager = disconnectionRequestManager;
this.webSocketConnectionEventManager = webSocketConnectionEventManager;
this.svr2CredentialGenerator = svr2CredentialGenerator;
this.svr3CredentialGenerator = svr3CredentialGenerator;
this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager;
Expand Down Expand Up @@ -164,7 +161,6 @@ public void verifyRegistrationLock(final Account account, @Nullable final String
}

final List<Byte> deviceIds = updatedAccount.getDevices().stream().map(Device::getId).toList();
webSocketConnectionEventManager.requestDisconnection(updatedAccount.getUuid(), deviceIds);
disconnectionRequestManager.requestDisconnection(updatedAccount.getUuid(), deviceIds);

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import org.glassfish.jersey.server.monitoring.ApplicationEventListener;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.glassfish.jersey.server.monitoring.RequestEventListener;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;
import org.whispersystems.textsecuregcm.storage.AccountsManager;

/**
Expand All @@ -20,12 +19,10 @@ public class WebsocketRefreshApplicationEventListener implements ApplicationEven
private final WebsocketRefreshRequestEventListener websocketRefreshRequestEventListener;

public WebsocketRefreshApplicationEventListener(final AccountsManager accountsManager,
final DisconnectionRequestManager disconnectionRequestManager,
final WebSocketConnectionEventManager webSocketConnectionEventManager) {
final DisconnectionRequestManager disconnectionRequestManager) {

this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(
disconnectionRequestManager,
webSocketConnectionEventManager,
new LinkedDeviceRefreshRequirementProvider(accountsManager),
new PhoneNumberChangeRefreshRequirementProvider(accountsManager));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
import org.glassfish.jersey.server.monitoring.RequestEventListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;

public class WebsocketRefreshRequestEventListener implements RequestEventListener {

private final DisconnectionRequestManager disconnectionRequestManager;
private final WebSocketConnectionEventManager webSocketConnectionEventManager;
private final WebsocketRefreshRequirementProvider[] providers;

private static final Counter DISPLACED_ACCOUNTS = Metrics.counter(
Expand All @@ -37,11 +35,9 @@ public class WebsocketRefreshRequestEventListener implements RequestEventListene

public WebsocketRefreshRequestEventListener(
final DisconnectionRequestManager disconnectionRequestManager,
final WebSocketConnectionEventManager webSocketConnectionEventManager,
final WebsocketRefreshRequirementProvider... providers) {

this.disconnectionRequestManager = disconnectionRequestManager;
this.webSocketConnectionEventManager = webSocketConnectionEventManager;
this.providers = providers;
}

Expand All @@ -63,7 +59,6 @@ public void onEvent(final RequestEvent event) {
.forEach(pair -> {
try {
displacedDevices.incrementAndGet();
webSocketConnectionEventManager.requestDisconnection(pair.first(), List.of(pair.second()));
disconnectionRequestManager.requestDisconnection(pair.first(), List.of(pair.second()));
} catch (final Exception e) {
logger.error("Could not displace device presence", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
Expand All @@ -36,7 +35,6 @@
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubClusterConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.Util;
Expand Down Expand Up @@ -77,11 +75,6 @@ public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter<b
.build()
.toByteArray();

private static final byte[] DISCONNECT_REQUESTED_EVENT_BYTES = ClientEvent.newBuilder()
.setDisconnectRequested(DisconnectRequested.getDefaultInstance())
.build()
.toByteArray();

private static final Counter PUBLISH_CLIENT_CONNECTION_EVENT_ERROR_COUNTER =
Metrics.counter(MetricsUtil.name(WebSocketConnectionEventManager.class, "publishClientConnectionEventError"));

Expand Down Expand Up @@ -246,35 +239,6 @@ public boolean isLocallyPresent(final UUID accountUuid, final byte deviceId) {
return listenersByAccountAndDeviceIdentifier.containsKey(new AccountAndDeviceIdentifier(accountUuid, deviceId));
}

/**
* Broadcasts a request that all devices associated with the identified account and connected to any event manager
* instance close their network connections.
*
* @param accountIdentifier the account identifier for which to request disconnection
*
* @return a future that completes when the request has been sent
*/
public CompletableFuture<Void> requestDisconnection(final UUID accountIdentifier) {
return requestDisconnection(accountIdentifier, Device.ALL_POSSIBLE_DEVICE_IDS);
}

/**
* Broadcasts a request that the specified devices associated with the identified account and connected to any event
* manager instance close their network connections.
*
* @param accountIdentifier the account identifier for which to request disconnection
* @param deviceIds the IDs of the devices for which to request disconnection
*
* @return a future that completes when the request has been sent
*/
public CompletableFuture<Void> requestDisconnection(final UUID accountIdentifier, final Collection<Byte> deviceIds) {
return CompletableFuture.allOf(deviceIds.stream()
.map(deviceId -> clusterClient.withBinaryCluster(connection -> connection.async()
.spublish(getClientEventChannel(accountIdentifier, deviceId), DISCONNECT_REQUESTED_EVENT_BYTES))
.toCompletableFuture())
.toArray(CompletableFuture[]::new));
}

@Override
public void handleDisconnectionRequest(final UUID accountIdentifier, final Collection<Byte> deviceIds) {
deviceIds.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
Expand Down Expand Up @@ -127,7 +126,6 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
private final SecureStorageClient secureStorageClient;
private final SecureValueRecovery2Client secureValueRecovery2Client;
private final DisconnectionRequestManager disconnectionRequestManager;
private final WebSocketConnectionEventManager webSocketConnectionEventManager;
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager;
private final ClientPublicKeysManager clientPublicKeysManager;
private final Executor accountLockExecutor;
Expand Down Expand Up @@ -210,7 +208,6 @@ public AccountsManager(final Accounts accounts,
final SecureStorageClient secureStorageClient,
final SecureValueRecovery2Client secureValueRecovery2Client,
final DisconnectionRequestManager disconnectionRequestManager,
final WebSocketConnectionEventManager webSocketConnectionEventManager,
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
final ClientPublicKeysManager clientPublicKeysManager,
final Executor accountLockExecutor,
Expand All @@ -229,7 +226,6 @@ public AccountsManager(final Accounts accounts,
this.secureStorageClient = secureStorageClient;
this.secureValueRecovery2Client = secureValueRecovery2Client;
this.disconnectionRequestManager = disconnectionRequestManager;
this.webSocketConnectionEventManager = webSocketConnectionEventManager;
this.registrationRecoveryPasswordsManager = requireNonNull(registrationRecoveryPasswordsManager);
this.clientPublicKeysManager = clientPublicKeysManager;
this.accountLockExecutor = accountLockExecutor;
Expand Down Expand Up @@ -336,7 +332,6 @@ public Account create(final String number,
keysManager.deleteSingleUsePreKeys(pni),
messagesManager.clear(aci),
profilesManager.deleteAll(aci))
.thenCompose(ignored -> webSocketConnectionEventManager.requestDisconnection(aci))
.thenCompose(ignored -> disconnectionRequestManager.requestDisconnection(aci))
.thenCompose(ignored -> accounts.reclaimAccount(e.getExistingAccount(), account, additionalWriteItems))
.thenCompose(ignored -> {
Expand Down Expand Up @@ -601,7 +596,6 @@ private CompletableFuture<Account> removeDevice(final UUID accountIdentifier, fi
})
.whenComplete((ignored, throwable) -> {
if (throwable == null) {
webSocketConnectionEventManager.requestDisconnection(accountIdentifier, List.of(deviceId));
disconnectionRequestManager.requestDisconnection(accountIdentifier, List.of(deviceId));
}
});
Expand Down Expand Up @@ -1249,10 +1243,7 @@ private CompletableFuture<Void> delete(final Account account) {
registrationRecoveryPasswordsManager.removeForNumber(account.getNumber()))
.thenCompose(ignored -> accounts.delete(account.getUuid(), additionalWriteItems))
.thenCompose(ignored -> redisDeleteAsync(account))
.thenRun(() -> {
webSocketConnectionEventManager.requestDisconnection(account.getUuid());
disconnectionRequestManager.requestDisconnection(account.getUuid());
});
.thenRun(() -> disconnectionRequestManager.requestDisconnection(account.getUuid()));
}

private String getAccountMapKey(String key) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ static CommandDependencies build(
new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor);
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
pubsubClient, accountLockManager, keys, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager, webSocketConnectionEventManager,
secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, messagePollExecutor,
clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager);
RateLimiters rateLimiters = RateLimiters.createAndValidate(configuration.getLimitsConfiguration(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
Expand Down Expand Up @@ -96,20 +95,18 @@ class LinkedDeviceRefreshRequirementProviderTest {

private AccountsManager accountsManager;
private DisconnectionRequestManager disconnectionRequestManager;
private WebSocketConnectionEventManager webSocketConnectionEventManager;

private LinkedDeviceRefreshRequirementProvider provider;

@BeforeEach
void setup() {
accountsManager = mock(AccountsManager.class);
disconnectionRequestManager = mock(DisconnectionRequestManager.class);
webSocketConnectionEventManager = mock(WebSocketConnectionEventManager.class);

provider = new LinkedDeviceRefreshRequirementProvider(accountsManager);

final WebsocketRefreshRequestEventListener listener =
new WebsocketRefreshRequestEventListener(disconnectionRequestManager, webSocketConnectionEventManager, provider);
new WebsocketRefreshRequestEventListener(disconnectionRequestManager, provider);

when(applicationEventListener.onRequest(any())).thenReturn(listener);

Expand Down Expand Up @@ -141,10 +138,6 @@ void testDeviceAdded() {

assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size());

verify(webSocketConnectionEventManager).requestDisconnection(account.getUuid(), List.of((byte) 1));
verify(webSocketConnectionEventManager).requestDisconnection(account.getUuid(), List.of((byte) 2));
verify(webSocketConnectionEventManager).requestDisconnection(account.getUuid(), List.of((byte) 3));

verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of((byte) 1));
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of((byte) 2));
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of((byte) 3));
Expand Down Expand Up @@ -175,12 +168,10 @@ void testDeviceRemoved(final int removedDeviceCount) {

assertEquals(200, response.getStatus());

initialDeviceIds.forEach(deviceId -> {
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of(deviceId));
verify(webSocketConnectionEventManager).requestDisconnection(account.getUuid(), List.of(deviceId));
});
initialDeviceIds.forEach(deviceId ->
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of(deviceId)));

verifyNoMoreInteractions(webSocketConnectionEventManager);
verifyNoMoreInteractions(disconnectionRequestManager);
}

@Test
Expand Down
Loading

0 comments on commit 3fefb24

Please sign in to comment.