From 378c98d2de62f266e4f42f4239e6992affeb89a4 Mon Sep 17 00:00:00 2001 From: Estelle Scifo Date: Tue, 3 Sep 2024 14:02:25 +0200 Subject: [PATCH] Save Document node in lexical graph (#116) * WIP * Add document info model - move chunk index to TextChunk model * Update tests to test DocumentInfo * Update examples * Update e2e tests, documentation, CHANGELOG * Remove print * Add docstrings, remove another print * Fix tests after merge --- CHANGELOG.md | 3 + docs/source/user_guide_kg_builder.rst | 2 + examples/pipeline/kg_builder_from_pdf.py | 22 +- examples/pipeline/kg_builder_from_text.py | 5 + .../experimental/components/embedder.py | 4 +- .../components/entity_relation_extractor.py | 220 ++++++++++++------ .../experimental/components/pdf_loader.py | 34 ++- .../components/text_splitters/langchain.py | 3 +- .../components/text_splitters/llamaindex.py | 3 +- .../experimental/components/types.py | 1 + tests/e2e/test_kg_builder_pipeline_e2e.py | 20 +- .../experimental/components/test_embedder.py | 4 +- .../test_entity_relation_extractor.py | 95 +++++--- .../components/test_pdf_loader.py | 5 +- 14 files changed, 292 insertions(+), 129 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e961b977..0be4ecce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## Next +### Changed +- When saving the lexical graph in a KG creation pipeline, the document is also saved as a specific node, together with relationships between each chunk and the document they were created from. + ## 0.5.0 ### Fixed diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index e3e69cf2..182f0323 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -289,9 +289,11 @@ Lexical Graph By default, the `LLMEntityRelationExtractor` adds some extra nodes and relationships to the extracted graph: +- `Document` node: represent the processed document and have a `path` property. - `Chunk` nodes: represent the text chunks. They have a `text` property and, if computed, an `embedding` property. - `NEXT_CHUNK` relationships between one chunk node and the next one in the document. It can be used to enhance the context in a RAG application. - `FROM_CHUNK` relationship between any extracted entity and the chunk it has been identified into. +- `FROM_DOCUMENT` relationship between each chunk and the document it was built from. If this 'lexical graph' is not desired, set the `created_lexical_graph` to `False` in the extractor constructor: diff --git a/examples/pipeline/kg_builder_from_pdf.py b/examples/pipeline/kg_builder_from_pdf.py index a5e29016..a8670b66 100644 --- a/examples/pipeline/kg_builder_from_pdf.py +++ b/examples/pipeline/kg_builder_from_pdf.py @@ -16,7 +16,7 @@ import asyncio import logging -from typing import Any +from typing import Any, Dict, List import neo4j from langchain_text_splitters import CharacterTextSplitter @@ -62,13 +62,13 @@ class Neo4jGraph(DataModel): class ERExtractor(Component): - async def _process_chunk(self, chunk: str, schema: str) -> dict[str, Any]: + async def _process_chunk(self, chunk: str, schema: str) -> Dict[str, Any]: return { "entities": [{"label": "Person", "properties": {"name": "John Doe"}}], "relations": [], } - async def run(self, chunks: list[str], schema: str) -> Neo4jGraph: + async def run(self, chunks: List[str], schema: str) -> Neo4jGraph: tasks = [self._process_chunk(chunk, schema) for chunk in chunks] result = await asyncio.gather(*tasks) merged_result: dict[str, Any] = {"entities": [], "relations": []} @@ -141,10 +141,7 @@ async def main(neo4j_driver: neo4j.Driver) -> dict[str, Any]: pipe = Pipeline() pipe.add_component(PdfLoader(), "pdf_loader") pipe.add_component( - LangChainTextSplitterAdapter( - # chunk_size=50 for the sake of this demo - CharacterTextSplitter(chunk_size=50, chunk_overlap=10, separator=".") - ), + LangChainTextSplitterAdapter(CharacterTextSplitter(separator=". \n")), "splitter", ) pipe.add_component(SchemaBuilder(), "schema") @@ -153,7 +150,7 @@ async def main(neo4j_driver: neo4j.Driver) -> dict[str, Any]: llm=OpenAILLM( model_name="gpt-4o", model_params={ - "max_tokens": 1000, + "max_tokens": 2000, "response_format": {"type": "json_object"}, }, ), @@ -164,7 +161,14 @@ async def main(neo4j_driver: neo4j.Driver) -> dict[str, Any]: pipe.add_component(Neo4jWriter(neo4j_driver), "writer") pipe.connect("pdf_loader", "splitter", input_config={"text": "pdf_loader.text"}) pipe.connect("splitter", "extractor", input_config={"chunks": "splitter"}) - pipe.connect("schema", "extractor", input_config={"schema": "schema"}) + pipe.connect( + "schema", + "extractor", + input_config={ + "schema": "schema", + "document_info": "pdf_loader.document_info", + }, + ) pipe.connect( "extractor", "writer", diff --git a/examples/pipeline/kg_builder_from_text.py b/examples/pipeline/kg_builder_from_text.py index 5e717578..3e28e2d7 100644 --- a/examples/pipeline/kg_builder_from_text.py +++ b/examples/pipeline/kg_builder_from_text.py @@ -154,6 +154,11 @@ async def main(neo4j_driver: neo4j.Driver) -> dict[str, Any]: ("Person", "WORKED_FOR", "Organization"), ], }, + "extractor": { + "document_info": { + "path": "my text", + } + }, } # run the pipeline return await pipe.run(pipe_inputs) diff --git a/src/neo4j_genai/experimental/components/embedder.py b/src/neo4j_genai/experimental/components/embedder.py index 0683bef1..3878f912 100644 --- a/src/neo4j_genai/experimental/components/embedder.py +++ b/src/neo4j_genai/experimental/components/embedder.py @@ -56,7 +56,9 @@ def _embed_chunk(self, text_chunk: TextChunk) -> TextChunk: embedding = self._embedder.embed_query(text_chunk.text) metadata = text_chunk.metadata if text_chunk.metadata else {} metadata["embedding"] = embedding - return TextChunk(text=text_chunk.text, metadata=metadata) + return TextChunk( + text=text_chunk.text, index=text_chunk.index, metadata=metadata + ) @validate_call async def run(self, text_chunks: TextChunks) -> TextChunks: diff --git a/src/neo4j_genai/experimental/components/entity_relation_extractor.py b/src/neo4j_genai/experimental/components/entity_relation_extractor.py index 93beba84..242fd4df 100644 --- a/src/neo4j_genai/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_genai/experimental/components/entity_relation_extractor.py @@ -20,12 +20,14 @@ import json import logging import re +import warnings from datetime import datetime -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from pydantic import ValidationError, validate_call from neo4j_genai.exceptions import LLMGenerationError +from neo4j_genai.experimental.components.pdf_loader import DocumentInfo from neo4j_genai.experimental.components.schema import SchemaConfig from neo4j_genai.experimental.components.types import ( Neo4jGraph, @@ -47,8 +49,10 @@ class OnError(enum.Enum): CHUNK_NODE_LABEL = "Chunk" +DOCUMENT_NODE_LABEL = "Document" NEXT_CHUNK_RELATIONSHIP_TYPE = "NEXT_CHUNK" NODE_TO_CHUNK_RELATIONSHIP_TYPE = "FROM_CHUNK" +CHUNK_TO_DOCUMENT_RELATIONSHIP_TYPE = "FROM_DOCUMENT" def balance_curly_braces(json_string: str) -> str: @@ -124,45 +128,12 @@ def fix_invalid_json(invalid_json_string: str) -> str: return balance_curly_braces(invalid_json_string) -class EntityRelationExtractor(Component, abc.ABC): - """Abstract class for entity relation extraction components. - - Args: - on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error. - create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True. - """ - - def __init__( - self, - *args: Any, - on_error: OnError = OnError.IGNORE, - create_lexical_graph: bool = True, - **kwargs: Any, - ) -> None: - self.create_lexical_graph = create_lexical_graph - self.on_error = on_error - self._id_prefix = "" - - @abc.abstractmethod - async def run(self, chunks: TextChunks, **kwargs: Any) -> Neo4jGraph: - pass - - def update_ids(self, graph: Neo4jGraph, chunk_index: int) -> Neo4jGraph: - """Make node IDs unique across chunks and pipeline runs by - prefixing them with a custom prefix (set in the run method) - and chunk index.""" - for node in graph.nodes: - node.id = f"{self._id_prefix}:{chunk_index}:{node.id}" - if node.properties is None: - node.properties = {} - node.properties.update({"chunk_index": chunk_index}) - for rel in graph.relationships: - rel.start_node_id = f"{self._id_prefix}:{chunk_index}:{rel.start_node_id}" - rel.end_node_id = f"{self._id_prefix}:{chunk_index}:{rel.end_node_id}" - return graph +class LexicalGraphBuilder: + """A helper class to encompass useful methods to build the lexical graph""" + @staticmethod def create_next_chunk_relationship( - self, previous_chunk_id: str, chunk_id: str + previous_chunk_id: str, chunk_id: str ) -> Neo4jRelationship: """Create relationship between a chunk and the next one""" return Neo4jRelationship( @@ -171,12 +142,14 @@ def create_next_chunk_relationship( end_node_id=chunk_id, ) - def create_chunk_node(self, chunk: TextChunk, chunk_id: str) -> Neo4jNode: - """Create chunk node with properties 'text' and any 'metadata' added during + @staticmethod + def create_chunk_node(chunk: TextChunk, chunk_id: str) -> Neo4jNode: + """Create chunk node with properties 'text', 'index' and any 'metadata' added during the process. Special case for the potential chunk embedding property that gets added as an embedding_property""" chunk_properties: Dict[str, Any] = { "text": chunk.text, + "index": chunk.index, } embedding_properties = {} if chunk.metadata: @@ -200,27 +173,105 @@ def create_node_to_chunk_rel( type=NODE_TO_CHUNK_RELATIONSHIP_TYPE, ) - def build_lexical_graph( - self, chunk_graph: Neo4jGraph, chunk_index: int, chunk: TextChunk - ) -> Neo4jGraph: + @staticmethod + def create_document_node(document_info: DocumentInfo) -> Neo4jNode: + """Create a Document node with 'path' property. Any document metadata is also + added as a node property. + """ + document_metadata = document_info.metadata or {} + return Neo4jNode( + id=document_info.path, + label=DOCUMENT_NODE_LABEL, + properties={ + "path": document_info.path, + **document_metadata, + }, + ) + + @staticmethod + def create_chunk_to_document_rel( + chunk_id: str, document_id: str + ) -> Neo4jRelationship: + """Create the relationship between a chunk and the document it belongs to.""" + return Neo4jRelationship( + start_node_id=chunk_id, + end_node_id=document_id, + type=CHUNK_TO_DOCUMENT_RELATIONSHIP_TYPE, + ) + + async def process_chunk( + self, + chunk_graph: Neo4jGraph, + chunk: TextChunk, + id_prefix: str, + document_id: Optional[str] = None, + ) -> None: """Add chunks and relationships between them (NEXT_CHUNK) and between chunks and extracted entities from that chunk. + Updates `chunk_graph` in place. """ - chunk_id = f"{self._id_prefix}:{chunk_index}" + chunk_id = f"{id_prefix}:{chunk.index}" + if document_id: + chunk_to_doc_rel = self.create_chunk_to_document_rel(chunk_id, document_id) + chunk_graph.relationships.append(chunk_to_doc_rel) chunk_node = self.create_chunk_node(chunk, chunk_id) chunk_graph.nodes.append(chunk_node) - if chunk_index > 0: - previous_chunk_id = f"{self._id_prefix}:{chunk_index - 1}" + if chunk.index > 0: + previous_chunk_id = f"{id_prefix}:{chunk.index - 1}" next_chunk_rel = self.create_next_chunk_relationship( previous_chunk_id, chunk_id ) chunk_graph.relationships.append(next_chunk_rel) for node in chunk_graph.nodes: - if node.label == CHUNK_NODE_LABEL: + if node.label in (CHUNK_NODE_LABEL, DOCUMENT_NODE_LABEL): continue node_to_chunk_rel = self.create_node_to_chunk_rel(node, chunk_id) chunk_graph.relationships.append(node_to_chunk_rel) - return chunk_graph + + +class EntityRelationExtractor(Component, abc.ABC): + """Abstract class for entity relation extraction components. + + Args: + on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error. + create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True. + """ + + def __init__( + self, + *args: Any, + on_error: OnError = OnError.IGNORE, + create_lexical_graph: bool = True, + **kwargs: Any, + ) -> None: + self.on_error = on_error + self.create_lexical_graph = create_lexical_graph + + @abc.abstractmethod + async def run( + self, + chunks: TextChunks, + document_info: Optional[DocumentInfo] = None, + **kwargs: Any, + ) -> Neo4jGraph: + pass + + def update_ids( + self, graph: Neo4jGraph, chunk_index: int, run_id: str + ) -> Neo4jGraph: + """Make node IDs unique across chunks and pipeline runs by + prefixing them with a custom prefix (set in the run method) + and chunk index.""" + prefix = f"{run_id}:{chunk_index}" + for node in graph.nodes: + node.id = f"{prefix}:{node.id}" + if node.properties is None: + node.properties = {} + node.properties.update({"chunk_index": chunk_index}) + for rel in graph.relationships: + rel.start_node_id = f"{prefix}:{rel.start_node_id}" + rel.end_node_id = f"{prefix}:{rel.end_node_id}" + return graph class LLMEntityRelationExtractor(EntityRelationExtractor): @@ -268,7 +319,7 @@ def __init__( self.prompt_template = template async def extract_for_chunk( - self, schema: SchemaConfig, examples: str, chunk_index: int, chunk: TextChunk + self, schema: SchemaConfig, examples: str, chunk: TextChunk ) -> Neo4jGraph: """Run entity extraction for a given text chunk.""" prompt = self.prompt_template.format( @@ -278,6 +329,9 @@ async def extract_for_chunk( try: result = json.loads(llm_result.content) except json.JSONDecodeError: + logger.warning( + f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}. Trying to fix it." + ) fixed_content = fix_invalid_json(llm_result.content) try: result = json.loads(fixed_content) @@ -288,7 +342,7 @@ async def extract_for_chunk( ) else: logger.error( - f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk_index}" + f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}" ) result = {"nodes": [], "relationships": []} try: @@ -300,25 +354,34 @@ async def extract_for_chunk( ) else: logger.error( - f"LLM response has improper format {result} for chunk_index={chunk_index}" + f"LLM response has improper format {result} for chunk_index={chunk.index}" ) chunk_graph = Neo4jGraph() return chunk_graph async def post_process_chunk( - self, chunk_graph: Neo4jGraph, chunk_index: int, chunk: TextChunk + self, + chunk_graph: Neo4jGraph, + chunk: TextChunk, + run_id: str, + lexical_graph_builder: Optional[LexicalGraphBuilder] = None, + document_id: Optional[str] = None, ) -> None: """Perform post-processing after entity and relation extraction: - Update node IDs to make them unique across chunks - Build the lexical graph if requested """ - self.update_ids(chunk_graph, chunk_index) - if self.create_lexical_graph: - self.build_lexical_graph(chunk_graph, chunk_index, chunk) + self.update_ids(chunk_graph, chunk.index, run_id) + if lexical_graph_builder: + await lexical_graph_builder.process_chunk( + chunk_graph, chunk, run_id, document_id=document_id + ) - def combine_chunk_graphs(self, chunk_graphs: List[Neo4jGraph]) -> Neo4jGraph: + def combine_chunk_graphs( + self, lexical_graph: Neo4jGraph, chunk_graphs: List[Neo4jGraph] + ) -> Neo4jGraph: """Combine sub-graphs obtained for each chunk into a single Neo4jGraph object""" - graph = Neo4jGraph() + graph = lexical_graph.model_copy(deep=True) for chunk_graph in chunk_graphs: graph.nodes.extend(chunk_graph.nodes) graph.relationships.extend(chunk_graph.relationships) @@ -326,38 +389,59 @@ def combine_chunk_graphs(self, chunk_graphs: List[Neo4jGraph]) -> Neo4jGraph: async def run_for_chunk( self, + sem: asyncio.Semaphore, + run_id: str, + chunk: TextChunk, schema: SchemaConfig, examples: str, - chunk_index: int, - chunk: TextChunk, - sem: asyncio.Semaphore, + lexical_graph_builder: Optional[LexicalGraphBuilder] = None, + document_id: Optional[str] = None, ) -> Neo4jGraph: """Run extraction and post processing for a single chunk""" async with sem: - chunk_graph = await self.extract_for_chunk( - schema, examples, chunk_index, chunk + chunk_graph = await self.extract_for_chunk(schema, examples, chunk) + await self.post_process_chunk( + chunk_graph, chunk, run_id, lexical_graph_builder, document_id ) - await self.post_process_chunk(chunk_graph, chunk_index, chunk) return chunk_graph @validate_call async def run( self, chunks: TextChunks, + document_info: Optional[DocumentInfo] = None, schema: Union[SchemaConfig, None] = None, examples: str = "", **kwargs: Any, ) -> Neo4jGraph: """Perform entity and relation extraction for all chunks in a list.""" + lexical_graph_builder = None + document_id = None + nodes = [] + if self.create_lexical_graph: + lexical_graph_builder = LexicalGraphBuilder() + if document_info is None: + warnings.warn( + "No document metadata provided, the document node won't be created in the lexical graph" + ) + else: + document_node = lexical_graph_builder.create_document_node( + document_info + ) + nodes.append(document_node) + document_id = document_node.id + lexical_graph = Neo4jGraph(nodes=nodes, relationships=[]) schema = schema or SchemaConfig(entities={}, relations={}, potential_schema=[]) examples = examples or "" - self._id_prefix = str(datetime.now().timestamp()) + run_id = str(datetime.now().timestamp()) sem = asyncio.Semaphore(self.max_concurrency) tasks = [ - self.run_for_chunk(schema, examples, chunk_index, chunk, sem) - for chunk_index, chunk in enumerate(chunks.chunks) + self.run_for_chunk( + sem, run_id, chunk, schema, examples, lexical_graph_builder, document_id + ) + for chunk in chunks.chunks ] - chunk_graphs = await asyncio.gather(*tasks) - graph = self.combine_chunk_graphs(chunk_graphs) + chunk_graphs: list[Neo4jGraph] = list(await asyncio.gather(*tasks)) + graph = self.combine_chunk_graphs(lexical_graph, chunk_graphs) logger.debug(f"{self.__class__.__name__}: {graph}") return graph diff --git a/src/neo4j_genai/experimental/components/pdf_loader.py b/src/neo4j_genai/experimental/components/pdf_loader.py index b6b77408..a8fa6e1c 100644 --- a/src/neo4j_genai/experimental/components/pdf_loader.py +++ b/src/neo4j_genai/experimental/components/pdf_loader.py @@ -12,10 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import io from abc import abstractmethod from pathlib import Path -from typing import Optional, Union +from typing import Dict, Optional, Union import fsspec import pypdf @@ -26,8 +28,14 @@ from neo4j_genai.experimental.pipeline import Component, DataModel +class DocumentInfo(DataModel): + path: str + metadata: Optional[Dict[str, str]] = None + + class PdfDocument(DataModel): text: str + document_info: DocumentInfo class DataLoader(Component): @@ -35,8 +43,15 @@ class DataLoader(Component): Interface for loading data of various input types. """ + def get_document_metadata( + self, text: str, metadata: Optional[Dict[str, str]] = None + ) -> Dict[str, str] | None: + return metadata + @abstractmethod - async def run(self, filepath: Path) -> PdfDocument: + async def run( + self, filepath: Path, metadata: Optional[Dict[str, str]] = None + ) -> PdfDocument: pass @@ -48,14 +63,12 @@ class PdfLoader(DataLoader): @staticmethod def load_file( file: Union[Path, str], - fs: Optional[AbstractFileSystem] = None, + fs: AbstractFileSystem, ) -> str: """Parse PDF file and return text.""" if not isinstance(file, Path): file = Path(file) - fs = fs or LocalFileSystem() - try: with fs.open(file, "rb") as fp: stream = fp if is_default_fs(fs) else io.BytesIO(fp.read()) @@ -73,6 +86,15 @@ def load_file( async def run( self, filepath: Path, + metadata: Optional[Dict[str, str]] = None, fs: Optional[AbstractFileSystem] = None, ) -> PdfDocument: - return PdfDocument(text=self.load_file(filepath, fs)) + fs = fs or LocalFileSystem() + text = self.load_file(filepath, fs) + return PdfDocument( + text=text, + document_info=DocumentInfo( + path=str(filepath), + metadata=self.get_document_metadata(text, metadata), + ), + ) diff --git a/src/neo4j_genai/experimental/components/text_splitters/langchain.py b/src/neo4j_genai/experimental/components/text_splitters/langchain.py index 077748f0..6788b392 100644 --- a/src/neo4j_genai/experimental/components/text_splitters/langchain.py +++ b/src/neo4j_genai/experimental/components/text_splitters/langchain.py @@ -54,8 +54,9 @@ async def run(self, text: str) -> TextChunks: Returns: TextChunks: The text split into chunks. """ + chunks = self.text_splitter.split_text(text) return TextChunks( chunks=[ - TextChunk(text=chunk) for chunk in self.text_splitter.split_text(text) + TextChunk(text=chunk, index=index) for index, chunk in enumerate(chunks) ] ) diff --git a/src/neo4j_genai/experimental/components/text_splitters/llamaindex.py b/src/neo4j_genai/experimental/components/text_splitters/llamaindex.py index 24289593..08cf45b8 100644 --- a/src/neo4j_genai/experimental/components/text_splitters/llamaindex.py +++ b/src/neo4j_genai/experimental/components/text_splitters/llamaindex.py @@ -54,8 +54,9 @@ async def run(self, text: str) -> TextChunks: Returns: TextChunks: The text split into chunks. """ + chunks = self.text_splitter.split_text(text) return TextChunks( chunks=[ - TextChunk(text=chunk) for chunk in self.text_splitter.split_text(text) + TextChunk(text=chunk, index=index) for index, chunk in enumerate(chunks) ] ) diff --git a/src/neo4j_genai/experimental/components/types.py b/src/neo4j_genai/experimental/components/types.py index a22443a1..cf6aadd7 100644 --- a/src/neo4j_genai/experimental/components/types.py +++ b/src/neo4j_genai/experimental/components/types.py @@ -30,6 +30,7 @@ class TextChunk(BaseModel): """ text: str + index: int metadata: Optional[dict[str, Any]] = None diff --git a/tests/e2e/test_kg_builder_pipeline_e2e.py b/tests/e2e/test_kg_builder_pipeline_e2e.py index 0ccc6fcb..9978a44b 100644 --- a/tests/e2e/test_kg_builder_pipeline_e2e.py +++ b/tests/e2e/test_kg_builder_pipeline_e2e.py @@ -250,6 +250,7 @@ async def test_pipeline_builder_happy_path( ("Organization", "LED_BY", "Person"), ], }, + "extractor": {"document_info": {"path": "my document path"}}, } res = await kg_builder_pipeline.run(pipe_inputs) # llm must have been called for each chunk @@ -260,12 +261,13 @@ async def test_pipeline_builder_happy_path( chunks = kg_builder_pipeline.get_results_for_component("splitter") assert len(chunks["chunks"]) == 3 graph = kg_builder_pipeline.get_results_for_component("extractor") - # 3 entities + 3 chunks + # 3 entities + 3 chunks + 1 document nodes = graph["nodes"] - assert len(nodes) == 6 + assert len(nodes) == 7 label_counts = dict(Counter([n["label"] for n in nodes])) assert label_counts == { "Chunk": 3, + "Document": 1, "Person": 2, "Organization": 1, } @@ -273,14 +275,20 @@ async def test_pipeline_builder_happy_path( # + 3 rels between entities and their chunk # + 2 "NEXT_CHUNK" rels relationships = graph["relationships"] - assert len(relationships) == 7 + assert len(relationships) == 10 type_counts = dict(Counter([r["type"] for r in relationships])) - assert type_counts == {"FROM_CHUNK": 3, "KNOWS": 1, "LED_BY": 1, "NEXT_CHUNK": 2} + assert type_counts == { + "FROM_CHUNK": 3, + "FROM_DOCUMENT": 3, + "KNOWS": 1, + "LED_BY": 1, + "NEXT_CHUNK": 2, + } # then check content of neo4j db created_nodes = driver.execute_query("MATCH (n) RETURN n") - assert len(created_nodes.records) == 6 + assert len(created_nodes.records) == 7 created_rels = driver.execute_query("MATCH ()-[r]->() RETURN r") - assert len(created_rels.records) == 7 + assert len(created_rels.records) == 10 created_chunks = driver.execute_query("MATCH (n:Chunk) RETURN n").records assert len(created_chunks) == 3 diff --git a/tests/unit/experimental/components/test_embedder.py b/tests/unit/experimental/components/test_embedder.py index 98064184..cbeb5d5d 100644 --- a/tests/unit/experimental/components/test_embedder.py +++ b/tests/unit/experimental/components/test_embedder.py @@ -23,7 +23,9 @@ async def test_text_chunk_embedder_run(embedder: MagicMock) -> None: embedder.embed_query.return_value = [1.0, 2.0, 3.0] text_chunk_embedder = TextChunkEmbedder(embedder=embedder) - text_chunks = TextChunks(chunks=[TextChunk(text="may thy knife chip and shatter")]) + text_chunks = TextChunks( + chunks=[TextChunk(text="may thy knife chip and shatter", index=0)] + ) embedded_chunks = await text_chunk_embedder.run(text_chunks) embedder.embed_query.assert_called_once_with("may thy knife chip and shatter") assert isinstance(embedded_chunks, TextChunks) diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index c6d16041..e919b647 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -20,12 +20,13 @@ import pytest from neo4j_genai.exceptions import LLMGenerationError from neo4j_genai.experimental.components.entity_relation_extractor import ( - EntityRelationExtractor, + LexicalGraphBuilder, LLMEntityRelationExtractor, OnError, balance_curly_braces, fix_invalid_json, ) +from neo4j_genai.experimental.components.pdf_loader import DocumentInfo from neo4j_genai.experimental.components.types import ( Neo4jGraph, Neo4jNode, @@ -36,61 +37,80 @@ def test_create_chunk_node_no_metadata() -> None: - # instantiating an abstract class to test common methods - extractor = EntityRelationExtractor() # type: ignore - node = extractor.create_chunk_node( - chunk=TextChunk(text="text chunk"), chunk_id="10" + builder = LexicalGraphBuilder() + node = builder.create_chunk_node( + chunk=TextChunk(text="text chunk", index=0), chunk_id="10" ) assert isinstance(node, Neo4jNode) assert node.id == "10" - assert node.properties == {"text": "text chunk"} + assert node.properties == {"index": 0, "text": "text chunk"} assert node.embedding_properties == {} def test_create_chunk_node_metadata_no_embedding() -> None: - # instantiating an abstract class to test common methods - extractor = EntityRelationExtractor() # type: ignore - node = extractor.create_chunk_node( - chunk=TextChunk(text="text chunk", metadata={"status": "ok"}), chunk_id="10" + builder = LexicalGraphBuilder() + node = builder.create_chunk_node( + chunk=TextChunk(text="text chunk", index=0, metadata={"status": "ok"}), + chunk_id="10", ) assert isinstance(node, Neo4jNode) assert node.id == "10" - assert node.properties == {"text": "text chunk", "status": "ok"} + assert node.properties == {"index": 0, "text": "text chunk", "status": "ok"} assert node.embedding_properties == {} def test_create_chunk_node_metadata_embedding() -> None: - # instantiating an abstract class to test common methods - extractor = EntityRelationExtractor() # type: ignore - node = extractor.create_chunk_node( + builder = LexicalGraphBuilder() + node = builder.create_chunk_node( chunk=TextChunk( - text="text chunk", metadata={"status": "ok", "embedding": [1, 2, 3]} + text="text chunk", + index=0, + metadata={"status": "ok", "embedding": [1, 2, 3]}, ), chunk_id="10", ) assert isinstance(node, Neo4jNode) assert node.id == "10" - assert node.properties == {"text": "text chunk", "status": "ok"} + assert node.properties == {"index": 0, "text": "text chunk", "status": "ok"} assert node.embedding_properties == {"embedding": [1, 2, 3]} @pytest.mark.asyncio -async def test_extractor_happy_path_no_entities() -> None: +async def test_extractor_happy_path_no_entities_no_document() -> None: llm = MagicMock(spec=LLMInterface) llm.ainvoke.return_value = LLMResponse(content='{"nodes": [], "relationships": []}') extractor = LLMEntityRelationExtractor( llm=llm, ) - chunks = TextChunks(chunks=[TextChunk(text="some text")]) + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) result = await extractor.run(chunks=chunks) assert isinstance(result, Neo4jGraph) - # only one Chunk node + # only one Chunk node (no document info provided) assert len(result.nodes) == 1 assert result.nodes[0].label == "Chunk" assert result.relationships == [] +@pytest.mark.asyncio +async def test_extractor_happy_path_no_entities() -> None: + llm = MagicMock(spec=LLMInterface) + llm.ainvoke.return_value = LLMResponse(content='{"nodes": [], "relationships": []}') + + extractor = LLMEntityRelationExtractor( + llm=llm, + ) + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) + document_info = DocumentInfo(path="path") + result = await extractor.run(chunks=chunks, document_info=document_info) + assert isinstance(result, Neo4jGraph) + # one Chunk node and one Document node + assert len(result.nodes) == 2 + assert set(n.label for n in result.nodes) == {"Chunk", "Document"} + assert len(result.relationships) == 1 + assert result.relationships[0].type == "FROM_DOCUMENT" + + @pytest.mark.asyncio async def test_extractor_happy_path_no_entities_no_lexical_graph() -> None: llm = MagicMock(spec=LLMInterface) @@ -100,8 +120,9 @@ async def test_extractor_happy_path_no_entities_no_lexical_graph() -> None: llm=llm, create_lexical_graph=False, ) - chunks = TextChunks(chunks=[TextChunk(text="some text")]) - result = await extractor.run(chunks=chunks) + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) + document_info = DocumentInfo(path="path") + result = await extractor.run(chunks=chunks, document_info=document_info) assert result.nodes == [] assert result.relationships == [] @@ -116,18 +137,24 @@ async def test_extractor_happy_path_non_empty_result() -> None: extractor = LLMEntityRelationExtractor( llm=llm, ) - chunks = TextChunks(chunks=[TextChunk(text="some text")]) - result = await extractor.run(chunks=chunks) + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) + document_info = DocumentInfo(path="path") + result = await extractor.run(chunks=chunks, document_info=document_info) assert isinstance(result, Neo4jGraph) - assert len(result.nodes) == 2 - entity = result.nodes[0] + assert len(result.nodes) == 3 + doc = result.nodes[0] + assert doc.label == "Document" + entity = result.nodes[1] assert entity.id.endswith("0:0") assert entity.label == "Person" assert entity.properties == {"chunk_index": 0} - chunk_entity = result.nodes[1] + chunk_entity = result.nodes[2] assert chunk_entity.label == "Chunk" - assert len(result.relationships) == 1 - assert result.relationships[0].type == "FROM_CHUNK" + assert len(result.relationships) == 2 + assert result.relationships[0].type == "FROM_DOCUMENT" + assert result.relationships[0].start_node_id.endswith(":0") + assert result.relationships[0].end_node_id == "path" + assert result.relationships[1].type == "FROM_CHUNK" @pytest.mark.asyncio @@ -139,7 +166,7 @@ async def test_extractor_missing_entity_id() -> None: extractor = LLMEntityRelationExtractor( llm=llm, ) - chunks = TextChunks(chunks=[TextChunk(text="some text")]) + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) with pytest.raises(LLMGenerationError): await extractor.run(chunks=chunks) @@ -152,7 +179,7 @@ async def test_extractor_llm_ainvoke_failed() -> None: extractor = LLMEntityRelationExtractor( llm=llm, ) - chunks = TextChunks(chunks=[TextChunk(text="some text")]) + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) with pytest.raises(LLMGenerationError): await extractor.run(chunks=chunks) @@ -167,7 +194,7 @@ async def test_extractor_llm_badly_formatted_json() -> None: extractor = LLMEntityRelationExtractor( llm=llm, ) - chunks = TextChunks(chunks=[TextChunk(text="some text")]) + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) with pytest.raises(LLMGenerationError): await extractor.run(chunks=chunks) @@ -185,7 +212,7 @@ async def test_extractor_llm_invalid_json() -> None: extractor = LLMEntityRelationExtractor( llm=llm, ) - chunks = TextChunks(chunks=[TextChunk(text="some text")]) + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) with pytest.raises(LLMGenerationError): await extractor.run(chunks=chunks) @@ -202,7 +229,7 @@ async def test_extractor_llm_badly_formatted_json_do_not_raise() -> None: on_error=OnError.IGNORE, create_lexical_graph=False, ) - chunks = TextChunks(chunks=[TextChunk(text="some text")]) + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) res = await extractor.run(chunks=chunks) assert res.nodes == [] assert res.relationships == [] @@ -214,7 +241,7 @@ async def test_extractor_custom_prompt() -> None: llm.ainvoke.return_value = LLMResponse(content='{"nodes": [], "relationships": []}') extractor = LLMEntityRelationExtractor(llm=llm, prompt_template="this is my prompt") - chunks = TextChunks(chunks=[TextChunk(text="some text")]) + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) await extractor.run(chunks=chunks) llm.ainvoke.assert_called_once_with("this is my prompt") diff --git a/tests/unit/experimental/components/test_pdf_loader.py b/tests/unit/experimental/components/test_pdf_loader.py index 44d80f21..28db7eb5 100644 --- a/tests/unit/experimental/components/test_pdf_loader.py +++ b/tests/unit/experimental/components/test_pdf_loader.py @@ -17,6 +17,7 @@ from unittest.mock import patch import pytest +from fsspec.implementations.local import LocalFileSystem from neo4j_genai.exceptions import PdfLoaderError from neo4j_genai.experimental.components.pdf_loader import PdfLoader @@ -35,7 +36,7 @@ def dummy_pdf_path() -> Path: def test_pdf_loading(pdf_loader: PdfLoader, dummy_pdf_path: Path) -> None: expected_content = "Lorem ipsum dolor sit amet." - actual_content = pdf_loader.load_file(dummy_pdf_path) + actual_content = pdf_loader.load_file(dummy_pdf_path, fs=LocalFileSystem()) assert actual_content == expected_content @@ -45,4 +46,4 @@ def test_pdf_processing_error(pdf_loader: PdfLoader, dummy_pdf_path: Path) -> No side_effect=Exception("Failed to open"), ): with pytest.raises(PdfLoaderError): - pdf_loader.load_file(dummy_pdf_path) + pdf_loader.load_file(dummy_pdf_path, fs=LocalFileSystem())