Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Hugging Face Image Embedding MLTransform #31536

Merged
merged 3 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler
from apache_beam.ml.transforms.base import _TextEmbeddingHandler

try:
Expand Down Expand Up @@ -153,6 +154,45 @@ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
))


class SentenceTransformerImageEmbeddings(EmbeddingsManager):
def __init__(self, model_name: str, columns: List[str], **kwargs):
"""
Embedding config for sentence-transformers. This config can be used with
MLTransform to embed image data. Models are loaded using the RunInference
PTransform with the help of ModelHandler.

Args:
model_name: Name of the model to use. The model should be hosted on
HuggingFace Hub or compatible with sentence_transformers. See
https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#image-text-models # pylint: disable=line-too-long
for a list of sentence_transformers models.
columns: List of columns to be embedded.
min_batch_size: The minimum batch size to be used for inference.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add all these options to __init__? similarly to SentenceTransformerEmbeddings?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit of a weird case where those parameters are passed up as kwargs and handled by the EmbeddingsManager. I'd be okay to explicitly have these in the constructor though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this wound up looking cleaner, PTAL

max_batch_size: The maximum batch size to be used for inference.
large_model: Whether to share the model across processes.
"""
super().__init__(columns, **kwargs)
self.model_name = model_name

def get_model_handler(self):
return _SentenceTransformerModelHandler(
model_class=SentenceTransformer,
model_name=self.model_name,
load_model_args=self.load_model_args,
min_batch_size=self.min_batch_size,
max_batch_size=self.max_batch_size,
large_model=self.large_model)

def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
# wrap the model handler in a _TextEmbeddingHandler since
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_TextEmbeddingHandler?
and do we need to create SentenceTransformerImageEmbeddings? shall we just add image_model:bool in SentenceTransformerEmbeddings? or can we infer the model type automatically and then call either _ImageEmbeddingHandler or _TextEmbeddingHandler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Swapping to a bool assignment may be cleaner (and is how we probably need to handle the Inference API version as well,) let me take a run at writing that real quick

# the SentenceTransformerEmbeddings works on text input data.
return (
RunInference(
model_handler=_ImageEmbeddingHandler(self),
inference_args=self.inference_args,
))


class _InferenceAPIHandler(ModelHandler):
def __init__(self, config: 'InferenceAPIEmbeddings'):
super().__init__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
# pylint: disable=ungrouped-imports
try:
from apache_beam.ml.transforms.embeddings.huggingface import SentenceTransformerEmbeddings
from apache_beam.ml.transforms.embeddings.huggingface import SentenceTransformerImageEmbeddings
from apache_beam.ml.transforms.embeddings.huggingface import InferenceAPIEmbeddings
from PIL import Image
import torch
except ImportError:
SentenceTransformerEmbeddings = None # type: ignore
Expand All @@ -46,6 +48,12 @@
except ImportError:
tft = None

# pylint: disable=ungrouped-imports
try:
from PIL import Image
except ImportError:
Image = None

_HF_TOKEN = os.environ.get('HF_INFERENCE_TOKEN')
test_query = "This is a test"
test_query_column = "feature_1"
Expand Down Expand Up @@ -278,6 +286,61 @@ def test_mltransform_to_ptransform_with_sentence_transformer(self):
ptransform_list[i]._model_handler._underlying.model_name, model_name)


@pytest.mark.no_xdist
@unittest.skipIf(
SentenceTransformerEmbeddings is None,
'sentence-transformers is not installed.')
@unittest.skipIf(Image is None, 'Pillow is not installed.')
class SentenceTransformerImageEmbeddingsTest(unittest.TestCase):
def setUp(self) -> None:
self.artifact_location = tempfile.mkdtemp(prefix='sentence_transformers_')
# this bucket has TTL and will be deleted periodically
self.gcs_artifact_location = os.path.join(
'gs://temp-storage-for-perf-tests/sentence_transformers',
uuid.uuid4().hex)
self.model_name = "clip-ViT-B-32"

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)

def generateRandomImage(self, size: int):
imarray = np.random.rand(size, size, 3) * 255
return Image.fromarray(imarray.astype('uint8')).convert('RGBA')

def test_sentence_transformer_image_embeddings(self):
embedding_config = SentenceTransformerImageEmbeddings(
model_name=self.model_name, columns=[test_query_column])
img = self.generateRandomImage(256)
with beam.Pipeline() as pipeline:
result_pcoll = (
pipeline
| "CreateData" >> beam.Create([{
test_query_column: img
}])
| "MLTransform" >> MLTransform(
write_artifact_location=self.artifact_location).with_transform(
embedding_config))

def assert_element(element):
assert len(element[test_query_column]) == 512

_ = (result_pcoll | beam.Map(assert_element))

def test_sentence_transformer_images_with_str_data_types(self):
embedding_config = SentenceTransformerImageEmbeddings(
model_name=self.model_name, columns=[test_query_column])
with self.assertRaises(TypeError):
with beam.Pipeline() as pipeline:
_ = (
pipeline
| "CreateData" >> beam.Create([{
test_query_column: "image.jpg"
}])
| "MLTransform" >> MLTransform(
write_artifact_location=self.artifact_location).with_transform(
embedding_config))


@unittest.skipIf(_HF_TOKEN is None, 'HF_TOKEN environment variable not set.')
class HuggingfaceInferenceAPITest(unittest.TestCase):
def setUp(self):
Expand Down
Loading