From c04712bb7312509e5c2049f93437ad1595f85a9b Mon Sep 17 00:00:00 2001 From: Jeffrey Kinard Date: Tue, 17 Dec 2024 15:58:28 -0500 Subject: [PATCH 1/5] [yaml] add RunInference support with VertexAI Signed-off-by: Jeffrey Kinard --- .../apache_beam/yaml/standard_providers.yaml | 1 + sdks/python/apache_beam/yaml/yaml_ml.py | 421 +++++++++++++++++- .../python/apache_beam/yaml/yaml_transform.py | 58 +-- .../yaml/yaml_transform_unit_test.py | 53 +-- sdks/python/apache_beam/yaml/yaml_utils.py | 58 +++ .../apache_beam/yaml/yaml_utils_test.py | 56 +++ 6 files changed, 537 insertions(+), 110 deletions(-) create mode 100644 sdks/python/apache_beam/yaml/yaml_utils.py create mode 100644 sdks/python/apache_beam/yaml/yaml_utils_test.py diff --git a/sdks/python/apache_beam/yaml/standard_providers.yaml b/sdks/python/apache_beam/yaml/standard_providers.yaml index 242faaa9a77b..31eb5e1c6daa 100644 --- a/sdks/python/apache_beam/yaml/standard_providers.yaml +++ b/sdks/python/apache_beam/yaml/standard_providers.yaml @@ -56,6 +56,7 @@ config: {} transforms: MLTransform: 'apache_beam.yaml.yaml_ml.ml_transform' + RunInference: 'apache_beam.yaml.yaml_ml.run_inference' - type: renaming transforms: diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py b/sdks/python/apache_beam/yaml/yaml_ml.py index 33f2eeefd296..111618679f96 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml.py +++ b/sdks/python/apache_beam/yaml/yaml_ml.py @@ -16,13 +16,20 @@ # """This module defines yaml wrappings for some ML transforms.""" - from typing import Any +from typing import Callable +from typing import Dict from typing import List from typing import Optional import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference import RunInference +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.typehints.row_type import RowTypeConstraint +from apache_beam.utils import python_callable from apache_beam.yaml import options +from apache_beam.yaml.yaml_utils import SafeLineLoader try: from apache_beam.ml.transforms import tft @@ -33,11 +40,419 @@ tft = None # type: ignore +def normalize_ml(spec): + if spec['type'] == 'RunInference': + config = spec.get('config') + for required in ('model_handler', ): + if required not in config: + raise ValueError( + f'Missing {required} parameter in RunInference config ' + f'at line {SafeLineLoader.get_line(spec)}') + model_handler = config.get('model_handler') + if not isinstance(model_handler, dict): + raise ValueError( + 'Invalid model_handler specification at line ' + f'{SafeLineLoader.get_line(spec)}. Expected ' + f'dict but was {type(model_handler)}.') + for required in ('type', 'config'): + if required not in model_handler: + raise ValueError( + f'Missing {required} in model handler ' + f'at line {SafeLineLoader.get_line(model_handler)}') + typ = model_handler['type'] + extra_params = set(SafeLineLoader.strip_metadata(model_handler).keys()) - { + 'type', 'config' + } + if extra_params: + raise ValueError( + f'Unexpected parameters in model handler of type {typ} ' + f'at line {SafeLineLoader.get_line(spec)}: {extra_params}') + model_handler_provider = ModelHandlerProvider.handler_types.get(typ, None) + if model_handler_provider: + model_handler_provider.validate(model_handler['config']) + else: + raise NotImplementedError( + f'Unknown model handler type: {typ} ' + f'at line {SafeLineLoader.get_line(spec)}.') + + return spec + + +class ModelHandlerProvider: + handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {} + + def __init__( + self, handler, preprocess: Callable = None, postprocess: Callable = None): + self._handler = handler + self._preprocess = self.parse_processing_transform( + preprocess, 'preprocess') or self.preprocess_fn + self._postprocess = self.parse_processing_transform( + postprocess, 'postprocess') or self.postprocess_fn + + def get_output_schema(self): + return Any + + @staticmethod + def parse_processing_transform(processing_transform, typ): + def _parse_config(callable=None, path=None, name=None): + if callable and (path or name): + raise ValueError( + f"Cannot specify 'callable' with 'path' and 'name' for {typ} " + f"function.") + if path and name: + return python_callable.PythonCallableWithSource.load_from_script( + FileSystems.open(path).read().decode(), name) + elif callable: + return python_callable.PythonCallableWithSource(callable) + else: + raise ValueError( + f"Must specify one of 'callable' or 'path' and 'name' for {typ} " + f"function.") + + if processing_transform: + if isinstance(processing_transform, dict): + return _parse_config(**processing_transform) + else: + raise ValueError("Invalid model_handler specification.") + + def underlying_handler(self): + return self._handler + + def preprocess_fn(self, row): + raise ValueError( + 'Handler does not implement a default preprocess ' + 'method. Please define a preprocessing method using the ' + '\'preprocess\' tag.') + + def create_preprocess_fn(self): + return lambda row: (row, self._preprocess(row)) + + @staticmethod + def postprocess_fn(x): + return x + + def create_postprocess_fn(self): + return lambda result: (result[0], self._postprocess(result[1])) + + @staticmethod + def validate(model_handler_spec): + raise NotImplementedError(type(ModelHandlerProvider)) + + @classmethod + def register_handler_type(cls, type_name): + def apply(constructor): + cls.handler_types[type_name] = constructor + return constructor + + return apply + + @classmethod + def create_handler(cls, model_handler_spec) -> "ModelHandlerProvider": + typ = model_handler_spec['type'] + config = model_handler_spec['config'] + try: + result = cls.handler_types[typ](**config) + if not hasattr(result, 'to_json'): + result.to_json = lambda: model_handler_spec + return result + except Exception as exn: + raise ValueError( + f'Unable to instantiate model handler of type {typ}. {exn}') + + +@ModelHandlerProvider.register_handler_type('VertexAIModelHandlerJSON') +class VertexAIModelHandlerJSONProvider(ModelHandlerProvider): + def __init__( + self, + endpoint_id: str, + endpoint_project: str, + endpoint_region: str, + experiment: Optional[str] = None, + network: Optional[str] = None, + private: bool = False, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, + env_vars=None, + preprocess: Callable = None, + postprocess: Callable = None): + """ModelHandler for Vertex AI. + + For example: :: + + - type: RunInference + config: + inference_tag: 'my_inference' + model_handler: + type: VertexAIModelHandlerJSON + config: + endpoint_id: 9876543210 + endpoint_project: 1234567890 + endpoint_region: us-east1 + preprocess: + callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}' + + Args: + endpoint_id: the numerical ID of the Vertex AI endpoint to query. + endpoint_project: the GCP project name where the endpoint is deployed. + endpoint_region: the GCP location where the endpoint is deployed. + experiment: optional. experiment label to apply to the + queries. See + https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments + for more information. + network: optional. the full name of the Compute Engine + network the endpoint is deployed on; used for private + endpoints. The network or subnetwork Dataflow pipeline + option must be set and match this network for pipeline + execution. + Ex: "projects/12345/global/networks/myVPC" + private: optional. if the deployed Vertex AI endpoint is + private, set to true. Requires a network to be provided + as well. + min_batch_size: optional. the minimum batch size to use when batching + inputs. + max_batch_size: optional. the maximum batch size to use when batching + inputs. + max_batch_duration_secs: optional. the maximum amount of time to buffer + a batch before emitting; used in streaming contexts. + env_vars: Environment variables. + preprocess: + postprocess: + """ + + try: + from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON + except ImportError: + raise ValueError( + 'Unable to import VertexAIModelHandlerJSON. Please ' + 'install gcp dependencies: `pip install apache_beam[gcp]`') + + _handler = VertexAIModelHandlerJSON( + endpoint_id=str(endpoint_id), + project=endpoint_project, + location=endpoint_region, + experiment=experiment, + network=network, + private=private, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + env_vars=env_vars or {}) + + super().__init__(_handler, preprocess, postprocess) + + @staticmethod + def validate(model_handler_spec): + for required in ('endpoint_id', 'endpoint_project', 'endpoint_region'): + if required not in model_handler_spec: + raise ValueError( + f'Missing {required} in model handler ' + f'at line {SafeLineLoader.get_line(model_handler_spec)}') + + def get_output_schema(self): + return RowTypeConstraint.from_fields([('example', Any), ('inference', Any), + ('model_id', Optional[str])]) + + def create_postprocess_fn(self): + return lambda x: (x[0], beam.Row(**self._postprocess(x[1])._asdict())) + + +@beam.ptransform.ptransform_fn +def run_inference( + pcoll, + model_handler: Dict[str, Any], + inference_tag: Optional[str] = 'inference', + inference_args: Optional[Dict[str, Any]] = None) -> beam.PCollection[beam.Row]: # pylint: disable=line-too-long + """ + A transform that takes the input rows, containing examples (or features), for + use on an ML model. The transform then appends the inferences + (or predictions) for those examples to the input row. + + A ModelHandler must be passed to the `model_handler` parameter. The + ModelHandler is responsible for configuring how the ML model will be loaded + and how input data will be passed to it. Every ModelHandler has a config tag, + similar to how a transform is defined, where the parameters are defined. + + For example: :: + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + ... + + By default, the RunInference transform will return the + input row with a single field appended named by the `inference_tag` parameter + ("inference" by default) that contains the inference directly returned by the + underlying ModelHandler, after any optional postprocessing. + + For example, if the input had the following: :: + + Row(question="What is a car?") + + The output row would look like: :: + + Row(question="What is a car?", inference=...) + + where the `inference` tag can be overridden with the `inference_tag` + parameter. + + However, if one specified the following transform config: :: + + - type: RunInference + config: + inference_tag: my_inference + model_handler: ... + + The output row would look like: :: + + Row(question="What is a car?", my_inference=...) + + See more complete documentation on the underlying + [RunInference](https://beam.apache.org/documentation/ml/inference-overview/) + transform. + + ### Preprocessing input data + + In most cases, the model will be expecting data in a particular data format, + whether it be a Python Dict, PyTorch tensor, etc. However, the outputs of all + built-in Beam YAML transforms are Beam Rows. To allow for transforming + the Beam Row into a data format the model recognizes, each ModelHandler is + equipped with a `preprocessing` parameter for performing necessary data + preprocessing. It is possible for a ModelHandler to define a default + preprocessing function, but in most cases, one will need to be specified by + the caller. + + For example, using `callable`: :: + + pipeline: + type: chain + + transforms: + - type: Create + config: + elements: + - question: "What is a car?" + - question: "Where is the Eiffel Tower located?" + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + preprocess: + callable: 'lambda row: {"prompt": row.question}' + ... + + In the above example, the Create transform generates a collection of two + elements, each with a single field - "question". The model, however, expects + a Python Dict with a single key, "prompt". In this case, we can specify a + simple Lambda function (alternatively could define a full function), to map + the data. + + ### Postprocessing predictions + + It is also possible to define a postprocessing function to postprocess the + data output by the ModelHandler. See the documentation for the ModelHandler + you intend to use (list defined below under `model_handler` parameter doc). + + In many cases, before postprocessing, the object + will be a + [PredictionResult](https://beam.apache.org/releases/pydoc/BEAM_VERSION/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult). # pylint: disable=line-too-long + This type behaves very similarly to a Beam Row and fields can be accessed + using dot notation. However, make sure to check the docs for your ModelHandler + to see which fields its PredictionResult contains or if it returns a + different object altogether. + + For example: :: + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + postprocess: + callable: | + def fn(x: PredictionResult): + return beam.Row(x.example, x.inference, x.model_id) + ... + + The above example demonstrates converting the original output data type (in + this case it is PredictionResult), and converts to a Beam Row, which allows + for easier mapping in a later transform. + + ### File-based pre/postprocessing functions + + For both preprocessing and postprocessing, it is also possible to specify a + Python UDF (User-defined function) file that contains the function. This is + possible by specifying the `path` to the file (local file or GCS path) and + the `name` of the function in the file. + + For example: :: + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + preprocess: + path: gs://my-bucket/path/to/preprocess.py + name: my_preprocess_fn + postprocess: + path: gs://my-bucket/path/to/postprocess.py + name: my_postprocess_fn + ... + + Args: + model_handler: Specifies the parameters for the respective + enrichment_handler in a YAML/JSON format. To see the full set of + handler_config parameters, see their corresponding doc pages: + + - [VertexAIModelHandlerJSON](https://beam.apache.org/releases/pydoc/current/apache_beam.yaml.yaml_ml.VertexAIModelHandlerJSONProvider) # pylint: disable=line-too-long + inference_tag: The tag to use for the returned inference. Default is + 'inference'. + inference_args: Extra arguments for models whose inference call requires + extra parameters. Make sure to check the underlying ModelHandler docs to + see which args are allowed. + + """ + + options.YamlOptions.check_enabled(pcoll.pipeline, 'ML') + + model_handler_provider = ModelHandlerProvider.create_handler(model_handler) + schema = RowTypeConstraint.from_fields( + list( + RowTypeConstraint.from_user_type( + pcoll.element_type.user_type)._fields) + + [(inference_tag, model_handler_provider.get_output_schema())]) + + return ( + pcoll | RunInference( + model_handler=KeyedModelHandler( + model_handler_provider.underlying_handler()).with_preprocess_fn( + model_handler_provider.create_preprocess_fn()). + with_postprocess_fn(model_handler_provider.create_postprocess_fn()), + inference_args=inference_args) + | beam.Map( + lambda row: beam.Row(**{ + inference_tag: row[1], **row[0]._asdict() + })).with_output_types(schema)) + + def _config_to_obj(spec): if 'type' not in spec: - raise ValueError(r"Missing type in ML transform spec {spec}") + raise ValueError(f"Missing type in ML transform spec {spec}") if 'config' not in spec: - raise ValueError(r"Missing config in ML transform spec {spec}") + raise ValueError(f"Missing config in ML transform spec {spec}") constructor = _transform_constructors.get(spec['type']) if constructor is None: raise ValueError("Unknown ML transform type: %r" % spec['type']) diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 7cb96a7efb32..c83cd53b5855 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -23,7 +23,6 @@ import os import pprint import re -import uuid from typing import Any from typing import Iterable from typing import List @@ -32,7 +31,6 @@ import jinja2 import yaml -from yaml.loader import SafeLoader import apache_beam as beam from apache_beam.io.filesystems import FileSystems @@ -42,6 +40,8 @@ from apache_beam.yaml.yaml_combine import normalize_combine from apache_beam.yaml.yaml_mapping import normalize_mapping from apache_beam.yaml.yaml_mapping import validate_generic_expressions +from apache_beam.yaml.yaml_ml import normalize_ml +from apache_beam.yaml.yaml_utils import SafeLineLoader __all__ = ["YamlTransform"] @@ -130,59 +130,6 @@ def empty_if_explicitly_empty(io): return io -class SafeLineLoader(SafeLoader): - """A yaml loader that attaches line information to mappings and strings.""" - class TaggedString(str): - """A string class to which we can attach metadata. - - This is primarily used to trace a string's origin back to its place in a - yaml file. - """ - def __reduce__(self): - # Pickle as an ordinary string. - return str, (str(self), ) - - def construct_scalar(self, node): - value = super().construct_scalar(node) - if isinstance(value, str): - value = SafeLineLoader.TaggedString(value) - value._line_ = node.start_mark.line + 1 - return value - - def construct_mapping(self, node, deep=False): - mapping = super().construct_mapping(node, deep=deep) - mapping['__line__'] = node.start_mark.line + 1 - mapping['__uuid__'] = self.create_uuid() - return mapping - - @classmethod - def create_uuid(cls): - return str(uuid.uuid4()) - - @classmethod - def strip_metadata(cls, spec, tagged_str=True): - if isinstance(spec, Mapping): - return { - cls.strip_metadata(key, tagged_str): - cls.strip_metadata(value, tagged_str) - for (key, value) in spec.items() - if key not in ('__line__', '__uuid__') - } - elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)): - return [cls.strip_metadata(value, tagged_str) for value in spec] - elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str: - return str(spec) - else: - return spec - - @staticmethod - def get_line(obj): - if isinstance(obj, dict): - return obj.get('__line__', 'unknown') - else: - return getattr(obj, '_line_', 'unknown') - - class LightweightScope(object): def __init__(self, transforms): self._transforms = transforms @@ -960,6 +907,7 @@ def preprocess_languages(spec): ensure_transforms_have_types, normalize_mapping, normalize_combine, + normalize_ml, preprocess_languages, ensure_transforms_have_providers, preprocess_source_sink, diff --git a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py index 084e03cdb197..5bc9de24bb38 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py @@ -23,7 +23,6 @@ from apache_beam.yaml import YamlTransform from apache_beam.yaml import yaml_provider from apache_beam.yaml.yaml_provider import InlineProvider -from apache_beam.yaml.yaml_transform import SafeLineLoader from apache_beam.yaml.yaml_transform import Scope from apache_beam.yaml.yaml_transform import chain_as_composite from apache_beam.yaml.yaml_transform import ensure_errors_consumed @@ -39,57 +38,7 @@ from apache_beam.yaml.yaml_transform import preprocess_flattened_inputs from apache_beam.yaml.yaml_transform import preprocess_windowing from apache_beam.yaml.yaml_transform import push_windowing_to_roots - - -class SafeLineLoaderTest(unittest.TestCase): - def test_get_line(self): - pipeline_yaml = ''' - type: composite - input: - elements: input - transforms: - - type: PyMap - name: Square - input: elements - config: - fn: "lambda x: x * x" - - type: PyMap - name: Cube - input: elements - config: - fn: "lambda x: x * x * x" - output: - Flatten - ''' - spec = yaml.load(pipeline_yaml, Loader=SafeLineLoader) - self.assertEqual(SafeLineLoader.get_line(spec['type']), 2) - self.assertEqual(SafeLineLoader.get_line(spec['input']), 4) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]), 6) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['type']), 6) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['name']), 7) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][1]), 11) - self.assertEqual(SafeLineLoader.get_line(spec['output']), 17) - self.assertEqual(SafeLineLoader.get_line(spec['transforms']), "unknown") - - def test_strip_metadata(self): - spec_yaml = ''' - transforms: - - type: PyMap - name: Square - ''' - spec = yaml.load(spec_yaml, Loader=SafeLineLoader) - stripped = SafeLineLoader.strip_metadata(spec['transforms']) - - self.assertFalse(hasattr(stripped[0], '__line__')) - self.assertFalse(hasattr(stripped[0], '__uuid__')) - - def test_strip_metadata_nothing_to_strip(self): - spec_yaml = 'prop: 123' - spec = yaml.load(spec_yaml, Loader=SafeLineLoader) - stripped = SafeLineLoader.strip_metadata(spec['prop']) - - self.assertFalse(hasattr(stripped, '__line__')) - self.assertFalse(hasattr(stripped, '__uuid__')) +from apache_beam.yaml.yaml_utils import SafeLineLoader def new_pipeline(): diff --git a/sdks/python/apache_beam/yaml/yaml_utils.py b/sdks/python/apache_beam/yaml/yaml_utils.py new file mode 100644 index 000000000000..91cad4175fdb --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_utils.py @@ -0,0 +1,58 @@ +import uuid +from typing import Iterable +from typing import Mapping + +from yaml import SafeLoader + + +class SafeLineLoader(SafeLoader): + """A yaml loader that attaches line information to mappings and strings.""" + class TaggedString(str): + """A string class to which we can attach metadata. + + This is primarily used to trace a string's origin back to its place in a + yaml file. + """ + def __reduce__(self): + # Pickle as an ordinary string. + return str, (str(self), ) + + def construct_scalar(self, node): + value = super().construct_scalar(node) + if isinstance(value, str): + value = SafeLineLoader.TaggedString(value) + value._line_ = node.start_mark.line + 1 + return value + + def construct_mapping(self, node, deep=False): + mapping = super().construct_mapping(node, deep=deep) + mapping['__line__'] = node.start_mark.line + 1 + mapping['__uuid__'] = self.create_uuid() + return mapping + + @classmethod + def create_uuid(cls): + return str(uuid.uuid4()) + + @classmethod + def strip_metadata(cls, spec, tagged_str=True): + if isinstance(spec, Mapping): + return { + cls.strip_metadata(key, tagged_str): + cls.strip_metadata(value, tagged_str) + for (key, value) in spec.items() + if key not in ('__line__', '__uuid__') + } + elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)): + return [cls.strip_metadata(value, tagged_str) for value in spec] + elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str: + return str(spec) + else: + return spec + + @staticmethod + def get_line(obj): + if isinstance(obj, dict): + return obj.get('__line__', 'unknown') + else: + return getattr(obj, '_line_', 'unknown') diff --git a/sdks/python/apache_beam/yaml/yaml_utils_test.py b/sdks/python/apache_beam/yaml/yaml_utils_test.py new file mode 100644 index 000000000000..23105a4e82d4 --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_utils_test.py @@ -0,0 +1,56 @@ +import unittest + +import yaml + +from apache_beam.yaml.yaml_utils import SafeLineLoader + + +class SafeLineLoaderTest(unittest.TestCase): + def test_get_line(self): + pipeline_yaml = ''' + type: composite + input: + elements: input + transforms: + - type: PyMap + name: Square + input: elements + config: + fn: "lambda x: x * x" + - type: PyMap + name: Cube + input: elements + config: + fn: "lambda x: x * x * x" + output: + Flatten + ''' + spec = yaml.load(pipeline_yaml, Loader=SafeLineLoader) + self.assertEqual(SafeLineLoader.get_line(spec['type']), 2) + self.assertEqual(SafeLineLoader.get_line(spec['input']), 4) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]), 6) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['type']), 6) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['name']), 7) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][1]), 11) + self.assertEqual(SafeLineLoader.get_line(spec['output']), 17) + self.assertEqual(SafeLineLoader.get_line(spec['transforms']), "unknown") + + def test_strip_metadata(self): + spec_yaml = ''' + transforms: + - type: PyMap + name: Square + ''' + spec = yaml.load(spec_yaml, Loader=SafeLineLoader) + stripped = SafeLineLoader.strip_metadata(spec['transforms']) + + self.assertFalse(hasattr(stripped[0], '__line__')) + self.assertFalse(hasattr(stripped[0], '__uuid__')) + + def test_strip_metadata_nothing_to_strip(self): + spec_yaml = 'prop: 123' + spec = yaml.load(spec_yaml, Loader=SafeLineLoader) + stripped = SafeLineLoader.strip_metadata(spec['prop']) + + self.assertFalse(hasattr(stripped, '__line__')) + self.assertFalse(hasattr(stripped, '__uuid__')) From 17a34c43768f2ab2bde2140d68e72525357f2bbe Mon Sep 17 00:00:00 2001 From: Jeffrey Kinard Date: Wed, 18 Dec 2024 15:00:25 -0500 Subject: [PATCH 2/5] address comments and fix tests Signed-off-by: Jeffrey Kinard --- sdks/python/apache_beam/yaml/yaml_ml.py | 166 +++++++++--------- .../python/apache_beam/yaml/yaml_transform.py | 2 - sdks/python/apache_beam/yaml/yaml_utils.py | 17 ++ .../apache_beam/yaml/yaml_utils_test.py | 22 +++ 4 files changed, 118 insertions(+), 89 deletions(-) diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py b/sdks/python/apache_beam/yaml/yaml_ml.py index 111618679f96..f60cc8f71adf 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml.py +++ b/sdks/python/apache_beam/yaml/yaml_ml.py @@ -40,56 +40,21 @@ tft = None # type: ignore -def normalize_ml(spec): - if spec['type'] == 'RunInference': - config = spec.get('config') - for required in ('model_handler', ): - if required not in config: - raise ValueError( - f'Missing {required} parameter in RunInference config ' - f'at line {SafeLineLoader.get_line(spec)}') - model_handler = config.get('model_handler') - if not isinstance(model_handler, dict): - raise ValueError( - 'Invalid model_handler specification at line ' - f'{SafeLineLoader.get_line(spec)}. Expected ' - f'dict but was {type(model_handler)}.') - for required in ('type', 'config'): - if required not in model_handler: - raise ValueError( - f'Missing {required} in model handler ' - f'at line {SafeLineLoader.get_line(model_handler)}') - typ = model_handler['type'] - extra_params = set(SafeLineLoader.strip_metadata(model_handler).keys()) - { - 'type', 'config' - } - if extra_params: - raise ValueError( - f'Unexpected parameters in model handler of type {typ} ' - f'at line {SafeLineLoader.get_line(spec)}: {extra_params}') - model_handler_provider = ModelHandlerProvider.handler_types.get(typ, None) - if model_handler_provider: - model_handler_provider.validate(model_handler['config']) - else: - raise NotImplementedError( - f'Unknown model handler type: {typ} ' - f'at line {SafeLineLoader.get_line(spec)}.') - - return spec - - class ModelHandlerProvider: handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {} def __init__( - self, handler, preprocess: Callable = None, postprocess: Callable = None): + self, + handler, + preprocess: Optional[Dict[str, str]] = None, + postprocess: Optional[Dict[str, str]] = None): self._handler = handler - self._preprocess = self.parse_processing_transform( - preprocess, 'preprocess') or self.preprocess_fn - self._postprocess = self.parse_processing_transform( - postprocess, 'postprocess') or self.postprocess_fn + self._preprocess_fn = self.parse_processing_transform( + preprocess, 'preprocess') or self.default_preprocess_fn() + self._postprocess_fn = self.parse_processing_transform( + postprocess, 'postprocess') or self.default_postprocess_fn() - def get_output_schema(self): + def inference_output_type(self): return Any @staticmethod @@ -118,21 +83,22 @@ def _parse_config(callable=None, path=None, name=None): def underlying_handler(self): return self._handler - def preprocess_fn(self, row): + @staticmethod + def default_preprocess_fn(): raise ValueError( 'Handler does not implement a default preprocess ' 'method. Please define a preprocessing method using the ' '\'preprocess\' tag.') - def create_preprocess_fn(self): - return lambda row: (row, self._preprocess(row)) + def _preprocess_fn_internal(self): + return lambda row: (row, self._preprocess_fn(row)) @staticmethod - def postprocess_fn(x): - return x + def default_postprocess_fn(): + return lambda x: x - def create_postprocess_fn(self): - return lambda result: (result[0], self._postprocess(result[1])) + def _postprocess_fn_internal(self): + return lambda result: (result[0], self._postprocess_fn(result[1])) @staticmethod def validate(model_handler_spec): @@ -165,18 +131,19 @@ class VertexAIModelHandlerJSONProvider(ModelHandlerProvider): def __init__( self, endpoint_id: str, - endpoint_project: str, - endpoint_region: str, + project: str, + location: str, + preprocess: Dict[str, str], experiment: Optional[str] = None, network: Optional[str] = None, private: bool = False, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, max_batch_duration_secs: Optional[int] = None, - env_vars=None, - preprocess: Callable = None, - postprocess: Callable = None): - """ModelHandler for Vertex AI. + env_vars: Optional[Dict[str, Any]] = None, + postprocess: Optional[Dict[str, str]] = None): + """ + ModelHandler for Vertex AI. For example: :: @@ -187,37 +154,43 @@ def __init__( type: VertexAIModelHandlerJSON config: endpoint_id: 9876543210 - endpoint_project: 1234567890 - endpoint_region: us-east1 + project: my-project + location: us-east1 preprocess: callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}' Args: endpoint_id: the numerical ID of the Vertex AI endpoint to query. - endpoint_project: the GCP project name where the endpoint is deployed. - endpoint_region: the GCP location where the endpoint is deployed. - experiment: optional. experiment label to apply to the + project: the GCP project name where the endpoint is deployed. + location: the GCP location where the endpoint is deployed. + experiment: Experiment label to apply to the queries. See https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments for more information. - network: optional. the full name of the Compute Engine + network: The full name of the Compute Engine network the endpoint is deployed on; used for private endpoints. The network or subnetwork Dataflow pipeline option must be set and match this network for pipeline execution. Ex: "projects/12345/global/networks/myVPC" - private: optional. if the deployed Vertex AI endpoint is + private: If the deployed Vertex AI endpoint is private, set to true. Requires a network to be provided as well. - min_batch_size: optional. the minimum batch size to use when batching + min_batch_size: The minimum batch size to use when batching inputs. - max_batch_size: optional. the maximum batch size to use when batching + max_batch_size: The maximum batch size to use when batching inputs. - max_batch_duration_secs: optional. the maximum amount of time to buffer + max_batch_duration_secs: The maximum amount of time to buffer a batch before emitting; used in streaming contexts. env_vars: Environment variables. - preprocess: - postprocess: + preprocess: A python callable, defined either inline, or using a file, + that is invoked on the input row before sending to the model to be + loaded by this ModelHandler. This parameter is required by the + `VertexAIModelHandlerJSON` ModelHandler. + postprocess: A python callable, defined either inline, or using a file, + that is invoked on the PredictionResult output by the ModelHandler + before parsing into the output Beam Row under the field name defined + by the inference_tag. """ try: @@ -229,8 +202,8 @@ def __init__( _handler = VertexAIModelHandlerJSON( endpoint_id=str(endpoint_id), - project=endpoint_project, - location=endpoint_region, + project=project, + location=location, experiment=experiment, network=network, private=private, @@ -243,18 +216,15 @@ def __init__( @staticmethod def validate(model_handler_spec): - for required in ('endpoint_id', 'endpoint_project', 'endpoint_region'): - if required not in model_handler_spec: - raise ValueError( - f'Missing {required} in model handler ' - f'at line {SafeLineLoader.get_line(model_handler_spec)}') + pass - def get_output_schema(self): + def inference_output_type(self): return RowTypeConstraint.from_fields([('example', Any), ('inference', Any), ('model_id', Optional[str])]) - def create_postprocess_fn(self): - return lambda x: (x[0], beam.Row(**self._postprocess(x[1])._asdict())) + @staticmethod + def default_postprocess_fn(): + return lambda x: beam.Row(**x._asdict()) @beam.ptransform.ptransform_fn @@ -349,11 +319,11 @@ def run_inference( callable: 'lambda row: {"prompt": row.question}' ... - In the above example, the Create transform generates a collection of two - elements, each with a single field - "question". The model, however, expects - a Python Dict with a single key, "prompt". In this case, we can specify a - simple Lambda function (alternatively could define a full function), to map - the data. + In the above example, the Create transform generates a collection of two Beam + Row elements, each with a single field - "question". The model, however, + expects a Python Dict with a single key, "prompt". In this case, we can + specify a simple Lambda function (alternatively could define a full function), + to map the data. ### Postprocessing predictions @@ -428,19 +398,41 @@ def fn(x: PredictionResult): options.YamlOptions.check_enabled(pcoll.pipeline, 'ML') + if not isinstance(model_handler, dict): + raise ValueError( + 'Invalid model_handler specification. Expected dict but was ' + f'{type(model_handler)}.') + expected_model_handler_params = {'type', 'config'} + given_model_handler_params = set( + SafeLineLoader.strip_metadata(model_handler).keys()) + extra_params = given_model_handler_params - expected_model_handler_params + if extra_params: + raise ValueError(f'Unexpected parameters in model_handler: {extra_params}') + missing_params = expected_model_handler_params - given_model_handler_params + if missing_params: + raise ValueError(f'Missing parameters in model_handler: {missing_params}') + typ = model_handler['type'] + model_handler_provider = ModelHandlerProvider.handler_types.get(typ, None) + if model_handler_provider and issubclass(model_handler_provider, + ModelHandlerProvider): + model_handler_provider.validate(model_handler['config']) + else: + raise NotImplementedError(f'Unknown model handler type: {typ}.') + model_handler_provider = ModelHandlerProvider.create_handler(model_handler) schema = RowTypeConstraint.from_fields( list( RowTypeConstraint.from_user_type( pcoll.element_type.user_type)._fields) + - [(inference_tag, model_handler_provider.get_output_schema())]) + [(inference_tag, model_handler_provider.inference_output_type())]) return ( pcoll | RunInference( model_handler=KeyedModelHandler( model_handler_provider.underlying_handler()).with_preprocess_fn( - model_handler_provider.create_preprocess_fn()). - with_postprocess_fn(model_handler_provider.create_postprocess_fn()), + model_handler_provider._preprocess_fn_internal()). + with_postprocess_fn( + model_handler_provider._postprocess_fn_internal()), inference_args=inference_args) | beam.Map( lambda row: beam.Row(**{ diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index c83cd53b5855..c6f5544c650a 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -40,7 +40,6 @@ from apache_beam.yaml.yaml_combine import normalize_combine from apache_beam.yaml.yaml_mapping import normalize_mapping from apache_beam.yaml.yaml_mapping import validate_generic_expressions -from apache_beam.yaml.yaml_ml import normalize_ml from apache_beam.yaml.yaml_utils import SafeLineLoader __all__ = ["YamlTransform"] @@ -907,7 +906,6 @@ def preprocess_languages(spec): ensure_transforms_have_types, normalize_mapping, normalize_combine, - normalize_ml, preprocess_languages, ensure_transforms_have_providers, preprocess_source_sink, diff --git a/sdks/python/apache_beam/yaml/yaml_utils.py b/sdks/python/apache_beam/yaml/yaml_utils.py index 91cad4175fdb..63beb90f0711 100644 --- a/sdks/python/apache_beam/yaml/yaml_utils.py +++ b/sdks/python/apache_beam/yaml/yaml_utils.py @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import uuid from typing import Iterable from typing import Mapping diff --git a/sdks/python/apache_beam/yaml/yaml_utils_test.py b/sdks/python/apache_beam/yaml/yaml_utils_test.py index 23105a4e82d4..70f6ba9b5198 100644 --- a/sdks/python/apache_beam/yaml/yaml_utils_test.py +++ b/sdks/python/apache_beam/yaml/yaml_utils_test.py @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import unittest import yaml @@ -54,3 +71,8 @@ def test_strip_metadata_nothing_to_strip(self): self.assertFalse(hasattr(stripped, '__line__')) self.assertFalse(hasattr(stripped, '__uuid__')) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() From a8adde304a9f2b295ce6649e9c1545f4e8b57c37 Mon Sep 17 00:00:00 2001 From: Jeffrey Kinard Date: Fri, 20 Dec 2024 15:08:45 -0500 Subject: [PATCH 3/5] add more docs Signed-off-by: Jeffrey Kinard --- sdks/python/apache_beam/yaml/yaml_ml.py | 54 +++++++++++++++++-------- 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py b/sdks/python/apache_beam/yaml/yaml_ml.py index f60cc8f71adf..92b33dcb7db6 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml.py +++ b/sdks/python/apache_beam/yaml/yaml_ml.py @@ -86,9 +86,12 @@ def underlying_handler(self): @staticmethod def default_preprocess_fn(): raise ValueError( - 'Handler does not implement a default preprocess ' + 'Model Handler does not implement a default preprocess ' 'method. Please define a preprocessing method using the ' - '\'preprocess\' tag.') + '\'preprocess\' tag. This is required in most cases because ' + 'most models will have a different input shape, so the model ' + 'cannot generalize how the input Row should be transformed. For ' + 'an example preprocess method, see VertexAIModelHandlerJSONProvider') def _preprocess_fn_internal(self): return lambda row: (row, self._preprocess_fn(row)) @@ -134,17 +137,34 @@ def __init__( project: str, location: str, preprocess: Dict[str, str], + postprocess: Optional[Dict[str, str]] = None, experiment: Optional[str] = None, network: Optional[str] = None, private: bool = False, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, max_batch_duration_secs: Optional[int] = None, - env_vars: Optional[Dict[str, Any]] = None, - postprocess: Optional[Dict[str, str]] = None): + env_vars: Optional[Dict[str, Any]] = None): """ ModelHandler for Vertex AI. + This Model Handler can be used with RunInference to load a model hosted + on VertexAI. Every model that is hosted on VertexAI should have three + distinct, required, parameters - `endpoint_id`, `project` and `location`. + These parameters tell the Model Handler how to access the model's endpoint + so that input data can be sent using an API request, and inferences can be + received as a response. + + This Model Handler also required a `preprocess` function to be defined. + Preprocessing and Postprocessing are described in more detail in the + RunInference docs: + https://beam.apache.org/releases/yamldoc/current/#runinference + + Every model will have a unique input, but all requests should be + JSON-formatted. For example, most language models such as Llama and Gemma + expect a JSON with the key "prompt" (among other optional keys). In Python, + JSON can be expressed as a dictionary. + For example: :: - type: RunInference @@ -159,10 +179,24 @@ def __init__( preprocess: callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}' + In the above example, which mimics a call to a Llama 3 model hosted on + VertexAI, the preprocess function (in this case a lambda) takes in a Beam + Row with a single field, "prompt", and maps it to a dict with the same + field. It also specifies an optional parameter, "max_tokens", that tells the + model the allowed token size (in this case input + output token size). + Args: endpoint_id: the numerical ID of the Vertex AI endpoint to query. project: the GCP project name where the endpoint is deployed. location: the GCP location where the endpoint is deployed. + preprocess: A python callable, defined either inline, or using a file, + that is invoked on the input row before sending to the model to be + loaded by this ModelHandler. This parameter is required by the + `VertexAIModelHandlerJSON` ModelHandler. + postprocess: A python callable, defined either inline, or using a file, + that is invoked on the PredictionResult output by the ModelHandler + before parsing into the output Beam Row under the field name defined + by the inference_tag. experiment: Experiment label to apply to the queries. See https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments @@ -183,14 +217,6 @@ def __init__( max_batch_duration_secs: The maximum amount of time to buffer a batch before emitting; used in streaming contexts. env_vars: Environment variables. - preprocess: A python callable, defined either inline, or using a file, - that is invoked on the input row before sending to the model to be - loaded by this ModelHandler. This parameter is required by the - `VertexAIModelHandlerJSON` ModelHandler. - postprocess: A python callable, defined either inline, or using a file, - that is invoked on the PredictionResult output by the ModelHandler - before parsing into the output Beam Row under the field name defined - by the inference_tag. """ try: @@ -222,10 +248,6 @@ def inference_output_type(self): return RowTypeConstraint.from_fields([('example', Any), ('inference', Any), ('model_id', Optional[str])]) - @staticmethod - def default_postprocess_fn(): - return lambda x: beam.Row(**x._asdict()) - @beam.ptransform.ptransform_fn def run_inference( From cda2ee1992d86ca660d37f2238aadb8b043be165 Mon Sep 17 00:00:00 2001 From: Jeffrey Kinard Date: Fri, 20 Dec 2024 16:13:29 -0500 Subject: [PATCH 4/5] fix failing tests Signed-off-by: Jeffrey Kinard --- sdks/python/apache_beam/yaml/yaml_ml.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py b/sdks/python/apache_beam/yaml/yaml_ml.py index 92b33dcb7db6..061ca6fc6028 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml.py +++ b/sdks/python/apache_beam/yaml/yaml_ml.py @@ -17,7 +17,6 @@ """This module defines yaml wrappings for some ML transforms.""" from typing import Any -from typing import Callable from typing import Dict from typing import List from typing import Optional @@ -41,7 +40,7 @@ class ModelHandlerProvider: - handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {} + handler_types: Dict[str, "ModelHandlerProvider"] = {} def __init__( self, @@ -158,7 +157,7 @@ def __init__( This Model Handler also required a `preprocess` function to be defined. Preprocessing and Postprocessing are described in more detail in the RunInference docs: - https://beam.apache.org/releases/yamldoc/current/#runinference + https://beam.apache.org/releases/yamldoc/current/#runinference Every model will have a unique input, but all requests should be JSON-formatted. For example, most language models such as Llama and Gemma @@ -414,7 +413,7 @@ def fn(x: PredictionResult): 'inference'. inference_args: Extra arguments for models whose inference call requires extra parameters. Make sure to check the underlying ModelHandler docs to - see which args are allowed. + see which args are allowed. """ From 063890ce9454f3fe8f344b072cc3335e48d624f0 Mon Sep 17 00:00:00 2001 From: Jeffrey Kinard Date: Sat, 21 Dec 2024 08:36:15 -0600 Subject: [PATCH 5/5] fix errors Signed-off-by: Jeffrey Kinard --- sdks/python/apache_beam/yaml/yaml_ml.py | 10 +++++----- sdks/python/apache_beam/yaml/yaml_utils_test.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py b/sdks/python/apache_beam/yaml/yaml_ml.py index 061ca6fc6028..fb255c5b0b02 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml.py +++ b/sdks/python/apache_beam/yaml/yaml_ml.py @@ -17,6 +17,7 @@ """This module defines yaml wrappings for some ML transforms.""" from typing import Any +from typing import Callable from typing import Dict from typing import List from typing import Optional @@ -40,7 +41,7 @@ class ModelHandlerProvider: - handler_types: Dict[str, "ModelHandlerProvider"] = {} + handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {} def __init__( self, @@ -435,16 +436,15 @@ def fn(x: PredictionResult): typ = model_handler['type'] model_handler_provider = ModelHandlerProvider.handler_types.get(typ, None) if model_handler_provider and issubclass(model_handler_provider, - ModelHandlerProvider): + type(ModelHandlerProvider)): model_handler_provider.validate(model_handler['config']) else: raise NotImplementedError(f'Unknown model handler type: {typ}.') model_handler_provider = ModelHandlerProvider.create_handler(model_handler) + user_type = RowTypeConstraint.from_user_type(pcoll.element_type.user_type) schema = RowTypeConstraint.from_fields( - list( - RowTypeConstraint.from_user_type( - pcoll.element_type.user_type)._fields) + + list(user_type._fields if user_type else []) + [(inference_tag, model_handler_provider.inference_output_type())]) return ( diff --git a/sdks/python/apache_beam/yaml/yaml_utils_test.py b/sdks/python/apache_beam/yaml/yaml_utils_test.py index 70f6ba9b5198..4fd2c793e57e 100644 --- a/sdks/python/apache_beam/yaml/yaml_utils_test.py +++ b/sdks/python/apache_beam/yaml/yaml_utils_test.py @@ -15,6 +15,7 @@ # limitations under the License. # +import logging import unittest import yaml