From ac0f3c3f480d04e8729e7a2ebadc33961f8ffc9f Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Tue, 17 Dec 2024 14:44:13 -0500 Subject: [PATCH] Add TF MNIST classification cost benchmark (#33391) * Add TF MNIST classification cost benchmark * linting * Generalize to single workflow file for cost benchmarks * fix incorrect UTC time in comment * move wordcount to same workflow * update workflow job name --- ...> beam_Python_CostBenchmarks_Dataflow.yml} | 28 +- .../python_tf_mnist_classification.txt | 29 ++ .../apache_beam/examples/inference/output.txt | 3 + ...low_mnist_classification_cost_benchmark.py | 41 ++ .../apache_beam/yaml/generate_yaml_docs.py | 1 + sdks/python/apache_beam/yaml/main.py | 6 +- sdks/python/apache_beam/yaml/ml.yaml | 46 ++ sdks/python/apache_beam/yaml/requirements.txt | 3 + .../apache_beam/yaml/standard_providers.yaml | 1 + .../apache_beam/yaml/tests/bigquery.yaml | 6 +- sdks/python/apache_beam/yaml/yaml_combine.py | 2 +- sdks/python/apache_beam/yaml/yaml_mapping.py | 9 +- sdks/python/apache_beam/yaml/yaml_ml.py | 421 +++++++++++++++++- .../yaml/yaml_provider_unit_test.py | 2 +- .../python/apache_beam/yaml/yaml_transform.py | 79 +--- .../yaml/yaml_transform_scope_test.py | 2 +- .../yaml/yaml_transform_unit_test.py | 53 +-- sdks/python/apache_beam/yaml/yaml_utils.py | 58 +++ .../apache_beam/yaml/yaml_utils_test.py | 56 +++ 19 files changed, 711 insertions(+), 135 deletions(-) rename .github/workflows/{beam_Wordcount_Python_Cost_Benchmark_Dataflow.yml => beam_Python_CostBenchmarks_Dataflow.yml} (69%) create mode 100644 .github/workflows/cost-benchmarks-pipeline-options/python_tf_mnist_classification.txt create mode 100644 sdks/python/apache_beam/examples/inference/output.txt create mode 100644 sdks/python/apache_beam/testing/benchmarks/inference/tensorflow_mnist_classification_cost_benchmark.py create mode 100644 sdks/python/apache_beam/yaml/ml.yaml create mode 100644 sdks/python/apache_beam/yaml/requirements.txt create mode 100644 sdks/python/apache_beam/yaml/yaml_utils.py create mode 100644 sdks/python/apache_beam/yaml/yaml_utils_test.py diff --git a/.github/workflows/beam_Wordcount_Python_Cost_Benchmark_Dataflow.yml b/.github/workflows/beam_Python_CostBenchmarks_Dataflow.yml similarity index 69% rename from .github/workflows/beam_Wordcount_Python_Cost_Benchmark_Dataflow.yml rename to .github/workflows/beam_Python_CostBenchmarks_Dataflow.yml index 51d1005affbc..18fe37e142ac 100644 --- a/.github/workflows/beam_Wordcount_Python_Cost_Benchmark_Dataflow.yml +++ b/.github/workflows/beam_Python_CostBenchmarks_Dataflow.yml @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: Wordcount Python Cost Benchmarks Dataflow +name: Python Cost Benchmarks Dataflow on: + schedule: + - cron: '30 18 * * 6' # Run at 6:30 pm UTC on Saturdays workflow_dispatch: #Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event @@ -47,16 +49,17 @@ env: INFLUXDB_USER_PASSWORD: ${{ secrets.INFLUXDB_USER_PASSWORD }} jobs: - beam_Inference_Python_Benchmarks_Dataflow: + beam_Python_Cost_Benchmarks_Dataflow: if: | - github.event_name == 'workflow_dispatch' + github.event_name == 'workflow_dispatch' || + (github.event_name == 'schedule' && github.repository == 'apache/beam') runs-on: [self-hosted, ubuntu-20.04, main] timeout-minutes: 900 name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) strategy: matrix: - job_name: ["beam_Wordcount_Python_Cost_Benchmarks_Dataflow"] - job_phrase: ["Run Wordcount Cost Benchmark"] + job_name: ["beam_Python_CostBenchmark_Dataflow"] + job_phrase: ["Run Python Dataflow Cost Benchmarks"] steps: - uses: actions/checkout@v4 - name: Setup repository @@ -76,10 +79,11 @@ jobs: test-language: python argument-file-paths: | ${{ github.workspace }}/.github/workflows/cost-benchmarks-pipeline-options/python_wordcount.txt + ${{ github.workspace }}/.github/workflows/cost-benchmarks-pipeline-options/python_tf_mnist_classification.txt # The env variables are created and populated in the test-arguments-action as "_test_arguments_" - name: get current time run: echo "NOW_UTC=$(date '+%m%d%H%M%S' --utc)" >> $GITHUB_ENV - - name: run wordcount on Dataflow Python + - name: Run wordcount on Dataflow uses: ./.github/actions/gradle-command-self-hosted-action timeout-minutes: 30 with: @@ -88,4 +92,14 @@ jobs: -PloadTest.mainClass=apache_beam.testing.benchmarks.wordcount.wordcount \ -Prunner=DataflowRunner \ -PpythonVersion=3.10 \ - '-PloadTest.args=${{ env.beam_Inference_Python_Benchmarks_Dataflow_test_arguments_1 }} --job_name=benchmark-tests-wordcount-python-${{env.NOW_UTC}} --output=gs://temp-storage-for-end-to-end-tests/wordcount/result_wordcount-${{env.NOW_UTC}}.txt' \ \ No newline at end of file + '-PloadTest.args=${{ env.beam_Inference_Python_Benchmarks_Dataflow_test_arguments_1 }} --job_name=benchmark-tests-wordcount-python-${{env.NOW_UTC}} --output_file=gs://temp-storage-for-end-to-end-tests/wordcount/result_wordcount-${{env.NOW_UTC}}.txt' \ + - name: Run Tensorflow MNIST Image Classification on Dataflow + uses: ./.github/actions/gradle-command-self-hosted-action + timeout-minutes: 30 + with: + gradle-command: :sdks:python:apache_beam:testing:load_tests:run + arguments: | + -PloadTest.mainClass=apache_beam.testing.benchmarks.inference.tensorflow_mnist_classification_cost_benchmark \ + -Prunner=DataflowRunner \ + -PpythonVersion=3.10 \ + '-PloadTest.args=${{ env.beam_Inference_Python_Benchmarks_Dataflow_test_arguments_2 }} --job_name=benchmark-tests-tf-mnist-classification-python-${{env.NOW_UTC}} --input_file=gs://apache-beam-ml/testing/inputs/it_mnist_data.csv --output_file=gs://temp-storage-for-end-to-end-tests/wordcount/result_tf_mnist-${{env.NOW_UTC}}.txt --model=gs://apache-beam-ml/models/tensorflow/mnist/' \ \ No newline at end of file diff --git a/.github/workflows/cost-benchmarks-pipeline-options/python_tf_mnist_classification.txt b/.github/workflows/cost-benchmarks-pipeline-options/python_tf_mnist_classification.txt new file mode 100644 index 000000000000..01f4460b8c7e --- /dev/null +++ b/.github/workflows/cost-benchmarks-pipeline-options/python_tf_mnist_classification.txt @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--region=us-central1 +--machine_type=n1-standard-2 +--num_workers=1 +--disk_size_gb=50 +--autoscaling_algorithm=NONE +--input_options={} +--staging_location=gs://temp-storage-for-perf-tests/loadtests +--temp_location=gs://temp-storage-for-perf-tests/loadtests +--requirements_file=apache_beam/ml/inference/tensorflow_tests_requirements.txt +--publish_to_big_query=true +--metrics_dataset=beam_run_inference +--metrics_table=tf_mnist_classification +--runner=DataflowRunner \ No newline at end of file diff --git a/sdks/python/apache_beam/examples/inference/output.txt b/sdks/python/apache_beam/examples/inference/output.txt new file mode 100644 index 000000000000..e2b89887782e --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/output.txt @@ -0,0 +1,3 @@ +What does Apache Beam do?;enables batch and streaming data processing +What is the capital of France?;Paris +Where was beam summit?;NYC diff --git a/sdks/python/apache_beam/testing/benchmarks/inference/tensorflow_mnist_classification_cost_benchmark.py b/sdks/python/apache_beam/testing/benchmarks/inference/tensorflow_mnist_classification_cost_benchmark.py new file mode 100644 index 000000000000..f7e12dcead03 --- /dev/null +++ b/sdks/python/apache_beam/testing/benchmarks/inference/tensorflow_mnist_classification_cost_benchmark.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import logging + +from apache_beam.examples.inference import tensorflow_mnist_classification +from apache_beam.testing.load_tests.dataflow_cost_benchmark import DataflowCostBenchmark + + +class TensorflowMNISTClassificationCostBenchmark(DataflowCostBenchmark): + def __init__(self): + super().__init__() + + def test(self): + extra_opts = {} + extra_opts['input'] = self.pipeline.get_option('input_file') + extra_opts['output'] = self.pipeline.get_option('output_file') + extra_opts['model_path'] = self.pipeline.get_option('model') + tensorflow_mnist_classification.run( + self.pipeline.get_full_options_as_args(**extra_opts), + save_main_session=False) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + TensorflowMNISTClassificationCostBenchmark().run() diff --git a/sdks/python/apache_beam/yaml/generate_yaml_docs.py b/sdks/python/apache_beam/yaml/generate_yaml_docs.py index 27e17029f387..26040fcee3ba 100644 --- a/sdks/python/apache_beam/yaml/generate_yaml_docs.py +++ b/sdks/python/apache_beam/yaml/generate_yaml_docs.py @@ -241,6 +241,7 @@ def main(): json_config_schemas = [] markdown_out = io.StringIO() providers = yaml_provider.standard_providers() + providers = {'RunInference': providers['RunInference']} for transform_base, transforms in itertools.groupby( sorted(providers.keys(), key=io_grouping_key), key=lambda s: s.split('-')[0]): diff --git a/sdks/python/apache_beam/yaml/main.py b/sdks/python/apache_beam/yaml/main.py index e4f03ad300b3..475e5da3a9c6 100644 --- a/sdks/python/apache_beam/yaml/main.py +++ b/sdks/python/apache_beam/yaml/main.py @@ -26,6 +26,7 @@ from apache_beam.typehints.schemas import LogicalType from apache_beam.typehints.schemas import MillisInstant from apache_beam.yaml import yaml_transform +from apache_beam.yaml.yaml_utils import SafeLineLoader def _preparse_jinja_flags(argv): @@ -126,15 +127,14 @@ def run(argv=None): pipeline_template = _pipeline_spec_from_args(known_args) pipeline_yaml = yaml_transform.expand_jinja( pipeline_template, known_args.jinja_variables or {}) - pipeline_spec = yaml.load(pipeline_yaml, Loader=yaml_transform.SafeLineLoader) + pipeline_spec = yaml.load(pipeline_yaml, Loader=SafeLineLoader) with _fix_xlang_instant_coding(): with beam.Pipeline( # linebreak for better yapf formatting options=beam.options.pipeline_options.PipelineOptions( pipeline_args, pickle_library='cloudpickle', - **yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get( - 'options', {}))), + **SafeLineLoader.strip_metadata(pipeline_spec.get('options', {}))), display_data={'yaml': pipeline_yaml, 'yaml_jinja_template': pipeline_template, 'yaml_jinja_variables': json.dumps( diff --git a/sdks/python/apache_beam/yaml/ml.yaml b/sdks/python/apache_beam/yaml/ml.yaml new file mode 100644 index 000000000000..b613532db5da --- /dev/null +++ b/sdks/python/apache_beam/yaml/ml.yaml @@ -0,0 +1,46 @@ +pipeline: + - type: RunInference + config: + model_handler: + type: Huggingface + config: + task: translation_en_to_fr + model: google-t5/t5-small + preprocess_fn: | + def preprocess_fn(_element): + ... + inference_fn: | + def inference_fn(batch, pipeline, inference_args): + ... + postprocess_fn: | + def postprocess_fn(result): + ... + load_model_args: + framework: pt + revision: main + device: gpu + min_batch_size: 1 + max_batch_size: 2 + max_batch_duration: 60s + large_model: false + model_copies: 1 + env_vars: + SOME_ENV_VAR: val + using_key: key + element: element + inference_args: + arg1: val1 + arg2: val2 +transform_providers: + - ... + +# Would probably require callable syntax +model_handler_providers: + - type: python + config: + packages: + - 'some_pypi_package>=version' + model_handlers: + SomeName: 'pkg.module.MyModelHandler' + SomeOtherName: 'pkg.module.MyOtherModelHandler' + diff --git a/sdks/python/apache_beam/yaml/requirements.txt b/sdks/python/apache_beam/yaml/requirements.txt new file mode 100644 index 000000000000..cbc328f3d6fc --- /dev/null +++ b/sdks/python/apache_beam/yaml/requirements.txt @@ -0,0 +1,3 @@ +tensorflow +torch +transformers \ No newline at end of file diff --git a/sdks/python/apache_beam/yaml/standard_providers.yaml b/sdks/python/apache_beam/yaml/standard_providers.yaml index 242faaa9a77b..31eb5e1c6daa 100644 --- a/sdks/python/apache_beam/yaml/standard_providers.yaml +++ b/sdks/python/apache_beam/yaml/standard_providers.yaml @@ -56,6 +56,7 @@ config: {} transforms: MLTransform: 'apache_beam.yaml.yaml_ml.ml_transform' + RunInference: 'apache_beam.yaml.yaml_ml.run_inference' - type: renaming transforms: diff --git a/sdks/python/apache_beam/yaml/tests/bigquery.yaml b/sdks/python/apache_beam/yaml/tests/bigquery.yaml index f5ab31b3855b..1189fb6e68d1 100644 --- a/sdks/python/apache_beam/yaml/tests/bigquery.yaml +++ b/sdks/python/apache_beam/yaml/tests/bigquery.yaml @@ -38,7 +38,7 @@ pipelines: - {label: "389a", rank: 2} - type: WriteToBigQuery config: - table: "{BQ_TABLE}" + table: "{BQ_TABLE[0]}:{BQ_TABLE[1]}.{BQ_TABLE[2]}" options: project: "apache-beam-testing" temp_location: "{TEMP_DIR}" @@ -48,7 +48,7 @@ pipelines: transforms: - type: ReadFromBigQuery config: - table: "{BQ_TABLE}" + table: "{BQ_TABLE[0]}:{BQ_TABLE[1]}.{BQ_TABLE[2]}" - type: AssertEqual config: elements: @@ -64,7 +64,7 @@ pipelines: transforms: - type: ReadFromBigQuery config: - table: "{BQ_TABLE}" + table: "{BQ_TABLE[0]}:{BQ_TABLE[1]}.{BQ_TABLE[2]}" fields: ["label"] row_restriction: "rank > 0" - type: AssertEqual diff --git a/sdks/python/apache_beam/yaml/yaml_combine.py b/sdks/python/apache_beam/yaml/yaml_combine.py index b7499f3b0c7a..84963b150fb6 100644 --- a/sdks/python/apache_beam/yaml/yaml_combine.py +++ b/sdks/python/apache_beam/yaml/yaml_combine.py @@ -31,6 +31,7 @@ from apache_beam.utils import python_callable from apache_beam.yaml import yaml_mapping from apache_beam.yaml import yaml_provider +from apache_beam.yaml.yaml_utils import SafeLineLoader BUILTIN_COMBINE_FNS = { 'sum': sum, @@ -61,7 +62,6 @@ def normalize_combine(spec): fn: type: fn_type """ - from apache_beam.yaml.yaml_transform import SafeLineLoader if spec['type'] == 'Combine': config = spec.get('config') if isinstance(config.get('group_by'), str): diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 8f4a2118c236..11cec247df55 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -297,10 +297,13 @@ def _expand_python_mapping_func( # TODO(robertwb): Consider constructing a single callable that takes # the row and returns the new row, rather than invoking (and unpacking) # for each field individually. - source = '\n'.join(['def fn(__row__):'] + [ - f' {name} = __row__.{name}' + source = '\n'.join(['def fn(__row__):'] + [' try:'] + [ + f' {name} = __row__.{name}' for name in original_fields if name in expression - ] + [' return (' + expression + ')']) + ] + [f' return ({expression})'] + [' except NameError as e:'] + [ + f' raise ValueError(f"{{e}}. Valid values include ' + f'{original_fields}")' + ]) else: source = callable diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py b/sdks/python/apache_beam/yaml/yaml_ml.py index 33f2eeefd296..111618679f96 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml.py +++ b/sdks/python/apache_beam/yaml/yaml_ml.py @@ -16,13 +16,20 @@ # """This module defines yaml wrappings for some ML transforms.""" - from typing import Any +from typing import Callable +from typing import Dict from typing import List from typing import Optional import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference import RunInference +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.typehints.row_type import RowTypeConstraint +from apache_beam.utils import python_callable from apache_beam.yaml import options +from apache_beam.yaml.yaml_utils import SafeLineLoader try: from apache_beam.ml.transforms import tft @@ -33,11 +40,419 @@ tft = None # type: ignore +def normalize_ml(spec): + if spec['type'] == 'RunInference': + config = spec.get('config') + for required in ('model_handler', ): + if required not in config: + raise ValueError( + f'Missing {required} parameter in RunInference config ' + f'at line {SafeLineLoader.get_line(spec)}') + model_handler = config.get('model_handler') + if not isinstance(model_handler, dict): + raise ValueError( + 'Invalid model_handler specification at line ' + f'{SafeLineLoader.get_line(spec)}. Expected ' + f'dict but was {type(model_handler)}.') + for required in ('type', 'config'): + if required not in model_handler: + raise ValueError( + f'Missing {required} in model handler ' + f'at line {SafeLineLoader.get_line(model_handler)}') + typ = model_handler['type'] + extra_params = set(SafeLineLoader.strip_metadata(model_handler).keys()) - { + 'type', 'config' + } + if extra_params: + raise ValueError( + f'Unexpected parameters in model handler of type {typ} ' + f'at line {SafeLineLoader.get_line(spec)}: {extra_params}') + model_handler_provider = ModelHandlerProvider.handler_types.get(typ, None) + if model_handler_provider: + model_handler_provider.validate(model_handler['config']) + else: + raise NotImplementedError( + f'Unknown model handler type: {typ} ' + f'at line {SafeLineLoader.get_line(spec)}.') + + return spec + + +class ModelHandlerProvider: + handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {} + + def __init__( + self, handler, preprocess: Callable = None, postprocess: Callable = None): + self._handler = handler + self._preprocess = self.parse_processing_transform( + preprocess, 'preprocess') or self.preprocess_fn + self._postprocess = self.parse_processing_transform( + postprocess, 'postprocess') or self.postprocess_fn + + def get_output_schema(self): + return Any + + @staticmethod + def parse_processing_transform(processing_transform, typ): + def _parse_config(callable=None, path=None, name=None): + if callable and (path or name): + raise ValueError( + f"Cannot specify 'callable' with 'path' and 'name' for {typ} " + f"function.") + if path and name: + return python_callable.PythonCallableWithSource.load_from_script( + FileSystems.open(path).read().decode(), name) + elif callable: + return python_callable.PythonCallableWithSource(callable) + else: + raise ValueError( + f"Must specify one of 'callable' or 'path' and 'name' for {typ} " + f"function.") + + if processing_transform: + if isinstance(processing_transform, dict): + return _parse_config(**processing_transform) + else: + raise ValueError("Invalid model_handler specification.") + + def underlying_handler(self): + return self._handler + + def preprocess_fn(self, row): + raise ValueError( + 'Handler does not implement a default preprocess ' + 'method. Please define a preprocessing method using the ' + '\'preprocess\' tag.') + + def create_preprocess_fn(self): + return lambda row: (row, self._preprocess(row)) + + @staticmethod + def postprocess_fn(x): + return x + + def create_postprocess_fn(self): + return lambda result: (result[0], self._postprocess(result[1])) + + @staticmethod + def validate(model_handler_spec): + raise NotImplementedError(type(ModelHandlerProvider)) + + @classmethod + def register_handler_type(cls, type_name): + def apply(constructor): + cls.handler_types[type_name] = constructor + return constructor + + return apply + + @classmethod + def create_handler(cls, model_handler_spec) -> "ModelHandlerProvider": + typ = model_handler_spec['type'] + config = model_handler_spec['config'] + try: + result = cls.handler_types[typ](**config) + if not hasattr(result, 'to_json'): + result.to_json = lambda: model_handler_spec + return result + except Exception as exn: + raise ValueError( + f'Unable to instantiate model handler of type {typ}. {exn}') + + +@ModelHandlerProvider.register_handler_type('VertexAIModelHandlerJSON') +class VertexAIModelHandlerJSONProvider(ModelHandlerProvider): + def __init__( + self, + endpoint_id: str, + endpoint_project: str, + endpoint_region: str, + experiment: Optional[str] = None, + network: Optional[str] = None, + private: bool = False, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, + env_vars=None, + preprocess: Callable = None, + postprocess: Callable = None): + """ModelHandler for Vertex AI. + + For example: :: + + - type: RunInference + config: + inference_tag: 'my_inference' + model_handler: + type: VertexAIModelHandlerJSON + config: + endpoint_id: 9876543210 + endpoint_project: 1234567890 + endpoint_region: us-east1 + preprocess: + callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}' + + Args: + endpoint_id: the numerical ID of the Vertex AI endpoint to query. + endpoint_project: the GCP project name where the endpoint is deployed. + endpoint_region: the GCP location where the endpoint is deployed. + experiment: optional. experiment label to apply to the + queries. See + https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments + for more information. + network: optional. the full name of the Compute Engine + network the endpoint is deployed on; used for private + endpoints. The network or subnetwork Dataflow pipeline + option must be set and match this network for pipeline + execution. + Ex: "projects/12345/global/networks/myVPC" + private: optional. if the deployed Vertex AI endpoint is + private, set to true. Requires a network to be provided + as well. + min_batch_size: optional. the minimum batch size to use when batching + inputs. + max_batch_size: optional. the maximum batch size to use when batching + inputs. + max_batch_duration_secs: optional. the maximum amount of time to buffer + a batch before emitting; used in streaming contexts. + env_vars: Environment variables. + preprocess: + postprocess: + """ + + try: + from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON + except ImportError: + raise ValueError( + 'Unable to import VertexAIModelHandlerJSON. Please ' + 'install gcp dependencies: `pip install apache_beam[gcp]`') + + _handler = VertexAIModelHandlerJSON( + endpoint_id=str(endpoint_id), + project=endpoint_project, + location=endpoint_region, + experiment=experiment, + network=network, + private=private, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + env_vars=env_vars or {}) + + super().__init__(_handler, preprocess, postprocess) + + @staticmethod + def validate(model_handler_spec): + for required in ('endpoint_id', 'endpoint_project', 'endpoint_region'): + if required not in model_handler_spec: + raise ValueError( + f'Missing {required} in model handler ' + f'at line {SafeLineLoader.get_line(model_handler_spec)}') + + def get_output_schema(self): + return RowTypeConstraint.from_fields([('example', Any), ('inference', Any), + ('model_id', Optional[str])]) + + def create_postprocess_fn(self): + return lambda x: (x[0], beam.Row(**self._postprocess(x[1])._asdict())) + + +@beam.ptransform.ptransform_fn +def run_inference( + pcoll, + model_handler: Dict[str, Any], + inference_tag: Optional[str] = 'inference', + inference_args: Optional[Dict[str, Any]] = None) -> beam.PCollection[beam.Row]: # pylint: disable=line-too-long + """ + A transform that takes the input rows, containing examples (or features), for + use on an ML model. The transform then appends the inferences + (or predictions) for those examples to the input row. + + A ModelHandler must be passed to the `model_handler` parameter. The + ModelHandler is responsible for configuring how the ML model will be loaded + and how input data will be passed to it. Every ModelHandler has a config tag, + similar to how a transform is defined, where the parameters are defined. + + For example: :: + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + ... + + By default, the RunInference transform will return the + input row with a single field appended named by the `inference_tag` parameter + ("inference" by default) that contains the inference directly returned by the + underlying ModelHandler, after any optional postprocessing. + + For example, if the input had the following: :: + + Row(question="What is a car?") + + The output row would look like: :: + + Row(question="What is a car?", inference=...) + + where the `inference` tag can be overridden with the `inference_tag` + parameter. + + However, if one specified the following transform config: :: + + - type: RunInference + config: + inference_tag: my_inference + model_handler: ... + + The output row would look like: :: + + Row(question="What is a car?", my_inference=...) + + See more complete documentation on the underlying + [RunInference](https://beam.apache.org/documentation/ml/inference-overview/) + transform. + + ### Preprocessing input data + + In most cases, the model will be expecting data in a particular data format, + whether it be a Python Dict, PyTorch tensor, etc. However, the outputs of all + built-in Beam YAML transforms are Beam Rows. To allow for transforming + the Beam Row into a data format the model recognizes, each ModelHandler is + equipped with a `preprocessing` parameter for performing necessary data + preprocessing. It is possible for a ModelHandler to define a default + preprocessing function, but in most cases, one will need to be specified by + the caller. + + For example, using `callable`: :: + + pipeline: + type: chain + + transforms: + - type: Create + config: + elements: + - question: "What is a car?" + - question: "Where is the Eiffel Tower located?" + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + preprocess: + callable: 'lambda row: {"prompt": row.question}' + ... + + In the above example, the Create transform generates a collection of two + elements, each with a single field - "question". The model, however, expects + a Python Dict with a single key, "prompt". In this case, we can specify a + simple Lambda function (alternatively could define a full function), to map + the data. + + ### Postprocessing predictions + + It is also possible to define a postprocessing function to postprocess the + data output by the ModelHandler. See the documentation for the ModelHandler + you intend to use (list defined below under `model_handler` parameter doc). + + In many cases, before postprocessing, the object + will be a + [PredictionResult](https://beam.apache.org/releases/pydoc/BEAM_VERSION/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult). # pylint: disable=line-too-long + This type behaves very similarly to a Beam Row and fields can be accessed + using dot notation. However, make sure to check the docs for your ModelHandler + to see which fields its PredictionResult contains or if it returns a + different object altogether. + + For example: :: + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + postprocess: + callable: | + def fn(x: PredictionResult): + return beam.Row(x.example, x.inference, x.model_id) + ... + + The above example demonstrates converting the original output data type (in + this case it is PredictionResult), and converts to a Beam Row, which allows + for easier mapping in a later transform. + + ### File-based pre/postprocessing functions + + For both preprocessing and postprocessing, it is also possible to specify a + Python UDF (User-defined function) file that contains the function. This is + possible by specifying the `path` to the file (local file or GCS path) and + the `name` of the function in the file. + + For example: :: + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + preprocess: + path: gs://my-bucket/path/to/preprocess.py + name: my_preprocess_fn + postprocess: + path: gs://my-bucket/path/to/postprocess.py + name: my_postprocess_fn + ... + + Args: + model_handler: Specifies the parameters for the respective + enrichment_handler in a YAML/JSON format. To see the full set of + handler_config parameters, see their corresponding doc pages: + + - [VertexAIModelHandlerJSON](https://beam.apache.org/releases/pydoc/current/apache_beam.yaml.yaml_ml.VertexAIModelHandlerJSONProvider) # pylint: disable=line-too-long + inference_tag: The tag to use for the returned inference. Default is + 'inference'. + inference_args: Extra arguments for models whose inference call requires + extra parameters. Make sure to check the underlying ModelHandler docs to + see which args are allowed. + + """ + + options.YamlOptions.check_enabled(pcoll.pipeline, 'ML') + + model_handler_provider = ModelHandlerProvider.create_handler(model_handler) + schema = RowTypeConstraint.from_fields( + list( + RowTypeConstraint.from_user_type( + pcoll.element_type.user_type)._fields) + + [(inference_tag, model_handler_provider.get_output_schema())]) + + return ( + pcoll | RunInference( + model_handler=KeyedModelHandler( + model_handler_provider.underlying_handler()).with_preprocess_fn( + model_handler_provider.create_preprocess_fn()). + with_postprocess_fn(model_handler_provider.create_postprocess_fn()), + inference_args=inference_args) + | beam.Map( + lambda row: beam.Row(**{ + inference_tag: row[1], **row[0]._asdict() + })).with_output_types(schema)) + + def _config_to_obj(spec): if 'type' not in spec: - raise ValueError(r"Missing type in ML transform spec {spec}") + raise ValueError(f"Missing type in ML transform spec {spec}") if 'config' not in spec: - raise ValueError(r"Missing config in ML transform spec {spec}") + raise ValueError(f"Missing config in ML transform spec {spec}") constructor = _transform_constructors.get(spec['type']) if constructor is None: raise ValueError("Unknown ML transform type: %r" % spec['type']) diff --git a/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py b/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py index 175f9388a0c6..b340ff260f77 100644 --- a/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py @@ -27,8 +27,8 @@ from apache_beam.testing.util import equal_to from apache_beam.yaml import yaml_provider from apache_beam.yaml.yaml_provider import YamlProviders -from apache_beam.yaml.yaml_transform import SafeLineLoader from apache_beam.yaml.yaml_transform import YamlTransform +from apache_beam.yaml.yaml_utils import SafeLineLoader class WindowIntoTest(unittest.TestCase): diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 7cb96a7efb32..033053b7d429 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -23,7 +23,7 @@ import os import pprint import re -import uuid +from json import JSONDecodeError from typing import Any from typing import Iterable from typing import List @@ -32,7 +32,6 @@ import jinja2 import yaml -from yaml.loader import SafeLoader import apache_beam as beam from apache_beam.io.filesystems import FileSystems @@ -45,6 +44,10 @@ __all__ = ["YamlTransform"] +from apache_beam.yaml.yaml_ml import normalize_ml + +from apache_beam.yaml.yaml_utils import SafeLineLoader + _LOGGER = logging.getLogger(__name__) yaml_provider.fix_pycallable() @@ -130,59 +133,6 @@ def empty_if_explicitly_empty(io): return io -class SafeLineLoader(SafeLoader): - """A yaml loader that attaches line information to mappings and strings.""" - class TaggedString(str): - """A string class to which we can attach metadata. - - This is primarily used to trace a string's origin back to its place in a - yaml file. - """ - def __reduce__(self): - # Pickle as an ordinary string. - return str, (str(self), ) - - def construct_scalar(self, node): - value = super().construct_scalar(node) - if isinstance(value, str): - value = SafeLineLoader.TaggedString(value) - value._line_ = node.start_mark.line + 1 - return value - - def construct_mapping(self, node, deep=False): - mapping = super().construct_mapping(node, deep=deep) - mapping['__line__'] = node.start_mark.line + 1 - mapping['__uuid__'] = self.create_uuid() - return mapping - - @classmethod - def create_uuid(cls): - return str(uuid.uuid4()) - - @classmethod - def strip_metadata(cls, spec, tagged_str=True): - if isinstance(spec, Mapping): - return { - cls.strip_metadata(key, tagged_str): - cls.strip_metadata(value, tagged_str) - for (key, value) in spec.items() - if key not in ('__line__', '__uuid__') - } - elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)): - return [cls.strip_metadata(value, tagged_str) for value in spec] - elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str: - return str(spec) - else: - return spec - - @staticmethod - def get_line(obj): - if isinstance(obj, dict): - return obj.get('__line__', 'unknown') - else: - return getattr(obj, '_line_', 'unknown') - - class LightweightScope(object): def __init__(self, transforms): self._transforms = transforms @@ -927,12 +877,14 @@ def apply(phase, spec): spec, transforms=[apply(phase, t) for t in spec['transforms']]) return spec + known_transform_keys = None if known_transforms: - known_transforms = set(known_transforms).union(['chain', 'composite']) + known_transform_keys = set(known_transforms.keys()).union( + ['chain', 'composite']) def ensure_transforms_have_providers(spec): - if known_transforms: - if spec['type'] not in known_transforms: + if known_transform_keys: + if spec['type'] not in known_transform_keys: raise ValueError( 'Unknown type or missing provider ' f'for type {spec["type"]} for {identify_object(spec)}') @@ -946,7 +898,7 @@ def preprocess_languages(spec): 'Partition'): language = spec.get('config', {}).get('language', 'generic') new_type = spec['type'] + '-' + language - if known_transforms and new_type not in known_transforms: + if known_transform_keys and new_type not in known_transform_keys: if language == 'generic': raise ValueError(f'Missing language for {identify_object(spec)}') else: @@ -960,6 +912,7 @@ def preprocess_languages(spec): ensure_transforms_have_types, normalize_mapping, normalize_combine, + normalize_ml, preprocess_languages, ensure_transforms_have_providers, preprocess_source_sink, @@ -999,7 +952,10 @@ def expand_jinja( class YamlTransform(beam.PTransform): def __init__(self, spec, providers={}): # pylint: disable=dangerous-default-value if isinstance(spec, str): - spec = yaml.load(spec, Loader=SafeLineLoader) + try: + spec = json.loads(spec) + except JSONDecodeError: + spec = yaml.load(spec, Loader=SafeLineLoader) if isinstance(providers, dict): providers = { key: yaml_provider.as_provider_list(key, value) @@ -1008,7 +964,7 @@ def __init__(self, spec, providers={}): # pylint: disable=dangerous-default-val # TODO(BEAM-26941): Validate as a transform. self._providers = yaml_provider.merge_providers( providers, yaml_provider.standard_providers()) - self._spec = preprocess(spec, known_transforms=self._providers.keys()) + self._spec = preprocess(spec, known_transforms=self._providers) self._was_chain = spec['type'] == 'chain' def expand(self, pcolls): @@ -1068,6 +1024,7 @@ def expand_pipeline( # this could certainly be handy as a first pass when Beam is not available. if validate_schema and validate_schema != 'none': validate_against_schema(pipeline_spec, validate_schema) + # Calling expand directly to avoid outer layer of nesting. return YamlTransform( pipeline_as_composite(pipeline_spec['pipeline']), diff --git a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py index 2a5a96aa42df..3989abeebd66 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py @@ -25,8 +25,8 @@ from apache_beam.yaml import yaml_provider from apache_beam.yaml import yaml_transform from apache_beam.yaml.yaml_transform import LightweightScope -from apache_beam.yaml.yaml_transform import SafeLineLoader from apache_beam.yaml.yaml_transform import Scope +from apache_beam.yaml.yaml_utils import SafeLineLoader class ScopeTest(unittest.TestCase): diff --git a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py index 084e03cdb197..5bc9de24bb38 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py @@ -23,7 +23,6 @@ from apache_beam.yaml import YamlTransform from apache_beam.yaml import yaml_provider from apache_beam.yaml.yaml_provider import InlineProvider -from apache_beam.yaml.yaml_transform import SafeLineLoader from apache_beam.yaml.yaml_transform import Scope from apache_beam.yaml.yaml_transform import chain_as_composite from apache_beam.yaml.yaml_transform import ensure_errors_consumed @@ -39,57 +38,7 @@ from apache_beam.yaml.yaml_transform import preprocess_flattened_inputs from apache_beam.yaml.yaml_transform import preprocess_windowing from apache_beam.yaml.yaml_transform import push_windowing_to_roots - - -class SafeLineLoaderTest(unittest.TestCase): - def test_get_line(self): - pipeline_yaml = ''' - type: composite - input: - elements: input - transforms: - - type: PyMap - name: Square - input: elements - config: - fn: "lambda x: x * x" - - type: PyMap - name: Cube - input: elements - config: - fn: "lambda x: x * x * x" - output: - Flatten - ''' - spec = yaml.load(pipeline_yaml, Loader=SafeLineLoader) - self.assertEqual(SafeLineLoader.get_line(spec['type']), 2) - self.assertEqual(SafeLineLoader.get_line(spec['input']), 4) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]), 6) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['type']), 6) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['name']), 7) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][1]), 11) - self.assertEqual(SafeLineLoader.get_line(spec['output']), 17) - self.assertEqual(SafeLineLoader.get_line(spec['transforms']), "unknown") - - def test_strip_metadata(self): - spec_yaml = ''' - transforms: - - type: PyMap - name: Square - ''' - spec = yaml.load(spec_yaml, Loader=SafeLineLoader) - stripped = SafeLineLoader.strip_metadata(spec['transforms']) - - self.assertFalse(hasattr(stripped[0], '__line__')) - self.assertFalse(hasattr(stripped[0], '__uuid__')) - - def test_strip_metadata_nothing_to_strip(self): - spec_yaml = 'prop: 123' - spec = yaml.load(spec_yaml, Loader=SafeLineLoader) - stripped = SafeLineLoader.strip_metadata(spec['prop']) - - self.assertFalse(hasattr(stripped, '__line__')) - self.assertFalse(hasattr(stripped, '__uuid__')) +from apache_beam.yaml.yaml_utils import SafeLineLoader def new_pipeline(): diff --git a/sdks/python/apache_beam/yaml/yaml_utils.py b/sdks/python/apache_beam/yaml/yaml_utils.py new file mode 100644 index 000000000000..91cad4175fdb --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_utils.py @@ -0,0 +1,58 @@ +import uuid +from typing import Iterable +from typing import Mapping + +from yaml import SafeLoader + + +class SafeLineLoader(SafeLoader): + """A yaml loader that attaches line information to mappings and strings.""" + class TaggedString(str): + """A string class to which we can attach metadata. + + This is primarily used to trace a string's origin back to its place in a + yaml file. + """ + def __reduce__(self): + # Pickle as an ordinary string. + return str, (str(self), ) + + def construct_scalar(self, node): + value = super().construct_scalar(node) + if isinstance(value, str): + value = SafeLineLoader.TaggedString(value) + value._line_ = node.start_mark.line + 1 + return value + + def construct_mapping(self, node, deep=False): + mapping = super().construct_mapping(node, deep=deep) + mapping['__line__'] = node.start_mark.line + 1 + mapping['__uuid__'] = self.create_uuid() + return mapping + + @classmethod + def create_uuid(cls): + return str(uuid.uuid4()) + + @classmethod + def strip_metadata(cls, spec, tagged_str=True): + if isinstance(spec, Mapping): + return { + cls.strip_metadata(key, tagged_str): + cls.strip_metadata(value, tagged_str) + for (key, value) in spec.items() + if key not in ('__line__', '__uuid__') + } + elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)): + return [cls.strip_metadata(value, tagged_str) for value in spec] + elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str: + return str(spec) + else: + return spec + + @staticmethod + def get_line(obj): + if isinstance(obj, dict): + return obj.get('__line__', 'unknown') + else: + return getattr(obj, '_line_', 'unknown') diff --git a/sdks/python/apache_beam/yaml/yaml_utils_test.py b/sdks/python/apache_beam/yaml/yaml_utils_test.py new file mode 100644 index 000000000000..23105a4e82d4 --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_utils_test.py @@ -0,0 +1,56 @@ +import unittest + +import yaml + +from apache_beam.yaml.yaml_utils import SafeLineLoader + + +class SafeLineLoaderTest(unittest.TestCase): + def test_get_line(self): + pipeline_yaml = ''' + type: composite + input: + elements: input + transforms: + - type: PyMap + name: Square + input: elements + config: + fn: "lambda x: x * x" + - type: PyMap + name: Cube + input: elements + config: + fn: "lambda x: x * x * x" + output: + Flatten + ''' + spec = yaml.load(pipeline_yaml, Loader=SafeLineLoader) + self.assertEqual(SafeLineLoader.get_line(spec['type']), 2) + self.assertEqual(SafeLineLoader.get_line(spec['input']), 4) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]), 6) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['type']), 6) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['name']), 7) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][1]), 11) + self.assertEqual(SafeLineLoader.get_line(spec['output']), 17) + self.assertEqual(SafeLineLoader.get_line(spec['transforms']), "unknown") + + def test_strip_metadata(self): + spec_yaml = ''' + transforms: + - type: PyMap + name: Square + ''' + spec = yaml.load(spec_yaml, Loader=SafeLineLoader) + stripped = SafeLineLoader.strip_metadata(spec['transforms']) + + self.assertFalse(hasattr(stripped[0], '__line__')) + self.assertFalse(hasattr(stripped[0], '__uuid__')) + + def test_strip_metadata_nothing_to_strip(self): + spec_yaml = 'prop: 123' + spec = yaml.load(spec_yaml, Loader=SafeLineLoader) + stripped = SafeLineLoader.strip_metadata(spec['prop']) + + self.assertFalse(hasattr(stripped, '__line__')) + self.assertFalse(hasattr(stripped, '__uuid__'))