From 17a34c43768f2ab2bde2140d68e72525357f2bbe Mon Sep 17 00:00:00 2001 From: Jeffrey Kinard Date: Wed, 18 Dec 2024 15:00:25 -0500 Subject: [PATCH] address comments and fix tests Signed-off-by: Jeffrey Kinard --- sdks/python/apache_beam/yaml/yaml_ml.py | 166 +++++++++--------- .../python/apache_beam/yaml/yaml_transform.py | 2 - sdks/python/apache_beam/yaml/yaml_utils.py | 17 ++ .../apache_beam/yaml/yaml_utils_test.py | 22 +++ 4 files changed, 118 insertions(+), 89 deletions(-) diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py b/sdks/python/apache_beam/yaml/yaml_ml.py index 111618679f96..f60cc8f71adf 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml.py +++ b/sdks/python/apache_beam/yaml/yaml_ml.py @@ -40,56 +40,21 @@ tft = None # type: ignore -def normalize_ml(spec): - if spec['type'] == 'RunInference': - config = spec.get('config') - for required in ('model_handler', ): - if required not in config: - raise ValueError( - f'Missing {required} parameter in RunInference config ' - f'at line {SafeLineLoader.get_line(spec)}') - model_handler = config.get('model_handler') - if not isinstance(model_handler, dict): - raise ValueError( - 'Invalid model_handler specification at line ' - f'{SafeLineLoader.get_line(spec)}. Expected ' - f'dict but was {type(model_handler)}.') - for required in ('type', 'config'): - if required not in model_handler: - raise ValueError( - f'Missing {required} in model handler ' - f'at line {SafeLineLoader.get_line(model_handler)}') - typ = model_handler['type'] - extra_params = set(SafeLineLoader.strip_metadata(model_handler).keys()) - { - 'type', 'config' - } - if extra_params: - raise ValueError( - f'Unexpected parameters in model handler of type {typ} ' - f'at line {SafeLineLoader.get_line(spec)}: {extra_params}') - model_handler_provider = ModelHandlerProvider.handler_types.get(typ, None) - if model_handler_provider: - model_handler_provider.validate(model_handler['config']) - else: - raise NotImplementedError( - f'Unknown model handler type: {typ} ' - f'at line {SafeLineLoader.get_line(spec)}.') - - return spec - - class ModelHandlerProvider: handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {} def __init__( - self, handler, preprocess: Callable = None, postprocess: Callable = None): + self, + handler, + preprocess: Optional[Dict[str, str]] = None, + postprocess: Optional[Dict[str, str]] = None): self._handler = handler - self._preprocess = self.parse_processing_transform( - preprocess, 'preprocess') or self.preprocess_fn - self._postprocess = self.parse_processing_transform( - postprocess, 'postprocess') or self.postprocess_fn + self._preprocess_fn = self.parse_processing_transform( + preprocess, 'preprocess') or self.default_preprocess_fn() + self._postprocess_fn = self.parse_processing_transform( + postprocess, 'postprocess') or self.default_postprocess_fn() - def get_output_schema(self): + def inference_output_type(self): return Any @staticmethod @@ -118,21 +83,22 @@ def _parse_config(callable=None, path=None, name=None): def underlying_handler(self): return self._handler - def preprocess_fn(self, row): + @staticmethod + def default_preprocess_fn(): raise ValueError( 'Handler does not implement a default preprocess ' 'method. Please define a preprocessing method using the ' '\'preprocess\' tag.') - def create_preprocess_fn(self): - return lambda row: (row, self._preprocess(row)) + def _preprocess_fn_internal(self): + return lambda row: (row, self._preprocess_fn(row)) @staticmethod - def postprocess_fn(x): - return x + def default_postprocess_fn(): + return lambda x: x - def create_postprocess_fn(self): - return lambda result: (result[0], self._postprocess(result[1])) + def _postprocess_fn_internal(self): + return lambda result: (result[0], self._postprocess_fn(result[1])) @staticmethod def validate(model_handler_spec): @@ -165,18 +131,19 @@ class VertexAIModelHandlerJSONProvider(ModelHandlerProvider): def __init__( self, endpoint_id: str, - endpoint_project: str, - endpoint_region: str, + project: str, + location: str, + preprocess: Dict[str, str], experiment: Optional[str] = None, network: Optional[str] = None, private: bool = False, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, max_batch_duration_secs: Optional[int] = None, - env_vars=None, - preprocess: Callable = None, - postprocess: Callable = None): - """ModelHandler for Vertex AI. + env_vars: Optional[Dict[str, Any]] = None, + postprocess: Optional[Dict[str, str]] = None): + """ + ModelHandler for Vertex AI. For example: :: @@ -187,37 +154,43 @@ def __init__( type: VertexAIModelHandlerJSON config: endpoint_id: 9876543210 - endpoint_project: 1234567890 - endpoint_region: us-east1 + project: my-project + location: us-east1 preprocess: callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}' Args: endpoint_id: the numerical ID of the Vertex AI endpoint to query. - endpoint_project: the GCP project name where the endpoint is deployed. - endpoint_region: the GCP location where the endpoint is deployed. - experiment: optional. experiment label to apply to the + project: the GCP project name where the endpoint is deployed. + location: the GCP location where the endpoint is deployed. + experiment: Experiment label to apply to the queries. See https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments for more information. - network: optional. the full name of the Compute Engine + network: The full name of the Compute Engine network the endpoint is deployed on; used for private endpoints. The network or subnetwork Dataflow pipeline option must be set and match this network for pipeline execution. Ex: "projects/12345/global/networks/myVPC" - private: optional. if the deployed Vertex AI endpoint is + private: If the deployed Vertex AI endpoint is private, set to true. Requires a network to be provided as well. - min_batch_size: optional. the minimum batch size to use when batching + min_batch_size: The minimum batch size to use when batching inputs. - max_batch_size: optional. the maximum batch size to use when batching + max_batch_size: The maximum batch size to use when batching inputs. - max_batch_duration_secs: optional. the maximum amount of time to buffer + max_batch_duration_secs: The maximum amount of time to buffer a batch before emitting; used in streaming contexts. env_vars: Environment variables. - preprocess: - postprocess: + preprocess: A python callable, defined either inline, or using a file, + that is invoked on the input row before sending to the model to be + loaded by this ModelHandler. This parameter is required by the + `VertexAIModelHandlerJSON` ModelHandler. + postprocess: A python callable, defined either inline, or using a file, + that is invoked on the PredictionResult output by the ModelHandler + before parsing into the output Beam Row under the field name defined + by the inference_tag. """ try: @@ -229,8 +202,8 @@ def __init__( _handler = VertexAIModelHandlerJSON( endpoint_id=str(endpoint_id), - project=endpoint_project, - location=endpoint_region, + project=project, + location=location, experiment=experiment, network=network, private=private, @@ -243,18 +216,15 @@ def __init__( @staticmethod def validate(model_handler_spec): - for required in ('endpoint_id', 'endpoint_project', 'endpoint_region'): - if required not in model_handler_spec: - raise ValueError( - f'Missing {required} in model handler ' - f'at line {SafeLineLoader.get_line(model_handler_spec)}') + pass - def get_output_schema(self): + def inference_output_type(self): return RowTypeConstraint.from_fields([('example', Any), ('inference', Any), ('model_id', Optional[str])]) - def create_postprocess_fn(self): - return lambda x: (x[0], beam.Row(**self._postprocess(x[1])._asdict())) + @staticmethod + def default_postprocess_fn(): + return lambda x: beam.Row(**x._asdict()) @beam.ptransform.ptransform_fn @@ -349,11 +319,11 @@ def run_inference( callable: 'lambda row: {"prompt": row.question}' ... - In the above example, the Create transform generates a collection of two - elements, each with a single field - "question". The model, however, expects - a Python Dict with a single key, "prompt". In this case, we can specify a - simple Lambda function (alternatively could define a full function), to map - the data. + In the above example, the Create transform generates a collection of two Beam + Row elements, each with a single field - "question". The model, however, + expects a Python Dict with a single key, "prompt". In this case, we can + specify a simple Lambda function (alternatively could define a full function), + to map the data. ### Postprocessing predictions @@ -428,19 +398,41 @@ def fn(x: PredictionResult): options.YamlOptions.check_enabled(pcoll.pipeline, 'ML') + if not isinstance(model_handler, dict): + raise ValueError( + 'Invalid model_handler specification. Expected dict but was ' + f'{type(model_handler)}.') + expected_model_handler_params = {'type', 'config'} + given_model_handler_params = set( + SafeLineLoader.strip_metadata(model_handler).keys()) + extra_params = given_model_handler_params - expected_model_handler_params + if extra_params: + raise ValueError(f'Unexpected parameters in model_handler: {extra_params}') + missing_params = expected_model_handler_params - given_model_handler_params + if missing_params: + raise ValueError(f'Missing parameters in model_handler: {missing_params}') + typ = model_handler['type'] + model_handler_provider = ModelHandlerProvider.handler_types.get(typ, None) + if model_handler_provider and issubclass(model_handler_provider, + ModelHandlerProvider): + model_handler_provider.validate(model_handler['config']) + else: + raise NotImplementedError(f'Unknown model handler type: {typ}.') + model_handler_provider = ModelHandlerProvider.create_handler(model_handler) schema = RowTypeConstraint.from_fields( list( RowTypeConstraint.from_user_type( pcoll.element_type.user_type)._fields) + - [(inference_tag, model_handler_provider.get_output_schema())]) + [(inference_tag, model_handler_provider.inference_output_type())]) return ( pcoll | RunInference( model_handler=KeyedModelHandler( model_handler_provider.underlying_handler()).with_preprocess_fn( - model_handler_provider.create_preprocess_fn()). - with_postprocess_fn(model_handler_provider.create_postprocess_fn()), + model_handler_provider._preprocess_fn_internal()). + with_postprocess_fn( + model_handler_provider._postprocess_fn_internal()), inference_args=inference_args) | beam.Map( lambda row: beam.Row(**{ diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index c83cd53b5855..c6f5544c650a 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -40,7 +40,6 @@ from apache_beam.yaml.yaml_combine import normalize_combine from apache_beam.yaml.yaml_mapping import normalize_mapping from apache_beam.yaml.yaml_mapping import validate_generic_expressions -from apache_beam.yaml.yaml_ml import normalize_ml from apache_beam.yaml.yaml_utils import SafeLineLoader __all__ = ["YamlTransform"] @@ -907,7 +906,6 @@ def preprocess_languages(spec): ensure_transforms_have_types, normalize_mapping, normalize_combine, - normalize_ml, preprocess_languages, ensure_transforms_have_providers, preprocess_source_sink, diff --git a/sdks/python/apache_beam/yaml/yaml_utils.py b/sdks/python/apache_beam/yaml/yaml_utils.py index 91cad4175fdb..63beb90f0711 100644 --- a/sdks/python/apache_beam/yaml/yaml_utils.py +++ b/sdks/python/apache_beam/yaml/yaml_utils.py @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import uuid from typing import Iterable from typing import Mapping diff --git a/sdks/python/apache_beam/yaml/yaml_utils_test.py b/sdks/python/apache_beam/yaml/yaml_utils_test.py index 23105a4e82d4..70f6ba9b5198 100644 --- a/sdks/python/apache_beam/yaml/yaml_utils_test.py +++ b/sdks/python/apache_beam/yaml/yaml_utils_test.py @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import unittest import yaml @@ -54,3 +71,8 @@ def test_strip_metadata_nothing_to_strip(self): self.assertFalse(hasattr(stripped, '__line__')) self.assertFalse(hasattr(stripped, '__uuid__')) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main()