Skip to content

Commit

Permalink
Add support for sharing models across steps (#31665)
Browse files Browse the repository at this point in the history
* Add support for sharing models across steps

* Tests + CHANGES
  • Loading branch information
damccorm authored Jun 21, 2024
1 parent 3695d49 commit 14fd366
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

## New Features / Improvements

* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Multiple RunInference instances can now share the same model instance by setting the model_identifier parameter (Python) ([#31665](https://github.com/apache/beam/issues/31665)).

## Breaking Changes

Expand Down
15 changes: 13 additions & 2 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,7 @@ def __init__(
*,
model_metadata_pcoll: beam.PCollection[ModelMetadata] = None,
watch_model_pattern: Optional[str] = None,
model_identifier: Optional[str] = None,
**kwargs):
"""
A transform that takes a PCollection of examples (or features) for use
Expand All @@ -1154,6 +1155,12 @@ def __init__(
to the _RunInferenceDoFn.
watch_model_pattern: A glob pattern used to watch a directory
for automatic model refresh.
model_identifier: A string used to identify the model being loaded. You
can set this if you want to reuse the same model across multiple
RunInference steps and don't want to reload it twice. Note that using
the same tag for different models will lead to non-deterministic
results, so exercise caution when using this parameter. This only
impacts models which are already being shared across processes.
"""
self._model_handler = model_handler
self._inference_args = inference_args
Expand All @@ -1164,8 +1171,12 @@ def __init__(
self._watch_model_pattern = watch_model_pattern
self._kwargs = kwargs
# Generate a random tag to use for shared.py and multi_process_shared.py to
# allow us to effectively disambiguate in multi-model settings.
self._model_tag = uuid.uuid4().hex
# allow us to effectively disambiguate in multi-model settings. Only use
# the same tag if the model being loaded across multiple steps is actually
# the same.
self._model_tag = model_identifier
if model_identifier is None:
self._model_tag = uuid.uuid4().hex

def annotations(self):
return {
Expand Down
39 changes: 39 additions & 0 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,45 @@ def test_model_manager_evicts_correct_num_of_models_after_being_incremented(
mh3.load_model, tag=tag3).acquire()
self.assertEqual(8, model3.predict(10))

def test_run_inference_loads_different_models(self):
mh1 = FakeModelHandler(incrementing=True, min_batch_size=3)
with TestPipeline() as pipeline:
pcoll = pipeline | 'start' >> beam.Create([1, 2, 3])
actual = (
pcoll
| 'ri1' >> base.RunInference(mh1)
| 'ri2' >> base.RunInference(mh1))
assert_that(actual, equal_to([1, 2, 3]), label='assert:inferences')

def test_run_inference_loads_different_models_multi_process_shared(self):
mh1 = FakeModelHandler(
incrementing=True, min_batch_size=3, multi_process_shared=True)
with TestPipeline() as pipeline:
pcoll = pipeline | 'start' >> beam.Create([1, 2, 3])
actual = (
pcoll
| 'ri1' >> base.RunInference(mh1)
| 'ri2' >> base.RunInference(mh1))
assert_that(actual, equal_to([1, 2, 3]), label='assert:inferences')

def test_runinference_loads_same_model_with_identifier_multi_process_shared(
self):
mh1 = FakeModelHandler(
incrementing=True, min_batch_size=3, multi_process_shared=True)
with TestPipeline() as pipeline:
pcoll = pipeline | 'start' >> beam.Create([1, 2, 3])
actual = (
pcoll
| 'ri1' >> base.RunInference(
mh1,
model_identifier='same_model_with_identifier_multi_process_shared'
)
| 'ri2' >> base.RunInference(
mh1,
model_identifier='same_model_with_identifier_multi_process_shared'
))
assert_that(actual, equal_to([4, 5, 6]), label='assert:inferences')

def test_run_inference_watch_file_pattern_side_input_label(self):
pipeline = TestPipeline()
# label of the WatchPattern transform.
Expand Down

0 comments on commit 14fd366

Please sign in to comment.