Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add support for custom mapping in ElasticsearchDocumentStore #721

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
self,
*,
hosts: Optional[Hosts] = None,
custom_mapping: Optional[Dict[str, Any]] = None,
index: str = "default",
embedding_similarity_function: Literal["cosine", "dot_product", "l2_norm", "max_inner_product"] = "cosine",
**kwargs,
Expand All @@ -82,6 +83,7 @@ def __init__(
[reference](https://elasticsearch-py.readthedocs.io/en/stable/api.html#module-elasticsearch)

:param hosts: List of hosts running the Elasticsearch client.
:param custom_mapping: Custom mapping for the index. If not provided, a default mapping will be used.
:param index: Name of index in Elasticsearch.
:param embedding_similarity_function: The similarity function used to compare Documents embeddings.
This parameter only takes effect if the index does not yet exist and is created.
Expand All @@ -98,29 +100,37 @@ def __init__(
)
self._index = index
self._embedding_similarity_function = embedding_similarity_function
self._custom_mapping = custom_mapping
self._kwargs = kwargs

# Check client connection, this will raise if not connected
self._client.info()

# configure mapping for the embedding field
mappings = {
"properties": {
"embedding": {"type": "dense_vector", "index": True, "similarity": embedding_similarity_function},
"content": {"type": "text"},
},
"dynamic_templates": [
{
"strings": {
"path_match": "*",
"match_mapping_type": "string",
"mapping": {
"type": "keyword",
},
if self._custom_mapping and not isinstance(self._custom_mapping, Dict):
msg = "custom_mapping must be a dictionary"
raise ValueError(msg)

if self._custom_mapping:
mappings = self._custom_mapping
else:
# Configure mapping for the embedding field if none is provided
mappings = {
"properties": {
"embedding": {"type": "dense_vector", "index": True, "similarity": embedding_similarity_function},
"content": {"type": "text"},
},
"dynamic_templates": [
{
"strings": {
"path_match": "*",
"match_mapping_type": "string",
"mapping": {
"type": "keyword",
},
}
}
}
],
}
],
}

# Create the index if it doesn't exist
if not self._client.indices.exists(index=index):
Expand All @@ -139,6 +149,7 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
hosts=self._hosts,
custom_mapping=self._custom_mapping,
index=self._index,
embedding_similarity_function=self._embedding_similarity_function,
**self._kwargs,
Expand Down
1 change: 1 addition & 0 deletions integrations/elasticsearch/tests/test_bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_to_dict(_mock_elasticsearch_client):
"document_store": {
"init_parameters": {
"hosts": "some fake host",
"custom_mapping": None,
"index": "default",
"embedding_similarity_function": "cosine",
},
Expand Down
35 changes: 34 additions & 1 deletion integrations/elasticsearch/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import random
from typing import List
from unittest.mock import patch
from unittest.mock import Mock, patch

import pytest
from elasticsearch.exceptions import BadRequestError # type: ignore[import-not-found]
Expand All @@ -23,6 +23,7 @@ def test_to_dict(_mock_elasticsearch_client):
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
"init_parameters": {
"hosts": "some hosts",
"custom_mapping": None,
"index": "default",
"embedding_similarity_function": "cosine",
},
Expand All @@ -35,13 +36,15 @@ def test_from_dict(_mock_elasticsearch_client):
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
"init_parameters": {
"hosts": "some hosts",
"custom_mapping": None,
"index": "default",
"embedding_similarity_function": "cosine",
},
}
document_store = ElasticsearchDocumentStore.from_dict(data)
assert document_store._hosts == "some hosts"
assert document_store._index == "default"
assert document_store._custom_mapping is None
assert document_store._embedding_similarity_function == "cosine"


Expand Down Expand Up @@ -280,3 +283,33 @@ def test_write_documents_different_embedding_sizes_fail(self, document_store: El

with pytest.raises(DocumentStoreError):
document_store.write_documents(docs)

@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch")
def test_init_with_custom_mapping(self, mock_elasticsearch):
custom_mapping = {
"properties": {
"embedding": {"type": "dense_vector", "index": True, "similarity": "dot_product"},
"content": {"type": "text"},
},
"dynamic_templates": [
{
"strings": {
"path_match": "*",
"match_mapping_type": "string",
"mapping": {
"type": "keyword",
},
}
}
],
}
mock_client = Mock(
indices=Mock(create=Mock(), exists=Mock(return_value=False)),
)
mock_elasticsearch.return_value = mock_client

ElasticsearchDocumentStore(hosts="some hosts", custom_mapping=custom_mapping)
mock_client.indices.create.assert_called_once_with(
index="default",
mappings=custom_mapping,
)
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_to_dict(_mock_elasticsearch_client):
"document_store": {
"init_parameters": {
"hosts": "some fake host",
"custom_mapping": None,
"index": "default",
"embedding_similarity_function": "cosine",
},
Expand Down