Skip to content

Commit

Permalink
Add image embedding support to TFHub MLTransforms (#31564)
Browse files Browse the repository at this point in the history
* Add image embedding support to TFHub MLTransforms

* linting

* more linting

* formatting

* typo
  • Loading branch information
jrmccluskey authored Jun 12, 2024
1 parent c001c3a commit 7732e11
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
from apache_beam.ml.inference.tensorflow_inference import default_tensor_inference_fn
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler
from apache_beam.ml.transforms.base import _TextEmbeddingHandler

__all__ = ['TensorflowHubTextEmbeddings']
__all__ = ['TensorflowHubTextEmbeddings', 'TensorflowHubImageEmbeddings']


# TODO: https://github.com/apache/beam/issues/30288
Expand Down Expand Up @@ -132,3 +133,42 @@ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
model_handler=_TextEmbeddingHandler(self),
inference_args=self.inference_args,
))


class TensorflowHubImageEmbeddings(EmbeddingsManager):
def __init__(self, columns: List[str], hub_url: str, **kwargs):
"""
Embedding config for tensorflow hub models. This config can be used with
MLTransform to embed image data. Models are loaded using the RunInference
PTransform with the help of a ModelHandler.
Args:
columns: The columns containing the images to be embedded.
hub_url: The url of the tensorflow hub model.
min_batch_size: The minimum batch size to be used for inference.
max_batch_size: The maximum batch size to be used for inference.
large_model: Whether to share the model across processes.
"""
super().__init__(columns=columns, **kwargs)
self.model_uri = hub_url

def get_model_handler(self) -> ModelHandler:
# override the default inference function
return _TensorflowHubModelHandler(
model_uri=self.model_uri,
preprocessing_url=None,
min_batch_size=self.min_batch_size,
max_batch_size=self.max_batch_size,
large_model=self.large_model,
)

def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
"""
Returns a RunInference object that is used to run inference on the text
input using _ImageEmbeddingHandler.
"""
return (
RunInference(
model_handler=_ImageEmbeddingHandler(self),
inference_args=self.inference_args,
))
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
import unittest
import uuid

import numpy as np

import apache_beam as beam
from apache_beam.ml.transforms.base import MLTransform

hub_url = 'https://tfhub.dev/google/nnlm-en-dim128/2'
hub_img_url = 'https://www.kaggle.com/models/google/resnet-v2/TensorFlow2/101-feature-vector/2' # pylint: disable=line-too-long
test_query_column = 'test_query'
test_query = 'This is a test query'

Expand All @@ -32,6 +35,7 @@
from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubTextEmbeddings
except ImportError:
TensorflowHubTextEmbeddings = None # type: ignore
tf = None

# pylint: disable=ungrouped-imports
try:
Expand All @@ -40,6 +44,14 @@
except ImportError:
tft = None

# pylint: disable=ungrouped-imports
try:
from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubImageEmbeddings
from PIL import Image
except ImportError:
TensorflowHubImageEmbeddings = None # type: ignore
Image = None


@unittest.skipIf(
TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.')
Expand Down Expand Up @@ -161,6 +173,54 @@ def test_with_int_data_types(self):
embedding_config))


@unittest.skipIf(
TensorflowHubImageEmbeddings is None, 'Tensorflow is not installed.')
class TFHubImageEmbeddingsTest(unittest.TestCase):
def setUp(self) -> None:
self.artifact_location = tempfile.mkdtemp()

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)

def generateRandomImage(self, size: int):
imarray = np.random.rand(size, size, 3) * 255
return imarray / 255.0

@unittest.skipIf(Image is None, 'Pillow is not installed.')
def test_sentence_transformer_image_embeddings(self):
embedding_config = TensorflowHubImageEmbeddings(
hub_url=hub_img_url, columns=[test_query_column])
img = self.generateRandomImage(224)
with beam.Pipeline() as pipeline:
result_pcoll = (
pipeline
| "CreateData" >> beam.Create([{
test_query_column: img
}])
| "MLTransform" >> MLTransform(
write_artifact_location=self.artifact_location).with_transform(
embedding_config))

def assert_element(element):
assert len(element[test_query_column]) == 2048

_ = (result_pcoll | beam.Map(assert_element))

def test_with_str_data_types(self):
embedding_config = TensorflowHubImageEmbeddings(
hub_url=hub_img_url, columns=[test_query_column])
with self.assertRaises(TypeError):
with beam.Pipeline() as pipeline:
_ = (
pipeline
| "CreateData" >> beam.Create([{
test_query_column: "img.jpg"
}])
| "MLTransform" >> MLTransform(
write_artifact_location=self.artifact_location).with_transform(
embedding_config))


@unittest.skipIf(
TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.')
class TFHubEmbeddingsGCSArtifactLocationTest(TFHubEmbeddingsTest):
Expand Down

0 comments on commit 7732e11

Please sign in to comment.