Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Rename model_name or model_name_or_path to model in all Embedder classes #6733

Merged
merged 21 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions e2e/pipelines/test_dense_doc_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def test_dense_doc_search_pipeline(tmp_path, samples_path):
instance=DocumentSplitter(split_by="sentence", split_length=250, split_overlap=30), name="splitter"
)
indexing_pipeline.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="embedder"
)
indexing_pipeline.add_component(instance=DocumentWriter(document_store=InMemoryDocumentStore()), name="writer")

Expand Down Expand Up @@ -60,8 +59,7 @@ def test_dense_doc_search_pipeline(tmp_path, samples_path):
# Create the querying pipeline
query_pipeline = Pipeline()
query_pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="text_embedder",
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="text_embedder"
)
query_pipeline.add_component(
instance=InMemoryEmbeddingRetriever(document_store=filled_document_store, top_k=20), name="embedding_retriever"
Expand Down
6 changes: 2 additions & 4 deletions e2e/pipelines/test_eval_dense_doc_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def test_dense_doc_search_pipeline(samples_path):
instance=DocumentSplitter(split_by="sentence", split_length=250, split_overlap=30), name="splitter"
)
indexing_pipeline.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="embedder"
)
indexing_pipeline.add_component(instance=DocumentWriter(document_store=InMemoryDocumentStore()), name="writer")

Expand All @@ -45,8 +44,7 @@ def test_dense_doc_search_pipeline(samples_path):
# Create the querying pipeline
query_pipeline = Pipeline()
query_pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="text_embedder",
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="text_embedder"
)
query_pipeline.add_component(
instance=InMemoryEmbeddingRetriever(document_store=filled_document_store, top_k=20), name="embedding_retriever"
Expand Down
3 changes: 1 addition & 2 deletions e2e/pipelines/test_eval_hybrid_doc_search_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ def test_hybrid_doc_search_pipeline():
hybrid_pipeline = Pipeline()
hybrid_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever")
hybrid_pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="text_embedder",
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="text_embedder"
)
hybrid_pipeline.add_component(
instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever"
Expand Down
5 changes: 2 additions & 3 deletions e2e/pipelines/test_eval_rag_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
"""
rag_pipeline = Pipeline()
rag_pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="text_embedder",
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="text_embedder"
)
rag_pipeline.add_component(
instance=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()), name="retriever"
Expand Down Expand Up @@ -124,7 +123,7 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
document_store = rag_pipeline.get_component("retriever").document_store
indexing_pipeline = Pipeline()
indexing_pipeline.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
name="document_embedder",
)
indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="document_writer")
Expand Down
3 changes: 1 addition & 2 deletions e2e/pipelines/test_hybrid_doc_search_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def test_hybrid_doc_search_pipeline(tmp_path):
hybrid_pipeline = Pipeline()
hybrid_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever")
hybrid_pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="text_embedder",
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="text_embedder"
)
hybrid_pipeline.add_component(
instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever"
Expand Down
3 changes: 1 addition & 2 deletions e2e/pipelines/test_preprocessing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def test_preprocessing_pipeline(tmp_path):
instance=DocumentSplitter(split_by="sentence", split_length=1), name="splitter"
)
preprocessing_pipeline.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="embedder"
)
preprocessing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="writer")
preprocessing_pipeline.connect("file_type_router.text/plain", "text_file_converter.sources")
Expand Down
5 changes: 2 additions & 3 deletions e2e/pipelines/test_rag_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
"""
rag_pipeline = Pipeline()
rag_pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="text_embedder",
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="text_embedder"
)
rag_pipeline.add_component(
instance=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()), name="retriever"
Expand Down Expand Up @@ -131,7 +130,7 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
document_store = rag_pipeline.get_component("retriever").document_store
indexing_pipeline = Pipeline()
indexing_pipeline.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
name="document_embedder",
)
indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="document_writer")
Expand Down
3 changes: 1 addition & 2 deletions examples/pipelines/indexing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
p.add_component(instance=DocumentCleaner(), name="cleaner")
p.add_component(instance=DocumentSplitter(split_by="sentence", split_length=250, split_overlap=30), name="splitter")
p.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="embedder"
)
p.add_component(instance=DocumentWriter(document_store=InMemoryDocumentStore()), name="writer")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@ class _SentenceTransformersEmbeddingBackendFactory:
_instances: Dict[str, "_SentenceTransformersEmbeddingBackend"] = {}

@staticmethod
def get_embedding_backend(
model_name_or_path: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None
):
embedding_backend_id = f"{model_name_or_path}{device}{use_auth_token}"
def get_embedding_backend(model: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None):
embedding_backend_id = f"{model}{device}{use_auth_token}"

if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances:
return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id]
embedding_backend = _SentenceTransformersEmbeddingBackend(
model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token
model=model, device=device, use_auth_token=use_auth_token
)
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend
Expand All @@ -33,13 +31,9 @@ class _SentenceTransformersEmbeddingBackend:
Class to manage Sentence Transformers embeddings.
"""

def __init__(
self, model_name_or_path: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None
):
def __init__(self, model: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None):
sentence_transformers_import.check()
self.model = SentenceTransformer(
model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token
)
self.model = SentenceTransformer(model_name_or_path=model, device=device, use_auth_token=use_auth_token)

def embed(self, data: List[str], **kwargs) -> List[List[float]]:
embeddings = self.model.encode(data, **kwargs).tolist()
Expand Down
12 changes: 6 additions & 6 deletions haystack/components/embedders/openai_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class OpenAIDocumentEmbedder:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "text-embedding-ada-002",
model: str = "text-embedding-ada-002",
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
prefix: str = "",
Expand All @@ -45,7 +45,7 @@ def __init__(
Create a OpenAIDocumentEmbedder component.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param model_name: The name of the model to use.
:param model: The name of the model to use.
:param api_base_url: The OpenAI API Base url, defaults to None. For more details, see OpenAI [docs](https://platform.openai.com/docs/api-reference/audio).
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
Expand All @@ -57,7 +57,7 @@ def __init__(
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
"""
self.model_name = model_name
self.model = model
self.api_base_url = api_base_url
self.organization = organization
self.prefix = prefix
Expand All @@ -73,7 +73,7 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name}
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -82,7 +82,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
organization=self.organization,
api_base_url=self.api_base_url,
prefix=self.prefix,
Expand Down Expand Up @@ -124,7 +124,7 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i : i + batch_size]
response = self.client.embeddings.create(model=self.model_name, input=batch)
response = self.client.embeddings.create(model=self.model, input=batch)
embeddings = [el.embedding for el in response.data]
all_embeddings.extend(embeddings)

Expand Down
12 changes: 6 additions & 6 deletions haystack/components/embedders/openai_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class OpenAITextEmbedder:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "text-embedding-ada-002",
model: str = "text-embedding-ada-002",
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
prefix: str = "",
Expand All @@ -40,15 +40,15 @@ def __init__(

:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param model_name: The name of the OpenAI model to use. For more details on the available models,
:param model: The name of the OpenAI model to use. For more details on the available models,
see [OpenAI documentation](https://platform.openai.com/docs/guides/embeddings/embedding-models).
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param api_base_url: The OpenAI API Base url, defaults to None. For more details, see OpenAI [docs](https://platform.openai.com/docs/api-reference/audio).
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
"""
self.model_name = model_name
self.model = model
self.organization = organization
self.prefix = prefix
self.suffix = suffix
Expand All @@ -59,7 +59,7 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name}
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -68,7 +68,7 @@ def to_dict(self) -> Dict[str, Any]:
"""

return default_to_dict(
self, model_name=self.model_name, organization=self.organization, prefix=self.prefix, suffix=self.suffix
self, model=self.model, organization=self.organization, prefix=self.prefix, suffix=self.suffix
)

@component.output_types(embedding=List[float], meta=Dict[str, Any])
Expand All @@ -86,7 +86,7 @@ def run(self, text: str):
# replace newlines, which can negatively affect performance.
text_to_embed = text_to_embed.replace("\n", " ")

response = self.client.embeddings.create(model=self.model_name, input=text_to_embed)
response = self.client.embeddings.create(model=self.model, input=text_to_embed)
meta = {"model": response.model, "usage": dict(response.usage)}

return {"embedding": response.data[0].embedding, "meta": meta}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class SentenceTransformersDocumentEmbedder:

def __init__(
self,
model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2",
model: str = "sentence-transformers/all-mpnet-base-v2",
device: Optional[str] = None,
token: Union[bool, str, None] = None,
prefix: str = "",
Expand All @@ -43,7 +43,7 @@ def __init__(
"""
Create a SentenceTransformersDocumentEmbedder component.

:param model_name_or_path: Local path or name of the model in Hugging Face's model hub,
:param model: Local path or name of the model in Hugging Face's model hub,
such as ``'sentence-transformers/all-mpnet-base-v2'``.
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
Defaults to CPU.
Expand All @@ -61,7 +61,7 @@ def __init__(
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""

self.model_name_or_path = model_name_or_path
self.model = model
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device or "cpu"
self.token = token
Expand All @@ -77,15 +77,15 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name_or_path}
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
model_name_or_path=self.model_name_or_path,
model=self.model,
device=self.device,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
prefix=self.prefix,
Expand All @@ -103,7 +103,7 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.token
model=self.model, device=self.device, use_auth_token=self.token
)

@component.output_types(documents=List[Document])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class SentenceTransformersTextEmbedder:

def __init__(
self,
model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2",
model: str = "sentence-transformers/all-mpnet-base-v2",
device: Optional[str] = None,
token: Union[bool, str, None] = None,
prefix: str = "",
Expand All @@ -40,7 +40,7 @@ def __init__(
"""
Create a SentenceTransformersTextEmbedder component.

:param model_name_or_path: Local path or name of the model in Hugging Face's model hub,
:param model: Local path or name of the model in Hugging Face's model hub,
such as ``'sentence-transformers/all-mpnet-base-v2'``.
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
Defaults to CPU.
Expand All @@ -56,7 +56,7 @@ def __init__(
:param normalize_embeddings: If set to true, returned vectors will have length 1.
"""

self.model_name_or_path = model_name_or_path
self.model = model
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device or "cpu"
self.token = token
Expand All @@ -70,15 +70,15 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name_or_path}
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
model_name_or_path=self.model_name_or_path,
model=self.model,
device=self.device,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
prefix=self.prefix,
Expand All @@ -94,7 +94,7 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.token
model=self.model, device=self.device, use_auth_token=self.token
)

@component.output_types(embedding=List[float])
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/joiners/document_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DocumentJoiner:
p = Pipeline()
p.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever")
p.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
name="text_embedder",
)
p.add_component(instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever")
Expand Down
Loading