Skip to content

Commit

Permalink
feat: embedding instructions for dense retrieval (#6372)
Browse files Browse the repository at this point in the history
* Embedding instructions in EmbeddingRetriever

Query and documents embeddings are prefixed with instructions, useful
for retrievers finetuned on specific tasks, such as Q&A.

* Tests

Checking vectors 0th component vs. reference, using different stores.

* Normalizing vectors

* Release notes
  • Loading branch information
danielfleischer authored Nov 21, 2023
1 parent 07cda09 commit 0cef17a
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 1 deletion.
12 changes: 11 additions & 1 deletion haystack/nodes/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,8 @@ def __init__(
max_seq_len: int = 512,
model_format: Optional[str] = None,
pooling_strategy: str = "reduce_mean",
query_prompt: Optional[str] = None,
passage_prompt: Optional[str] = None,
emb_extraction_layer: int = -1,
top_k: int = 10,
progress_bar: bool = True,
Expand Down Expand Up @@ -1495,7 +1497,8 @@ def __init__(
2. `reduce_mean` (sentence vector)
3. `reduce_max` (sentence vector)
4. `per_token` (individual token vectors)
:param query_prompt: Model instruction for embedding texts to be used as queries.
:param passage_prompt: Model instruction for embedding texts to be retrieved.
:param emb_extraction_layer: Number of layer from which the embeddings shall be extracted (for farm / transformers models only).
Default: -1 (very last layer).
:param top_k: How many documents to return per query.
Expand Down Expand Up @@ -1550,6 +1553,8 @@ def __init__(
self.max_seq_len = max_seq_len
self.pooling_strategy = pooling_strategy
self.emb_extraction_layer = emb_extraction_layer
self.query_prompt = query_prompt
self.passage_prompt = passage_prompt
self.top_k = top_k
self.progress_bar = progress_bar
self.use_auth_token = use_auth_token
Expand Down Expand Up @@ -1830,6 +1835,8 @@ def embed_queries(self, queries: List[str]) -> np.ndarray:
if isinstance(queries, str):
queries = [queries]
assert isinstance(queries, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
if self.query_prompt:
queries = [self.query_prompt + " " + q for q in queries]
return self.embedding_encoder.embed_queries(queries)

def embed_documents(self, documents: List[Document]) -> np.ndarray:
Expand All @@ -1840,6 +1847,9 @@ def embed_documents(self, documents: List[Document]) -> np.ndarray:
:return: Embeddings, one per input document, shape: (docs, embedding_dim)
"""
documents = self._preprocess_documents(documents)
if self.passage_prompt:
for doc in documents:
doc.content = self.passage_prompt + " " + doc.content
return self.embedding_encoder.embed_documents(documents)

def _preprocess_documents(self, docs: List[Document]) -> List[Document]:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Support for dense embedding instructions, used in retrieval models such as BGE and LLM-Embedder.
9 changes: 9 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,15 @@ def get_retriever(retriever_type, document_store):
model_format="sentence_transformers",
use_gpu=False,
)
elif retriever_type == "embedding_sbert_instructions":
retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model="sentence-transformers/msmarco-distilbert-dot-v5",
model_format="sentence_transformers",
query_prompt="Embed this query for retrieval:",
passage_prompt="Embed this passage for retrieval:",
use_gpu=False,
)
elif retriever_type == "retribert":
retriever = EmbeddingRetriever(
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False
Expand Down
22 changes: 22 additions & 0 deletions test/nodes/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,28 @@ def test_dpr_embedding(document_store: BaseDocumentStore, retriever, docs_with_i
assert isclose(embedding[0], expected_value, rel_tol=0.01)


@pytest.mark.integration
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "weaviate", "pinecone"], indirect=True)
@pytest.mark.parametrize("retriever", ["embedding_sbert_instructions"], indirect=True)
def test_embedding_with_instructions(document_store: BaseDocumentStore, retriever, docs_with_ids):
document_store.return_embedding = True
document_store.write_documents(docs_with_ids)
document_store.update_embeddings(retriever=retriever)

docs = document_store.get_all_documents()
docs.sort(key=lambda d: d.id)

print([doc.id for doc in docs])

expected_values = [0.00484978, 0.02258789, 0.03414359, -0.01461711, 0.01784192]
for doc, expected_value in zip(docs, expected_values):
embedding = doc.embedding
# always normalize vector as faiss returns normalized vectors and other document stores do not
embedding /= np.linalg.norm(embedding)
assert len(embedding) == 768
assert isclose(embedding[0], expected_value, rel_tol=0.01)


@pytest.mark.integration
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "weaviate", "pinecone"], indirect=True)
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
Expand Down

0 comments on commit 0cef17a

Please sign in to comment.