Skip to content

Commit

Permalink
fix: serialize the path to the local db (#506)
Browse files Browse the repository at this point in the history
* serialize the path to the local db

* fix tests

* fix tests
  • Loading branch information
masci authored Feb 29, 2024
1 parent cc0f2bc commit 979a812
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 31 deletions.
4 changes: 2 additions & 2 deletions integrations/chroma/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ cover/
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
*.sqlite3
*.sqlite3-journal

# Flask stuff:
instance/
Expand Down
3 changes: 2 additions & 1 deletion integrations/chroma/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ module = [
"chromadb.*",
"haystack.*",
"haystack_integrations.*",
"pytest.*"
"pytest.*",
"numpy.*"
]
ignore_missing_imports = true
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,6 @@ def run(

return {"documents": self.document_store.search([query], top_k)[0]}

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
d = default_to_dict(self, filters=self.filters, top_k=self.top_k)
d["init_parameters"]["document_store"] = self.document_store.to_dict()

return d

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever":
"""
Expand All @@ -99,6 +87,20 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever":
data["init_parameters"]["document_store"] = document_store
return default_from_dict(cls, data)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
filters=self.filters,
top_k=self.top_k,
document_store=self.document_store.to_dict(),
)


@component
class ChromaEmbeddingRetriever(ChromaQueryTextRetriever):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import chromadb
import numpy as np
from chromadb.api.types import GetResult, QueryResult, validate_where, validate_where_document
from haystack import default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack.document_stores.types import DuplicatePolicy

Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
self._collection_name = collection_name
self._embedding_function = embedding_function
self._embedding_function_params = embedding_function_params
self._persist_path = persist_path
# Create the client instance
if persist_path is None:
self._chroma_client = chromadb.Client()
Expand Down Expand Up @@ -252,20 +254,22 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChromaDocumentStore":
:returns:
Deserialized component.
"""
return ChromaDocumentStore(**data)
return default_from_dict(cls, data)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
Dictionary with serialized data.
"""
return {
"collection_name": self._collection_name,
"embedding_function": self._embedding_function,
return default_to_dict(
self,
collection_name=self._collection_name,
embedding_function=self._embedding_function,
persist_path=self._persist_path,
**self._embedding_function_params,
}
)

@staticmethod
def _normalize_filters(filters: Dict[str, Any]) -> Tuple[List[str], Dict[str, Any], Dict[str, Any]]:
Expand Down
20 changes: 16 additions & 4 deletions integrations/chroma/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,28 @@ def test_to_json(self, request):
)
ds_dict = ds.to_dict()
assert ds_dict == {
"collection_name": request.node.name,
"embedding_function": "HuggingFaceEmbeddingFunction",
"api_key": "1234567890",
"type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore",
"init_parameters": {
"collection_name": "test_to_json",
"embedding_function": "HuggingFaceEmbeddingFunction",
"persist_path": None,
"api_key": "1234567890",
},
}

@pytest.mark.integration
def test_from_json(self):
collection_name = "test_collection"
function_name = "HuggingFaceEmbeddingFunction"
ds_dict = {"collection_name": collection_name, "embedding_function": function_name, "api_key": "1234567890"}
ds_dict = {
"type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore",
"init_parameters": {
"collection_name": "test_collection",
"embedding_function": "HuggingFaceEmbeddingFunction",
"persist_path": None,
"api_key": "1234567890",
},
}

ds = ChromaDocumentStore.from_dict(ds_dict)
assert ds._collection_name == collection_name
Expand Down
21 changes: 15 additions & 6 deletions integrations/chroma/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ def test_retriever_to_json(request):
"filters": {"foo": "bar"},
"top_k": 99,
"document_store": {
"collection_name": request.node.name,
"embedding_function": "HuggingFaceEmbeddingFunction",
"api_key": "1234567890",
"type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore",
"init_parameters": {
"collection_name": "test_retriever_to_json",
"embedding_function": "HuggingFaceEmbeddingFunction",
"persist_path": None,
"api_key": "1234567890",
},
},
},
}
Expand All @@ -34,15 +38,20 @@ def test_retriever_from_json(request):
"filters": {"bar": "baz"},
"top_k": 42,
"document_store": {
"collection_name": request.node.name,
"embedding_function": "HuggingFaceEmbeddingFunction",
"api_key": "1234567890",
"type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore",
"init_parameters": {
"collection_name": "test_retriever_from_json",
"embedding_function": "HuggingFaceEmbeddingFunction",
"persist_path": ".",
"api_key": "1234567890",
},
},
},
}
retriever = ChromaQueryTextRetriever.from_dict(data)
assert retriever.document_store._collection_name == request.node.name
assert retriever.document_store._embedding_function == "HuggingFaceEmbeddingFunction"
assert retriever.document_store._embedding_function_params == {"api_key": "1234567890"}
assert retriever.document_store._persist_path == "."
assert retriever.filters == {"bar": "baz"}
assert retriever.top_k == 42

0 comments on commit 979a812

Please sign in to comment.