Skip to content

Commit

Permalink
feat: Chroma - defer the DB connection (#1107)
Browse files Browse the repository at this point in the history
* defer DB from chroma

* added ensure_initialized

* addressed comments

* simplification and linting

* refinements to docstrings

---------

Co-authored-by: Stefano Fiorucci <[email protected]>
  • Loading branch information
alperkaya and anakin87 authored Sep 30, 2024
1 parent 907c10b commit 242e3c5
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 41 deletions.
2 changes: 2 additions & 0 deletions integrations/chroma/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ ignore = [
"PLR0915",
# Ignore unused params
"ARG002",
# Allow assertions
"S101",
]
unfixable = [
# Don't touch unused imports
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ def __init__(
**embedding_function_params,
):
"""
Initializes the store. The __init__ constructor is not part of the Store Protocol
and the signature can be customized to your needs. For example, parameters needed
to set up a database client would be passed to this method.
Creates a new ChromaDocumentStore instance.
It is meant to be connected to a Chroma collection.
Note: for the component to be part of a serializable pipeline, the __init__
parameters must be serializable, reason why we use a registry to configure the
Expand All @@ -65,7 +64,6 @@ def __init__(
:param metadata: a dictionary of chromadb collection parameters passed directly to chromadb's client
method `create_collection`. If it contains the key `"hnsw:space"`, the value will take precedence over the
`distance_function` parameter above.
:param embedding_function_params: additional parameters to pass to the embedding function.
"""

Expand All @@ -79,60 +77,70 @@ def __init__(
# Store the params for marshalling
self._collection_name = collection_name
self._embedding_function = embedding_function
self._embedding_func = get_embedding_function(embedding_function, **embedding_function_params)
self._embedding_function_params = embedding_function_params
self._distance_function = distance_function
self._metadata = metadata
self._collection = None

self._persist_path = persist_path
self._host = host
self._port = port

# Create the client instance
if persist_path and (host or port is not None):
error_message = (
"You must specify `persist_path` for local persistent storage or, "
"alternatively, `host` and `port` for remote HTTP client connection. "
"You cannot specify both options."
)
raise ValueError(error_message)
if host and port is not None:
# Remote connection via HTTP client
self._chroma_client = chromadb.HttpClient(
host=host,
port=port,
)
elif persist_path is None:
# In-memory storage
self._chroma_client = chromadb.Client()
else:
# Local persistent storage
self._chroma_client = chromadb.PersistentClient(path=persist_path)
self._initialized = False

embedding_func = get_embedding_function(embedding_function, **embedding_function_params)
def _ensure_initialized(self):
if not self._initialized:
# Create the client instance
if self._persist_path and (self._host or self._port is not None):
error_message = (
"You must specify `persist_path` for local persistent storage or, "
"alternatively, `host` and `port` for remote HTTP client connection. "
"You cannot specify both options."
)
raise ValueError(error_message)
if self._host and self._port is not None:
# Remote connection via HTTP client
client = chromadb.HttpClient(
host=self._host,
port=self._port,
)
elif self._persist_path is None:
# In-memory storage
client = chromadb.Client()
else:
# Local persistent storage
client = chromadb.PersistentClient(path=self._persist_path)

metadata = metadata or {}
if "hnsw:space" not in metadata:
metadata["hnsw:space"] = distance_function
self._metadata = self._metadata or {}
if "hnsw:space" not in self._metadata:
self._metadata["hnsw:space"] = self._distance_function

if collection_name in [c.name for c in self._chroma_client.list_collections()]:
self._collection = self._chroma_client.get_collection(collection_name, embedding_function=embedding_func)
if self._collection_name in [c.name for c in client.list_collections()]:
self._collection = client.get_collection(self._collection_name, embedding_function=self._embedding_func)

if metadata != self._collection.metadata:
logger.warning(
"Collection already exists. The `distance_function` and `metadata` parameters will be ignored."
if self._metadata != self._collection.metadata:
logger.warning(
"Collection already exists. "
"The `distance_function` and `metadata` parameters will be ignored."
)
else:
self._collection = client.create_collection(
name=self._collection_name,
metadata=self._metadata,
embedding_function=self._embedding_func,
)
else:
self._collection = self._chroma_client.create_collection(
name=collection_name,
metadata=metadata,
embedding_function=embedding_func,
)

self._initialized = True

def count_documents(self) -> int:
"""
Returns how many documents are present in the document store.
:returns: how many documents are present in the document store.
"""
self._ensure_initialized()
assert self._collection is not None
return self._collection.count()

def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
Expand Down Expand Up @@ -197,6 +205,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
:param filters: the filters to apply to the document list.
:returns: a list of Documents that match the given filters.
"""
self._ensure_initialized()
assert self._collection is not None

if filters:
chroma_filter = _convert_filters(filters)
kwargs: Dict[str, Any] = {"where": chroma_filter.where}
Expand Down Expand Up @@ -227,6 +238,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
:returns:
The number of documents written
"""
self._ensure_initialized()
assert self._collection is not None

for doc in documents:
if not isinstance(doc, Document):
msg = "param 'documents' must contain a list of objects of type Document"
Expand Down Expand Up @@ -280,8 +294,11 @@ def delete_documents(self, document_ids: List[str]) -> None:
"""
Deletes all documents with a matching document_ids from the document store.
:param document_ids: the object_ids to delete
:param document_ids: the document ids to delete
"""
self._ensure_initialized()
assert self._collection is not None

self._collection.delete(ids=document_ids)

def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any]] = None) -> List[List[Document]]:
Expand All @@ -292,6 +309,9 @@ def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any
:param filters: a dictionary of filters to apply to the search. Accepts filters in haystack format.
:returns: matching documents for each query.
"""
self._ensure_initialized()
assert self._collection is not None

if filters is None:
results = self._collection.query(
query_texts=queries,
Expand Down Expand Up @@ -323,6 +343,9 @@ def search_embeddings(
:returns: a list of lists of documents that match the given filters.
"""
self._ensure_initialized()
assert self._collection is not None

if filters is None:
results = self._collection.query(
query_embeddings=query_embeddings,
Expand Down
10 changes: 9 additions & 1 deletion integrations/chroma/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def test_invalid_initialization_both_host_and_persist_path(self):
Test that providing both host and persist_path raises an error.
"""
with pytest.raises(ValueError):
ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost")
store = ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost")
store._ensure_initialized()

def test_delete_empty(self, document_store: ChromaDocumentStore):
"""
Expand Down Expand Up @@ -207,6 +208,7 @@ def test_same_collection_name_reinitialization(self):
@pytest.mark.integration
def test_distance_metric_initialization(self):
store = ChromaDocumentStore("test_2", distance_function="cosine")
store._ensure_initialized()
assert store._collection.metadata["hnsw:space"] == "cosine"

with pytest.raises(ValueError):
Expand All @@ -215,9 +217,11 @@ def test_distance_metric_initialization(self):
@pytest.mark.integration
def test_distance_metric_reinitialization(self, caplog):
store = ChromaDocumentStore("test_4", distance_function="cosine")
store._ensure_initialized()

with caplog.at_level(logging.WARNING):
new_store = ChromaDocumentStore("test_4", distance_function="ip")
new_store._ensure_initialized()

assert (
"Collection already exists. The `distance_function` and `metadata` parameters will be ignored."
Expand All @@ -238,6 +242,8 @@ def test_metadata_initialization(self, caplog):
"hnsw:M": 103,
},
)
store._ensure_initialized()

assert store._collection.metadata["hnsw:space"] == "ip"
assert store._collection.metadata["hnsw:search_ef"] == 101
assert store._collection.metadata["hnsw:construction_ef"] == 102
Expand All @@ -254,6 +260,8 @@ def test_metadata_initialization(self, caplog):
},
)

new_store._ensure_initialized()

assert (
"Collection already exists. The `distance_function` and `metadata` parameters will be ignored."
in caplog.text
Expand Down

0 comments on commit 242e3c5

Please sign in to comment.