From 29ab1f7b68cfb972a02c3a87692b9a4f73952cb1 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Tue, 27 Feb 2024 12:01:45 +0000 Subject: [PATCH 01/23] first commit --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index e69de29b..306936a2 100644 --- a/README.md +++ b/README.md @@ -0,0 +1 @@ +# neo4j-genai-python From ff98901e2cf5de79dd792d2c2aae5db0c68cac42 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Wed, 28 Feb 2024 17:48:50 +0000 Subject: [PATCH 02/23] Adds first GenAIClient --- .gitignore | 2 + neo4j_genai_python/src/__init__.py | 0 neo4j_genai_python/src/client.py | 138 ++++++++++ poetry.lock | 406 +++++++++++++++++++++++++++++ pyproject.toml | 36 ++- tests/__init__.py | 0 tests/test_client.py | 10 + 7 files changed, 591 insertions(+), 1 deletion(-) create mode 100644 neo4j_genai_python/src/__init__.py create mode 100644 neo4j_genai_python/src/client.py create mode 100644 tests/__init__.py create mode 100644 tests/test_client.py diff --git a/.gitignore b/.gitignore index 849ddff3..7264d9de 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ dist/ +__pycache__/ +*.py[cod] diff --git a/neo4j_genai_python/src/__init__.py b/neo4j_genai_python/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/neo4j_genai_python/src/client.py b/neo4j_genai_python/src/client.py new file mode 100644 index 00000000..1e04c623 --- /dev/null +++ b/neo4j_genai_python/src/client.py @@ -0,0 +1,138 @@ +import neo4j + +from typing import List, Dict, Any, Optional +from neo4j import Driver, GraphDatabase +from neo4j.exceptions import CypherSyntaxError + +from abc import ABC, abstractmethod + + +class Embeddings(ABC): + """Interface for embedding models.""" + + @abstractmethod + def embed_query(self, text: str) -> List[float]: + """Embed query text.""" + + +class GenAIClient: + def __init__(self, driver: Driver, embeddings: Optional[Embeddings]) -> None: + # Verify if the version supports vector index + self._verify_version(driver) + self.embeddings = embeddings if embeddings else None + + def _verify_version(self, driver: Driver) -> None: + """ + Check if the connected Neo4j database version supports vector indexing. + + Queries the Neo4j database to retrieve its version and compares it + against a target version (5.11.0) that is known to support vector + indexing. Raises a ValueError if the connected Neo4j version is + not supported. + """ + version = self.database_query(driver, "CALL dbms.components()")[0]["versions"][ + 0 + ] + if "aura" in version: + version_tuple = (*tuple(map(int, version.split("-")[0].split("."))), 0) + else: + version_tuple = tuple(map(int, version.split("."))) + + target_version = (5, 11, 0) + + if version_tuple < target_version: + raise ValueError( + "Version index is only supported in Neo4j version 5.11 or greater" + ) + + def database_query( + self, driver: Driver, query: str, params: Optional[Dict] = None + ) -> List[Dict[str, Any]]: + """ + This method sends a Cypher query to the connected Neo4j database + and returns the results as a list of dictionaries. + + Args: + query (str): The Cypher query to execute. + params (Dict, optional): Dictionary of query parameters. Defaults to {}. + + Returns: + List[Dict[str, Any]]: List of dictionaries containing the query results. + """ + params = params or {} + # TODO: how do we pass this database variable + with driver.session(database="neo4j") as session: + try: + data = session.run(query, params) + return [r.data() for r in data] + except CypherSyntaxError as e: + raise ValueError(f"Cypher Statement is not valid\n{e}") + + def create_index( + self, + driver: Driver, + name: str, + label: str, + property: str, + dimensions: int, + similarity_fn: str, + ) -> None: + """ + This method constructs a Cypher query and executes it + to create a new vector index in Neo4j. + """ + index_query = ( + "CALL db.index.vector.createNodeIndex(" + "$name," + "$label," + "$property," + "toInteger($dimensions)," + "$similarity_fn )" + ) + + parameters = { + "name": name, + "node_label": label, + "property": property, + "dimensions": dimensions, + "similarity_fn": similarity_fn, + } + self.database_query(driver, index_query, params=parameters) + + def similarity_search( + self, + driver: Driver, + name: str, + query_vector: Optional[List[float]], + query_text: Optional[str], + top_k: int, + ) -> List[Dict[str, Any]]: + """ + Performs the similarity search + """ + if not ((query_vector is not None) ^ (query_text is not None)): + raise ValueError("You must provide one of query_vector or query_text.") + + if query_vector: + parameters = { + "index_name": name, + "top_k": top_k, + "vector": query_vector, + } + + if query_text: + # TODO: do we need to validate embeddings? Normalizing etc. + if self.embeddings: + vector_embedding = self.embeddings.embed_query(query_text) + parameters = { + "index_name": name, + "top_k": top_k, + "vector": vector_embedding, + } + else: + raise ValueError( + "Embeddings required in definition to perform search for query_text" + ) + + db_query_string = "CALL db.index.vector.queryNodes($index_name, $top_k, $vector) YIELD node, score" + return self.database_query(driver, db_query_string, params=parameters) diff --git a/poetry.lock b/poetry.lock index 90a60f18..218ecfbf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,6 +1,217 @@ # This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] +<<<<<<< HEAD +======= +name = "astroid" +version = "3.1.0" +description = "An abstract syntax tree for Python with inference support." +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "astroid-3.1.0-py3-none-any.whl", hash = "sha256:951798f922990137ac090c53af473db7ab4e70c770e6d7fae0cec59f74411819"}, + {file = "astroid-3.1.0.tar.gz", hash = "sha256:ac248253bfa4bd924a0de213707e7ebeeb3138abeb48d798784ead1e56d419d4"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} + +[[package]] +name = "black" +version = "24.2.0" +description = "The uncompromising code formatter." +optional = false +python-versions = ">=3.8" +files = [ + {file = "black-24.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6981eae48b3b33399c8757036c7f5d48a535b962a7c2310d19361edeef64ce29"}, + {file = "black-24.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d533d5e3259720fdbc1b37444491b024003e012c5173f7d06825a77508085430"}, + {file = "black-24.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61a0391772490ddfb8a693c067df1ef5227257e72b0e4108482b8d41b5aee13f"}, + {file = "black-24.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:992e451b04667116680cb88f63449267c13e1ad134f30087dec8527242e9862a"}, + {file = "black-24.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:163baf4ef40e6897a2a9b83890e59141cc8c2a98f2dda5080dc15c00ee1e62cd"}, + {file = "black-24.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e37c99f89929af50ffaf912454b3e3b47fd64109659026b678c091a4cd450fb2"}, + {file = "black-24.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9de21bafcba9683853f6c96c2d515e364aee631b178eaa5145fc1c61a3cc92"}, + {file = "black-24.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:9db528bccb9e8e20c08e716b3b09c6bdd64da0dd129b11e160bf082d4642ac23"}, + {file = "black-24.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d84f29eb3ee44859052073b7636533ec995bd0f64e2fb43aeceefc70090e752b"}, + {file = "black-24.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e08fb9a15c914b81dd734ddd7fb10513016e5ce7e6704bdd5e1251ceee51ac9"}, + {file = "black-24.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:810d445ae6069ce64030c78ff6127cd9cd178a9ac3361435708b907d8a04c693"}, + {file = "black-24.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:ba15742a13de85e9b8f3239c8f807723991fbfae24bad92d34a2b12e81904982"}, + {file = "black-24.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7e53a8c630f71db01b28cd9602a1ada68c937cbf2c333e6ed041390d6968faf4"}, + {file = "black-24.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:93601c2deb321b4bad8f95df408e3fb3943d85012dddb6121336b8e24a0d1218"}, + {file = "black-24.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0057f800de6acc4407fe75bb147b0c2b5cbb7c3ed110d3e5999cd01184d53b0"}, + {file = "black-24.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:faf2ee02e6612577ba0181f4347bcbcf591eb122f7841ae5ba233d12c39dcb4d"}, + {file = "black-24.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:057c3dc602eaa6fdc451069bd027a1b2635028b575a6c3acfd63193ced20d9c8"}, + {file = "black-24.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:08654d0797e65f2423f850fc8e16a0ce50925f9337fb4a4a176a7aa4026e63f8"}, + {file = "black-24.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca610d29415ee1a30a3f30fab7a8f4144e9d34c89a235d81292a1edb2b55f540"}, + {file = "black-24.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:4dd76e9468d5536abd40ffbc7a247f83b2324f0c050556d9c371c2b9a9a95e31"}, + {file = "black-24.2.0-py3-none-any.whl", hash = "sha256:e8a6ae970537e67830776488bca52000eaa37fa63b9988e8c487458d9cd5ace6"}, + {file = "black-24.2.0.tar.gz", hash = "sha256:bce4f25c27c3435e4dace4815bcb2008b87e167e3bf4ee47ccdc5ce906eb4894"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + +[[package]] +name = "click" +version = "8.1.7" +description = "Composable command line interface toolkit" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, + {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "dill" +version = "0.3.8" +description = "serialize all of Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, + {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] +profile = ["gprof2dot (>=2022.7.29)"] + +[[package]] +name = "exceptiongroup" +version = "1.2.0" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.2.0-py3-none-any.whl", hash = "sha256:4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14"}, + {file = "exceptiongroup-1.2.0.tar.gz", hash = "sha256:91f5c769735f051a4290d52edd0858999b57e5876e9f85937691bd4c9fa3ed68"}, +] + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] +name = "isort" +version = "5.13.2" +description = "A Python utility / library to sort Python imports." +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, + {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, +] + +[package.extras] +colors = ["colorama (>=0.4.6)"] + +[[package]] +name = "mccabe" +version = "0.7.0" +description = "McCabe checker, plugin for flake8" +optional = false +python-versions = ">=3.6" +files = [ + {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, + {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, +] + +[[package]] +name = "mypy" +version = "1.8.0" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:485a8942f671120f76afffff70f259e1cd0f0cfe08f81c05d8816d958d4577d3"}, + {file = "mypy-1.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:df9824ac11deaf007443e7ed2a4a26bebff98d2bc43c6da21b2b64185da011c4"}, + {file = "mypy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2afecd6354bbfb6e0160f4e4ad9ba6e4e003b767dd80d85516e71f2e955ab50d"}, + {file = "mypy-1.8.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8963b83d53ee733a6e4196954502b33567ad07dfd74851f32be18eb932fb1cb9"}, + {file = "mypy-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:e46f44b54ebddbeedbd3d5b289a893219065ef805d95094d16a0af6630f5d410"}, + {file = "mypy-1.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:855fe27b80375e5c5878492f0729540db47b186509c98dae341254c8f45f42ae"}, + {file = "mypy-1.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4c886c6cce2d070bd7df4ec4a05a13ee20c0aa60cb587e8d1265b6c03cf91da3"}, + {file = "mypy-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d19c413b3c07cbecf1f991e2221746b0d2a9410b59cb3f4fb9557f0365a1a817"}, + {file = "mypy-1.8.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9261ed810972061388918c83c3f5cd46079d875026ba97380f3e3978a72f503d"}, + {file = "mypy-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:51720c776d148bad2372ca21ca29256ed483aa9a4cdefefcef49006dff2a6835"}, + {file = "mypy-1.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:52825b01f5c4c1c4eb0db253ec09c7aa17e1a7304d247c48b6f3599ef40db8bd"}, + {file = "mypy-1.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f5ac9a4eeb1ec0f1ccdc6f326bcdb464de5f80eb07fb38b5ddd7b0de6bc61e55"}, + {file = "mypy-1.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afe3fe972c645b4632c563d3f3eff1cdca2fa058f730df2b93a35e3b0c538218"}, + {file = "mypy-1.8.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:42c6680d256ab35637ef88891c6bd02514ccb7e1122133ac96055ff458f93fc3"}, + {file = "mypy-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:720a5ca70e136b675af3af63db533c1c8c9181314d207568bbe79051f122669e"}, + {file = "mypy-1.8.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:028cf9f2cae89e202d7b6593cd98db6759379f17a319b5faf4f9978d7084cdc6"}, + {file = "mypy-1.8.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4e6d97288757e1ddba10dd9549ac27982e3e74a49d8d0179fc14d4365c7add66"}, + {file = "mypy-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f1478736fcebb90f97e40aff11a5f253af890c845ee0c850fe80aa060a267c6"}, + {file = "mypy-1.8.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42419861b43e6962a649068a61f4a4839205a3ef525b858377a960b9e2de6e0d"}, + {file = "mypy-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:2b5b6c721bd4aabaadead3a5e6fa85c11c6c795e0c81a7215776ef8afc66de02"}, + {file = "mypy-1.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5c1538c38584029352878a0466f03a8ee7547d7bd9f641f57a0f3017a7c905b8"}, + {file = "mypy-1.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ef4be7baf08a203170f29e89d79064463b7fc7a0908b9d0d5114e8009c3a259"}, + {file = "mypy-1.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7178def594014aa6c35a8ff411cf37d682f428b3b5617ca79029d8ae72f5402b"}, + {file = "mypy-1.8.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ab3c84fa13c04aeeeabb2a7f67a25ef5d77ac9d6486ff33ded762ef353aa5592"}, + {file = "mypy-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:99b00bc72855812a60d253420d8a2eae839b0afa4938f09f4d2aa9bb4654263a"}, + {file = "mypy-1.8.0-py3-none-any.whl", hash = "sha256:538fd81bb5e430cc1381a443971c0475582ff9f434c16cd46d2c66763ce85d9d"}, + {file = "mypy-1.8.0.tar.gz", hash = "sha256:6ff8b244d7085a0b425b56d327b480c3b29cafbd2eff27316a004f9a7391ae07"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = ">=4.1.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + +[[package]] +>>>>>>> ca1e4dd (Adds first GenAIClient) name = "neo4j" version = "5.17.0" description = "Neo4j Bolt driver for Python" @@ -19,6 +230,130 @@ pandas = ["numpy (>=1.7.0,<2.0.0)", "pandas (>=1.1.0,<3.0.0)"] pyarrow = ["pyarrow (>=1.0.0)"] [[package]] +<<<<<<< HEAD +======= +name = "packaging" +version = "23.2" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, + {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, +] + +[[package]] +name = "pathspec" +version = "0.12.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, +] + +[[package]] +name = "platformdirs" +version = "4.2.0" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +optional = false +python-versions = ">=3.8" +files = [ + {file = "platformdirs-4.2.0-py3-none-any.whl", hash = "sha256:0614df2a2f37e1a662acbd8e2b25b92ccf8632929bc6d43467e17fe89c75e068"}, + {file = "platformdirs-4.2.0.tar.gz", hash = "sha256:ef0cc731df711022c174543cb70a9b5bd22e5a9337c8624ef2c2ceb8ddad8768"}, +] + +[package.extras] +docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] + +[[package]] +name = "pluggy" +version = "1.4.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"}, + {file = "pluggy-1.4.0.tar.gz", hash = "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "pylint" +version = "3.1.0" +description = "python code static checker" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "pylint-3.1.0-py3-none-any.whl", hash = "sha256:507a5b60953874766d8a366e8e8c7af63e058b26345cfcb5f91f89d987fd6b74"}, + {file = "pylint-3.1.0.tar.gz", hash = "sha256:6a69beb4a6f63debebaab0a3477ecd0f559aa726af4954fc948c51f7a2549e23"}, +] + +[package.dependencies] +astroid = ">=3.1.0,<=3.2.0-dev0" +colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} +dill = [ + {version = ">=0.2", markers = "python_version < \"3.11\""}, + {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, + {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, +] +isort = ">=4.2.5,<5.13.0 || >5.13.0,<6" +mccabe = ">=0.6,<0.8" +platformdirs = ">=2.2.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +tomlkit = ">=0.10.1" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} + +[package.extras] +spelling = ["pyenchant (>=3.2,<4.0)"] +testutils = ["gitpython (>3)"] + +[[package]] +name = "pytest" +version = "8.0.2" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-8.0.2-py3-none-any.whl", hash = "sha256:edfaaef32ce5172d5466b5127b42e0d6d35ebbe4453f0e3505d96afd93f6b096"}, + {file = "pytest-8.0.2.tar.gz", hash = "sha256:d4051d623a2e0b7e51960ba963193b09ce6daeb9759a451844a21e4ddedfc1bd"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=1.3.0,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pytest-mock" +version = "3.12.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.12.0.tar.gz", hash = "sha256:31a40f038c22cad32287bb43932054451ff5583ff094bca6f675df2f8bc1a6e9"}, + {file = "pytest_mock-3.12.0-py3-none-any.whl", hash = "sha256:0972719a7263072da3a21c7f4773069bcc7486027d7e8e1f81d98a47e701bc4f"}, +] + +[package.dependencies] +pytest = ">=5.0" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + +[[package]] +>>>>>>> ca1e4dd (Adds first GenAIClient) name = "pytz" version = "2024.1" description = "World timezone definitions, modern and historical" @@ -29,7 +364,78 @@ files = [ {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, ] +<<<<<<< HEAD [metadata] lock-version = "2.0" python-versions = ">=3.7" content-hash = "4638595b14c9dedbce42766fd4745866b8011e8cd03967f27d170a0200602437" +======= +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + +[[package]] +name = "tomlkit" +version = "0.12.4" +description = "Style preserving TOML library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomlkit-0.12.4-py3-none-any.whl", hash = "sha256:5cd82d48a3dd89dee1f9d64420aa20ae65cfbd00668d6f094d7578a78efbb77b"}, + {file = "tomlkit-0.12.4.tar.gz", hash = "sha256:7ca1cfc12232806517a8515047ba66a19369e71edf2439d0f5824f91032b6cc3"}, +] + +[[package]] +name = "types-requests" +version = "2.31.0.20240218" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.31.0.20240218.tar.gz", hash = "sha256:f1721dba8385958f504a5386240b92de4734e047a08a40751c1654d1ac3349c5"}, + {file = "types_requests-2.31.0.20240218-py3-none-any.whl", hash = "sha256:a82807ec6ddce8f00fe0e949da6d6bc1fbf1715420218a9640d695f70a9e5a9b"}, +] + +[package.dependencies] +urllib3 = ">=2" + +[[package]] +name = "typing-extensions" +version = "4.10.0" +description = "Backported and Experimental Type Hints for Python 3.8+" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"}, + {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, +] + +[[package]] +name = "urllib3" +version = "2.2.1" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=3.8" +files = [ + {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, + {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, +] + +[package.extras] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +h2 = ["h2 (>=4,<5)"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] + +[metadata] +lock-version = "2.0" +python-versions = "^3.8" +content-hash = "a2e97b40deca16a26490a2fce2e70c3742942ecbb026bca3e83219888530bfcf" +>>>>>>> ca1e4dd (Adds first GenAIClient) diff --git a/pyproject.toml b/pyproject.toml index f4af4c7f..cf9b7691 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,18 +1,52 @@ [tool.poetry] +<<<<<<< HEAD name = "neo4j-genai" version = "0.1.0" description = "" +======= +name = "neo4j-genai-python" +version = "0.1.0" +description = "Python package to allow easy integration to Neo4j's GenAI features" +>>>>>>> ca1e4dd (Adds first GenAIClient) authors = ["Neo4j, Inc "] license = "Apache License, Version 2.0" readme = "README.md" [tool.poetry.dependencies] +<<<<<<< HEAD python = ">=3.7" neo4j = "^5.17.0" +======= +python = "^3.8" +neo4j = "^5.17.0" +types-requests = "^2.31.0.20240218" +pytest = "^8.0.2" +pytest-mock = "^3.12.0" + +[tool.poetry.group.dev.dependencies] +pylint = "^3.1.0" +mypy = "^1.8.0" +black = "^24.2.0" +>>>>>>> ca1e4dd (Adds first GenAIClient) [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" -exclude = ["**/tests/"] \ No newline at end of file +<<<<<<< HEAD +exclude = ["**/tests/"] +======= +[tool.black] +line-length = 88 +target-version = ['py37'] +include = '\.pyi?$' +exclude = ''' +/( + .git + | .venv + | build + | dist +)/ +''' +>>>>>>> ca1e4dd (Adds first GenAIClient) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 00000000..6c74a29c --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,10 @@ +import pytest +from neo4j import GraphDatabase +from neo4j_genai_python.src import GenAIClient, Embeddings + +@pytest.fixture +def genai_client(mocker): + mock_driver = mocker.MagicMock(spec=GraphDatabase.driver) + mock_embeddings = mocker.MagicMock(spec=Embeddings) + client = GenAIClient(driver=mock_driver, embeddings=mock_embeddings) + return client \ No newline at end of file From 3c2fd0e4061d1d8b1eb938dd4fedf1e0ee779c84 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Thu, 29 Feb 2024 10:01:16 +0000 Subject: [PATCH 03/23] push pre-commit-config --- .pre-commit-config.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..34f891f9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 # Use the ref you want to point at + hooks: + - id: trailing-whitespace + # - id: ... \ No newline at end of file From 17692473af13316243b6a29abcae266988f4af98 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Thu, 29 Feb 2024 10:22:04 +0000 Subject: [PATCH 04/23] fixed pyproject.toml and remove tests --- pyproject.toml | 17 +---------------- tests/test_client.py | 10 ---------- 2 files changed, 1 insertion(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cf9b7691..17bbd5a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,23 +1,12 @@ [tool.poetry] -<<<<<<< HEAD -name = "neo4j-genai" -version = "0.1.0" -description = "" -======= name = "neo4j-genai-python" version = "0.1.0" description = "Python package to allow easy integration to Neo4j's GenAI features" ->>>>>>> ca1e4dd (Adds first GenAIClient) authors = ["Neo4j, Inc "] license = "Apache License, Version 2.0" readme = "README.md" [tool.poetry.dependencies] -<<<<<<< HEAD -python = ">=3.7" -neo4j = "^5.17.0" - -======= python = "^3.8" neo4j = "^5.17.0" types-requests = "^2.31.0.20240218" @@ -28,18 +17,15 @@ pytest-mock = "^3.12.0" pylint = "^3.1.0" mypy = "^1.8.0" black = "^24.2.0" ->>>>>>> ca1e4dd (Adds first GenAIClient) [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" -<<<<<<< HEAD exclude = ["**/tests/"] -======= [tool.black] line-length = 88 -target-version = ['py37'] +target-version = ['py38'] include = '\.pyi?$' exclude = ''' /( @@ -49,4 +35,3 @@ exclude = ''' | dist )/ ''' ->>>>>>> ca1e4dd (Adds first GenAIClient) diff --git a/tests/test_client.py b/tests/test_client.py index 6c74a29c..e69de29b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,10 +0,0 @@ -import pytest -from neo4j import GraphDatabase -from neo4j_genai_python.src import GenAIClient, Embeddings - -@pytest.fixture -def genai_client(mocker): - mock_driver = mocker.MagicMock(spec=GraphDatabase.driver) - mock_embeddings = mocker.MagicMock(spec=Embeddings) - client = GenAIClient(driver=mock_driver, embeddings=mock_embeddings) - return client \ No newline at end of file From 6b33cbbe91aea870b5e03dc4f54781d1dc6885db Mon Sep 17 00:00:00 2001 From: Will Tai Date: Thu, 29 Feb 2024 12:17:13 +0000 Subject: [PATCH 05/23] Similarity search example --- examples/similarity_search.py | 44 ++++++++++++++++++++++++++++++++ neo4j_genai_python/src/client.py | 23 ++++++++++++----- 2 files changed, 61 insertions(+), 6 deletions(-) create mode 100644 examples/similarity_search.py diff --git a/examples/similarity_search.py b/examples/similarity_search.py new file mode 100644 index 00000000..9812dc71 --- /dev/null +++ b/examples/similarity_search.py @@ -0,0 +1,44 @@ +from neo4j import GraphDatabase +from neo4j_genai_python.src.client import GenAIClient + +from random import random + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +INDEX_NAME = "embedding-name" +DIMENSION = 1536 +driver = GraphDatabase.driver(URI, auth=AUTH) + +client = GenAIClient(driver) + +client.drop_index(driver, INDEX_NAME) + +# Creating the index +client.create_index( + driver, + INDEX_NAME, + label="label", + property="property", + dimensions=DIMENSION, + similarity_fn="euclidean", +) + +# Upsert the vector +vector = [random() for _ in range(DIMENSION)] +insert_query = ( + "MATCH (n:Node {id: $id})" + "CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)" + "RETURN n" +) +parameters = { + "id": 1, + "vector": vector, +} +client.database_query(driver, insert_query, params=parameters) + +# Perform the similarity search +query_vector = [random() for _ in range(DIMENSION)] +print(client.similarity_search( + driver, INDEX_NAME, query_vector=query_vector, top_k=5 +)) diff --git a/neo4j_genai_python/src/client.py b/neo4j_genai_python/src/client.py index 1e04c623..7eac5b69 100644 --- a/neo4j_genai_python/src/client.py +++ b/neo4j_genai_python/src/client.py @@ -16,7 +16,7 @@ def embed_query(self, text: str) -> List[float]: class GenAIClient: - def __init__(self, driver: Driver, embeddings: Optional[Embeddings]) -> None: + def __init__(self, driver: Driver, embeddings: Optional[Embeddings] = None) -> None: # Verify if the version supports vector index self._verify_version(driver) self.embeddings = embeddings if embeddings else None @@ -92,20 +92,31 @@ def create_index( parameters = { "name": name, - "node_label": label, + "label": label, "property": property, "dimensions": dimensions, "similarity_fn": similarity_fn, } self.database_query(driver, index_query, params=parameters) + def drop_index(self, driver, name: str) -> None: + """ + This method constructs a Cypher query and executes it + to drop a vector index in Neo4j. + """ + index_query = "DROP INDEX $name" + parameters = { + "name": name, + } + self.database_query(driver, index_query, params=parameters) + def similarity_search( self, driver: Driver, name: str, - query_vector: Optional[List[float]], - query_text: Optional[str], - top_k: int, + query_vector: Optional[List[float]] = None, + query_text: Optional[str] = None, + top_k: int = 5, ) -> List[Dict[str, Any]]: """ Performs the similarity search @@ -131,7 +142,7 @@ def similarity_search( } else: raise ValueError( - "Embeddings required in definition to perform search for query_text" + "Embedding method required to perform search for query_text" ) db_query_string = "CALL db.index.vector.queryNodes($index_name, $top_k, $vector) YIELD node, score" From 67a634203f04a4d3348b34c358a401984a58bf82 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Thu, 29 Feb 2024 13:56:02 +0000 Subject: [PATCH 06/23] Added embeddings type --- .gitignore | 2 +- ...earch.py => similarity_search_for_text.py} | 0 examples/similarity_search_for_vector.py | 44 ++++++++++++++++++ .../__pycache__/__init__.cpython-311.pyc | Bin 170 -> 0 bytes neo4j_genai/__pycache__/tools.cpython-311.pyc | Bin 327 -> 0 bytes neo4j_genai_python/src/client.py | 20 ++------ neo4j_genai_python/src/embeddings.py | 10 ++++ 7 files changed, 59 insertions(+), 17 deletions(-) rename examples/{similarity_search.py => similarity_search_for_text.py} (100%) create mode 100644 examples/similarity_search_for_vector.py delete mode 100644 neo4j_genai/__pycache__/__init__.cpython-311.pyc delete mode 100644 neo4j_genai/__pycache__/tools.cpython-311.pyc create mode 100644 neo4j_genai_python/src/embeddings.py diff --git a/.gitignore b/.gitignore index 7264d9de..77ed5a6c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ dist/ -__pycache__/ +**/__pycache__/* *.py[cod] diff --git a/examples/similarity_search.py b/examples/similarity_search_for_text.py similarity index 100% rename from examples/similarity_search.py rename to examples/similarity_search_for_text.py diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py new file mode 100644 index 00000000..9812dc71 --- /dev/null +++ b/examples/similarity_search_for_vector.py @@ -0,0 +1,44 @@ +from neo4j import GraphDatabase +from neo4j_genai_python.src.client import GenAIClient + +from random import random + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +INDEX_NAME = "embedding-name" +DIMENSION = 1536 +driver = GraphDatabase.driver(URI, auth=AUTH) + +client = GenAIClient(driver) + +client.drop_index(driver, INDEX_NAME) + +# Creating the index +client.create_index( + driver, + INDEX_NAME, + label="label", + property="property", + dimensions=DIMENSION, + similarity_fn="euclidean", +) + +# Upsert the vector +vector = [random() for _ in range(DIMENSION)] +insert_query = ( + "MATCH (n:Node {id: $id})" + "CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)" + "RETURN n" +) +parameters = { + "id": 1, + "vector": vector, +} +client.database_query(driver, insert_query, params=parameters) + +# Perform the similarity search +query_vector = [random() for _ in range(DIMENSION)] +print(client.similarity_search( + driver, INDEX_NAME, query_vector=query_vector, top_k=5 +)) diff --git a/neo4j_genai/__pycache__/__init__.cpython-311.pyc b/neo4j_genai/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 8c801ec1b0a7531e389ce13b49ae76e19db3c031..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 170 zcmZ3^%ge<81TT#5r-JCmAOZ#$p^VRLK*n^26oz01O-8?!3`I;p{%4TnFI)YL{M=Oi zto%IvfTH}Y)Z~(4{k+tClPulz)V#z@-Ga)J44@oXC>|`NA0MBYmst`YuUAm{i^C>2 hKczG$)vkyYXd1}AVtyd;ftit!@dE>lC}IYR0RZTPDIfp< diff --git a/neo4j_genai/__pycache__/tools.cpython-311.pyc b/neo4j_genai/__pycache__/tools.cpython-311.pyc deleted file mode 100644 index 719223d01f0e739def98b2327aea748d3db2fc52..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 327 zcmZ3^%ge<81QDk9Q)PklV-N=h7@>^MJV3^Dh7^VthA4&<#$X0brev5J5X}t5pH+a2 z=?o<>eL!9c<1!#`HC%)dXbMQb-z}Dc{JeBc##^ifMVWaeD;YimRWtmu($C1xP1VoJ z&(jYm%FjwoE-BW}OU*aQ(oIjzOU%?Qs4U3<%7KOA!9w~a`T04;dIgn06EpMDi#UMB zfUGMP1QHDlcUd?))Ea$Td>g>9hz%&}r^yIbR>TVAfJKVffvjH~Ho5sJr8%i~MVvq` Z*bq4&;{!7zBjX1qMn<6z444F1KLCW*Lo)yX diff --git a/neo4j_genai_python/src/client.py b/neo4j_genai_python/src/client.py index 7eac5b69..c48b9252 100644 --- a/neo4j_genai_python/src/client.py +++ b/neo4j_genai_python/src/client.py @@ -1,18 +1,7 @@ -import neo4j - from typing import List, Dict, Any, Optional -from neo4j import Driver, GraphDatabase +from neo4j import Driver from neo4j.exceptions import CypherSyntaxError - -from abc import ABC, abstractmethod - - -class Embeddings(ABC): - """Interface for embedding models.""" - - @abstractmethod - def embed_query(self, text: str) -> List[float]: - """Embed query text.""" +from neo4j_genai_python.src.embeddings import Embeddings class GenAIClient: @@ -60,8 +49,7 @@ def database_query( List[Dict[str, Any]]: List of dictionaries containing the query results. """ params = params or {} - # TODO: how do we pass this database variable - with driver.session(database="neo4j") as session: + with driver.session() as session: try: data = session.run(query, params) return [r.data() for r in data] @@ -142,7 +130,7 @@ def similarity_search( } else: raise ValueError( - "Embedding method required to perform search for query_text" + "Embedding method required to perform search for text query." ) db_query_string = "CALL db.index.vector.queryNodes($index_name, $top_k, $vector) YIELD node, score" diff --git a/neo4j_genai_python/src/embeddings.py b/neo4j_genai_python/src/embeddings.py new file mode 100644 index 00000000..6f773d52 --- /dev/null +++ b/neo4j_genai_python/src/embeddings.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod +from typing import List + + +class Embeddings(ABC): + """Interface for embedding models.""" + + @abstractmethod + def embed_query(self, text: str) -> List[float]: + """Embed query text.""" From ad0163faa909167e27d36031eab29380b7cd1775 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Thu, 29 Feb 2024 13:57:49 +0000 Subject: [PATCH 07/23] Adds TODO comment --- examples/similarity_search_for_text.py | 1 + neo4j_genai_python/src/client.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index 9812dc71..38e0b4a1 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -1,3 +1,4 @@ +#TODO from neo4j import GraphDatabase from neo4j_genai_python.src.client import GenAIClient diff --git a/neo4j_genai_python/src/client.py b/neo4j_genai_python/src/client.py index c48b9252..5c8ef182 100644 --- a/neo4j_genai_python/src/client.py +++ b/neo4j_genai_python/src/client.py @@ -112,6 +112,7 @@ def similarity_search( if not ((query_vector is not None) ^ (query_text is not None)): raise ValueError("You must provide one of query_vector or query_text.") + # TODO: add query over vectors functionality if query_vector: parameters = { "index_name": name, From 9ed46a9a96d155ae743854091728a22e73c3b0b1 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Thu, 29 Feb 2024 14:56:50 +0000 Subject: [PATCH 08/23] moved code to neo4j_genai/ --- examples/similarity_search_for_text.py | 23 +++++-- examples/similarity_search_for_vector.py | 5 +- neo4j_genai/{ => src}/__init__.py | 0 .../src/client.py | 3 +- .../src/embeddings.py | 0 neo4j_genai/tests/test.py | 35 ---------- neo4j_genai/tests/test2.py | 68 ------------------- neo4j_genai/tools.py | 2 - neo4j_genai_python/src/__init__.py | 0 pyproject.toml | 2 +- 10 files changed, 21 insertions(+), 117 deletions(-) rename neo4j_genai/{ => src}/__init__.py (100%) rename {neo4j_genai_python => neo4j_genai}/src/client.py (97%) rename {neo4j_genai_python => neo4j_genai}/src/embeddings.py (100%) delete mode 100644 neo4j_genai/tests/test.py delete mode 100644 neo4j_genai/tests/test2.py delete mode 100644 neo4j_genai/tools.py delete mode 100644 neo4j_genai_python/src/__init__.py diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index 38e0b4a1..986690fe 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -1,8 +1,9 @@ -#TODO +from typing import List from neo4j import GraphDatabase -from neo4j_genai_python.src.client import GenAIClient +from neo4j_genai.src.client import GenAIClient from random import random +from neo4j_genai.src.embeddings import Embeddings URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") @@ -11,7 +12,15 @@ DIMENSION = 1536 driver = GraphDatabase.driver(URI, auth=AUTH) -client = GenAIClient(driver) +# Create Embeddings object +class CustomEmbeddings(Embeddings): + def embed_query(self, text: str) -> List[float]: + return [float(ord(c)) for c in text] + +embeddings = CustomEmbeddings() + +# Initialize the client +client = GenAIClient(driver, embeddings) client.drop_index(driver, INDEX_NAME) @@ -25,7 +34,7 @@ similarity_fn="euclidean", ) -# Upsert the vector +# Upsert the query vector = [random() for _ in range(DIMENSION)] insert_query = ( "MATCH (n:Node {id: $id})" @@ -38,8 +47,8 @@ } client.database_query(driver, insert_query, params=parameters) -# Perform the similarity search -query_vector = [random() for _ in range(DIMENSION)] +# Perform the similarity search for a text query +query_text = "hello world" print(client.similarity_search( - driver, INDEX_NAME, query_vector=query_vector, top_k=5 + driver, INDEX_NAME, query_text=query_text, top_k=5 )) diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index 9812dc71..9d32a285 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -1,5 +1,5 @@ from neo4j import GraphDatabase -from neo4j_genai_python.src.client import GenAIClient +from neo4j_genai.src.client import GenAIClient from random import random @@ -10,6 +10,7 @@ DIMENSION = 1536 driver = GraphDatabase.driver(URI, auth=AUTH) +# Initialize the client client = GenAIClient(driver) client.drop_index(driver, INDEX_NAME) @@ -37,7 +38,7 @@ } client.database_query(driver, insert_query, params=parameters) -# Perform the similarity search +# Perform the similarity search for a vector query query_vector = [random() for _ in range(DIMENSION)] print(client.similarity_search( driver, INDEX_NAME, query_vector=query_vector, top_k=5 diff --git a/neo4j_genai/__init__.py b/neo4j_genai/src/__init__.py similarity index 100% rename from neo4j_genai/__init__.py rename to neo4j_genai/src/__init__.py diff --git a/neo4j_genai_python/src/client.py b/neo4j_genai/src/client.py similarity index 97% rename from neo4j_genai_python/src/client.py rename to neo4j_genai/src/client.py index 5c8ef182..46c63bd2 100644 --- a/neo4j_genai_python/src/client.py +++ b/neo4j_genai/src/client.py @@ -1,7 +1,7 @@ from typing import List, Dict, Any, Optional from neo4j import Driver from neo4j.exceptions import CypherSyntaxError -from neo4j_genai_python.src.embeddings import Embeddings +from neo4j_genai.src.embeddings import Embeddings class GenAIClient: @@ -112,7 +112,6 @@ def similarity_search( if not ((query_vector is not None) ^ (query_text is not None)): raise ValueError("You must provide one of query_vector or query_text.") - # TODO: add query over vectors functionality if query_vector: parameters = { "index_name": name, diff --git a/neo4j_genai_python/src/embeddings.py b/neo4j_genai/src/embeddings.py similarity index 100% rename from neo4j_genai_python/src/embeddings.py rename to neo4j_genai/src/embeddings.py diff --git a/neo4j_genai/tests/test.py b/neo4j_genai/tests/test.py deleted file mode 100644 index 16226231..00000000 --- a/neo4j_genai/tests/test.py +++ /dev/null @@ -1,35 +0,0 @@ -from neo4j_genai import Client -from neo4j_genai import GenAI -from neo4j_genai import GenAIClient -from neo4j_genai import VectorClient - -from neo4j import GraphDatabase, Driver -from typing import Optional -from langchain_core.embeddings import Embeddings - -from pydantic_v1 import BaseModel - -URI = "neo4j://localhost:7687" -AUTH = ("neo4j", "password") - -driver = GraphDatabase.driver(URI, auth=AUTH) - -class GenAIClient: - def __init__(self, driver: Driver, embeddings: Optional[]) -> None: - pass -client = GenAIClient(driver, embeddings=) - -client.create_vector_index() -client.drop_vector_index() - -"""**Embeddings** interface.""" -from abc import ABC, abstractmethod -from typing import List - - -class Embeddings(ABC): - """Interface for embedding models.""" - - @abstractmethod - def embed_query(self, text: str) -> List[float]: - """Embed query text.""" diff --git a/neo4j_genai/tests/test2.py b/neo4j_genai/tests/test2.py deleted file mode 100644 index c7103834..00000000 --- a/neo4j_genai/tests/test2.py +++ /dev/null @@ -1,68 +0,0 @@ -from neo4j_genai import GenAIClient -from neo4j import GraphDatabase - -URI = "neo4j://localhost:7687" -AUTH = ("neo4j", "password") - -driver = GraphDatabase.driver(URI, auth=AUTH) - -client = GenAIClient(driver) - - -client.create_vector_index("indexMovies", "Movie", "embedding", dimensions=666, similarity_function="cosine") - -node = driver.execute_query("MATCH (m:Movie {movieId: row.movieId}) RETURN m") - -######### EXPLICIT ENCODING #################### -try: - from langchain_community.embeddings import OllamaEmbeddings - embeddings = OllamaEmbeddings() - - embedded_vectors = embeddings.embed_query("This is the query") -except ImportError: - embedded_vectors = requests.post("ollama.com/api/embeddings", query="This is the query") - -######### EXPLICIT ENDODING #################### - - -client.setNodeVectorProperty(node, "embedding", embedded_vectors) - -client.similarity_search("indexMovies", vectors=embedded_vectors) - -############################# - -from langchain_community.embeddings import OllamaEmbeddings -from neo4j_genai import GenAIClient -from neo4j import GraphDatabase - -URI = "neo4j://localhost:7687" -AUTH = ("neo4j", "password") - -driver = GraphDatabase.driver(URI, auth=AUTH) - -embedding_model = "ollama7b" -ollama7b_embedding_size = 666 -embeddings = OllamaEmbeddings(embedding_model) -client = GenAIClient(driver, embeddings=embeddings) - - - -client.create_vector_index("indexMovies", "Movie", "embedding", dimensions=ollama7b_embedding_size, similarity_function="cosine") - - -node = driver.execute_query("MATCH (m:Movie {movieId: row.movieId}) RETURN m") - - - -# client.setNodeVectorProperty(node, "embedding", embedded_vectors) -client.similarity_search("indexMovies", text="Landing on the moon") -## - -# def similarity_search(self, index_name, query_text): -# embedded_vectors = self.embedddings.embed_query(query_text) - -# # similary() - - -# FUTURE -# client.generateEmbeddingsAndSetVectorProperty(node, text_property="plot", vector_property="embedding") \ No newline at end of file diff --git a/neo4j_genai/tools.py b/neo4j_genai/tools.py deleted file mode 100644 index 44d864dd..00000000 --- a/neo4j_genai/tools.py +++ /dev/null @@ -1,2 +0,0 @@ -def ping(): - print("pong") \ No newline at end of file diff --git a/neo4j_genai_python/src/__init__.py b/neo4j_genai_python/src/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pyproject.toml b/pyproject.toml index 17bbd5a7..98491375 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.poetry] -name = "neo4j-genai-python" +name = "neo4j-genai" version = "0.1.0" description = "Python package to allow easy integration to Neo4j's GenAI features" authors = ["Neo4j, Inc "] From 9326f412cd17308b0e48adc31472d3a7abb02130 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Thu, 29 Feb 2024 16:34:10 +0000 Subject: [PATCH 09/23] restructured src folder again --- examples/similarity_search_for_text.py | 4 ++-- examples/similarity_search_for_vector.py | 2 +- neo4j_genai/src/client.py | 2 +- poetry.lock | 17 ++--------------- pyproject.toml | 4 ++++ 5 files changed, 10 insertions(+), 19 deletions(-) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index 986690fe..ec199861 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -1,9 +1,9 @@ from typing import List from neo4j import GraphDatabase -from neo4j_genai.src.client import GenAIClient +from src.client import GenAIClient from random import random -from neo4j_genai.src.embeddings import Embeddings +from src.embeddings import Embeddings URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index 9d32a285..d4668785 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -1,5 +1,5 @@ from neo4j import GraphDatabase -from neo4j_genai.src.client import GenAIClient +from src.client import GenAIClient from random import random diff --git a/neo4j_genai/src/client.py b/neo4j_genai/src/client.py index 46c63bd2..cc12700b 100644 --- a/neo4j_genai/src/client.py +++ b/neo4j_genai/src/client.py @@ -1,7 +1,7 @@ from typing import List, Dict, Any, Optional from neo4j import Driver from neo4j.exceptions import CypherSyntaxError -from neo4j_genai.src.embeddings import Embeddings +from src.embeddings import Embeddings class GenAIClient: diff --git a/poetry.lock b/poetry.lock index 218ecfbf..e0032a7d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,8 +1,6 @@ # This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] -<<<<<<< HEAD -======= name = "astroid" version = "3.1.0" description = "An abstract syntax tree for Python with inference support." @@ -211,14 +209,13 @@ files = [ ] [[package]] ->>>>>>> ca1e4dd (Adds first GenAIClient) name = "neo4j" -version = "5.17.0" +version = "5.18.0" description = "Neo4j Bolt driver for Python" optional = false python-versions = ">=3.7" files = [ - {file = "neo4j-5.17.0.tar.gz", hash = "sha256:dcd7150a0c3834a89a6e27505e614f340376f31c97c48ba60dc70a220ee85e3b"}, + {file = "neo4j-5.18.0.tar.gz", hash = "sha256:4014406ae5b8b485a8ba46c9f00b6f5b4aaf88e7c3a50603445030c2aab701c9"}, ] [package.dependencies] @@ -230,8 +227,6 @@ pandas = ["numpy (>=1.7.0,<2.0.0)", "pandas (>=1.1.0,<3.0.0)"] pyarrow = ["pyarrow (>=1.0.0)"] [[package]] -<<<<<<< HEAD -======= name = "packaging" version = "23.2" description = "Core utilities for Python packages" @@ -353,7 +348,6 @@ pytest = ">=5.0" dev = ["pre-commit", "pytest-asyncio", "tox"] [[package]] ->>>>>>> ca1e4dd (Adds first GenAIClient) name = "pytz" version = "2024.1" description = "World timezone definitions, modern and historical" @@ -364,12 +358,6 @@ files = [ {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, ] -<<<<<<< HEAD -[metadata] -lock-version = "2.0" -python-versions = ">=3.7" -content-hash = "4638595b14c9dedbce42766fd4745866b8011e8cd03967f27d170a0200602437" -======= [[package]] name = "tomli" version = "2.0.1" @@ -438,4 +426,3 @@ zstd = ["zstandard (>=0.18.0)"] lock-version = "2.0" python-versions = "^3.8" content-hash = "a2e97b40deca16a26490a2fce2e70c3742942ecbb026bca3e83219888530bfcf" ->>>>>>> ca1e4dd (Adds first GenAIClient) diff --git a/pyproject.toml b/pyproject.toml index 98491375..41d075ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,10 @@ authors = ["Neo4j, Inc "] license = "Apache License, Version 2.0" readme = "README.md" +[[tool.poetry.packages]] +include = "src" +from = "neo4j_genai" + [tool.poetry.dependencies] python = "^3.8" neo4j = "^5.17.0" From 773f3846c5cdb7cbf9e6255b45ef8d6a1827f0f9 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Thu, 29 Feb 2024 16:37:28 +0000 Subject: [PATCH 10/23] Got working example for querying over text --- examples/similarity_search_for_text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index ec199861..c640a7ef 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -15,7 +15,7 @@ # Create Embeddings object class CustomEmbeddings(Embeddings): def embed_query(self, text: str) -> List[float]: - return [float(ord(c)) for c in text] + return [random() for _ in range(1536)] embeddings = CustomEmbeddings() From b6b88c8b23c9a211be15abc694166d1b5c50b3e7 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Thu, 29 Feb 2024 17:49:12 +0000 Subject: [PATCH 11/23] Introduce pydantic data models for client inputs and outputs --- examples/similarity_search_for_text.py | 2 + examples/similarity_search_for_vector.py | 2 + neo4j_genai/src/client.py | 61 ++++++----- neo4j_genai/src/data_validators.py | 24 +++++ neo4j_genai/src/embeddings.py | 6 +- poetry.lock | 126 ++++++++++++++++++++++- pyproject.toml | 1 + 7 files changed, 188 insertions(+), 34 deletions(-) create mode 100644 neo4j_genai/src/data_validators.py diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index c640a7ef..8dc5f1c3 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -10,6 +10,8 @@ INDEX_NAME = "embedding-name" DIMENSION = 1536 + +# Connect to Neo4j database driver = GraphDatabase.driver(URI, auth=AUTH) # Create Embeddings object diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index d4668785..3e2f1b11 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -8,6 +8,8 @@ INDEX_NAME = "embedding-name" DIMENSION = 1536 + +# Connect to Neo4j database driver = GraphDatabase.driver(URI, auth=AUTH) # Initialize the client diff --git a/neo4j_genai/src/client.py b/neo4j_genai/src/client.py index cc12700b..1be35f56 100644 --- a/neo4j_genai/src/client.py +++ b/neo4j_genai/src/client.py @@ -2,7 +2,8 @@ from neo4j import Driver from neo4j.exceptions import CypherSyntaxError from src.embeddings import Embeddings - +from src.data_validators import CreateIndexModel, SimilaritySearchModel +from pydantic import ValidationError class GenAIClient: def __init__(self, driver: Driver, embeddings: Optional[Embeddings] = None) -> None: @@ -69,6 +70,18 @@ def create_index( This method constructs a Cypher query and executes it to create a new vector index in Neo4j. """ + index_data = { + "name": name, + "label": label, + "property": property, + "dimensions": dimensions, + "similarity_fn": similarity_fn, + } + try: + index_data = CreateIndexModel(**index_data) + except ValidationError as e: + raise ValueError(f"Error for inputs to create_index {str(e)}") + index_query = ( "CALL db.index.vector.createNodeIndex(" "$name," @@ -77,15 +90,7 @@ def create_index( "toInteger($dimensions)," "$similarity_fn )" ) - - parameters = { - "name": name, - "label": label, - "property": property, - "dimensions": dimensions, - "similarity_fn": similarity_fn, - } - self.database_query(driver, index_query, params=parameters) + self.database_query(driver, index_query, params=index_data.dict()) def drop_index(self, driver, name: str) -> None: """ @@ -109,29 +114,21 @@ def similarity_search( """ Performs the similarity search """ - if not ((query_vector is not None) ^ (query_text is not None)): - raise ValueError("You must provide one of query_vector or query_text.") - - if query_vector: - parameters = { - "index_name": name, - "top_k": top_k, - "vector": query_vector, - } - - if query_text: - # TODO: do we need to validate embeddings? Normalizing etc. - if self.embeddings: + try: + if query_vector: + validated_data = SimilaritySearchModel(index_name=name, top_k=top_k, vector=query_vector) + elif query_text: + if not self.embeddings: + raise ValueError("Embedding method required for text query.") vector_embedding = self.embeddings.embed_query(query_text) - parameters = { - "index_name": name, - "top_k": top_k, - "vector": vector_embedding, - } + validated_data = SimilaritySearchModel(index_name=name, top_k=top_k, vector=vector_embedding) else: - raise ValueError( - "Embedding method required to perform search for text query." - ) - + raise ValueError("Either query_vector or query_text must be provided.") + + parameters = validated_data.dict(exclude_none=True) + except ValidationError as e: + error_details = e.errors() + raise ValueError(f"Validation failed: {error_details}") + db_query_string = "CALL db.index.vector.queryNodes($index_name, $top_k, $vector) YIELD node, score" return self.database_query(driver, db_query_string, params=parameters) diff --git a/neo4j_genai/src/data_validators.py b/neo4j_genai/src/data_validators.py new file mode 100644 index 00000000..7cad1faf --- /dev/null +++ b/neo4j_genai/src/data_validators.py @@ -0,0 +1,24 @@ +from pydantic import BaseModel, PositiveInt, root_validator +from src.embeddings import EmbeddingVector +from typing import List, Literal, Optional + + +class CreateIndexModel(BaseModel): + name: str + label: str + property: str + dimensions: PositiveInt + similarity_fn: Literal["euclidean", "cosine"] + +class SimilaritySearchModel(BaseModel): + index_name: str + top_k: PositiveInt = 5 + vector: Optional[EmbeddingVector] = None + query_text: Optional[str] = None + + @root_validator(pre=True) + def check_query(cls, values): + vector, query_text = values.get("vector"), values.get("query_text") + if bool(vector) == bool(query_text): + raise ValueError("You must provide exactly one of query_vector or query_text.") + return values diff --git a/neo4j_genai/src/embeddings.py b/neo4j_genai/src/embeddings.py index 6f773d52..1a60f9ac 100644 --- a/neo4j_genai/src/embeddings.py +++ b/neo4j_genai/src/embeddings.py @@ -1,10 +1,14 @@ from abc import ABC, abstractmethod from typing import List +from pydantic import BaseModel +class EmbeddingVector(BaseModel): + vector: List[float] class Embeddings(ABC): """Interface for embedding models.""" @abstractmethod - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> EmbeddingVector: """Embed query text.""" + pass diff --git a/poetry.lock b/poetry.lock index e0032a7d..5a61a6ea 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,19 @@ # This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +[[package]] +name = "annotated-types" +version = "0.6.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.8" +files = [ + {file = "annotated_types-0.6.0-py3-none-any.whl", hash = "sha256:0641064de18ba7a25dee8f96403ebc39113d0cb953a01429249d5c7564666a43"}, + {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} + [[package]] name = "astroid" version = "3.1.0" @@ -278,6 +292,116 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "pydantic" +version = "2.6.3" +description = "Data validation using Python type hints" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic-2.6.3-py3-none-any.whl", hash = "sha256:72c6034df47f46ccdf81869fddb81aade68056003900a8724a4f160700016a2a"}, + {file = "pydantic-2.6.3.tar.gz", hash = "sha256:e07805c4c7f5c6826e33a1d4c9d47950d7eaf34868e2690f8594d2e30241f11f"}, +] + +[package.dependencies] +annotated-types = ">=0.4.0" +pydantic-core = "2.16.3" +typing-extensions = ">=4.6.1" + +[package.extras] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.16.3" +description = "" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic_core-2.16.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:75b81e678d1c1ede0785c7f46690621e4c6e63ccd9192af1f0bd9d504bbb6bf4"}, + {file = "pydantic_core-2.16.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9c865a7ee6f93783bd5d781af5a4c43dadc37053a5b42f7d18dc019f8c9d2bd1"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:162e498303d2b1c036b957a1278fa0899d02b2842f1ff901b6395104c5554a45"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2f583bd01bbfbff4eaee0868e6fc607efdfcc2b03c1c766b06a707abbc856187"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b926dd38db1519ed3043a4de50214e0d600d404099c3392f098a7f9d75029ff8"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:716b542728d4c742353448765aa7cdaa519a7b82f9564130e2b3f6766018c9ec"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc4ad7f7ee1a13d9cb49d8198cd7d7e3aa93e425f371a68235f784e99741561f"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bd87f48924f360e5d1c5f770d6155ce0e7d83f7b4e10c2f9ec001c73cf475c99"}, + {file = "pydantic_core-2.16.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0df446663464884297c793874573549229f9eca73b59360878f382a0fc085979"}, + {file = "pydantic_core-2.16.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4df8a199d9f6afc5ae9a65f8f95ee52cae389a8c6b20163762bde0426275b7db"}, + {file = "pydantic_core-2.16.3-cp310-none-win32.whl", hash = "sha256:456855f57b413f077dff513a5a28ed838dbbb15082ba00f80750377eed23d132"}, + {file = "pydantic_core-2.16.3-cp310-none-win_amd64.whl", hash = "sha256:732da3243e1b8d3eab8c6ae23ae6a58548849d2e4a4e03a1924c8ddf71a387cb"}, + {file = "pydantic_core-2.16.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:519ae0312616026bf4cedc0fe459e982734f3ca82ee8c7246c19b650b60a5ee4"}, + {file = "pydantic_core-2.16.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b3992a322a5617ded0a9f23fd06dbc1e4bd7cf39bc4ccf344b10f80af58beacd"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d62da299c6ecb04df729e4b5c52dc0d53f4f8430b4492b93aa8de1f541c4aac"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2acca2be4bb2f2147ada8cac612f8a98fc09f41c89f87add7256ad27332c2fda"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1b662180108c55dfbf1280d865b2d116633d436cfc0bba82323554873967b340"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e7c6ed0dc9d8e65f24f5824291550139fe6f37fac03788d4580da0d33bc00c97"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6b1bb0827f56654b4437955555dc3aeeebeddc47c2d7ed575477f082622c49e"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e56f8186d6210ac7ece503193ec84104da7ceb98f68ce18c07282fcc2452e76f"}, + {file = "pydantic_core-2.16.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:936e5db01dd49476fa8f4383c259b8b1303d5dd5fb34c97de194560698cc2c5e"}, + {file = "pydantic_core-2.16.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:33809aebac276089b78db106ee692bdc9044710e26f24a9a2eaa35a0f9fa70ba"}, + {file = "pydantic_core-2.16.3-cp311-none-win32.whl", hash = "sha256:ded1c35f15c9dea16ead9bffcde9bb5c7c031bff076355dc58dcb1cb436c4721"}, + {file = "pydantic_core-2.16.3-cp311-none-win_amd64.whl", hash = "sha256:d89ca19cdd0dd5f31606a9329e309d4fcbb3df860960acec32630297d61820df"}, + {file = "pydantic_core-2.16.3-cp311-none-win_arm64.whl", hash = "sha256:6162f8d2dc27ba21027f261e4fa26f8bcb3cf9784b7f9499466a311ac284b5b9"}, + {file = "pydantic_core-2.16.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:0f56ae86b60ea987ae8bcd6654a887238fd53d1384f9b222ac457070b7ac4cff"}, + {file = "pydantic_core-2.16.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9bd22a2a639e26171068f8ebb5400ce2c1bc7d17959f60a3b753ae13c632975"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4204e773b4b408062960e65468d5346bdfe139247ee5f1ca2a378983e11388a2"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f651dd19363c632f4abe3480a7c87a9773be27cfe1341aef06e8759599454120"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aaf09e615a0bf98d406657e0008e4a8701b11481840be7d31755dc9f97c44053"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8e47755d8152c1ab5b55928ab422a76e2e7b22b5ed8e90a7d584268dd49e9c6b"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:500960cb3a0543a724a81ba859da816e8cf01b0e6aaeedf2c3775d12ee49cade"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cf6204fe865da605285c34cf1172879d0314ff267b1c35ff59de7154f35fdc2e"}, + {file = "pydantic_core-2.16.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d33dd21f572545649f90c38c227cc8631268ba25c460b5569abebdd0ec5974ca"}, + {file = "pydantic_core-2.16.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:49d5d58abd4b83fb8ce763be7794d09b2f50f10aa65c0f0c1696c677edeb7cbf"}, + {file = "pydantic_core-2.16.3-cp312-none-win32.whl", hash = "sha256:f53aace168a2a10582e570b7736cc5bef12cae9cf21775e3eafac597e8551fbe"}, + {file = "pydantic_core-2.16.3-cp312-none-win_amd64.whl", hash = "sha256:0d32576b1de5a30d9a97f300cc6a3f4694c428d956adbc7e6e2f9cad279e45ed"}, + {file = "pydantic_core-2.16.3-cp312-none-win_arm64.whl", hash = "sha256:ec08be75bb268473677edb83ba71e7e74b43c008e4a7b1907c6d57e940bf34b6"}, + {file = "pydantic_core-2.16.3-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:b1f6f5938d63c6139860f044e2538baeee6f0b251a1816e7adb6cbce106a1f01"}, + {file = "pydantic_core-2.16.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2a1ef6a36fdbf71538142ed604ad19b82f67b05749512e47f247a6ddd06afdc7"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:704d35ecc7e9c31d48926150afada60401c55efa3b46cd1ded5a01bdffaf1d48"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d937653a696465677ed583124b94a4b2d79f5e30b2c46115a68e482c6a591c8a"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9803edf8e29bd825f43481f19c37f50d2b01899448273b3a7758441b512acf8"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:72282ad4892a9fb2da25defeac8c2e84352c108705c972db82ab121d15f14e6d"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f752826b5b8361193df55afcdf8ca6a57d0232653494ba473630a83ba50d8c9"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4384a8f68ddb31a0b0c3deae88765f5868a1b9148939c3f4121233314ad5532c"}, + {file = "pydantic_core-2.16.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a4b2bf78342c40b3dc830880106f54328928ff03e357935ad26c7128bbd66ce8"}, + {file = "pydantic_core-2.16.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:13dcc4802961b5f843a9385fc821a0b0135e8c07fc3d9949fd49627c1a5e6ae5"}, + {file = "pydantic_core-2.16.3-cp38-none-win32.whl", hash = "sha256:e3e70c94a0c3841e6aa831edab1619ad5c511199be94d0c11ba75fe06efe107a"}, + {file = "pydantic_core-2.16.3-cp38-none-win_amd64.whl", hash = "sha256:ecdf6bf5f578615f2e985a5e1f6572e23aa632c4bd1dc67f8f406d445ac115ed"}, + {file = "pydantic_core-2.16.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:bda1ee3e08252b8d41fa5537413ffdddd58fa73107171a126d3b9ff001b9b820"}, + {file = "pydantic_core-2.16.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:21b888c973e4f26b7a96491c0965a8a312e13be108022ee510248fe379a5fa23"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be0ec334369316fa73448cc8c982c01e5d2a81c95969d58b8f6e272884df0074"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b5b6079cc452a7c53dd378c6f881ac528246b3ac9aae0f8eef98498a75657805"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ee8d5f878dccb6d499ba4d30d757111847b6849ae07acdd1205fffa1fc1253c"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7233d65d9d651242a68801159763d09e9ec96e8a158dbf118dc090cd77a104c9"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6119dc90483a5cb50a1306adb8d52c66e447da88ea44f323e0ae1a5fcb14256"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:578114bc803a4c1ff9946d977c221e4376620a46cf78da267d946397dc9514a8"}, + {file = "pydantic_core-2.16.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d8f99b147ff3fcf6b3cc60cb0c39ea443884d5559a30b1481e92495f2310ff2b"}, + {file = "pydantic_core-2.16.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4ac6b4ce1e7283d715c4b729d8f9dab9627586dafce81d9eaa009dd7f25dd972"}, + {file = "pydantic_core-2.16.3-cp39-none-win32.whl", hash = "sha256:e7774b570e61cb998490c5235740d475413a1f6de823169b4cf94e2fe9e9f6b2"}, + {file = "pydantic_core-2.16.3-cp39-none-win_amd64.whl", hash = "sha256:9091632a25b8b87b9a605ec0e61f241c456e9248bfdcf7abdf344fdb169c81cf"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:36fa178aacbc277bc6b62a2c3da95226520da4f4e9e206fdf076484363895d2c"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:dcca5d2bf65c6fb591fff92da03f94cd4f315972f97c21975398bd4bd046854a"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a72fb9963cba4cd5793854fd12f4cfee731e86df140f59ff52a49b3552db241"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b60cc1a081f80a2105a59385b92d82278b15d80ebb3adb200542ae165cd7d183"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cbcc558401de90a746d02ef330c528f2e668c83350f045833543cd57ecead1ad"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:fee427241c2d9fb7192b658190f9f5fd6dfe41e02f3c1489d2ec1e6a5ab1e04a"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f4cb85f693044e0f71f394ff76c98ddc1bc0953e48c061725e540396d5c8a2e1"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:b29eeb887aa931c2fcef5aa515d9d176d25006794610c264ddc114c053bf96fe"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a425479ee40ff021f8216c9d07a6a3b54b31c8267c6e17aa88b70d7ebd0e5e5b"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:5c5cbc703168d1b7a838668998308018a2718c2130595e8e190220238addc96f"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99b6add4c0b39a513d323d3b93bc173dac663c27b99860dd5bf491b240d26137"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75f76ee558751746d6a38f89d60b6228fa174e5172d143886af0f85aa306fd89"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:00ee1c97b5364b84cb0bd82e9bbf645d5e2871fb8c58059d158412fee2d33d8a"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:287073c66748f624be4cef893ef9174e3eb88fe0b8a78dc22e88eca4bc357ca6"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ed25e1835c00a332cb10c683cd39da96a719ab1dfc08427d476bce41b92531fc"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:86b3d0033580bd6bbe07590152007275bd7af95f98eaa5bd36f3da219dcd93da"}, + {file = "pydantic_core-2.16.3.tar.gz", hash = "sha256:1cac689f80a3abab2d3c0048b29eea5751114054f032a941a32de4c852c59cad"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + [[package]] name = "pylint" version = "3.1.0" @@ -425,4 +549,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "a2e97b40deca16a26490a2fce2e70c3742942ecbb026bca3e83219888530bfcf" +content-hash = "e9571707d11b67cf7226bd97e391882098e9ca6065c92ce148da425824aa7718" diff --git a/pyproject.toml b/pyproject.toml index 41d075ed..96aeff6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ neo4j = "^5.17.0" types-requests = "^2.31.0.20240218" pytest = "^8.0.2" pytest-mock = "^3.12.0" +pydantic = "^2.6.3" [tool.poetry.group.dev.dependencies] pylint = "^3.1.0" From f2fcaf6c4b4ee999c4579eb58950641433438414 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Thu, 29 Feb 2024 17:51:53 +0000 Subject: [PATCH 12/23] Reformatting with black --- .gitignore | 1 + examples/similarity_search_for_text.py | 6 +++--- examples/similarity_search_for_vector.py | 4 +--- neo4j_genai/src/client.py | 15 ++++++++++----- neo4j_genai/src/data_validators.py | 5 ++++- neo4j_genai/src/embeddings.py | 2 ++ 6 files changed, 21 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 77ed5a6c..d16f774f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ dist/ **/__pycache__/* *.py[cod] +.mypy_cache/ \ No newline at end of file diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index 8dc5f1c3..81a7ac37 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -14,11 +14,13 @@ # Connect to Neo4j database driver = GraphDatabase.driver(URI, auth=AUTH) + # Create Embeddings object class CustomEmbeddings(Embeddings): def embed_query(self, text: str) -> List[float]: return [random() for _ in range(1536)] + embeddings = CustomEmbeddings() # Initialize the client @@ -51,6 +53,4 @@ def embed_query(self, text: str) -> List[float]: # Perform the similarity search for a text query query_text = "hello world" -print(client.similarity_search( - driver, INDEX_NAME, query_text=query_text, top_k=5 -)) +print(client.similarity_search(driver, INDEX_NAME, query_text=query_text, top_k=5)) diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index 3e2f1b11..12a59487 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -42,6 +42,4 @@ # Perform the similarity search for a vector query query_vector = [random() for _ in range(DIMENSION)] -print(client.similarity_search( - driver, INDEX_NAME, query_vector=query_vector, top_k=5 -)) +print(client.similarity_search(driver, INDEX_NAME, query_vector=query_vector, top_k=5)) diff --git a/neo4j_genai/src/client.py b/neo4j_genai/src/client.py index 1be35f56..9100a292 100644 --- a/neo4j_genai/src/client.py +++ b/neo4j_genai/src/client.py @@ -5,6 +5,7 @@ from src.data_validators import CreateIndexModel, SimilaritySearchModel from pydantic import ValidationError + class GenAIClient: def __init__(self, driver: Driver, embeddings: Optional[Embeddings] = None) -> None: # Verify if the version supports vector index @@ -81,7 +82,7 @@ def create_index( index_data = CreateIndexModel(**index_data) except ValidationError as e: raise ValueError(f"Error for inputs to create_index {str(e)}") - + index_query = ( "CALL db.index.vector.createNodeIndex(" "$name," @@ -116,19 +117,23 @@ def similarity_search( """ try: if query_vector: - validated_data = SimilaritySearchModel(index_name=name, top_k=top_k, vector=query_vector) + validated_data = SimilaritySearchModel( + index_name=name, top_k=top_k, vector=query_vector + ) elif query_text: if not self.embeddings: raise ValueError("Embedding method required for text query.") vector_embedding = self.embeddings.embed_query(query_text) - validated_data = SimilaritySearchModel(index_name=name, top_k=top_k, vector=vector_embedding) + validated_data = SimilaritySearchModel( + index_name=name, top_k=top_k, vector=vector_embedding + ) else: raise ValueError("Either query_vector or query_text must be provided.") - + parameters = validated_data.dict(exclude_none=True) except ValidationError as e: error_details = e.errors() raise ValueError(f"Validation failed: {error_details}") - + db_query_string = "CALL db.index.vector.queryNodes($index_name, $top_k, $vector) YIELD node, score" return self.database_query(driver, db_query_string, params=parameters) diff --git a/neo4j_genai/src/data_validators.py b/neo4j_genai/src/data_validators.py index 7cad1faf..8d95a4d5 100644 --- a/neo4j_genai/src/data_validators.py +++ b/neo4j_genai/src/data_validators.py @@ -10,6 +10,7 @@ class CreateIndexModel(BaseModel): dimensions: PositiveInt similarity_fn: Literal["euclidean", "cosine"] + class SimilaritySearchModel(BaseModel): index_name: str top_k: PositiveInt = 5 @@ -20,5 +21,7 @@ class SimilaritySearchModel(BaseModel): def check_query(cls, values): vector, query_text = values.get("vector"), values.get("query_text") if bool(vector) == bool(query_text): - raise ValueError("You must provide exactly one of query_vector or query_text.") + raise ValueError( + "You must provide exactly one of query_vector or query_text." + ) return values diff --git a/neo4j_genai/src/embeddings.py b/neo4j_genai/src/embeddings.py index 1a60f9ac..a48433a3 100644 --- a/neo4j_genai/src/embeddings.py +++ b/neo4j_genai/src/embeddings.py @@ -2,9 +2,11 @@ from typing import List from pydantic import BaseModel + class EmbeddingVector(BaseModel): vector: List[float] + class Embeddings(ABC): """Interface for embedding models.""" From cffb9b83d41554c8d451b0137be9a804ed2e8bdb Mon Sep 17 00:00:00 2001 From: Will Tai Date: Fri, 1 Mar 2024 13:18:24 +0000 Subject: [PATCH 13/23] Addressed PR comments --- examples/similarity_search_for_vector.py | 6 +++++- neo4j_genai/src/client.py | 5 ++--- neo4j_genai/src/data_validators.py | 7 +++++-- neo4j_genai/src/embeddings.py | 6 +----- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index 12a59487..0304e711 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -42,4 +42,8 @@ # Perform the similarity search for a vector query query_vector = [random() for _ in range(DIMENSION)] -print(client.similarity_search(driver, INDEX_NAME, query_vector=query_vector, top_k=5)) +# retriever_query +client.similarity_search(driver, INDEX_NAME, query_vector=query_vector, top_k=5) + +retriever_query = "WITH node, score MATCH (node)-->(parent) RETURN parent.text" +client.similarity_search(driver, retriever_query, INDEX_NAME, query_vector=query_vector, top_k=5) diff --git a/neo4j_genai/src/client.py b/neo4j_genai/src/client.py index 9100a292..d463c939 100644 --- a/neo4j_genai/src/client.py +++ b/neo4j_genai/src/client.py @@ -10,7 +10,7 @@ class GenAIClient: def __init__(self, driver: Driver, embeddings: Optional[Embeddings] = None) -> None: # Verify if the version supports vector index self._verify_version(driver) - self.embeddings = embeddings if embeddings else None + self.embeddings = embeddings def _verify_version(self, driver: Driver) -> None: """ @@ -37,7 +37,7 @@ def _verify_version(self, driver: Driver) -> None: ) def database_query( - self, driver: Driver, query: str, params: Optional[Dict] = None + self, driver: Driver, query: str, params: Dict ={} ) -> List[Dict[str, Any]]: """ This method sends a Cypher query to the connected Neo4j database @@ -50,7 +50,6 @@ def database_query( Returns: List[Dict[str, Any]]: List of dictionaries containing the query results. """ - params = params or {} with driver.session() as session: try: data = session.run(query, params) diff --git a/neo4j_genai/src/data_validators.py b/neo4j_genai/src/data_validators.py index 8d95a4d5..f281cc87 100644 --- a/neo4j_genai/src/data_validators.py +++ b/neo4j_genai/src/data_validators.py @@ -1,8 +1,11 @@ from pydantic import BaseModel, PositiveInt, root_validator -from src.embeddings import EmbeddingVector from typing import List, Literal, Optional +class EmbeddingVector(BaseModel): + vector: List[float] + + class CreateIndexModel(BaseModel): name: str label: str @@ -20,7 +23,7 @@ class SimilaritySearchModel(BaseModel): @root_validator(pre=True) def check_query(cls, values): vector, query_text = values.get("vector"), values.get("query_text") - if bool(vector) == bool(query_text): + if vector and query_text: raise ValueError( "You must provide exactly one of query_vector or query_text." ) diff --git a/neo4j_genai/src/embeddings.py b/neo4j_genai/src/embeddings.py index a48433a3..208041d1 100644 --- a/neo4j_genai/src/embeddings.py +++ b/neo4j_genai/src/embeddings.py @@ -1,10 +1,6 @@ from abc import ABC, abstractmethod from typing import List -from pydantic import BaseModel - - -class EmbeddingVector(BaseModel): - vector: List[float] +from src.data_validators import EmbeddingVector class Embeddings(ABC): From 6234831affe4879f8d7c4c58f9c47872151278b1 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Fri, 1 Mar 2024 15:46:01 +0000 Subject: [PATCH 14/23] moved code into src/ --- examples/similarity_search_for_text.py | 4 ++-- examples/similarity_search_for_vector.py | 7 ++++--- pyproject.toml | 4 ++-- {neo4j_genai/src => src/neo4j_genai}/__init__.py | 0 {neo4j_genai/src => src/neo4j_genai}/client.py | 9 +++++---- {neo4j_genai/src => src/neo4j_genai}/embeddings.py | 2 +- .../src/data_validators.py => src/neo4j_genai/types.py | 6 ++++++ 7 files changed, 20 insertions(+), 12 deletions(-) rename {neo4j_genai/src => src/neo4j_genai}/__init__.py (100%) rename {neo4j_genai/src => src/neo4j_genai}/client.py (95%) rename {neo4j_genai/src => src/neo4j_genai}/embeddings.py (83%) rename neo4j_genai/src/data_validators.py => src/neo4j_genai/types.py (92%) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index 81a7ac37..996403ce 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -1,9 +1,9 @@ from typing import List from neo4j import GraphDatabase -from src.client import GenAIClient +from neo4j_genai.client import GenAIClient from random import random -from src.embeddings import Embeddings +from neo4j_genai.embeddings import Embeddings URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index 0304e711..ae8236cf 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -1,5 +1,5 @@ from neo4j import GraphDatabase -from src.client import GenAIClient +from neo4j_genai.client import GenAIClient from random import random @@ -45,5 +45,6 @@ # retriever_query client.similarity_search(driver, INDEX_NAME, query_vector=query_vector, top_k=5) -retriever_query = "WITH node, score MATCH (node)-->(parent) RETURN parent.text" -client.similarity_search(driver, retriever_query, INDEX_NAME, query_vector=query_vector, top_k=5) +client.similarity_search( + driver, retriever_query, INDEX_NAME, query_vector=query_vector, top_k=5 +) diff --git a/pyproject.toml b/pyproject.toml index 96aeff6d..675c1cee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,8 +7,8 @@ license = "Apache License, Version 2.0" readme = "README.md" [[tool.poetry.packages]] -include = "src" -from = "neo4j_genai" +include = "neo4j_genai" +from = "src" [tool.poetry.dependencies] python = "^3.8" diff --git a/neo4j_genai/src/__init__.py b/src/neo4j_genai/__init__.py similarity index 100% rename from neo4j_genai/src/__init__.py rename to src/neo4j_genai/__init__.py diff --git a/neo4j_genai/src/client.py b/src/neo4j_genai/client.py similarity index 95% rename from neo4j_genai/src/client.py rename to src/neo4j_genai/client.py index d463c939..2e884169 100644 --- a/neo4j_genai/src/client.py +++ b/src/neo4j_genai/client.py @@ -1,8 +1,8 @@ from typing import List, Dict, Any, Optional from neo4j import Driver from neo4j.exceptions import CypherSyntaxError -from src.embeddings import Embeddings -from src.data_validators import CreateIndexModel, SimilaritySearchModel +from neo4j_genai.embeddings import Embeddings +from neo4j_genai.types import CreateIndexModel, SimilaritySearchModel from pydantic import ValidationError @@ -37,7 +37,7 @@ def _verify_version(self, driver: Driver) -> None: ) def database_query( - self, driver: Driver, query: str, params: Dict ={} + self, driver: Driver, query: str, params: Dict = {} ) -> List[Dict[str, Any]]: """ This method sends a Cypher query to the connected Neo4j database @@ -130,9 +130,10 @@ def similarity_search( raise ValueError("Either query_vector or query_text must be provided.") parameters = validated_data.dict(exclude_none=True) + except ValidationError as e: error_details = e.errors() raise ValueError(f"Validation failed: {error_details}") - db_query_string = "CALL db.index.vector.queryNodes($index_name, $top_k, $vector) YIELD node, score" + db_query_string = "CALL db.index.vector.queryNodes($index_name, $top_k, $vector) YIELD node, score, node.id AS id" return self.database_query(driver, db_query_string, params=parameters) diff --git a/neo4j_genai/src/embeddings.py b/src/neo4j_genai/embeddings.py similarity index 83% rename from neo4j_genai/src/embeddings.py rename to src/neo4j_genai/embeddings.py index 208041d1..237443f0 100644 --- a/neo4j_genai/src/embeddings.py +++ b/src/neo4j_genai/embeddings.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from typing import List -from src.data_validators import EmbeddingVector +from neo4j_genai.types import EmbeddingVector class Embeddings(ABC): diff --git a/neo4j_genai/src/data_validators.py b/src/neo4j_genai/types.py similarity index 92% rename from neo4j_genai/src/data_validators.py rename to src/neo4j_genai/types.py index f281cc87..807fd2f8 100644 --- a/neo4j_genai/src/data_validators.py +++ b/src/neo4j_genai/types.py @@ -2,6 +2,12 @@ from typing import List, Literal, Optional +class DatabaseQueryResult: + node + score: float + id: str + + class EmbeddingVector(BaseModel): vector: List[float] From 749dc457e370706fc2319c4170db18b27c339d75 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Fri, 1 Mar 2024 17:11:40 +0000 Subject: [PATCH 15/23] Stopped passing driver argument to every method in GenAIClient --- examples/similarity_search_for_text.py | 7 ++- examples/similarity_search_for_vector.py | 12 ++---- src/neo4j_genai/client.py | 54 ++++++++++-------------- src/neo4j_genai/types.py | 14 +++--- 4 files changed, 36 insertions(+), 51 deletions(-) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index 996403ce..6a439e48 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -26,11 +26,10 @@ def embed_query(self, text: str) -> List[float]: # Initialize the client client = GenAIClient(driver, embeddings) -client.drop_index(driver, INDEX_NAME) +client.drop_index(INDEX_NAME) # Creating the index client.create_index( - driver, INDEX_NAME, label="label", property="property", @@ -49,8 +48,8 @@ def embed_query(self, text: str) -> List[float]: "id": 1, "vector": vector, } -client.database_query(driver, insert_query, params=parameters) +client.database_query(insert_query, params=parameters) # Perform the similarity search for a text query query_text = "hello world" -print(client.similarity_search(driver, INDEX_NAME, query_text=query_text, top_k=5)) +print(client.similarity_search(INDEX_NAME, query_text=query_text, top_k=5)) diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index ae8236cf..418635ee 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -15,11 +15,10 @@ # Initialize the client client = GenAIClient(driver) -client.drop_index(driver, INDEX_NAME) +client.drop_index(INDEX_NAME) # Creating the index client.create_index( - driver, INDEX_NAME, label="label", property="property", @@ -38,13 +37,8 @@ "id": 1, "vector": vector, } -client.database_query(driver, insert_query, params=parameters) +client.database_query(insert_query, params=parameters) # Perform the similarity search for a vector query query_vector = [random() for _ in range(DIMENSION)] -# retriever_query -client.similarity_search(driver, INDEX_NAME, query_vector=query_vector, top_k=5) - -client.similarity_search( - driver, retriever_query, INDEX_NAME, query_vector=query_vector, top_k=5 -) +client.similarity_search(INDEX_NAME, query_vector=query_vector, top_k=5) diff --git a/src/neo4j_genai/client.py b/src/neo4j_genai/client.py index 2e884169..cec838da 100644 --- a/src/neo4j_genai/client.py +++ b/src/neo4j_genai/client.py @@ -9,10 +9,11 @@ class GenAIClient: def __init__(self, driver: Driver, embeddings: Optional[Embeddings] = None) -> None: # Verify if the version supports vector index - self._verify_version(driver) + self.driver = driver + self._verify_version() self.embeddings = embeddings - def _verify_version(self, driver: Driver) -> None: + def _verify_version(self) -> None: """ Check if the connected Neo4j database version supports vector indexing. @@ -21,7 +22,7 @@ def _verify_version(self, driver: Driver) -> None: indexing. Raises a ValueError if the connected Neo4j version is not supported. """ - version = self.database_query(driver, "CALL dbms.components()")[0]["versions"][ + version = self.database_query("CALL dbms.components()")[0]["versions"][ 0 ] if "aura" in version: @@ -37,7 +38,7 @@ def _verify_version(self, driver: Driver) -> None: ) def database_query( - self, driver: Driver, query: str, params: Dict = {} + self, query: str, params: Dict = {} ) -> List[Dict[str, Any]]: """ This method sends a Cypher query to the connected Neo4j database @@ -50,7 +51,7 @@ def database_query( Returns: List[Dict[str, Any]]: List of dictionaries containing the query results. """ - with driver.session() as session: + with self.driver.session() as session: try: data = session.run(query, params) return [r.data() for r in data] @@ -59,7 +60,6 @@ def database_query( def create_index( self, - driver: Driver, name: str, label: str, property: str, @@ -82,7 +82,7 @@ def create_index( except ValidationError as e: raise ValueError(f"Error for inputs to create_index {str(e)}") - index_query = ( + query = ( "CALL db.index.vector.createNodeIndex(" "$name," "$label," @@ -90,22 +90,21 @@ def create_index( "toInteger($dimensions)," "$similarity_fn )" ) - self.database_query(driver, index_query, params=index_data.dict()) + self.database_query(query, params=index_data.dict()) - def drop_index(self, driver, name: str) -> None: + def drop_index(self, name: str) -> None: """ This method constructs a Cypher query and executes it to drop a vector index in Neo4j. """ - index_query = "DROP INDEX $name" + query = "DROP INDEX $name" parameters = { "name": name, } - self.database_query(driver, index_query, params=parameters) + self.database_query(query, params=parameters) def similarity_search( self, - driver: Driver, name: str, query_vector: Optional[List[float]] = None, query_text: Optional[str] = None, @@ -115,25 +114,18 @@ def similarity_search( Performs the similarity search """ try: - if query_vector: - validated_data = SimilaritySearchModel( - index_name=name, top_k=top_k, vector=query_vector - ) - elif query_text: - if not self.embeddings: - raise ValueError("Embedding method required for text query.") - vector_embedding = self.embeddings.embed_query(query_text) - validated_data = SimilaritySearchModel( - index_name=name, top_k=top_k, vector=vector_embedding - ) - else: - raise ValueError("Either query_vector or query_text must be provided.") - - parameters = validated_data.dict(exclude_none=True) - + validated_data = SimilaritySearchModel( + index_name=name, top_k=top_k, query_vector=query_vector, query_text=query_text + ) except ValidationError as e: error_details = e.errors() raise ValueError(f"Validation failed: {error_details}") - - db_query_string = "CALL db.index.vector.queryNodes($index_name, $top_k, $vector) YIELD node, score, node.id AS id" - return self.database_query(driver, db_query_string, params=parameters) + + if query_text: + if not self.embeddings: + raise ValueError("Embedding method required for text query.") + query_vector = self.embeddings.embed_query(query_text) + + parameters = validated_data.dict(exclude_none=True) + db_query_string = "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) YIELD node, score" + return self.database_query(db_query_string, params=parameters) diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 807fd2f8..eba3eb04 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -2,10 +2,10 @@ from typing import List, Literal, Optional -class DatabaseQueryResult: - node - score: float - id: str +# class DatabaseQueryResult: +# node +# score: float +# id: str class EmbeddingVector(BaseModel): @@ -23,13 +23,13 @@ class CreateIndexModel(BaseModel): class SimilaritySearchModel(BaseModel): index_name: str top_k: PositiveInt = 5 - vector: Optional[EmbeddingVector] = None + query_vector: Optional[EmbeddingVector] = None query_text: Optional[str] = None @root_validator(pre=True) def check_query(cls, values): - vector, query_text = values.get("vector"), values.get("query_text") - if vector and query_text: + query_vector, query_text = values.get("query_vector"), values.get("query_text") + if bool(query_vector) ^ bool(query_text): raise ValueError( "You must provide exactly one of query_vector or query_text." ) From 7a0b30acf57f7e9e4639e53ff3a437a06ce16080 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Fri, 1 Mar 2024 17:12:08 +0000 Subject: [PATCH 16/23] Run black on src/ --- src/neo4j_genai/client.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/neo4j_genai/client.py b/src/neo4j_genai/client.py index cec838da..9ccafa9e 100644 --- a/src/neo4j_genai/client.py +++ b/src/neo4j_genai/client.py @@ -22,9 +22,7 @@ def _verify_version(self) -> None: indexing. Raises a ValueError if the connected Neo4j version is not supported. """ - version = self.database_query("CALL dbms.components()")[0]["versions"][ - 0 - ] + version = self.database_query("CALL dbms.components()")[0]["versions"][0] if "aura" in version: version_tuple = (*tuple(map(int, version.split("-")[0].split("."))), 0) else: @@ -37,9 +35,7 @@ def _verify_version(self) -> None: "Version index is only supported in Neo4j version 5.11 or greater" ) - def database_query( - self, query: str, params: Dict = {} - ) -> List[Dict[str, Any]]: + def database_query(self, query: str, params: Dict = {}) -> List[Dict[str, Any]]: """ This method sends a Cypher query to the connected Neo4j database and returns the results as a list of dictionaries. @@ -115,17 +111,20 @@ def similarity_search( """ try: validated_data = SimilaritySearchModel( - index_name=name, top_k=top_k, query_vector=query_vector, query_text=query_text + index_name=name, + top_k=top_k, + query_vector=query_vector, + query_text=query_text, ) except ValidationError as e: error_details = e.errors() raise ValueError(f"Validation failed: {error_details}") - + if query_text: if not self.embeddings: raise ValueError("Embedding method required for text query.") query_vector = self.embeddings.embed_query(query_text) - + parameters = validated_data.dict(exclude_none=True) db_query_string = "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) YIELD node, score" return self.database_query(db_query_string, params=parameters) From 3854ca7f83c85e69e241a82ad4fbba126dbc13d0 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Mon, 4 Mar 2024 10:13:46 +0000 Subject: [PATCH 17/23] Fixed similarity search code and type --- src/neo4j_genai/client.py | 4 +++- src/neo4j_genai/types.py | 9 ++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/neo4j_genai/client.py b/src/neo4j_genai/client.py index 9ccafa9e..0a29b713 100644 --- a/src/neo4j_genai/client.py +++ b/src/neo4j_genai/client.py @@ -120,11 +120,13 @@ def similarity_search( error_details = e.errors() raise ValueError(f"Validation failed: {error_details}") + parameters = validated_data.dict(exclude_none=True) + if query_text: if not self.embeddings: raise ValueError("Embedding method required for text query.") query_vector = self.embeddings.embed_query(query_text) + parameters["query_vector"] = query_vector - parameters = validated_data.dict(exclude_none=True) db_query_string = "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) YIELD node, score" return self.database_query(db_query_string, params=parameters) diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index eba3eb04..6048c5a3 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -2,11 +2,6 @@ from typing import List, Literal, Optional -# class DatabaseQueryResult: -# node -# score: float -# id: str - class EmbeddingVector(BaseModel): vector: List[float] @@ -23,13 +18,13 @@ class CreateIndexModel(BaseModel): class SimilaritySearchModel(BaseModel): index_name: str top_k: PositiveInt = 5 - query_vector: Optional[EmbeddingVector] = None + query_vector: Optional[List[float]] = None query_text: Optional[str] = None @root_validator(pre=True) def check_query(cls, values): query_vector, query_text = values.get("query_vector"), values.get("query_text") - if bool(query_vector) ^ bool(query_text): + if not (bool(query_vector) ^ bool(query_text)): raise ValueError( "You must provide exactly one of query_vector or query_text." ) From d728839a1c6d4562fa580a12513dbe9278f8fbd4 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Mon, 4 Mar 2024 11:41:47 +0000 Subject: [PATCH 18/23] Added returned type Neo4jRecord --- examples/similarity_search_for_text.py | 4 ++-- examples/similarity_search_for_vector.py | 4 ++-- src/neo4j_genai/client.py | 22 ++++++++++++++++------ src/neo4j_genai/embeddings.py | 2 -- src/neo4j_genai/types.py | 6 +++++- 5 files changed, 25 insertions(+), 13 deletions(-) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index 6a439e48..027e8ca0 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -18,7 +18,7 @@ # Create Embeddings object class CustomEmbeddings(Embeddings): def embed_query(self, text: str) -> List[float]: - return [random() for _ in range(1536)] + return [random() for _ in range(DIMENSION)] embeddings = CustomEmbeddings() @@ -32,7 +32,7 @@ def embed_query(self, text: str) -> List[float]: client.create_index( INDEX_NAME, label="label", - property="property", + property="propertyKey", dimensions=DIMENSION, similarity_fn="euclidean", ) diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index 418635ee..9a723978 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -21,7 +21,7 @@ client.create_index( INDEX_NAME, label="label", - property="property", + property="propertyKey", dimensions=DIMENSION, similarity_fn="euclidean", ) @@ -41,4 +41,4 @@ # Perform the similarity search for a vector query query_vector = [random() for _ in range(DIMENSION)] -client.similarity_search(INDEX_NAME, query_vector=query_vector, top_k=5) +print(client.similarity_search(INDEX_NAME, query_vector=query_vector, top_k=5)) diff --git a/src/neo4j_genai/client.py b/src/neo4j_genai/client.py index 0a29b713..2168c04f 100644 --- a/src/neo4j_genai/client.py +++ b/src/neo4j_genai/client.py @@ -1,17 +1,22 @@ from typing import List, Dict, Any, Optional +from pydantic import ValidationError from neo4j import Driver from neo4j.exceptions import CypherSyntaxError from neo4j_genai.embeddings import Embeddings -from neo4j_genai.types import CreateIndexModel, SimilaritySearchModel -from pydantic import ValidationError +from neo4j_genai.types import CreateIndexModel, SimilaritySearchModel, Neo4jRecord class GenAIClient: - def __init__(self, driver: Driver, embeddings: Optional[Embeddings] = None) -> None: + def __init__( + self, + driver: Driver, + embeddings: Optional[Embeddings] = None, + ) -> None: # Verify if the version supports vector index self.driver = driver self._verify_version() self.embeddings = embeddings + self.embeddings = embeddings def _verify_version(self) -> None: """ @@ -105,7 +110,7 @@ def similarity_search( query_vector: Optional[List[float]] = None, query_text: Optional[str] = None, top_k: int = 5, - ) -> List[Dict[str, Any]]: + ) -> List[Neo4jRecord]: """ Performs the similarity search """ @@ -128,5 +133,10 @@ def similarity_search( query_vector = self.embeddings.embed_query(query_text) parameters["query_vector"] = query_vector - db_query_string = "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) YIELD node, score" - return self.database_query(db_query_string, params=parameters) + db_query_string = """ + CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) + YIELD node, score + """ + records = self.database_query(db_query_string, params=parameters) + + return [Neo4jRecord(node=record.node, score=record.score) for record in records] diff --git a/src/neo4j_genai/embeddings.py b/src/neo4j_genai/embeddings.py index 237443f0..7c0d9eb5 100644 --- a/src/neo4j_genai/embeddings.py +++ b/src/neo4j_genai/embeddings.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import List from neo4j_genai.types import EmbeddingVector @@ -9,4 +8,3 @@ class Embeddings(ABC): @abstractmethod def embed_query(self, text: str) -> EmbeddingVector: """Embed query text.""" - pass diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 6048c5a3..21ab7b08 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -1,7 +1,11 @@ +from typing import List, Any, Literal, Optional from pydantic import BaseModel, PositiveInt, root_validator -from typing import List, Literal, Optional +class Neo4jRecord(BaseModel): + node: Any + score: float + class EmbeddingVector(BaseModel): vector: List[float] From 4a8c20b805d52ffdaf383c6f6c8d57f552b0409b Mon Sep 17 00:00:00 2001 From: Will Tai Date: Mon, 4 Mar 2024 11:43:25 +0000 Subject: [PATCH 19/23] Fix: removed redundant embeddings setting --- src/neo4j_genai/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/neo4j_genai/client.py b/src/neo4j_genai/client.py index 2168c04f..67bb3829 100644 --- a/src/neo4j_genai/client.py +++ b/src/neo4j_genai/client.py @@ -16,7 +16,6 @@ def __init__( self.driver = driver self._verify_version() self.embeddings = embeddings - self.embeddings = embeddings def _verify_version(self) -> None: """ From 9be9e5ec7e86f4c030fae7fa0f7b1045ecf6f696 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Mon, 4 Mar 2024 13:31:49 +0000 Subject: [PATCH 20/23] Added README and docstrings --- README.md | 4 +++- src/neo4j_genai/client.py | 42 ++++++++++++++++++++++++++++++++--- src/neo4j_genai/embeddings.py | 9 +++++++- src/neo4j_genai/types.py | 7 ++++-- 4 files changed, 55 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 306936a2..d021e195 100644 --- a/README.md +++ b/README.md @@ -1 +1,3 @@ -# neo4j-genai-python +# Neo4j GenAI package for Python + +This repository contains the official Neo4j GenAI features for Python. diff --git a/src/neo4j_genai/client.py b/src/neo4j_genai/client.py index 67bb3829..4dcfc829 100644 --- a/src/neo4j_genai/client.py +++ b/src/neo4j_genai/client.py @@ -7,12 +7,15 @@ class GenAIClient: + """ + Provides functionality to use Neo4j's GenAI features + """ + def __init__( self, driver: Driver, embeddings: Optional[Embeddings] = None, ) -> None: - # Verify if the version supports vector index self.driver = driver self._verify_version() self.embeddings = embeddings @@ -69,6 +72,19 @@ def create_index( """ This method constructs a Cypher query and executes it to create a new vector index in Neo4j. + + See Cypher manual on [Create node index](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_createNodeIndex) + + Args: + name (str): The unique name of the index. + label (str): The node label to be indexed. + property (str): The property key of a node which contains embedding values. + dimensions (int): Vector embedding dimension + similarity_fn (str): case-insensitive values for the vector similarity function: + ``euclidean`` or ``cosine``. + + Raises: + ValueError: If validation of the input arguments fail. """ index_data = { "name": name, @@ -96,6 +112,10 @@ def drop_index(self, name: str) -> None: """ This method constructs a Cypher query and executes it to drop a vector index in Neo4j. + See Cypher manual on [Drop vector indexes](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-drop) + + Args: + name (str): The name of the index to delete. """ query = "DROP INDEX $name" parameters = { @@ -110,8 +130,24 @@ def similarity_search( query_text: Optional[str] = None, top_k: int = 5, ) -> List[Neo4jRecord]: - """ - Performs the similarity search + """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. + See the following documentation for more details: + + - [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) + - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) + + Args: + name (str): Refers to the unique name of the vector index to query. + query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. + query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None. + top_k (int, optional): The number of neighbors to return. Defaults to 5. + + Raises: + ValueError: If validation of the input arguments fail. + ValueError: If no embeddings is provided. + + Returns: + List[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores. """ try: validated_data = SimilaritySearchModel( diff --git a/src/neo4j_genai/embeddings.py b/src/neo4j_genai/embeddings.py index 7c0d9eb5..1fe801ff 100644 --- a/src/neo4j_genai/embeddings.py +++ b/src/neo4j_genai/embeddings.py @@ -7,4 +7,11 @@ class Embeddings(ABC): @abstractmethod def embed_query(self, text: str) -> EmbeddingVector: - """Embed query text.""" + """Embed query text. + + Args: + text (str): Text to convert to vector embedding + + Returns: + EmbeddingVector: A vector embedding. + """ diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 21ab7b08..83ac97e0 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -1,5 +1,5 @@ from typing import List, Any, Literal, Optional -from pydantic import BaseModel, PositiveInt, root_validator +from pydantic import BaseModel, PositiveInt, Field, root_validator class Neo4jRecord(BaseModel): @@ -15,7 +15,7 @@ class CreateIndexModel(BaseModel): name: str label: str property: str - dimensions: PositiveInt + dimensions: int = Field(ge=1, le=20) similarity_fn: Literal["euclidean", "cosine"] @@ -27,6 +27,9 @@ class SimilaritySearchModel(BaseModel): @root_validator(pre=True) def check_query(cls, values): + """ + Validates that one of either query_vector or query_text is provided exclusively. + """ query_vector, query_text = values.get("query_vector"), values.get("query_text") if not (bool(query_vector) ^ bool(query_text)): raise ValueError( From 88562533d49e27147dd8b0eeec66e36fb6136dc9 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Mon, 4 Mar 2024 15:01:44 +0000 Subject: [PATCH 21/23] Catches validationerror if fail to construct Neo4jRecord, fixed examples --- examples/similarity_search_for_text.py | 12 ++++++++---- examples/similarity_search_for_vector.py | 11 +++++++---- src/neo4j_genai/client.py | 6 +++++- src/neo4j_genai/types.py | 2 +- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index 027e8ca0..17be0536 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -1,5 +1,6 @@ from typing import List from neo4j import GraphDatabase +from neo4j.exceptions import DatabaseError from neo4j_genai.client import GenAIClient from random import random @@ -26,12 +27,15 @@ def embed_query(self, text: str) -> List[float]: # Initialize the client client = GenAIClient(driver, embeddings) -client.drop_index(INDEX_NAME) +try: + client.drop_index(INDEX_NAME) +except DatabaseError as e: + print(e) # Creating the index client.create_index( INDEX_NAME, - label="label", + label="Document", property="propertyKey", dimensions=DIMENSION, similarity_fn="euclidean", @@ -40,12 +44,12 @@ def embed_query(self, text: str) -> List[float]: # Upsert the query vector = [random() for _ in range(DIMENSION)] insert_query = ( - "MATCH (n:Node {id: $id})" + "MERGE (n:Document)" + "WITH n " "CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)" "RETURN n" ) parameters = { - "id": 1, "vector": vector, } client.database_query(insert_query, params=parameters) diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index 9a723978..b4237376 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -15,12 +15,15 @@ # Initialize the client client = GenAIClient(driver) -client.drop_index(INDEX_NAME) +try: + client.drop_index(INDEX_NAME) +except DatabaseError as e: + print(e) # Creating the index client.create_index( INDEX_NAME, - label="label", + label="Document", property="propertyKey", dimensions=DIMENSION, similarity_fn="euclidean", @@ -29,12 +32,12 @@ # Upsert the vector vector = [random() for _ in range(DIMENSION)] insert_query = ( - "MATCH (n:Node {id: $id})" + "MERGE (n:Document)" + "WITH n " "CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)" "RETURN n" ) parameters = { - "id": 1, "vector": vector, } client.database_query(insert_query, params=parameters) diff --git a/src/neo4j_genai/client.py b/src/neo4j_genai/client.py index 4dcfc829..eb9714e2 100644 --- a/src/neo4j_genai/client.py +++ b/src/neo4j_genai/client.py @@ -174,4 +174,8 @@ def similarity_search( """ records = self.database_query(db_query_string, params=parameters) - return [Neo4jRecord(node=record.node, score=record.score) for record in records] + try: + return [Neo4jRecord(node=record["node"], score=record["score"]) for record in records] + except ValidationError as e: + error_details = e.errors() + raise ValueError(f"Validation failed while constructing output: {error_details}") diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 83ac97e0..82da6699 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -15,7 +15,7 @@ class CreateIndexModel(BaseModel): name: str label: str property: str - dimensions: int = Field(ge=1, le=20) + dimensions: int = Field(ge=1, le=2048) similarity_fn: Literal["euclidean", "cosine"] From 83725a260983a10d393e1adc5f3a5b6ca6f1aa7c Mon Sep 17 00:00:00 2001 From: Will Tai Date: Mon, 4 Mar 2024 15:40:39 +0000 Subject: [PATCH 22/23] removed drop_index from python examples --- examples/similarity_search_for_text.py | 5 ----- examples/similarity_search_for_vector.py | 7 +------ 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index 17be0536..1b93b225 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -27,11 +27,6 @@ def embed_query(self, text: str) -> List[float]: # Initialize the client client = GenAIClient(driver, embeddings) -try: - client.drop_index(INDEX_NAME) -except DatabaseError as e: - print(e) - # Creating the index client.create_index( INDEX_NAME, diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index b4237376..345c9a77 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -7,7 +7,7 @@ AUTH = ("neo4j", "password") INDEX_NAME = "embedding-name" -DIMENSION = 1536 +DIMENSION = 2049 # Connect to Neo4j database driver = GraphDatabase.driver(URI, auth=AUTH) @@ -15,11 +15,6 @@ # Initialize the client client = GenAIClient(driver) -try: - client.drop_index(INDEX_NAME) -except DatabaseError as e: - print(e) - # Creating the index client.create_index( INDEX_NAME, From 6832d795e56c32576b8747b44402cdc2d794b11d Mon Sep 17 00:00:00 2001 From: Will Tai Date: Mon, 4 Mar 2024 15:51:05 +0000 Subject: [PATCH 23/23] Reverted dimensions in example to 1536 --- examples/similarity_search_for_vector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index 345c9a77..157dc29b 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -7,7 +7,7 @@ AUTH = ("neo4j", "password") INDEX_NAME = "embedding-name" -DIMENSION = 2049 +DIMENSION = 1536 # Connect to Neo4j database driver = GraphDatabase.driver(URI, auth=AUTH)