Skip to content

Commit

Permalink
Merge branch 'main' into elasticsearch
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Nov 6, 2023
2 parents a052341 + a66add2 commit 805dcaa
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,50 @@ class InstructorDocumentEmbedder:
"""
A component for computing Document embeddings using INSTRUCTOR embedding models.
The embedding of each Document is stored in the `embedding` field of the Document.
"""
Usage example:
```python
# To use this component, install the "instructor-embedders-haystack" package.
# pip install instructor-embedders-haystack
from instructor_embedders.instructor_document_embedder import InstructorDocumentEmbedder
from haystack.preview.dataclasses import Document
doc_embedding_instruction = "Represent the Medical Document for retrieval:"
doc_embedder = InstructorDocumentEmbedder(
model_name_or_path="hkunlp/instructor-base",
instruction=doc_embedding_instruction,
batch_size=32,
device="cpu",
)
doc_embedder.warm_up()
# Text taken from PubMed QA Dataset (https://huggingface.co/datasets/pubmed_qa)
document_list = [
Document(
content="Oxidative stress generated within inflammatory joints can produce autoimmune phenomena and joint destruction. Radical species with oxidative activity, including reactive nitrogen species, represent mediators of inflammation and cartilage damage.",
meta={
"pubid": "25,445,628",
"long_answer": "yes",
},
),
Document(
content="Plasma levels of pancreatic polypeptide (PP) rise upon food intake. Although other pancreatic islet hormones, such as insulin and glucagon, have been extensively investigated, PP secretion and actions are still poorly understood.",
meta={
"pubid": "25,445,712",
"long_answer": "yes",
},
),
]
result = doc_embedder.run(document_list)
print(f"Document Text: {result['documents'][0].text}")
print(f"Document Embedding: {result['documents'][0].embedding}")
print(f"Embedding Dimension: {len(result['documents'][0].embedding)}")
""" # noqa: E501

def __init__(
self,
Expand Down Expand Up @@ -100,8 +143,10 @@ def run(self, documents: List[Document]):
The embedding of each Document is stored in the `embedding` field of the Document.
"""
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
msg = ("InstructorDocumentEmbedder expects a list of Documents as input. "
"In case you want to embed a list of strings, please use the InstructorTextEmbedder.")
msg = (
"InstructorDocumentEmbedder expects a list of Documents as input. "
"In case you want to embed a list of strings, please use the InstructorTextEmbedder."
)
raise TypeError(msg)
if not hasattr(self, "embedding_backend"):
msg = "The embedding model has not been loaded. Please call warm_up() before running."
Expand All @@ -112,11 +157,14 @@ def run(self, documents: List[Document]):
texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.metadata[key])
str(doc.meta[key])
for key in self.metadata_fields_to_embed
if key in doc.metadata and doc.metadata[key] is not None
if key in doc.meta and doc.meta[key] is not None
]
text_to_embed = [
self.instruction,
self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]),
]
text_to_embed = [self.instruction, self.embedding_separator.join([*meta_values_to_embed, doc.text or ""])]
texts_to_embed.append(text_to_embed)

embeddings = self.embedding_backend.embed(
Expand All @@ -126,11 +174,7 @@ def run(self, documents: List[Document]):
normalize_embeddings=self.normalize_embeddings,
)

documents_with_embeddings = []
for doc, emb in zip(documents, embeddings):
doc_as_dict = doc.to_dict()
doc_as_dict["embedding"] = emb
del doc_as_dict["id"]
documents_with_embeddings.append(Document.from_dict(doc_as_dict))
doc.embedding = emb

return {"documents": documents_with_embeddings}
return {"documents": documents}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,28 @@
@component
class InstructorTextEmbedder:
"""
A component for embedding strings using Sentence Transformers models.
"""
A component for embedding strings using INSTRUCTOR embedding models.
Usage example:
```python
# To use this component, install the "instructor-embedders-haystack" package.
# pip install instructor-embedders-haystack
from instructor_embedders.instructor_text_embedder import InstructorTextEmbedder
text = "It clearly says online this will work on a Mac OS system. The disk comes and it does not, only Windows. Do Not order this if you have a Mac!!"
instruction = (
"Represent the Amazon comment for classifying the sentence as positive or negative"
)
text_embedder = InstructorTextEmbedder(
model_name_or_path="hkunlp/instructor-base", instruction=instruction,
device="cpu"
)
embedding = text_embedder.run(text)
```
""" # noqa: E501

def __init__(
self,
Expand Down Expand Up @@ -88,8 +108,10 @@ def warm_up(self):
def run(self, text: str):
"""Embed a string."""
if not isinstance(text, str):
msg = ("InstructorTextEmbedder expects a string as input. "
"In case you want to embed a list of Documents, please use the InstructorDocumentEmbedder.")
msg = (
"InstructorTextEmbedder expects a string as input. "
"In case you want to embed a list of Documents, please use the InstructorDocumentEmbedder."
)
raise TypeError(msg)
if not hasattr(self, "embedding_backend"):
msg = "The embedding model has not been loaded. Please call warm_up() before running."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@pytest.mark.unit
@patch("instructor_embedders.embedding_backend.instructor_backend.INSTRUCTOR")
def test_factory_behavior(mock_instructor): # noqa: ARG001
def test_factory_behavior(mock_instructor): # noqa: ARG001
embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path="hkunlp/instructor-large", device="cpu"
)
Expand All @@ -33,7 +33,7 @@ def test_model_initialization(mock_instructor):

@pytest.mark.unit
@patch("instructor_embedders.embedding_backend.instructor_backend.INSTRUCTOR")
def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001
def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001
embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path="hkunlp/instructor-base"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_embed(self):
embedder.embedding_backend = MagicMock()
embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() # noqa: ARG005

documents = [Document(text=f"Sample-document text {i}") for i in range(5)]
documents = [Document(content=f"Sample-document text {i}") for i in range(5)]

result = embedder.run(documents=documents)

Expand Down Expand Up @@ -239,9 +239,7 @@ def test_embed_metadata(self):
)
embedder.embedding_backend = MagicMock()

documents = [
Document(text=f"document-number {i}", metadata={"meta_field": f"meta_value {i}"}) for i in range(5)
]
documents = [Document(content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]

embedder.run(documents=documents)

Expand All @@ -260,12 +258,14 @@ def test_embed_metadata(self):

@pytest.mark.integration
def test_run(self):
embedder = InstructorDocumentEmbedder(model_name_or_path="hkunlp/instructor-base",
device="cpu",
instruction="Represent the Science document for retrieval")
embedder = InstructorDocumentEmbedder(
model_name_or_path="hkunlp/instructor-base",
device="cpu",
instruction="Represent the Science document for retrieval",
)
embedder.warm_up()

doc = Document(text="Parton energy loss in QCD matter")
doc = Document(content="Parton energy loss in QCD matter")

result = embedder.run(documents=[doc])
embedding = result["documents"][0].embedding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,11 @@ def test_run_wrong_incorrect_format(self):

@pytest.mark.integration
def test_run(self):
embedder = InstructorTextEmbedder(model_name_or_path="hkunlp/instructor-base",
device="cpu",
instruction="Represent the Science sentence for retrieval")
embedder = InstructorTextEmbedder(
model_name_or_path="hkunlp/instructor-base",
device="cpu",
instruction="Represent the Science sentence for retrieval",
)
embedder.warm_up()

text = "Parton energy loss in QCD matter"
Expand Down

0 comments on commit 805dcaa

Please sign in to comment.