Skip to content

Commit

Permalink
async_to_sync wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
stellasia committed Jan 2, 2025
1 parent a2aea05 commit cf6662c
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 14 deletions.
3 changes: 3 additions & 0 deletions src/neo4j_graphrag/experimental/components/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from neo4j_graphrag.experimental.components.types import ResolutionStats
from neo4j_graphrag.experimental.pipeline import Component
from neo4j_graphrag.utils import async_to_sync


class EntityResolver(Component, abc.ABC):
Expand Down Expand Up @@ -136,3 +137,5 @@ async def run(self) -> ResolutionStats:
number_of_nodes_to_resolve=number_of_nodes_to_resolve,
number_of_created_nodes=number_of_created_nodes,
)

run_sync = async_to_sync(run)
3 changes: 3 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pydantic import BaseModel

from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
from neo4j_graphrag.utils import async_to_sync


class DataModel(BaseModel):
Expand Down Expand Up @@ -63,6 +64,8 @@ def __new__(
}
for f, field in return_model.model_fields.items()
}
# create sync method:
attrs["run_sync"] = async_to_sync(run_method)
return type.__new__(meta, name, bases, attrs)


Expand Down
6 changes: 6 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from timeit import default_timer
from typing import Any, AsyncGenerator, Optional

from neo4j_graphrag.utils import async_to_sync

try:
import pygraphviz as pgv
except ImportError:
Expand Down Expand Up @@ -107,6 +109,8 @@ async def run(self, inputs: dict[str, Any]) -> RunResult | None:
logger.debug(f"TASK RESULT {self.name=} {res=}")
return res

run_sync = async_to_sync(run)


class Orchestrator:
"""Orchestrate a pipeline.
Expand Down Expand Up @@ -638,3 +642,5 @@ async def run(self, data: dict[str, Any]) -> PipelineResult:
run_id=orchestrator.run_id,
result=await self.final_results.get(orchestrator.run_id),
)

run_sync = async_to_sync(run)
20 changes: 6 additions & 14 deletions src/neo4j_graphrag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import annotations

from functools import wraps
from typing import Optional
import asyncio
import concurrent.futures
Expand All @@ -33,17 +34,8 @@ def run_sync(function, *args, **kwargs):
return return_value


if __name__ == "__main__":
async def async_run(char: str, repeat: int = 2) -> str:
await asyncio.sleep(5)
return char * repeat

async def async_run_multiple(char, n=10):
return await asyncio.gather(*[
async_run(char)
for _ in range(n)
])

print(
run_sync(async_run_multiple, "abc")
)
def async_to_sync(func):
@wraps(func)
def wrapper(*args, **kwargs):
return run_sync(func, *args, **kwargs)
return wrapper

0 comments on commit cf6662c

Please sign in to comment.