diff --git a/CHANGES.md b/CHANGES.md index 0c2c2e3f79f4..d1a0a36b8188 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -79,6 +79,7 @@ ## Bugfixes * Fixed "Desired bundle size 0 bytes must be greater than 0" in Java SDK's BigtableIO.BigtableSource when you have more cores than bytes to read (Java) [#28793](https://github.com/apache/beam/issues/28793). +* `watch_file_pattern` arg of the [RunInference](https://github.com/apache/beam/blob/104c10b3ee536a9a3ea52b4dbf62d86b669da5d9/sdks/python/apache_beam/ml/inference/base.py#L997) arg had no effect prior to 2.52.0. To use the behavior of arg `watch_file_pattern` prior to 2.52.0, follow the documentation at https://beam.apache.org/documentation/ml/side-input-updates/ and use `WatchFilePattern` PTransform as a SideInput. ([#28948](https://github.com/apache/beam/pulls/28948)) ## Security Fixes * Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)). diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 90d43cfddb94..ef47334c22ac 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -1025,7 +1025,6 @@ def __init__( self._clock = clock self._metrics_namespace = metrics_namespace self._model_metadata_pcoll = model_metadata_pcoll - self._enable_side_input_loading = self._model_metadata_pcoll is not None self._with_exception_handling = False self._watch_model_pattern = watch_model_pattern self._kwargs = kwargs @@ -1126,12 +1125,12 @@ def expand( self._model_handler, self._clock, self._metrics_namespace, - self._enable_side_input_loading, + self._model_metadata_pcoll is not None, self._model_tag), self._inference_args, beam.pvalue.AsSingleton( self._model_metadata_pcoll, - ) if self._enable_side_input_loading else None).with_resource_hints( + ) if self._model_metadata_pcoll else None).with_resource_hints( **resource_hints) if self._with_exception_handling: diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 1b1a7393872c..7075810ff0f0 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -1513,6 +1513,28 @@ def test_model_manager_evicts_correct_num_of_models_after_being_incremented( mh3.load_model, tag=tag3).acquire() self.assertEqual(8, model3.predict(10)) + def test_run_inference_watch_file_pattern_side_input_label(self): + pipeline = TestPipeline() + # label of the WatchPattern transform. + side_input_str = 'WatchFilePattern/ApplyGlobalWindow' + from apache_beam.ml.inference.utils import WatchFilePattern + file_pattern_side_input = ( + pipeline + | 'WatchFilePattern' >> WatchFilePattern(file_pattern='fake/path/*')) + pcoll = pipeline | 'start' >> beam.Create([1, 2, 3]) + result_pcoll = pcoll | base.RunInference( + FakeModelHandler(), model_metadata_pcoll=file_pattern_side_input) + assert side_input_str in str(result_pcoll.producer.side_inputs[0]) + + def test_run_inference_watch_file_pattern_keyword_arg_side_input_label(self): + # label of the WatchPattern transform. + side_input_str = 'WatchFilePattern/ApplyGlobalWindow' + pipeline = TestPipeline() + pcoll = pipeline | 'start' >> beam.Create([1, 2, 3]) + result_pcoll = pcoll | base.RunInference( + FakeModelHandler(), watch_model_pattern='fake/path/*') + assert side_input_str in str(result_pcoll.producer.side_inputs[0]) + if __name__ == '__main__': unittest.main()