From 711aa21c9e09c7787f7b707b05daf9fed85d4787 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 15 Oct 2024 20:28:09 +0200 Subject: [PATCH 1/2] Implement a sync run method - experiment 1 --- examples/pipeline/kg_builder_example.py | 3 ++- .../experimental/pipeline/kg_builder.py | 5 ++++ src/neo4j_graphrag/utils.py | 26 +++++++++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/examples/pipeline/kg_builder_example.py b/examples/pipeline/kg_builder_example.py index 03acb1ac..8dc28a2f 100644 --- a/examples/pipeline/kg_builder_example.py +++ b/examples/pipeline/kg_builder_example.py @@ -79,7 +79,8 @@ async def main(neo4j_driver: neo4j.Driver) -> None: # Run the knowledge graph building process with text input text_input = "John Doe lives in New York City." - text_result = await kg_builder_text.run_async(text=text_input) + # text_result = await kg_builder_text.run_async(text=text_input) + text_result = kg_builder_text.run(text=text_input) print(f"Text Processing Result: {text_result}") await llm.async_client.close() diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 9f7ff488..5d7d6bf0 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -43,6 +43,7 @@ from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline, PipelineResult from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.utils import run_sync class SimpleKGPipelineConfig(BaseModel): @@ -225,6 +226,10 @@ async def run_async( pipe_inputs = self._prepare_inputs(file_path=file_path, text=text) return await self.pipeline.run(pipe_inputs) + def run(self, file_path: Optional[str] = None, text: Optional[str] = None) -> PipelineResult: + """Run pipeline synchronously""" + return run_sync(self, file_path=file_path, text=text) + def _prepare_inputs( self, file_path: Optional[str], text: Optional[str] ) -> dict[str, Any]: diff --git a/src/neo4j_graphrag/utils.py b/src/neo4j_graphrag/utils.py index 5f1b322f..16a3e35f 100644 --- a/src/neo4j_graphrag/utils.py +++ b/src/neo4j_graphrag/utils.py @@ -16,6 +16,8 @@ import inspect from typing import Any, Optional, Union +import asyncio +import concurrent.futures import neo4j @@ -37,3 +39,27 @@ async def execute_query( # but we're sure at this stage we do not have a coroutine anymore records, _, _ = driver.execute_query(query, **kwargs) # type: ignore[misc] return records # type: ignore[no-any-return] + + +def run_sync(function, *args, **kwargs): + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(lambda: asyncio.run(function(*args, **kwargs))) + return_value = future.result() + return return_value + + + +if __name__ == "__main__": + async def async_run(char: str, repeat: int = 2) -> str: + await asyncio.sleep(5) + return char * repeat + + async def async_run_multiple(char, n=10): + return await asyncio.gather(*[ + async_run(char) + for _ in range(n) + ]) + + print( + run_sync(async_run_multiple, "abc") + ) From 4a37035a0e08591b954e468f2759806398358763 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 16 Oct 2024 16:16:11 +0200 Subject: [PATCH 2/2] async_to_sync wrapper --- .../experimental/components/resolver.py | 4 +++- .../experimental/pipeline/component.py | 3 +++ .../experimental/pipeline/pipeline.py | 6 ++++++ src/neo4j_graphrag/utils.py | 21 ++++++------------- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/resolver.py b/src/neo4j_graphrag/experimental/components/resolver.py index 4e07a1d6..999b3084 100644 --- a/src/neo4j_graphrag/experimental/components/resolver.py +++ b/src/neo4j_graphrag/experimental/components/resolver.py @@ -19,7 +19,7 @@ from neo4j_graphrag.experimental.components.types import ResolutionStats from neo4j_graphrag.experimental.pipeline import Component -from neo4j_graphrag.utils import execute_query +from neo4j_graphrag.utils import execute_query, async_to_sync class EntityResolver(Component, abc.ABC): @@ -140,3 +140,5 @@ async def run(self) -> ResolutionStats: number_of_nodes_to_resolve=number_of_nodes_to_resolve, number_of_created_nodes=number_of_created_nodes, ) + + run_sync = async_to_sync(run) diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index 84cd5bc0..efbd9bcb 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -21,6 +21,7 @@ from pydantic import BaseModel from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError +from neo4j_graphrag.utils import async_to_sync class DataModel(BaseModel): @@ -63,6 +64,8 @@ def __new__( } for f, field in return_model.model_fields.items() } + # create sync method: + attrs["run_sync"] = async_to_sync(run_method) return type.__new__(meta, name, bases, attrs) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 3d004eb8..a8a6efed 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -24,6 +24,8 @@ from timeit import default_timer from typing import Any, AsyncGenerator, Optional +from neo4j_graphrag.utils import async_to_sync + try: import pygraphviz as pgv except ImportError: @@ -105,6 +107,8 @@ async def run(self, inputs: dict[str, Any]) -> RunResult | None: logger.debug(f"TASK RESULT {self.name=} {res=}") return res + run_sync = async_to_sync(run) + class Orchestrator: """Orchestrate a pipeline. @@ -618,3 +622,5 @@ async def run(self, data: dict[str, Any]) -> PipelineResult: run_id=orchestrator.run_id, result=await self.final_results.get(orchestrator.run_id), ) + + run_sync = async_to_sync(run) diff --git a/src/neo4j_graphrag/utils.py b/src/neo4j_graphrag/utils.py index 16a3e35f..3a0f7419 100644 --- a/src/neo4j_graphrag/utils.py +++ b/src/neo4j_graphrag/utils.py @@ -15,6 +15,7 @@ from __future__ import annotations import inspect +from functools import wraps from typing import Any, Optional, Union import asyncio import concurrent.futures @@ -48,18 +49,8 @@ def run_sync(function, *args, **kwargs): return return_value - -if __name__ == "__main__": - async def async_run(char: str, repeat: int = 2) -> str: - await asyncio.sleep(5) - return char * repeat - - async def async_run_multiple(char, n=10): - return await asyncio.gather(*[ - async_run(char) - for _ in range(n) - ]) - - print( - run_sync(async_run_multiple, "abc") - ) +def async_to_sync(func): + @wraps(func) + def wrapper(*args, **kwargs): + return run_sync(func, *args, **kwargs) + return wrapper