diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java index e946022c4e36..aa0dea80b0a1 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java @@ -17,11 +17,13 @@ */ package org.apache.beam.sdk.fn.data; +import java.time.Duration; import java.util.HashSet; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.function.Consumer; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.pipeline.v1.Endpoints; @@ -30,6 +32,8 @@ import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; @@ -49,13 +53,20 @@ */ public class BeamFnDataGrpcMultiplexer implements AutoCloseable { private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataGrpcMultiplexer.class); + private static final Duration POISONED_INSTRUCTION_ID_CACHE_TIMEOUT = Duration.ofMinutes(20); private final Endpoints.@Nullable ApiServiceDescriptor apiServiceDescriptor; private final StreamObserver inboundObserver; private final StreamObserver outboundObserver; - private final ConcurrentMap< + private final ConcurrentHashMap< /*instructionId=*/ String, CompletableFuture>> receivers; - private final ConcurrentMap erroredInstructionIds; + private final Cache poisonedInstructionIds; + + private static class PoisonedException extends RuntimeException { + public PoisonedException() { + super("Instruction poisoned"); + } + }; public BeamFnDataGrpcMultiplexer( Endpoints.@Nullable ApiServiceDescriptor apiServiceDescriptor, @@ -64,7 +75,8 @@ public BeamFnDataGrpcMultiplexer( baseOutboundObserverFactory) { this.apiServiceDescriptor = apiServiceDescriptor; this.receivers = new ConcurrentHashMap<>(); - this.erroredInstructionIds = new ConcurrentHashMap<>(); + this.poisonedInstructionIds = + CacheBuilder.newBuilder().expireAfterWrite(POISONED_INSTRUCTION_ID_CACHE_TIMEOUT).build(); this.inboundObserver = new InboundObserver(); this.outboundObserver = outboundObserverFactory.outboundObserverFor(baseOutboundObserverFactory, inboundObserver); @@ -87,11 +99,6 @@ public StreamObserver getOutboundObserver() { return outboundObserver; } - private CompletableFuture> receiverFuture( - String instructionId) { - return receivers.computeIfAbsent(instructionId, (unused) -> new CompletableFuture<>()); - } - /** * Registers a consumer for the specified instruction id. * @@ -99,17 +106,63 @@ private CompletableFuture> receiverF * instruction ids ensuring that the receiver will only see {@link BeamFnApi.Elements} with a * single instruction id. * - *

The caller must {@link #unregisterConsumer unregister the consumer} when they no longer wish - * to receive messages. + *

The caller must either {@link #unregisterConsumer unregister the consumer} when all messages + * have been processed or {@link #poisonInstructionId(String) poison the instruction} if messages + * for the instruction should be dropped. */ public void registerConsumer( String instructionId, CloseableFnDataReceiver receiver) { - receiverFuture(instructionId).complete(receiver); + receivers.compute( + instructionId, + (unused, existing) -> { + if (existing != null) { + if (!existing.complete(receiver)) { + throw new IllegalArgumentException("Instruction id was registered twice"); + } + return existing; + } + if (poisonedInstructionIds.getIfPresent(instructionId) != null) { + throw new IllegalArgumentException("Instruction id was poisoned"); + } + return CompletableFuture.completedFuture(receiver); + }); } - /** Unregisters a consumer. */ + /** Unregisters a previously registered consumer. */ public void unregisterConsumer(String instructionId) { - receivers.remove(instructionId); + @Nullable + CompletableFuture> receiverFuture = + receivers.remove(instructionId); + if (receiverFuture != null && !receiverFuture.isDone()) { + // The future must have been inserted by the inbound observer since registerConsumer completes + // the future. + throw new IllegalArgumentException("Unregistering consumer which was not registered."); + } + } + + /** + * Poisons an instruction id. + * + *

Any records for the instruction on the inbound observer will be dropped for the next {@link + * #POISONED_INSTRUCTION_ID_CACHE_TIMEOUT}. + */ + public void poisonInstructionId(String instructionId) { + poisonedInstructionIds.put(instructionId, Boolean.TRUE); + @Nullable + CompletableFuture> receiverFuture = + receivers.remove(instructionId); + if (receiverFuture != null) { + // Completing exceptionally has no effect if the future was already notified. In that case + // whatever registered the receiver needs to handle cancelling it. + receiverFuture.completeExceptionally(new PoisonedException()); + if (!receiverFuture.isCompletedExceptionally()) { + try { + receiverFuture.get().close(); + } catch (Exception e) { + LOG.warn("Unexpected error closing existing observer"); + } + } + } } @VisibleForTesting @@ -210,27 +263,42 @@ public void onNext(BeamFnApi.Elements value) { } private void forwardToConsumerForInstructionId(String instructionId, BeamFnApi.Elements value) { - if (erroredInstructionIds.containsKey(instructionId)) { - LOG.debug("Ignoring inbound data for failed instruction {}", instructionId); - return; - } - CompletableFuture> consumerFuture = - receiverFuture(instructionId); - if (!consumerFuture.isDone()) { - LOG.debug( - "Received data for instruction {} without consumer ready. " - + "Waiting for consumer to be registered.", - instructionId); - } CloseableFnDataReceiver consumer; try { - consumer = consumerFuture.get(); - + CompletableFuture> consumerFuture = + receivers.computeIfAbsent( + instructionId, + (unused) -> { + if (poisonedInstructionIds.getIfPresent(instructionId) != null) { + throw new PoisonedException(); + } + LOG.debug( + "Received data for instruction {} without consumer ready. " + + "Waiting for consumer to be registered.", + instructionId); + return new CompletableFuture<>(); + }); + // The consumer may not be registered until the bundle processor is fully constructed so we + // conservatively set + // a high timeout. Poisoning will prevent this for occurring for consumers that will not be + // registered. + consumer = consumerFuture.get(3, TimeUnit.HOURS); /* * TODO: On failure we should fail any bundles that were impacted eagerly * instead of relying on the Runner harness to do all the failure handling. */ - } catch (ExecutionException | InterruptedException e) { + } catch (TimeoutException e) { + LOG.error( + "Timed out waiting to observe consumer data stream for instruction {}", + instructionId, + e); + outboundObserver.onError(e); + return; + } catch (ExecutionException | InterruptedException | PoisonedException e) { + if (e instanceof PoisonedException || e.getCause() instanceof PoisonedException) { + LOG.debug("Received data for poisoned instruction {}. Dropping input.", instructionId); + return; + } LOG.error( "Client interrupted during handling of data for instruction {}", instructionId, e); outboundObserver.onError(e); @@ -240,10 +308,11 @@ private void forwardToConsumerForInstructionId(String instructionId, BeamFnApi.E outboundObserver.onError(e); return; } + try { consumer.accept(value); } catch (Exception e) { - erroredInstructionIds.put(instructionId, true); + poisonInstructionId(instructionId); } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java index 3a7a0d5a8935..37580824b558 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java @@ -280,6 +280,7 @@ public void testFailedProcessingCausesAdditionalInboundDataToBeIgnored() throws DESCRIPTOR, OutboundObserverFactory.clientDirect(), inboundObserver -> TestStreams.withOnNext(outboundValues::add).build()); + final AtomicBoolean closed = new AtomicBoolean(); multiplexer.registerConsumer( DATA_INSTRUCTION_ID, new CloseableFnDataReceiver() { @@ -290,7 +291,7 @@ public void flush() throws Exception { @Override public void close() throws Exception { - fail("Unexpected call"); + closed.set(true); } @Override @@ -320,6 +321,7 @@ public void accept(BeamFnApi.Elements input) throws Exception { dataInboundValues, Matchers.contains( BeamFnApi.Elements.newBuilder().addData(data.setTransformId("A").build()).build())); + assertTrue(closed.get()); } @Test diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index c91d5ba71b89..0d517503b12d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -64,7 +64,6 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse; -import org.apache.beam.model.pipeline.v1.Endpoints; import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor; import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.model.pipeline.v1.RunnerApi.Coder; @@ -93,6 +92,7 @@ import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.TextFormat; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheLoader; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; @@ -108,20 +108,19 @@ import org.slf4j.LoggerFactory; /** - * Processes {@link BeamFnApi.ProcessBundleRequest}s and {@link - * BeamFnApi.ProcessBundleSplitRequest}s. + * Processes {@link ProcessBundleRequest}s and {@link BeamFnApi.ProcessBundleSplitRequest}s. * *

{@link BeamFnApi.ProcessBundleSplitRequest}s use a {@link BundleProcessorCache cache} to * find/create a {@link BundleProcessor}. The creation of a {@link BundleProcessor} uses the - * associated {@link BeamFnApi.ProcessBundleDescriptor} definition; creating runners for each {@link + * associated {@link ProcessBundleDescriptor} definition; creating runners for each {@link * RunnerApi.FunctionSpec}; wiring them together based upon the {@code input} and {@code output} map * definitions. The {@link BundleProcessor} executes the DAG based graph by starting all runners in * reverse topological order, and finishing all runners in forward topological order. * *

{@link BeamFnApi.ProcessBundleSplitRequest}s finds an {@code active} {@link BundleProcessor} - * associated with a currently processing {@link BeamFnApi.ProcessBundleRequest} and uses it to - * perform a split request. See breaking the - * fusion barrier for further details. + * associated with a currently processing {@link ProcessBundleRequest} and uses it to perform a + * split request. See breaking the fusion + * barrier for further details. */ @SuppressWarnings({ "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) @@ -153,7 +152,7 @@ public class ProcessBundleHandler { } private final PipelineOptions options; - private final Function fnApiRegistry; + private final Function fnApiRegistry; private final BeamFnDataClient beamFnDataClient; private final BeamFnStateGrpcClientCache beamFnStateGrpcClientCache; private final FinalizeBundleHandler finalizeBundleHandler; @@ -170,7 +169,7 @@ public class ProcessBundleHandler { public ProcessBundleHandler( PipelineOptions options, Set runnerCapabilities, - Function fnApiRegistry, + Function fnApiRegistry, BeamFnDataClient beamFnDataClient, BeamFnStateGrpcClientCache beamFnStateGrpcClientCache, FinalizeBundleHandler finalizeBundleHandler, @@ -197,7 +196,7 @@ public ProcessBundleHandler( ProcessBundleHandler( PipelineOptions options, Set runnerCapabilities, - Function fnApiRegistry, + Function fnApiRegistry, BeamFnDataClient beamFnDataClient, BeamFnStateGrpcClientCache beamFnStateGrpcClientCache, FinalizeBundleHandler finalizeBundleHandler, @@ -216,7 +215,7 @@ public ProcessBundleHandler( this.runnerCapabilities = runnerCapabilities; this.runnerAcceptsShortIds = runnerCapabilities.contains( - BeamUrns.getUrn(RunnerApi.StandardRunnerProtocols.Enum.MONITORING_INFO_SHORT_IDS)); + BeamUrns.getUrn(StandardRunnerProtocols.Enum.MONITORING_INFO_SHORT_IDS)); this.executionStateSampler = executionStateSampler; this.urnToPTransformRunnerFactoryMap = urnToPTransformRunnerFactoryMap; this.defaultPTransformRunnerFactory = @@ -232,7 +231,7 @@ private void createRunnerAndConsumersForPTransformRecursively( String pTransformId, PTransform pTransform, Supplier processBundleInstructionId, - Supplier> cacheTokens, + Supplier> cacheTokens, Supplier> bundleCache, ProcessBundleDescriptor processBundleDescriptor, SetMultimap pCollectionIdsToConsumingPTransforms, @@ -242,7 +241,7 @@ private void createRunnerAndConsumersForPTransformRecursively( PTransformFunctionRegistry finishFunctionRegistry, Consumer addResetFunction, Consumer addTearDownFunction, - BiConsumer> addDataEndpoint, + BiConsumer> addDataEndpoint, Consumer> addTimerEndpoint, Consumer addBundleProgressReporter, BundleSplitListener splitListener, @@ -499,28 +498,29 @@ public BundleFinalizer getBundleFinalizer() { * Processes a bundle, running the start(), process(), and finish() functions. This function is * required to be reentrant. */ - public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.InstructionRequest request) + public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest request) throws Exception { - BeamFnApi.ProcessBundleResponse.Builder response = BeamFnApi.ProcessBundleResponse.newBuilder(); - - BundleProcessor bundleProcessor = - bundleProcessorCache.get( - request, - () -> { - try { - return createBundleProcessor( - request.getProcessBundle().getProcessBundleDescriptorId(), - request.getProcessBundle()); - } catch (IOException e) { - throw new RuntimeException(e); - } - }); + @Nullable BundleProcessor bundleProcessor = null; try { + bundleProcessor = + Preconditions.checkNotNull( + bundleProcessorCache.get( + request, + () -> { + try { + return createBundleProcessor( + request.getProcessBundle().getProcessBundleDescriptorId(), + request.getProcessBundle()); + } catch (IOException e) { + throw new RuntimeException(e); + } + })); + PTransformFunctionRegistry startFunctionRegistry = bundleProcessor.getStartFunctionRegistry(); PTransformFunctionRegistry finishFunctionRegistry = bundleProcessor.getFinishFunctionRegistry(); ExecutionStateTracker stateTracker = bundleProcessor.getStateTracker(); - + ProcessBundleResponse.Builder response = ProcessBundleResponse.newBuilder(); try (HandleStateCallsForBundle beamFnStateClient = bundleProcessor.getBeamFnStateClient()) { stateTracker.start(request.getInstructionId()); try { @@ -596,12 +596,17 @@ public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.Instruction request.getProcessBundle().getProcessBundleDescriptorId(), bundleProcessor); return BeamFnApi.InstructionResponse.newBuilder().setProcessBundle(response); } catch (Exception e) { - // Make sure we clean up from the active set of bundle processors. LOG.debug( - "Discard bundleProcessor for {} after exception: {}", + "Error processing bundle {} with bundleProcessor for {} after exception: {}", + request.getInstructionId(), request.getProcessBundle().getProcessBundleDescriptorId(), e.getMessage()); - bundleProcessorCache.discard(bundleProcessor); + if (bundleProcessor != null) { + // Make sure we clean up from the active set of bundle processors. + bundleProcessorCache.discard(bundleProcessor); + } + // Ensure that if more data arrives for the instruction it is discarded. + beamFnDataClient.poisonInstructionId(request.getInstructionId()); throw e; } } @@ -643,7 +648,7 @@ private void embedOutboundElementsIfApplicable( } } - public BeamFnApi.InstructionResponse.Builder progress(BeamFnApi.InstructionRequest request) + public BeamFnApi.InstructionResponse.Builder progress(InstructionRequest request) throws Exception { BundleProcessor bundleProcessor = bundleProcessorCache.find(request.getProcessBundleProgress().getInstructionId()); @@ -727,7 +732,7 @@ private Map finalMonitoringData(BundleProcessor bundleProces } /** Splits an active bundle. */ - public BeamFnApi.InstructionResponse.Builder trySplit(BeamFnApi.InstructionRequest request) { + public BeamFnApi.InstructionResponse.Builder trySplit(InstructionRequest request) { BundleProcessor bundleProcessor = bundleProcessorCache.find(request.getProcessBundleSplit().getInstructionId()); BeamFnApi.ProcessBundleSplitResponse.Builder response = @@ -772,8 +777,8 @@ public void discard() { } private BundleProcessor createBundleProcessor( - String bundleId, BeamFnApi.ProcessBundleRequest processBundleRequest) throws IOException { - BeamFnApi.ProcessBundleDescriptor bundleDescriptor = fnApiRegistry.apply(bundleId); + String bundleId, ProcessBundleRequest processBundleRequest) throws IOException { + ProcessBundleDescriptor bundleDescriptor = fnApiRegistry.apply(bundleId); SetMultimap pCollectionIdsToConsumingPTransforms = HashMultimap.create(); BundleProgressReporter.InMemory bundleProgressReporterAndRegistrar = @@ -799,8 +804,7 @@ private BundleProcessor createBundleProcessor( List tearDownFunctions = new ArrayList<>(); // Build a multimap of PCollection ids to PTransform ids which consume said PCollections - for (Map.Entry entry : - bundleDescriptor.getTransformsMap().entrySet()) { + for (Map.Entry entry : bundleDescriptor.getTransformsMap().entrySet()) { for (String pCollectionId : entry.getValue().getInputsMap().values()) { pCollectionIdsToConsumingPTransforms.put(pCollectionId, entry.getKey()); } @@ -848,8 +852,7 @@ public void afterBundleCommit(Instant callbackExpiry, Callback callback) { runnerCapabilities); // Create a BeamFnStateClient - for (Map.Entry entry : - bundleDescriptor.getTransformsMap().entrySet()) { + for (Map.Entry entry : bundleDescriptor.getTransformsMap().entrySet()) { // Skip anything which isn't a root. // Also force data output transforms to be unconditionally instantiated (see BEAM-10450). @@ -1090,7 +1093,7 @@ public static BundleProcessor create( abstract HandleStateCallsForBundle getBeamFnStateClient(); - abstract List getInboundEndpointApiServiceDescriptors(); + abstract List getInboundEndpointApiServiceDescriptors(); abstract List> getInboundDataEndpoints(); @@ -1117,7 +1120,7 @@ synchronized List getCacheTokens() { synchronized Cache getBundleCache() { if (this.bundleCache == null) { this.bundleCache = - new Caches.ClearableCache<>( + new ClearableCache<>( Caches.subCache(getProcessWideCache(), "Bundle", this.instructionId)); } return this.bundleCache; @@ -1264,7 +1267,7 @@ public void close() throws Exception { } @Override - public CompletableFuture handle(BeamFnApi.StateRequest.Builder requestBuilder) { + public CompletableFuture handle(StateRequest.Builder requestBuilder) { throw new IllegalStateException( String.format( "State API calls are unsupported because the " diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java index 75f3a24301c9..94d59d0fcb62 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java @@ -55,10 +55,19 @@ void registerReceiver( * successfully. * *

It is expected that if a bundle fails during processing then the failure will become visible - * to the {@link BeamFnDataClient} during a future {@link FnDataReceiver#accept} invocation. + * to the {@link BeamFnDataClient} during a future {@link FnDataReceiver#accept} invocation or via + * a call to {@link #poisonInstructionId}. */ void unregisterReceiver(String instructionId, List apiServiceDescriptors); + /** + * Poisons the instruction id, indicating that future data arriving for it should be discarded. + * Unregisters the receiver if was registered. + * + * @param instructionId + */ + void poisonInstructionId(String instructionId); + /** * Creates a {@link BeamFnDataOutboundAggregator} for buffering and sending outbound data and * timers over the data plane. It is important that {@link diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java index 981b115c58e7..cd1ac26e364d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java @@ -82,6 +82,14 @@ public void unregisterReceiver( } } + @Override + public void poisonInstructionId(String instructionId) { + LOG.debug("Poisoning instruction {}", instructionId); + for (BeamFnDataGrpcMultiplexer client : multiplexerCache.values()) { + client.poisonInstructionId(instructionId); + } + } + @Override public BeamFnDataOutboundAggregator createOutboundAggregator( ApiServiceDescriptor apiServiceDescriptor, diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java index 9328dc86c009..acfd3bb70202 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java @@ -92,6 +92,11 @@ public BeamFnDataOutboundAggregator createOutboundAggregator( boolean collectElementsIfNoFlushes) { throw new UnsupportedOperationException("Unexpected call during test."); } + + @Override + public void poisonInstructionId(String instructionId) { + throw new UnsupportedOperationException("Unexpected call during test."); + } }) .beamFnStateClient( new BeamFnStateClient() { diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index 2d1e323707f7..95b404aa6203 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -1516,6 +1516,7 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { // Ensure that we unregister during successful processing verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any()); + verify(beamFnDataClient).poisonInstructionId(eq("instructionId")); verifyNoMoreInteractions(beamFnDataClient); } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java index 3489fe766891..514cf61ded40 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java @@ -23,14 +23,17 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import org.apache.beam.model.fnexecution.v1.BeamFnApi; @@ -281,6 +284,93 @@ public StreamObserver data( } } + @Test + public void testForInboundConsumerThatIsPoisoned() throws Exception { + CountDownLatch waitForClientToConnect = new CountDownLatch(1); + CountDownLatch receivedAElement = new CountDownLatch(1); + Collection> inboundValuesA = new ConcurrentLinkedQueue<>(); + Collection inboundServerValues = new ConcurrentLinkedQueue<>(); + AtomicReference> outboundServerObserver = + new AtomicReference<>(); + CallStreamObserver inboundServerObserver = + TestStreams.withOnNext(inboundServerValues::add).build(); + + Endpoints.ApiServiceDescriptor apiServiceDescriptor = + Endpoints.ApiServiceDescriptor.newBuilder() + .setUrl(this.getClass().getName() + "-" + UUID.randomUUID()) + .build(); + Server server = + InProcessServerBuilder.forName(apiServiceDescriptor.getUrl()) + .addService( + new BeamFnDataGrpc.BeamFnDataImplBase() { + @Override + public StreamObserver data( + StreamObserver outboundObserver) { + outboundServerObserver.set(outboundObserver); + waitForClientToConnect.countDown(); + return inboundServerObserver; + } + }) + .build(); + server.start(); + + try { + ManagedChannel channel = + InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build(); + + BeamFnDataGrpcClient clientFactory = + new BeamFnDataGrpcClient( + PipelineOptionsFactory.create(), + (Endpoints.ApiServiceDescriptor descriptor) -> channel, + OutboundObserverFactory.trivial()); + + BeamFnDataInboundObserver observerA = + BeamFnDataInboundObserver.forConsumers( + Arrays.asList( + DataEndpoint.create( + TRANSFORM_ID_A, + CODER, + (WindowedValue elem) -> { + receivedAElement.countDown(); + inboundValuesA.add(elem); + })), + Collections.emptyList()); + CompletableFuture future = + CompletableFuture.runAsync( + () -> { + try { + observerA.awaitCompletion(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + clientFactory.registerReceiver( + INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), observerA); + + waitForClientToConnect.await(); + outboundServerObserver.get().onNext(ELEMENTS_B_1); + clientFactory.poisonInstructionId(INSTRUCTION_ID_B); + + outboundServerObserver.get().onNext(ELEMENTS_B_1); + outboundServerObserver.get().onNext(ELEMENTS_A_1); + assertTrue(receivedAElement.await(5, TimeUnit.SECONDS)); + + clientFactory.poisonInstructionId(INSTRUCTION_ID_A); + try { + future.get(); + fail(); // We expect the awaitCompletion to fail due to closing. + } catch (Exception ignored) { + } + + outboundServerObserver.get().onNext(ELEMENTS_A_2); + + assertThat(inboundValuesA, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"))); + } finally { + server.shutdownNow(); + } + } + @Test public void testForOutboundConsumer() throws Exception { CountDownLatch waitForInboundServerValuesCompletion = new CountDownLatch(2);