From cf6662ce57ff0318e5d39b6ed706e26192216fac Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 16 Oct 2024 16:16:11 +0200 Subject: [PATCH] async_to_sync wrapper --- .../experimental/components/resolver.py | 3 +++ .../experimental/pipeline/component.py | 3 +++ .../experimental/pipeline/pipeline.py | 6 ++++++ src/neo4j_graphrag/utils.py | 20 ++++++------------- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/resolver.py b/src/neo4j_graphrag/experimental/components/resolver.py index f2da0bff..8d05d578 100644 --- a/src/neo4j_graphrag/experimental/components/resolver.py +++ b/src/neo4j_graphrag/experimental/components/resolver.py @@ -19,6 +19,7 @@ from neo4j_graphrag.experimental.components.types import ResolutionStats from neo4j_graphrag.experimental.pipeline import Component +from neo4j_graphrag.utils import async_to_sync class EntityResolver(Component, abc.ABC): @@ -136,3 +137,5 @@ async def run(self) -> ResolutionStats: number_of_nodes_to_resolve=number_of_nodes_to_resolve, number_of_created_nodes=number_of_created_nodes, ) + + run_sync = async_to_sync(run) diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index 84cd5bc0..efbd9bcb 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -21,6 +21,7 @@ from pydantic import BaseModel from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError +from neo4j_graphrag.utils import async_to_sync class DataModel(BaseModel): @@ -63,6 +64,8 @@ def __new__( } for f, field in return_model.model_fields.items() } + # create sync method: + attrs["run_sync"] = async_to_sync(run_method) return type.__new__(meta, name, bases, attrs) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index e3ded494..3ab1d12e 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -24,6 +24,8 @@ from timeit import default_timer from typing import Any, AsyncGenerator, Optional +from neo4j_graphrag.utils import async_to_sync + try: import pygraphviz as pgv except ImportError: @@ -107,6 +109,8 @@ async def run(self, inputs: dict[str, Any]) -> RunResult | None: logger.debug(f"TASK RESULT {self.name=} {res=}") return res + run_sync = async_to_sync(run) + class Orchestrator: """Orchestrate a pipeline. @@ -638,3 +642,5 @@ async def run(self, data: dict[str, Any]) -> PipelineResult: run_id=orchestrator.run_id, result=await self.final_results.get(orchestrator.run_id), ) + + run_sync = async_to_sync(run) diff --git a/src/neo4j_graphrag/utils.py b/src/neo4j_graphrag/utils.py index 60e4130d..1962630f 100644 --- a/src/neo4j_graphrag/utils.py +++ b/src/neo4j_graphrag/utils.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import annotations +from functools import wraps from typing import Optional import asyncio import concurrent.futures @@ -33,17 +34,8 @@ def run_sync(function, *args, **kwargs): return return_value -if __name__ == "__main__": - async def async_run(char: str, repeat: int = 2) -> str: - await asyncio.sleep(5) - return char * repeat - - async def async_run_multiple(char, n=10): - return await asyncio.gather(*[ - async_run(char) - for _ in range(n) - ]) - - print( - run_sync(async_run_multiple, "abc") - ) +def async_to_sync(func): + @wraps(func) + def wrapper(*args, **kwargs): + return run_sync(func, *args, **kwargs) + return wrapper