Skip to content

Commit

Permalink
Add dynamic configuration to optionally use shared MRM data
Browse files Browse the repository at this point in the history
  • Loading branch information
eager-signal committed Nov 7, 2024
1 parent 88a1f95 commit 5d9641a
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

package org.whispersystems.textsecuregcm.configuration.dynamic;

public record DynamicMessagesConfiguration(boolean storeSharedMrmData, boolean mrmViewExperimentEnabled) {
public record DynamicMessagesConfiguration(boolean storeSharedMrmData, boolean fetchSharedMrmData,
boolean useSharedMrmData) {

public DynamicMessagesConfiguration() {
this(false, false);
this(false, false, false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ private void sendCommonPayloadMessage(Account destinationAccount,
if (sharedMrmKey != null) {
messageBuilder.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey));
}
// mrm views phase 1: always set content
// mrm views phase 2: always set content
messageBuilder.setContent(ByteString.copyFrom(payload));

messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ public class MessagesCache {
private final Counter staleEphemeralMessagesCounter = Metrics.counter(
name(MessagesCache.class, "staleEphemeralMessages"));
private final Counter mrmContentRetrievedCounter = Metrics.counter(name(MessagesCache.class, "mrmViewRetrieved"));
private final Counter mrmRetrievalErrorCounter = Metrics.counter(name(MessagesCache.class, "mrmRetrievalError"));
private final Counter mrmPhaseTwoMissingContentCounter = Metrics.counter(
name(MessagesCache.class, "mrmPhaseTwoMissingContent"));
private final Counter sharedMrmDataKeyRemovedCounter = Metrics.counter(
name(MessagesCache.class, "sharedMrmKeyRemoved"));

Expand Down Expand Up @@ -349,19 +352,15 @@ Flux<MessageProtos.Envelope> getAllMessages(final UUID destinationUuid, final by
final Mono<MessageProtos.Envelope> messageMono;
if (message.hasSharedMrmKey()) {

final Mono<?> experimentMono;
if (isStaleEphemeralMessage(message, earliestAllowableEphemeralTimestamp)) {
// skip fetching content for message that will be discarded
experimentMono = Mono.empty();
messageMono = Mono.just(message.toBuilder().clearSharedMrmKey().build());
} else {
experimentMono = maybeRunMrmViewExperiment(message, destinationUuid, destinationDevice);
// mrm views phase 2: fetch shared MRM data -- internally depends on dynamic config that
// enables fetching and using it (the stored messages still always have `content` set upstream)
messageMono = getMessageWithSharedMrmData(message, destinationDevice);
}

// mrm views phase 1: messageMono for sharedMrmKey is always Mono.just(), because messages always have content
// To avoid races, wait for the experiment to run, but ignore any errors
messageMono = experimentMono
.onErrorComplete()
.then(Mono.just(message.toBuilder().clearSharedMrmKey().build()));
} else {
messageMono = Mono.just(message);
}
Expand All @@ -378,14 +377,23 @@ Flux<MessageProtos.Envelope> getAllMessages(final UUID destinationUuid, final by
}

/**
* Runs the fetch and compare logic for the MRM view experiment, if it is enabled.
* Returns the given message with its shared MRM data.
*
* @see DynamicMessagesConfiguration#mrmViewExperimentEnabled()
* @see DynamicMessagesConfiguration#fetchSharedMrmData()
* @see DynamicMessagesConfiguration#useSharedMrmData()
*/
private Mono<?> maybeRunMrmViewExperiment(final MessageProtos.Envelope mrmMessage, final UUID destinationUuid,
private Mono<MessageProtos.Envelope> getMessageWithSharedMrmData(final MessageProtos.Envelope mrmMessage,
final byte destinationDevice) {
if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration()
.mrmViewExperimentEnabled()) {

assert mrmMessage.hasSharedMrmKey();

// mrm views phase 2: messages have content
if (!mrmMessage.hasContent()) {
mrmPhaseTwoMissingContentCounter.increment();
}

if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().fetchSharedMrmData()
|| !mrmMessage.hasContent()) {

final Experiment experiment = new Experiment(MRM_VIEWS_EXPERIMENT_NAME);

Expand All @@ -394,7 +402,7 @@ private Mono<?> maybeRunMrmViewExperiment(final MessageProtos.Envelope mrmMessag
// the message might be addressed to the account's PNI, so use the service ID from the envelope
ServiceIdentifier.valueOf(mrmMessage.getDestinationServiceId()), destinationDevice);

final Mono<MessageProtos.Envelope> mrmMessageMono = Mono.from(redisCluster.withBinaryClusterReactive(
final Mono<MessageProtos.Envelope> messageFromRedisMono = Mono.from(redisCluster.withBinaryClusterReactive(
conn -> conn.reactive().hmget(key, "data".getBytes(StandardCharsets.UTF_8), sharedMrmViewKey)
.collectList()
.publishOn(messageDeliveryScheduler)))
Expand All @@ -416,14 +424,25 @@ private Mono<?> maybeRunMrmViewExperiment(final MessageProtos.Envelope mrmMessag
sink.error(e);
}
})
.onErrorResume(throwable -> {
logger.warn("Failed to retrieve shared mrm data", throwable);
mrmRetrievalErrorCounter.increment();
return Mono.empty();
})
.share();

experiment.compareMonoResult(mrmMessage.toBuilder().clearSharedMrmKey().build(), mrmMessageMono);
if (mrmMessage.hasContent()) {
experiment.compareMonoResult(mrmMessage.toBuilder().clearSharedMrmKey().build(), messageFromRedisMono);
}

return mrmMessageMono;
} else {
return Mono.empty();
if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().useSharedMrmData()
|| !mrmMessage.hasContent()) {
return messageFromRedisMono;
}
}

// if fetching or using shared data is disabled, fallback to just() with the existing message
return Mono.just(mrmMessage.toBuilder().clearSharedMrmKey().build());
}

/**
Expand Down Expand Up @@ -497,13 +516,9 @@ List<MessageProtos.Envelope> getMessagesToPersist(final UUID accountUuid, final
.concatMap(message -> {
final Mono<MessageProtos.Envelope> messageMono;
if (message.hasSharedMrmKey()) {
final Mono<?> experimentMono = maybeRunMrmViewExperiment(message, accountUuid, destinationDevice);

// mrm views phase 1: messageMono for sharedMrmKey is always Mono.just(), because messages always have content
// To avoid races, wait for the experiment to run, but ignore any errors
messageMono = experimentMono
.onErrorComplete()
.then(Mono.just(message.toBuilder().clearSharedMrmKey().build()));
// mrm views phase 2: fetch shared MRM data -- internally depends on dynamic config that
// enables fetching and using it (the stored messages still always have `content` set upstream)
messageMono = getMessageWithSharedMrmData(message, destinationDevice);
} else {
messageMono = Mono.just(message);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ void setup() {

final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfiguration.getInboundMessageByteLimitConfiguration()).thenReturn(inboundMessageByteLimitConfiguration);
when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(new DynamicMessagesConfiguration(true, true));
when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(
new DynamicMessagesConfiguration(true, true, true));

when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.reactivestreams.Publisher;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId;
Expand Down Expand Up @@ -96,15 +97,17 @@ class WithRealCluster {
private MessagesCache messagesCache;

private DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private DynamicConfiguration dynamicConfiguration;

private static final UUID DESTINATION_UUID = UUID.randomUUID();

private static final byte DESTINATION_DEVICE_ID = 7;

@BeforeEach
void setUp() throws Exception {
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(new DynamicMessagesConfiguration(true, true));
dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(
new DynamicMessagesConfiguration(true, true, true));
dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);

Expand Down Expand Up @@ -399,9 +402,13 @@ public void testGetQueuesToPersist(final boolean sealedSender) {
assertEquals(DESTINATION_DEVICE_ID, MessagesCache.getDeviceIdFromQueueName(queues.getFirst()));
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
void testMultiRecipientMessage(final boolean sharedMrmKeyPresent) throws Exception {
@CartesianTest
void testMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean sharedMrmKeyPresent,
@CartesianTest.Values(booleans = {true, false}) final boolean useSharedMrmData) {

when(dynamicConfiguration.getMessagesConfiguration())
.thenReturn(new DynamicMessagesConfiguration(true, true, useSharedMrmData));

final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(UUID.randomUUID());
final byte deviceId = 1;

Expand All @@ -419,7 +426,7 @@ void testMultiRecipientMessage(final boolean sharedMrmKeyPresent) throws Excepti
.toBuilder()
// clear some things added by the helper
.clearServerGuid()
// mrm views phase 1: messages have content
// mrm views phase 2: messages have content
.setContent(
ByteString.copyFrom(mrm.messageForRecipient(mrm.getRecipients().get(destinationServiceId.toLibsignal()))))
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
Expand All @@ -430,10 +437,70 @@ void testMultiRecipientMessage(final boolean sharedMrmKeyPresent) throws Excepti
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey)));

final List<MessageProtos.Envelope> messages = get(destinationServiceId.uuid(), deviceId, 1);
if (useSharedMrmData && !sharedMrmKeyPresent) {
assertTrue(messages.isEmpty());
} else {

assertEquals(1, messages.size());
assertEquals(guid, UUID.fromString(messages.getFirst().getServerGuid()));
assertFalse(messages.getFirst().hasSharedMrmKey());
final SealedSenderMultiRecipientMessage.Recipient recipient = mrm.getRecipients()
.get(destinationServiceId.toLibsignal());
assertArrayEquals(mrm.messageForRecipient(recipient), messages.getFirst().getContent().toByteArray());
}

final Optional<RemovedMessage> removedMessage = messagesCache.remove(destinationServiceId.uuid(), deviceId, guid)
.join();

assertTrue(removedMessage.isPresent());
assertEquals(guid, UUID.fromString(removedMessage.get().serverGuid().toString()));
assertTrue(get(destinationServiceId.uuid(), deviceId, 1).isEmpty());

// updating the shared MRM data is purely async, so we just wait for it
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> {
boolean exists;
do {
exists = 1 == REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey));
} while (exists);
}, "Shared MRM data should be deleted asynchronously");
}

@CartesianTest
void testMultiRecipientMessagePhase2MissingContentSafeguard(
@CartesianTest.Values(booleans = {true, false}) final boolean useSharedMrmData,
@CartesianTest.Values(booleans = {true, false}) final boolean fetchSharedMrmData) {

when(dynamicConfiguration.getMessagesConfiguration())
.thenReturn(new DynamicMessagesConfiguration(true, fetchSharedMrmData, useSharedMrmData));

final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(UUID.randomUUID());
final byte deviceId = 1;

final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(destinationServiceId, deviceId);

final byte[] sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm);

final UUID guid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(guid, destinationServiceId, true)
.toBuilder()
// clear some things added by the helper
.clearServerGuid()
// mrm views phase 2: there is a safeguard against missing content, even if the dynamic configuration
// is to not fetch or use shared MRM data
.clearContent()
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.build();
messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message);

assertEquals(1, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey)));

final List<MessageProtos.Envelope> messages = get(destinationServiceId.uuid(), deviceId, 1);

assertEquals(1, messages.size());
assertEquals(guid, UUID.fromString(messages.getFirst().getServerGuid()));
assertFalse(messages.getFirst().hasSharedMrmKey());

final SealedSenderMultiRecipientMessage.Recipient recipient = mrm.getRecipients()
.get(destinationServiceId.toLibsignal());
assertArrayEquals(mrm.messageForRecipient(recipient), messages.getFirst().getContent().toByteArray());
Expand All @@ -455,19 +522,24 @@ void testMultiRecipientMessage(final boolean sharedMrmKeyPresent) throws Excepti
}, "Shared MRM data should be deleted asynchronously");
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
void testGetMessagesToPersist(final boolean sharedMrmKeyPresent) {
@CartesianTest
void testGetMessagesToPersist(@CartesianTest.Values(booleans = {true, false}) final boolean sharedMrmKeyPresent,
@CartesianTest.Values(booleans = {true, false}) final boolean useSharedMrmData) {

when(dynamicConfiguration.getMessagesConfiguration())
.thenReturn(new DynamicMessagesConfiguration(true, true, useSharedMrmData));

final UUID destinationUuid = UUID.randomUUID();
final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(destinationUuid);
final byte deviceId = 1;

final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(destinationUuid, true);
final MessageProtos.Envelope message = generateRandomMessage(messageGuid,
new AciServiceIdentifier(destinationUuid), true);

messagesCache.insert(messageGuid, destinationUuid, deviceId, message);

final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(
new AciServiceIdentifier(destinationUuid), deviceId);
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(destinationServiceId, deviceId);

final byte[] sharedMrmDataKey;
if (sharedMrmKeyPresent) {
Expand All @@ -477,31 +549,35 @@ void testGetMessagesToPersist(final boolean sharedMrmKeyPresent) {
}

final UUID mrmMessageGuid = UUID.randomUUID();
final MessageProtos.Envelope mrmMessage = generateRandomMessage(mrmMessageGuid, true)
final MessageProtos.Envelope mrmMessage = generateRandomMessage(mrmMessageGuid, destinationServiceId, true)
.toBuilder()
// clear some things added by the helper
.clearServerGuid()
// mrm views phase 1: messages have content
// mrm views phase 2: messages have content
.setContent(
ByteString.copyFrom(mrm.messageForRecipient(mrm.getRecipients().get(new ServiceId.Aci(destinationUuid)))))
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.build();
messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage);

final List<MessageProtos.Envelope> messages = get(destinationUuid, deviceId, 100);
final List<MessageProtos.Envelope> messages = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 100);

assertEquals(2, messages.size());
if (useSharedMrmData && !sharedMrmKeyPresent) {
assertEquals(1, messages.size());
} else {
assertEquals(2, messages.size());

assertEquals(mrmMessage.toBuilder().
clearSharedMrmKey().
setServerGuid(mrmMessageGuid.toString())
.build(),
messages.getLast());
}

assertEquals(message.toBuilder()
.setServerGuid(messageGuid.toString())
.build(),
messages.getFirst());

assertEquals(mrmMessage.toBuilder().
clearSharedMrmKey().
setServerGuid(mrmMessageGuid.toString())
.build(),
messages.getLast());
}

private List<MessageProtos.Envelope> get(final UUID destinationUuid, final byte destinationDeviceId,
Expand Down

0 comments on commit 5d9641a

Please sign in to comment.