Skip to content

Commit

Permalink
Move RunInference GC to finalize_bundle (#32281)
Browse files Browse the repository at this point in the history
  • Loading branch information
damccorm authored Aug 22, 2024
1 parent 995724d commit 8488a0d
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,16 +1795,6 @@ def process(
side input is empty or the model has not been updated, the method
simply runs inference on the batch of data.
"""
# Do garbage collection of old models before starting inference.
tags_to_gc = self._model_metadata.get_tags_for_garbage_collection()
if len(tags_to_gc) > 0:
for unprefixed_tag in tags_to_gc:
for tag in _get_tags_for_copies(unprefixed_tag,
self._model_handler.model_copies()):
multi_process_shared.MultiProcessShared(lambda: None,
tag).unsafe_hard_delete()
self._model_metadata.mark_tags_deleted(tags_to_gc)

if not si_model_metadata:
if (not self._model_metadata.is_valid_tag(self._cur_tag) or
self._model is None):
Expand Down Expand Up @@ -1834,6 +1824,16 @@ def finish_bundle(self):
if self._metrics_collector:
self._metrics_collector.update_metrics_with_cache()

# Do garbage collection of old models
tags_to_gc = self._model_metadata.get_tags_for_garbage_collection()
if len(tags_to_gc) > 0:
for unprefixed_tag in tags_to_gc:
for tag in _get_tags_for_copies(unprefixed_tag,
self._model_handler.model_copies()):
multi_process_shared.MultiProcessShared(lambda: None,
tag).unsafe_hard_delete()
self._model_metadata.mark_tags_deleted(tags_to_gc)


def _is_darwin() -> bool:
return sys.platform == 'darwin'
Expand Down

0 comments on commit 8488a0d

Please sign in to comment.