diff --git a/examples/build_graph/simple_kg_builder_from_text.py b/examples/build_graph/simple_kg_builder_from_text.py index 67ed6776..0295331f 100644 --- a/examples/build_graph/simple_kg_builder_from_text.py +++ b/examples/build_graph/simple_kg_builder_from_text.py @@ -8,6 +8,7 @@ """ import asyncio +import logging import neo4j from neo4j_graphrag.embeddings import OpenAIEmbeddings @@ -20,6 +21,10 @@ from neo4j_graphrag.llm import LLMInterface from neo4j_graphrag.llm.openai_llm import OpenAILLM +logging.basicConfig() +logging.getLogger("neo4j_graphrag").setLevel(logging.DEBUG) + + # Neo4j db infos URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index d4070aea..3dbeccb2 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -39,6 +39,7 @@ from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError from neo4j_graphrag.generation.prompts import ERExtractionTemplate, PromptTemplate from neo4j_graphrag.llm import LLMInterface +from neo4j_graphrag.utils import prettyfier logger = logging.getLogger(__name__) @@ -220,8 +221,9 @@ async def extract_for_chunk( ) else: logger.error( - f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}" + f"LLM response is not valid JSON for chunk_index={chunk.index}" ) + logger.debug(f"Invalid JSON: {llm_result.content}") result = {"nodes": [], "relationships": []} try: chunk_graph = Neo4jGraph(**result) @@ -232,8 +234,9 @@ async def extract_for_chunk( ) else: logger.error( - f"LLM response has improper format {result} for chunk_index={chunk.index}" + f"LLM response has improper format for chunk_index={chunk.index}" ) + logger.debug(f"Invalid JSON format: {result}") chunk_graph = Neo4jGraph() return chunk_graph @@ -340,5 +343,5 @@ async def run( ] chunk_graphs: list[Neo4jGraph] = list(await asyncio.gather(*tasks)) graph = self.combine_chunk_graphs(lexical_graph, chunk_graphs) - logger.debug(f"{self.__class__.__name__}: {graph}") + logger.debug(f"Extracted graph: {prettyfier(graph)}") return graph diff --git a/src/neo4j_graphrag/experimental/components/types.py b/src/neo4j_graphrag/experimental/components/types.py index 093bb26d..cd4ebffa 100644 --- a/src/neo4j_graphrag/experimental/components/types.py +++ b/src/neo4j_graphrag/experimental/components/types.py @@ -14,13 +14,17 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, field_validator, RootModel from neo4j_graphrag.experimental.pipeline.component import DataModel +if TYPE_CHECKING: + from pydantic._internal import _repr + + class TextChunk(BaseModel): """A chunk of text split from a document by a text splitter. @@ -45,6 +49,20 @@ class TextChunks(DataModel): chunks: list[TextChunk] +# class Embeddings(RootModel): +# """A wrapper around list[float] to represent embeddings. +# Used to improve logging of vectors by not showing the full vector. +# """ +# root: list[float] +# +# # def __rep_str__(self, sep: str = ", ") -> str: +# # return f"" +# +# def __repr_args__(self) -> _repr.ReprArgs: +# yield 'dimension', len(self.root) +# yield 'vector', self.root[:3] +# + class Neo4jNode(BaseModel): """Represents a Neo4j node. @@ -99,6 +117,9 @@ class Neo4jGraph(DataModel): nodes: list[Neo4jNode] = [] relationships: list[Neo4jRelationship] = [] + # def __str__(self) -> str: + # return f"" + class ResolutionStats(DataModel): number_of_nodes_to_resolve: int diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py index a1a22585..c31f54a9 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/runner.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py @@ -70,6 +70,7 @@ class PipelineConfigWrapper(BaseModel): ] = Field(discriminator=Discriminator(_get_discriminator_value)) def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition: + logger.debug("PIPELINE_CONFIG: start parsing config...") return self.config.parse(resolved_data) def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: @@ -101,10 +102,12 @@ def from_config( cls, config: AbstractPipelineConfig | dict[str, Any], do_cleaning: bool = False ) -> Self: wrapper = PipelineConfigWrapper.model_validate({"config": config}) + logger.debug(f"PIPELINE_RUNNER: instantiate Pipeline from config type: {wrapper.config.template_}") return cls(wrapper.parse(), config=wrapper.config, do_cleaning=do_cleaning) @classmethod def from_config_file(cls, file_path: Union[str, Path]) -> Self: + logger.info(f"PIPELINE_RUNNER: reading config file from {file_path}") if not isinstance(file_path, str): file_path = str(file_path) data = ConfigReader().read(file_path) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index e3ded494..a5897667 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -24,6 +24,8 @@ from timeit import default_timer from typing import Any, AsyncGenerator, Optional +from neo4j_graphrag.utils import prettyfier + try: import pygraphviz as pgv except ImportError: @@ -90,21 +92,19 @@ async def execute(self, **kwargs: Any) -> RunResult | None: if the task run successfully, None if the status update was unsuccessful. """ - logger.debug(f"Running component {self.name} with {kwargs}") - start_time = default_timer() component_result = await self.component.run(**kwargs) run_result = RunResult( result=component_result, ) - end_time = default_timer() - logger.debug(f"Component {self.name} finished in {end_time - start_time}s") return run_result async def run(self, inputs: dict[str, Any]) -> RunResult | None: """Main method to execute the task.""" - logger.debug(f"TASK START {self.name=} {inputs=}") + logger.debug(f"TASK START {self.name=} input={prettyfier(inputs)}") + start_time = default_timer() res = await self.execute(**inputs) - logger.debug(f"TASK RESULT {self.name=} {res=}") + end_time = default_timer() + logger.debug(f"TASK FINISHED {self.name} in {end_time - start_time} res={prettyfier(res)}") return res @@ -141,7 +141,7 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None: try: await self.set_task_status(task.name, RunStatus.RUNNING) except PipelineStatusUpdateError: - logger.info(f"Component {task.name} already running or done") + logger.debug(f"ORCHESTRATOR: TASK ABORTED: {task.name} is already running or done, aborting") return None res = await task.run(inputs) await self.set_task_status(task.name, RunStatus.DONE) @@ -198,7 +198,8 @@ async def check_dependencies_complete(self, task: TaskPipelineNode) -> None: d_status = await self.get_status_for_component(d.start) if d_status != RunStatus.DONE: logger.debug( - f"Missing dependency {d.start} for {task.name} (status: {d_status}). " + f"ORCHESTRATOR {self.run_id}: TASK DELAYED: Missing dependency {d.start} for {task.name} " + f"(status: {d_status}). " "Will try again when dependency is complete." ) raise PipelineMissingDependencyError() @@ -227,6 +228,7 @@ async def next( await self.check_dependencies_complete(next_node) except PipelineMissingDependencyError: continue + logger.debug(f"ORCHESTRATOR {self.run_id}: enqueuing next task: {next_node.name}") yield next_node return @@ -315,7 +317,6 @@ async def run(self, data: dict[str, Any]) -> None: (node without any parent). Then the callback on_task_complete will handle the task dependencies. """ - logger.debug(f"PIPELINE START {data=}") tasks = [self.run_task(root, data) for root in self.pipeline.roots()] await asyncio.gather(*tasks) @@ -624,15 +625,16 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool: return True async def run(self, data: dict[str, Any]) -> PipelineResult: - logger.debug("Starting pipeline") + logger.debug("PIPELINE START") start_time = default_timer() self.invalidate() self.validate_input_data(data) orchestrator = Orchestrator(self) + logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}") await orchestrator.run(data) end_time = default_timer() logger.debug( - f"Pipeline {orchestrator.run_id} finished in {end_time - start_time}s" + f"PIPELINE FINISHED {orchestrator.run_id} in {end_time - start_time}s" ) return PipelineResult( run_id=orchestrator.run_id, diff --git a/src/neo4j_graphrag/utils.py b/src/neo4j_graphrag/utils.py index e86f7588..b4331c88 100644 --- a/src/neo4j_graphrag/utils.py +++ b/src/neo4j_graphrag/utils.py @@ -14,7 +14,9 @@ # limitations under the License. from __future__ import annotations -from typing import Optional +from typing import Optional, Any + +from pydantic import BaseModel def validate_search_query_input( @@ -22,3 +24,41 @@ def validate_search_query_input( ) -> None: if not (bool(query_vector) ^ bool(query_text)): raise ValueError("You must provide exactly one of query_vector or query_text.") + + + +class Prettyfier: + """Prettyfy object for logging. + + I.e.: truncate long lists. + """ + def __init__(self, max_items_in_list: int = 5): + self.max_items_in_list = max_items_in_list + + def _prettyfy_dict(self, value: dict[Any, Any]) -> dict[Any, Any]: + return { + k: self(v) # prettyfy each value + for k, v in value.items() + } + + def _prettyfy_list(self, value: list[Any]) -> list[Any]: + items = [ + self(v) # prettify each item + for v in value[:self.max_items_in_list] + ] + remaining_items = len(value) - len(items) + if remaining_items > 0: + items.append(f"...truncated {remaining_items} items...") + return items + + def __call__(self, value: Any) -> Any: + if isinstance(value, dict): + return self._prettyfy_dict(value) + if isinstance(value, BaseModel): + return self(value.model_dump()) + if isinstance(value, list): + return self._prettyfy_list(value) + return value + + +prettyfier = Prettyfier() diff --git a/tests/e2e/test_kg_writer_component_e2e.py b/tests/e2e/test_kg_writer_component_e2e.py index 2fc0ab90..8774e30c 100644 --- a/tests/e2e/test_kg_writer_component_e2e.py +++ b/tests/e2e/test_kg_writer_component_e2e.py @@ -76,7 +76,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: if start_node.embedding_properties: # for mypy for key, val in start_node.embedding_properties.items(): assert key in node_a.keys() - assert node_a.get(key) == [1.0, 2.0, 3.0] + assert val.root == node_a.get(key) node_b = record["b"] assert end_node.label in list(node_b.labels) @@ -100,7 +100,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: if node_with_two_embeddings.embedding_properties: # for mypy for key, val in node_with_two_embeddings.embedding_properties.items(): assert key in node_c.keys() - assert val == node_c.get(key) + assert val.root == node_c.get(key) @pytest.mark.asyncio diff --git a/tests/unit/experimental/components/test_embedder.py b/tests/unit/experimental/components/test_embedder.py index 5c72e0d2..f7da60c8 100644 --- a/tests/unit/experimental/components/test_embedder.py +++ b/tests/unit/experimental/components/test_embedder.py @@ -16,7 +16,11 @@ import pytest from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder -from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks +from neo4j_graphrag.experimental.components.types import ( + Embeddings, + TextChunk, + TextChunks, +) @pytest.mark.asyncio @@ -33,6 +37,4 @@ async def test_text_chunk_embedder_run(embedder: MagicMock) -> None: assert isinstance(chunk, TextChunk) assert chunk.metadata is not None assert "embedding" in chunk.metadata.keys() - assert isinstance(chunk.metadata["embedding"], list) - for i in chunk.metadata["embedding"]: - assert isinstance(i, float) + assert isinstance(chunk.metadata["embedding"], Embeddings) diff --git a/tests/unit/experimental/components/test_lexical_graph_builder.py b/tests/unit/experimental/components/test_lexical_graph_builder.py index 4cedae61..52bb272a 100644 --- a/tests/unit/experimental/components/test_lexical_graph_builder.py +++ b/tests/unit/experimental/components/test_lexical_graph_builder.py @@ -26,7 +26,7 @@ LexicalGraphConfig, Neo4jNode, TextChunk, - TextChunks, + TextChunks, Embeddings, ) @@ -68,7 +68,7 @@ def test_lexical_graph_builder_create_chunk_node_metadata_embedding() -> None: assert isinstance(node, Neo4jNode) assert node.id == "test_create_chunk_node_metadata_embedding:0" assert node.properties == {"index": 0, "text": "text chunk", "status": "ok"} - assert node.embedding_properties == {"embedding": [1, 2, 3]} + assert node.embedding_properties == {"embedding": Embeddings([1, 2, 3])} @pytest.mark.asyncio