diff --git a/sdks/python/apache_beam/yaml/standard_providers.yaml b/sdks/python/apache_beam/yaml/standard_providers.yaml index 242faaa9a77b..31eb5e1c6daa 100644 --- a/sdks/python/apache_beam/yaml/standard_providers.yaml +++ b/sdks/python/apache_beam/yaml/standard_providers.yaml @@ -56,6 +56,7 @@ config: {} transforms: MLTransform: 'apache_beam.yaml.yaml_ml.ml_transform' + RunInference: 'apache_beam.yaml.yaml_ml.run_inference' - type: renaming transforms: diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py b/sdks/python/apache_beam/yaml/yaml_ml.py index 33f2eeefd296..111618679f96 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml.py +++ b/sdks/python/apache_beam/yaml/yaml_ml.py @@ -16,13 +16,20 @@ # """This module defines yaml wrappings for some ML transforms.""" - from typing import Any +from typing import Callable +from typing import Dict from typing import List from typing import Optional import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference import RunInference +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.typehints.row_type import RowTypeConstraint +from apache_beam.utils import python_callable from apache_beam.yaml import options +from apache_beam.yaml.yaml_utils import SafeLineLoader try: from apache_beam.ml.transforms import tft @@ -33,11 +40,419 @@ tft = None # type: ignore +def normalize_ml(spec): + if spec['type'] == 'RunInference': + config = spec.get('config') + for required in ('model_handler', ): + if required not in config: + raise ValueError( + f'Missing {required} parameter in RunInference config ' + f'at line {SafeLineLoader.get_line(spec)}') + model_handler = config.get('model_handler') + if not isinstance(model_handler, dict): + raise ValueError( + 'Invalid model_handler specification at line ' + f'{SafeLineLoader.get_line(spec)}. Expected ' + f'dict but was {type(model_handler)}.') + for required in ('type', 'config'): + if required not in model_handler: + raise ValueError( + f'Missing {required} in model handler ' + f'at line {SafeLineLoader.get_line(model_handler)}') + typ = model_handler['type'] + extra_params = set(SafeLineLoader.strip_metadata(model_handler).keys()) - { + 'type', 'config' + } + if extra_params: + raise ValueError( + f'Unexpected parameters in model handler of type {typ} ' + f'at line {SafeLineLoader.get_line(spec)}: {extra_params}') + model_handler_provider = ModelHandlerProvider.handler_types.get(typ, None) + if model_handler_provider: + model_handler_provider.validate(model_handler['config']) + else: + raise NotImplementedError( + f'Unknown model handler type: {typ} ' + f'at line {SafeLineLoader.get_line(spec)}.') + + return spec + + +class ModelHandlerProvider: + handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {} + + def __init__( + self, handler, preprocess: Callable = None, postprocess: Callable = None): + self._handler = handler + self._preprocess = self.parse_processing_transform( + preprocess, 'preprocess') or self.preprocess_fn + self._postprocess = self.parse_processing_transform( + postprocess, 'postprocess') or self.postprocess_fn + + def get_output_schema(self): + return Any + + @staticmethod + def parse_processing_transform(processing_transform, typ): + def _parse_config(callable=None, path=None, name=None): + if callable and (path or name): + raise ValueError( + f"Cannot specify 'callable' with 'path' and 'name' for {typ} " + f"function.") + if path and name: + return python_callable.PythonCallableWithSource.load_from_script( + FileSystems.open(path).read().decode(), name) + elif callable: + return python_callable.PythonCallableWithSource(callable) + else: + raise ValueError( + f"Must specify one of 'callable' or 'path' and 'name' for {typ} " + f"function.") + + if processing_transform: + if isinstance(processing_transform, dict): + return _parse_config(**processing_transform) + else: + raise ValueError("Invalid model_handler specification.") + + def underlying_handler(self): + return self._handler + + def preprocess_fn(self, row): + raise ValueError( + 'Handler does not implement a default preprocess ' + 'method. Please define a preprocessing method using the ' + '\'preprocess\' tag.') + + def create_preprocess_fn(self): + return lambda row: (row, self._preprocess(row)) + + @staticmethod + def postprocess_fn(x): + return x + + def create_postprocess_fn(self): + return lambda result: (result[0], self._postprocess(result[1])) + + @staticmethod + def validate(model_handler_spec): + raise NotImplementedError(type(ModelHandlerProvider)) + + @classmethod + def register_handler_type(cls, type_name): + def apply(constructor): + cls.handler_types[type_name] = constructor + return constructor + + return apply + + @classmethod + def create_handler(cls, model_handler_spec) -> "ModelHandlerProvider": + typ = model_handler_spec['type'] + config = model_handler_spec['config'] + try: + result = cls.handler_types[typ](**config) + if not hasattr(result, 'to_json'): + result.to_json = lambda: model_handler_spec + return result + except Exception as exn: + raise ValueError( + f'Unable to instantiate model handler of type {typ}. {exn}') + + +@ModelHandlerProvider.register_handler_type('VertexAIModelHandlerJSON') +class VertexAIModelHandlerJSONProvider(ModelHandlerProvider): + def __init__( + self, + endpoint_id: str, + endpoint_project: str, + endpoint_region: str, + experiment: Optional[str] = None, + network: Optional[str] = None, + private: bool = False, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, + env_vars=None, + preprocess: Callable = None, + postprocess: Callable = None): + """ModelHandler for Vertex AI. + + For example: :: + + - type: RunInference + config: + inference_tag: 'my_inference' + model_handler: + type: VertexAIModelHandlerJSON + config: + endpoint_id: 9876543210 + endpoint_project: 1234567890 + endpoint_region: us-east1 + preprocess: + callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}' + + Args: + endpoint_id: the numerical ID of the Vertex AI endpoint to query. + endpoint_project: the GCP project name where the endpoint is deployed. + endpoint_region: the GCP location where the endpoint is deployed. + experiment: optional. experiment label to apply to the + queries. See + https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments + for more information. + network: optional. the full name of the Compute Engine + network the endpoint is deployed on; used for private + endpoints. The network or subnetwork Dataflow pipeline + option must be set and match this network for pipeline + execution. + Ex: "projects/12345/global/networks/myVPC" + private: optional. if the deployed Vertex AI endpoint is + private, set to true. Requires a network to be provided + as well. + min_batch_size: optional. the minimum batch size to use when batching + inputs. + max_batch_size: optional. the maximum batch size to use when batching + inputs. + max_batch_duration_secs: optional. the maximum amount of time to buffer + a batch before emitting; used in streaming contexts. + env_vars: Environment variables. + preprocess: + postprocess: + """ + + try: + from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON + except ImportError: + raise ValueError( + 'Unable to import VertexAIModelHandlerJSON. Please ' + 'install gcp dependencies: `pip install apache_beam[gcp]`') + + _handler = VertexAIModelHandlerJSON( + endpoint_id=str(endpoint_id), + project=endpoint_project, + location=endpoint_region, + experiment=experiment, + network=network, + private=private, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + env_vars=env_vars or {}) + + super().__init__(_handler, preprocess, postprocess) + + @staticmethod + def validate(model_handler_spec): + for required in ('endpoint_id', 'endpoint_project', 'endpoint_region'): + if required not in model_handler_spec: + raise ValueError( + f'Missing {required} in model handler ' + f'at line {SafeLineLoader.get_line(model_handler_spec)}') + + def get_output_schema(self): + return RowTypeConstraint.from_fields([('example', Any), ('inference', Any), + ('model_id', Optional[str])]) + + def create_postprocess_fn(self): + return lambda x: (x[0], beam.Row(**self._postprocess(x[1])._asdict())) + + +@beam.ptransform.ptransform_fn +def run_inference( + pcoll, + model_handler: Dict[str, Any], + inference_tag: Optional[str] = 'inference', + inference_args: Optional[Dict[str, Any]] = None) -> beam.PCollection[beam.Row]: # pylint: disable=line-too-long + """ + A transform that takes the input rows, containing examples (or features), for + use on an ML model. The transform then appends the inferences + (or predictions) for those examples to the input row. + + A ModelHandler must be passed to the `model_handler` parameter. The + ModelHandler is responsible for configuring how the ML model will be loaded + and how input data will be passed to it. Every ModelHandler has a config tag, + similar to how a transform is defined, where the parameters are defined. + + For example: :: + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + ... + + By default, the RunInference transform will return the + input row with a single field appended named by the `inference_tag` parameter + ("inference" by default) that contains the inference directly returned by the + underlying ModelHandler, after any optional postprocessing. + + For example, if the input had the following: :: + + Row(question="What is a car?") + + The output row would look like: :: + + Row(question="What is a car?", inference=...) + + where the `inference` tag can be overridden with the `inference_tag` + parameter. + + However, if one specified the following transform config: :: + + - type: RunInference + config: + inference_tag: my_inference + model_handler: ... + + The output row would look like: :: + + Row(question="What is a car?", my_inference=...) + + See more complete documentation on the underlying + [RunInference](https://beam.apache.org/documentation/ml/inference-overview/) + transform. + + ### Preprocessing input data + + In most cases, the model will be expecting data in a particular data format, + whether it be a Python Dict, PyTorch tensor, etc. However, the outputs of all + built-in Beam YAML transforms are Beam Rows. To allow for transforming + the Beam Row into a data format the model recognizes, each ModelHandler is + equipped with a `preprocessing` parameter for performing necessary data + preprocessing. It is possible for a ModelHandler to define a default + preprocessing function, but in most cases, one will need to be specified by + the caller. + + For example, using `callable`: :: + + pipeline: + type: chain + + transforms: + - type: Create + config: + elements: + - question: "What is a car?" + - question: "Where is the Eiffel Tower located?" + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + preprocess: + callable: 'lambda row: {"prompt": row.question}' + ... + + In the above example, the Create transform generates a collection of two + elements, each with a single field - "question". The model, however, expects + a Python Dict with a single key, "prompt". In this case, we can specify a + simple Lambda function (alternatively could define a full function), to map + the data. + + ### Postprocessing predictions + + It is also possible to define a postprocessing function to postprocess the + data output by the ModelHandler. See the documentation for the ModelHandler + you intend to use (list defined below under `model_handler` parameter doc). + + In many cases, before postprocessing, the object + will be a + [PredictionResult](https://beam.apache.org/releases/pydoc/BEAM_VERSION/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult). # pylint: disable=line-too-long + This type behaves very similarly to a Beam Row and fields can be accessed + using dot notation. However, make sure to check the docs for your ModelHandler + to see which fields its PredictionResult contains or if it returns a + different object altogether. + + For example: :: + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + postprocess: + callable: | + def fn(x: PredictionResult): + return beam.Row(x.example, x.inference, x.model_id) + ... + + The above example demonstrates converting the original output data type (in + this case it is PredictionResult), and converts to a Beam Row, which allows + for easier mapping in a later transform. + + ### File-based pre/postprocessing functions + + For both preprocessing and postprocessing, it is also possible to specify a + Python UDF (User-defined function) file that contains the function. This is + possible by specifying the `path` to the file (local file or GCS path) and + the `name` of the function in the file. + + For example: :: + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + preprocess: + path: gs://my-bucket/path/to/preprocess.py + name: my_preprocess_fn + postprocess: + path: gs://my-bucket/path/to/postprocess.py + name: my_postprocess_fn + ... + + Args: + model_handler: Specifies the parameters for the respective + enrichment_handler in a YAML/JSON format. To see the full set of + handler_config parameters, see their corresponding doc pages: + + - [VertexAIModelHandlerJSON](https://beam.apache.org/releases/pydoc/current/apache_beam.yaml.yaml_ml.VertexAIModelHandlerJSONProvider) # pylint: disable=line-too-long + inference_tag: The tag to use for the returned inference. Default is + 'inference'. + inference_args: Extra arguments for models whose inference call requires + extra parameters. Make sure to check the underlying ModelHandler docs to + see which args are allowed. + + """ + + options.YamlOptions.check_enabled(pcoll.pipeline, 'ML') + + model_handler_provider = ModelHandlerProvider.create_handler(model_handler) + schema = RowTypeConstraint.from_fields( + list( + RowTypeConstraint.from_user_type( + pcoll.element_type.user_type)._fields) + + [(inference_tag, model_handler_provider.get_output_schema())]) + + return ( + pcoll | RunInference( + model_handler=KeyedModelHandler( + model_handler_provider.underlying_handler()).with_preprocess_fn( + model_handler_provider.create_preprocess_fn()). + with_postprocess_fn(model_handler_provider.create_postprocess_fn()), + inference_args=inference_args) + | beam.Map( + lambda row: beam.Row(**{ + inference_tag: row[1], **row[0]._asdict() + })).with_output_types(schema)) + + def _config_to_obj(spec): if 'type' not in spec: - raise ValueError(r"Missing type in ML transform spec {spec}") + raise ValueError(f"Missing type in ML transform spec {spec}") if 'config' not in spec: - raise ValueError(r"Missing config in ML transform spec {spec}") + raise ValueError(f"Missing config in ML transform spec {spec}") constructor = _transform_constructors.get(spec['type']) if constructor is None: raise ValueError("Unknown ML transform type: %r" % spec['type']) diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 7cb96a7efb32..c83cd53b5855 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -23,7 +23,6 @@ import os import pprint import re -import uuid from typing import Any from typing import Iterable from typing import List @@ -32,7 +31,6 @@ import jinja2 import yaml -from yaml.loader import SafeLoader import apache_beam as beam from apache_beam.io.filesystems import FileSystems @@ -42,6 +40,8 @@ from apache_beam.yaml.yaml_combine import normalize_combine from apache_beam.yaml.yaml_mapping import normalize_mapping from apache_beam.yaml.yaml_mapping import validate_generic_expressions +from apache_beam.yaml.yaml_ml import normalize_ml +from apache_beam.yaml.yaml_utils import SafeLineLoader __all__ = ["YamlTransform"] @@ -130,59 +130,6 @@ def empty_if_explicitly_empty(io): return io -class SafeLineLoader(SafeLoader): - """A yaml loader that attaches line information to mappings and strings.""" - class TaggedString(str): - """A string class to which we can attach metadata. - - This is primarily used to trace a string's origin back to its place in a - yaml file. - """ - def __reduce__(self): - # Pickle as an ordinary string. - return str, (str(self), ) - - def construct_scalar(self, node): - value = super().construct_scalar(node) - if isinstance(value, str): - value = SafeLineLoader.TaggedString(value) - value._line_ = node.start_mark.line + 1 - return value - - def construct_mapping(self, node, deep=False): - mapping = super().construct_mapping(node, deep=deep) - mapping['__line__'] = node.start_mark.line + 1 - mapping['__uuid__'] = self.create_uuid() - return mapping - - @classmethod - def create_uuid(cls): - return str(uuid.uuid4()) - - @classmethod - def strip_metadata(cls, spec, tagged_str=True): - if isinstance(spec, Mapping): - return { - cls.strip_metadata(key, tagged_str): - cls.strip_metadata(value, tagged_str) - for (key, value) in spec.items() - if key not in ('__line__', '__uuid__') - } - elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)): - return [cls.strip_metadata(value, tagged_str) for value in spec] - elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str: - return str(spec) - else: - return spec - - @staticmethod - def get_line(obj): - if isinstance(obj, dict): - return obj.get('__line__', 'unknown') - else: - return getattr(obj, '_line_', 'unknown') - - class LightweightScope(object): def __init__(self, transforms): self._transforms = transforms @@ -960,6 +907,7 @@ def preprocess_languages(spec): ensure_transforms_have_types, normalize_mapping, normalize_combine, + normalize_ml, preprocess_languages, ensure_transforms_have_providers, preprocess_source_sink, diff --git a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py index 084e03cdb197..5bc9de24bb38 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py @@ -23,7 +23,6 @@ from apache_beam.yaml import YamlTransform from apache_beam.yaml import yaml_provider from apache_beam.yaml.yaml_provider import InlineProvider -from apache_beam.yaml.yaml_transform import SafeLineLoader from apache_beam.yaml.yaml_transform import Scope from apache_beam.yaml.yaml_transform import chain_as_composite from apache_beam.yaml.yaml_transform import ensure_errors_consumed @@ -39,57 +38,7 @@ from apache_beam.yaml.yaml_transform import preprocess_flattened_inputs from apache_beam.yaml.yaml_transform import preprocess_windowing from apache_beam.yaml.yaml_transform import push_windowing_to_roots - - -class SafeLineLoaderTest(unittest.TestCase): - def test_get_line(self): - pipeline_yaml = ''' - type: composite - input: - elements: input - transforms: - - type: PyMap - name: Square - input: elements - config: - fn: "lambda x: x * x" - - type: PyMap - name: Cube - input: elements - config: - fn: "lambda x: x * x * x" - output: - Flatten - ''' - spec = yaml.load(pipeline_yaml, Loader=SafeLineLoader) - self.assertEqual(SafeLineLoader.get_line(spec['type']), 2) - self.assertEqual(SafeLineLoader.get_line(spec['input']), 4) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]), 6) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['type']), 6) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['name']), 7) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][1]), 11) - self.assertEqual(SafeLineLoader.get_line(spec['output']), 17) - self.assertEqual(SafeLineLoader.get_line(spec['transforms']), "unknown") - - def test_strip_metadata(self): - spec_yaml = ''' - transforms: - - type: PyMap - name: Square - ''' - spec = yaml.load(spec_yaml, Loader=SafeLineLoader) - stripped = SafeLineLoader.strip_metadata(spec['transforms']) - - self.assertFalse(hasattr(stripped[0], '__line__')) - self.assertFalse(hasattr(stripped[0], '__uuid__')) - - def test_strip_metadata_nothing_to_strip(self): - spec_yaml = 'prop: 123' - spec = yaml.load(spec_yaml, Loader=SafeLineLoader) - stripped = SafeLineLoader.strip_metadata(spec['prop']) - - self.assertFalse(hasattr(stripped, '__line__')) - self.assertFalse(hasattr(stripped, '__uuid__')) +from apache_beam.yaml.yaml_utils import SafeLineLoader def new_pipeline(): diff --git a/sdks/python/apache_beam/yaml/yaml_utils.py b/sdks/python/apache_beam/yaml/yaml_utils.py new file mode 100644 index 000000000000..91cad4175fdb --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_utils.py @@ -0,0 +1,58 @@ +import uuid +from typing import Iterable +from typing import Mapping + +from yaml import SafeLoader + + +class SafeLineLoader(SafeLoader): + """A yaml loader that attaches line information to mappings and strings.""" + class TaggedString(str): + """A string class to which we can attach metadata. + + This is primarily used to trace a string's origin back to its place in a + yaml file. + """ + def __reduce__(self): + # Pickle as an ordinary string. + return str, (str(self), ) + + def construct_scalar(self, node): + value = super().construct_scalar(node) + if isinstance(value, str): + value = SafeLineLoader.TaggedString(value) + value._line_ = node.start_mark.line + 1 + return value + + def construct_mapping(self, node, deep=False): + mapping = super().construct_mapping(node, deep=deep) + mapping['__line__'] = node.start_mark.line + 1 + mapping['__uuid__'] = self.create_uuid() + return mapping + + @classmethod + def create_uuid(cls): + return str(uuid.uuid4()) + + @classmethod + def strip_metadata(cls, spec, tagged_str=True): + if isinstance(spec, Mapping): + return { + cls.strip_metadata(key, tagged_str): + cls.strip_metadata(value, tagged_str) + for (key, value) in spec.items() + if key not in ('__line__', '__uuid__') + } + elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)): + return [cls.strip_metadata(value, tagged_str) for value in spec] + elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str: + return str(spec) + else: + return spec + + @staticmethod + def get_line(obj): + if isinstance(obj, dict): + return obj.get('__line__', 'unknown') + else: + return getattr(obj, '_line_', 'unknown') diff --git a/sdks/python/apache_beam/yaml/yaml_utils_test.py b/sdks/python/apache_beam/yaml/yaml_utils_test.py new file mode 100644 index 000000000000..23105a4e82d4 --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_utils_test.py @@ -0,0 +1,56 @@ +import unittest + +import yaml + +from apache_beam.yaml.yaml_utils import SafeLineLoader + + +class SafeLineLoaderTest(unittest.TestCase): + def test_get_line(self): + pipeline_yaml = ''' + type: composite + input: + elements: input + transforms: + - type: PyMap + name: Square + input: elements + config: + fn: "lambda x: x * x" + - type: PyMap + name: Cube + input: elements + config: + fn: "lambda x: x * x * x" + output: + Flatten + ''' + spec = yaml.load(pipeline_yaml, Loader=SafeLineLoader) + self.assertEqual(SafeLineLoader.get_line(spec['type']), 2) + self.assertEqual(SafeLineLoader.get_line(spec['input']), 4) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]), 6) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['type']), 6) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['name']), 7) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][1]), 11) + self.assertEqual(SafeLineLoader.get_line(spec['output']), 17) + self.assertEqual(SafeLineLoader.get_line(spec['transforms']), "unknown") + + def test_strip_metadata(self): + spec_yaml = ''' + transforms: + - type: PyMap + name: Square + ''' + spec = yaml.load(spec_yaml, Loader=SafeLineLoader) + stripped = SafeLineLoader.strip_metadata(spec['transforms']) + + self.assertFalse(hasattr(stripped[0], '__line__')) + self.assertFalse(hasattr(stripped[0], '__uuid__')) + + def test_strip_metadata_nothing_to_strip(self): + spec_yaml = 'prop: 123' + spec = yaml.load(spec_yaml, Loader=SafeLineLoader) + stripped = SafeLineLoader.strip_metadata(spec['prop']) + + self.assertFalse(hasattr(stripped, '__line__')) + self.assertFalse(hasattr(stripped, '__uuid__'))