Skip to content

Commit

Permalink
Revert "fix: reconfigure MQTT client better when things change (#1197)"
Browse files Browse the repository at this point in the history
This reverts commit cddcc5c.
  • Loading branch information
junfuchen99 committed May 13, 2022
1 parent 89bac93 commit 00eaed4
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

import java.io.IOException;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

Expand All @@ -57,7 +56,6 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.internal.verification.VerificationModeFactory.atLeast;
Expand Down Expand Up @@ -91,7 +89,6 @@ void before() {
// handlers here
TestFeatureParameters.clearHandlerCallbacks();
TestFeatureParameters.internalEnableTestingFeatureParameters(DEFAULT_HANDLER);
lenient().when(mqttClient.publish(any())).thenReturn(CompletableFuture.completedFuture(0));
}

@AfterEach
Expand Down
148 changes: 65 additions & 83 deletions src/main/java/com/aws/greengrass/mqttclient/MqttClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import com.aws.greengrass.mqttclient.spool.SpoolerStoreException;
import com.aws.greengrass.security.SecurityService;
import com.aws.greengrass.security.exceptions.MqttConnectionProviderException;
import com.aws.greengrass.util.BatchedSubscriber;
import com.aws.greengrass.util.Coerce;
import com.aws.greengrass.util.LockScope;
import com.aws.greengrass.util.ProxyUtils;
Expand All @@ -39,7 +38,6 @@
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -118,7 +116,6 @@ public class MqttClient implements Closeable {
private final AtomicInteger connectionRoundRobin = new AtomicInteger(0);
@Getter
private final AtomicBoolean mqttOnline = new AtomicBoolean(false);
private final Object httpProxyLock = new Object();

private final EventLoopGroup eventLoopGroup;
private final HostResolver hostResolver;
Expand All @@ -127,9 +124,8 @@ public class MqttClient implements Closeable {
private final Spool spool;
private final ExecutorService executorService;

private TlsContextOptions proxyTlsOptions;
private ClientTlsContext proxyTlsContext;
private String rootCaPath;
private final TlsContextOptions proxyTlsOptions;
private final ClientTlsContext proxyTlsContext;

private ScheduledExecutorService ses;
private final AtomicReference<Future<?>> spoolingFuture = new AtomicReference<>();
Expand Down Expand Up @@ -193,12 +189,9 @@ public MqttClient(DeviceConfiguration deviceConfiguration, ScheduledExecutorServ
Coerce.toInt(mqttTopics.findOrDefault(DEFAULT_MQTT_PING_TIMEOUT, MQTT_PING_TIMEOUT_KEY)))
.withSocketOptions(new SocketOptions()).withTimeoutMs(Coerce.toInt(
mqttTopics.findOrDefault(DEFAULT_MQTT_SOCKET_TIMEOUT, MQTT_SOCKET_TIMEOUT_KEY)));
synchronized (httpProxyLock) {
HttpProxyOptions httpProxyOptions =
ProxyUtils.getHttpProxyOptions(deviceConfiguration, proxyTlsContext);
if (httpProxyOptions != null) {
builder.withHttpProxyOptions(httpProxyOptions);
}
HttpProxyOptions httpProxyOptions = ProxyUtils.getHttpProxyOptions(deviceConfiguration, proxyTlsContext);
if (httpProxyOptions != null) {
builder.withHttpProxyOptions(httpProxyOptions);
}
return builder;
};
Expand All @@ -210,8 +203,8 @@ protected MqttClient(DeviceConfiguration deviceConfiguration,
this.deviceConfiguration = deviceConfiguration;
this.executorService = executorService;
this.ses = ses;
rootCaPath = Coerce.toString(deviceConfiguration.getRootCAFilePath());
this.proxyTlsOptions = getTlsContextOptions(rootCaPath);

this.proxyTlsOptions = getTlsContextOptions(deviceConfiguration);
this.proxyTlsContext = new ClientTlsContext(proxyTlsOptions);

mqttTopics = this.deviceConfiguration.getMQTTNamespace();
Expand All @@ -232,76 +225,67 @@ protected MqttClient(DeviceConfiguration deviceConfiguration,
deviceConfiguration.getSpoolerNamespace();
deviceConfiguration.getAWSRegion();

// Skip the reconnect logic below if device is running offline
if (!deviceConfiguration.isDeviceConfiguredToTalkToCloud()) {
return;
}

// If anything in the device configuration changes, then we will need to reconnect to the cloud
// using the new settings. We do this by calling reconnect() on all of our connections
this.deviceConfiguration.onAnyChange(new BatchedSubscriber((what, node) -> {
// Skip events that don't change anything
if (WhatHappened.timestampUpdated.equals(what) || WhatHappened.interiorAdded.equals(what) || node == null) {
return true;
}

// List of configuration nodes that we need to reconfigure for if they change
if (!(node.childOf(DEVICE_MQTT_NAMESPACE) || node.childOf(DEVICE_PARAM_THING_NAME) || node.childOf(
DEVICE_PARAM_IOT_DATA_ENDPOINT) || node.childOf(DEVICE_PARAM_PRIVATE_KEY_PATH) || node.childOf(
DEVICE_PARAM_CERTIFICATE_FILE_PATH) || node.childOf(DEVICE_PARAM_ROOT_CA_PATH) || node.childOf(
DEVICE_PARAM_AWS_REGION))) {
return true;
this.deviceConfiguration.onAnyChange((what, node) -> {
if (connections.isEmpty()) {
return;
}
if (WhatHappened.childChanged.equals(what) && node != null) {
// List of configuration nodes that we need to reconfigure for if they change
if (!(node.childOf(DEVICE_MQTT_NAMESPACE) || node.childOf(DEVICE_PARAM_THING_NAME) || node
.childOf(DEVICE_PARAM_IOT_DATA_ENDPOINT) || node.childOf(DEVICE_PARAM_PRIVATE_KEY_PATH) || node
.childOf(DEVICE_PARAM_CERTIFICATE_FILE_PATH) || node.childOf(DEVICE_PARAM_ROOT_CA_PATH) || node
.childOf(DEVICE_PARAM_AWS_REGION))) {
return;
}

// Only reconnect when the region changed if the proxy exists
if (node.childOf(DEVICE_PARAM_AWS_REGION) && !ProxyUtils.isProxyConfigured(deviceConfiguration)) {
return true;
}
if (node.childOf(DEVICE_MQTT_NAMESPACE)) {
validateAndSetMqttPublishConfiguration();
}

logger.atDebug().kv("modifiedNode", node.getFullName()).kv("changeType", what)
.log("Reconfiguring MQTT clients");
return false;
}, (what) -> {
validateAndSetMqttPublishConfiguration();

// Reconnect in separate thread to not block publish thread
// Schedule the reconnection for slightly in the future to de-dupe multiple changes
Future<?> oldFuture = reconfigureFuture.getAndSet(ses.schedule(() -> {
// If the rootCa path changed, then we need to update the TLS options
String newRootCaPath = Coerce.toString(deviceConfiguration.getRootCAFilePath());
synchronized (httpProxyLock) {
if (!Objects.equals(rootCaPath, newRootCaPath)) {
if (proxyTlsOptions != null) {
proxyTlsOptions.close();
}
if (proxyTlsContext != null) {
proxyTlsContext.close();
}
rootCaPath = newRootCaPath;
proxyTlsOptions = getTlsContextOptions(rootCaPath);
proxyTlsContext = new ClientTlsContext(proxyTlsOptions);
}
// Only reconnect when the region changed if the proxy exists
if (node.childOf(DEVICE_PARAM_AWS_REGION)
&& !ProxyUtils.isProxyConfigured(deviceConfiguration)) {
return;
}

// Continually try to reconnect until all the connections are reconnected
Set<AwsIotMqttClient> brokenConnections = new CopyOnWriteArraySet<>(connections);
do {
for (AwsIotMqttClient connection : brokenConnections) {
if (Thread.currentThread().isInterrupted()) {
return;
}
logger.atDebug().kv("modifiedNode", node.getFullName()).kv("changeType", what)
.log("Reconfiguring MQTT clients");

// Reconnect in separate thread to not block publish thread
// Schedule the reconnection for slightly in the future to de-dupe multiple changes
Future<?> oldFuture = reconfigureFuture.getAndSet(ses.schedule(() -> {
// Continually try to reconnect until all the connections are reconnected
Set<AwsIotMqttClient> brokenConnections = new CopyOnWriteArraySet<>(connections);
do {
for (AwsIotMqttClient connection : brokenConnections) {
if (Thread.currentThread().isInterrupted()) {
return;
}

try {
connection.reconnect();
brokenConnections.remove(connection);
} catch (InterruptedException | ExecutionException | TimeoutException e) {
logger.atError().setCause(e).kv(CLIENT_ID_KEY, connection.getClientId())
.log("Error while reconnecting MQTT client");
try {
connection.reconnect();
brokenConnections.remove(connection);
} catch (InterruptedException | ExecutionException | TimeoutException e) {
logger.atError().setCause(e).kv(CLIENT_ID_KEY, connection.getClientId())
.log("Error while reconnecting MQTT client");
}
}
}
} while (!brokenConnections.isEmpty());
}, 1, TimeUnit.SECONDS));
} while (!brokenConnections.isEmpty());
}, 1, TimeUnit.SECONDS));

// If a reconfiguration task already existed, then kill it and create a new one
if (oldFuture != null) {
oldFuture.cancel(true);
// If a reconfiguration task already existed, then kill it and create a new one
if (oldFuture != null) {
oldFuture.cancel(true);
}
}
}));
});
}

/**
Expand All @@ -325,14 +309,16 @@ public MqttClient(DeviceConfiguration deviceConfiguration, Spool spool, boolean
this.builderProvider = builderProvider;
this.spool = spool;
this.mqttOnline.set(mqttOnline);
this.builderProvider = builderProvider;
this.executorService = executorService;
rootCaPath = Coerce.toString(deviceConfiguration.getRootCAFilePath());
this.proxyTlsOptions = getTlsContextOptions(rootCaPath);
this.proxyTlsContext = new ClientTlsContext(proxyTlsOptions);
validateAndSetMqttPublishConfiguration();

this.proxyTlsOptions = getTlsContextOptions(deviceConfiguration);
this.proxyTlsContext = new ClientTlsContext(proxyTlsOptions);
}

private TlsContextOptions getTlsContextOptions(String rootCaPath) {
private TlsContextOptions getTlsContextOptions(DeviceConfiguration deviceConfiguration) {
String rootCaPath = Coerce.toString(deviceConfiguration.getRootCAFilePath());
return Utils.isNotEmpty(rootCaPath)
? TlsContextOptions.createDefaultClient().withCertificateAuthorityFromPath(null, rootCaPath)
: TlsContextOptions.createDefaultClient();
Expand Down Expand Up @@ -793,12 +779,8 @@ public synchronized void close() {
}

connections.forEach(AwsIotMqttClient::close);
if (proxyTlsOptions != null) {
proxyTlsOptions.close();
}
if (proxyTlsContext != null) {
proxyTlsContext.close();
}
proxyTlsOptions.close();
proxyTlsContext.close();
clientBootstrap.close();
hostResolver.close();
eventLoopGroup.close();
Expand Down
28 changes: 11 additions & 17 deletions src/main/java/com/aws/greengrass/telemetry/TelemetryAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -246,25 +246,19 @@ void aggregatePeriodicMetrics() {
/**
* Helper for metrics uploader. Also used in tests.
*/
@SuppressWarnings("PMD.AvoidCatchingThrowable")
void publishPeriodicMetrics() {
try {
if (!isConnected.get()) {
logger.atDebug().log("Cannot publish the metrics. MQTT connection interrupted.");
return;
}
long timestamp = Instant.now().toEpochMilli();
long lastPublish = Coerce.toLong(getPeriodicPublishTimeTopic());
Map<Long, List<AggregatedNamespaceData>> metricsToPublishMap =
metricsAggregator.getMetricsToPublish(lastPublish, timestamp);
getPeriodicPublishTimeTopic().withValue(timestamp);
if (metricsToPublishMap != null && metricsToPublishMap.containsKey(timestamp)) {
publisher.publish(MetricsPayload.builder().build(), metricsToPublishMap.get(timestamp));
logger.atInfo().event("telemetry-metrics-published").log("Telemetry metrics update published.");
}
} catch (Throwable t) {
logger.atWarn().log("Error collecting telemetry. Will retry.", t);
if (!isConnected.get()) {
logger.atDebug().log("Cannot publish the metrics. MQTT connection interrupted.");
return;
}
long timestamp = Instant.now().toEpochMilli();
long lastPublish = Coerce.toLong(getPeriodicPublishTimeTopic());
Map<Long, List<AggregatedNamespaceData>> metricsToPublishMap =
metricsAggregator.getMetricsToPublish(lastPublish, timestamp);
getPeriodicPublishTimeTopic().withValue(timestamp);
// TODO: [P41214679] Do not publish if the metrics are empty.
publisher.publish(MetricsPayload.builder().build(), metricsToPublishMap.get(timestamp));
logger.atInfo().event("telemetry-metrics-published").log("Telemetry metrics update published.");
}

private Topic getPeriodicPublishTimeTopic() {
Expand Down
17 changes: 2 additions & 15 deletions src/main/java/com/aws/greengrass/util/BatchedSubscriber.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,14 @@ public BatchedSubscriber(Topics topics,
this((Node) topics, exclusions, callback);
}

/**
* Constructs a new BatchedSubscriber.
*
* @param exclusions predicate for ignoring a subset topic(s) changes
* @param callback action to perform after a <i>batch</i> of changes and on initialization
*/
public BatchedSubscriber(BiPredicate<WhatHappened, Node> exclusions,
Consumer<WhatHappened> callback) {
this((Node) null, exclusions, callback);
}

/**
* Constructs a new BatchedSubscriber.
*
* @param node topic or topics to subscribe to
* @param exclusions predicate for ignoring a subset topic(s) changes
* @param callback action to perform after a <i>batch</i> of changes and on initialization
*/
private BatchedSubscriber(Node node,
private BatchedSubscriber(@NonNull Node node,
BiPredicate<WhatHappened, Node> exclusions,
@NonNull Consumer<WhatHappened> callback) {
this.node = node;
Expand All @@ -138,9 +127,7 @@ public void subscribe() {
* Unsubscribe from the topic(s).
*/
public void unsubscribe() {
if (node != null) {
node.remove(this);
}
node.remove(this);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -125,6 +125,7 @@ void setup(ExtensionContext ec) {
configurationTopics.createLeafChild("periodicPublishMetricsIntervalSeconds").withValue(300);
lenient().when(mockDeviceConfiguration.getTelemetryConfigurationTopics()).thenReturn(configurationTopics);
lenient().when(mockMqttClient.publish(any())).thenReturn(CompletableFuture.completedFuture(0));
lenient().doNothing().when(mockMqttClient).addToCallbackEvents(mqttClientConnectionEventsArgumentCaptor.capture());
telemetryAgent = new TelemetryAgent(config, mockMqttClient, mockDeviceConfiguration, ma, sme, kme, ses, executorService,
3, 1);
}
Expand All @@ -134,6 +135,7 @@ void cleanUp() throws IOException, InterruptedException {
TelemetryConfig.getInstance().closeContext();
telemetryAgent.shutdown();
context.waitForPublishQueueToClear();
Thread.sleep(1000);
ses.shutdownNow();
executorService.shutdownNow();
context.close();
Expand Down Expand Up @@ -234,17 +236,15 @@ void GIVEN_Telemetry_Agent_WHEN_mqtt_is_interrupted_THEN_aggregation_continues_b
});

telemetryAgent.postInject();
long timeoutMs = 10_000;
long timeoutMs = 10000;
verify(mockMqttClient, timeout(timeoutMs).atLeastOnce()).publish(publishRequestArgumentCaptor.capture());
PublishRequest request = publishRequestArgumentCaptor.getValue();
assertEquals(QualityOfService.AT_LEAST_ONCE, request.getQos());
assertEquals("$aws/things/testThing/greengrass/health/json", request.getTopic());
verify(mockMqttClient, timeout(timeoutMs).atLeastOnce())
.addToCallbackEvents(mqttClientConnectionEventsArgumentCaptor.capture());
reset(mockMqttClient);
mqttClientConnectionEventsArgumentCaptor.getValue().onConnectionInterrupted(500);
//verify that nothing is published when mqtt is interrupted
verify(mockMqttClient, never()).publish(publishRequestArgumentCaptor.capture());
verify(mockMqttClient, times(0)).publish(publishRequestArgumentCaptor.capture());
// aggregation is continued irrespective of the mqtt connection
verify(ma, timeout(timeoutMs).atLeastOnce()).aggregateMetrics(anyLong(), anyLong());
}
Expand Down

0 comments on commit 00eaed4

Please sign in to comment.