diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillConnection.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillConnection.java new file mode 100644 index 000000000000..e49a04a7a543 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillConnection.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill; + +import com.google.auto.value.AutoValue; +import java.util.Optional; +import java.util.function.Function; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; +import org.apache.beam.sdk.annotations.Internal; + +@AutoValue +@Internal +public abstract class WindmillConnection { + public static WindmillConnection from( + Endpoint windmillEndpoint, + Function endpointToStubFn) { + WindmillConnection.Builder windmillWorkerConnection = WindmillConnection.builder(); + + windmillEndpoint.workerToken().ifPresent(windmillWorkerConnection::setBackendWorkerToken); + windmillWorkerConnection.setStub(endpointToStubFn.apply(windmillEndpoint)); + + return windmillWorkerConnection.build(); + } + + public static Builder builder() { + return new AutoValue_WindmillConnection.Builder(); + } + + public abstract Optional backendWorkerToken(); + + public abstract CloudWindmillServiceV1Alpha1Stub stub(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setBackendWorkerToken(String backendWorkerToken); + + abstract Builder setStub(CloudWindmillServiceV1Alpha1Stub stub); + + abstract WindmillConnection build(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GetWorkTimingInfosTracker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GetWorkTimingInfosTracker.java index 221b18be164c..dc3486d743ad 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GetWorkTimingInfosTracker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GetWorkTimingInfosTracker.java @@ -35,8 +35,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class GetWorkTimingInfosTracker { - +final class GetWorkTimingInfosTracker { private static final Logger LOG = LoggerFactory.getLogger(GetWorkTimingInfosTracker.class); private final Map aggregatedGetWorkStreamLatencies; @@ -53,7 +52,7 @@ class GetWorkTimingInfosTracker { workItemCreationLatency = null; } - public void addTimingInfo(Collection infos) { + void addTimingInfo(Collection infos) { // We want to record duration for each stage and also be reflective on total work item // processing time. It can be tricky because timings of different // StreamingGetWorkResponseChunks can be interleaved. Current strategy is to record the @@ -170,7 +169,7 @@ List getLatencyAttributions() { return latencyAttributions; } - public void reset() { + void reset() { this.aggregatedGetWorkStreamLatencies.clear(); this.workItemCreationEndTime = Instant.EPOCH; this.workItemLastChunkReceivedByWorkerTime = Instant.EPOCH; @@ -178,11 +177,10 @@ public void reset() { } private static class SumAndMaxDurations { - private Duration sum; private Duration max; - public SumAndMaxDurations(Duration sum, Duration max) { + private SumAndMaxDurations(Duration sum, Duration max) { this.sum = sum; this.max = max; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 5d0a5085fe1b..9350b89f182a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -43,7 +43,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -final class GrpcCommitWorkStream +public final class GrpcCommitWorkStream extends AbstractWindmillStream implements CommitWorkStream { private static final Logger LOG = LoggerFactory.getLogger(GrpcCommitWorkStream.class); @@ -82,7 +82,7 @@ private GrpcCommitWorkStream( this.streamingRpcBatchLimit = streamingRpcBatchLimit; } - static GrpcCommitWorkStream create( + public static GrpcCommitWorkStream create( Function, StreamObserver> startCommitWorkRpcFn, BackOff backoff, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java new file mode 100644 index 000000000000..683f94eb71ee --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import com.google.auto.value.AutoValue; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.function.Supplier; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationWorkItemMetadata; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkResponseChunk; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; +import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.ProcessWorkItemClient; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Implementation of {@link GetWorkStream} that passes along a specific {@link + * org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream} and {@link + * org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream} to the + * processing context {@link ProcessWorkItemClient}. During the work item processing lifecycle, + * these direct streams are used to facilitate these RPC calls to specific backend workers. + */ +@Internal +public final class GrpcDirectGetWorkStream + extends AbstractWindmillStream + implements GetWorkStream { + private static final Logger LOG = LoggerFactory.getLogger(GrpcDirectGetWorkStream.class); + private static final StreamingGetWorkRequest HEALTH_CHECK_REQUEST = + StreamingGetWorkRequest.newBuilder() + .setRequestExtension( + Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(0) + .setMaxBytes(0) + .build()) + .build(); + + private final AtomicReference inFlightBudget; + private final AtomicReference nextBudgetAdjustment; + private final AtomicReference pendingResponseBudget; + private final GetWorkRequest request; + private final WorkItemProcessor workItemProcessorFn; + private final ThrottleTimer getWorkThrottleTimer; + private final Supplier getDataStream; + private final Supplier commitWorkStream; + /** + * Map of stream IDs to their buffers. Used to aggregate streaming gRPC response chunks as they + * come in. Once all chunks for a response has been received, the chunk is processed and the + * buffer is cleared. + */ + private final ConcurrentMap workItemBuffers; + + private GrpcDirectGetWorkStream( + Function< + StreamObserver, + StreamObserver> + startGetWorkRpcFn, + GetWorkRequest request, + BackOff backoff, + StreamObserverFactory streamObserverFactory, + Set> streamRegistry, + int logEveryNStreamFailures, + ThrottleTimer getWorkThrottleTimer, + Supplier getDataStream, + Supplier commitWorkStream, + WorkItemProcessor workItemProcessorFn) { + super( + startGetWorkRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures); + this.request = request; + this.getWorkThrottleTimer = getWorkThrottleTimer; + this.workItemProcessorFn = workItemProcessorFn; + this.workItemBuffers = new ConcurrentHashMap<>(); + // Use the same GetDataStream and CommitWorkStream instances to process all the work in this + // stream. + this.getDataStream = Suppliers.memoize(getDataStream::get); + this.commitWorkStream = Suppliers.memoize(commitWorkStream::get); + this.inFlightBudget = new AtomicReference<>(GetWorkBudget.noBudget()); + this.nextBudgetAdjustment = new AtomicReference<>(GetWorkBudget.noBudget()); + this.pendingResponseBudget = new AtomicReference<>(GetWorkBudget.noBudget()); + } + + public static GrpcDirectGetWorkStream create( + Function< + StreamObserver, + StreamObserver> + startGetWorkRpcFn, + GetWorkRequest request, + BackOff backoff, + StreamObserverFactory streamObserverFactory, + Set> streamRegistry, + int logEveryNStreamFailures, + ThrottleTimer getWorkThrottleTimer, + Supplier getDataStream, + Supplier commitWorkStream, + WorkItemProcessor workItemProcessorFn) { + GrpcDirectGetWorkStream getWorkStream = + new GrpcDirectGetWorkStream( + startGetWorkRpcFn, + request, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + getWorkThrottleTimer, + getDataStream, + commitWorkStream, + workItemProcessorFn); + getWorkStream.startStream(); + return getWorkStream; + } + + private synchronized GetWorkBudget getThenResetBudgetAdjustment() { + return nextBudgetAdjustment.getAndUpdate(unused -> GetWorkBudget.noBudget()); + } + + private void sendRequestExtension() { + // Just sent the request extension, reset the nextBudgetAdjustment. This will be set when + // adjustBudget is called. + GetWorkBudget adjustment = getThenResetBudgetAdjustment(); + StreamingGetWorkRequest extension = + StreamingGetWorkRequest.newBuilder() + .setRequestExtension( + Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(adjustment.items()) + .setMaxBytes(adjustment.bytes())) + .build(); + + executor() + .execute( + () -> { + try { + send(extension); + } catch (IllegalStateException e) { + // Stream was closed. + } + }); + } + + @Override + protected synchronized void onNewStream() { + workItemBuffers.clear(); + // Add the current in-flight budget to the next adjustment. Only positive values are allowed + // here + // with negatives defaulting to 0, since GetWorkBudgets cannot be created with negative values. + GetWorkBudget budgetAdjustment = nextBudgetAdjustment.get().apply(inFlightBudget.get()); + inFlightBudget.set(budgetAdjustment); + send( + StreamingGetWorkRequest.newBuilder() + .setRequest( + request + .toBuilder() + .setMaxBytes(budgetAdjustment.bytes()) + .setMaxItems(budgetAdjustment.items())) + .build()); + + // We just sent the budget, reset it. + nextBudgetAdjustment.set(GetWorkBudget.noBudget()); + } + + @Override + protected boolean hasPendingRequests() { + return false; + } + + @Override + public void appendSpecificHtml(PrintWriter writer) { + // Number of buffers is same as distinct workers that sent work on this stream. + writer.format( + "GetWorkStream: %d buffers, %s inflight budget allowed.", + workItemBuffers.size(), inFlightBudget.get()); + } + + @Override + public void sendHealthCheck() { + send(HEALTH_CHECK_REQUEST); + } + + @Override + protected void onResponse(StreamingGetWorkResponseChunk chunk) { + getWorkThrottleTimer.stop(); + WorkItemBuffer workItemBuffer = + workItemBuffers.computeIfAbsent(chunk.getStreamId(), unused -> new WorkItemBuffer()); + workItemBuffer.append(chunk); + + // The entire WorkItem has been received, it is ready to be processed. + if (chunk.getRemainingBytesForWorkItem() == 0) { + workItemBuffer.runAndReset(); + // Record the fact that there are now fewer outstanding messages and bytes on the stream. + inFlightBudget.updateAndGet(budget -> budget.subtract(1, workItemBuffer.bufferedSize())); + } + } + + @Override + protected void startThrottleTimer() { + getWorkThrottleTimer.start(); + } + + @Override + public synchronized void adjustBudget(long itemsDelta, long bytesDelta) { + nextBudgetAdjustment.set(nextBudgetAdjustment.get().apply(itemsDelta, bytesDelta)); + sendRequestExtension(); + } + + @Override + public GetWorkBudget remainingBudget() { + // Snapshot the current budgets. + GetWorkBudget currentPendingResponseBudget = pendingResponseBudget.get(); + GetWorkBudget currentNextBudgetAdjustment = nextBudgetAdjustment.get(); + GetWorkBudget currentInflightBudget = inFlightBudget.get(); + + return currentPendingResponseBudget + .apply(currentNextBudgetAdjustment) + .apply(currentInflightBudget); + } + + private synchronized void updatePendingResponseBudget(long itemsDelta, long bytesDelta) { + pendingResponseBudget.set(pendingResponseBudget.get().apply(itemsDelta, bytesDelta)); + } + + @AutoValue + abstract static class ComputationMetadata { + private static ComputationMetadata fromProto(ComputationWorkItemMetadata metadataProto) { + return new AutoValue_GrpcDirectGetWorkStream_ComputationMetadata( + metadataProto.getComputationId(), + WindmillTimeUtils.windmillToHarnessWatermark(metadataProto.getInputDataWatermark()), + WindmillTimeUtils.windmillToHarnessWatermark( + metadataProto.getDependentRealtimeInputWatermark())); + } + + abstract String computationId(); + + abstract Instant inputDataWatermark(); + + abstract Instant synchronizedProcessingTime(); + } + + private class WorkItemBuffer { + private final GetWorkTimingInfosTracker workTimingInfosTracker; + private ByteString data; + private @Nullable ComputationMetadata metadata; + + private WorkItemBuffer() { + workTimingInfosTracker = new GetWorkTimingInfosTracker(System::currentTimeMillis); + data = ByteString.EMPTY; + this.metadata = null; + } + + private void append(StreamingGetWorkResponseChunk chunk) { + if (chunk.hasComputationMetadata()) { + this.metadata = ComputationMetadata.fromProto(chunk.getComputationMetadata()); + } + + this.data = data.concat(chunk.getSerializedWorkItem()); + workTimingInfosTracker.addTimingInfo(chunk.getPerWorkItemTimingInfosList()); + } + + private long bufferedSize() { + return data.size(); + } + + private void runAndReset() { + try { + WorkItem workItem = WorkItem.parseFrom(data.newInput()); + updatePendingResponseBudget(1, workItem.getSerializedSize()); + Preconditions.checkNotNull(metadata); + workItemProcessorFn.processWork( + metadata.computationId(), + metadata.inputDataWatermark(), + metadata.synchronizedProcessingTime(), + ProcessWorkItemClient.create( + WorkItem.parseFrom(data.newInput()), getDataStream.get(), commitWorkStream.get()), + // After the work item is successfully queued or dropped by ActiveWorkState, remove it + // from the pendingResponseBudget. + queuedWorkItem -> updatePendingResponseBudget(-1, -workItem.getSerializedSize()), + workTimingInfosTracker.getLatencyAttributions()); + } catch (IOException e) { + LOG.error("Failed to parse work item from stream: ", e); + } + workTimingInfosTracker.reset(); + data = ByteString.EMPTY; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index ea9cd7f0fa32..a04a961ca9c2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -53,7 +53,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -final class GrpcGetDataStream +public final class GrpcGetDataStream extends AbstractWindmillStream implements GetDataStream { private static final Logger LOG = LoggerFactory.getLogger(GrpcGetDataStream.class); @@ -86,7 +86,7 @@ private GrpcGetDataStream( this.pending = new ConcurrentHashMap<>(); } - static GrpcGetDataStream create( + public static GrpcGetDataStream create( Function, StreamObserver> startGetDataRpcFn, BackOff backoff, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java index a403feddb450..35524dbd2eeb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java @@ -36,7 +36,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -final class GrpcGetWorkerMetadataStream +public final class GrpcGetWorkerMetadataStream extends AbstractWindmillStream implements GetWorkerMetadataStream { private static final Logger LOG = LoggerFactory.getLogger(GrpcGetWorkerMetadataStream.class); @@ -100,6 +100,7 @@ public static GrpcGetWorkerMetadataStream create( metadataVersion, getWorkerMetadataThrottleTimer, serverMappingUpdater); + LOG.info("Started GetWorkerMetadataStream. {}", getWorkerMetadataStream); getWorkerMetadataStream.startStream(); return getWorkerMetadataStream; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index e474ebf18b29..099be8db0fda 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -42,7 +42,9 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor; import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; @@ -54,7 +56,8 @@ * RPC streams for health check/heartbeat requests to keep the streams alive. */ @ThreadSafe -public final class GrpcWindmillStreamFactory implements StatusDataProvider { +@Internal +public class GrpcWindmillStreamFactory implements StatusDataProvider { private static final Duration MIN_BACKOFF = Duration.millis(1); private static final Duration DEFAULT_MAX_BACKOFF = Duration.standardSeconds(30); private static final int DEFAULT_LOG_EVERY_N_STREAM_FAILURES = 1; @@ -128,6 +131,26 @@ public GetWorkStream createGetWorkStream( processWorkItem); } + public GetWorkStream createDirectGetWorkStream( + CloudWindmillServiceV1Alpha1Stub stub, + GetWorkRequest request, + ThrottleTimer getWorkThrottleTimer, + Supplier getDataStream, + Supplier commitWorkStream, + WorkItemProcessor workItemProcessor) { + return GrpcDirectGetWorkStream.create( + responseObserver -> withDeadline(stub).getWorkStream(responseObserver), + request, + grpcBackOff.get(), + newStreamObserverFactory(), + streamRegistry, + logEveryNStreamFailures, + getWorkThrottleTimer, + getDataStream, + commitWorkStream, + workItemProcessor); + } + public GetDataStream createGetDataStream( CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer getDataThrottleTimer) { return GrpcGetDataStream.create( @@ -210,8 +233,9 @@ public void appendSummaryHtml(PrintWriter writer) { } } + @Internal @AutoBuilder(ofClass = GrpcWindmillStreamFactory.class) - interface Builder { + public interface Builder { Builder setJobHeader(JobHeader jobHeader); Builder setLogEveryNStreamFailures(int logEveryNStreamFailures); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java new file mode 100644 index 000000000000..01783f6aa4d3 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java @@ -0,0 +1,401 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; + +import java.util.Collection; +import java.util.List; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Queue; +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.function.Supplier; +import javax.annotation.CheckReturnValue; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Queues; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Client for StreamingEngine. Given a {@link GetWorkBudget}, divides the budget and starts the + * {@link WindmillStream.GetWorkStream}(s). + */ +@Internal +@CheckReturnValue +@ThreadSafe +public final class StreamingEngineClient { + private static final Logger LOG = LoggerFactory.getLogger(StreamingEngineClient.class); + private static final String PUBLISH_NEW_WORKER_METADATA_THREAD = "PublishNewWorkerMetadataThread"; + private static final String CONSUME_NEW_WORKER_METADATA_THREAD = "ConsumeNewWorkerMetadataThread"; + + private final AtomicBoolean started; + private final JobHeader jobHeader; + private final GrpcWindmillStreamFactory streamFactory; + private final WorkItemProcessor workItemProcessor; + private final WindmillStubFactory stubFactory; + private final GrpcDispatcherClient dispatcherClient; + private final AtomicBoolean isBudgetRefreshPaused; + private final GetWorkBudgetRefresher getWorkBudgetRefresher; + private final AtomicReference lastBudgetRefresh; + private final ThrottleTimer getWorkerMetadataThrottleTimer; + private final ExecutorService newWorkerMetadataPublisher; + private final ExecutorService newWorkerMetadataConsumer; + private final long clientId; + private final Supplier getWorkerMetadataStream; + private final Queue newWindmillEndpoints; + /** Writes are guarded by synchronization, reads are lock free. */ + private final AtomicReference connections; + + @SuppressWarnings("FutureReturnValueIgnored") + private StreamingEngineClient( + JobHeader jobHeader, + GetWorkBudget totalGetWorkBudget, + AtomicReference connections, + GrpcWindmillStreamFactory streamFactory, + WorkItemProcessor workItemProcessor, + WindmillStubFactory stubFactory, + GetWorkBudgetDistributor getWorkBudgetDistributor, + GrpcDispatcherClient dispatcherClient, + long clientId) { + this.jobHeader = jobHeader; + this.started = new AtomicBoolean(); + this.streamFactory = streamFactory; + this.workItemProcessor = workItemProcessor; + this.connections = connections; + this.stubFactory = stubFactory; + this.dispatcherClient = dispatcherClient; + this.isBudgetRefreshPaused = new AtomicBoolean(false); + this.getWorkerMetadataThrottleTimer = new ThrottleTimer(); + this.newWorkerMetadataPublisher = + singleThreadedExecutorServiceOf(PUBLISH_NEW_WORKER_METADATA_THREAD); + this.newWorkerMetadataConsumer = + singleThreadedExecutorServiceOf(CONSUME_NEW_WORKER_METADATA_THREAD); + this.clientId = clientId; + this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH); + this.newWindmillEndpoints = Queues.synchronizedQueue(EvictingQueue.create(1)); + this.getWorkBudgetRefresher = + new GetWorkBudgetRefresher( + isBudgetRefreshPaused::get, + () -> { + getWorkBudgetDistributor.distributeBudget( + connections.get().windmillStreams().values(), totalGetWorkBudget); + lastBudgetRefresh.set(Instant.now()); + }); + this.getWorkerMetadataStream = + Suppliers.memoize( + () -> + streamFactory.createGetWorkerMetadataStream( + dispatcherClient.getDispatcherStub(), + getWorkerMetadataThrottleTimer, + endpoints -> + // Run this on a separate thread than the grpc stream thread. + newWorkerMetadataPublisher.submit( + () -> newWindmillEndpoints.add(endpoints)))); + } + + private static ExecutorService singleThreadedExecutorServiceOf(String threadName) { + return Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder() + .setNameFormat(threadName) + .setUncaughtExceptionHandler( + (t, e) -> { + LOG.error( + "{} failed due to uncaught exception during execution. ", t.getName(), e); + throw new StreamingEngineClientException(e); + }) + .build()); + } + + /** + * Creates an instance of {@link StreamingEngineClient} and starts the {@link + * GetWorkerMetadataStream} with an RPC to the StreamingEngine backend. {@link + * GetWorkerMetadataStream} will populate {@link #connections} when a response is received. + * + * @implNote Does not block the calling thread. + */ + public static StreamingEngineClient create( + JobHeader jobHeader, + GetWorkBudget totalGetWorkBudget, + GrpcWindmillStreamFactory streamingEngineStreamFactory, + WorkItemProcessor processWorkItem, + WindmillStubFactory windmillGrpcStubFactory, + GetWorkBudgetDistributor getWorkBudgetDistributor, + GrpcDispatcherClient dispatcherClient) { + StreamingEngineClient streamingEngineClient = + new StreamingEngineClient( + jobHeader, + totalGetWorkBudget, + new AtomicReference<>(StreamingEngineConnectionState.EMPTY), + streamingEngineStreamFactory, + processWorkItem, + windmillGrpcStubFactory, + getWorkBudgetDistributor, + dispatcherClient, + new Random().nextLong()); + streamingEngineClient.startGetWorkerMetadataStream(); + streamingEngineClient.startWorkerMetadataConsumer(); + streamingEngineClient.getWorkBudgetRefresher.start(); + return streamingEngineClient; + } + + @VisibleForTesting + static StreamingEngineClient forTesting( + JobHeader jobHeader, + GetWorkBudget totalGetWorkBudget, + AtomicReference connections, + GrpcWindmillStreamFactory streamFactory, + WorkItemProcessor processWorkItem, + WindmillStubFactory stubFactory, + GetWorkBudgetDistributor getWorkBudgetDistributor, + GrpcDispatcherClient dispatcherClient, + long clientId) { + StreamingEngineClient streamingEngineClient = + new StreamingEngineClient( + jobHeader, + totalGetWorkBudget, + connections, + streamFactory, + processWorkItem, + stubFactory, + getWorkBudgetDistributor, + dispatcherClient, + clientId); + streamingEngineClient.startGetWorkerMetadataStream(); + streamingEngineClient.startWorkerMetadataConsumer(); + streamingEngineClient.getWorkBudgetRefresher.start(); + return streamingEngineClient; + } + + @SuppressWarnings("FutureReturnValueIgnored") + private void startWorkerMetadataConsumer() { + newWorkerMetadataConsumer.submit( + () -> { + while (true) { + Optional.ofNullable(newWindmillEndpoints.poll()) + .ifPresent(this::consumeWindmillWorkerEndpoints); + } + }); + } + + @VisibleForTesting + boolean isWorkerMetadataReady() { + return !connections.get().equals(StreamingEngineConnectionState.EMPTY); + } + + @VisibleForTesting + void finish() { + if (!started.compareAndSet(true, false)) { + return; + } + getWorkerMetadataStream.get().close(); + getWorkBudgetRefresher.stop(); + newWorkerMetadataPublisher.shutdownNow(); + newWorkerMetadataConsumer.shutdownNow(); + } + + /** + * {@link java.util.function.Consumer} used to update {@link #connections} on + * new backend worker metadata. + */ + private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWindmillEndpoints) { + isBudgetRefreshPaused.set(true); + LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints); + ImmutableMap newWindmillConnections = + createNewWindmillConnections(newWindmillEndpoints.windmillEndpoints()); + + StreamingEngineConnectionState newConnectionsState = + StreamingEngineConnectionState.builder() + .setWindmillConnections(newWindmillConnections) + .setWindmillStreams( + closeStaleStreamsAndCreateNewStreams(newWindmillConnections.values())) + .setGlobalDataStreams( + createNewGlobalDataStreams(newWindmillEndpoints.globalDataEndpoints())) + .build(); + + LOG.info( + "Setting new connections: {}. Previous connections: {}.", + newConnectionsState, + connections.get()); + connections.set(newConnectionsState); + isBudgetRefreshPaused.set(false); + getWorkBudgetRefresher.requestBudgetRefresh(); + } + + public final ImmutableList getAndResetThrottleTimes() { + StreamingEngineConnectionState currentConnections = connections.get(); + + ImmutableList keyedWorkStreamThrottleTimes = + currentConnections.windmillStreams().values().stream() + .map(WindmillStreamSender::getAndResetThrottleTime) + .collect(toImmutableList()); + + return ImmutableList.builder() + .add(getWorkerMetadataThrottleTimer.getAndResetThrottleTime()) + .addAll(keyedWorkStreamThrottleTimes) + .build(); + } + + /** Starts {@link GetWorkerMetadataStream}. */ + @SuppressWarnings({ + "ReturnValueIgnored", // starts the stream, this value is memoized. + }) + private void startGetWorkerMetadataStream() { + started.set(true); + getWorkerMetadataStream.get(); + } + + private synchronized ImmutableMap createNewWindmillConnections( + List newWindmillEndpoints) { + ImmutableMap currentConnections = + connections.get().windmillConnections(); + return newWindmillEndpoints.stream() + .collect( + toImmutableMap( + Function.identity(), + // Reuse existing stubs if they exist. + endpoint -> + currentConnections.getOrDefault( + endpoint, WindmillConnection.from(endpoint, this::createWindmillStub)))); + } + + private synchronized ImmutableMap + closeStaleStreamsAndCreateNewStreams(Collection newWindmillConnections) { + ImmutableMap currentStreams = + connections.get().windmillStreams(); + + // Close the streams that are no longer valid. + currentStreams.entrySet().stream() + .filter( + connectionAndStream -> !newWindmillConnections.contains(connectionAndStream.getKey())) + .map(Entry::getValue) + .forEach(WindmillStreamSender::closeAllStreams); + + return newWindmillConnections.stream() + .collect( + toImmutableMap( + Function.identity(), + newConnection -> + Optional.ofNullable(currentStreams.get(newConnection)) + .orElseGet(() -> createAndStartWindmillStreamSenderFor(newConnection)))); + } + + private ImmutableMap> createNewGlobalDataStreams( + ImmutableMap newGlobalDataEndpoints) { + ImmutableMap> currentGlobalDataStreams = + connections.get().globalDataStreams(); + return newGlobalDataEndpoints.entrySet().stream() + .collect( + toImmutableMap( + Entry::getKey, + keyedEndpoint -> + existingOrNewGetDataStreamFor(keyedEndpoint, currentGlobalDataStreams))); + } + + private Supplier existingOrNewGetDataStreamFor( + Entry keyedEndpoint, + ImmutableMap> currentGlobalDataStreams) { + return Preconditions.checkNotNull( + currentGlobalDataStreams.getOrDefault( + keyedEndpoint.getKey(), + () -> + streamFactory.createGetDataStream( + newOrExistingStubFor(keyedEndpoint.getValue()), new ThrottleTimer()))); + } + + private CloudWindmillServiceV1Alpha1Stub newOrExistingStubFor(Endpoint endpoint) { + return Optional.ofNullable(connections.get().windmillConnections().get(endpoint)) + .map(WindmillConnection::stub) + .orElseGet(() -> createWindmillStub(endpoint)); + } + + private WindmillStreamSender createAndStartWindmillStreamSenderFor( + WindmillConnection connection) { + // Initially create each stream with no budget. The budget will be eventually assigned by the + // GetWorkBudgetDistributor. + WindmillStreamSender windmillStreamSender = + WindmillStreamSender.create( + connection.stub(), + GetWorkRequest.newBuilder() + .setClientId(clientId) + .setJobId(jobHeader.getJobId()) + .setProjectId(jobHeader.getProjectId()) + .setWorkerId(jobHeader.getWorkerId()) + .build(), + GetWorkBudget.noBudget(), + streamFactory, + workItemProcessor); + windmillStreamSender.startStreams(); + return windmillStreamSender; + } + + private CloudWindmillServiceV1Alpha1Stub createWindmillStub(Endpoint endpoint) { + switch (stubFactory.getKind()) { + // This is only used in tests. + case IN_PROCESS: + return stubFactory.inProcess().get(); + // Create stub for direct_endpoint or just default to Dispatcher stub. + case REMOTE: + return endpoint + .directEndpoint() + .map(stubFactory.remote()) + .orElseGet(dispatcherClient::getDispatcherStub); + // Should never be called, this switch statement is exhaustive. + default: + throw new UnsupportedOperationException( + "Only remote or in-process stub factories are available."); + } + } + + private static class StreamingEngineClientException extends IllegalStateException { + + private StreamingEngineClientException(Throwable exception) { + super(exception); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineConnectionState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineConnectionState.java new file mode 100644 index 000000000000..8d784456d655 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineConnectionState.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import com.google.auto.value.AutoValue; +import java.util.function.Supplier; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; + +/** + * Represents the current state of connections to Streaming Engine. Connections are updated when + * backend workers assigned to the key ranges being processed by this user worker change during + * pipeline execution. For example, changes can happen via autoscaling, load-balancing, or other + * backend updates. + */ +@AutoValue +abstract class StreamingEngineConnectionState { + static final StreamingEngineConnectionState EMPTY = builder().build(); + + static Builder builder() { + return new AutoValue_StreamingEngineConnectionState.Builder() + .setWindmillConnections(ImmutableMap.of()) + .setWindmillStreams(ImmutableMap.of()) + .setGlobalDataStreams(ImmutableMap.of()); + } + + abstract ImmutableMap windmillConnections(); + + abstract ImmutableMap windmillStreams(); + + /** Mapping of GlobalDataIds and the direct GetDataStreams used fetch them. */ + abstract ImmutableMap> globalDataStreams(); + + @AutoValue.Builder + abstract static class Builder { + public abstract Builder setWindmillConnections( + ImmutableMap value); + + public abstract Builder setWindmillStreams( + ImmutableMap value); + + public abstract Builder setGlobalDataStreams( + ImmutableMap> value); + + public abstract StreamingEngineConnectionState build(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java new file mode 100644 index 000000000000..bef710329ffa --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.StreamingEngineThrottleTimers; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; + +/** + * Owns and maintains a set of streams used to communicate with a specific Windmill worker. + * Underlying streams are "cached" in a threadsafe manner so that once {@link Supplier#get} is + * called, a stream that is already started is returned. + * + *

Holds references to {@link + * Supplier} because + * initializing the streams automatically start them, and we want to do so lazily here once the + * {@link GetWorkBudget} is set. + * + *

Once started, the underlying streams are "alive" until they are manually closed via {@link + * #closeAllStreams()}. + * + *

If closed, it means that the backend endpoint is no longer in the worker set. Once closed, + * these instances are not reused. + * + * @implNote Does not manage streams for fetching {@link + * org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData} for side inputs. + */ +@Internal +@ThreadSafe +public class WindmillStreamSender { + private final AtomicBoolean started; + private final AtomicReference getWorkBudget; + private final Supplier getWorkStream; + private final Supplier getDataStream; + private final Supplier commitWorkStream; + private final StreamingEngineThrottleTimers streamingEngineThrottleTimers; + + private WindmillStreamSender( + CloudWindmillServiceV1Alpha1Stub stub, + GetWorkRequest getWorkRequest, + AtomicReference getWorkBudget, + GrpcWindmillStreamFactory streamingEngineStreamFactory, + WorkItemProcessor workItemProcessor) { + this.started = new AtomicBoolean(false); + this.getWorkBudget = getWorkBudget; + this.streamingEngineThrottleTimers = StreamingEngineThrottleTimers.create(); + + // All streams are memoized/cached since they are expensive to create and some implementations + // perform side effects on construction (i.e. sending initial requests to the stream server to + // initiate the streaming RPC connection). Stream instances connect/reconnect internally so we + // can reuse the same instance through the entire lifecycle of WindmillStreamSender. + this.getDataStream = + Suppliers.memoize( + () -> + streamingEngineStreamFactory.createGetDataStream( + stub, streamingEngineThrottleTimers.getDataThrottleTimer())); + this.commitWorkStream = + Suppliers.memoize( + () -> + streamingEngineStreamFactory.createCommitWorkStream( + stub, streamingEngineThrottleTimers.commitWorkThrottleTimer())); + this.getWorkStream = + Suppliers.memoize( + () -> + streamingEngineStreamFactory.createDirectGetWorkStream( + stub, + withRequestBudget(getWorkRequest, getWorkBudget.get()), + streamingEngineThrottleTimers.getWorkThrottleTimer(), + getDataStream, + commitWorkStream, + workItemProcessor)); + } + + public static WindmillStreamSender create( + CloudWindmillServiceV1Alpha1Stub stub, + GetWorkRequest getWorkRequest, + GetWorkBudget getWorkBudget, + GrpcWindmillStreamFactory streamingEngineStreamFactory, + WorkItemProcessor workItemProcessor) { + return new WindmillStreamSender( + stub, + getWorkRequest, + new AtomicReference<>(getWorkBudget), + streamingEngineStreamFactory, + workItemProcessor); + } + + private static GetWorkRequest withRequestBudget(GetWorkRequest request, GetWorkBudget budget) { + return request.toBuilder().setMaxItems(budget.items()).setMaxBytes(budget.bytes()).build(); + } + + @SuppressWarnings("ReturnValueIgnored") + void startStreams() { + getWorkStream.get(); + getDataStream.get(); + commitWorkStream.get(); + // *stream.get() is all memoized in a threadsafe manner. + started.set(true); + } + + void closeAllStreams() { + // Supplier.get() starts the stream which is an expensive operation as it initiates the + // streaming RPCs by possibly making calls over the network. Do not close the streams unless + // they have already been started. + if (started.get()) { + getWorkStream.get().close(); + getDataStream.get().close(); + commitWorkStream.get().close(); + } + } + + public void adjustBudget(long itemsDelta, long bytesDelta) { + getWorkBudget.set(getWorkBudget.get().apply(itemsDelta, bytesDelta)); + if (started.get()) { + getWorkStream.get().adjustBudget(itemsDelta, bytesDelta); + } + } + + public void adjustBudget(GetWorkBudget adjustment) { + adjustBudget(adjustment.items(), adjustment.bytes()); + } + + public GetWorkBudget remainingGetWorkBudget() { + return started.get() ? getWorkStream.get().remainingBudget() : getWorkBudget.get(); + } + + public long getAndResetThrottleTime() { + return streamingEngineThrottleTimers.getAndResetThrottleTime(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/ProcessWorkItemClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/ProcessWorkItemClient.java new file mode 100644 index 000000000000..1adfe02f45fc --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/ProcessWorkItemClient.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work; + +import com.google.auto.value.AutoValue; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.sdk.annotations.Internal; + +/** + * A client context to process {@link WorkItem} and route all subsequent Windmill WorkItem API calls + * to the same backend worker. Wraps the {@link WorkItem}. + */ +@AutoValue +@Internal +public abstract class ProcessWorkItemClient { + public static ProcessWorkItemClient create( + WorkItem workItem, GetDataStream getDataStream, CommitWorkStream commitWorkStream) { + return new AutoValue_ProcessWorkItemClient(workItem, getDataStream, commitWorkStream); + } + + /** {@link WorkItem} being processed. */ + public abstract WorkItem workItem(); + + /** + * {@link GetDataStream} that connects to the backend Windmill worker handling the {@link + * WorkItem}. + */ + public abstract GetDataStream getDataStream(); + + /** + * {@link CommitWorkStream} that connects to backend Windmill worker handling the {@link + * WorkItem}. + */ + public abstract CommitWorkStream commitWorkStream(); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/WorkItemProcessor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/WorkItemProcessor.java new file mode 100644 index 000000000000..4ebc77775fcd --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/WorkItemProcessor.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work; + +import java.util.Collection; +import java.util.function.Consumer; +import javax.annotation.CheckReturnValue; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; +import org.apache.beam.sdk.annotations.Internal; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; + +@FunctionalInterface +@CheckReturnValue +@Internal +public interface WorkItemProcessor { + /** + * Receives and processes {@link WorkItem}(s) wrapped in its {@link ProcessWorkItemClient} + * processing context. + * + * @param computation the Computation that the Work belongs to. + * @param inputDataWatermark Watermark of when the input data was received by the computation. + * @param synchronizedProcessingTime Aggregate system watermark that also depends on each + * computation's received dependent system watermark value to propagate the system watermark + * downstream. + * @param wrappedWorkItem A workItem and it's processing context, used to route subsequent + * WorkItem API (GetData, CommitWork) RPC calls to the same backend worker, where the WorkItem + * was returned from GetWork. + * @param ackWorkItemQueued Called after an attempt to queue the work item for processing. Used to + * free up pending budget. + * @param getWorkStreamLatencies Latencies per processing stage for the WorkItem for reporting + * back to Streaming Engine backend. + */ + void processWork( + String computation, + @Nullable Instant inputDataWatermark, + @Nullable Instant synchronizedProcessingTime, + ProcessWorkItemClient wrappedWorkItem, + Consumer ackWorkItemQueued, + Collection getWorkStreamLatencies); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java new file mode 100644 index 000000000000..3a17222d3e6b --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work.budget; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.DoubleMath.roundToLong; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath.divide; + +import java.math.RoundingMode; +import java.util.Map; +import java.util.Map.Entry; +import java.util.function.Function; +import java.util.function.Supplier; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.WindmillStreamSender; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Evenly distributes the provided budget across the available {@link WindmillStreamSender}(s). */ +@Internal +final class EvenGetWorkBudgetDistributor implements GetWorkBudgetDistributor { + private static final Logger LOG = LoggerFactory.getLogger(EvenGetWorkBudgetDistributor.class); + private final Supplier activeWorkBudgetSupplier; + + EvenGetWorkBudgetDistributor(Supplier activeWorkBudgetSupplier) { + this.activeWorkBudgetSupplier = activeWorkBudgetSupplier; + } + + private static boolean isBelowFiftyPercentOfTarget( + GetWorkBudget remaining, GetWorkBudget target) { + return remaining.items() < roundToLong(target.items() * 0.5, RoundingMode.CEILING) + || remaining.bytes() < roundToLong(target.bytes() * 0.5, RoundingMode.CEILING); + } + + @Override + public void distributeBudget( + ImmutableCollection streams, GetWorkBudget getWorkBudget) { + if (streams.isEmpty()) { + LOG.debug("Cannot distribute budget to no streams."); + return; + } + + if (getWorkBudget.equals(GetWorkBudget.noBudget())) { + LOG.debug("Cannot distribute 0 budget."); + return; + } + + Map desiredBudgets = + computeDesiredBudgets(streams, getWorkBudget); + + for (Entry streamAndDesiredBudget : + desiredBudgets.entrySet()) { + WindmillStreamSender stream = streamAndDesiredBudget.getKey(); + GetWorkBudget desired = streamAndDesiredBudget.getValue(); + GetWorkBudget remaining = stream.remainingGetWorkBudget(); + if (isBelowFiftyPercentOfTarget(remaining, desired)) { + GetWorkBudget adjustment = desired.subtract(remaining); + stream.adjustBudget(adjustment); + } + } + } + + private ImmutableMap computeDesiredBudgets( + ImmutableCollection streams, GetWorkBudget totalGetWorkBudget) { + GetWorkBudget activeWorkBudget = activeWorkBudgetSupplier.get(); + LOG.info("Current active work budget: {}", activeWorkBudget); + // TODO: Fix possibly non-deterministic handing out of budgets. + // Rounding up here will drift upwards over the lifetime of the streams. + GetWorkBudget budgetPerStream = + GetWorkBudget.builder() + .setItems( + divide( + totalGetWorkBudget.items() - activeWorkBudget.items(), + streams.size(), + RoundingMode.CEILING)) + .setBytes( + divide( + totalGetWorkBudget.bytes() - activeWorkBudget.bytes(), + streams.size(), + RoundingMode.CEILING)) + .build(); + return streams.stream().collect(toImmutableMap(Function.identity(), unused -> budgetPerStream)); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudget.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudget.java index 0038e3e9cc60..bc82b622ce64 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudget.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudget.java @@ -20,13 +20,14 @@ import com.google.auto.value.AutoValue; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.sdk.annotations.Internal; /** * Budget of items and bytes for fetching {@link * org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem}(s) via {@link * WindmillStream.GetWorkStream}. Used to control how "much" work is returned from Windmill. */ +@Internal @AutoValue public abstract class GetWorkBudget { public static GetWorkBudget.Builder builder() { @@ -46,29 +47,26 @@ public static GetWorkBudget from(GetWorkRequest getWorkRequest) { } /** - * Adds the given bytes and items or the current budget, returning a new {@link GetWorkBudget}. - * Does not drop below 0. + * Applies the given bytes and items delta to the current budget, returning a new {@link + * GetWorkBudget}. Does not drop below 0. */ - public GetWorkBudget add(long items, long bytes) { - Preconditions.checkArgument(items >= 0 && bytes >= 0); - return GetWorkBudget.builder().setBytes(bytes() + bytes).setItems(items() + items).build(); + public GetWorkBudget apply(long itemsDelta, long bytesDelta) { + return GetWorkBudget.builder() + .setBytes(bytes() + bytesDelta) + .setItems(items() + itemsDelta) + .build(); } - public GetWorkBudget add(GetWorkBudget other) { - return add(other.items(), other.bytes()); + public GetWorkBudget apply(GetWorkBudget other) { + return apply(other.items(), other.bytes()); } - /** - * Subtracts the given bytes and items or the current budget, returning a new {@link - * GetWorkBudget}. Does not drop below 0. - */ - public GetWorkBudget subtract(long items, long bytes) { - Preconditions.checkArgument(items >= 0 && bytes >= 0); - return GetWorkBudget.builder().setBytes(bytes() - bytes).setItems(items() - items).build(); + public GetWorkBudget subtract(GetWorkBudget other) { + return apply(-other.items(), -other.bytes()); } - public GetWorkBudget subtract(GetWorkBudget other) { - return subtract(other.items(), other.bytes()); + public GetWorkBudget subtract(long items, long bytes) { + return apply(-items, -bytes); } /** Budget of bytes for GetWork. Does not drop below 0. */ @@ -77,6 +75,9 @@ public GetWorkBudget subtract(GetWorkBudget other) { /** Budget of items for GetWork. Does not drop below 0. */ public abstract long items(); + public abstract GetWorkBudget.Builder toBuilder(); + + @Internal @AutoValue.Builder public abstract static class Builder { public abstract Builder setBytes(long bytes); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributor.java new file mode 100644 index 000000000000..3ec9718e041e --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributor.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work.budget; + +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.WindmillStreamSender; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection; + +/** + * Distributes the total {@link GetWorkBudget} amongst the {@link + * org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream}(s) to + * Windmill. + */ +@Internal +public interface GetWorkBudgetDistributor { + void distributeBudget( + ImmutableCollection streams, GetWorkBudget getWorkBudget); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java new file mode 100644 index 000000000000..43c0d46139da --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work.budget; + +import java.util.function.Supplier; +import org.apache.beam.sdk.annotations.Internal; + +@Internal +public final class GetWorkBudgetDistributors { + public static GetWorkBudgetDistributor distributeEvenly( + Supplier activeWorkBudgetSupplier) { + return new EvenGetWorkBudgetDistributor(activeWorkBudgetSupplier); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java new file mode 100644 index 000000000000..e39aa8dbc8a5 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work.budget; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.Supplier; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.fn.stream.AdvancingPhaser; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Handles refreshing the budget either via triggered or scheduled execution using a {@link + * java.util.concurrent.Phaser} to emulate publish/subscribe pattern. + */ +@Internal +@ThreadSafe +public final class GetWorkBudgetRefresher { + @VisibleForTesting public static final int SCHEDULED_BUDGET_REFRESH_MILLIS = 100; + private static final int INITIAL_BUDGET_REFRESH_PHASE = 0; + private static final String BUDGET_REFRESH_THREAD = "GetWorkBudgetRefreshThread"; + private static final Logger LOG = LoggerFactory.getLogger(GetWorkBudgetRefresher.class); + + private final AdvancingPhaser budgetRefreshTrigger; + private final ExecutorService budgetRefreshExecutor; + private final Supplier isBudgetRefreshPaused; + private final Runnable redistributeBudget; + + public GetWorkBudgetRefresher( + Supplier isBudgetRefreshPaused, Runnable redistributeBudget) { + this.budgetRefreshTrigger = new AdvancingPhaser(1); + this.budgetRefreshExecutor = + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder() + .setNameFormat(BUDGET_REFRESH_THREAD) + .setUncaughtExceptionHandler( + (t, e) -> + LOG.error( + "{} failed due to uncaught exception during execution. ", + t.getName(), + e)) + .build()); + this.isBudgetRefreshPaused = isBudgetRefreshPaused; + this.redistributeBudget = redistributeBudget; + } + + @SuppressWarnings("FutureReturnValueIgnored") + public void start() { + budgetRefreshExecutor.submit(this::subscribeToRefreshBudget); + } + + /** Publishes an event to trigger a budget refresh. */ + public void requestBudgetRefresh() { + budgetRefreshTrigger.arrive(); + } + + public void stop() { + budgetRefreshTrigger.arriveAndDeregister(); + // Put the budgetRefreshTrigger in a terminated state, #waitForBudgetRefreshEventWithTimeout + // will subsequently return false, and #subscribeToRefreshBudget will return, completing the + // task. + budgetRefreshTrigger.forceTermination(); + budgetRefreshExecutor.shutdownNow(); + } + + private void subscribeToRefreshBudget() { + int currentBudgetRefreshPhase = INITIAL_BUDGET_REFRESH_PHASE; + // Runs forever until #stop is called. + while (true) { + currentBudgetRefreshPhase = waitForBudgetRefreshEventWithTimeout(currentBudgetRefreshPhase); + // Phaser.awaitAdvanceInterruptibly(...) returns a negative value if the phaser is + // terminated, else returns when either a budget refresh has been manually triggered or + // SCHEDULED_BUDGET_REFRESH_MILLIS have passed. + if (currentBudgetRefreshPhase < 0) { + return; + } + // Budget refreshes are paused during endpoint updates. + if (!isBudgetRefreshPaused.get()) { + redistributeBudget.run(); + } + } + } + + /** + * Waits for a budget refresh trigger event with a timeout. Returns the current phase of the + * {@link #budgetRefreshTrigger}, to be used for following waits for the {@link + * #budgetRefreshTrigger} to advance. + * + *

Budget refresh event is triggered when {@link #budgetRefreshTrigger} moves on from the given + * currentBudgetRefreshPhase. + */ + private int waitForBudgetRefreshEventWithTimeout(int currentBudgetRefreshPhase) { + try { + // Wait for budgetRefreshTrigger to advance FROM the current phase. + return budgetRefreshTrigger.awaitAdvanceInterruptibly( + currentBudgetRefreshPhase, SCHEDULED_BUDGET_REFRESH_MILLIS, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new BudgetRefreshException("Error occurred waiting for budget refresh.", e); + } catch (TimeoutException ignored) { + // Intentionally do nothing since we trigger the budget refresh on the timeout. + } + + return currentBudgetRefreshPhase; + } + + private static class BudgetRefreshException extends RuntimeException { + private BudgetRefreshException(String msg, Throwable sourceException) { + super(msg, sourceException); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java new file mode 100644 index 000000000000..8a2c643a5b76 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java @@ -0,0 +1,417 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessSocketAddress; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class StreamingEngineClientTest { + private static final WindmillServiceAddress DEFAULT_WINDMILL_SERVICE_ADDRESS = + WindmillServiceAddress.create(HostAndPort.fromParts(WindmillChannelFactory.LOCALHOST, 443)); + private static final ImmutableMap DEFAULT = + ImmutableMap.of( + "global_data", + WorkerMetadataResponse.Endpoint.newBuilder() + .setDirectEndpoint(DEFAULT_WINDMILL_SERVICE_ADDRESS.gcpServiceAddress().toString()) + .build()); + private static final long CLIENT_ID = 1L; + private static final String JOB_ID = "jobId"; + private static final String PROJECT_ID = "projectId"; + private static final String WORKER_ID = "workerId"; + private static final JobHeader JOB_HEADER = + JobHeader.newBuilder() + .setJobId(JOB_ID) + .setProjectId(PROJECT_ID) + .setWorkerId(WORKER_ID) + .build(); + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final Set channels = new HashSet<>(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + + private final GrpcWindmillStreamFactory streamFactory = + spy(GrpcWindmillStreamFactory.of(JOB_HEADER).build()); + private final WindmillStubFactory stubFactory = + WindmillStubFactory.inProcessStubFactory( + "StreamingEngineClientTest", + name -> { + ManagedChannel channel = + grpcCleanup.register(WindmillChannelFactory.inProcessChannel(name)); + channels.add(channel); + return channel; + }); + private final GrpcDispatcherClient dispatcherClient = + GrpcDispatcherClient.forTesting(stubFactory, new ArrayList<>(), new HashSet<>()); + private final GetWorkBudgetDistributor getWorkBudgetDistributor = + spy(new TestGetWorkBudgetDistributor()); + private final AtomicReference connections = + new AtomicReference<>(StreamingEngineConnectionState.EMPTY); + private Server fakeStreamingEngineServer; + private CountDownLatch getWorkerMetadataReady; + private GetWorkerMetadataTestStub fakeGetWorkerMetadataStub; + + private StreamingEngineClient streamingEngineClient; + + private static WorkItemProcessor noOpProcessWorkItemFn() { + return (computation, + inputDataWatermark, + synchronizedProcessingTime, + workItem, + ackQueuedWorkItem, + getWorkStreamLatencies) -> {}; + } + + private static GetWorkRequest getWorkRequest(long items, long bytes) { + return GetWorkRequest.newBuilder() + .setJobId(JOB_ID) + .setProjectId(PROJECT_ID) + .setWorkerId(WORKER_ID) + .setClientId(CLIENT_ID) + .setMaxItems(items) + .setMaxBytes(bytes) + .build(); + } + + private static WorkerMetadataResponse.Endpoint metadataResponseEndpoint(String workerToken) { + return WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken).build(); + } + + @Before + public void setUp() throws IOException { + channels.forEach(ManagedChannel::shutdownNow); + channels.clear(); + fakeStreamingEngineServer = + grpcCleanup.register( + InProcessServerBuilder.forName("StreamingEngineClientTest") + .fallbackHandlerRegistry(serviceRegistry) + .executor(Executors.newFixedThreadPool(1)) + .build()); + + fakeStreamingEngineServer.start(); + dispatcherClient.consumeWindmillDispatcherEndpoints( + ImmutableSet.of( + HostAndPort.fromString( + new InProcessSocketAddress("StreamingEngineClientTest").toString()))); + getWorkerMetadataReady = new CountDownLatch(1); + fakeGetWorkerMetadataStub = new GetWorkerMetadataTestStub(getWorkerMetadataReady); + serviceRegistry.addService(fakeGetWorkerMetadataStub); + } + + @After + public void cleanUp() { + fakeGetWorkerMetadataStub.close(); + fakeStreamingEngineServer.shutdownNow(); + channels.forEach(ManagedChannel::shutdownNow); + Preconditions.checkNotNull(streamingEngineClient).finish(); + } + + private StreamingEngineClient newStreamingEngineClient( + GetWorkBudget getWorkBudget, WorkItemProcessor workItemProcessor) { + return StreamingEngineClient.forTesting( + JOB_HEADER, + getWorkBudget, + connections, + streamFactory, + workItemProcessor, + stubFactory, + getWorkBudgetDistributor, + dispatcherClient, + CLIENT_ID); + } + + @Test + public void testStreamsStartCorrectly() throws InterruptedException { + long items = 10L; + long bytes = 10L; + + streamingEngineClient = + newStreamingEngineClient( + GetWorkBudget.builder().setItems(items).setBytes(bytes).build(), + noOpProcessWorkItemFn()); + + String workerToken = "workerToken1"; + String workerToken2 = "workerToken2"; + + WorkerMetadataResponse firstWorkerMetadata = + WorkerMetadataResponse.newBuilder() + .setMetadataVersion(1) + .addWorkEndpoints(metadataResponseEndpoint(workerToken)) + .addWorkEndpoints(metadataResponseEndpoint(workerToken2)) + .putAllGlobalDataEndpoints(DEFAULT) + .build(); + + getWorkerMetadataReady.await(); + fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); + StreamingEngineConnectionState currentConnections = waitForWorkerMetadataToBeConsumed(1); + + assertEquals(2, currentConnections.windmillConnections().size()); + assertEquals(2, currentConnections.windmillStreams().size()); + Set workerTokens = + connections.get().windmillConnections().values().stream() + .map(WindmillConnection::backendWorkerToken) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(Collectors.toSet()); + + assertTrue(workerTokens.contains(workerToken)); + assertTrue(workerTokens.contains(workerToken2)); + + verify(getWorkBudgetDistributor, atLeast(1)) + .distributeBudget( + any(), eq(GetWorkBudget.builder().setItems(items).setBytes(bytes).build())); + + verify(streamFactory, times(2)) + .createDirectGetWorkStream( + any(), eq(getWorkRequest(0, 0)), any(), any(), any(), eq(noOpProcessWorkItemFn())); + + verify(streamFactory, times(2)).createGetDataStream(any(), any()); + verify(streamFactory, times(2)).createCommitWorkStream(any(), any()); + } + + @Test + public void testScheduledBudgetRefresh() throws InterruptedException { + streamingEngineClient = + newStreamingEngineClient( + GetWorkBudget.builder().setItems(1L).setBytes(1L).build(), noOpProcessWorkItemFn()); + + getWorkerMetadataReady.await(); + fakeGetWorkerMetadataStub.injectWorkerMetadata( + WorkerMetadataResponse.newBuilder() + .setMetadataVersion(1) + .addWorkEndpoints(metadataResponseEndpoint("workerToken")) + .putAllGlobalDataEndpoints(DEFAULT) + .build()); + waitForWorkerMetadataToBeConsumed(1); + Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS); + verify(getWorkBudgetDistributor, atLeast(2)).distributeBudget(any(), any()); + } + + @Test + public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() + throws InterruptedException { + streamingEngineClient = + newStreamingEngineClient( + GetWorkBudget.builder().setItems(1).setBytes(1).build(), noOpProcessWorkItemFn()); + + String workerToken = "workerToken1"; + String workerToken2 = "workerToken2"; + String workerToken3 = "workerToken3"; + + WorkerMetadataResponse firstWorkerMetadata = + WorkerMetadataResponse.newBuilder() + .setMetadataVersion(1) + .addWorkEndpoints( + WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken).build()) + .addWorkEndpoints( + WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken2).build()) + .putAllGlobalDataEndpoints(DEFAULT) + .build(); + WorkerMetadataResponse secondWorkerMetadata = + WorkerMetadataResponse.newBuilder() + .setMetadataVersion(2) + .addWorkEndpoints( + WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken3).build()) + .putAllGlobalDataEndpoints(DEFAULT) + .build(); + + getWorkerMetadataReady.await(); + fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); + fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); + + StreamingEngineConnectionState currentConnections = waitForWorkerMetadataToBeConsumed(2); + + assertEquals(1, currentConnections.windmillConnections().size()); + assertEquals(1, currentConnections.windmillStreams().size()); + Set workerTokens = + connections.get().windmillConnections().values().stream() + .map(WindmillConnection::backendWorkerToken) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(Collectors.toSet()); + + assertFalse(workerTokens.contains(workerToken)); + assertFalse(workerTokens.contains(workerToken2)); + } + + @Test + public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedException { + streamingEngineClient = + newStreamingEngineClient( + GetWorkBudget.builder().setItems(1).setBytes(1).build(), noOpProcessWorkItemFn()); + + String workerToken = "workerToken1"; + String workerToken2 = "workerToken2"; + String workerToken3 = "workerToken3"; + + WorkerMetadataResponse firstWorkerMetadata = + WorkerMetadataResponse.newBuilder() + .setMetadataVersion(1) + .addWorkEndpoints( + WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken).build()) + .putAllGlobalDataEndpoints(DEFAULT) + .build(); + WorkerMetadataResponse secondWorkerMetadata = + WorkerMetadataResponse.newBuilder() + .setMetadataVersion(2) + .addWorkEndpoints( + WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken2).build()) + .putAllGlobalDataEndpoints(DEFAULT) + .build(); + WorkerMetadataResponse thirdWorkerMetadata = + WorkerMetadataResponse.newBuilder() + .setMetadataVersion(3) + .addWorkEndpoints( + WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken3).build()) + .putAllGlobalDataEndpoints(DEFAULT) + .build(); + + getWorkerMetadataReady.await(); + fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); + Thread.sleep(50); + fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); + Thread.sleep(50); + fakeGetWorkerMetadataStub.injectWorkerMetadata(thirdWorkerMetadata); + Thread.sleep(50); + verify(getWorkBudgetDistributor, atLeast(3)).distributeBudget(any(), any()); + } + + private StreamingEngineConnectionState waitForWorkerMetadataToBeConsumed( + int expectedMetadataConsumed) throws InterruptedException { + int currentMetadataConsumed = 0; + StreamingEngineConnectionState currentConsumedMetadata = StreamingEngineConnectionState.EMPTY; + while (true) { + if (!connections.get().equals(currentConsumedMetadata)) { + ++currentMetadataConsumed; + if (currentMetadataConsumed == expectedMetadataConsumed) { + break; + } + currentConsumedMetadata = connections.get(); + } + } + // Wait for metadata to be consumed and budgets to be redistributed. + Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS); + return connections.get(); + } + + private static class GetWorkerMetadataTestStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private static final WorkerMetadataResponse CLOSE_ALL_STREAMS = + WorkerMetadataResponse.newBuilder().setMetadataVersion(100).build(); + private final CountDownLatch ready; + private @Nullable StreamObserver responseObserver; + + private GetWorkerMetadataTestStub(CountDownLatch ready) { + this.ready = ready; + } + + @Override + public StreamObserver getWorkerMetadataStream( + StreamObserver responseObserver) { + if (this.responseObserver == null) { + ready.countDown(); + this.responseObserver = responseObserver; + } + + return new StreamObserver() { + @Override + public void onNext(WorkerMetadataRequest workerMetadataRequest) {} + + @Override + public void onError(Throwable throwable) { + if (responseObserver != null) { + responseObserver.onError(throwable); + } + } + + @Override + public void onCompleted() {} + }; + } + + private void injectWorkerMetadata(WorkerMetadataResponse response) { + if (responseObserver != null) { + responseObserver.onNext(response); + } + } + + private void close() { + if (responseObserver != null) { + // Send an empty response to close out all the streams and channels currently open in + // Streaming Engine Client. + responseObserver.onNext(CLOSE_ALL_STREAMS); + } + } + } + + private static class TestGetWorkBudgetDistributor implements GetWorkBudgetDistributor { + @Override + public void distributeBudget( + ImmutableCollection streams, GetWorkBudget getWorkBudget) { + streams.forEach(stream -> stream.adjustBudget(getWorkBudget.items(), getWorkBudget.bytes())); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSenderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSenderTest.java new file mode 100644 index 000000000000..c8d2974f923d --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSenderTest.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.testing.GrpcCleanupRule; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class WindmillStreamSenderTest { + private static final GetWorkRequest GET_WORK_REQUEST = + GetWorkRequest.newBuilder().setClientId(1L).setJobId("job").setProjectId("project").build(); + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + + private final GrpcWindmillStreamFactory streamFactory = + spy( + GrpcWindmillStreamFactory.of( + JobHeader.newBuilder() + .setJobId("job") + .setProjectId("project") + .setWorkerId("worker") + .build()) + .build()); + private final WorkItemProcessor workItemProcessor = + (computation, + inputDataWatermark, + synchronizedProcessingTime, + workItem, + ackQueuedWorkItem, + getWorkStreamLatencies) -> {}; + private ManagedChannel inProcessChannel; + private CloudWindmillServiceV1Alpha1Stub stub; + + @Before + public void setUp() { + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName("WindmillStreamSenderTest").directExecutor().build()); + grpcCleanup.register(inProcessChannel); + stub = CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + } + + @Test + public void testStartStream_startsAllStreams() { + long itemBudget = 1L; + long byteBudget = 1L; + + WindmillStreamSender windmillStreamSender = + newWindmillStreamSender( + GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build()); + + windmillStreamSender.startStreams(); + + verify(streamFactory) + .createDirectGetWorkStream( + eq(stub), + eq( + GET_WORK_REQUEST + .toBuilder() + .setMaxItems(itemBudget) + .setMaxBytes(byteBudget) + .build()), + any(ThrottleTimer.class), + any(), + any(), + eq(workItemProcessor)); + + verify(streamFactory).createGetDataStream(eq(stub), any(ThrottleTimer.class)); + verify(streamFactory).createCommitWorkStream(eq(stub), any(ThrottleTimer.class)); + } + + @Test + public void testStartStream_onlyStartsStreamsOnce() { + long itemBudget = 1L; + long byteBudget = 1L; + + WindmillStreamSender windmillStreamSender = + newWindmillStreamSender( + GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build()); + + windmillStreamSender.startStreams(); + windmillStreamSender.startStreams(); + windmillStreamSender.startStreams(); + + verify(streamFactory, times(1)) + .createDirectGetWorkStream( + eq(stub), + eq( + GET_WORK_REQUEST + .toBuilder() + .setMaxItems(itemBudget) + .setMaxBytes(byteBudget) + .build()), + any(ThrottleTimer.class), + any(), + any(), + eq(workItemProcessor)); + + verify(streamFactory, times(1)).createGetDataStream(eq(stub), any(ThrottleTimer.class)); + verify(streamFactory, times(1)).createCommitWorkStream(eq(stub), any(ThrottleTimer.class)); + } + + @Test + public void testStartStream_onlyStartsStreamsOnceConcurrent() throws InterruptedException { + long itemBudget = 1L; + long byteBudget = 1L; + + WindmillStreamSender windmillStreamSender = + newWindmillStreamSender( + GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build()); + + Thread startStreamThread = new Thread(windmillStreamSender::startStreams); + startStreamThread.start(); + + windmillStreamSender.startStreams(); + + startStreamThread.join(); + + verify(streamFactory, times(1)) + .createDirectGetWorkStream( + eq(stub), + eq( + GET_WORK_REQUEST + .toBuilder() + .setMaxItems(itemBudget) + .setMaxBytes(byteBudget) + .build()), + any(ThrottleTimer.class), + any(), + any(), + eq(workItemProcessor)); + + verify(streamFactory, times(1)).createGetDataStream(eq(stub), any(ThrottleTimer.class)); + verify(streamFactory, times(1)).createCommitWorkStream(eq(stub), any(ThrottleTimer.class)); + } + + @Test + public void testCloseAllStreams_doesNotCloseUnstartedStreams() { + WindmillStreamSender windmillStreamSender = + newWindmillStreamSender(GetWorkBudget.builder().setBytes(1L).setItems(1L).build()); + + windmillStreamSender.closeAllStreams(); + + verifyNoInteractions(streamFactory); + } + + @Test + public void testCloseAllStreams_closesAllStreams() { + long itemBudget = 1L; + long byteBudget = 1L; + GetWorkRequest getWorkRequestWithBudget = + GET_WORK_REQUEST.toBuilder().setMaxItems(itemBudget).setMaxBytes(byteBudget).build(); + GrpcWindmillStreamFactory mockStreamFactory = mock(GrpcWindmillStreamFactory.class); + GetWorkStream mockGetWorkStream = mock(GetWorkStream.class); + GetDataStream mockGetDataStream = mock(GetDataStream.class); + CommitWorkStream mockCommitWorkStream = mock(CommitWorkStream.class); + + when(mockStreamFactory.createDirectGetWorkStream( + eq(stub), + eq(getWorkRequestWithBudget), + any(ThrottleTimer.class), + any(), + any(), + eq(workItemProcessor))) + .thenReturn(mockGetWorkStream); + + when(mockStreamFactory.createGetDataStream(eq(stub), any(ThrottleTimer.class))) + .thenReturn(mockGetDataStream); + when(mockStreamFactory.createCommitWorkStream(eq(stub), any(ThrottleTimer.class))) + .thenReturn(mockCommitWorkStream); + + WindmillStreamSender windmillStreamSender = + newWindmillStreamSender( + GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build(), + mockStreamFactory); + + windmillStreamSender.startStreams(); + windmillStreamSender.closeAllStreams(); + + verify(mockGetWorkStream).close(); + verify(mockGetDataStream).close(); + verify(mockCommitWorkStream).close(); + } + + private WindmillStreamSender newWindmillStreamSender(GetWorkBudget budget) { + return newWindmillStreamSender(budget, streamFactory); + } + + private WindmillStreamSender newWindmillStreamSender( + GetWorkBudget budget, GrpcWindmillStreamFactory streamFactory) { + return WindmillStreamSender.create( + stub, GET_WORK_REQUEST, budget, streamFactory, workItemProcessor); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java new file mode 100644 index 000000000000..14da55fe2389 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work.budget; + +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +import java.util.ArrayList; +import java.util.List; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillStreamFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.WindmillStreamSender; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class EvenGetWorkBudgetDistributorTest { + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + + private ManagedChannel inProcessChannel; + private CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub; + + private static GetWorkBudgetDistributor createBudgetDistributor(GetWorkBudget activeWorkBudget) { + return GetWorkBudgetDistributors.distributeEvenly(() -> activeWorkBudget); + } + + private static GetWorkBudgetDistributor createBudgetDistributor(long activeWorkItemsAndBytes) { + return createBudgetDistributor( + GetWorkBudget.builder() + .setItems(activeWorkItemsAndBytes) + .setBytes(activeWorkItemsAndBytes) + .build()); + } + + @Before + public void setUp() { + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName("WindmillStreamSenderTest").directExecutor().build()); + grpcCleanup.register(inProcessChannel); + stub = CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + } + + @Test + public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() { + createBudgetDistributor(1L) + .distributeBudget( + ImmutableList.of(), GetWorkBudget.builder().setItems(10L).setBytes(10L).build()); + } + + @Test + public void testDistributeBudget_doesNothingWithNoBudget() { + WindmillStreamSender windmillStreamSender = + spy(createWindmillStreamSender(GetWorkBudget.noBudget())); + createBudgetDistributor(1L) + .distributeBudget(ImmutableList.of(windmillStreamSender), GetWorkBudget.noBudget()); + verifyNoInteractions(windmillStreamSender); + } + + @Test + public void testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighNoActiveWork() { + WindmillStreamSender windmillStreamSender = + spy( + createWindmillStreamSender( + GetWorkBudget.builder().setItems(10L).setBytes(10L).build())); + createBudgetDistributor(0L) + .distributeBudget( + ImmutableList.of(windmillStreamSender), + GetWorkBudget.builder().setItems(10L).setBytes(10L).build()); + + verify(windmillStreamSender, never()).adjustBudget(anyLong(), anyLong()); + } + + @Test + public void + testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighWithActiveWork() { + WindmillStreamSender windmillStreamSender = + spy(createWindmillStreamSender(GetWorkBudget.builder().setItems(5L).setBytes(5L).build())); + createBudgetDistributor(10L) + .distributeBudget( + ImmutableList.of(windmillStreamSender), + GetWorkBudget.builder().setItems(20L).setBytes(20L).build()); + + verify(windmillStreamSender, never()).adjustBudget(anyLong(), anyLong()); + } + + @Test + public void + testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithNoActiveWork() { + GetWorkBudget streamRemainingBudget = + GetWorkBudget.builder().setItems(1L).setBytes(10L).build(); + GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); + WindmillStreamSender windmillStreamSender = + spy(createWindmillStreamSender(streamRemainingBudget)); + createBudgetDistributor(0L) + .distributeBudget(ImmutableList.of(windmillStreamSender), totalGetWorkBudget); + + verify(windmillStreamSender, times(1)) + .adjustBudget( + eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), + eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); + } + + @Test + public void + testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithActiveWork() { + GetWorkBudget streamRemainingBudget = + GetWorkBudget.builder().setItems(1L).setBytes(10L).build(); + GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); + long activeWorkItemsAndBytes = 2L; + WindmillStreamSender windmillStreamSender = + spy(createWindmillStreamSender(streamRemainingBudget)); + createBudgetDistributor(activeWorkItemsAndBytes) + .distributeBudget(ImmutableList.of(windmillStreamSender), totalGetWorkBudget); + + verify(windmillStreamSender, times(1)) + .adjustBudget( + eq( + totalGetWorkBudget.items() + - streamRemainingBudget.items() + - activeWorkItemsAndBytes), + eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); + } + + @Test + public void testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowNoActiveWork() { + GetWorkBudget streamRemainingBudget = + GetWorkBudget.builder().setItems(10L).setBytes(1L).build(); + GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); + WindmillStreamSender windmillStreamSender = + spy(createWindmillStreamSender(streamRemainingBudget)); + createBudgetDistributor(0L) + .distributeBudget(ImmutableList.of(windmillStreamSender), totalGetWorkBudget); + + verify(windmillStreamSender, times(1)) + .adjustBudget( + eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), + eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); + } + + @Test + public void + testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowWithActiveWork() { + GetWorkBudget streamRemainingBudget = + GetWorkBudget.builder().setItems(10L).setBytes(1L).build(); + GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); + long activeWorkItemsAndBytes = 2L; + + WindmillStreamSender windmillStreamSender = + spy(createWindmillStreamSender(streamRemainingBudget)); + createBudgetDistributor(activeWorkItemsAndBytes) + .distributeBudget(ImmutableList.of(windmillStreamSender), totalGetWorkBudget); + + verify(windmillStreamSender, times(1)) + .adjustBudget( + eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), + eq( + totalGetWorkBudget.bytes() + - streamRemainingBudget.bytes() + - activeWorkItemsAndBytes)); + } + + @Test + public void testDistributeBudget_distributesBudgetEvenlyIfPossible() { + long totalItemsAndBytes = 10L; + List streams = new ArrayList<>(); + for (int i = 0; i < totalItemsAndBytes; i++) { + streams.add(spy(createWindmillStreamSender(GetWorkBudget.noBudget()))); + } + createBudgetDistributor(0L) + .distributeBudget( + ImmutableList.copyOf(streams), + GetWorkBudget.builder() + .setItems(totalItemsAndBytes) + .setBytes(totalItemsAndBytes) + .build()); + + long itemsAndBytesPerStream = totalItemsAndBytes / streams.size(); + streams.forEach( + stream -> + verify(stream, times(1)) + .adjustBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); + } + + @Test + public void testDistributeBudget_distributesFairlyWhenNotEven() { + long totalItemsAndBytes = 10L; + List streams = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + streams.add(spy(createWindmillStreamSender(GetWorkBudget.noBudget()))); + } + createBudgetDistributor(0L) + .distributeBudget( + ImmutableList.copyOf(streams), + GetWorkBudget.builder() + .setItems(totalItemsAndBytes) + .setBytes(totalItemsAndBytes) + .build()); + + long itemsAndBytesPerStream = (long) Math.ceil(totalItemsAndBytes / (streams.size() * 1.0)); + streams.forEach( + stream -> + verify(stream, times(1)) + .adjustBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); + } + + private WindmillStreamSender createWindmillStreamSender(GetWorkBudget getWorkBudget) { + return WindmillStreamSender.create( + stub, + Windmill.GetWorkRequest.newBuilder() + .setClientId(1L) + .setJobId("job") + .setProjectId("project") + .build(), + getWorkBudget, + GrpcWindmillStreamFactory.of( + JobHeader.newBuilder() + .setJobId("job") + .setProjectId("project") + .setWorkerId("worker") + .build()) + .build(), + (computation, + inputDataWatermark, + synchronizedProcessingTime, + workItem, + ackQueuedWorkItem, + getWorkStreamLatencies) -> {}); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresherTest.java new file mode 100644 index 000000000000..fd85410cc91d --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresherTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work.budget; + +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public class GetWorkBudgetRefresherTest { + private static final int WAIT_BUFFER = 10; + private final Runnable redistributeBudget = Mockito.mock(Runnable.class); + + private GetWorkBudgetRefresher createBudgetRefresher() { + return createBudgetRefresher(false); + } + + private GetWorkBudgetRefresher createBudgetRefresher(Boolean isBudgetRefreshPaused) { + return new GetWorkBudgetRefresher(() -> isBudgetRefreshPaused, redistributeBudget); + } + + @Test + public void testStop_successfullyTerminates() throws InterruptedException { + GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(); + budgetRefresher.start(); + budgetRefresher.stop(); + budgetRefresher.requestBudgetRefresh(); + Thread.sleep(WAIT_BUFFER); + verifyNoInteractions(redistributeBudget); + } + + @Test + public void testRequestBudgetRefresh_triggersBudgetRefresh() throws InterruptedException { + GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(); + budgetRefresher.start(); + budgetRefresher.requestBudgetRefresh(); + // Wait a bit for redistribute budget to run. + Thread.sleep(WAIT_BUFFER); + verify(redistributeBudget, times(1)).run(); + } + + @Test + public void testScheduledBudgetRefresh() throws InterruptedException { + GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(); + budgetRefresher.start(); + Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS + WAIT_BUFFER); + verify(redistributeBudget, times(1)).run(); + } + + @Test + public void testTriggeredAndScheduledBudgetRefresh_concurrent() throws InterruptedException { + GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(); + budgetRefresher.start(); + Thread budgetRefreshTriggerThread = new Thread(budgetRefresher::requestBudgetRefresh); + budgetRefreshTriggerThread.start(); + Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS + WAIT_BUFFER); + budgetRefreshTriggerThread.join(); + + // Wait a bit for redistribute budget to run. + Thread.sleep(WAIT_BUFFER); + verify(redistributeBudget, times(2)).run(); + } + + @Test + public void testTriggeredBudgetRefresh_doesNotRunWhenBudgetRefreshPaused() + throws InterruptedException { + GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(true); + budgetRefresher.start(); + budgetRefresher.requestBudgetRefresh(); + Thread.sleep(WAIT_BUFFER); + verifyNoInteractions(redistributeBudget); + } + + @Test + public void testScheduledBudgetRefresh_doesNotRunWhenBudgetRefreshPaused() + throws InterruptedException { + GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(true); + budgetRefresher.start(); + Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS + WAIT_BUFFER); + verifyNoInteractions(redistributeBudget); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetTest.java index 76d508397850..97789abaaa97 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetTest.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertThrows; import org.junit.Test; import org.junit.runner.RunWith; @@ -42,31 +41,10 @@ public void testBuild_itemsAndBytesNeverBelowZero() { } @Test - public void testAdd_doesNotAllowNegativeParameters() { + public void testApply_itemsAndBytesNeverBelowZero() { GetWorkBudget getWorkBudget = GetWorkBudget.builder().setItems(1).setBytes(1).build(); - assertThrows(IllegalArgumentException.class, () -> getWorkBudget.add(-1, -1)); - } - - @Test - public void testSubtract_itemsAndBytesNeverBelowZero() { - GetWorkBudget getWorkBudget = GetWorkBudget.builder().setItems(1).setBytes(1).build(); - GetWorkBudget subtracted = getWorkBudget.subtract(10, 10); - assertEquals(0, subtracted.items()); - assertEquals(0, subtracted.bytes()); - } - - @Test - public void testSubtractGetWorkBudget_itemsAndBytesNeverBelowZero() { - GetWorkBudget getWorkBudget = GetWorkBudget.builder().setItems(1).setBytes(1).build(); - GetWorkBudget subtracted = - getWorkBudget.subtract(GetWorkBudget.builder().setItems(10).setBytes(10).build()); + GetWorkBudget subtracted = getWorkBudget.apply(-10, -10); assertEquals(0, subtracted.items()); assertEquals(0, subtracted.bytes()); } - - @Test - public void testSubtract_doesNotAllowNegativeParameters() { - GetWorkBudget getWorkBudget = GetWorkBudget.builder().setItems(1).setBytes(1).build(); - assertThrows(IllegalArgumentException.class, () -> getWorkBudget.subtract(-1, -1)); - } }