Skip to content

Commit

Permalink
Send push notifications if we receive a "new message" notification, b…
Browse files Browse the repository at this point in the history
…ut no listener is present
  • Loading branch information
jon-signal committed Nov 12, 2024
1 parent 3fefb24 commit 2f890f7
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,6 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator,
storageServiceExecutor, storageServiceRetryExecutor, config.getSecureStorageServiceConfiguration());
DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, disconnectionRequestListenerExecutor);
WebSocketConnectionEventManager webSocketConnectionEventManager = new WebSocketConnectionEventManager(messagesCluster, clientEventExecutor);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
MessagesCache messagesCache = new MessagesCache(messagesCluster, messageDeliveryScheduler,
messageDeletionAsyncExecutor, clock, dynamicConfigurationManager);
Expand Down Expand Up @@ -629,6 +628,8 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
apnSender, fcmSender, accountsManager, 0, 0);
PushNotificationManager pushNotificationManager =
new PushNotificationManager(accountsManager, apnSender, fcmSender, pushNotificationScheduler);
WebSocketConnectionEventManager webSocketConnectionEventManager =
new WebSocketConnectionEventManager(accountsManager, pushNotificationManager, messagesCluster, clientEventExecutor);
RateLimiters rateLimiters = RateLimiters.createAndValidate(config.getLimitsConfiguration(),
dynamicConfigurationManager, rateLimitersCluster);
ProvisioningManager provisioningManager = new ProvisioningManager(pubsubClient);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubClusterConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.Util;
Expand All @@ -58,6 +59,8 @@
public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter<byte[], byte[]> implements Managed,
DisconnectionRequestListener {

private final AccountsManager accountsManager;
private final PushNotificationManager pushNotificationManager;
private final FaultTolerantRedisClusterClient clusterClient;
private final Executor listenerEventExecutor;

Expand All @@ -81,8 +84,11 @@ public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter<b
private static final Counter UNSUBSCRIBE_ERROR_COUNTER =
Metrics.counter(MetricsUtil.name(WebSocketConnectionEventManager.class, "unsubscribeError"));

private static final Counter MESSAGE_WITHOUT_LISTENER_COUNTER =
Metrics.counter(MetricsUtil.name(WebSocketConnectionEventManager.class, "messageWithoutListener"));
private static final Counter PUB_SUB_EVENT_WITHOUT_LISTENER_COUNTER =
Metrics.counter(MetricsUtil.name(WebSocketConnectionEventManager.class, "pubSubEventWithoutListener"));

private static final Counter MESSAGE_AVAILABLE_WITHOUT_LISTENER_COUNTER =
Metrics.counter(MetricsUtil.name(WebSocketConnectionEventManager.class, "messageAvailableWithoutListener"));

private static final String LISTENER_GAUGE_NAME =
MetricsUtil.name(WebSocketConnectionEventManager.class, "listeners");
Expand All @@ -93,8 +99,13 @@ public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter<b
record AccountAndDeviceIdentifier(UUID accountIdentifier, byte deviceId) {
}

public WebSocketConnectionEventManager(final FaultTolerantRedisClusterClient clusterClient,
final Executor listenerEventExecutor) {
public WebSocketConnectionEventManager(final AccountsManager accountsManager,
final PushNotificationManager pushNotificationManager,
final FaultTolerantRedisClusterClient clusterClient,
final Executor listenerEventExecutor) {

this.accountsManager = accountsManager;
this.pushNotificationManager = pushNotificationManager;

this.clusterClient = clusterClient;
this.listenerEventExecutor = listenerEventExecutor;
Expand Down Expand Up @@ -331,8 +342,29 @@ public void smessage(final RedisClusterNode node, final byte[] shardChannel, fin
default -> logger.warn("Unexpected client event type: {}", clientEvent.getClass());
}
} else {
MESSAGE_WITHOUT_LISTENER_COUNTER.increment();
PUB_SUB_EVENT_WITHOUT_LISTENER_COUNTER.increment();

listenerEventExecutor.execute(() -> unsubscribeIfMissingListener(accountAndDeviceIdentifier));

if (clientEvent.getEventCase() == ClientEvent.EventCase.NEW_MESSAGE_AVAILABLE) {
MESSAGE_AVAILABLE_WITHOUT_LISTENER_COUNTER.increment();

// If we have an active subscription but no registered listener, it's likely that the publisher of this event
// believes that the receiving client was present when it really wasn't. Send a push notification as a
// just-in-case measure.
accountsManager.getByAccountIdentifierAsync(accountAndDeviceIdentifier.accountIdentifier())
.thenAccept(maybeAccount -> maybeAccount.ifPresent(account -> {
try {
pushNotificationManager.sendNewMessageNotification(account, accountAndDeviceIdentifier.deviceId(), true);
} catch (final NotPushRegisteredException ignored) {
}
}))
.whenComplete((ignored, throwable) -> {
if (throwable != null) {
logger.warn("Failed to send follow-up notification to {}:{}", accountAndDeviceIdentifier.accountIdentifier(), accountAndDeviceIdentifier.deviceId());
}
});
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ static CommandDependencies build(
SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator,
storageServiceExecutor, storageServiceRetryExecutor, configuration.getSecureStorageServiceConfiguration());
DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, disconnectionRequestListenerExecutor);
WebSocketConnectionEventManager webSocketConnectionEventManager = new WebSocketConnectionEventManager(messagesCluster, clientEventExecutor);
MessagesCache messagesCache = new MessagesCache(messagesCluster,
messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC(), dynamicConfigurationManager);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
Expand Down Expand Up @@ -264,6 +263,9 @@ static CommandDependencies build(
configuration.getDynamoDbTables().getPushNotificationExperimentSamples().getTableName(),
Clock.systemUTC());

WebSocketConnectionEventManager webSocketConnectionEventManager =
new WebSocketConnectionEventManager(accountsManager, pushNotificationManager, messagesCluster, clientEventExecutor);

environment.lifecycle().manage(apnSender);
environment.lifecycle().manage(disconnectionRequestManager);
environment.lifecycle().manage(webSocketConnectionEventManager);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

import io.lettuce.core.cluster.SlotHash;
Expand All @@ -19,7 +20,9 @@
import io.lettuce.core.cluster.pubsub.api.async.RedisClusterPubSubAsyncCommands;
import io.lettuce.core.cluster.pubsub.api.sync.RedisClusterPubSubCommands;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -35,6 +38,8 @@
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
Expand All @@ -45,7 +50,7 @@ class WebSocketConnectionEventManagerTest {
private WebSocketConnectionEventManager localEventManager;
private WebSocketConnectionEventManager remoteEventManager;

private static ExecutorService clientEventExecutor;
private static ExecutorService webSocketConnectionEventExecutor;

@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
Expand All @@ -67,13 +72,20 @@ public void handleConnectionDisplaced(final boolean connectedElsewhere) {

@BeforeAll
static void setUpBeforeAll() {
clientEventExecutor = Executors.newVirtualThreadPerTaskExecutor();
webSocketConnectionEventExecutor = Executors.newVirtualThreadPerTaskExecutor();
}

@BeforeEach
void setUp() {
localEventManager = new WebSocketConnectionEventManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), clientEventExecutor);
remoteEventManager = new WebSocketConnectionEventManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), clientEventExecutor);
localEventManager = new WebSocketConnectionEventManager(mock(AccountsManager.class),
mock(PushNotificationManager.class),
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
webSocketConnectionEventExecutor);

remoteEventManager = new WebSocketConnectionEventManager(mock(AccountsManager.class),
mock(PushNotificationManager.class),
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
webSocketConnectionEventExecutor);

localEventManager.start();
remoteEventManager.start();
Expand All @@ -87,7 +99,7 @@ void tearDown() {

@AfterAll
static void tearDownAfterAll() {
clientEventExecutor.shutdown();
webSocketConnectionEventExecutor.shutdown();
}

@ParameterizedTest
Expand Down Expand Up @@ -226,7 +238,11 @@ void resubscribe() {
.binaryPubSubAsyncCommands(pubSubAsyncCommands)
.build();

final WebSocketConnectionEventManager eventManager = new WebSocketConnectionEventManager(clusterClient, Runnable::run);
final WebSocketConnectionEventManager eventManager = new WebSocketConnectionEventManager(
mock(AccountsManager.class),
mock(PushNotificationManager.class),
clusterClient,
Runnable::run);

eventManager.start();

Expand Down Expand Up @@ -279,7 +295,7 @@ void resubscribe() {
}

@Test
void smessageWithoutListener() {
void unsubscribeIfMissingListener() {
@SuppressWarnings("unchecked") final RedisClusterPubSubAsyncCommands<byte[], byte[]> pubSubAsyncCommands =
mock(RedisClusterPubSubAsyncCommands.class);

Expand All @@ -289,7 +305,11 @@ void smessageWithoutListener() {
.binaryPubSubAsyncCommands(pubSubAsyncCommands)
.build();

final WebSocketConnectionEventManager eventManager = new WebSocketConnectionEventManager(clusterClient, Runnable::run);
final WebSocketConnectionEventManager eventManager = new WebSocketConnectionEventManager(
mock(AccountsManager.class),
mock(PushNotificationManager.class),
clusterClient,
Runnable::run);

eventManager.start();

Expand All @@ -315,4 +335,59 @@ void smessageWithoutListener() {
verify(pubSubAsyncCommands)
.sunsubscribe(WebSocketConnectionEventManager.getClientEventChannel(noListenerAccountIdentifier, noListenerDeviceId));
}

@Test
void newMessageNotificationWithoutListener() throws NotPushRegisteredException {
final UUID listenerAccountIdentifier = UUID.randomUUID();
final byte listenerDeviceId = Device.PRIMARY_ID;

final UUID noListenerAccountIdentifier = UUID.randomUUID();
final byte noListenerDeviceId = listenerDeviceId + 1;

final Account noListenerAccount = mock(Account.class);

final AccountsManager accountsManager = mock(AccountsManager.class);

when(accountsManager.getByAccountIdentifierAsync(noListenerAccountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(noListenerAccount)));

final PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class);

@SuppressWarnings("unchecked") final RedisClusterPubSubAsyncCommands<byte[], byte[]> pubSubAsyncCommands =
mock(RedisClusterPubSubAsyncCommands.class);

when(pubSubAsyncCommands.ssubscribe(any())).thenReturn(MockRedisFuture.completedFuture(null));

final FaultTolerantRedisClusterClient clusterClient = RedisClusterHelper.builder()
.binaryPubSubAsyncCommands(pubSubAsyncCommands)
.build();

final WebSocketConnectionEventManager eventManager = new WebSocketConnectionEventManager(
accountsManager,
pushNotificationManager,
clusterClient,
Runnable::run);

eventManager.start();

eventManager.handleClientConnected(listenerAccountIdentifier, listenerDeviceId, new WebSocketConnectionEventAdapter())
.toCompletableFuture()
.join();

final byte[] newMessagePayload = ClientEvent.newBuilder()
.setNewMessageAvailable(NewMessageAvailableEvent.getDefaultInstance())
.build()
.toByteArray();

eventManager.smessage(mock(RedisClusterNode.class),
WebSocketConnectionEventManager.getClientEventChannel(listenerAccountIdentifier, listenerDeviceId),
newMessagePayload);

eventManager.smessage(mock(RedisClusterNode.class),
WebSocketConnectionEventManager.getClientEventChannel(noListenerAccountIdentifier, noListenerDeviceId),
newMessagePayload);

verify(pushNotificationManager).sendNewMessageNotification(noListenerAccount, noListenerDeviceId, true);
verifyNoMoreInteractions(pushNotificationManager);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventListener;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
Expand All @@ -52,7 +53,7 @@ class MessagePersisterIntegrationTest {

private Scheduler messageDeliveryScheduler;
private ExecutorService messageDeletionExecutorService;
private ExecutorService clientEventExecutorService;
private ExecutorService websocketConnectionEventExecutor;
private MessagesCache messagesCache;
private MessagesManager messagesManager;
private WebSocketConnectionEventManager webSocketConnectionEventManager;
Expand Down Expand Up @@ -84,8 +85,12 @@ void setUp() throws Exception {
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class),
messageDeletionExecutorService);

clientEventExecutorService = Executors.newVirtualThreadPerTaskExecutor();
webSocketConnectionEventManager = new WebSocketConnectionEventManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), clientEventExecutorService);
websocketConnectionEventExecutor = Executors.newVirtualThreadPerTaskExecutor();
webSocketConnectionEventManager = new WebSocketConnectionEventManager(mock(AccountsManager.class),
mock(PushNotificationManager.class),
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
websocketConnectionEventExecutor);

webSocketConnectionEventManager.start();

messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
Expand All @@ -108,8 +113,8 @@ void tearDown() throws Exception {
messageDeletionExecutorService.shutdown();
messageDeletionExecutorService.awaitTermination(15, TimeUnit.SECONDS);

clientEventExecutorService.shutdown();
clientEventExecutorService.awaitTermination(15, TimeUnit.SECONDS);
websocketConnectionEventExecutor.shutdown();
websocketConnectionEventExecutor.awaitTermination(15, TimeUnit.SECONDS);

messageDeliveryScheduler.dispose();

Expand Down

0 comments on commit 2f890f7

Please sign in to comment.