From 90b1a6b7200aa5693c6dc3c2abbfa479ebff395f Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 2 Nov 2023 17:12:44 -0700 Subject: [PATCH 1/7] Add new AnyOf environment. This allows an SDK (including, importantly, an expansion service) to provide several alternatives environments suitable for running a pipeline. --- .../model/pipeline/v1/beam_runner_api.proto | 6 ++ .../core/construction/Environments.java | 47 +++++++++++++++ .../core/construction/EnvironmentsTest.java | 41 +++++++++++++ .../apache_beam/transforms/environments.py | 58 +++++++++++++++++++ 4 files changed, 152 insertions(+) diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto index db958f183c45..87af7c19dd79 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto @@ -1570,6 +1570,8 @@ message StandardEnvironments { EXTERNAL = 2 [(beam_urn) = "beam:env:external:v1"]; // An external non managed process to run user code. DEFAULT = 3 [(beam_urn) = "beam:env:default:v1"]; // Used as a stub when context is missing a runner-provided default environment. + + ANYOF = 4 [(beam_urn) = "beam:env:anyof:v1"]; // A selection of equivalent environments a runner may use. } } @@ -1590,6 +1592,10 @@ message ExternalPayload { map params = 2; // Arbitrary extra parameters to pass } +message AnyOfEnvironmentPayload { + repeated Environment environments = 1; +} + // These URNs are used to indicate capabilities of environments that cannot // simply be expressed as a component (such as a Coder or PTransform) that this // environment understands. diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/Environments.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/Environments.java index f531b5be344d..6e05c006d283 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/Environments.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/Environments.java @@ -28,8 +28,11 @@ import java.util.Optional; import java.util.Set; import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor; import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.RunnerApi.AnyOfEnvironmentPayload; import org.apache.beam.model.pipeline.v1.RunnerApi.ArtifactInformation; import org.apache.beam.model.pipeline.v1.RunnerApi.Components; import org.apache.beam.model.pipeline.v1.RunnerApi.DockerPayload; @@ -291,6 +294,50 @@ public static Environment createProcessEnvironment( .build(); } + public static Environment createAnyOfEnvironment(Environment... environments) { + AnyOfEnvironmentPayload.Builder payload = AnyOfEnvironmentPayload.newBuilder(); + for (Environment environment : environments) { + payload.addEnvironments(environment); + } + return Environment.newBuilder() + .setUrn(BeamUrns.getUrn(StandardEnvironments.Environments.ANYOF)) + .setPayload(payload.build().toByteString()) + .build(); + } + + public static List expandAnyOfEnvironments(Environment environment) { + return Stream.of(environment) + .flatMap( + env -> { + if (BeamUrns.getUrn(StandardEnvironments.Environments.ANYOF) + .equals(environment.getUrn())) { + try { + return AnyOfEnvironmentPayload.parseFrom(environment.getPayload()) + .getEnvironmentsList().stream() + .flatMap(subenv -> expandAnyOfEnvironments(subenv).stream()); + } catch (InvalidProtocolBufferException exn) { + throw new RuntimeException(exn); + } + } else { + return Stream.of(env); + } + }) + .collect(Collectors.toList()); + } + + public static Environment resolveAnyOfEnvironment( + Environment environment, String... preferredEnvironmentTypes) { + List allEnvironments = expandAnyOfEnvironments(environment); + for (String urn : preferredEnvironmentTypes) { + for (Environment env : allEnvironments) { + if (urn.equals(env.getUrn())) { + return env; + } + } + } + return allEnvironments.iterator().next(); + } + public static Optional getEnvironment(String ptransformId, Components components) { PTransform ptransform = components.getTransformsOrThrow(ptransformId); String envId = ptransform.getEnvironmentId(); diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EnvironmentsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EnvironmentsTest.java index b71a654f1031..453f6ab6db88 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EnvironmentsTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EnvironmentsTest.java @@ -19,15 +19,18 @@ import static org.apache.beam.runners.core.construction.Environments.JAVA_SDK_HARNESS_CONTAINER_URL; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; import static org.junit.Assert.assertEquals; import java.io.File; import java.io.IOException; import java.io.Serializable; +import java.util.HashMap; import java.util.List; import java.util.Optional; import org.apache.beam.model.pipeline.v1.Endpoints; @@ -353,4 +356,42 @@ public void testGetArtifactsBadNamedFileLogsWarn() throws Exception { assertThat(artifacts, hasSize(1)); expectedLogs.verifyWarn("name 'file_name' was not found"); } + + @Test + public void testExpandAnyOfEnvironmentsOnOrdinaryEnvironment() { + Environment env = Environments.createDockerEnvironment("java"); + assertThat(Environments.expandAnyOfEnvironments(env), contains(env)); + } + + @Test + public void testExpandAnyOfEnvironmentsOnNestedEnvironment() { + Environment envA = Environments.createDockerEnvironment("A"); + Environment envB = Environments.createDockerEnvironment("B"); + Environment envC = Environments.createDockerEnvironment("C"); + Environment env = + Environments.createAnyOfEnvironment(envA, Environments.createAnyOfEnvironment(envB, envC)); + assertThat(Environments.expandAnyOfEnvironments(env), contains(envA, envB, envC)); + } + + @Test + public void testResolveAnyOfEnvironment() { + Environment dockerEnv = Environments.createDockerEnvironment("A"); + Environment processEnv = + Environments.createProcessEnvironment("os", "arch", "cmd", new HashMap<>()); + Environment env = + Environments.createAnyOfEnvironment( + dockerEnv, Environments.createAnyOfEnvironment(processEnv)); + assertThat( + Environments.resolveAnyOfEnvironment( + env, BeamUrns.getUrn(StandardEnvironments.Environments.DOCKER)), + equalTo(dockerEnv)); + assertThat( + Environments.resolveAnyOfEnvironment( + env, BeamUrns.getUrn(StandardEnvironments.Environments.PROCESS)), + equalTo(processEnv)); + assertThat( + Environments.resolveAnyOfEnvironment( + env, BeamUrns.getUrn(StandardEnvironments.Environments.EXTERNAL)), + notNullValue()); + } } diff --git a/sdks/python/apache_beam/transforms/environments.py b/sdks/python/apache_beam/transforms/environments.py index 109fcb825347..b2fbe87a729a 100644 --- a/sdks/python/apache_beam/transforms/environments.py +++ b/sdks/python/apache_beam/transforms/environments.py @@ -61,6 +61,7 @@ __all__ = [ 'Environment', + 'AnyOfEnvironment', 'DefaultEnvironment', 'DockerEnvironment', 'ProcessEnvironment', @@ -584,6 +585,24 @@ def from_options(cls, options): resource_hints=resource_hints_from_options(options)) +def expand_anyof_environments(env_proto): + if env_proto.urn == common_urns.environments.ANYOF.urn: + for alt in beam_runner_api_pb2.AnyOfEnvironmentPayload.FromString( + env_proto.payload).environments: + yield from expand_anyof_environments(alt) + else: + yield env_proto + + +def resolve_anyof_environment(env_proto, *preferred_types): + all_environments = list(expand_anyof_environments(env_proto)) + for preferred_type in preferred_types: + for env in all_environments: + if env.urn == preferred_type: + return env + return all_environments[0] + + @Environment.register_urn(python_urns.EMBEDDED_PYTHON, None) class EmbeddedPythonEnvironment(Environment): def to_runner_api_parameter(self, context): @@ -796,6 +815,45 @@ def from_command_string(cls, command_string): command_string, capabilities=python_sdk_capabilities(), artifacts=()) +@Environment.register_urn( + common_urns.environments.ANYOF.urn, + beam_runner_api_pb2.AnyOfEnvironmentPayload) +class AnyOfEnvironment(Environment): + def __init__(self, environments): + self._environments = environments + + def to_runner_api_parameter(self, context): + # type: (PipelineContext) -> Tuple[str, beam_runner_api_pb2.AnyOfEnvironmentPayload] + return ( + common_urns.environments.ANYOF.urn, + beam_runner_api_pb2.AnyOfEnvironmentPayload( + environments=[ + env.to_runner_api(context) for env in self._environments + ])) + + @staticmethod + def from_runner_api_parameter(payload, # type: beam_runner_api_pb2.AnyOfEnvironmentPayload + capabilities, # type: Iterable[str] + artifacts, # type: Iterable[beam_runner_api_pb2.ArtifactInformation] + resource_hints, # type: Mapping[str, bytes] + context # type: PipelineContext + ): + # type: (...) -> AnyOfEnvironment + return AnyOfEnvironment([ + Environment.from_runner_api(env, context) + for env in payload.environments + ]) + + @staticmethod + def create_proto( + environments: Iterable[beam_runner_api_pb2.Environment] + ) -> beam_runner_api_pb2.Environment: + return beam_runner_api_pb2.Environment( + urn=common_urns.environments.ANYOF.urn, + payload=beam_runner_api_pb2.AnyOfEnvironmentPayload( + environments=environments).SerializeToString()) + + class PyPIArtifactRegistry(object): _registered_artifacts = set() # type: Set[Tuple[str, str]] From e5a9ab6233ec76b0aab17474c684b1fad7fe5082 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 2 Nov 2023 17:23:42 -0700 Subject: [PATCH 2/7] Add AnyOf environment handling to the various runners. --- .../beam/runners/dataflow/DataflowRunner.java | 16 ++++++++++++++ .../control/DefaultJobBundleFactory.java | 6 +++++ .../runners/dataflow/dataflow_runner.py | 8 ++++++- .../portability/fn_api_runner/fn_runner.py | 22 ++++++++++++------- .../fn_api_runner/worker_handlers.py | 12 ++++++++++ 5 files changed, 55 insertions(+), 9 deletions(-) diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index d2b10f91c064..41d149b6bf58 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -877,6 +877,21 @@ public PTransformReplacement, PCollection> getRepla } } + private RunnerApi.Pipeline resolveAnyOfEnvironments(RunnerApi.Pipeline pipeline) { + RunnerApi.Pipeline.Builder pipelineBuilder = pipeline.toBuilder(); + RunnerApi.Components.Builder componentsBuilder = pipelineBuilder.getComponentsBuilder(); + componentsBuilder.clearEnvironments(); + for (Map.Entry entry : + pipeline.getComponents().getEnvironmentsMap().entrySet()) { + componentsBuilder.putEnvironments( + entry.getKey(), + Environments.resolveAnyOfEnvironment( + entry.getValue(), + BeamUrns.getUrn(RunnerApi.StandardEnvironments.Environments.DOCKER))); + } + return pipelineBuilder.build(); + } + protected RunnerApi.Pipeline applySdkEnvironmentOverrides( RunnerApi.Pipeline pipeline, DataflowPipelineOptions options) { String sdkHarnessContainerImageOverrides = options.getSdkHarnessContainerImageOverrides(); @@ -1173,6 +1188,7 @@ public DataflowPipelineJob run(Pipeline pipeline) { PipelineTranslation.toProto(pipeline, portableComponents, false); // Note that `stageArtifacts` has to be called before `resolveArtifact` because // `resolveArtifact` updates local paths to staged paths in pipeline proto. + portablePipelineProto = resolveAnyOfEnvironments(portablePipelineProto); List packages = stageArtifacts(portablePipelineProto); portablePipelineProto = resolveArtifacts(portablePipelineProto); portablePipelineProto = applySdkEnvironmentOverrides(portablePipelineProto, options); diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java index 824c2c78cc50..019028f5b936 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java @@ -232,6 +232,12 @@ private ImmutableList createEnvironmentCaches( new CacheLoader() { @Override public WrappedSdkHarnessClient load(Environment environment) throws Exception { + // TODO(robertwb): Docker is the the safest fallback (if we are distributed) + // but it would be good to have the ability to make a more intellegent choice + // (e.g. in-process or loopback workers, especially if running locally). + environment = + Environments.resolveAnyOfEnvironment( + environment, BeamUrns.getUrn(StandardEnvironments.Environments.DOCKER)); EnvironmentFactory.Provider environmentFactoryProvider = environmentFactoryProviderMap.get(environment.getUrn()); ServerFactory serverFactory = environmentFactoryProvider.getServerFactory(); diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index 7ad6ab04be68..dc315119e480 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -47,6 +47,7 @@ from apache_beam.runners.runner import PipelineResult from apache_beam.runners.runner import PipelineRunner from apache_beam.runners.runner import PipelineState +from apache_beam.transforms import environments from apache_beam.typehints import typehints from apache_beam.utils import processes from apache_beam.utils.interactive_utils import is_in_notebook @@ -380,7 +381,6 @@ def run_pipeline(self, pipeline, options, pipeline_proto=None): self.proto_pipeline = pipeline_proto else: - from apache_beam.transforms import environments if options.view_as(SetupOptions).prebuild_sdk_container_engine: # if prebuild_sdk_container_engine is specified we will build a new sdk # container image with dependencies pre-installed and use that image, @@ -414,6 +414,12 @@ def run_pipeline(self, pipeline, options, pipeline_proto=None): self.proto_pipeline, self.proto_context = pipeline.to_runner_api( return_context=True, default_environment=self._default_environment) + # Dataflow can only handle Docker environments. + for env_id, env in self.proto_pipeline.components.environments.items(): + self.proto_pipeline.components.environments[env_id].CopyFrom( + environments.resolve_anyof_environment( + env, common_urns.environments.DOCKER.urn)) + # Optimize the pipeline if it not streaming and the pre_optimize # experiment is set. if not options.view_as(StandardOptions).streaming: diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py index 098e16933b73..9abf1d9ab8b6 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py @@ -230,11 +230,16 @@ def embed_default_docker_image(self, pipeline_proto): docker_env = environments.DockerEnvironment.from_container_image( environments.DockerEnvironment.default_docker_image()).to_runner_api( None) # type: ignore[arg-type] - for env_id, env in pipeline_proto.components.environments.items(): - if env == docker_env: - docker_env_id = env_id - break - else: + + def is_python_docker_env(env): + return any( + e == docker_env for e in environments.expand_anyof_environments(env)) + + python_docker_environments = set( + env_id + for (env_id, env) in pipeline_proto.components.environments.items() + if is_python_docker_env(env)) + if not python_docker_environments: # No matching docker environments. return pipeline_proto @@ -244,12 +249,13 @@ def embed_default_docker_image(self, pipeline_proto): break else: # No existing embedded environment. - pipeline_proto.components.environments[docker_env_id].CopyFrom( - embedded_env) + for docker_env_id in python_docker_environments: + pipeline_proto.components.environments[docker_env_id].CopyFrom( + embedded_env) return pipeline_proto for transform in pipeline_proto.components.transforms.values(): - if transform.environment_id == docker_env_id: + if transform.environment_id in python_docker_environments: transform.environment_id = embedded_env_id return pipeline_proto diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py index b11c8349909c..de55235368e3 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py @@ -879,6 +879,18 @@ def get_worker_handlers( environment_id = next(iter(self._environments.keys())) environment = self._environments[environment_id] + if environment.urn == common_urns.environments.ANYOF.urn: + payload = beam_runner_api_pb2.AnyOfEnvironmentPayload.FromString( + environment.payload) + env_rankings = { + python_urns.EMBEDDED_PYTHON: 10, + common_urns.environments.EXTERNAL.urn: 5, + common_urns.environments.DOCKER.urn: 1, + } + environment = sorted( + payload.environments, + key=lambda env: env_rankings.get(env.urn, -1))[-1] + # assume all environments except EMBEDDED_PYTHON use gRPC. if environment.urn == python_urns.EMBEDDED_PYTHON: # special case for EmbeddedWorkerHandler: there's no need for a gRPC From b136f1d3dd6da8b7bea48e8587e0d2bb7f360483 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 2 Nov 2023 17:25:57 -0700 Subject: [PATCH 3/7] Add AnyOf environment handling to the artifact staging services. --- .../jobsubmission/InMemoryJobService.java | 58 ++++++++++++------ .../portability/fn_api_runner/fn_runner.py | 19 +++--- .../runners/portability/local_job_service.py | 59 +++++++++++++++---- .../python/apache_beam/transforms/external.py | 20 +++++-- 4 files changed, 115 insertions(+), 41 deletions(-) diff --git a/runners/java-job-service/src/main/java/org/apache/beam/runners/jobsubmission/InMemoryJobService.java b/runners/java-job-service/src/main/java/org/apache/beam/runners/jobsubmission/InMemoryJobService.java index 17efbf9a06ec..41e6135b9207 100644 --- a/runners/java-job-service/src/main/java/org/apache/beam/runners/jobsubmission/InMemoryJobService.java +++ b/runners/java-job-service/src/main/java/org/apache/beam/runners/jobsubmission/InMemoryJobService.java @@ -49,6 +49,7 @@ import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc; import org.apache.beam.model.pipeline.v1.Endpoints; import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.core.construction.Environments; import org.apache.beam.runners.core.construction.graph.PipelineValidator; import org.apache.beam.runners.fnexecution.artifact.ArtifactStagingService; import org.apache.beam.sdk.fn.server.FnService; @@ -62,7 +63,6 @@ import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.StatusRuntimeException; import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -198,11 +198,7 @@ public void prepare( stagingService .getService() - .registerJob( - stagingSessionToken, - Maps.transformValues( - request.getPipeline().getComponents().getEnvironmentsMap(), - RunnerApi.Environment::getDependenciesList)); + .registerJob(stagingSessionToken, extractDependencies(request.getPipeline())); // send response PrepareJobResponse response = @@ -287,26 +283,50 @@ public void run(RunJobRequest request, StreamObserver responseOb } } + private Map> extractDependencies( + RunnerApi.Pipeline pipeline) { + Map> dependencies = new HashMap<>(); + for (Map.Entry entry : + pipeline.getComponents().getEnvironmentsMap().entrySet()) { + List subEnvs = Environments.expandAnyOfEnvironments(entry.getValue()); + for (int i = 0; i < subEnvs.size(); i++) { + dependencies.put(i + ":" + entry.getKey(), subEnvs.get(i).getDependenciesList()); + } + } + return dependencies; + } + private RunnerApi.Pipeline resolveDependencies(RunnerApi.Pipeline pipeline, String stagingToken) { Map> resolvedDependencies = stagingService.getService().getStagedArtifacts(stagingToken); Map newEnvironments = new HashMap<>(); for (Map.Entry entry : pipeline.getComponents().getEnvironmentsMap().entrySet()) { - if (entry.getValue().getDependenciesCount() > 0 && resolvedDependencies == null) { - throw new RuntimeException( - "Artifact dependencies provided but not staged for " + entry.getKey()); + List subEnvs = Environments.expandAnyOfEnvironments(entry.getValue()); + List newSubEnvs = new ArrayList<>(); + for (int i = 0; i < subEnvs.size(); i++) { + RunnerApi.Environment subEnv = subEnvs.get(i); + if (subEnv.getDependenciesCount() > 0 && resolvedDependencies == null) { + throw new RuntimeException( + "Artifact dependencies provided but not staged for " + entry.getKey()); + } + newSubEnvs.add( + subEnv.getDependenciesCount() == 0 + ? subEnv + : subEnv + .toBuilder() + .clearDependencies() + .addAllDependencies(resolvedDependencies.get(i + ":" + entry.getKey())) + .build()); + } + if (newSubEnvs.size() == 1) { + newEnvironments.put(entry.getKey(), newSubEnvs.get(0)); + } else { + newEnvironments.put( + entry.getKey(), + Environments.createAnyOfEnvironment( + newSubEnvs.toArray(new RunnerApi.Environment[newSubEnvs.size()]))); } - newEnvironments.put( - entry.getKey(), - entry.getValue().getDependenciesCount() == 0 - ? entry.getValue() - : entry - .getValue() - .toBuilder() - .clearDependencies() - .addAllDependencies(resolvedDependencies.get(entry.getKey())) - .build()); } RunnerApi.Pipeline.Builder builder = pipeline.toBuilder(); builder.getComponentsBuilder().clearEnvironments().putAllEnvironments(newEnvironments); diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py index 9abf1d9ab8b6..a736288dc62d 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py @@ -224,6 +224,10 @@ def run_via_runner_api(self, pipeline_proto, options): return self.run_stages(stage_context, stages) def embed_default_docker_image(self, pipeline_proto): + """Updates the pipeline proto to execute transforms that would normally + be executed in the default docker image for this SDK to execute inline + via the "embedded" environment. + """ # Context is unused for these types. embedded_env = environments.EmbeddedPythonEnvironment.default( ).to_runner_api(None) # type: ignore[arg-type] @@ -231,14 +235,14 @@ def embed_default_docker_image(self, pipeline_proto): environments.DockerEnvironment.default_docker_image()).to_runner_api( None) # type: ignore[arg-type] - def is_python_docker_env(env): + def is_this_python_docker_env(env): return any( e == docker_env for e in environments.expand_anyof_environments(env)) python_docker_environments = set( env_id for (env_id, env) in pipeline_proto.components.environments.items() - if is_python_docker_env(env)) + if is_this_python_docker_env(env)) if not python_docker_environments: # No matching docker environments. return pipeline_proto @@ -248,11 +252,12 @@ def is_python_docker_env(env): embedded_env_id = env_id break else: - # No existing embedded environment. - for docker_env_id in python_docker_environments: - pipeline_proto.components.environments[docker_env_id].CopyFrom( - embedded_env) - return pipeline_proto + # No existing embedded environment. Create one. + embedded_env_id = "python_embedded_env" + while embedded_env_id in pipeline_proto.components.environments: + embedded_env_id += '_' + pipeline_proto.components.environments[embedded_env_id].CopyFrom( + embedded_env) for transform in pipeline_proto.components.transforms.values(): if transform.environment_id in python_docker_environments: diff --git a/sdks/python/apache_beam/runners/portability/local_job_service.py b/sdks/python/apache_beam/runners/portability/local_job_service.py index 91ddb3fced15..6966e66d2c64 100644 --- a/sdks/python/apache_beam/runners/portability/local_job_service.py +++ b/sdks/python/apache_beam/runners/portability/local_job_service.py @@ -28,7 +28,9 @@ import time import traceback from typing import TYPE_CHECKING +from typing import Any from typing import List +from typing import Mapping from typing import Optional import grpc @@ -43,6 +45,7 @@ from apache_beam.portability.api import beam_job_api_pb2 from apache_beam.portability.api import beam_job_api_pb2_grpc from apache_beam.portability.api import beam_provision_api_pb2 +from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.portability.api import endpoints_pb2 from apache_beam.runners.job import utils as job_utils from apache_beam.runners.portability import abstract_job_service @@ -51,11 +54,11 @@ from apache_beam.runners.portability.fn_api_runner import fn_runner from apache_beam.runners.portability.fn_api_runner import worker_handlers from apache_beam.runners.worker.log_handler import LOGENTRY_TO_LOG_LEVEL_MAP +from apache_beam.transforms import environments from apache_beam.utils import thread_pool_executor if TYPE_CHECKING: from google.protobuf import struct_pb2 # pylint: disable=ungrouped-imports - from apache_beam.portability.api import beam_runner_api_pb2 _LOGGER = logging.getLogger(__name__) @@ -96,10 +99,8 @@ def create_beam_job(self, # type: (...) -> BeamJob self._artifact_service.register_job( staging_token=preparation_id, - dependency_sets={ - id: env.dependencies - for (id, env) in pipeline.components.environments.items() - }) + dependency_sets=_extract_dependency_sets( + pipeline.components.environments)) provision_info = fn_runner.ExtendedProvisionInfo( beam_provision_api_pb2.ProvisionInfo(pipeline_options=options), self._staging_dir, @@ -321,12 +322,9 @@ def _invoke_runner(self): def _update_dependencies(self): try: - for env_id, deps in self._artifact_service.resolved_deps( - self._job_id, timeout=0).items(): - # Slice assignment not supported for repeated fields. - env = self._pipeline_proto.components.environments[env_id] - del env.dependencies[:] - env.dependencies.extend(deps) + _update_dependency_sets( + self._pipeline_proto.components.environments, + self._artifact_service.resolved_deps(self._job_id, timeout=0)) self._provision_info.provision_info.ClearField('retrieval_token') except concurrent.futures.TimeoutError: # TODO(https://github.com/apache/beam/issues/20267): Require this once @@ -457,3 +455,42 @@ def emit(self, record): # Inform all message consumers. self._log_queues.put(msg) + + +def _extract_dependency_sets( + envs: Mapping[str, beam_runner_api_pb2.Environment] +) -> Mapping[Any, List[beam_runner_api_pb2.ArtifactInformation]]: + """Expands the set of environments into a mapping of (opaque) keys to + dependency sets. This is not 1:1 in the case of AnyOf environments. + + The values can then be resolved and the mapping passed back to + _update_dependency_sets to update the dependencies in the original protos. + """ + def dependencies_iter(): + for env_id, env in envs.items(): + for ix, sub_env in enumerate(environments.expand_anyof_environments(env)): + yield (env_id, ix), sub_env.dependencies + + return dict(dependencies_iter()) + + +def _update_dependency_sets( + envs: Mapping[str, beam_runner_api_pb2.Environment], + resolved_deps: Mapping[Any, List[beam_runner_api_pb2.ArtifactInformation]]): + """Takes the mapping of beam Environments (originally passed to + `_extract_dependency_sets`) and a set of (key-wise) updated dependencies, + and updates the original environment protos to contain the updated + dependencies. + """ + for env_id, env in envs.items(): + new_envs = [] + for ix, sub_env in enumerate(environments.expand_anyof_environments(env)): + # Slice assignment not supported for repeated fields. + del sub_env.dependencies[:] + sub_env.dependencies.extend(resolved_deps[env_id, ix]) + new_envs.append(sub_env) + if len(new_envs) == 1: + envs[env_id].CopyFrom(new_envs[0]) + else: + envs[env_id].CopyFrom( + environments.AnyOfEnvironment.create_proto(new_envs)) diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index 0d0b6f1e7be2..52e27ecc2e8b 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -43,6 +43,7 @@ from apache_beam.portability.api import external_transforms_pb2 from apache_beam.runners import pipeline_context from apache_beam.runners.portability import artifact_service +from apache_beam.transforms import environments from apache_beam.transforms import ptransform from apache_beam.typehints import WithTypeHints from apache_beam.typehints import native_type_compatibility @@ -731,8 +732,9 @@ def expand(self, pvalueish): if response.error: raise RuntimeError(response.error) self._expanded_components = response.components - if any(env.dependencies - for env in self._expanded_components.environments.values()): + if any(e.dependencies + for env in self._expanded_components.environments.values() + for e in environments.expand_anyof_environments(env)): self._expanded_components = self._resolve_artifacts( self._expanded_components, service.artifact_service(), @@ -785,12 +787,22 @@ def service(expansion_service): yield stub def _resolve_artifacts(self, components, service, dest): - for env in components.environments.values(): - if env.dependencies: + def _resolve_artifacts_for(env): + if env.urn == common_urns.environments.ANYOF.urn: + env.CopyFrom( + environments.AnyOfEnvironment.create_proto([ + _resolve_artifacts_for(e) + for e in environments.expand_anyof_environments(env) + ])) + elif env.dependencies: resolved = list( artifact_service.resolve_artifacts(env.dependencies, service, dest)) del env.dependencies[:] env.dependencies.extend(resolved) + return env + + for env in components.environments.values(): + _resolve_artifacts_for(env) return components def _output_to_pvalueish(self, output_dict): From 6550cec0e9cfcc35991d5d0170360eb92384a572 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 2 Nov 2023 17:28:06 -0700 Subject: [PATCH 4/7] Add the ability to cache subprocess services for the duration of a pipeline. This can greatly reduce startup time when many cross-langauge transforms are used, but more importantly by keeping these processes alive we open up the potential for using them as workers as well. These can be cached across longer durations as well, but this is the default. --- sdks/python/apache_beam/pipeline.py | 10 +- .../apache_beam/utils/subprocess_server.py | 192 +++++++++++++----- .../utils/subprocess_server_test.py | 79 +++++++ 3 files changed, 227 insertions(+), 54 deletions(-) diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index ed0736250d1f..53044982a066 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -48,6 +48,7 @@ # mypy: disallow-untyped-defs import abc +import contextlib import logging import os import re @@ -590,9 +591,12 @@ def run(self, test_runner_api='AUTO'): def __enter__(self): # type: () -> Pipeline - self._extra_context = subprocess_server.JavaJarServer.beam_services( - self._options.view_as(CrossLanguageOptions).beam_services) - self._extra_context.__enter__() + self._extra_context = contextlib.ExitStack() + self._extra_context.enter_context( + subprocess_server.JavaJarServer.beam_services( + self._options.view_as(CrossLanguageOptions).beam_services)) + self._extra_context.enter_context( + subprocess_server.SubprocessServer.cache_subprocesses()) return self def __exit__( diff --git a/sdks/python/apache_beam/utils/subprocess_server.py b/sdks/python/apache_beam/utils/subprocess_server.py index 6ed1568b57a5..7d9cf50d8532 100644 --- a/sdks/python/apache_beam/utils/subprocess_server.py +++ b/sdks/python/apache_beam/utils/subprocess_server.py @@ -18,6 +18,7 @@ # pytype: skip-file import contextlib +import dataclasses import glob import hashlib import logging @@ -27,10 +28,11 @@ import signal import socket import subprocess -import tempfile import threading import time import zipfile +from typing import Any +from typing import Set from urllib.error import URLError from urllib.request import urlopen @@ -42,6 +44,75 @@ _LOGGER = logging.getLogger(__name__) +@dataclasses.dataclass +class _SharedCacheEntry: + obj: Any + owners: Set[str] + + +class _SharedCache: + """A cache that keeps objects alive (and repeatedly returns the same instance) + until the last user indicates that they're done. + + The typical usage is as follows:: + + try: + token = cache.register() + # All objects retrieved from the cache from this point on will be memoized + # and kept alive (including across other threads and callers) at least + # until the purge is called below (and possibly longer, if other calls + # to register were made). + obj = cache.get(...) + another_obj = cache.get(...) + ... + finally: + cache.purge(token) + """ + def __init__(self, constructor, destructor): + self._constructor = constructor + self._destructor = destructor + self._live_owners = set() + self._cache = {} + self._lock = threading.Lock() + self._counter = 0 + + def _next_id(self): + with self._lock: + self._counter += 1 + return self._counter + + def register(self): + owner = self._next_id() + self._live_owners.add(owner) + return owner + + def purge(self, owner): + if owner not in self._live_owners: + raise ValueError(f"{owner} not in {self._live_owners}") + self._live_owners.remove(owner) + to_delete = [] + with self._lock: + for key, entry in list(self._cache.items()): + if owner in entry.owners: + entry.owners.remove(owner) + if not entry.owners: + to_delete.append(entry.obj) + del self._cache[key] + # Actually call the destructors outside of the lock. + for value in to_delete: + self._destructor(value) + + def get(self, *key): + if not self._live_owners: + raise RuntimeError("At least one owner must be registered.") + with self._lock: + if key not in self._cache: + self._cache[key] = _SharedCacheEntry(self._constructor(*key), set()) + for owner in self._live_owners: + self._cache[key].owners.add(owner) + return self._cache[key].obj + + class SubprocessServer(object): """An abstract base class for running GRPC Servers as an external process. @@ -63,13 +134,26 @@ def __init__(self, stub_class, cmd, port=None): string "{{PORT}}" will be substituted in the command line arguments with the chosen port. """ - self._process_lock = threading.RLock() - self._process = None + self._owner_id = None self._stub_class = stub_class self._cmd = [str(arg) for arg in cmd] self._port = port self._grpc_channel = None + @classmethod + @contextlib.contextmanager + def cache_subprocesses(cls): + """A context that ensures any subprocess created or used in its duration + stay alive for at least the duration of this context. + + These subprocesses may be shared with other contexts as well. + """ + try: + unique_id = cls._cache.register() + yield + finally: + cls._cache.purge(unique_id) + def __enter__(self): return self.start() @@ -78,7 +162,7 @@ def __exit__(self, *unused_args): def start(self): try: - endpoint = self.start_process() + process, endpoint = self.start_process() wait_secs = .1 channel_options = [("grpc.max_receive_message_length", -1), ("grpc.max_send_message_length", -1)] @@ -86,10 +170,10 @@ def start(self): endpoint, options=channel_options) channel_ready = grpc.channel_ready_future(self._grpc_channel) while True: - if self._process is not None and self._process.poll() is not None: - _LOGGER.error("Starting job service with %s", self._process.args) + if process is not None and process.poll() is not None: + _LOGGER.error("Started job service with %s", process.args) raise RuntimeError( - 'Service failed to start up with error %s' % self._process.poll()) + 'Service failed to start up with error %s' % process.poll()) try: channel_ready.result(timeout=wait_secs) break @@ -106,60 +190,66 @@ def start(self): raise def start_process(self): - with self._process_lock: - if self._process: - self.stop() - if self._port: - port = self._port - cmd = self._cmd - else: - port, = pick_port(None) - cmd = [arg.replace('{{PORT}}', str(port)) for arg in self._cmd] - endpoint = 'localhost:%s' % port - _LOGGER.info("Starting service with %s", str(cmd).replace("',", "'")) - self._process = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - - # Emit the output of this command as info level logging. - def log_stdout(): - line = self._process.stdout.readline() - while line: - # The log obtained from stdout is bytes, decode it into string. - # Remove newline via rstrip() to not print an empty line. - _LOGGER.info(line.decode(errors='backslashreplace').rstrip()) - line = self._process.stdout.readline() - - t = threading.Thread(target=log_stdout) - t.daemon = True - t.start() - return endpoint + if self._owner_id is not None: + self._cache.purge(self._owner_id) + self._owner_id = self._cache.register() + return self._cache.get(tuple(self._cmd), self._port) + + def _really_start_process(cmd, port): + if not port: + port, = pick_port(None) + cmd = [arg.replace('{{PORT}}', str(port)) for arg in cmd] # pylint: disable=not-an-iterable + endpoint = 'localhost:%s' % port + _LOGGER.info("Starting service with %s", str(cmd).replace("',", "'")) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + # Emit the output of this command as info level logging. + def log_stdout(): + line = process.stdout.readline() + while line: + # The log obtained from stdout is bytes, decode it into string. + # Remove newline via rstrip() to not print an empty line. + _LOGGER.info(line.decode(errors='backslashreplace').rstrip()) + line = process.stdout.readline() + + t = threading.Thread(target=log_stdout) + t.daemon = True + t.start() + return process, endpoint def stop(self): self.stop_process() def stop_process(self): - with self._process_lock: - if not self._process: - return - for _ in range(5): - if self._process.poll() is not None: - break - logging.debug("Sending SIGINT to job_server") - self._process.send_signal(signal.SIGINT) - time.sleep(1) - if self._process.poll() is None: - self._process.kill() - self._process = None + if self._owner_id is not None: + self._cache.purge(self._owner_id) + self._owner_id = None if self._grpc_channel: try: self._grpc_channel.close() except: # pylint: disable=bare-except _LOGGER.error( - "Could not close the gRPC channel started for the " + + "Could not close the gRPC channel started for the " "expansion service") - - def local_temp_dir(self, **kwargs): - return tempfile.mkdtemp(dir=self._local_temp_root, **kwargs) + finally: + self._grpc_channel = None + + def _really_stop_process(process_and_endpoint): + process, _ = process_and_endpoint # pylint: disable=unpacking-non-sequence + if not process: + return + for _ in range(5): + if process.poll() is not None: + break + logging.debug("Sending SIGINT to process") + process.send_signal(signal.SIGINT) + time.sleep(1) + if process.poll() is None: + process.kill() + + _cache = _SharedCache( + constructor=_really_start_process, destructor=_really_stop_process) class JavaJarServer(SubprocessServer): @@ -184,7 +274,7 @@ def __init__(self, stub_class, path_to_jar, java_arguments, classpath=None): def start_process(self): if self._existing_service: - return self._existing_service + return None, self._existing_service else: if not shutil.which('java'): raise RuntimeError( diff --git a/sdks/python/apache_beam/utils/subprocess_server_test.py b/sdks/python/apache_beam/utils/subprocess_server_test.py index e0d0892c8e68..c0c8e5694b86 100644 --- a/sdks/python/apache_beam/utils/subprocess_server_test.py +++ b/sdks/python/apache_beam/utils/subprocess_server_test.py @@ -21,6 +21,7 @@ import glob import os +import random import re import shutil import socketserver @@ -166,5 +167,83 @@ def test_classpath_jar(self): os.chdir(oldwd) +class CacheTest(unittest.TestCase): + @staticmethod + def with_prefix(prefix): + return '%s-%s' % (prefix, random.random()) + + def test_memoization(self): + cache = subprocess_server._SharedCache(self.with_prefix, lambda x: None) + try: + token = cache.register() + a = cache.get('a') + self.assertEqual(a[0], 'a') + self.assertEqual(cache.get('a'), a) + b = cache.get('b') + self.assertEqual(b[0], 'b') + self.assertEqual(cache.get('b'), b) + finally: + cache.purge(token) + + def test_purged(self): + cache = subprocess_server._SharedCache(self.with_prefix, lambda x: None) + try: + token = cache.register() + a = cache.get('a') + self.assertEqual(cache.get('a'), a) + finally: + cache.purge(token) + + try: + token = cache.register() + new_a = cache.get('a') + self.assertNotEqual(new_a, a) + finally: + cache.purge(token) + + def test_multiple_owners(self): + cache = subprocess_server._SharedCache(self.with_prefix, lambda x: None) + try: + owner1 = cache.register() + a = cache.get('a') + try: + self.assertEqual(cache.get('a'), a) + owner2 = cache.register() + b = cache.get('b') + self.assertEqual(cache.get('b'), b) + finally: + cache.purge(owner2) + self.assertEqual(cache.get('a'), a) + self.assertEqual(cache.get('b'), b) + finally: + cache.purge(owner1) + + try: + owner3 = cache.register() + self.assertNotEqual(cache.get('a'), a) + self.assertNotEqual(cache.get('b'), b) + finally: + cache.purge(owner3) + + def test_interleaved_owners(self): + cache = subprocess_server._SharedCache(self.with_prefix, lambda x: None) + owner1 = cache.register() + a = cache.get('a') + self.assertEqual(cache.get('a'), a) + + owner2 = cache.register() + b = cache.get('b') + self.assertEqual(cache.get('b'), b) + + cache.purge(owner1) + self.assertNotEqual(cache.get('a'), a) + self.assertEqual(cache.get('b'), b) + + cache.purge(owner2) + owner3 = cache.register() + self.assertNotEqual(cache.get('b'), b) + cache.purge(owner3) + + if __name__ == '__main__': unittest.main() From e8cebdf4fe7e920b6494156aa3127037f23b58fe Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 2 Nov 2023 17:33:13 -0700 Subject: [PATCH 5/7] Add an option to start loopback workers in expansion services. --- sdks/java/expansion-service/build.gradle | 3 ++ .../expansion/service/ExpansionService.java | 37 +++++++++++++++---- .../service/ExpansionServiceOptions.java | 5 +++ .../runners/portability/expansion_service.py | 14 +++++-- .../portability/expansion_service_main.py | 16 +++++++- 5 files changed, 62 insertions(+), 13 deletions(-) diff --git a/sdks/java/expansion-service/build.gradle b/sdks/java/expansion-service/build.gradle index 7bc77e3aea80..99c515cd4e63 100644 --- a/sdks/java/expansion-service/build.gradle +++ b/sdks/java/expansion-service/build.gradle @@ -36,11 +36,14 @@ test { dependencies { implementation project(path: ":model:pipeline", configuration: "shadow") + implementation project(path: ":model:fn-execution", configuration: "shadow") implementation project(path: ":model:job-management", configuration: "shadow") implementation project(path: ":sdks:java:core", configuration: "shadow") implementation project(path: ":runners:core-construction-java") implementation project(path: ":runners:java-fn-execution") implementation project(path: ":sdks:java:fn-execution") + implementation project(path: ":sdks:java:harness") + permitUnusedDeclared project(path: ":model:fn-execution") permitUnusedDeclared project(path: ":sdks:java:fn-execution") implementation library.java.jackson_annotations implementation library.java.jackson_databind diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java index c6509a5bbb4a..8e57ad706fed 100644 --- a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java @@ -36,6 +36,7 @@ import java.util.ServiceLoader; import java.util.Set; import java.util.stream.Collectors; +import org.apache.beam.fn.harness.ExternalWorkerService; import org.apache.beam.model.expansion.v1.ExpansionApi; import org.apache.beam.model.expansion.v1.ExpansionApi.DiscoverSchemaTransformRequest; import org.apache.beam.model.expansion.v1.ExpansionApi.DiscoverSchemaTransformResponse; @@ -514,6 +515,7 @@ default List getDependencies(RunnerApi.FunctionSpec spec, PipelineOption private @MonotonicNonNull Map registeredTransforms; private final PipelineOptions pipelineOptions; + private final @Nullable String loopbackAddress; public ExpansionService() { this(new String[] {}); @@ -524,7 +526,12 @@ public ExpansionService(String[] args) { } public ExpansionService(PipelineOptions opts) { + this(opts, null); + } + + public ExpansionService(PipelineOptions opts, @Nullable String loopbackAddress) { this.pipelineOptions = opts; + this.loopbackAddress = loopbackAddress; } private Map getRegisteredTransforms() { @@ -628,9 +635,19 @@ private Map loadRegisteredTransforms() { rehydratedComponents .getSdkComponents(request.getRequirementsList()) .withNewIdPrefix(request.getNamespace()); - sdkComponents.registerEnvironment( + RunnerApi.Environment defaultEnvironment = Environments.createOrGetDefaultEnvironment( - pipeline.getOptions().as(PortablePipelineOptions.class))); + pipeline.getOptions().as(PortablePipelineOptions.class)); + if (pipelineOptions.as(ExpansionServiceOptions.class).getAlsoStartLoopbackWorker()) { + PortablePipelineOptions externalOptions = + PipelineOptionsFactory.create().as(PortablePipelineOptions.class); + externalOptions.setDefaultEnvironmentType(Environments.ENVIRONMENT_EXTERNAL); + externalOptions.setDefaultEnvironmentConfig(loopbackAddress); + defaultEnvironment = + Environments.createAnyOfEnvironment( + defaultEnvironment, Environments.createOrGetDefaultEnvironment(externalOptions)); + } + sdkComponents.registerEnvironment(defaultEnvironment); Map outputMap = outputs.entrySet().stream() .collect( @@ -759,9 +776,12 @@ public static void main(String[] args) throws Exception { // Register the options class used by the expansion service. PipelineOptionsFactory.register(ExpansionServiceOptions.class); + @SuppressWarnings({"nullness"}) + PipelineOptions options = + PipelineOptionsFactory.fromArgs(Arrays.copyOfRange(args, 1, args.length)).create(); @SuppressWarnings("nullness") - ExpansionService service = new ExpansionService(Arrays.copyOfRange(args, 1, args.length)); + ExpansionService service = new ExpansionService(options, "localhost:" + port); StringBuilder registeredTransformsLog = new StringBuilder(); boolean registeredTransformsFound = false; @@ -794,11 +814,12 @@ public static void main(String[] args) throws Exception { System.out.println("\nDid not find any registered transforms or SchemaTransforms.\n"); } - Server server = - ServerBuilder.forPort(port) - .addService(service) - .addService(new ArtifactRetrievalService()) - .build(); + ServerBuilder serverBuilder = + ServerBuilder.forPort(port).addService(service).addService(new ArtifactRetrievalService()); + if (options.as(ExpansionServiceOptions.class).getAlsoStartLoopbackWorker()) { + serverBuilder.addService(new ExternalWorkerService(options)); + } + Server server = serverBuilder.build(); server.start(); server.awaitTermination(); } diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceOptions.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceOptions.java index b3af511cd94f..6b569c7fce38 100644 --- a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceOptions.java +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceOptions.java @@ -42,6 +42,11 @@ public interface ExpansionServiceOptions extends PipelineOptions { void setJavaClassLookupAllowlistFile(String file); + @Description("Whether to also start a loopback worker as part of this service.") + boolean getAlsoStartLoopbackWorker(); + + void setAlsoStartLoopbackWorker(boolean value); + @Description("Expansion service configuration file.") String getExpansionServiceConfigFile(); diff --git a/sdks/python/apache_beam/runners/portability/expansion_service.py b/sdks/python/apache_beam/runners/portability/expansion_service.py index 8be9d98508ed..50c793a0e8bd 100644 --- a/sdks/python/apache_beam/runners/portability/expansion_service.py +++ b/sdks/python/apache_beam/runners/portability/expansion_service.py @@ -22,6 +22,7 @@ import traceback from apache_beam import pipeline as beam_pipeline +from apache_beam.portability import common_urns from apache_beam.portability import python_urns from apache_beam.portability.api import beam_expansion_api_pb2 from apache_beam.portability.api import beam_expansion_api_pb2_grpc @@ -33,11 +34,18 @@ class ExpansionServiceServicer( beam_expansion_api_pb2_grpc.ExpansionServiceServicer): - def __init__(self, options=None): + def __init__(self, options=None, loopback_address=None): self._options = options or beam_pipeline.PipelineOptions( environment_type=python_urns.EMBEDDED_PYTHON, sdk_location='container') - self._default_environment = ( - environments.Environment.from_options(self._options)) + default_environment = (environments.Environment.from_options(self._options)) + if loopback_address: + loopback_environment = environments.Environment.from_options( + beam_pipeline.PipelineOptions( + environment_type=common_urns.environments.EXTERNAL.urn, + environment_config=loopback_address)) + default_environment = environments.AnyOfEnvironment( + [default_environment, loopback_environment]) + self._default_environment = default_environment def Expand(self, request, context=None): try: diff --git a/sdks/python/apache_beam/runners/portability/expansion_service_main.py b/sdks/python/apache_beam/runners/portability/expansion_service_main.py index 30cbdccb596e..307f6bd54182 100644 --- a/sdks/python/apache_beam/runners/portability/expansion_service_main.py +++ b/sdks/python/apache_beam/runners/portability/expansion_service_main.py @@ -28,8 +28,10 @@ from apache_beam.options.pipeline_options import SetupOptions from apache_beam.portability.api import beam_artifact_api_pb2_grpc from apache_beam.portability.api import beam_expansion_api_pb2_grpc +from apache_beam.portability.api import beam_fn_api_pb2_grpc from apache_beam.runners.portability import artifact_service from apache_beam.runners.portability import expansion_service +from apache_beam.runners.worker import worker_pool_main from apache_beam.transforms import fully_qualified_named_transform from apache_beam.utils import thread_pool_executor @@ -41,6 +43,7 @@ def main(argv): parser.add_argument( '-p', '--port', type=int, help='port on which to serve the job api') parser.add_argument('--fully_qualified_name_glob', default=None) + parser.add_argument('--serve_loopback_worker', action='store_true') known_args, pipeline_args = parser.parse_known_args(argv) pipeline_options = PipelineOptions( pipeline_args + ["--experiments=beam_fn_api", "--sdk_location=container"]) @@ -52,14 +55,23 @@ def main(argv): with fully_qualified_named_transform.FullyQualifiedNamedTransform.with_filter( known_args.fully_qualified_name_glob): + address = '[::]:{}'.format(known_args.port) server = grpc.server(thread_pool_executor.shared_unbounded_instance()) + if known_args.serve_loopback_worker: + beam_fn_api_pb2_grpc.add_BeamFnExternalWorkerPoolServicer_to_server( + worker_pool_main.BeamFnExternalWorkerPoolServicer(), server) + loopback_address = address + else: + loopback_address = None beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server( - expansion_service.ExpansionServiceServicer(pipeline_options), server) + expansion_service.ExpansionServiceServicer( + pipeline_options, loopback_address=loopback_address), + server) beam_artifact_api_pb2_grpc.add_ArtifactRetrievalServiceServicer_to_server( artifact_service.ArtifactRetrievalService( artifact_service.BeamFilesystemHandler(None).file_reader), server) - server.add_insecure_port('[::]:{}'.format(known_args.port)) + server.add_insecure_port(address) server.start() _LOGGER.info('Listening for expansion requests at %d', known_args.port) From e88fb5df88774ce570240413ffd09c51ea4bd97a Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 2 Nov 2023 17:30:57 -0700 Subject: [PATCH 6/7] Use Java expansion services as workers by default in Python. Due to the AnyOf environment, remote runners can choose more expensive but remote-friendly options such as docker. --- CHANGES.md | 3 +++ sdks/python/apache_beam/transforms/external.py | 7 ++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 9318e85d477b..4cf4cd1fd497 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -71,6 +71,9 @@ * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * The Python SDK now type checks `collections.abc.Collections` types properly. Some type hints that were erroneously allowed by the SDK may now fail. ([#29272](https://github.com/apache/beam/pull/29272)) +* Running multi-language pipelines locally no longer requires Docker. + Instead, the same (generally auto-started) subprocess used to perform the + expansion can also be used as the cross-language worker. ## Breaking Changes diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index 52e27ecc2e8b..fc4ae3caa6df 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -980,7 +980,12 @@ def _default_args(self): to_stage = ','.join([self._path_to_jar] + sum(( JavaJarExpansionService._expand_jars(jar) for jar in self._classpath or []), [])) - return ['{{PORT}}', f'--filesToStage={to_stage}'] + args = ['{{PORT}}', f'--filesToStage={to_stage}'] + # TODO(robertwb): See if it's possible to scope this per pipeline. + # Checks to see if the cache is being used for this server. + if subprocess_server.SubprocessServer._cache._live_owners: + args.append('--alsoStartLoopbackWorker') + return args def __enter__(self): if self._service_count == 0: From d85ce3f38c531f1f880571b884c75cf958cbba60 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 3 Nov 2023 09:52:43 -0700 Subject: [PATCH 7/7] More complete documentation of environment proto definitions. --- .../model/pipeline/v1/beam_runner_api.proto | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto index 87af7c19dd79..48f057bbc1ea 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto @@ -1563,15 +1563,26 @@ message Environment { message StandardEnvironments { enum Environments { - DOCKER = 0 [(beam_urn) = "beam:env:docker:v1"]; // A managed docker container to run user code. - - PROCESS = 1 [(beam_urn) = "beam:env:process:v1"]; // A managed native process to run user code. - - EXTERNAL = 2 [(beam_urn) = "beam:env:external:v1"]; // An external non managed process to run user code. - - DEFAULT = 3 [(beam_urn) = "beam:env:default:v1"]; // Used as a stub when context is missing a runner-provided default environment. - - ANYOF = 4 [(beam_urn) = "beam:env:anyof:v1"]; // A selection of equivalent environments a runner may use. + // A managed docker container to run user code. + // Payload should be DockerPayload. + DOCKER = 0 [(beam_urn) = "beam:env:docker:v1"]; + + // A managed native process to run user code. + // Payload should be ProcessPayload. + PROCESS = 1 [(beam_urn) = "beam:env:process:v1"]; + + // An external non managed process to run user code. + // Payload should be ExternalPayload. + EXTERNAL = 2 [(beam_urn) = "beam:env:external:v1"]; + + // Used as a stub when context is missing a runner-provided default environment. + DEFAULT = 3 [(beam_urn) = "beam:env:default:v1"]; + + // A selection of equivalent fully-specified environments a runner may use. + // Note that this environment itself does not declare any dependencies or capabilities, + // as those may differ among the several alternatives. + // Payload should be AnyOfEnvironmentPayload. + ANYOF = 4 [(beam_urn) = "beam:env:anyof:v1"]; } } @@ -1588,11 +1599,12 @@ message ProcessPayload { } message ExternalPayload { - ApiServiceDescriptor endpoint = 1; + ApiServiceDescriptor endpoint = 1; // Serving BeamFnExternalWorkerPool API. map params = 2; // Arbitrary extra parameters to pass } message AnyOfEnvironmentPayload { + // Each is fully contained (with their own dependencies, capabilities, etc.) repeated Environment environments = 1; }