From 4c1871d9a8f05a7157e7a45d7849e51b0e20ce86 Mon Sep 17 00:00:00 2001 From: Mohammad Mohtashim <45242107+keenborder786@users.noreply.github.com> Date: Mon, 16 Dec 2024 01:34:29 +0500 Subject: [PATCH] community: Passing the `model_kwargs` correctly while maintaing backward compatability (#28439) - **Description:** `Model_Kwargs` was not being passed correctly to `sentence_transformers.SentenceTransformer` which has been corrected while maintaing backward compatability - **Issue:** #28436 --------- Co-authored-by: MoosaTae Co-authored-by: Sadit Wongprayon <101176694+MoosaTae@users.noreply.github.com> Co-authored-by: Erick Friis --- .../embeddings/huggingface.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/embeddings/huggingface.py b/libs/community/langchain_community/embeddings/huggingface.py index cc2e073160b17..5810bbc920e07 100644 --- a/libs/community/langchain_community/embeddings/huggingface.py +++ b/libs/community/langchain_community/embeddings/huggingface.py @@ -244,6 +244,11 @@ def embed_query(self, text: str) -> List[float]: return embedding.tolist() +@deprecated( + since="0.2.2", + removal="1.0", + alternative_import="langchain_huggingface.HuggingFaceEmbeddings", +) class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): """HuggingFace sentence_transformers embedding models. @@ -322,11 +327,25 @@ def __init__(self, **kwargs: Any): except ImportError as exc: raise ImportError( "Could not import sentence_transformers python package. " - "Please install it with `pip install sentence_transformers`." + "Please install it with `pip install sentence-transformers`." ) from exc - + extra_model_kwargs = [ + "torch_dtype", + "attn_implementation", + "provider", + "file_name", + "export", + ] + extra_model_kwargs_dict = { + k: self.model_kwargs.pop(k) + for k in extra_model_kwargs + if k in self.model_kwargs + } self.client = sentence_transformers.SentenceTransformer( - self.model_name, cache_folder=self.cache_folder, **self.model_kwargs + self.model_name, + cache_folder=self.cache_folder, + **self.model_kwargs, + model_kwargs=extra_model_kwargs_dict, ) if "-zh" in self.model_name: