Skip to content

Commit

Permalink
feat!: Rename model_name or model_name_or_path to model in all …
Browse files Browse the repository at this point in the history
…Embedder classes (#6733)

* rename model parameter in the openai doc embedder

* fix tests for openai doc embedder

* rename model parameter in the openai text embedder

* fix tests for openai text embedder

* rename model parameter in the st doc embedder

* fix tests for st doc embedder

* rename model parameter in the st backend

* fix tests for st backend

* rename model parameter in the st text embedder

* fix tests for st text embedder

* fix docstring

* fix pipeline utils

* fix e2e

* reno

* fix the indexing pipeline _create_embedder function

* fix e2e eval rag pipeline

* pytest
  • Loading branch information
ZanSara authored Jan 12, 2024
1 parent 3156343 commit 288ed15
Show file tree
Hide file tree
Showing 22 changed files with 98 additions and 141 deletions.
6 changes: 2 additions & 4 deletions e2e/pipelines/test_dense_doc_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def test_dense_doc_search_pipeline(tmp_path, samples_path):
instance=DocumentSplitter(split_by="sentence", split_length=250, split_overlap=30), name="splitter"
)
indexing_pipeline.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="embedder"
)
indexing_pipeline.add_component(instance=DocumentWriter(document_store=InMemoryDocumentStore()), name="writer")

Expand Down Expand Up @@ -60,8 +59,7 @@ def test_dense_doc_search_pipeline(tmp_path, samples_path):
# Create the querying pipeline
query_pipeline = Pipeline()
query_pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="text_embedder",
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="text_embedder"
)
query_pipeline.add_component(
instance=InMemoryEmbeddingRetriever(document_store=filled_document_store, top_k=20), name="embedding_retriever"
Expand Down
6 changes: 2 additions & 4 deletions e2e/pipelines/test_eval_dense_doc_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def test_dense_doc_search_pipeline(samples_path):
instance=DocumentSplitter(split_by="sentence", split_length=250, split_overlap=30), name="splitter"
)
indexing_pipeline.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="embedder"
)
indexing_pipeline.add_component(instance=DocumentWriter(document_store=InMemoryDocumentStore()), name="writer")

Expand All @@ -45,8 +44,7 @@ def test_dense_doc_search_pipeline(samples_path):
# Create the querying pipeline
query_pipeline = Pipeline()
query_pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="text_embedder",
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="text_embedder"
)
query_pipeline.add_component(
instance=InMemoryEmbeddingRetriever(document_store=filled_document_store, top_k=20), name="embedding_retriever"
Expand Down
3 changes: 1 addition & 2 deletions e2e/pipelines/test_eval_hybrid_doc_search_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ def test_hybrid_doc_search_pipeline():
hybrid_pipeline = Pipeline()
hybrid_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever")
hybrid_pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="text_embedder",
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="text_embedder"
)
hybrid_pipeline.add_component(
instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever"
Expand Down
5 changes: 2 additions & 3 deletions e2e/pipelines/test_eval_rag_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
"""
rag_pipeline = Pipeline()
rag_pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="text_embedder",
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="text_embedder"
)
rag_pipeline.add_component(
instance=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()), name="retriever"
Expand Down Expand Up @@ -124,7 +123,7 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
document_store = rag_pipeline.get_component("retriever").document_store
indexing_pipeline = Pipeline()
indexing_pipeline.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
name="document_embedder",
)
indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="document_writer")
Expand Down
3 changes: 1 addition & 2 deletions e2e/pipelines/test_hybrid_doc_search_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def test_hybrid_doc_search_pipeline(tmp_path):
hybrid_pipeline = Pipeline()
hybrid_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever")
hybrid_pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="text_embedder",
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="text_embedder"
)
hybrid_pipeline.add_component(
instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever"
Expand Down
3 changes: 1 addition & 2 deletions e2e/pipelines/test_preprocessing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def test_preprocessing_pipeline(tmp_path):
instance=DocumentSplitter(split_by="sentence", split_length=1), name="splitter"
)
preprocessing_pipeline.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="embedder"
)
preprocessing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="writer")
preprocessing_pipeline.connect("file_type_router.text/plain", "text_file_converter.sources")
Expand Down
5 changes: 2 additions & 3 deletions e2e/pipelines/test_rag_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
"""
rag_pipeline = Pipeline()
rag_pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="text_embedder",
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="text_embedder"
)
rag_pipeline.add_component(
instance=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()), name="retriever"
Expand Down Expand Up @@ -131,7 +130,7 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
document_store = rag_pipeline.get_component("retriever").document_store
indexing_pipeline = Pipeline()
indexing_pipeline.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
name="document_embedder",
)
indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="document_writer")
Expand Down
3 changes: 1 addition & 2 deletions examples/pipelines/indexing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
p.add_component(instance=DocumentCleaner(), name="cleaner")
p.add_component(instance=DocumentSplitter(split_by="sentence", split_length=250, split_overlap=30), name="splitter")
p.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="embedder"
)
p.add_component(instance=DocumentWriter(document_store=InMemoryDocumentStore()), name="writer")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@ class _SentenceTransformersEmbeddingBackendFactory:
_instances: Dict[str, "_SentenceTransformersEmbeddingBackend"] = {}

@staticmethod
def get_embedding_backend(
model_name_or_path: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None
):
embedding_backend_id = f"{model_name_or_path}{device}{use_auth_token}"
def get_embedding_backend(model: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None):
embedding_backend_id = f"{model}{device}{use_auth_token}"

if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances:
return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id]
embedding_backend = _SentenceTransformersEmbeddingBackend(
model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token
model=model, device=device, use_auth_token=use_auth_token
)
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend
Expand All @@ -33,13 +31,9 @@ class _SentenceTransformersEmbeddingBackend:
Class to manage Sentence Transformers embeddings.
"""

def __init__(
self, model_name_or_path: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None
):
def __init__(self, model: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None):
sentence_transformers_import.check()
self.model = SentenceTransformer(
model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token
)
self.model = SentenceTransformer(model_name_or_path=model, device=device, use_auth_token=use_auth_token)

def embed(self, data: List[str], **kwargs) -> List[List[float]]:
embeddings = self.model.encode(data, **kwargs).tolist()
Expand Down
12 changes: 6 additions & 6 deletions haystack/components/embedders/openai_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class OpenAIDocumentEmbedder:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "text-embedding-ada-002",
model: str = "text-embedding-ada-002",
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
prefix: str = "",
Expand All @@ -45,7 +45,7 @@ def __init__(
Create a OpenAIDocumentEmbedder component.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param model_name: The name of the model to use.
:param model: The name of the model to use.
:param api_base_url: The OpenAI API Base url, defaults to None. For more details, see OpenAI [docs](https://platform.openai.com/docs/api-reference/audio).
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
Expand All @@ -57,7 +57,7 @@ def __init__(
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
"""
self.model_name = model_name
self.model = model
self.api_base_url = api_base_url
self.organization = organization
self.prefix = prefix
Expand All @@ -73,7 +73,7 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name}
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -82,7 +82,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
organization=self.organization,
api_base_url=self.api_base_url,
prefix=self.prefix,
Expand Down Expand Up @@ -124,7 +124,7 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i : i + batch_size]
response = self.client.embeddings.create(model=self.model_name, input=batch)
response = self.client.embeddings.create(model=self.model, input=batch)
embeddings = [el.embedding for el in response.data]
all_embeddings.extend(embeddings)

Expand Down
12 changes: 6 additions & 6 deletions haystack/components/embedders/openai_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class OpenAITextEmbedder:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "text-embedding-ada-002",
model: str = "text-embedding-ada-002",
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
prefix: str = "",
Expand All @@ -40,15 +40,15 @@ def __init__(
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param model_name: The name of the OpenAI model to use. For more details on the available models,
:param model: The name of the OpenAI model to use. For more details on the available models,
see [OpenAI documentation](https://platform.openai.com/docs/guides/embeddings/embedding-models).
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param api_base_url: The OpenAI API Base url, defaults to None. For more details, see OpenAI [docs](https://platform.openai.com/docs/api-reference/audio).
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
"""
self.model_name = model_name
self.model = model
self.organization = organization
self.prefix = prefix
self.suffix = suffix
Expand All @@ -59,7 +59,7 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name}
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -68,7 +68,7 @@ def to_dict(self) -> Dict[str, Any]:
"""

return default_to_dict(
self, model_name=self.model_name, organization=self.organization, prefix=self.prefix, suffix=self.suffix
self, model=self.model, organization=self.organization, prefix=self.prefix, suffix=self.suffix
)

@component.output_types(embedding=List[float], meta=Dict[str, Any])
Expand All @@ -86,7 +86,7 @@ def run(self, text: str):
# replace newlines, which can negatively affect performance.
text_to_embed = text_to_embed.replace("\n", " ")

response = self.client.embeddings.create(model=self.model_name, input=text_to_embed)
response = self.client.embeddings.create(model=self.model, input=text_to_embed)
meta = {"model": response.model, "usage": dict(response.usage)}

return {"embedding": response.data[0].embedding, "meta": meta}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class SentenceTransformersDocumentEmbedder:

def __init__(
self,
model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2",
model: str = "sentence-transformers/all-mpnet-base-v2",
device: Optional[str] = None,
token: Union[bool, str, None] = None,
prefix: str = "",
Expand All @@ -43,7 +43,7 @@ def __init__(
"""
Create a SentenceTransformersDocumentEmbedder component.
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub,
:param model: Local path or name of the model in Hugging Face's model hub,
such as ``'sentence-transformers/all-mpnet-base-v2'``.
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
Defaults to CPU.
Expand All @@ -61,7 +61,7 @@ def __init__(
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""

self.model_name_or_path = model_name_or_path
self.model = model
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device or "cpu"
self.token = token
Expand All @@ -77,15 +77,15 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name_or_path}
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
model_name_or_path=self.model_name_or_path,
model=self.model,
device=self.device,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
prefix=self.prefix,
Expand All @@ -103,7 +103,7 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.token
model=self.model, device=self.device, use_auth_token=self.token
)

@component.output_types(documents=List[Document])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class SentenceTransformersTextEmbedder:

def __init__(
self,
model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2",
model: str = "sentence-transformers/all-mpnet-base-v2",
device: Optional[str] = None,
token: Union[bool, str, None] = None,
prefix: str = "",
Expand All @@ -40,7 +40,7 @@ def __init__(
"""
Create a SentenceTransformersTextEmbedder component.
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub,
:param model: Local path or name of the model in Hugging Face's model hub,
such as ``'sentence-transformers/all-mpnet-base-v2'``.
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
Defaults to CPU.
Expand All @@ -56,7 +56,7 @@ def __init__(
:param normalize_embeddings: If set to true, returned vectors will have length 1.
"""

self.model_name_or_path = model_name_or_path
self.model = model
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device or "cpu"
self.token = token
Expand All @@ -70,15 +70,15 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name_or_path}
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
model_name_or_path=self.model_name_or_path,
model=self.model,
device=self.device,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
prefix=self.prefix,
Expand All @@ -94,7 +94,7 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.token
model=self.model, device=self.device, use_auth_token=self.token
)

@component.output_types(embedding=List[float])
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/joiners/document_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DocumentJoiner:
p = Pipeline()
p.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever")
p.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
name="text_embedder",
)
p.add_component(instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever")
Expand Down
Loading

0 comments on commit 288ed15

Please sign in to comment.