Skip to content

Commit

Permalink
feat: SentenceTransformersTextEmbedder supports config_kwargs (#8432)
Browse files Browse the repository at this point in the history
* add config_kwargs

* disable PLR0913 for a specific function

* add a release note

* refer to AutoConfig in config_kwargs docstring

---------

Co-authored-by: David S. Batista <[email protected]>
Co-authored-by: Julian Risch <[email protected]>
  • Loading branch information
3 people authored Oct 14, 2024
1 parent b81abc0 commit b40f0c8
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class SentenceTransformersTextEmbedder:
```
"""

def __init__(
def __init__( # noqa: PLR0913
self,
model: str = "sentence-transformers/all-mpnet-base-v2",
device: Optional[ComponentDevice] = None,
Expand All @@ -48,6 +48,7 @@ def __init__(
truncate_dim: Optional[int] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
):
"""
Expand Down Expand Up @@ -86,6 +87,8 @@ def __init__(
:param tokenizer_kwargs:
Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
Refer to specific model documentation for available kwargs.
:param config_kwargs:
Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
:param precision:
The precision to use for the embeddings.
All non-float32 precisions are quantized embeddings.
Expand All @@ -105,6 +108,7 @@ def __init__(
self.truncate_dim = truncate_dim
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs
self.config_kwargs = config_kwargs
self.embedding_backend = None
self.precision = precision

Expand Down Expand Up @@ -135,6 +139,7 @@ def to_dict(self) -> Dict[str, Any]:
truncate_dim=self.truncate_dim,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
precision=self.precision,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
Expand Down Expand Up @@ -172,6 +177,7 @@ def warm_up(self):
truncate_dim=self.truncate_dim,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
)
if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
SentenceTransformersTextEmbedder now supports config_kwargs for additional parameters when loading the model configuration
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_to_dict(self):
"truncate_dim": None,
"model_kwargs": None,
"tokenizer_kwargs": None,
"config_kwargs": None,
"precision": "float32",
},
}
Expand All @@ -88,6 +89,7 @@ def test_to_dict_with_custom_init_parameters(self):
truncate_dim=256,
model_kwargs={"torch_dtype": torch.float32},
tokenizer_kwargs={"model_max_length": 512},
config_kwargs={"use_memory_efficient_attention": False},
precision="int8",
)
data = component.to_dict()
Expand All @@ -106,6 +108,7 @@ def test_to_dict_with_custom_init_parameters(self):
"truncate_dim": 256,
"model_kwargs": {"torch_dtype": "torch.float32"},
"tokenizer_kwargs": {"model_max_length": 512},
"config_kwargs": {"use_memory_efficient_attention": False},
"precision": "int8",
},
}
Expand All @@ -131,6 +134,7 @@ def test_from_dict(self):
"truncate_dim": None,
"model_kwargs": {"torch_dtype": "torch.float32"},
"tokenizer_kwargs": {"model_max_length": 512},
"config_kwargs": {"use_memory_efficient_attention": False},
"precision": "float32",
},
}
Expand All @@ -147,6 +151,7 @@ def test_from_dict(self):
assert component.truncate_dim is None
assert component.model_kwargs == {"torch_dtype": torch.float32}
assert component.tokenizer_kwargs == {"model_max_length": 512}
assert component.config_kwargs == {"use_memory_efficient_attention": False}
assert component.precision == "float32"

def test_from_dict_no_default_parameters(self):
Expand Down Expand Up @@ -218,6 +223,7 @@ def test_warmup(self, mocked_factory):
truncate_dim=None,
model_kwargs=None,
tokenizer_kwargs={"model_max_length": 512},
config_kwargs=None,
)

@patch(
Expand Down

0 comments on commit b40f0c8

Please sign in to comment.