Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix flaky test #30322

Merged
merged 4 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ public static StreamingEngineClient create(
getWorkBudgetDistributor,
dispatcherClient,
new Random().nextLong());
streamingEngineClient.startGetWorkerMetadataStream();
streamingEngineClient.startWorkerMetadataConsumer();
streamingEngineClient.getWorkBudgetRefresher.start();
streamingEngineClient.start();
return streamingEngineClient;
}

Expand All @@ -206,12 +204,16 @@ static StreamingEngineClient forTesting(
getWorkBudgetDistributor,
dispatcherClient,
clientId);
streamingEngineClient.startGetWorkerMetadataStream();
streamingEngineClient.startWorkerMetadataConsumer();
streamingEngineClient.getWorkBudgetRefresher.start();
streamingEngineClient.start();
return streamingEngineClient;
}

private void start() {
startGetWorkerMetadataStream();
startWorkerMetadataConsumer();
getWorkBudgetRefresher.start();
}

@SuppressWarnings("FutureReturnValueIgnored")
private void startWorkerMetadataConsumer() {
newWorkerMetadataConsumer.submit(
Expand All @@ -223,11 +225,6 @@ private void startWorkerMetadataConsumer() {
});
}

@VisibleForTesting
boolean isWorkerMetadataReady() {
return !connections.get().equals(StreamingEngineConnectionState.EMPTY);
}

@VisibleForTesting
void finish() {
if (!started.compareAndSet(true, false)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
Expand All @@ -50,7 +53,6 @@
import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor;
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor;
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
Expand All @@ -62,10 +64,10 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
Expand All @@ -92,7 +94,10 @@ public class StreamingEngineClientTest {
.setProjectId(PROJECT_ID)
.setWorkerId(WORKER_ID)
.build();

@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);

private final Set<ManagedChannel> channels = new HashSet<>();
private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry();
private final GrpcWindmillStreamFactory streamFactory =
Expand All @@ -109,11 +114,8 @@ public class StreamingEngineClientTest {
private final GrpcDispatcherClient dispatcherClient =
GrpcDispatcherClient.forTesting(
stubFactory, new ArrayList<>(), new ArrayList<>(), new HashSet<>());
private final GetWorkBudgetDistributor getWorkBudgetDistributor =
spy(new TestGetWorkBudgetDistributor());
private final AtomicReference<StreamingEngineConnectionState> connections =
new AtomicReference<>(StreamingEngineConnectionState.EMPTY);
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);
private Server fakeStreamingEngineServer;
private CountDownLatch getWorkerMetadataReady;
private GetWorkerMetadataTestStub fakeGetWorkerMetadataStub;
Expand Down Expand Up @@ -167,14 +169,16 @@ public void setUp() throws IOException {

@After
public void cleanUp() {
Preconditions.checkNotNull(streamingEngineClient).finish();
fakeGetWorkerMetadataStub.close();
fakeStreamingEngineServer.shutdownNow();
channels.forEach(ManagedChannel::shutdownNow);
Preconditions.checkNotNull(streamingEngineClient).finish();
}

private StreamingEngineClient newStreamingEngineClient(
GetWorkBudget getWorkBudget, WorkItemProcessor workItemProcessor) {
GetWorkBudget getWorkBudget,
GetWorkBudgetDistributor getWorkBudgetDistributor,
WorkItemProcessor workItemProcessor) {
return StreamingEngineClient.forTesting(
JOB_HEADER,
getWorkBudget,
Expand All @@ -191,10 +195,15 @@ private StreamingEngineClient newStreamingEngineClient(
public void testStreamsStartCorrectly() throws InterruptedException {
long items = 10L;
long bytes = 10L;
int numBudgetDistributionsExpected = 1;

TestGetWorkBudgetDistributor getWorkBudgetDistributor =
spy(new TestGetWorkBudgetDistributor(numBudgetDistributionsExpected));

streamingEngineClient =
newStreamingEngineClient(
GetWorkBudget.builder().setItems(items).setBytes(bytes).build(),
getWorkBudgetDistributor,
noOpProcessWorkItemFn());

String workerToken = "workerToken1";
Expand All @@ -210,12 +219,14 @@ public void testStreamsStartCorrectly() throws InterruptedException {

getWorkerMetadataReady.await();
fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
StreamingEngineConnectionState currentConnections = waitForWorkerMetadataToBeConsumed(1);
waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);

StreamingEngineConnectionState currentConnections = connections.get();

assertEquals(2, currentConnections.windmillConnections().size());
assertEquals(2, currentConnections.windmillStreams().size());
Set<String> workerTokens =
connections.get().windmillConnections().values().stream()
currentConnections.windmillConnections().values().stream()
.map(WindmillConnection::backendWorkerToken)
.filter(Optional::isPresent)
.map(Optional::get)
Expand All @@ -238,9 +249,13 @@ public void testStreamsStartCorrectly() throws InterruptedException {

@Test
public void testScheduledBudgetRefresh() throws InterruptedException {
TestGetWorkBudgetDistributor getWorkBudgetDistributor =
spy(new TestGetWorkBudgetDistributor(2));
streamingEngineClient =
newStreamingEngineClient(
GetWorkBudget.builder().setItems(1L).setBytes(1L).build(), noOpProcessWorkItemFn());
GetWorkBudget.builder().setItems(1L).setBytes(1L).build(),
getWorkBudgetDistributor,
noOpProcessWorkItemFn());

getWorkerMetadataReady.await();
fakeGetWorkerMetadataStub.injectWorkerMetadata(
Expand All @@ -249,18 +264,21 @@ public void testScheduledBudgetRefresh() throws InterruptedException {
.addWorkEndpoints(metadataResponseEndpoint("workerToken"))
.putAllGlobalDataEndpoints(DEFAULT)
.build());
waitForWorkerMetadataToBeConsumed(1);
Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS);
waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);
verify(getWorkBudgetDistributor, atLeast(2)).distributeBudget(any(), any());
}

@Test
@Ignore("https://github.com/apache/beam/issues/28957") // stuck test
public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers()
throws InterruptedException {
int metadataCount = 2;
TestGetWorkBudgetDistributor getWorkBudgetDistributor =
spy(new TestGetWorkBudgetDistributor(metadataCount));
streamingEngineClient =
newStreamingEngineClient(
GetWorkBudget.builder().setItems(1).setBytes(1).build(), noOpProcessWorkItemFn());
GetWorkBudget.builder().setItems(1).setBytes(1).build(),
getWorkBudgetDistributor,
noOpProcessWorkItemFn());

String workerToken = "workerToken1";
String workerToken2 = "workerToken2";
Expand Down Expand Up @@ -292,9 +310,8 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers()
getWorkerMetadataReady.await();
fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata);

StreamingEngineConnectionState currentConnections = waitForWorkerMetadataToBeConsumed(2);

waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);
StreamingEngineConnectionState currentConnections = connections.get();
assertEquals(1, currentConnections.windmillConnections().size());
assertEquals(1, currentConnections.windmillStreams().size());
Set<String> workerTokens =
Expand All @@ -310,10 +327,6 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers()

@Test
public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedException {
streamingEngineClient =
newStreamingEngineClient(
GetWorkBudget.builder().setItems(1).setBytes(1).build(), noOpProcessWorkItemFn());

String workerToken = "workerToken1";
String workerToken2 = "workerToken2";
String workerToken3 = "workerToken3";
Expand Down Expand Up @@ -346,39 +359,39 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce
.putAllGlobalDataEndpoints(DEFAULT)
.build();

List<WorkerMetadataResponse> workerMetadataResponses =
Lists.newArrayList(firstWorkerMetadata, secondWorkerMetadata, thirdWorkerMetadata);

TestGetWorkBudgetDistributor getWorkBudgetDistributor =
spy(new TestGetWorkBudgetDistributor(workerMetadataResponses.size()));
streamingEngineClient =
newStreamingEngineClient(
GetWorkBudget.builder().setItems(1).setBytes(1).build(),
getWorkBudgetDistributor,
noOpProcessWorkItemFn());

getWorkerMetadataReady.await();
fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
Thread.sleep(50);
fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata);
Thread.sleep(50);
fakeGetWorkerMetadataStub.injectWorkerMetadata(thirdWorkerMetadata);
Thread.sleep(50);
verify(getWorkBudgetDistributor, atLeast(3)).distributeBudget(any(), any());

// Make sure we are injecting the metadata from smallest to largest.
workerMetadataResponses.stream()
.sorted(Comparator.comparingLong(WorkerMetadataResponse::getMetadataVersion))
.forEach(fakeGetWorkerMetadataStub::injectWorkerMetadata);

waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);
verify(getWorkBudgetDistributor, atLeast(workerMetadataResponses.size()))
.distributeBudget(any(), any());
}

private StreamingEngineConnectionState waitForWorkerMetadataToBeConsumed(
int expectedMetadataConsumed) throws InterruptedException {
int currentMetadataConsumed = 0;
StreamingEngineConnectionState currentConsumedMetadata = StreamingEngineConnectionState.EMPTY;
while (true) {
if (!connections.get().equals(currentConsumedMetadata)) {
++currentMetadataConsumed;
if (currentMetadataConsumed == expectedMetadataConsumed) {
break;
}
currentConsumedMetadata = connections.get();
}
}
// Wait for metadata to be consumed and budgets to be redistributed.
Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS);
return connections.get();
private void waitForWorkerMetadataToBeConsumed(
TestGetWorkBudgetDistributor getWorkBudgetDistributor) throws InterruptedException {
getWorkBudgetDistributor.waitForBudgetDistribution();
}

private static class GetWorkerMetadataTestStub
extends CloudWindmillMetadataServiceV1Alpha1Grpc
.CloudWindmillMetadataServiceV1Alpha1ImplBase {
private static final WorkerMetadataResponse CLOSE_ALL_STREAMS =
WorkerMetadataResponse.newBuilder().setMetadataVersion(100).build();
WorkerMetadataResponse.newBuilder().setMetadataVersion(Long.MAX_VALUE).build();
private final CountDownLatch ready;
private @Nullable StreamObserver<WorkerMetadataResponse> responseObserver;

Expand Down Expand Up @@ -426,10 +439,22 @@ private void close() {
}

private static class TestGetWorkBudgetDistributor implements GetWorkBudgetDistributor {
private final CountDownLatch getWorkBudgetDistributorTriggered;

private TestGetWorkBudgetDistributor(int numBudgetDistributionsExpected) {
this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected);
}

@SuppressWarnings("ReturnValueIgnored")
private void waitForBudgetDistribution() throws InterruptedException {
getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS);
}

@Override
public void distributeBudget(
ImmutableCollection<WindmillStreamSender> streams, GetWorkBudget getWorkBudget) {
streams.forEach(stream -> stream.adjustBudget(getWorkBudget.items(), getWorkBudget.bytes()));
getWorkBudgetDistributorTriggered.countDown();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this throw if we decrement too much? if so with the schedled test it seems racy if it happens too often. Maybe could check before decrementing here since it is single-threaded distributing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these tests run on multiple threads in the GHA env?
are the @before and @after code blocks not being run?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't
from javadoc:

countDown
public void countDown()
Decrements the count of the latch, releasing all waiting threads if the count reaches zero.
If the current count is greater than zero then it is decremented. If the new count is zero then all waiting threads are re-enabled for thread scheduling purposes.

If the current count equals zero then nothing happens.

https://docs.oracle.com/javase/8/docs/api/java/util/concurrent/CountDownLatch.html#countDown--

}
}
}
Loading