diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index b075f22f15be..5cc38cd842c9 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -1795,16 +1795,6 @@ def process( side input is empty or the model has not been updated, the method simply runs inference on the batch of data. """ - # Do garbage collection of old models before starting inference. - tags_to_gc = self._model_metadata.get_tags_for_garbage_collection() - if len(tags_to_gc) > 0: - for unprefixed_tag in tags_to_gc: - for tag in _get_tags_for_copies(unprefixed_tag, - self._model_handler.model_copies()): - multi_process_shared.MultiProcessShared(lambda: None, - tag).unsafe_hard_delete() - self._model_metadata.mark_tags_deleted(tags_to_gc) - if not si_model_metadata: if (not self._model_metadata.is_valid_tag(self._cur_tag) or self._model is None): @@ -1834,6 +1824,16 @@ def finish_bundle(self): if self._metrics_collector: self._metrics_collector.update_metrics_with_cache() + # Do garbage collection of old models + tags_to_gc = self._model_metadata.get_tags_for_garbage_collection() + if len(tags_to_gc) > 0: + for unprefixed_tag in tags_to_gc: + for tag in _get_tags_for_copies(unprefixed_tag, + self._model_handler.model_copies()): + multi_process_shared.MultiProcessShared(lambda: None, + tag).unsafe_hard_delete() + self._model_metadata.mark_tags_deleted(tags_to_gc) + def _is_darwin() -> bool: return sys.platform == 'darwin'