Skip to content

Commit

Permalink
use direct executor to deflake tests (#33187)
Browse files Browse the repository at this point in the history
* use direct executor to deflake tests

* address PR comments
  • Loading branch information
m-trieu authored Nov 26, 2024
1 parent 292da72 commit 720b824
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -125,7 +126,8 @@ private FanOutStreamingEngineWorkerHarness(
GetWorkBudgetDistributor getWorkBudgetDistributor,
GrpcDispatcherClient dispatcherClient,
Function<WindmillStream.CommitWorkStream, WorkCommitter> workCommitterFactory,
ThrottlingGetDataMetricTracker getDataMetricTracker) {
ThrottlingGetDataMetricTracker getDataMetricTracker,
ExecutorService workerMetadataConsumer) {
this.jobHeader = jobHeader;
this.getDataMetricTracker = getDataMetricTracker;
this.started = false;
Expand All @@ -138,9 +140,7 @@ private FanOutStreamingEngineWorkerHarness(
this.windmillStreamManager =
Executors.newCachedThreadPool(
new ThreadFactoryBuilder().setNameFormat(STREAM_MANAGER_THREAD_NAME).build());
this.workerMetadataConsumer =
Executors.newSingleThreadExecutor(
new ThreadFactoryBuilder().setNameFormat(WORKER_METADATA_CONSUMER_THREAD_NAME).build());
this.workerMetadataConsumer = workerMetadataConsumer;
this.getWorkBudgetDistributor = getWorkBudgetDistributor;
this.totalGetWorkBudget = totalGetWorkBudget;
this.activeMetadataVersion = Long.MIN_VALUE;
Expand Down Expand Up @@ -171,7 +171,11 @@ public static FanOutStreamingEngineWorkerHarness create(
getWorkBudgetDistributor,
dispatcherClient,
workCommitterFactory,
getDataMetricTracker);
getDataMetricTracker,
Executors.newSingleThreadExecutor(
new ThreadFactoryBuilder()
.setNameFormat(WORKER_METADATA_CONSUMER_THREAD_NAME)
.build()));
}

@VisibleForTesting
Expand All @@ -195,7 +199,13 @@ static FanOutStreamingEngineWorkerHarness forTesting(
getWorkBudgetDistributor,
dispatcherClient,
workCommitterFactory,
getDataMetricTracker);
getDataMetricTracker,
// Run the workerMetadataConsumer on the direct calling thread to remove waiting and
// make unit tests more deterministic as we do not have to worry about network IO being
// blocked by the consumeWorkerMetadata() task. Test suites run in different
// environments and non-determinism has lead to past flakiness. See
// https://github.com/apache/beam/issues/28957.
MoreExecutors.newDirectExecutorService());
fanOutStreamingEngineWorkProvider.start();
return fanOutStreamingEngineWorkProvider;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,23 @@ private static Optional<WindmillServiceAddress> parseDirectEndpoint(
.map(address -> AuthenticatedGcpServiceAddress.create(authenticatingService, address))
.map(WindmillServiceAddress::create);

return directEndpointIpV6Address.isPresent()
? directEndpointIpV6Address
: tryParseEndpointIntoHostAndPort(endpointProto.getDirectEndpoint())
.map(WindmillServiceAddress::create);
Optional<WindmillServiceAddress> windmillServiceAddress =
directEndpointIpV6Address.isPresent()
? directEndpointIpV6Address
: tryParseEndpointIntoHostAndPort(endpointProto.getDirectEndpoint())
.map(WindmillServiceAddress::create);

if (!windmillServiceAddress.isPresent()) {
LOG.warn("Endpoint {} could not be parsed into a WindmillServiceAddress.", endpointProto);
}

return windmillServiceAddress;
}

private static Optional<HostAndPort> tryParseEndpointIntoHostAndPort(String directEndpoint) {
try {
return Optional.of(HostAndPort.fromString(directEndpoint));
} catch (IllegalArgumentException e) {
LOG.warn("{} cannot be parsed into a gcpServiceAddress", directEndpoint);
return Optional.empty();
}
}
Expand All @@ -113,19 +119,12 @@ private static Optional<HostAndPort> tryParseDirectEndpointIntoIpV6Address(
try {
directEndpointAddress = Inet6Address.getByName(endpointProto.getDirectEndpoint());
} catch (UnknownHostException e) {
LOG.warn(
"Error occurred trying to parse direct_endpoint={} into IPv6 address. Exception={}",
endpointProto.getDirectEndpoint(),
e.toString());
return Optional.empty();
}

// Inet6Address.getByAddress returns either an IPv4 or an IPv6 address depending on the format
// of the direct_endpoint string.
if (!(directEndpointAddress instanceof Inet6Address)) {
LOG.warn(
"{} is not an IPv6 address. Direct endpoints are expected to be in IPv6 format.",
endpointProto.getDirectEndpoint());
return Optional.empty();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ public CloudWindmillMetadataServiceV1Alpha1Stub getWindmillMetadataServiceStubBl
}
}

LOG.info("Windmill Service endpoint initialized after {} seconds.", secondsWaited);

ImmutableList<CloudWindmillMetadataServiceV1Alpha1Stub> windmillMetadataServiceStubs =
dispatcherStubs.get().windmillMetadataServiceStubs();

Expand Down Expand Up @@ -190,7 +192,7 @@ public void onJobConfig(StreamingGlobalConfig config) {

public synchronized void consumeWindmillDispatcherEndpoints(
ImmutableSet<HostAndPort> dispatcherEndpoints) {
consumeWindmillDispatcherEndpoints(dispatcherEndpoints, /*forceRecreateStubs=*/ false);
consumeWindmillDispatcherEndpoints(dispatcherEndpoints, /* forceRecreateStubs= */ false);
}

private synchronized void consumeWindmillDispatcherEndpoints(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs;

import java.io.PrintWriter;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
Expand All @@ -31,6 +32,7 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalListener;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalListeners;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -50,14 +52,11 @@ public final class ChannelCache implements StatusDataProvider {

private ChannelCache(
Function<WindmillServiceAddress, ManagedChannel> channelFactory,
RemovalListener<WindmillServiceAddress, ManagedChannel> onChannelRemoved) {
RemovalListener<WindmillServiceAddress, ManagedChannel> onChannelRemoved,
Executor channelCloser) {
this.channelCache =
CacheBuilder.newBuilder()
.removalListener(
RemovalListeners.asynchronous(
onChannelRemoved,
Executors.newCachedThreadPool(
new ThreadFactoryBuilder().setNameFormat("GrpcChannelCloser").build())))
.removalListener(RemovalListeners.asynchronous(onChannelRemoved, channelCloser))
.build(
new CacheLoader<WindmillServiceAddress, ManagedChannel>() {
@Override
Expand All @@ -72,11 +71,13 @@ public static ChannelCache create(
return new ChannelCache(
channelFactory,
// Shutdown the channels as they get removed from the cache, so they do not leak.
notification -> shutdownChannel(notification.getValue()));
notification -> shutdownChannel(notification.getValue()),
Executors.newCachedThreadPool(
new ThreadFactoryBuilder().setNameFormat("GrpcChannelCloser").build()));
}

@VisibleForTesting
static ChannelCache forTesting(
public static ChannelCache forTesting(
Function<WindmillServiceAddress, ManagedChannel> channelFactory, Runnable onChannelShutdown) {
return new ChannelCache(
channelFactory,
Expand All @@ -85,7 +86,11 @@ static ChannelCache forTesting(
notification -> {
shutdownChannel(notification.getValue());
onChannelShutdown.run();
});
},
// Run the removal synchronously on the calling thread to prevent waiting on asynchronous
// tasks to run and make unit tests deterministic. In testing, we verify that things are
// removed from the cache.
MoreExecutors.directExecutor());
}

private static void shutdownChannel(ManagedChannel channel) {
Expand All @@ -108,6 +113,7 @@ public void remove(WindmillServiceAddress windmillServiceAddress) {

public void clear() {
channelCache.invalidateAll();
channelCache.cleanUp();
}

/**
Expand Down
Loading

0 comments on commit 720b824

Please sign in to comment.