Skip to content

Commit

Permalink
Add docs for per key inference (#28243)
Browse files Browse the repository at this point in the history
* Update KeyMhMapping to KeyModelMapping

* Add docs for per key inference

* Add piece on memory thrashing

* Whitespace

* Update wording based on feedback

* Add references to website in pydoc

* Apply suggestions from code review

Co-authored-by: Rebecca Szper <[email protected]>

* Remove ordering implied by wording

* Lint fixes

---------

Co-authored-by: Rebecca Szper <[email protected]>
  • Loading branch information
damccorm and rszper authored Oct 12, 2023
1 parent 4a7c484 commit 835bd65
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
9 changes: 8 additions & 1 deletion sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,12 @@ def __init__(
from the cohort. When model updates occur, the metrics will be reported in
the form `<cohort_key>-<model id>-<metric_name>`.
Loading multiple models at the same time can increase the risk of an out of
memory (OOM) exception. To avoid this issue, use the parameter
`max_models_per_worker_hint` to limit the number of models that are loaded
at the same time. For more information about memory management, see
`Use a keyed `ModelHandler <https://beam.apache.org/documentation/sdks/python-machine-learning/#use-a-keyed-modelhandler>_`. # pylint: disable=line-too-long
Args:
unkeyed: Either (a) an implementation of ModelHandler that does not
Expand All @@ -491,7 +497,8 @@ def __init__(
models can be held in memory at one time per worker process. For
example, if your worker has 8 GB of memory provisioned and your workers
take up 1 GB each, you should set this to 7 to allow all models to sit
in memory with some buffer.
in memory with some buffer. For more information about memory management,
see `Use a keyed `ModelHandler <https://beam.apache.org/documentation/sdks/python-machine-learning/#use-a-keyed-modelhandler>_`. # pylint: disable=line-too-long
"""
self._metrics_collectors: Dict[str, _MetricsCollector] = {}
self._default_metrics_collector: _MetricsCollector = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,54 @@ with pipeline as p:

If you are unsure if your data is keyed, you can also use `MaybeKeyedModelHandler`.

You can also use a `KeyedModelHandler` to load several different models based on their associated key:

```
from apache_beam.ml.inference.base import KeyedModelHandler
keyed_model_handler = KeyedModelHandler([
KeyModelMapping(['key1'], PytorchModelHandlerTensor(<config1>)),
KeyModelMapping(['key2', 'key3'], PytorchModelHandlerTensor(<config2>))
])
with pipeline as p:
data = p | beam.Create([
('key1', torch.tensor([[1,2,3],[4,5,6],...])),
('key2', torch.tensor([[1,2,3],[4,5,6],...])),
('key3', torch.tensor([[1,2,3],[4,5,6],...])),
])
predictions = data | RunInference(keyed_model_handler)
```

The previous example loads a model by using `config1`. That model is then used for inference for all examples associated
with `key1`. It also loads a model by using `config2`. That model is used for all examples associated with `key2` and `key3`.

Loading multiple models at the same times increases the risk of out of memory (OOM) errors. By default, `KeyedModelHandler` doesn't
limit the number of models loaded into memory at the same time. If the models don't all fit into memory,
your pipeline will likely fail with an out of memory error. To avoid this issue, provide a hint about the
maximum number of models that can be loaded at the same time.

```
mhs = [
KeyModelMapping(['key1'], PytorchModelHandlerTensor(<config1>)),
KeyModelMapping(['key2', 'key3'], PytorchModelHandlerTensor(<config2>)),
KeyModelMapping(['key4'], PytorchModelHandlerTensor(<config3>)),
KeyModelMapping(['key5', 'key6', 'key7'], PytorchModelHandlerTensor(<config4>)),
]
keyed_model_handler = KeyedModelHandler(mhs, max_models_per_worker_hint=2)
```

The previous example loads at most two models per SDK worker process at any given time. It unloads models that aren't
currently being used. Runners that have multiple SDK worker processes on a given machine load at most
`max_models_per_worker_hint*<num worker processes>` models onto the machine. Leave enough space for the models
and any additional memory needs from other transforms. Because there might be a delay between when a model is offloaded and when the
memory is released, it is recommended that you leave additional buffer.

**Note**: Having many models but a small `max_models_per_worker_hint` can lead to _memory thrashing_, where
a large amount of execution time is wasted swapping models in and out of memory. To reduce the likelihood and impact
of memory thrashing, if you're using a distributed runner, insert a
[GroupByKey](https://beam.apache.org/documentation/transforms/python/aggregation/groupbykey/) transform before your
inference step. This step reduces thrashing by ensuring that elements with the same key and model are
collocated on the same worker.

For more information, see [`KeyedModelHander`](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler).

### Use the `PredictionResult` object
Expand Down

0 comments on commit 835bd65

Please sign in to comment.