Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC: Feature/sync mode #189

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/pipeline/kg_builder_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ async def main(neo4j_driver: neo4j.Driver) -> None:

# Run the knowledge graph building process with text input
text_input = "John Doe lives in New York City."
text_result = await kg_builder_text.run_async(text=text_input)
# text_result = await kg_builder_text.run_async(text=text_input)
text_result = kg_builder_text.run(text=text_input)
print(f"Text Processing Result: {text_result}")

await llm.async_client.close()
Expand Down
4 changes: 3 additions & 1 deletion src/neo4j_graphrag/experimental/components/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from neo4j_graphrag.experimental.components.types import ResolutionStats
from neo4j_graphrag.experimental.pipeline import Component
from neo4j_graphrag.utils import execute_query
from neo4j_graphrag.utils import execute_query, async_to_sync


class EntityResolver(Component, abc.ABC):
Expand Down Expand Up @@ -140,3 +140,5 @@ async def run(self) -> ResolutionStats:
number_of_nodes_to_resolve=number_of_nodes_to_resolve,
number_of_created_nodes=number_of_created_nodes,
)

run_sync = async_to_sync(run)
3 changes: 3 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pydantic import BaseModel

from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
from neo4j_graphrag.utils import async_to_sync


class DataModel(BaseModel):
Expand Down Expand Up @@ -63,6 +64,8 @@ def __new__(
}
for f, field in return_model.model_fields.items()
}
# create sync method:
attrs["run_sync"] = async_to_sync(run_method)
return type.__new__(meta, name, bases, attrs)


Expand Down
5 changes: 5 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline, PipelineResult
from neo4j_graphrag.generation.prompts import ERExtractionTemplate
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.utils import run_sync


class SimpleKGPipelineConfig(BaseModel):
Expand Down Expand Up @@ -225,6 +226,10 @@ async def run_async(
pipe_inputs = self._prepare_inputs(file_path=file_path, text=text)
return await self.pipeline.run(pipe_inputs)

def run(self, file_path: Optional[str] = None, text: Optional[str] = None) -> PipelineResult:
"""Run pipeline synchronously"""
return run_sync(self, file_path=file_path, text=text)

def _prepare_inputs(
self, file_path: Optional[str], text: Optional[str]
) -> dict[str, Any]:
Expand Down
6 changes: 6 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from timeit import default_timer
from typing import Any, AsyncGenerator, Optional

from neo4j_graphrag.utils import async_to_sync

try:
import pygraphviz as pgv
except ImportError:
Expand Down Expand Up @@ -105,6 +107,8 @@ async def run(self, inputs: dict[str, Any]) -> RunResult | None:
logger.debug(f"TASK RESULT {self.name=} {res=}")
return res

run_sync = async_to_sync(run)


class Orchestrator:
"""Orchestrate a pipeline.
Expand Down Expand Up @@ -618,3 +622,5 @@ async def run(self, data: dict[str, Any]) -> PipelineResult:
run_id=orchestrator.run_id,
result=await self.final_results.get(orchestrator.run_id),
)

run_sync = async_to_sync(run)
17 changes: 17 additions & 0 deletions src/neo4j_graphrag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from __future__ import annotations

import inspect
from functools import wraps
from typing import Any, Optional, Union
import asyncio
import concurrent.futures

import neo4j

Expand All @@ -37,3 +40,17 @@ async def execute_query(
# but we're sure at this stage we do not have a coroutine anymore
records, _, _ = driver.execute_query(query, **kwargs) # type: ignore[misc]
return records # type: ignore[no-any-return]


def run_sync(function, *args, **kwargs):
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(lambda: asyncio.run(function(*args, **kwargs)))
return_value = future.result()
return return_value


def async_to_sync(func):
@wraps(func)
def wrapper(*args, **kwargs):
return run_sync(func, *args, **kwargs)
return wrapper
Loading