Skip to content

Commit

Permalink
feat: Chroma - allow remote HTTP connection (#1094)
Browse files Browse the repository at this point in the history
* add http client

* remove old code

* update order

* Apply suggestions from code review

Co-authored-by: Stefano Fiorucci <[email protected]>

* add testcases

* run chroma db in the bg

* fix line too long

* fix testcase

* support chroma bg on windows

* fix chroma on win

* chroma fix for powershell on win

* simplification

* fix wrong skipif

* linting

---------

Co-authored-by: Stefano Fiorucci <[email protected]>
  • Loading branch information
alperkaya and anakin87 authored Sep 25, 2024
1 parent bca32be commit 3be7882
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 6 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/chroma.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ jobs:
if: matrix.python-version == '3.9' && runner.os == 'Linux'
run: hatch run docs

- name: Run Chroma server on Linux/macOS
if: matrix.os != 'windows-latest'
run: hatch run chroma run &

- name: Run tests
run: hatch run cov-retry

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(
collection_name: str = "documents",
embedding_function: str = "default",
persist_path: Optional[str] = None,
host: Optional[str] = None,
port: Optional[int] = None,
distance_function: Literal["l2", "cosine", "ip"] = "l2",
metadata: Optional[dict] = None,
**embedding_function_params,
Expand All @@ -48,7 +50,10 @@ def __init__(
:param collection_name: the name of the collection to use in the database.
:param embedding_function: the name of the embedding function to use to embed the query
:param persist_path: where to store the database. If None, the database will be `in-memory`.
:param persist_path: Path for local persistent storage. Cannot be used in combination with `host` and `port`.
If none of `persist_path`, `host`, and `port` is specified, the database will be `in-memory`.
:param host: The host address for the remote Chroma HTTP client connection. Cannot be used with `persist_path`.
:param port: The port number for the remote Chroma HTTP client connection. Cannot be used with `persist_path`.
:param distance_function: The distance metric for the embedding space.
- `"l2"` computes the Euclidean (straight-line) distance between vectors,
where smaller scores indicate more similarity.
Expand All @@ -75,12 +80,31 @@ def __init__(
self._collection_name = collection_name
self._embedding_function = embedding_function
self._embedding_function_params = embedding_function_params
self._persist_path = persist_path
self._distance_function = distance_function

self._persist_path = persist_path
self._host = host
self._port = port

# Create the client instance
if persist_path is None:
if persist_path and (host or port is not None):
error_message = (
"You must specify `persist_path` for local persistent storage or, "
"alternatively, `host` and `port` for remote HTTP client connection. "
"You cannot specify both options."
)
raise ValueError(error_message)
if host and port is not None:
# Remote connection via HTTP client
self._chroma_client = chromadb.HttpClient(
host=host,
port=port,
)
elif persist_path is None:
# In-memory storage
self._chroma_client = chromadb.Client()
else:
# Local persistent storage
self._chroma_client = chromadb.PersistentClient(path=persist_path)

embedding_func = get_embedding_function(embedding_function, **embedding_function_params)
Expand Down Expand Up @@ -341,6 +365,8 @@ def to_dict(self) -> Dict[str, Any]:
collection_name=self._collection_name,
embedding_function=self._embedding_function,
persist_path=self._persist_path,
host=self._host,
port=self._port,
distance_function=self._distance_function,
**self._embedding_function_params,
)
Expand Down
44 changes: 41 additions & 3 deletions integrations/chroma/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import uuid
from typing import List
from unittest import mock
import sys

import numpy as np
import pytest
Expand Down Expand Up @@ -66,6 +67,39 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do
assert doc_received.content == doc_expected.content
assert doc_received.meta == doc_expected.meta

def test_init_in_memory(self):
store = ChromaDocumentStore()

assert store._persist_path is None
assert store._host is None
assert store._port is None

def test_init_persistent_storage(self):
store = ChromaDocumentStore(persist_path="./path/to/local/store")

assert store._persist_path == "./path/to/local/store"
assert store._host is None
assert store._port is None

@pytest.mark.integration
@pytest.mark.skipif(
sys.platform == "win32",
reason="This test requires running the Chroma server. For simplicity, we don't run it on Windows.",
)
def test_init_http_connection(self):
store = ChromaDocumentStore(host="localhost", port=8000)

assert store._persist_path is None
assert store._host == "localhost"
assert store._port == 8000

def test_invalid_initialization_both_host_and_persist_path(self):
"""
Test that providing both host and persist_path raises an error.
"""
with pytest.raises(ValueError):
ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost")

def test_delete_empty(self, document_store: ChromaDocumentStore):
"""
Deleting a non-existing document should not raise with Chroma
Expand Down Expand Up @@ -125,24 +159,26 @@ def test_write_documents_unsupported_meta_values(self, document_store: ChromaDoc
assert written_docs[2].meta == {"ok": 123}

@pytest.mark.integration
def test_to_json(self, request):
def test_to_dict(self, request):
ds = ChromaDocumentStore(
collection_name=request.node.name, embedding_function="HuggingFaceEmbeddingFunction", api_key="1234567890"
)
ds_dict = ds.to_dict()
assert ds_dict == {
"type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore",
"init_parameters": {
"collection_name": "test_to_json",
"collection_name": "test_to_dict",
"embedding_function": "HuggingFaceEmbeddingFunction",
"persist_path": None,
"host": None,
"port": None,
"api_key": "1234567890",
"distance_function": "l2",
},
}

@pytest.mark.integration
def test_from_json(self):
def test_from_dict(self):
collection_name = "test_collection"
function_name = "HuggingFaceEmbeddingFunction"
ds_dict = {
Expand All @@ -151,6 +187,8 @@ def test_from_json(self):
"collection_name": "test_collection",
"embedding_function": "HuggingFaceEmbeddingFunction",
"persist_path": None,
"host": None,
"port": None,
"api_key": "1234567890",
"distance_function": "l2",
},
Expand Down
2 changes: 2 additions & 0 deletions integrations/chroma/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def test_retriever_to_json(request):
"collection_name": "test_retriever_to_json",
"embedding_function": "HuggingFaceEmbeddingFunction",
"persist_path": None,
"host": None,
"port": None,
"api_key": "1234567890",
"distance_function": "l2",
},
Expand Down

0 comments on commit 3be7882

Please sign in to comment.