Skip to content

Commit

Permalink
use direct executor to deflake tests
Browse files Browse the repository at this point in the history
  • Loading branch information
m-trieu committed Nov 21, 2024
1 parent a06454a commit 6dc5803
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -125,7 +126,8 @@ private FanOutStreamingEngineWorkerHarness(
GetWorkBudgetDistributor getWorkBudgetDistributor,
GrpcDispatcherClient dispatcherClient,
Function<WindmillStream.CommitWorkStream, WorkCommitter> workCommitterFactory,
ThrottlingGetDataMetricTracker getDataMetricTracker) {
ThrottlingGetDataMetricTracker getDataMetricTracker,
ExecutorService workerMetadataConsumer) {
this.jobHeader = jobHeader;
this.getDataMetricTracker = getDataMetricTracker;
this.started = false;
Expand All @@ -138,9 +140,7 @@ private FanOutStreamingEngineWorkerHarness(
this.windmillStreamManager =
Executors.newCachedThreadPool(
new ThreadFactoryBuilder().setNameFormat(STREAM_MANAGER_THREAD_NAME).build());
this.workerMetadataConsumer =
Executors.newSingleThreadExecutor(
new ThreadFactoryBuilder().setNameFormat(WORKER_METADATA_CONSUMER_THREAD_NAME).build());
this.workerMetadataConsumer = workerMetadataConsumer;
this.getWorkBudgetDistributor = getWorkBudgetDistributor;
this.totalGetWorkBudget = totalGetWorkBudget;
this.activeMetadataVersion = Long.MIN_VALUE;
Expand Down Expand Up @@ -171,7 +171,11 @@ public static FanOutStreamingEngineWorkerHarness create(
getWorkBudgetDistributor,
dispatcherClient,
workCommitterFactory,
getDataMetricTracker);
getDataMetricTracker,
Executors.newSingleThreadExecutor(
new ThreadFactoryBuilder()
.setNameFormat(WORKER_METADATA_CONSUMER_THREAD_NAME)
.build()));
}

@VisibleForTesting
Expand All @@ -195,7 +199,10 @@ static FanOutStreamingEngineWorkerHarness forTesting(
getWorkBudgetDistributor,
dispatcherClient,
workCommitterFactory,
getDataMetricTracker);
getDataMetricTracker,
// Run the workerMetadataConsumer on the direct calling thread to make testing more
// deterministic.
MoreExecutors.newDirectExecutorService());
fanOutStreamingEngineWorkProvider.start();
return fanOutStreamingEngineWorkProvider;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,23 @@ private static Optional<WindmillServiceAddress> parseDirectEndpoint(
.map(address -> AuthenticatedGcpServiceAddress.create(authenticatingService, address))
.map(WindmillServiceAddress::create);

return directEndpointIpV6Address.isPresent()
? directEndpointIpV6Address
: tryParseEndpointIntoHostAndPort(endpointProto.getDirectEndpoint())
.map(WindmillServiceAddress::create);
Optional<WindmillServiceAddress> windmillServiceAddress =
directEndpointIpV6Address.isPresent()
? directEndpointIpV6Address
: tryParseEndpointIntoHostAndPort(endpointProto.getDirectEndpoint())
.map(WindmillServiceAddress::create);

if (!windmillServiceAddress.isPresent()) {
LOG.warn("Endpoint {} could not be parsed into a WindmillServiceAddress.", endpointProto);
}

return windmillServiceAddress;
}

private static Optional<HostAndPort> tryParseEndpointIntoHostAndPort(String directEndpoint) {
try {
return Optional.of(HostAndPort.fromString(directEndpoint));
} catch (IllegalArgumentException e) {
LOG.warn("{} cannot be parsed into a gcpServiceAddress", directEndpoint);
return Optional.empty();
}
}
Expand All @@ -113,19 +119,12 @@ private static Optional<HostAndPort> tryParseDirectEndpointIntoIpV6Address(
try {
directEndpointAddress = Inet6Address.getByName(endpointProto.getDirectEndpoint());
} catch (UnknownHostException e) {
LOG.warn(
"Error occurred trying to parse direct_endpoint={} into IPv6 address. Exception={}",
endpointProto.getDirectEndpoint(),
e.toString());
return Optional.empty();
}

// Inet6Address.getByAddress returns either an IPv4 or an IPv6 address depending on the format
// of the direct_endpoint string.
if (!(directEndpointAddress instanceof Inet6Address)) {
LOG.warn(
"{} is not an IPv6 address. Direct endpoints are expected to be in IPv6 format.",
endpointProto.getDirectEndpoint());
return Optional.empty();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ public CloudWindmillMetadataServiceV1Alpha1Stub getWindmillMetadataServiceStubBl
}
}

LOG.info("Windmill Service endpoint initialized after {} seconds.", secondsWaited);

ImmutableList<CloudWindmillMetadataServiceV1Alpha1Stub> windmillMetadataServiceStubs =
dispatcherStubs.get().windmillMetadataServiceStubs();

Expand Down Expand Up @@ -190,7 +192,7 @@ public void onJobConfig(StreamingGlobalConfig config) {

public synchronized void consumeWindmillDispatcherEndpoints(
ImmutableSet<HostAndPort> dispatcherEndpoints) {
consumeWindmillDispatcherEndpoints(dispatcherEndpoints, /*forceRecreateStubs=*/ false);
consumeWindmillDispatcherEndpoints(dispatcherEndpoints, /* forceRecreateStubs= */ false);
}

private synchronized void consumeWindmillDispatcherEndpoints(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs;

import java.io.PrintWriter;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
Expand All @@ -31,6 +32,7 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalListener;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalListeners;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -50,14 +52,11 @@ public final class ChannelCache implements StatusDataProvider {

private ChannelCache(
Function<WindmillServiceAddress, ManagedChannel> channelFactory,
RemovalListener<WindmillServiceAddress, ManagedChannel> onChannelRemoved) {
RemovalListener<WindmillServiceAddress, ManagedChannel> onChannelRemoved,
Executor channelCloser) {
this.channelCache =
CacheBuilder.newBuilder()
.removalListener(
RemovalListeners.asynchronous(
onChannelRemoved,
Executors.newCachedThreadPool(
new ThreadFactoryBuilder().setNameFormat("GrpcChannelCloser").build())))
.removalListener(RemovalListeners.asynchronous(onChannelRemoved, channelCloser))
.build(
new CacheLoader<WindmillServiceAddress, ManagedChannel>() {
@Override
Expand All @@ -72,11 +71,13 @@ public static ChannelCache create(
return new ChannelCache(
channelFactory,
// Shutdown the channels as they get removed from the cache, so they do not leak.
notification -> shutdownChannel(notification.getValue()));
notification -> shutdownChannel(notification.getValue()),
Executors.newCachedThreadPool(
new ThreadFactoryBuilder().setNameFormat("GrpcChannelCloser").build()));
}

@VisibleForTesting
static ChannelCache forTesting(
public static ChannelCache forTesting(
Function<WindmillServiceAddress, ManagedChannel> channelFactory, Runnable onChannelShutdown) {
return new ChannelCache(
channelFactory,
Expand All @@ -85,7 +86,9 @@ static ChannelCache forTesting(
notification -> {
shutdownChannel(notification.getValue());
onChannelShutdown.run();
});
},
// Run the removal on the calling thread for better determinism in tests.
MoreExecutors.directExecutor());
}

private static void shutdownChannel(ManagedChannel channel) {
Expand All @@ -108,6 +111,7 @@ public void remove(WindmillServiceAddress windmillServiceAddress) {

public void clear() {
channelCache.invalidateAll();
channelCache.cleanUp();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions;
Expand Down Expand Up @@ -191,10 +190,8 @@ private FanOutStreamingEngineWorkerHarness newFanOutStreamingEngineWorkerHarness
public void testStreamsStartCorrectly() throws InterruptedException {
long items = 10L;
long bytes = 10L;
int numBudgetDistributionsExpected = 1;

TestGetWorkBudgetDistributor getWorkBudgetDistributor =
spy(new TestGetWorkBudgetDistributor(numBudgetDistributionsExpected));
TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor());

fanOutStreamingEngineWorkProvider =
newFanOutStreamingEngineWorkerHarness(
Expand All @@ -215,7 +212,6 @@ public void testStreamsStartCorrectly() throws InterruptedException {

getWorkerMetadataReady.await();
fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());

StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends();

Expand Down Expand Up @@ -249,8 +245,7 @@ public void testStreamsStartCorrectly() throws InterruptedException {
@Test
public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers()
throws InterruptedException {
TestGetWorkBudgetDistributor getWorkBudgetDistributor =
spy(new TestGetWorkBudgetDistributor(1));
TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor());
fanOutStreamingEngineWorkProvider =
newFanOutStreamingEngineWorkerHarness(
GetWorkBudget.builder().setItems(1).setBytes(1).build(),
Expand Down Expand Up @@ -285,10 +280,7 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers()

getWorkerMetadataReady.await();
fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());
getWorkBudgetDistributor.expectNumDistributions(1);
fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata);
assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());
StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends();
assertEquals(1, currentBackends.windmillStreams().size());
Set<String> workerTokens =
Expand Down Expand Up @@ -325,8 +317,7 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce
.putAllGlobalDataEndpoints(DEFAULT)
.build();

TestGetWorkBudgetDistributor getWorkBudgetDistributor =
spy(new TestGetWorkBudgetDistributor(1));
TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor());
fanOutStreamingEngineWorkProvider =
newFanOutStreamingEngineWorkerHarness(
GetWorkBudget.builder().setItems(1).setBytes(1).build(),
Expand All @@ -336,10 +327,7 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce
getWorkerMetadataReady.await();

fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());
getWorkBudgetDistributor.expectNumDistributions(1);
fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata);
assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());

verify(getWorkBudgetDistributor, times(2)).distributeBudget(any(), any());
}
Expand All @@ -354,10 +342,14 @@ public StreamObserver<Windmill.StreamingGetDataRequest> getDataStream(
public void onNext(Windmill.StreamingGetDataRequest getDataRequest) {}

@Override
public void onError(Throwable throwable) {}
public void onError(Throwable throwable) {
responseObserver.onError(throwable);
}

@Override
public void onCompleted() {}
public void onCompleted() {
responseObserver.onCompleted();
}
};
}

Expand All @@ -369,10 +361,14 @@ public StreamObserver<Windmill.StreamingGetWorkRequest> getWorkStream(
public void onNext(Windmill.StreamingGetWorkRequest getWorkRequest) {}

@Override
public void onError(Throwable throwable) {}
public void onError(Throwable throwable) {
responseObserver.onError(throwable);
}

@Override
public void onCompleted() {}
public void onCompleted() {
responseObserver.onCompleted();
}
};
}

Expand All @@ -384,10 +380,14 @@ public StreamObserver<Windmill.StreamingCommitWorkRequest> commitWorkStream(
public void onNext(Windmill.StreamingCommitWorkRequest streamingCommitWorkRequest) {}

@Override
public void onError(Throwable throwable) {}
public void onError(Throwable throwable) {
responseObserver.onError(throwable);
}

@Override
public void onCompleted() {}
public void onCompleted() {
responseObserver.onCompleted();
}
};
}
}
Expand Down Expand Up @@ -422,7 +422,11 @@ public void onError(Throwable throwable) {
}

@Override
public void onCompleted() {}
public void onCompleted() {
if (responseObserver != null) {
responseObserver.onCompleted();
}
}
};
}

Expand All @@ -434,25 +438,10 @@ private void injectWorkerMetadata(WorkerMetadataResponse response) {
}

private static class TestGetWorkBudgetDistributor implements GetWorkBudgetDistributor {
private CountDownLatch getWorkBudgetDistributorTriggered;

private TestGetWorkBudgetDistributor(int numBudgetDistributionsExpected) {
this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected);
}

private boolean waitForBudgetDistribution() throws InterruptedException {
return getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS);
}

private void expectNumDistributions(int numBudgetDistributionsExpected) {
this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected);
}

@Override
public <T extends GetWorkBudgetSpender> void distributeBudget(
ImmutableCollection<T> streams, GetWorkBudget getWorkBudget) {
streams.forEach(stream -> stream.setBudget(getWorkBudget.items(), getWorkBudget.bytes()));
getWorkBudgetDistributorTriggered.countDown();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,27 +105,17 @@ public ManagedChannel apply(WindmillServiceAddress windmillServiceAddress) {
@Test
public void testRemoveAndClose() throws InterruptedException {
String channelName = "existingChannel";
CountDownLatch verifyRemovalListenerAsync = new CountDownLatch(1);
CountDownLatch notifyWhenChannelClosed = new CountDownLatch(1);
cache =
ChannelCache.forTesting(
ignored -> newChannel(channelName),
() -> {
try {
verifyRemovalListenerAsync.await();
notifyWhenChannelClosed.countDown();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
});
ignored -> newChannel(channelName), notifyWhenChannelClosed::countDown);

WindmillServiceAddress someAddress = mock(WindmillServiceAddress.class);
ManagedChannel cachedChannel = cache.get(someAddress);
cache.remove(someAddress);
// Assert that the removal happened before we check to see if the shutdowns happen to confirm
// that removals are async.
assertTrue(cache.isEmpty());
verifyRemovalListenerAsync.countDown();

// Assert that the channel gets shutdown.
notifyWhenChannelClosed.await();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public final class FakeWindmillStubFactory implements ChannelCachingStubFactory
private final ChannelCache channelCache;

public FakeWindmillStubFactory(Supplier<ManagedChannel> channelFactory) {
this.channelCache = ChannelCache.create(ignored -> channelFactory.get());
this.channelCache = ChannelCache.forTesting(ignored -> channelFactory.get(), () -> {});
}

@Override
Expand Down

0 comments on commit 6dc5803

Please sign in to comment.