Skip to content

Commit

Permalink
Add counter for MLTransform and data processing transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
AnandInguva committed Oct 10, 2023
1 parent 2bccee1 commit c064082
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
39 changes: 38 additions & 1 deletion sdks/python/apache_beam/ml/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing import TypeVar

import apache_beam as beam
from apache_beam.metrics.metric import Metrics

__all__ = ['MLTransform', 'ProcessHandler', 'BaseOperation']

Expand Down Expand Up @@ -88,6 +89,13 @@ def __call__(self, data: OperationInputT,
transformed_data = {**transformed_data, **artifacts}
return transformed_data

def get_counter(self):
"""
Returns the counter name for the operation.
"""
counter_name = self.__class__.__name__
return Metrics.counter(MLTransform, f'BeamML_{counter_name}')


class ProcessHandler(Generic[ExampleT, MLTransformOutputT], abc.ABC):
"""
Expand Down Expand Up @@ -194,6 +202,9 @@ def __init__(
transforms=transforms) # type: ignore[arg-type]

self._process_handler = process_handler
self.transforms = transforms
self._counter = Metrics.counter(
MLTransform, f'BeamML_{self.__class__.__name__}')

def expand(
self, pcoll: beam.PCollection[ExampleT]
Expand All @@ -209,8 +220,11 @@ def expand(
Args:
pcoll: A PCollection of ExampleT type.
Returns:
A PCollection of MLTransformOutputT type.
A PCollection of MLTransformOutputT type
"""
_ = (
pcoll.pipeline
| "MLTransformMetricsUsage" >> MLTransformMetricsUsage(self))
return self._process_handler.process_data(pcoll)

def with_transform(self, transform: BaseOperation):
Expand All @@ -230,3 +244,26 @@ def _validate_transform(self, transform):
raise TypeError(
'transform must be a subclass of BaseOperation. '
'Got: %s instead.' % type(transform))


class MLTransformMetricsUsage(beam.PTransform):
def __init__(self, ml_transform: MLTransform):
self._ml_transform = ml_transform
self._ml_transform._counter.inc()

def expand(self, pipeline):
def _increment_counters():
# increment for MLTransform.
self._ml_transform._counter.inc()
# increment if data processing transforms are passed.
transforms = (
self._ml_transform.transforms or
self._ml_transform._process_handler.transforms)
if transforms:
for transform in transforms:
transform.get_counter().inc()

_ = (
pipeline
| beam.Create([None])
| beam.Map(lambda _: _increment_counters()))
25 changes: 25 additions & 0 deletions sdks/python/apache_beam/ml/transforms/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from parameterized import parameterized

import apache_beam as beam
from apache_beam.metrics.metric import MetricsFilter
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to

Expand Down Expand Up @@ -244,6 +245,30 @@ def test_ml_transforms_on_multiple_columns_multiple_transforms(self):
equal_to(expected_output_y, equals_fn=np.array_equal),
label='actual_output_y')

def test_mltransform_with_counter(self):
transforms = [
tft.ComputeAndApplyVocabulary(columns=['y']),
tft.ScaleTo01(columns=['x'])
]
data = [{'x': [1, 2, 3], 'y': ['a', 'b', 'c']}]
with beam.Pipeline() as p:
_ = (
p | beam.Create(data)
| base.MLTransform(
transforms=transforms,
write_artifact_location=self.artifact_location))
scale_to_01_counter = MetricsFilter().with_name('BeamML_ScaleTo01')
vocab_counter = MetricsFilter().with_name(
'BeamML_ComputeAndApplyVocabulary')
mltransform_counter = MetricsFilter().with_name('BeamML_MLTransform')
result = p.result
self.assertEqual(
result.metrics().query(scale_to_01_counter)['counters'][0].result, 1)
self.assertEqual(
result.metrics().query(vocab_counter)['counters'][0].result, 1)
self.assertEqual(
result.metrics().query(mltransform_counter)['counters'][0].result, 1)


if __name__ == '__main__':
unittest.main()

0 comments on commit c064082

Please sign in to comment.