diff --git a/haystack/document_stores/faiss.py b/haystack/document_stores/faiss.py index 258b874553..a666439d07 100644 --- a/haystack/document_stores/faiss.py +++ b/haystack/document_stores/faiss.py @@ -324,6 +324,18 @@ def write_documents( def _create_document_field_map(self) -> Dict: return {self.index: self.embedding_field} + def _validate_embedding_dimension(self, retriever: DenseRetriever, index: Optional[str] = None): + """ + Verify if the embedding dimension set in the document store and embedding dimension of the retriever are the same. + This check is done before calculating embeddings for all documents. + :param retriever: Retriever to use to get embeddings for text + :param index: Index name for which embeddings are to be updated. If set to None, the default self.index is used. + :return: None + """ + first_document = self.get_all_documents(index=index)[0] + embeddings = retriever.embed_documents([first_document]) + self._validate_embeddings_shape(embeddings=embeddings, num_documents=1, embedding_dim=self.embedding_dim) + def update_embeddings( self, retriever: DenseRetriever, @@ -373,6 +385,8 @@ def update_embeddings( logger.warning("Calling DocumentStore.update_embeddings() on an empty index") return + self._validate_embedding_dimension(retriever, index) + logger.info("Updating embeddings for %s docs...", document_count) vector_id = self.faiss_indexes[index].ntotal diff --git a/releasenotes/notes/verify-embed-dim-docustore-retriever-9ac88d8f0adc8a32.yaml b/releasenotes/notes/verify-embed-dim-docustore-retriever-9ac88d8f0adc8a32.yaml new file mode 100644 index 0000000000..f8b802b379 --- /dev/null +++ b/releasenotes/notes/verify-embed-dim-docustore-retriever-9ac88d8f0adc8a32.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Add a check to verify that the embedding dimension set in the FAISS Document Store and retriever are equal before running embedding calculations. diff --git a/test/document_stores/test_faiss.py b/test/document_stores/test_faiss.py index b2c0171ae0..0ed24bc487 100644 --- a/test/document_stores/test_faiss.py +++ b/test/document_stores/test_faiss.py @@ -42,6 +42,21 @@ def test_index_mutual_exclusive_args(self, tmp_path): isolation_level="AUTOCOMMIT", ) + @pytest.mark.unit + def test_validate_embedding_dimension_unequal_embedding_dim(self, ds, documents): + retriever = MockDenseRetriever(document_store=ds, embedding_dim=384) + ds.write_documents(documents) + assert ds.get_document_count() == len(documents) + with pytest.raises(RuntimeError): + ds._validate_embedding_dimension(retriever) + + @pytest.mark.unit + def test_validate_embedding_dimension_equal_embedding_dim(self, ds, documents): + retriever = MockDenseRetriever(document_store=ds, embedding_dim=768) + ds.write_documents(documents) + assert ds.get_document_count() == len(documents) + ds._validate_embedding_dimension(retriever) + @pytest.mark.integration def test_delete_index(self, ds, documents): """Contrary to other Document Stores, FAISSDocumentStore doesn't raise if the index is empty"""