-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Hugging Face Image Embedding MLTransform #31536
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ | |
from apache_beam.ml.inference.base import ModelHandler | ||
from apache_beam.ml.inference.base import RunInference | ||
from apache_beam.ml.transforms.base import EmbeddingsManager | ||
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler | ||
from apache_beam.ml.transforms.base import _TextEmbeddingHandler | ||
|
||
try: | ||
|
@@ -153,6 +154,45 @@ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: | |
)) | ||
|
||
|
||
class SentenceTransformerImageEmbeddings(EmbeddingsManager): | ||
def __init__(self, model_name: str, columns: List[str], **kwargs): | ||
""" | ||
Embedding config for sentence-transformers. This config can be used with | ||
MLTransform to embed image data. Models are loaded using the RunInference | ||
PTransform with the help of ModelHandler. | ||
|
||
Args: | ||
model_name: Name of the model to use. The model should be hosted on | ||
HuggingFace Hub or compatible with sentence_transformers. See | ||
https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#image-text-models # pylint: disable=line-too-long | ||
for a list of sentence_transformers models. | ||
columns: List of columns to be embedded. | ||
min_batch_size: The minimum batch size to be used for inference. | ||
max_batch_size: The maximum batch size to be used for inference. | ||
large_model: Whether to share the model across processes. | ||
""" | ||
super().__init__(columns, **kwargs) | ||
self.model_name = model_name | ||
|
||
def get_model_handler(self): | ||
return _SentenceTransformerModelHandler( | ||
model_class=SentenceTransformer, | ||
model_name=self.model_name, | ||
load_model_args=self.load_model_args, | ||
min_batch_size=self.min_batch_size, | ||
max_batch_size=self.max_batch_size, | ||
large_model=self.large_model) | ||
|
||
def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: | ||
# wrap the model handler in a _TextEmbeddingHandler since | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Swapping to a bool assignment may be cleaner (and is how we probably need to handle the Inference API version as well,) let me take a run at writing that real quick |
||
# the SentenceTransformerEmbeddings works on text input data. | ||
return ( | ||
RunInference( | ||
model_handler=_ImageEmbeddingHandler(self), | ||
inference_args=self.inference_args, | ||
)) | ||
|
||
|
||
class _InferenceAPIHandler(ModelHandler): | ||
def __init__(self, config: 'InferenceAPIEmbeddings'): | ||
super().__init__() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we add all these options to
__init__
? similarly toSentenceTransformerEmbeddings
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a bit of a weird case where those parameters are passed up as kwargs and handled by the
EmbeddingsManager
. I'd be okay to explicitly have these in the constructor thoughThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this wound up looking cleaner, PTAL