Skip to content

Commit

Permalink
address comments and fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jeffrey Kinard <[email protected]>
  • Loading branch information
Polber committed Dec 18, 2024
1 parent c04712b commit 17a34c4
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 89 deletions.
166 changes: 79 additions & 87 deletions sdks/python/apache_beam/yaml/yaml_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,56 +40,21 @@
tft = None # type: ignore


def normalize_ml(spec):
if spec['type'] == 'RunInference':
config = spec.get('config')
for required in ('model_handler', ):
if required not in config:
raise ValueError(
f'Missing {required} parameter in RunInference config '
f'at line {SafeLineLoader.get_line(spec)}')
model_handler = config.get('model_handler')
if not isinstance(model_handler, dict):
raise ValueError(
'Invalid model_handler specification at line '
f'{SafeLineLoader.get_line(spec)}. Expected '
f'dict but was {type(model_handler)}.')
for required in ('type', 'config'):
if required not in model_handler:
raise ValueError(
f'Missing {required} in model handler '
f'at line {SafeLineLoader.get_line(model_handler)}')
typ = model_handler['type']
extra_params = set(SafeLineLoader.strip_metadata(model_handler).keys()) - {
'type', 'config'
}
if extra_params:
raise ValueError(
f'Unexpected parameters in model handler of type {typ} '
f'at line {SafeLineLoader.get_line(spec)}: {extra_params}')
model_handler_provider = ModelHandlerProvider.handler_types.get(typ, None)
if model_handler_provider:
model_handler_provider.validate(model_handler['config'])
else:
raise NotImplementedError(
f'Unknown model handler type: {typ} '
f'at line {SafeLineLoader.get_line(spec)}.')

return spec


class ModelHandlerProvider:
handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {}

def __init__(
self, handler, preprocess: Callable = None, postprocess: Callable = None):
self,
handler,
preprocess: Optional[Dict[str, str]] = None,
postprocess: Optional[Dict[str, str]] = None):
self._handler = handler
self._preprocess = self.parse_processing_transform(
preprocess, 'preprocess') or self.preprocess_fn
self._postprocess = self.parse_processing_transform(
postprocess, 'postprocess') or self.postprocess_fn
self._preprocess_fn = self.parse_processing_transform(
preprocess, 'preprocess') or self.default_preprocess_fn()
self._postprocess_fn = self.parse_processing_transform(
postprocess, 'postprocess') or self.default_postprocess_fn()

def get_output_schema(self):
def inference_output_type(self):
return Any

@staticmethod
Expand Down Expand Up @@ -118,21 +83,22 @@ def _parse_config(callable=None, path=None, name=None):
def underlying_handler(self):
return self._handler

def preprocess_fn(self, row):
@staticmethod
def default_preprocess_fn():
raise ValueError(
'Handler does not implement a default preprocess '
'method. Please define a preprocessing method using the '
'\'preprocess\' tag.')

def create_preprocess_fn(self):
return lambda row: (row, self._preprocess(row))
def _preprocess_fn_internal(self):
return lambda row: (row, self._preprocess_fn(row))

@staticmethod
def postprocess_fn(x):
return x
def default_postprocess_fn():
return lambda x: x

def create_postprocess_fn(self):
return lambda result: (result[0], self._postprocess(result[1]))
def _postprocess_fn_internal(self):
return lambda result: (result[0], self._postprocess_fn(result[1]))

@staticmethod
def validate(model_handler_spec):
Expand Down Expand Up @@ -165,18 +131,19 @@ class VertexAIModelHandlerJSONProvider(ModelHandlerProvider):
def __init__(
self,
endpoint_id: str,
endpoint_project: str,
endpoint_region: str,
project: str,
location: str,
preprocess: Dict[str, str],
experiment: Optional[str] = None,
network: Optional[str] = None,
private: bool = False,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
max_batch_duration_secs: Optional[int] = None,
env_vars=None,
preprocess: Callable = None,
postprocess: Callable = None):
"""ModelHandler for Vertex AI.
env_vars: Optional[Dict[str, Any]] = None,
postprocess: Optional[Dict[str, str]] = None):
"""
ModelHandler for Vertex AI.
For example: ::
Expand All @@ -187,37 +154,43 @@ def __init__(
type: VertexAIModelHandlerJSON
config:
endpoint_id: 9876543210
endpoint_project: 1234567890
endpoint_region: us-east1
project: my-project
location: us-east1
preprocess:
callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}'
Args:
endpoint_id: the numerical ID of the Vertex AI endpoint to query.
endpoint_project: the GCP project name where the endpoint is deployed.
endpoint_region: the GCP location where the endpoint is deployed.
experiment: optional. experiment label to apply to the
project: the GCP project name where the endpoint is deployed.
location: the GCP location where the endpoint is deployed.
experiment: Experiment label to apply to the
queries. See
https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments
for more information.
network: optional. the full name of the Compute Engine
network: The full name of the Compute Engine
network the endpoint is deployed on; used for private
endpoints. The network or subnetwork Dataflow pipeline
option must be set and match this network for pipeline
execution.
Ex: "projects/12345/global/networks/myVPC"
private: optional. if the deployed Vertex AI endpoint is
private: If the deployed Vertex AI endpoint is
private, set to true. Requires a network to be provided
as well.
min_batch_size: optional. the minimum batch size to use when batching
min_batch_size: The minimum batch size to use when batching
inputs.
max_batch_size: optional. the maximum batch size to use when batching
max_batch_size: The maximum batch size to use when batching
inputs.
max_batch_duration_secs: optional. the maximum amount of time to buffer
max_batch_duration_secs: The maximum amount of time to buffer
a batch before emitting; used in streaming contexts.
env_vars: Environment variables.
preprocess:
postprocess:
preprocess: A python callable, defined either inline, or using a file,
that is invoked on the input row before sending to the model to be
loaded by this ModelHandler. This parameter is required by the
`VertexAIModelHandlerJSON` ModelHandler.
postprocess: A python callable, defined either inline, or using a file,
that is invoked on the PredictionResult output by the ModelHandler
before parsing into the output Beam Row under the field name defined
by the inference_tag.
"""

try:
Expand All @@ -229,8 +202,8 @@ def __init__(

_handler = VertexAIModelHandlerJSON(
endpoint_id=str(endpoint_id),
project=endpoint_project,
location=endpoint_region,
project=project,
location=location,
experiment=experiment,
network=network,
private=private,
Expand All @@ -243,18 +216,15 @@ def __init__(

@staticmethod
def validate(model_handler_spec):
for required in ('endpoint_id', 'endpoint_project', 'endpoint_region'):
if required not in model_handler_spec:
raise ValueError(
f'Missing {required} in model handler '
f'at line {SafeLineLoader.get_line(model_handler_spec)}')
pass

def get_output_schema(self):
def inference_output_type(self):
return RowTypeConstraint.from_fields([('example', Any), ('inference', Any),
('model_id', Optional[str])])

def create_postprocess_fn(self):
return lambda x: (x[0], beam.Row(**self._postprocess(x[1])._asdict()))
@staticmethod
def default_postprocess_fn():
return lambda x: beam.Row(**x._asdict())


@beam.ptransform.ptransform_fn
Expand Down Expand Up @@ -349,11 +319,11 @@ def run_inference(
callable: 'lambda row: {"prompt": row.question}'
...
In the above example, the Create transform generates a collection of two
elements, each with a single field - "question". The model, however, expects
a Python Dict with a single key, "prompt". In this case, we can specify a
simple Lambda function (alternatively could define a full function), to map
the data.
In the above example, the Create transform generates a collection of two Beam
Row elements, each with a single field - "question". The model, however,
expects a Python Dict with a single key, "prompt". In this case, we can
specify a simple Lambda function (alternatively could define a full function),
to map the data.
### Postprocessing predictions
Expand Down Expand Up @@ -428,19 +398,41 @@ def fn(x: PredictionResult):

options.YamlOptions.check_enabled(pcoll.pipeline, 'ML')

if not isinstance(model_handler, dict):
raise ValueError(
'Invalid model_handler specification. Expected dict but was '
f'{type(model_handler)}.')
expected_model_handler_params = {'type', 'config'}
given_model_handler_params = set(
SafeLineLoader.strip_metadata(model_handler).keys())
extra_params = given_model_handler_params - expected_model_handler_params
if extra_params:
raise ValueError(f'Unexpected parameters in model_handler: {extra_params}')
missing_params = expected_model_handler_params - given_model_handler_params
if missing_params:
raise ValueError(f'Missing parameters in model_handler: {missing_params}')
typ = model_handler['type']
model_handler_provider = ModelHandlerProvider.handler_types.get(typ, None)
if model_handler_provider and issubclass(model_handler_provider,
ModelHandlerProvider):
model_handler_provider.validate(model_handler['config'])
else:
raise NotImplementedError(f'Unknown model handler type: {typ}.')

model_handler_provider = ModelHandlerProvider.create_handler(model_handler)
schema = RowTypeConstraint.from_fields(
list(
RowTypeConstraint.from_user_type(
pcoll.element_type.user_type)._fields) +
[(inference_tag, model_handler_provider.get_output_schema())])
[(inference_tag, model_handler_provider.inference_output_type())])

return (
pcoll | RunInference(
model_handler=KeyedModelHandler(
model_handler_provider.underlying_handler()).with_preprocess_fn(
model_handler_provider.create_preprocess_fn()).
with_postprocess_fn(model_handler_provider.create_postprocess_fn()),
model_handler_provider._preprocess_fn_internal()).
with_postprocess_fn(
model_handler_provider._postprocess_fn_internal()),
inference_args=inference_args)
| beam.Map(
lambda row: beam.Row(**{
Expand Down
2 changes: 0 additions & 2 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from apache_beam.yaml.yaml_combine import normalize_combine
from apache_beam.yaml.yaml_mapping import normalize_mapping
from apache_beam.yaml.yaml_mapping import validate_generic_expressions
from apache_beam.yaml.yaml_ml import normalize_ml
from apache_beam.yaml.yaml_utils import SafeLineLoader

__all__ = ["YamlTransform"]
Expand Down Expand Up @@ -907,7 +906,6 @@ def preprocess_languages(spec):
ensure_transforms_have_types,
normalize_mapping,
normalize_combine,
normalize_ml,
preprocess_languages,
ensure_transforms_have_providers,
preprocess_source_sink,
Expand Down
17 changes: 17 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import uuid
from typing import Iterable
from typing import Mapping
Expand Down
22 changes: 22 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

import yaml
Expand Down Expand Up @@ -54,3 +71,8 @@ def test_strip_metadata_nothing_to_strip(self):

self.assertFalse(hasattr(stripped, '__line__'))
self.assertFalse(hasattr(stripped, '__uuid__'))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()

0 comments on commit 17a34c4

Please sign in to comment.