From 3152fac7bbcc0ad4a28d4547ca42eb1f294bb30e Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Tue, 25 Jul 2023 18:47:23 -0700 Subject: [PATCH] add getworkermetadata streaming rpc --- .../worker/build.gradle | 135 +++---- .../windmill/AbstractWindmillStream.java | 17 +- .../ForwardingClientResponseObserver.java | 14 +- .../windmill/StreamObserverFactory.java | 18 +- .../worker/windmill/WindmillEndpoints.java | 221 ++++++++++++ .../windmill/WindmillServiceAddress.java | 45 +++ .../worker/windmill/WindmillStream.java | 4 + .../grpcclient/GrpcCommitWorkStream.java | 31 +- .../grpcclient/GrpcGetDataStream.java | 41 +-- .../grpcclient/GrpcGetWorkStream.java | 25 +- .../GrpcGetWorkerMetadataStream.java | 170 +++++++++ .../grpcclient/GrpcWindmillServer.java | 40 ++- .../GrpcGetWorkerMetadataStreamTest.java | 328 ++++++++++++++++++ .../windmill/src/main/proto/windmill.proto | 11 +- .../src/main/proto/windmill_service.proto | 2 +- 15 files changed, 948 insertions(+), 154 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkerMetadataStream.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkerMetadataStreamTest.java diff --git a/runners/google-cloud-dataflow-java/worker/build.gradle b/runners/google-cloud-dataflow-java/worker/build.gradle index e1448e313c60..ce06063c9b52 100644 --- a/runners/google-cloud-dataflow-java/worker/build.gradle +++ b/runners/google-cloud-dataflow-java/worker/build.gradle @@ -67,90 +67,91 @@ def excluded_dependencies = [ library.java.error_prone_annotations, // Provided scope added in worker library.java.hamcrest, // Test only library.java.junit, // Test only - library.java.jsonassert // Test only + library.java.jsonassert, // Test only + library.java.truth // Test only ] applyJavaNature( automaticModuleName: 'org.apache.beam.runners.dataflow.worker', archivesBaseName: 'beam-runners-google-cloud-dataflow-java-legacy-worker', classesTriggerCheckerBugs: [ - 'BatchGroupAlsoByWindowAndCombineFn': 'TODO: file a bug report', - 'AssignWindowsParDoFnFactory': 'TODO: file a bug report', - 'FetchAndFilterStreamingSideInputsOperation': 'https://github.com/typetools/checker-framework/issues/5436', + 'BatchGroupAlsoByWindowAndCombineFn' : 'TODO: file a bug report', + 'AssignWindowsParDoFnFactory' : 'TODO: file a bug report', + 'FetchAndFilterStreamingSideInputsOperation': 'https://github.com/typetools/checker-framework/issues/5436', ], exportJavadoc: false, enableSpotbugs: false /* TODO(BEAM-5658): enable spotbugs */, shadowJarValidationExcludes: [ - "org/apache/beam/runners/dataflow/worker/**", - "org/apache/beam/repackaged/beam_runners_google_cloud_dataflow_java_legacy_worker/**", - // TODO(https://github.com/apache/beam/issues/19114): Move DataflowRunnerHarness class under org.apache.beam.runners.dataflow.worker namespace - "com/google/cloud/dataflow/worker/DataflowRunnerHarness.class", - // Allow slf4j implementation worker for logging during pipeline execution - "org/slf4j/impl/**" + "org/apache/beam/runners/dataflow/worker/**", + "org/apache/beam/repackaged/beam_runners_google_cloud_dataflow_java_legacy_worker/**", + // TODO(https://github.com/apache/beam/issues/19114): Move DataflowRunnerHarness class under org.apache.beam.runners.dataflow.worker namespace + "com/google/cloud/dataflow/worker/DataflowRunnerHarness.class", + // Allow slf4j implementation worker for logging during pipeline execution + "org/slf4j/impl/**" ], shadowClosure: { - // Each included dependency must also include all of its necessary transitive dependencies - // or have them provided by the users pipeline during job submission. Typically a users - // pipeline includes :runners:google-cloud-dataflow-java and its transitive dependencies - // so those dependencies don't need to be shaded (bundled and relocated) away. All other - // dependencies needed to run the worker must be shaded (bundled and relocated) to prevent - // ClassNotFound and/or MethodNotFound errors during pipeline execution. - // - // Each included dependency should have a matching relocation rule below that ensures - // that the shaded jar is correctly built. + // Each included dependency must also include all of its necessary transitive dependencies + // or have them provided by the users pipeline during job submission. Typically a users + // pipeline includes :runners:google-cloud-dataflow-java and its transitive dependencies + // so those dependencies don't need to be shaded (bundled and relocated) away. All other + // dependencies needed to run the worker must be shaded (bundled and relocated) to prevent + // ClassNotFound and/or MethodNotFound errors during pipeline execution. + // + // Each included dependency should have a matching relocation rule below that ensures + // that the shaded jar is correctly built. - dependencies { - include(dependency(library.java.slf4j_jdk14)) - } + dependencies { + include(dependency(library.java.slf4j_jdk14)) + } - dependencies { - include(project(path: ":model:fn-execution", configuration: "shadow")) - } - relocate("org.apache.beam.model.fnexecution.v1", getWorkerRelocatedPath("org.apache.beam.model.fnexecution.v1")) + dependencies { + include(project(path: ":model:fn-execution", configuration: "shadow")) + } + relocate("org.apache.beam.model.fnexecution.v1", getWorkerRelocatedPath("org.apache.beam.model.fnexecution.v1")) - dependencies { - include(project(":runners:core-construction-java")) - include(project(":runners:core-java")) - } - relocate("org.apache.beam.runners.core", getWorkerRelocatedPath("org.apache.beam.runners.core")) - relocate("org.apache.beam.repackaged.beam_runners_core_construction_java", getWorkerRelocatedPath("org.apache.beam.repackaged.beam_runners_core_construction_java")) - relocate("org.apache.beam.repackaged.beam_runners_core_java", getWorkerRelocatedPath("org.apache.beam.repackaged.beam_runners_core_java")) + dependencies { + include(project(":runners:core-construction-java")) + include(project(":runners:core-java")) + } + relocate("org.apache.beam.runners.core", getWorkerRelocatedPath("org.apache.beam.runners.core")) + relocate("org.apache.beam.repackaged.beam_runners_core_construction_java", getWorkerRelocatedPath("org.apache.beam.repackaged.beam_runners_core_construction_java")) + relocate("org.apache.beam.repackaged.beam_runners_core_java", getWorkerRelocatedPath("org.apache.beam.repackaged.beam_runners_core_java")) - dependencies { - include(project(":runners:java-fn-execution")) - } - relocate("org.apache.beam.runners.fnexecution", getWorkerRelocatedPath("org.apache.beam.runners.fnexecution")) - relocate("org.apache.beam.repackaged.beam_runners_java_fn_execution", getWorkerRelocatedPath("org.apache.beam.repackaged.beam_runners_java_fn_execution")) + dependencies { + include(project(":runners:java-fn-execution")) + } + relocate("org.apache.beam.runners.fnexecution", getWorkerRelocatedPath("org.apache.beam.runners.fnexecution")) + relocate("org.apache.beam.repackaged.beam_runners_java_fn_execution", getWorkerRelocatedPath("org.apache.beam.repackaged.beam_runners_java_fn_execution")) - dependencies { - include(project(":sdks:java:fn-execution")) - } - relocate("org.apache.beam.sdk.fn", getWorkerRelocatedPath("org.apache.beam.sdk.fn")) - relocate("org.apache.beam.repackaged.beam_sdks_java_fn_execution", getWorkerRelocatedPath("org.apache.beam.repackaged.beam_sdks_java_fn_execution")) + dependencies { + include(project(":sdks:java:fn-execution")) + } + relocate("org.apache.beam.sdk.fn", getWorkerRelocatedPath("org.apache.beam.sdk.fn")) + relocate("org.apache.beam.repackaged.beam_sdks_java_fn_execution", getWorkerRelocatedPath("org.apache.beam.repackaged.beam_sdks_java_fn_execution")) - dependencies { - // We have to include jetty-server/jetty-servlet and all of its transitive dependencies - // which includes several org.eclipse.jetty artifacts + servlet-api - include(dependency("org.eclipse.jetty:.*:9.2.10.v20150310")) - include(dependency("javax.servlet:javax.servlet-api:3.1.0")) - } - relocate("org.eclipse.jetty", getWorkerRelocatedPath("org.eclipse.jetty")) - relocate("javax.servlet", getWorkerRelocatedPath("javax.servlet")) + dependencies { + // We have to include jetty-server/jetty-servlet and all of its transitive dependencies + // which includes several org.eclipse.jetty artifacts + servlet-api + include(dependency("org.eclipse.jetty:.*:9.2.10.v20150310")) + include(dependency("javax.servlet:javax.servlet-api:3.1.0")) + } + relocate("org.eclipse.jetty", getWorkerRelocatedPath("org.eclipse.jetty")) + relocate("javax.servlet", getWorkerRelocatedPath("javax.servlet")) - // We don't relocate windmill since it is already underneath the org.apache.beam.runners.dataflow.worker namespace and never - // expect a user pipeline to include it. There is also a JNI component that windmill server relies on which makes - // arbitrary relocation more difficult. - dependencies { - include(project(path: ":runners:google-cloud-dataflow-java:worker:windmill", configuration: "shadow")) - } + // We don't relocate windmill since it is already underneath the org.apache.beam.runners.dataflow.worker namespace and never + // expect a user pipeline to include it. There is also a JNI component that windmill server relies on which makes + // arbitrary relocation more difficult. + dependencies { + include(project(path: ":runners:google-cloud-dataflow-java:worker:windmill", configuration: "shadow")) + } - // Include original source files extracted under - // '$buildDir/original_sources_to_package' to jar - from "$buildDir/original_sources_to_package" + // Include original source files extracted under + // '$buildDir/original_sources_to_package' to jar + from "$buildDir/original_sources_to_package" - exclude "META-INF/LICENSE.txt" - exclude "about.html" -}) + exclude "META-INF/LICENSE.txt" + exclude "about.html" + }) /******************************************************************************/ // Configure the worker root project @@ -219,6 +220,10 @@ dependencies { // as well and placed within the testImplementation configuration. Otherwise we can place it within // the shadowTest configuration. testImplementation project(path: ":runners:core-java", configuration: "testRuntimeMigration") + // TODO: excluding Guava until Truth updates it to >32.1.x + testImplementation(library.java.truth) { + exclude group: 'com.google.guava', module: 'guava' + } shadowTest project(path: ":sdks:java:extensions:google-cloud-platform-core", configuration: "testRuntimeMigration") shadowTest project(path: ":runners:direct-java", configuration: "shadow") shadowTest project(path: ":sdks:java:harness", configuration: "shadowTest") @@ -232,8 +237,8 @@ dependencies { project.task('validateShadedJarContainsSlf4jJdk14', dependsOn: 'shadowJar') { ext.outFile = project.file("${project.reportsDir}/${name}.out") inputs.files(project.configurations.shadow.artifacts.files) - .withPropertyName("shadowArtifactsFiles") - .withPathSensitivity(PathSensitivity.RELATIVE) + .withPropertyName("shadowArtifactsFiles") + .withPathSensitivity(PathSensitivity.RELATIVE) outputs.files outFile doLast { project.configurations.shadow.artifacts.files.each { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/AbstractWindmillStream.java index d3e7de58931f..ea7efff7a06d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/AbstractWindmillStream.java @@ -60,21 +60,16 @@ * synchronizing on this. */ public abstract class AbstractWindmillStream implements WindmillStream { - protected static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; + public static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; // Default gRPC streams to 2MB chunks, which has shown to be a large enough chunk size to reduce // per-chunk overhead, and small enough that we can still perform granular flow-control. protected static final int RPC_STREAM_CHUNK_SIZE = 2 << 20; - private static final Logger LOG = LoggerFactory.getLogger(AbstractWindmillStream.class); - protected final AtomicBoolean clientClosed; - + private final AtomicLong lastSendTimeMs; private final Executor executor; private final BackOff backoff; - // Indicates if the current stream in requestObserver is closed by calling close() method - private final AtomicBoolean streamClosed; private final AtomicLong startTimeMs; - private final AtomicLong lastSendTimeMs; private final AtomicLong lastResponseTimeMs; private final AtomicInteger errorCount; private final AtomicReference lastError; @@ -83,6 +78,8 @@ public abstract class AbstractWindmillStream implements Win private final Set> streamRegistry; private final int logEveryNStreamFailures; private final Supplier> requestObserverSupplier; + // Indicates if the current stream in requestObserver is closed by calling close() method + private final AtomicBoolean streamClosed; private @Nullable StreamObserver requestObserver; protected AbstractWindmillStream( @@ -132,9 +129,9 @@ private static long debugDuration(long nowMs, long startMs) { protected abstract boolean hasPendingRequests(); /** - * Called when the stream is throttled due to resource exhausted errors. Will be called for each - * resource exhausted error not just the first. onResponse() must stop throttling on receipt of - * the first good message. + * Called when the client side stream is throttled due to resource exhausted errors. Will be + * called for each resource exhausted error not just the first. onResponse() must stop throttling + * on receipt of the first good message. */ protected abstract void startThrottleTimer(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/ForwardingClientResponseObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/ForwardingClientResponseObserver.java index 3737e29efb13..a1f80598d89a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/ForwardingClientResponseObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/ForwardingClientResponseObserver.java @@ -27,23 +27,23 @@ *

Used to wrap existing {@link StreamObserver}s to be able to install an {@link * ClientCallStreamObserver#setOnReadyHandler(Runnable) onReadyHandler}. * - *

This is as thread-safe as the undering stream observer that is being wrapped. + *

This is as thread-safe as the underlying stream observer that is being wrapped. */ -final class ForwardingClientResponseObserver - implements ClientResponseObserver { +final class ForwardingClientResponseObserver + implements ClientResponseObserver { private final Runnable onReadyHandler; private final Runnable onDoneHandler; - private final StreamObserver inboundObserver; + private final StreamObserver inboundObserver; ForwardingClientResponseObserver( - StreamObserver inboundObserver, Runnable onReadyHandler, Runnable onDoneHandler) { + StreamObserver inboundObserver, Runnable onReadyHandler, Runnable onDoneHandler) { this.inboundObserver = inboundObserver; this.onReadyHandler = onReadyHandler; this.onDoneHandler = onDoneHandler; } @Override - public void onNext(ReqT value) { + public void onNext(ResponseT value) { inboundObserver.onNext(value); } @@ -60,7 +60,7 @@ public void onCompleted() { } @Override - public void beforeStart(ClientCallStreamObserver stream) { + public void beforeStart(ClientCallStreamObserver stream) { stream.setOnReadyHandler(onReadyHandler); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/StreamObserverFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/StreamObserverFactory.java index a046f2fd46ac..e0878b7b0b91 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/StreamObserverFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/StreamObserverFactory.java @@ -33,9 +33,9 @@ public static StreamObserverFactory direct( return new Direct(deadlineSeconds, messagesBetweenIsReadyChecks); } - public abstract StreamObserver from( - Function, StreamObserver> clientFactory, - StreamObserver responseObserver); + public abstract StreamObserver from( + Function, StreamObserver> clientFactory, + StreamObserver responseObserver); private static class Direct extends StreamObserverFactory { private final long deadlineSeconds; @@ -47,14 +47,14 @@ private static class Direct extends StreamObserverFactory { } @Override - public StreamObserver from( - Function, StreamObserver> clientFactory, - StreamObserver inboundObserver) { + public StreamObserver from( + Function, StreamObserver> clientFactory, + StreamObserver inboundObserver) { AdvancingPhaser phaser = new AdvancingPhaser(1); - CallStreamObserver outboundObserver = - (CallStreamObserver) + CallStreamObserver outboundObserver = + (CallStreamObserver) clientFactory.apply( - new ForwardingClientResponseObserver( + new ForwardingClientResponseObserver( inboundObserver, phaser::arrive, phaser::forceTermination)); return new DirectStreamObserver<>( phaser, outboundObserver, deadlineSeconds, messagesBetweenIsReadyChecks); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java new file mode 100644 index 000000000000..64b6e675ef5f --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; + +import com.google.auto.value.AutoValue; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.Map; +import java.util.Optional; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Value class for holding endpoints used for communicating with Windmill service. Corresponds + * directly with {@link Windmill.WorkerMetadataResponse}. + */ +@AutoValue +public abstract class WindmillEndpoints { + private static final Logger LOG = LoggerFactory.getLogger(WindmillEndpoints.class); + + /** + * Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns a map where the key + * is a global data tag and the value is the endpoint where the data associated with the global + * data tag resides. + * + * @see Beam Side + * Inputs + */ + public abstract ImmutableMap globalDataEndpoints(); + + /** + * Used by GetWork/GetData/CommitWork calls to send, receive, and commit work directly to/from + * Windmill servers. Returns a list of endpoints used to communicate with the corresponding + * Windmill servers. + */ + public abstract ImmutableList windmillEndpoints(); + + public static WindmillEndpoints from( + Windmill.WorkerMetadataResponse workerMetadataResponseProto) { + ImmutableMap globalDataServers = + workerMetadataResponseProto.getGlobalDataEndpointsMap().entrySet().stream() + .collect( + toImmutableMap( + Map.Entry::getKey, // global data key + endpoint -> WindmillEndpoints.Endpoint.from(endpoint.getValue()))); + + ImmutableList windmillServers = + workerMetadataResponseProto.getWorkEndpointsList().stream() + .map(WindmillEndpoints.Endpoint::from) + .collect(toImmutableList()); + + return WindmillEndpoints.builder() + .setGlobalDataEndpoints(globalDataServers) + .setWindmillEndpoints(windmillServers) + .build(); + } + + public static WindmillEndpoints.Builder builder() { + return new AutoValue_WindmillEndpoints.Builder(); + } + + /** + * Representation of an endpoint in {@link Windmill.WorkerMetadataResponse.Endpoint} proto with + * the worker_token field, and direct_endpoint field parsed into a {@link WindmillServiceAddress} + * which holds either a {@link Inet6Address} or {@link HostAndPort} used to connect to Streaming + * Engine. {@link Inet6Address}(s) represent direct Windmill worker connections, and {@link + * HostAndPort}(s) represent connections to the Windmill Dispatcher. + */ + @AutoValue + public abstract static class Endpoint { + /** + * {@link WindmillServiceAddress} representation of {@link + * Windmill.WorkerMetadataResponse.Endpoint#getDirectEndpoint()}. The proto's direct_endpoint + * string can be converted to either {@link Inet6Address} or {@link HostAndPort}. + */ + public abstract Optional directEndpoint(); + + /** + * Corresponds to {@link Windmill.WorkerMetadataResponse.Endpoint#getWorkerToken()} in the + * windmill.proto file. + */ + public abstract Optional workerToken(); + + public static Endpoint.Builder builder() { + return new AutoValue_WindmillEndpoints_Endpoint.Builder(); + } + + public static Endpoint from(Windmill.WorkerMetadataResponse.Endpoint endpointProto) { + Endpoint.Builder endpointBuilder = Endpoint.builder(); + if (endpointProto.hasDirectEndpoint() && !endpointProto.getDirectEndpoint().isEmpty()) { + parseDirectEndpoint(endpointProto.getDirectEndpoint()) + .ifPresent(endpointBuilder::setDirectEndpoint); + } + if (endpointProto.hasWorkerToken() && !endpointProto.getWorkerToken().isEmpty()) { + endpointBuilder.setWorkerToken(endpointProto.getWorkerToken()); + } + + Endpoint endpoint = endpointBuilder.build(); + + if (!endpoint.directEndpoint().isPresent() && !endpoint.workerToken().isPresent()) { + throw new IllegalArgumentException( + String.format( + "direct_endpoint=[%s] not present or could not be parsed, and worker_token" + + " not present. At least one of these fields is required.", + endpointProto.getDirectEndpoint())); + } + + return endpoint; + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setDirectEndpoint(WindmillServiceAddress directEndpoint); + + public abstract Builder setWorkerToken(String workerToken); + + public abstract Endpoint build(); + } + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setGlobalDataEndpoints( + ImmutableMap globalDataServers); + + public abstract Builder setWindmillEndpoints( + ImmutableList windmillServers); + + abstract ImmutableList.Builder windmillEndpointsBuilder(); + + public final Builder addWindmillEndpoint(WindmillEndpoints.Endpoint endpoint) { + windmillEndpointsBuilder().add(endpoint); + return this; + } + + public final Builder addAllWindmillEndpoints(Iterable endpoints) { + windmillEndpointsBuilder().addAll(endpoints); + return this; + } + + abstract ImmutableMap.Builder globalDataEndpointsBuilder(); + + public final Builder addGlobalDataEndpoint( + String globalDataKey, WindmillEndpoints.Endpoint endpoint) { + globalDataEndpointsBuilder().put(globalDataKey, endpoint); + return this; + } + + public final Builder addAllGlobalDataEndpoints( + Map globalDataEndpoints) { + globalDataEndpointsBuilder().putAll(globalDataEndpoints); + return this; + } + + public abstract WindmillEndpoints build(); + } + + private static Optional parseDirectEndpoint(String directEndpoint) { + Optional directEndpointIpV6Address = + tryParseDirectEndpointIntoIpV6Address(directEndpoint).map(WindmillServiceAddress::create); + + return directEndpointIpV6Address.isPresent() + ? directEndpointIpV6Address + : tryParseEndpointIntoHostAndPort(directEndpoint).map(WindmillServiceAddress::create); + } + + private static Optional tryParseEndpointIntoHostAndPort(String directEndpoint) { + try { + return Optional.of(HostAndPort.fromString(directEndpoint)); + } catch (IllegalArgumentException e) { + LOG.warn("{} cannot be parsed into a gcpServiceAddress", directEndpoint); + return Optional.empty(); + } + } + + private static Optional tryParseDirectEndpointIntoIpV6Address( + String directEndpoint) { + InetAddress directEndpointAddress = null; + try { + directEndpointAddress = Inet6Address.getByName(directEndpoint); + } catch (UnknownHostException e) { + LOG.warn( + "Error occurred trying to parse direct_endpoint={} into IPv6 address. Exception={}", + directEndpoint, + e.toString()); + } + + // Inet6Address.getByAddress returns either an IPv4 or an IPv6 address depending on the format + // of the direct_endpoint string. + if (!(directEndpointAddress instanceof Inet6Address)) { + LOG.warn( + "{} is not an IPv6 address. Direct endpoints are expected to be in IPv6 format.", + directEndpoint); + return Optional.empty(); + } + + return Optional.ofNullable((Inet6Address) directEndpointAddress); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java new file mode 100644 index 000000000000..3ebda8fab8ed --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill; + +import com.google.auto.value.AutoOneOf; +import java.net.Inet6Address; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; + +/** Used to create channels to communicate with Streaming Engine via gRpc. */ +@AutoOneOf(WindmillServiceAddress.Kind.class) +public abstract class WindmillServiceAddress { + public static WindmillServiceAddress create(Inet6Address ipv6Address) { + return AutoOneOf_WindmillServiceAddress.ipv6(ipv6Address); + } + + public static WindmillServiceAddress create(HostAndPort gcpServiceAddress) { + return AutoOneOf_WindmillServiceAddress.gcpServiceAddress(gcpServiceAddress); + } + + public abstract Kind getKind(); + + public abstract Inet6Address ipv6(); + + public abstract HostAndPort gcpServiceAddress(); + + public enum Kind { + IPV6, + GCP_SERVICE_ADDRESS + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStream.java index 70c7cc36ba31..4dd4164fc4ef 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStream.java @@ -86,4 +86,8 @@ boolean commitWorkItem( /** Flushes any pending work items to the wire. */ void flush(); } + + /** Interface for streaming GetWorkerMetadata requests to Windmill. */ + @ThreadSafe + interface GetWorkerMetadataStream extends WindmillStream {} } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcCommitWorkStream.java index 74bd93a5474f..1bba40805dec 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcCommitWorkStream.java @@ -17,16 +17,17 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.grpcclient; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; + import java.io.PrintWriter; import java.util.HashMap; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; +import java.util.function.Function; import org.apache.beam.runners.dataflow.worker.windmill.AbstractWindmillStream; -import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; @@ -37,7 +38,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.CommitWorkStream; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -56,7 +57,8 @@ final class GrpcCommitWorkStream private final int streamingRpcBatchLimit; private GrpcCommitWorkStream( - CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub, + Function, StreamObserver> + startCommitWorkRpcFn, BackOff backoff, StreamObserverFactory streamObserverFactory, Set> streamRegistry, @@ -66,10 +68,7 @@ private GrpcCommitWorkStream( AtomicLong idGenerator, int streamingRpcBatchLimit) { super( - responseObserver -> - stub.withDeadlineAfter( - AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS) - .commitWorkStream(responseObserver), + startCommitWorkRpcFn, backoff, streamObserverFactory, streamRegistry, @@ -83,7 +82,8 @@ private GrpcCommitWorkStream( } static GrpcCommitWorkStream create( - CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub, + Function, StreamObserver> + startCommitWorkRpcFn, BackOff backoff, StreamObserverFactory streamObserverFactory, Set> streamRegistry, @@ -94,7 +94,7 @@ static GrpcCommitWorkStream create( int streamingRpcBatchLimit) { GrpcCommitWorkStream commitWorkStream = new GrpcCommitWorkStream( - stub, + startCommitWorkRpcFn, backoff, streamObserverFactory, streamRegistry, @@ -252,7 +252,7 @@ private void issueBatchedRequest(Map requests) { } private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest) { - Preconditions.checkNotNull(pendingRequest.computation); + checkNotNull(pendingRequest.computation); final ByteString serializedCommit = pendingRequest.request.toByteString(); synchronized (this) { @@ -306,8 +306,13 @@ long getBytes() { private class Batcher { - final Map queue = new HashMap<>(); - long queuedBytes = 0; + private final Map queue; + private long queuedBytes; + + private Batcher() { + this.queuedBytes = 0; + this.queue = new HashMap<>(); + } boolean canAccept(PendingRequest request) { return queue.isEmpty() diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStream.java index b51daabb1a2b..238cc771dce8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStream.java @@ -17,6 +17,9 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.grpcclient; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify.verify; + import java.io.IOException; import java.io.InputStream; import java.io.PrintWriter; @@ -28,10 +31,9 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; import org.apache.beam.runners.dataflow.worker.windmill.AbstractWindmillStream; -import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData; @@ -45,8 +47,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.grpcclient.GrpcGetDataStreamRequests.QueuedBatch; import org.apache.beam.runners.dataflow.worker.windmill.grpcclient.GrpcGetDataStreamRequests.QueuedRequest; import org.apache.beam.sdk.util.BackOff; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -64,7 +65,8 @@ final class GrpcGetDataStream private final int streamingRpcBatchLimit; private GrpcGetDataStream( - CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub, + Function, StreamObserver> + startGetDataRpcFn, BackOff backoff, StreamObserverFactory streamObserverFactory, Set> streamRegistry, @@ -74,14 +76,7 @@ private GrpcGetDataStream( AtomicLong idGenerator, int streamingRpcBatchLimit) { super( - responseObserver -> - stub.withDeadlineAfter( - AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS) - .getDataStream(responseObserver), - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures); + startGetDataRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures); this.idGenerator = idGenerator; this.getDataThrottleTimer = getDataThrottleTimer; this.jobHeader = jobHeader; @@ -91,7 +86,8 @@ private GrpcGetDataStream( } static GrpcGetDataStream create( - CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub, + Function, StreamObserver> + startGetDataRpcFn, BackOff backoff, StreamObserverFactory streamObserverFactory, Set> streamRegistry, @@ -102,7 +98,7 @@ static GrpcGetDataStream create( int streamingRpcBatchLimit) { GrpcGetDataStream getDataStream = new GrpcGetDataStream( - stub, + startGetDataRpcFn, backoff, streamObserverFactory, streamRegistry, @@ -122,7 +118,7 @@ protected synchronized void onNewStream() { // We rely on close only occurring after all methods on the stream have returned. // Since the requestKeyedData and requestGlobalData methods are blocking this // means there should be no pending requests. - Verify.verify(!hasPendingRequests()); + verify(!hasPendingRequests()); } else { for (AppendableInputStream responseStream : pending.values()) { responseStream.cancel(); @@ -138,14 +134,13 @@ protected boolean hasPendingRequests() { @Override @SuppressWarnings("dereference.of.nullable") protected void onResponse(StreamingGetDataResponse chunk) { - Preconditions.checkArgument(chunk.getRequestIdCount() == chunk.getSerializedResponseCount()); - Preconditions.checkArgument( - chunk.getRemainingBytesForResponse() == 0 || chunk.getRequestIdCount() == 1); + checkArgument(chunk.getRequestIdCount() == chunk.getSerializedResponseCount()); + checkArgument(chunk.getRemainingBytesForResponse() == 0 || chunk.getRequestIdCount() == 1); getDataThrottleTimer.stop(); for (int i = 0; i < chunk.getRequestIdCount(); ++i) { AppendableInputStream responseStream = pending.get(chunk.getRequestId(i)); - Verify.verify(responseStream != null, "No pending response stream"); + verify(responseStream != null, "No pending response stream"); responseStream.append(chunk.getSerializedResponse(i).newInput()); if (chunk.getRemainingBytesForResponse() == 0) { responseStream.complete(); @@ -283,12 +278,12 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept // Finalize the batch so that no additional requests will be added. Leave the batch in the // queue so that a subsequent batch will wait for it's completion. synchronized (batches) { - Verify.verify(batch == batches.peekFirst()); + verify(batch == batches.peekFirst()); batch.markFinalized(); } sendBatch(batch.requests()); synchronized (batches) { - Verify.verify(batch == batches.pollFirst()); + verify(batch == batches.pollFirst()); } // Notify all waiters with requests in this batch as well as the sender // of the next batch (if one exists). @@ -308,7 +303,7 @@ private void sendBatch(List requests) { for (QueuedRequest request : requests) { // Map#put returns null if there was no previous mapping for the key, meaning we have not // seen it before. - Verify.verify(pending.put(request.id(), request.getResponseStream()) == null); + verify(pending.put(request.id(), request.getResponseStream()) == null); } try { send(batchedRequest); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkStream.java index 6e35beccdb6a..4660fe25b13b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkStream.java @@ -23,12 +23,11 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; import org.apache.beam.runners.dataflow.worker.windmill.AbstractWindmillStream; -import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; @@ -40,6 +39,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream.WorkItemReceiver; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -58,7 +58,10 @@ final class GrpcGetWorkStream private final AtomicLong inflightBytes; private GrpcGetWorkStream( - CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub, + Function< + StreamObserver, + StreamObserver> + startGetWorkRpcFn, GetWorkRequest request, BackOff backoff, StreamObserverFactory streamObserverFactory, @@ -67,14 +70,7 @@ private GrpcGetWorkStream( ThrottleTimer getWorkThrottleTimer, WorkItemReceiver receiver) { super( - responseObserver -> - stub.withDeadlineAfter( - AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS) - .getWorkStream(responseObserver), - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures); + startGetWorkRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures); this.request = request; this.getWorkThrottleTimer = getWorkThrottleTimer; this.receiver = receiver; @@ -84,7 +80,10 @@ private GrpcGetWorkStream( } static GrpcGetWorkStream create( - CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub, + Function< + StreamObserver, + StreamObserver> + startGetWorkRpcFn, GetWorkRequest request, BackOff backoff, StreamObserverFactory streamObserverFactory, @@ -94,7 +93,7 @@ static GrpcGetWorkStream create( WorkItemReceiver receiver) { GrpcGetWorkStream getWorkStream = new GrpcGetWorkStream( - stub, + startGetWorkRpcFn, request, backoff, streamObserverFactory, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkerMetadataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkerMetadataStream.java new file mode 100644 index 000000000000..427fd412ec7f --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkerMetadataStream.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.grpcclient; + +import com.google.errorprone.annotations.concurrent.GuardedBy; +import java.io.PrintWriter; +import java.util.Optional; +import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Function; +import org.apache.beam.runners.dataflow.worker.windmill.AbstractWindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.StreamObserverFactory; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkerMetadataStream; +import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +final class GrpcGetWorkerMetadataStream + extends AbstractWindmillStream + implements GetWorkerMetadataStream { + private static final Logger LOG = LoggerFactory.getLogger(GrpcGetWorkerMetadataStream.class); + private static final WorkerMetadataRequest HEALTH_CHECK_REQUEST = + WorkerMetadataRequest.getDefaultInstance(); + private final WorkerMetadataRequest workerMetadataRequest; + private final ThrottleTimer getWorkerMetadataThrottleTimer; + private final Consumer serverMappingConsumer; + private final Object metadataLock; + + @GuardedBy("metadataLock") + private long metadataVersion; + + @GuardedBy("metadataLock") + private WorkerMetadataResponse latestResponse; + + private GrpcGetWorkerMetadataStream( + Function, StreamObserver> + startGetWorkerMetadataRpcFn, + BackOff backoff, + StreamObserverFactory streamObserverFactory, + Set> streamRegistry, + int logEveryNStreamFailures, + JobHeader jobHeader, + long metadataVersion, + ThrottleTimer getWorkerMetadataThrottleTimer, + Consumer serverMappingConsumer) { + super( + startGetWorkerMetadataRpcFn, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures); + this.workerMetadataRequest = WorkerMetadataRequest.newBuilder().setHeader(jobHeader).build(); + this.metadataVersion = metadataVersion; + this.getWorkerMetadataThrottleTimer = getWorkerMetadataThrottleTimer; + this.serverMappingConsumer = serverMappingConsumer; + this.latestResponse = WorkerMetadataResponse.getDefaultInstance(); + this.metadataLock = new Object(); + } + + public static GrpcGetWorkerMetadataStream create( + Function, StreamObserver> + startGetWorkerMetadataRpcFn, + BackOff backoff, + StreamObserverFactory streamObserverFactory, + Set> streamRegistry, + int logEveryNStreamFailures, + JobHeader jobHeader, + int metadataVersion, + ThrottleTimer getWorkerMetadataThrottleTimer, + Consumer serverMappingUpdater) { + GrpcGetWorkerMetadataStream getWorkerMetadataStream = + new GrpcGetWorkerMetadataStream( + startGetWorkerMetadataRpcFn, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + jobHeader, + metadataVersion, + getWorkerMetadataThrottleTimer, + serverMappingUpdater); + getWorkerMetadataStream.startStream(); + return getWorkerMetadataStream; + } + + /** + * Each instance of {@link AbstractWindmillStream} owns its own responseObserver that calls + * onResponse(). + */ + @Override + protected void onResponse(WorkerMetadataResponse response) { + extractWindmillEndpointsFrom(response).ifPresent(serverMappingConsumer); + } + + /** + * Acquires the {@link #metadataLock} Returns {@link Optional} if the + * metadataVersion in the response is not stale (older or equal to {@link #metadataVersion}), else + * returns empty {@link Optional}. + */ + private Optional extractWindmillEndpointsFrom( + WorkerMetadataResponse response) { + synchronized (metadataLock) { + if (response.getMetadataVersion() > this.metadataVersion) { + this.metadataVersion = response.getMetadataVersion(); + this.latestResponse = response; + return Optional.of(WindmillEndpoints.from(response)); + } else { + // If the currentMetadataVersion is greater than or equal to one in the response, the + // response data is stale, and we do not want to do anything. + LOG.info( + "Received WorkerMetadataResponse={}; Received metadata version={}; Current metadata version={}. " + + "Skipping update because received stale metadata", + response, + response.getMetadataVersion(), + this.metadataVersion); + } + } + + return Optional.empty(); + } + + @Override + protected synchronized void onNewStream() { + send(workerMetadataRequest); + } + + @Override + protected boolean hasPendingRequests() { + return false; + } + + @Override + protected void startThrottleTimer() { + getWorkerMetadataThrottleTimer.start(); + } + + @Override + protected void sendHealthCheck() { + send(HEALTH_CHECK_REQUEST); + } + + @Override + protected void appendSpecificHtml(PrintWriter writer) { + synchronized (metadataLock) { + writer.format( + "GetWorkerMetadataStream: version=[%d] , job_header=[%s], latest_response=[%s]", + this.metadataVersion, workerMetadataRequest.getHeader(), this.latestResponse); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServer.java index e8745e265eea..19cb90297df5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServer.java @@ -107,7 +107,6 @@ public final class GrpcWindmillServer extends WindmillServerStub { private final ThrottleTimer commitWorkThrottleTimer; private final Random rand; private final Set> streamRegistry; - private ImmutableSet endpoints; private int logEveryNStreamFailures; private Duration maxBackoff = MAX_BACKOFF; @@ -301,14 +300,21 @@ private Channel remoteChannel(HostAndPort endpoint) throws IOException { .build(); } + /** + * Stubs returned from this method do not (and should not) have {@link + * org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Deadline}(s) set since they represent an absolute + * point in time. {@link org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Deadline}(s) should not be + * treated as a timeout which represents a relative point in time. + * + * @see Official gRPC deadline documentation for more + * details. + */ private synchronized CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub() { if (stubList.isEmpty()) { throw new RuntimeException("windmillServiceEndpoint has not been set"); } - if (stubList.size() == 1) { - return stubList.get(0); - } - return stubList.get(rand.nextInt(stubList.size())); + + return stubList.size() == 1 ? stubList.get(0) : stubList.get(rand.nextInt(stubList.size())); } @Override @@ -398,7 +404,13 @@ public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver rece .build(); return GrpcGetWorkStream.create( - stub(), + responseObserver -> + stub() + // Deadlines are absolute points in time, so generate a new one everytime this + // function is called. + .withDeadlineAfter( + AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS) + .getWorkStream(responseObserver), getWorkRequest, grpcBackoff(), newStreamObserverFactory(), @@ -411,7 +423,13 @@ public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver rece @Override public GetDataStream getDataStream() { return GrpcGetDataStream.create( - stub(), + responseObserver -> + stub() + // Deadlines are absolute points in time, so generate a new one everytime this + // function is called. + .withDeadlineAfter( + AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS) + .getDataStream(responseObserver), grpcBackoff(), newStreamObserverFactory(), streamRegistry, @@ -425,7 +443,13 @@ public GetDataStream getDataStream() { @Override public CommitWorkStream commitWorkStream() { return GrpcCommitWorkStream.create( - stub(), + responseObserver -> + stub() + // Deadlines are absolute points in time, so generate a new one everytime this + // function is called. + .withDeadlineAfter( + AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS) + .commitWorkStream(responseObserver), grpcBackoff(), newStreamObserverFactory(), streamRegistry, diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkerMetadataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkerMetadataStreamTest.java new file mode 100644 index 000000000000..45ed3381a8bf --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkerMetadataStreamTest.java @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.grpcclient; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.runners.dataflow.worker.windmill.AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.AbstractWindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.StreamObserverFactory; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; +import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public class GrpcGetWorkerMetadataStreamTest { + private static final String IPV6_ADDRESS_1 = "2001:db8:0000:bac5:0000:0000:fed0:81a2"; + private static final String IPV6_ADDRESS_2 = "2001:db8:0000:bac5:0000:0000:fed0:82a3"; + private static final List DIRECT_PATH_ENDPOINTS = + Lists.newArrayList( + WorkerMetadataResponse.Endpoint.newBuilder() + .setDirectEndpoint(IPV6_ADDRESS_1) + .setWorkerToken("worker_token") + .build()); + private static final Map GLOBAL_DATA_ENDPOINTS = + Maps.newHashMap(); + private static final JobHeader TEST_JOB_HEADER = + JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String FAKE_SERVER_NAME = "Fake server for GrpcGetWorkerMetadataStreamTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + private final Set> streamRegistry = new HashSet<>(); + private ManagedChannel inProcessChannel; + private GrpcGetWorkerMetadataStream stream; + + private GrpcGetWorkerMetadataStream getWorkerMetadataTestStream( + GetWorkerMetadataTestStub getWorkerMetadataTestStub, + int metadataVersion, + Consumer endpointsConsumer) { + serviceRegistry.addService(getWorkerMetadataTestStub); + return GrpcGetWorkerMetadataStream.create( + responseObserver -> + CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel) + .getWorkerMetadataStream(responseObserver), + FluentBackoff.DEFAULT.backoff(), + StreamObserverFactory.direct(DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2, 1), + streamRegistry, + 1, // logEveryNStreamFailures + TEST_JOB_HEADER, + metadataVersion, + new ThrottleTimer(), + endpointsConsumer); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + GLOBAL_DATA_ENDPOINTS.put( + "global_data", + WorkerMetadataResponse.Endpoint.newBuilder() + .setDirectEndpoint(IPV6_ADDRESS_1) + .setWorkerToken("worker_token") + .build()); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + } + + @Test + public void testGetWorkerMetadata() { + WorkerMetadataResponse mockResponse = + WorkerMetadataResponse.newBuilder() + .setMetadataVersion(1) + .addAllWorkEndpoints(DIRECT_PATH_ENDPOINTS) + .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS) + .build(); + TestWindmillEndpointsConsumer testWindmillEndpointsConsumer = + new TestWindmillEndpointsConsumer(); + GetWorkerMetadataTestStub testStub = + new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); + int metadataVersion = -1; + stream = getWorkerMetadataTestStream(testStub, metadataVersion, testWindmillEndpointsConsumer); + testStub.injectWorkerMetadata(mockResponse); + + assertThat(testWindmillEndpointsConsumer.globalDataEndpoints.keySet()) + .containsExactlyElementsIn(GLOBAL_DATA_ENDPOINTS.keySet()); + assertThat(testWindmillEndpointsConsumer.windmillEndpoints) + .containsExactlyElementsIn( + DIRECT_PATH_ENDPOINTS.stream() + .map(WindmillEndpoints.Endpoint::from) + .collect(Collectors.toList())); + } + + @Test + public void testGetWorkerMetadata_consumesSubsequentResponseMetadata() { + WorkerMetadataResponse initialResponse = + WorkerMetadataResponse.newBuilder() + .setMetadataVersion(1) + .addAllWorkEndpoints(DIRECT_PATH_ENDPOINTS) + .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS) + .build(); + TestWindmillEndpointsConsumer testWindmillEndpointsConsumer = + Mockito.spy(new TestWindmillEndpointsConsumer()); + + GetWorkerMetadataTestStub testStub = + new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); + int metadataVersion = 0; + stream = getWorkerMetadataTestStream(testStub, metadataVersion, testWindmillEndpointsConsumer); + testStub.injectWorkerMetadata(initialResponse); + + List newDirectPathEndpoints = + Lists.newArrayList( + WorkerMetadataResponse.Endpoint.newBuilder().setDirectEndpoint(IPV6_ADDRESS_2).build()); + Map newGlobalDataEndpoints = new HashMap<>(); + newGlobalDataEndpoints.put( + "new_global_data", + WorkerMetadataResponse.Endpoint.newBuilder().setDirectEndpoint(IPV6_ADDRESS_2).build()); + + WorkerMetadataResponse newWorkMetadataResponse = + WorkerMetadataResponse.newBuilder() + .setMetadataVersion(initialResponse.getMetadataVersion() + 1) + .addAllWorkEndpoints(newDirectPathEndpoints) + .putAllGlobalDataEndpoints(newGlobalDataEndpoints) + .build(); + + testStub.injectWorkerMetadata(newWorkMetadataResponse); + + assertThat(newGlobalDataEndpoints.keySet()) + .containsExactlyElementsIn(testWindmillEndpointsConsumer.globalDataEndpoints.keySet()); + assertThat(testWindmillEndpointsConsumer.windmillEndpoints) + .containsExactlyElementsIn( + newDirectPathEndpoints.stream() + .map(WindmillEndpoints.Endpoint::from) + .collect(Collectors.toList())); + } + + @Test + public void testGetWorkerMetadata_doesNotConsumeResponseIfMetadataStale() { + WorkerMetadataResponse freshEndpoints = + WorkerMetadataResponse.newBuilder() + .setMetadataVersion(2) + .addAllWorkEndpoints(DIRECT_PATH_ENDPOINTS) + .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS) + .build(); + + TestWindmillEndpointsConsumer testWindmillEndpointsConsumer = + Mockito.spy(new TestWindmillEndpointsConsumer()); + GetWorkerMetadataTestStub testStub = + new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); + int metadataVersion = 0; + stream = getWorkerMetadataTestStream(testStub, metadataVersion, testWindmillEndpointsConsumer); + testStub.injectWorkerMetadata(freshEndpoints); + + List staleDirectPathEndpoints = + Lists.newArrayList( + WorkerMetadataResponse.Endpoint.newBuilder() + .setDirectEndpoint("staleWindmillEndpoint") + .build()); + Map staleGlobalDataEndpoints = new HashMap<>(); + staleGlobalDataEndpoints.put( + "stale_global_data", + WorkerMetadataResponse.Endpoint.newBuilder().setDirectEndpoint("staleGlobalData").build()); + + testStub.injectWorkerMetadata( + WorkerMetadataResponse.newBuilder() + .setMetadataVersion(1) + .addAllWorkEndpoints(staleDirectPathEndpoints) + .putAllGlobalDataEndpoints(staleGlobalDataEndpoints) + .build()); + + // Should have ignored the stale update and only used initial. + verify(testWindmillEndpointsConsumer).accept(WindmillEndpoints.from(freshEndpoints)); + verifyNoMoreInteractions(testWindmillEndpointsConsumer); + } + + @Test + public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() { + GetWorkerMetadataTestStub testStub = + new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); + stream = getWorkerMetadataTestStream(testStub, 0, new TestWindmillEndpointsConsumer()); + testStub.injectWorkerMetadata( + WorkerMetadataResponse.newBuilder() + .setMetadataVersion(1) + .addAllWorkEndpoints(DIRECT_PATH_ENDPOINTS) + .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS) + .build()); + + assertTrue(streamRegistry.contains(stream)); + stream.close(); + assertFalse(streamRegistry.contains(stream)); + } + + @Test + public void testSendHealthCheck() { + TestGetWorkMetadataRequestObserver requestObserver = + Mockito.spy(new TestGetWorkMetadataRequestObserver()); + GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(requestObserver); + stream = getWorkerMetadataTestStream(testStub, 0, new TestWindmillEndpointsConsumer()); + stream.sendHealthCheck(); + + verify(requestObserver).onNext(WorkerMetadataRequest.getDefaultInstance()); + } + + private static class GetWorkerMetadataTestStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private final TestGetWorkMetadataRequestObserver requestObserver; + private @Nullable StreamObserver responseObserver; + + private GetWorkerMetadataTestStub(TestGetWorkMetadataRequestObserver requestObserver) { + this.requestObserver = requestObserver; + } + + @Override + public StreamObserver getWorkerMetadataStream( + StreamObserver responseObserver) { + if (this.responseObserver == null) { + this.responseObserver = responseObserver; + requestObserver.responseObserver = this.responseObserver; + } + + return requestObserver; + } + + private void injectWorkerMetadata(WorkerMetadataResponse response) { + if (responseObserver != null) { + responseObserver.onNext(response); + } + } + } + + @SuppressWarnings("UnusedVariable") + private static class TestGetWorkMetadataRequestObserver + implements StreamObserver { + private @Nullable StreamObserver responseObserver; + + @Override + public void onNext(WorkerMetadataRequest workerMetadataRequest) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + } + + private static class TestWindmillEndpointsConsumer implements Consumer { + private final Map globalDataEndpoints; + private final Set windmillEndpoints; + + private TestWindmillEndpointsConsumer() { + this.globalDataEndpoints = new HashMap<>(); + this.windmillEndpoints = new HashSet<>(); + } + + @Override + public void accept(WindmillEndpoints windmillEndpoints) { + this.globalDataEndpoints.clear(); + this.windmillEndpoints.clear(); + this.globalDataEndpoints.putAll(windmillEndpoints.globalDataEndpoints()); + this.windmillEndpoints.addAll(windmillEndpoints.windmillEndpoints()); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index f66b2bed48c6..1759185911d4 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -746,6 +746,8 @@ message WorkerMetadataRequest { optional JobHeader header = 1; } +// Converted into org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints +// used to connect to Streaming Engine. message WorkerMetadataResponse { // The metadata version increases with every modification. Within a single // stream it will always be increasing. The version may be used across streams @@ -758,7 +760,9 @@ message WorkerMetadataResponse { // CommitWorkStream. Each response on this stream replaces the previous, and // connections to endpoints that are no longer present should be closed. message Endpoint { - optional string endpoint = 1; + // IPv6 address of a streaming engine windmill worker. + optional string direct_endpoint = 1; + optional string worker_token = 2; } repeated Endpoint work_endpoints = 2; @@ -766,10 +770,7 @@ message WorkerMetadataResponse { // calls to retrieve that global data. map global_data_endpoints = 3; - // DirectPath endpoints to be used by user workers for streaming engine jobs. - // DirectPath endpoints here are virtual IPv6 addresses of the windmill - // workers. - repeated Endpoint direct_path_endpoints = 4; + reserved 4; } service WindmillAppliance { diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto index 803766d1a464..d9183e54e0dd 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto @@ -34,7 +34,7 @@ service CloudWindmillServiceV1Alpha1 { returns (stream .windmill.StreamingGetWorkResponseChunk); // Gets worker metadata. Response is a stream. - rpc GetWorkerMetadataStream(.windmill.WorkerMetadataRequest) + rpc GetWorkerMetadataStream(stream .windmill.WorkerMetadataRequest) returns (stream .windmill.WorkerMetadataResponse); // Gets data from Windmill.