diff --git a/sdks/python/apache_beam/ml/transforms/base.py b/sdks/python/apache_beam/ml/transforms/base.py index a45928f5c8bf..49ce6e9ec1e0 100644 --- a/sdks/python/apache_beam/ml/transforms/base.py +++ b/sdks/python/apache_beam/ml/transforms/base.py @@ -25,6 +25,7 @@ from typing import TypeVar import apache_beam as beam +from apache_beam.metrics.metric import Metrics __all__ = ['MLTransform', 'ProcessHandler', 'BaseOperation'] @@ -88,6 +89,13 @@ def __call__(self, data: OperationInputT, transformed_data = {**transformed_data, **artifacts} return transformed_data + def get_counter(self): + """ + Returns the counter name for the operation. + """ + counter_name = self.__class__.__name__ + return Metrics.counter(MLTransform, f'BeamML_{counter_name}') + class ProcessHandler(Generic[ExampleT, MLTransformOutputT], abc.ABC): """ @@ -194,6 +202,9 @@ def __init__( transforms=transforms) # type: ignore[arg-type] self._process_handler = process_handler + self.transforms = transforms + self._counter = Metrics.counter( + MLTransform, f'BeamML_{self.__class__.__name__}') def expand( self, pcoll: beam.PCollection[ExampleT] @@ -209,8 +220,11 @@ def expand( Args: pcoll: A PCollection of ExampleT type. Returns: - A PCollection of MLTransformOutputT type. + A PCollection of MLTransformOutputT type """ + _ = ( + pcoll.pipeline + | "MLTransformMetricsUsage" >> MLTransformMetricsUsage(self)) return self._process_handler.process_data(pcoll) def with_transform(self, transform: BaseOperation): @@ -230,3 +244,26 @@ def _validate_transform(self, transform): raise TypeError( 'transform must be a subclass of BaseOperation. ' 'Got: %s instead.' % type(transform)) + + +class MLTransformMetricsUsage(beam.PTransform): + def __init__(self, ml_transform: MLTransform): + self._ml_transform = ml_transform + self._ml_transform._counter.inc() + + def expand(self, pipeline): + def _increment_counters(): + # increment for MLTransform. + self._ml_transform._counter.inc() + # increment if data processing transforms are passed. + transforms = ( + self._ml_transform.transforms or + self._ml_transform._process_handler.transforms) + if transforms: + for transform in transforms: + transform.get_counter().inc() + + _ = ( + pipeline + | beam.Create([None]) + | beam.Map(lambda _: _increment_counters())) diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py index df7a6d26b47c..2e447964541b 100644 --- a/sdks/python/apache_beam/ml/transforms/base_test.py +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -27,6 +27,7 @@ from parameterized import parameterized import apache_beam as beam +from apache_beam.metrics.metric import MetricsFilter from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to @@ -244,6 +245,30 @@ def test_ml_transforms_on_multiple_columns_multiple_transforms(self): equal_to(expected_output_y, equals_fn=np.array_equal), label='actual_output_y') + def test_mltransform_with_counter(self): + transforms = [ + tft.ComputeAndApplyVocabulary(columns=['y']), + tft.ScaleTo01(columns=['x']) + ] + data = [{'x': [1, 2, 3], 'y': ['a', 'b', 'c']}] + with beam.Pipeline() as p: + _ = ( + p | beam.Create(data) + | base.MLTransform( + transforms=transforms, + write_artifact_location=self.artifact_location)) + scale_to_01_counter = MetricsFilter().with_name('BeamML_ScaleTo01') + vocab_counter = MetricsFilter().with_name( + 'BeamML_ComputeAndApplyVocabulary') + mltransform_counter = MetricsFilter().with_name('BeamML_MLTransform') + result = p.result + self.assertEqual( + result.metrics().query(scale_to_01_counter)['counters'][0].result, 1) + self.assertEqual( + result.metrics().query(vocab_counter)['counters'][0].result, 1) + self.assertEqual( + result.metrics().query(mltransform_counter)['counters'][0].result, 1) + if __name__ == '__main__': unittest.main()