Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Hugging Face Image Embedding MLTransform #31536

Merged
merged 3 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler
from apache_beam.ml.transforms.base import _TextEmbeddingHandler

try:
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(
model_name: str,
columns: List[str],
max_seq_length: Optional[int] = None,
image_model: bool = False,
**kwargs):
"""
Embedding config for sentence-transformers. This config can be used with
Expand All @@ -122,16 +124,21 @@ def __init__(

Args:
model_name: Name of the model to use. The model should be hosted on
HuggingFace Hub or compatible with sentence_transformers.
HuggingFace Hub or compatible with sentence_transformers. For image
embedding models, see
https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#image-text-models # pylint: disable=line-too-long
for a list of available sentence_transformers models.
columns: List of columns to be embedded.
max_seq_length: Max sequence length to use for the model if applicable.
image_model: Whether the model is generating image embeddings.
min_batch_size: The minimum batch size to be used for inference.
max_batch_size: The maximum batch size to be used for inference.
large_model: Whether to share the model across processes.
"""
super().__init__(columns, **kwargs)
self.model_name = model_name
self.max_seq_length = max_seq_length
self.image_model = image_model

def get_model_handler(self):
return _SentenceTransformerModelHandler(
Expand All @@ -144,8 +151,14 @@ def get_model_handler(self):
large_model=self.large_model)

def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
# wrap the model handler in a _TextEmbeddingHandler since
# the SentenceTransformerEmbeddings works on text input data.
# wrap the model handler in an appropriate embedding handler to provide
# some type checking.
if self.image_model:
return (
RunInference(
model_handler=_ImageEmbeddingHandler(self),
inference_args=self.inference_args,
))
return (
RunInference(
model_handler=_TextEmbeddingHandler(self),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
try:
from apache_beam.ml.transforms.embeddings.huggingface import SentenceTransformerEmbeddings
from apache_beam.ml.transforms.embeddings.huggingface import InferenceAPIEmbeddings
from PIL import Image
import torch
except ImportError:
SentenceTransformerEmbeddings = None # type: ignore
Expand All @@ -46,10 +47,17 @@
except ImportError:
tft = None

# pylint: disable=ungrouped-imports
try:
from PIL import Image
except ImportError:
Image = None

_HF_TOKEN = os.environ.get('HF_INFERENCE_TOKEN')
test_query = "This is a test"
test_query_column = "feature_1"
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
IMAGE_MODEL_NAME = "clip-ViT-B-32"
_parameterized_inputs = [
([{
test_query_column: 'That is a happy person'
Expand Down Expand Up @@ -85,7 +93,7 @@
@unittest.skipIf(
SentenceTransformerEmbeddings is None,
'sentence-transformers is not installed.')
class SentenceTrasformerEmbeddingsTest(unittest.TestCase):
class SentenceTransformerEmbeddingsTest(unittest.TestCase):
def setUp(self) -> None:
self.artifact_location = tempfile.mkdtemp(prefix='sentence_transformers_')
# this bucket has TTL and will be deleted periodically
Expand Down Expand Up @@ -277,6 +285,48 @@ def test_mltransform_to_ptransform_with_sentence_transformer(self):
self.assertEqual(
ptransform_list[i]._model_handler._underlying.model_name, model_name)

def generateRandomImage(self, size: int):
imarray = np.random.rand(size, size, 3) * 255
return Image.fromarray(imarray.astype('uint8')).convert('RGBA')

@unittest.skipIf(Image is None, 'Pillow is not installed.')
def test_sentence_transformer_image_embeddings(self):
embedding_config = SentenceTransformerEmbeddings(
model_name=IMAGE_MODEL_NAME,
columns=[test_query_column],
image_model=True)
img = self.generateRandomImage(256)
with beam.Pipeline() as pipeline:
result_pcoll = (
pipeline
| "CreateData" >> beam.Create([{
test_query_column: img
}])
| "MLTransform" >> MLTransform(
write_artifact_location=self.artifact_location).with_transform(
embedding_config))

def assert_element(element):
assert len(element[test_query_column]) == 512

_ = (result_pcoll | beam.Map(assert_element))

def test_sentence_transformer_images_with_str_data_types(self):
embedding_config = SentenceTransformerEmbeddings(
model_name=IMAGE_MODEL_NAME,
columns=[test_query_column],
image_model=True)
with self.assertRaises(TypeError):
with beam.Pipeline() as pipeline:
_ = (
pipeline
| "CreateData" >> beam.Create([{
test_query_column: "image.jpg"
}])
| "MLTransform" >> MLTransform(
write_artifact_location=self.artifact_location).with_transform(
embedding_config))


@unittest.skipIf(_HF_TOKEN is None, 'HF_TOKEN environment variable not set.')
class HuggingfaceInferenceAPITest(unittest.TestCase):
Expand Down
Loading