From 7732e11acb7e215649748f0cff32212a4c0d777d Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Wed, 12 Jun 2024 12:03:45 -0400 Subject: [PATCH] Add image embedding support to TFHub MLTransforms (#31564) * Add image embedding support to TFHub MLTransforms * linting * more linting * formatting * typo --- .../transforms/embeddings/tensorflow_hub.py | 42 ++++++++++++- .../embeddings/tensorflow_hub_test.py | 60 +++++++++++++++++++ 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py index 9e4480788257..f78ddf3ff04a 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py @@ -29,9 +29,10 @@ from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor from apache_beam.ml.inference.tensorflow_inference import default_tensor_inference_fn from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _ImageEmbeddingHandler from apache_beam.ml.transforms.base import _TextEmbeddingHandler -__all__ = ['TensorflowHubTextEmbeddings'] +__all__ = ['TensorflowHubTextEmbeddings', 'TensorflowHubImageEmbeddings'] # TODO: https://github.com/apache/beam/issues/30288 @@ -132,3 +133,42 @@ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: model_handler=_TextEmbeddingHandler(self), inference_args=self.inference_args, )) + + +class TensorflowHubImageEmbeddings(EmbeddingsManager): + def __init__(self, columns: List[str], hub_url: str, **kwargs): + """ + Embedding config for tensorflow hub models. This config can be used with + MLTransform to embed image data. Models are loaded using the RunInference + PTransform with the help of a ModelHandler. + + Args: + columns: The columns containing the images to be embedded. + hub_url: The url of the tensorflow hub model. + min_batch_size: The minimum batch size to be used for inference. + max_batch_size: The maximum batch size to be used for inference. + large_model: Whether to share the model across processes. + """ + super().__init__(columns=columns, **kwargs) + self.model_uri = hub_url + + def get_model_handler(self) -> ModelHandler: + # override the default inference function + return _TensorflowHubModelHandler( + model_uri=self.model_uri, + preprocessing_url=None, + min_batch_size=self.min_batch_size, + max_batch_size=self.max_batch_size, + large_model=self.large_model, + ) + + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + """ + Returns a RunInference object that is used to run inference on the text + input using _ImageEmbeddingHandler. + """ + return ( + RunInference( + model_handler=_ImageEmbeddingHandler(self), + inference_args=self.inference_args, + )) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py index b08ca8e2d8ea..24bca5155fa7 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py @@ -20,10 +20,13 @@ import unittest import uuid +import numpy as np + import apache_beam as beam from apache_beam.ml.transforms.base import MLTransform hub_url = 'https://tfhub.dev/google/nnlm-en-dim128/2' +hub_img_url = 'https://www.kaggle.com/models/google/resnet-v2/TensorFlow2/101-feature-vector/2' # pylint: disable=line-too-long test_query_column = 'test_query' test_query = 'This is a test query' @@ -32,6 +35,7 @@ from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubTextEmbeddings except ImportError: TensorflowHubTextEmbeddings = None # type: ignore + tf = None # pylint: disable=ungrouped-imports try: @@ -40,6 +44,14 @@ except ImportError: tft = None +# pylint: disable=ungrouped-imports +try: + from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubImageEmbeddings + from PIL import Image +except ImportError: + TensorflowHubImageEmbeddings = None # type: ignore + Image = None + @unittest.skipIf( TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.') @@ -161,6 +173,54 @@ def test_with_int_data_types(self): embedding_config)) +@unittest.skipIf( + TensorflowHubImageEmbeddings is None, 'Tensorflow is not installed.') +class TFHubImageEmbeddingsTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self) -> None: + shutil.rmtree(self.artifact_location) + + def generateRandomImage(self, size: int): + imarray = np.random.rand(size, size, 3) * 255 + return imarray / 255.0 + + @unittest.skipIf(Image is None, 'Pillow is not installed.') + def test_sentence_transformer_image_embeddings(self): + embedding_config = TensorflowHubImageEmbeddings( + hub_url=hub_img_url, columns=[test_query_column]) + img = self.generateRandomImage(224) + with beam.Pipeline() as pipeline: + result_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: img + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + def assert_element(element): + assert len(element[test_query_column]) == 2048 + + _ = (result_pcoll | beam.Map(assert_element)) + + def test_with_str_data_types(self): + embedding_config = TensorflowHubImageEmbeddings( + hub_url=hub_img_url, columns=[test_query_column]) + with self.assertRaises(TypeError): + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: "img.jpg" + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + @unittest.skipIf( TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.') class TFHubEmbeddingsGCSArtifactLocationTest(TFHubEmbeddingsTest):