From a2aea0504cf61f9e49d2d429977c54f646022240 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 15 Oct 2024 20:28:09 +0200 Subject: [PATCH] Implement a sync run method - experiment 1 --- examples/pipeline/kg_builder_example.py | 0 .../experimental/pipeline/kg_builder.py | 5 ++++ src/neo4j_graphrag/utils.py | 25 +++++++++++++++++++ 3 files changed, 30 insertions(+) create mode 100644 examples/pipeline/kg_builder_example.py diff --git a/examples/pipeline/kg_builder_example.py b/examples/pipeline/kg_builder_example.py new file mode 100644 index 00000000..e69de29b diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 3fca0215..e459fb25 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -34,6 +34,7 @@ ) from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.utils import run_sync class SimpleKGPipeline: @@ -124,3 +125,7 @@ async def run_async( PipelineResult: The result of the pipeline execution. """ return await self.runner.run({"file_path": file_path, "text": text}) + + def run(self, file_path: Optional[str] = None, text: Optional[str] = None) -> PipelineResult: + """Run pipeline synchronously""" + return run_sync(self, file_path=file_path, text=text) diff --git a/src/neo4j_graphrag/utils.py b/src/neo4j_graphrag/utils.py index e86f7588..60e4130d 100644 --- a/src/neo4j_graphrag/utils.py +++ b/src/neo4j_graphrag/utils.py @@ -15,6 +15,8 @@ from __future__ import annotations from typing import Optional +import asyncio +import concurrent.futures def validate_search_query_input( @@ -22,3 +24,26 @@ def validate_search_query_input( ) -> None: if not (bool(query_vector) ^ bool(query_text)): raise ValueError("You must provide exactly one of query_vector or query_text.") + + +def run_sync(function, *args, **kwargs): + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(lambda: asyncio.run(function(*args, **kwargs))) + return_value = future.result() + return return_value + + +if __name__ == "__main__": + async def async_run(char: str, repeat: int = 2) -> str: + await asyncio.sleep(5) + return char * repeat + + async def async_run_multiple(char, n=10): + return await asyncio.gather(*[ + async_run(char) + for _ in range(n) + ]) + + print( + run_sync(async_run_multiple, "abc") + )