From d1685dc00f9f135870830b5ebfcedbc1427ef39e Mon Sep 17 00:00:00 2001 From: Will Tai Date: Tue, 12 Mar 2024 17:57:40 +0000 Subject: [PATCH] Changed test to mock neo4j driver's execute_driver instead of the client's _database_query() Changed test_client.py to have mocks at execute_query level Changed precommit config to include ruff linting and formatting Update pre-commit Corrected cypher query error test and dimensions constraint in CreateIndexModel data type ruff formatting --- examples/similarity_search_for_text.py | 3 +- examples/similarity_search_for_vector.py | 2 +- src/neo4j_genai/client.py | 23 +-- src/neo4j_genai/types.py | 4 +- tests/conftest.py | 14 +- tests/test_client.py | 182 +++++++++++------------ 6 files changed, 112 insertions(+), 116 deletions(-) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index 6fc672784..20ea316b1 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -26,6 +26,7 @@ def embed_query(self, text: str) -> List[float]: # Initialize the client client = GenAIClient(driver, embedder) +client.drop_index(INDEX_NAME) # Creating the index client.create_index( INDEX_NAME, @@ -46,7 +47,7 @@ def embed_query(self, text: str) -> List[float]: parameters = { "vector": vector, } -client.database_query(insert_query, params=parameters) +client._database_query(insert_query, params=parameters) # Perform the similarity search for a text query query_text = "hello world" diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index 157dc29b4..159b81ae5 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -35,7 +35,7 @@ parameters = { "vector": vector, } -client.database_query(insert_query, params=parameters) +client._database_query(insert_query, params=parameters) # Perform the similarity search for a vector query query_vector = [random() for _ in range(DIMENSION)] diff --git a/src/neo4j_genai/client.py b/src/neo4j_genai/client.py index f1eb51021..7676c570c 100644 --- a/src/neo4j_genai/client.py +++ b/src/neo4j_genai/client.py @@ -29,7 +29,7 @@ def _verify_version(self) -> None: indexing. Raises a ValueError if the connected Neo4j version is not supported. """ - version = self.database_query("CALL dbms.components()")[0]["versions"][0] + version = self._database_query("CALL dbms.components()")[0]["versions"][0] if "aura" in version: version_tuple = ( *tuple(map(int, version.split("-")[0].split("."))), @@ -45,7 +45,9 @@ def _verify_version(self) -> None: "Version index is only supported in Neo4j version 5.11 or greater" ) - def database_query(self, query: str, params: Dict = {}) -> List[Dict[str, Any]]: + def _database_query( + self, query: str, params: Optional[Dict[str, Any]] = None + ) -> List[Dict[str, Any]]: """ This method sends a Cypher query to the connected Neo4j database and returns the results as a list of dictionaries. @@ -57,12 +59,11 @@ def database_query(self, query: str, params: Dict = {}) -> List[Dict[str, Any]]: Returns: List[Dict[str, Any]]: List of dictionaries containing the query results. """ - with self.driver.session() as session: - try: - data = session.run(query, params) - return [r.data() for r in data] - except CypherSyntaxError as e: - raise ValueError(f"Cypher Statement is not valid\n{e}") + try: + records, _, _ = self.driver.execute_query(query, params) + return records + except CypherSyntaxError as e: + raise ValueError(f"Cypher Statement is not valid\n{e}") def create_index( self, @@ -109,7 +110,7 @@ def create_index( "toInteger($dimensions)," "$similarity_fn )" ) - self.database_query(query, params=index_data.model_dump()) + self._database_query(query, params=index_data.model_dump()) def drop_index(self, name: str) -> None: """ @@ -124,7 +125,7 @@ def drop_index(self, name: str) -> None: parameters = { "name": name, } - self.database_query(query, params=parameters) + self._database_query(query, params=parameters) def similarity_search( self, @@ -176,7 +177,7 @@ def similarity_search( CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) YIELD node, score """ - records = self.database_query(db_query_string, params=parameters) + records = self._database_query(db_query_string, params=parameters) try: return [ diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index e56f2013f..3182a49c3 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -1,5 +1,5 @@ from typing import List, Any, Literal, Optional -from pydantic import BaseModel, PositiveInt, Field, model_validator +from pydantic import BaseModel, PositiveInt, model_validator class Neo4jRecord(BaseModel): @@ -15,7 +15,7 @@ class CreateIndexModel(BaseModel): name: str label: str property: str - dimensions: int = Field(ge=1, le=2048) + dimensions: int = PositiveInt similarity_fn: Literal["euclidean", "cosine"] diff --git a/tests/conftest.py b/tests/conftest.py index 729e3547c..a57c62521 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,13 @@ import pytest from neo4j_genai import GenAIClient -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, patch from typing import List from neo4j_genai.embedder import Embedder @pytest.fixture def driver(): - return Mock() + return MagicMock() @pytest.fixture @@ -20,8 +20,14 @@ def client(_verify_version_mock, driver): @patch("neo4j_genai.GenAIClient._verify_version") def client_with_embedder(_verify_version_mock, driver): class CustomEmbedder(Embedder): + def __init__(self): + self.dimension = 1536 + def embed_query(self, text: str) -> List[float]: - return [1.0 for _ in range(1536)] + return [1.0 for _ in range(self.dimension)] + + def set_dimension(self, dimension: int): + self.dimension = dimension embedder = CustomEmbedder() - return GenAIClient(driver, embedder) + return GenAIClient(driver, embedder), embedder diff --git a/tests/test_client.py b/tests/test_client.py index a82e917f5..809a47ac2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,49 +1,45 @@ import pytest from neo4j_genai import GenAIClient -from unittest.mock import Mock, patch from neo4j.exceptions import CypherSyntaxError -@patch( - "neo4j_genai.GenAIClient.database_query", - return_value=[{"versions": ["5.11-aura"]}], -) -def test_genai_client_supported_aura_version(mock_database_query, driver): - GenAIClient(driver) - mock_database_query.assert_called_once() +def test_genai_client_supported_aura_version(driver): + driver.execute_query.return_value = [[{"versions": ["5.11-aura"]}], None, None] + + GenAIClient(driver=driver) -@patch( - "neo4j_genai.GenAIClient.database_query", - return_value=[{"versions": ["5.3-aura"]}], -) def test_genai_client_no_supported_aura_version(driver): - with pytest.raises(ValueError): - GenAIClient(driver) + driver.execute_query.return_value = [[{"versions": ["5.3-aura"]}], None, None] + + with pytest.raises(ValueError) as excinfo: + GenAIClient(driver=driver) + + assert "Version index is only supported in Neo4j version 5.11 or greater" in str( + excinfo + ) -@patch( - "neo4j_genai.GenAIClient.database_query", - return_value=[{"versions": ["5.11.5"]}], -) -def test_genai_client_supported_version(mock_database_query, driver): - GenAIClient(driver) - mock_database_query.assert_called_once() +def test_genai_client_supported_version(driver): + driver.execute_query.return_value = [[{"versions": ["5.11.5"]}], None, None] + + GenAIClient(driver=driver) -@patch( - "neo4j_genai.GenAIClient.database_query", - return_value=[{"versions": ["4.3.5"]}], -) def test_genai_client_no_supported_version(driver): - with pytest.raises(ValueError): - GenAIClient(driver) + driver.execute_query.return_value = [[{"versions": ["4.3.5"]}], None, None] + with pytest.raises(ValueError) as excinfo: + GenAIClient(driver=driver) -@patch("neo4j_genai.GenAIClient.database_query") -def test_create_index_happy_path(mock_database_query, client): - client.create_index("my-index", "People", "name", 2048, "cosine") - query = ( + assert "Version index is only supported in Neo4j version 5.11 or greater" in str( + excinfo + ) + + +def test_create_index_happy_path(driver, client): + driver.execute_query.return_value = [None, None, None] + create_query = ( "CALL db.index.vector.createNodeIndex(" "$name," "$label," @@ -51,9 +47,12 @@ def test_create_index_happy_path(mock_database_query, client): "toInteger($dimensions)," "$similarity_fn )" ) - mock_database_query.assert_called_once_with( - query, - params={ + + client.create_index("my-index", "People", "name", 2048, "cosine") + + driver.execute_query.assert_called_once_with( + create_query, + { "name": "my-index", "label": "People", "property": "name", @@ -63,11 +62,6 @@ def test_create_index_happy_path(mock_database_query, client): ) -def test_create_index_too_big_dimension(client): - with pytest.raises(ValueError): - client.create_index("my-index", "People", "name", 5024, "cosine") - - def test_create_index_validation_error_dimensions(client): with pytest.raises(ValueError) as excinfo: client.create_index("my-index", "People", "name", "no-dim", "cosine") @@ -80,69 +74,56 @@ def test_create_index_validation_error_similarity_fn(client): assert "Error for inputs to create_index" in str(excinfo) -@patch("neo4j_genai.GenAIClient.database_query") -def test_drop_index(mock_database_query, client): - client.drop_index("my-index") +def test_drop_index(driver, client): + driver.execute_query.return_value = [None, None, None] + drop_query = "DROP INDEX $name" - query = "DROP INDEX $name" + client.drop_index("my-index") - mock_database_query.assert_called_with(query, params={"name": "my-index"}) + driver.execute_query.assert_called_once_with( + drop_query, + {"name": "my-index"}, + ) def test_database_query_happy(client, driver): - class Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - pass - - def run(self, query, params): - m_list = [] - for i in range(3): - mock = Mock() - mock.data.return_value = i - m_list.append(mock) + expected_db_result = [0, 1, 2] + driver.execute_query.return_value = [expected_db_result, None, None] - return m_list + res = client._database_query("MATCH (p:$label) RETURN p", {"label": "People"}) - driver.session = Session - res = client.database_query("MATCH (p:$label) RETURN p", {"label": "People"}) - assert res == [0, 1, 2] + assert res == expected_db_result def test_database_query_cypher_error(client, driver): - class Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - pass - - def run(self, query, params): - raise CypherSyntaxError + driver.execute_query.side_effect = CypherSyntaxError - driver.session = Session + with pytest.raises(ValueError) as excinfo: + client._database_query("MATCH (p:$label) RETURN p", {"label": "People"}) - with pytest.raises(ValueError): - client.database_query("MATCH (p:$label) RETURN p", {"label": "People"}) + assert "Cypher Statement is not valid" in str(excinfo) -@patch("neo4j_genai.GenAIClient.database_query") -def test_similarity_search_vector_happy_path(mock_database_query, client): +def test_similarity_search_vector_happy_path(driver, client): index_name = "my-index" - query_vector = [1.1, 2.2, 3.3] + dimensions = 1536 + query_vector = [1.0 for _ in range(dimensions)] top_k = 5 - - client.similarity_search(name=index_name, query_vector=query_vector, top_k=top_k) - - query = """ + driver.execute_query.return_value = [ + [{"node": "dummy-node", "score": 1.0}], + None, + None, + ] + search_query = """ CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) YIELD node, score """ - mock_database_query.assert_called_once_with( - query, - params={ + + client.similarity_search(name=index_name, query_vector=query_vector, top_k=top_k) + + driver.execute_query.assert_called_once_with( + search_query, + { "index_name": index_name, "top_k": top_k, "query_vector": query_vector, @@ -150,24 +131,29 @@ def test_similarity_search_vector_happy_path(mock_database_query, client): ) -@patch("neo4j_genai.GenAIClient.database_query") -def test_similarity_search_text_happy_path(mock_database_query, client_with_embedder): +def test_similarity_search_text_happy_path(driver, client_with_embedder): + client, embedder = client_with_embedder index_name = "my-index" query_text = "may thy knife chip and shatter" - query_vector = [1.0 for _ in range(1536)] + dimensions = 1536 + query_vector = [1.0 for _ in range(dimensions)] top_k = 5 - - client_with_embedder.similarity_search( - name=index_name, query_text=query_text, top_k=top_k - ) - - query = """ + driver.execute_query.return_value = [ + [{"node": "dummy-node", "score": 1.0}], + None, + None, + ] + embedder.set_dimension(dimensions) + search_query = """ CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) YIELD node, score """ - mock_database_query.assert_called_once_with( - query, - params={ + + client.similarity_search(name=index_name, query_text=query_text, top_k=top_k) + + driver.execute_query.assert_called_once_with( + search_query, + { "index_name": index_name, "top_k": top_k, "query_vector": query_vector, @@ -180,7 +166,7 @@ def test_similarity_search_missing_embedder_for_text(client): query_text = "may thy knife chip and shatter" top_k = 5 - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Embedding method required for text query"): client.similarity_search(name=index_name, query_text=query_text, top_k=top_k) @@ -190,7 +176,9 @@ def test_similarity_search_both_text_and_vector(client): query_vector = [1.1, 2.2, 3.3] top_k = 5 - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="You must provide exactly one of query_vector or query_text." + ): client.similarity_search( name=index_name, query_text=query_text,