forked from NirDiamant/RAG_Techniques
-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph_rag.py
816 lines (657 loc) · 33.4 KB
/
graph_rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
import networkx as nx
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.callbacks import get_openai_callback
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os
import sys
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from typing import List, Tuple, Dict
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize
import nltk
import spacy
import heapq
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import numpy as np
from spacy.cli import download
from spacy.lang.en import English
sys.path.append(os.path.abspath(
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
from helper_functions import *
from evaluation.evalute_rag import *
# Load environment variables from a .env file
load_dotenv()
# Set the OpenAI API key environment variable
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
nltk.download('punkt', quiet=True)
nltk.download('wordnet', quiet=True)
# Define the document processor class
# Define the DocumentProcessor class
class DocumentProcessor:
def __init__(self):
"""
Initializes the DocumentProcessor with a text splitter and OpenAI embeddings.
Attributes:
- text_splitter: An instance of RecursiveCharacterTextSplitter with specified chunk size and overlap.
- embeddings: An instance of OpenAIEmbeddings used for embedding documents.
"""
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
self.embeddings = OpenAIEmbeddings()
def process_documents(self, documents):
"""
Processes a list of documents by splitting them into smaller chunks and creating a vector store.
Args:
- documents (list of str): A list of documents to be processed.
Returns:
- tuple: A tuple containing:
- splits (list of str): The list of split document chunks.
- vector_store (FAISS): A FAISS vector store created from the split document chunks and their embeddings.
"""
splits = self.text_splitter.split_documents(documents)
vector_store = FAISS.from_documents(splits, self.embeddings)
return splits, vector_store
def create_embeddings_batch(self, texts, batch_size=32):
"""
Creates embeddings for a list of texts in batches.
Args:
- texts (list of str): A list of texts to be embedded.
- batch_size (int, optional): The number of texts to process in each batch. Default is 32.
Returns:
- numpy.ndarray: An array of embeddings for the input texts.
"""
embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
batch_embeddings = self.embeddings.embed_documents(batch)
embeddings.extend(batch_embeddings)
return np.array(embeddings)
def compute_similarity_matrix(self, embeddings):
"""
Computes a cosine similarity matrix for a given set of embeddings.
Args:
- embeddings (numpy.ndarray): An array of embeddings.
Returns:
- numpy.ndarray: A cosine similarity matrix for the input embeddings.
"""
return cosine_similarity(embeddings)
# Define the knowledge graph class
# Define the Concepts class
class Concepts(BaseModel):
concepts_list: List[str] = Field(description="List of concepts")
# Define the KnowledgeGraph class
class KnowledgeGraph:
def __init__(self):
"""
Initializes the KnowledgeGraph with a graph, lemmatizer, and NLP model.
Attributes:
- graph: An instance of a networkx Graph.
- lemmatizer: An instance of WordNetLemmatizer.
- concept_cache: A dictionary to cache extracted concepts.
- nlp: An instance of a spaCy NLP model.
- edges_threshold: A float value that sets the threshold for adding edges based on similarity.
"""
self.graph = nx.Graph()
self.lemmatizer = WordNetLemmatizer()
self.concept_cache = {}
self.nlp = self._load_spacy_model()
self.edges_threshold = 0.8
def build_graph(self, splits, llm, embedding_model):
"""
Builds the knowledge graph by adding nodes, creating embeddings, extracting concepts, and adding edges.
Args:
- splits (list): A list of document splits.
- llm: An instance of a large language model.
- embedding_model: An instance of an embedding model.
Returns:
- None
"""
self._add_nodes(splits)
embeddings = self._create_embeddings(splits, embedding_model)
self._extract_concepts(splits, llm)
self._add_edges(embeddings)
def _add_nodes(self, splits):
"""
Adds nodes to the graph from the document splits.
Args:
- splits (list): A list of document splits.
Returns:
- None
"""
for i, split in enumerate(splits):
self.graph.add_node(i, content=split.page_content)
def _create_embeddings(self, splits, embedding_model):
"""
Creates embeddings for the document splits using the embedding model.
Args:
- splits (list): A list of document splits.
- embedding_model: An instance of an embedding model.
Returns:
- numpy.ndarray: An array of embeddings for the document splits.
"""
texts = [split.page_content for split in splits]
return embedding_model.embed_documents(texts)
def _compute_similarities(self, embeddings):
"""
Computes the cosine similarity matrix for the embeddings.
Args:
- embeddings (numpy.ndarray): An array of embeddings.
Returns:
- numpy.ndarray: A cosine similarity matrix for the embeddings.
"""
return cosine_similarity(embeddings)
def _load_spacy_model(self):
"""
Loads the spaCy NLP model, downloading it if necessary.
Args:
- None
Returns:
- spacy.Language: An instance of a spaCy NLP model.
"""
try:
return spacy.load("en_core_web_sm")
except OSError:
print("Downloading spaCy model...")
download("en_core_web_sm")
return spacy.load("en_core_web_sm")
def _extract_concepts_and_entities(self, content, llm):
"""
Extracts concepts and named entities from the content using spaCy and a large language model.
Args:
- content (str): The content from which to extract concepts and entities.
- llm: An instance of a large language model.
Returns:
- list: A list of extracted concepts and entities.
"""
if content in self.concept_cache:
return self.concept_cache[content]
# Extract named entities using spaCy
doc = self.nlp(content)
named_entities = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "WORK_OF_ART"]]
# Extract general concepts using LLM
concept_extraction_prompt = PromptTemplate(
input_variables=["text"],
template="Extract key concepts (excluding named entities) from the following text:\n\n{text}\n\nKey concepts:"
)
concept_chain = concept_extraction_prompt | llm.with_structured_output(Concepts)
general_concepts = concept_chain.invoke({"text": content}).concepts_list
# Combine named entities and general concepts
all_concepts = list(set(named_entities + general_concepts))
self.concept_cache[content] = all_concepts
return all_concepts
def _extract_concepts(self, splits, llm):
"""
Extracts concepts for all document splits using multi-threading.
Args:
- splits (list): A list of document splits.
- llm: An instance of a large language model.
Returns:
- None
"""
with ThreadPoolExecutor() as executor:
future_to_node = {executor.submit(self._extract_concepts_and_entities, split.page_content, llm): i
for i, split in enumerate(splits)}
for future in tqdm(as_completed(future_to_node), total=len(splits),
desc="Extracting concepts and entities"):
node = future_to_node[future]
concepts = future.result()
self.graph.nodes[node]['concepts'] = concepts
def _add_edges(self, embeddings):
"""
Adds edges to the graph based on the similarity of embeddings and shared concepts.
Args:
- embeddings (numpy.ndarray): An array of embeddings for the document splits.
Returns:
- None
"""
similarity_matrix = self._compute_similarities(embeddings)
num_nodes = len(self.graph.nodes)
for node1 in tqdm(range(num_nodes), desc="Adding edges"):
for node2 in range(node1 + 1, num_nodes):
similarity_score = similarity_matrix[node1][node2]
if similarity_score > self.edges_threshold:
shared_concepts = set(self.graph.nodes[node1]['concepts']) & set(
self.graph.nodes[node2]['concepts'])
edge_weight = self._calculate_edge_weight(node1, node2, similarity_score, shared_concepts)
self.graph.add_edge(node1, node2, weight=edge_weight,
similarity=similarity_score,
shared_concepts=list(shared_concepts))
def _calculate_edge_weight(self, node1, node2, similarity_score, shared_concepts, alpha=0.7, beta=0.3):
"""
Calculates the weight of an edge based on similarity score and shared concepts.
Args:
- node1 (int): The first node.
- node2 (int): The second node.
- similarity_score (float): The similarity score between the nodes.
- shared_concepts (set): The set of shared concepts between the nodes.
- alpha (float, optional): The weight of the similarity score. Default is 0.7.
- beta (float, optional): The weight of the shared concepts. Default is 0.3.
Returns:
- float: The calculated weight of the edge.
"""
max_possible_shared = min(len(self.graph.nodes[node1]['concepts']), len(self.graph.nodes[node2]['concepts']))
normalized_shared_concepts = len(shared_concepts) / max_possible_shared if max_possible_shared > 0 else 0
return alpha * similarity_score + beta * normalized_shared_concepts
def _lemmatize_concept(self, concept):
"""
Lemmatizes a given concept.
Args:
- concept (str): The concept to be lemmatized.
Returns:
- str: The lemmatized concept.
"""
return ' '.join([self.lemmatizer.lemmatize(word) for word in concept.lower().split()])
# Define the Query Engine class
# Define the AnswerCheck class
class AnswerCheck(BaseModel):
is_complete: bool = Field(description="Whether the current context provides a complete answer to the query")
answer: str = Field(description="The current answer based on the context, if any")
# Define the QueryEngine class
class QueryEngine:
def __init__(self, vector_store, knowledge_graph, llm):
self.vector_store = vector_store
self.knowledge_graph = knowledge_graph
self.llm = llm
self.max_context_length = 4000
self.answer_check_chain = self._create_answer_check_chain()
def _create_answer_check_chain(self):
"""
Creates a chain to check if the context provides a complete answer to the query.
Args:
- None
Returns:
- Chain: A chain to check if the context provides a complete answer.
"""
answer_check_prompt = PromptTemplate(
input_variables=["query", "context"],
template="Given the query: '{query}'\n\nAnd the current context:\n{context}\n\nDoes this context provide a complete answer to the query? If yes, provide the answer. If no, state that the answer is incomplete.\n\nIs complete answer (Yes/No):\nAnswer (if complete):"
)
return answer_check_prompt | self.llm.with_structured_output(AnswerCheck)
def _check_answer(self, query: str, context: str) -> Tuple[bool, str]:
"""
Checks if the current context provides a complete answer to the query.
Args:
- query (str): The query to be answered.
- context (str): The current context.
Returns:
- tuple: A tuple containing:
- is_complete (bool): Whether the context provides a complete answer.
- answer (str): The answer based on the context, if complete.
"""
response = self.answer_check_chain.invoke({"query": query, "context": context})
return response.is_complete, response.answer
def _expand_context(self, query: str, relevant_docs) -> Tuple[str, List[int], Dict[int, str], str]:
"""
Expands the context by traversing the knowledge graph using a Dijkstra-like approach.
This method implements a modified version of Dijkstra's algorithm to explore the knowledge graph,
prioritizing the most relevant and strongly connected information. The algorithm works as follows:
1. Initialize:
- Start with nodes corresponding to the most relevant documents.
- Use a priority queue to manage the traversal order, where priority is based on connection strength.
- Maintain a dictionary of best known "distances" (inverse of connection strengths) to each node.
2. Traverse:
- Always explore the node with the highest priority (strongest connection) next.
- For each node, check if we've found a complete answer.
- Explore the node's neighbors, updating their priorities if a stronger connection is found.
3. Concept Handling:
- Track visited concepts to guide the exploration towards new, relevant information.
- Expand to neighbors only if they introduce new concepts.
4. Termination:
- Stop if a complete answer is found.
- Continue until the priority queue is empty (all reachable nodes explored).
This approach ensures that:
- We prioritize the most relevant and strongly connected information.
- We explore new concepts systematically.
- We find the most relevant answer by following the strongest connections in the knowledge graph.
Args:
- query (str): The query to be answered.
- relevant_docs (List[Document]): A list of relevant documents to start the traversal.
Returns:
- tuple: A tuple containing:
- expanded_context (str): The accumulated context from traversed nodes.
- traversal_path (List[int]): The sequence of node indices visited.
- filtered_content (Dict[int, str]): A mapping of node indices to their content.
- final_answer (str): The final answer found, if any.
"""
# Initialize variables
expanded_context = ""
traversal_path = []
visited_concepts = set()
filtered_content = {}
final_answer = ""
priority_queue = []
distances = {} # Stores the best known "distance" (inverse of connection strength) to each node
print("\nTraversing the knowledge graph:")
# Initialize priority queue with closest nodes from relevant docs
for doc in relevant_docs:
# Find the most similar node in the knowledge graph for each relevant document
closest_nodes = self.vector_store.similarity_search_with_score(doc.page_content, k=1)
closest_node_content, similarity_score = closest_nodes[0]
# Get the corresponding node in our knowledge graph
closest_node = next(n for n in self.knowledge_graph.graph.nodes if
self.knowledge_graph.graph.nodes[n]['content'] == closest_node_content.page_content)
# Initialize priority (inverse of similarity score for min-heap behavior)
priority = 1 / similarity_score
heapq.heappush(priority_queue, (priority, closest_node))
distances[closest_node] = priority
step = 0
while priority_queue:
# Get the node with the highest priority (lowest distance value)
current_priority, current_node = heapq.heappop(priority_queue)
# Skip if we've already found a better path to this node
if current_priority > distances.get(current_node, float('inf')):
continue
if current_node not in traversal_path:
step += 1
traversal_path.append(current_node)
node_content = self.knowledge_graph.graph.nodes[current_node]['content']
node_concepts = self.knowledge_graph.graph.nodes[current_node]['concepts']
# Add node content to our accumulated context
filtered_content[current_node] = node_content
expanded_context += "\n" + node_content if expanded_context else node_content
# Log the current step for debugging and visualization
print(f"\nStep {step} - Node {current_node}:")
print(f"Content: {node_content[:100]}...")
print(f"Concepts: {', '.join(node_concepts)}")
print("-" * 50)
# Check if we have a complete answer with the current context
is_complete, answer = self._check_answer(query, expanded_context)
if is_complete:
final_answer = answer
break
# Process the concepts of the current node
node_concepts_set = set(self.knowledge_graph._lemmatize_concept(c) for c in node_concepts)
if not node_concepts_set.issubset(visited_concepts):
visited_concepts.update(node_concepts_set)
# Explore neighbors
for neighbor in self.knowledge_graph.graph.neighbors(current_node):
edge_data = self.knowledge_graph.graph[current_node][neighbor]
edge_weight = edge_data['weight']
# Calculate new distance (priority) to the neighbor
# Note: We use 1 / edge_weight because higher weights mean stronger connections
distance = current_priority + (1 / edge_weight)
# If we've found a stronger connection to the neighbor, update its distance
if distance < distances.get(neighbor, float('inf')):
distances[neighbor] = distance
heapq.heappush(priority_queue, (distance, neighbor))
# Process the neighbor node if it's not already in our traversal path
if neighbor not in traversal_path:
step += 1
traversal_path.append(neighbor)
neighbor_content = self.knowledge_graph.graph.nodes[neighbor]['content']
neighbor_concepts = self.knowledge_graph.graph.nodes[neighbor]['concepts']
filtered_content[neighbor] = neighbor_content
expanded_context += "\n" + neighbor_content if expanded_context else neighbor_content
# Log the neighbor node information
print(f"\nStep {step} - Node {neighbor} (neighbor of {current_node}):")
print(f"Content: {neighbor_content[:100]}...")
print(f"Concepts: {', '.join(neighbor_concepts)}")
print("-" * 50)
# Check if we have a complete answer after adding the neighbor's content
is_complete, answer = self._check_answer(query, expanded_context)
if is_complete:
final_answer = answer
break
# Process the neighbor's concepts
neighbor_concepts_set = set(
self.knowledge_graph._lemmatize_concept(c) for c in neighbor_concepts)
if not neighbor_concepts_set.issubset(visited_concepts):
visited_concepts.update(neighbor_concepts_set)
# If we found a final answer, break out of the main loop
if final_answer:
break
# If we haven't found a complete answer, generate one using the LLM
if not final_answer:
print("\nGenerating final answer...")
response_prompt = PromptTemplate(
input_variables=["query", "context"],
template="Based on the following context, please answer the query.\n\nContext: {context}\n\nQuery: {query}\n\nAnswer:"
)
response_chain = response_prompt | self.llm
input_data = {"query": query, "context": expanded_context}
final_answer = response_chain.invoke(input_data)
return expanded_context, traversal_path, filtered_content, final_answer
def query(self, query: str) -> Tuple[str, List[int], Dict[int, str]]:
"""
Processes a query by retrieving relevant documents, expanding the context, and generating the final answer.
Args:
- query (str): The query to be answered.
Returns:
- tuple: A tuple containing:
- final_answer (str): The final answer to the query.
- traversal_path (list): The traversal path of nodes in the knowledge graph.
- filtered_content (dict): The filtered content of nodes.
"""
with get_openai_callback() as cb:
print(f"\nProcessing query: {query}")
relevant_docs = self._retrieve_relevant_documents(query)
expanded_context, traversal_path, filtered_content, final_answer = self._expand_context(query,
relevant_docs)
if not final_answer:
print("\nGenerating final answer...")
response_prompt = PromptTemplate(
input_variables=["query", "context"],
template="Based on the following context, please answer the query.\n\nContext: {context}\n\nQuery: {query}\n\nAnswer:"
)
response_chain = response_prompt | self.llm
input_data = {"query": query, "context": expanded_context}
response = response_chain.invoke(input_data)
final_answer = response
else:
print("\nComplete answer found during traversal.")
print(f"\nFinal Answer: {final_answer}")
print(f"\nTotal Tokens: {cb.total_tokens}")
print(f"Prompt Tokens: {cb.prompt_tokens}")
print(f"Completion Tokens: {cb.completion_tokens}")
print(f"Total Cost (USD): ${cb.total_cost}")
return final_answer, traversal_path, filtered_content
def _retrieve_relevant_documents(self, query: str):
"""
Retrieves relevant documents based on the query using the vector store.
Args:
- query (str): The query to be answered.
Returns:
- list: A list of relevant documents.
"""
print("\nRetrieving relevant documents...")
retriever = self.vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
compressor = LLMChainExtractor.from_llm(self.llm)
compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
return compression_retriever.invoke(query)
# Import necessary libraries
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# Define the Visualizer class
class Visualizer:
@staticmethod
def visualize_traversal(graph, traversal_path):
"""
Visualizes the traversal path on the knowledge graph with nodes, edges, and traversal path highlighted.
Args:
- graph (networkx.Graph): The knowledge graph containing nodes and edges.
- traversal_path (list of int): The list of node indices representing the traversal path.
Returns:
- None
"""
traversal_graph = nx.DiGraph()
# Add nodes and edges from the original graph
for node in graph.nodes():
traversal_graph.add_node(node)
for u, v, data in graph.edges(data=True):
traversal_graph.add_edge(u, v, **data)
fig, ax = plt.subplots(figsize=(16, 12))
# Generate positions for all nodes
pos = nx.spring_layout(traversal_graph, k=1, iterations=50)
# Draw regular edges with color based on weight
edges = traversal_graph.edges()
edge_weights = [traversal_graph[u][v].get('weight', 0.5) for u, v in edges]
nx.draw_networkx_edges(traversal_graph, pos,
edgelist=edges,
edge_color=edge_weights,
edge_cmap=plt.cm.Blues,
width=2,
ax=ax)
# Draw nodes
nx.draw_networkx_nodes(traversal_graph, pos,
node_color='lightblue',
node_size=3000,
ax=ax)
# Draw traversal path with curved arrows
edge_offset = 0.1
for i in range(len(traversal_path) - 1):
start = traversal_path[i]
end = traversal_path[i + 1]
start_pos = pos[start]
end_pos = pos[end]
# Calculate control point for curve
mid_point = ((start_pos[0] + end_pos[0]) / 2, (start_pos[1] + end_pos[1]) / 2)
control_point = (mid_point[0] + edge_offset, mid_point[1] + edge_offset)
# Draw curved arrow
arrow = patches.FancyArrowPatch(start_pos, end_pos,
connectionstyle=f"arc3,rad={0.3}",
color='red',
arrowstyle="->",
mutation_scale=20,
linestyle='--',
linewidth=2,
zorder=4)
ax.add_patch(arrow)
# Prepare labels for the nodes
labels = {}
for i, node in enumerate(traversal_path):
concepts = graph.nodes[node].get('concepts', [])
label = f"{i + 1}. {concepts[0] if concepts else ''}"
labels[node] = label
for node in traversal_graph.nodes():
if node not in labels:
concepts = graph.nodes[node].get('concepts', [])
labels[node] = concepts[0] if concepts else ''
# Draw labels
nx.draw_networkx_labels(traversal_graph, pos, labels, font_size=8, font_weight="bold", ax=ax)
# Highlight start and end nodes
start_node = traversal_path[0]
end_node = traversal_path[-1]
nx.draw_networkx_nodes(traversal_graph, pos,
nodelist=[start_node],
node_color='lightgreen',
node_size=3000,
ax=ax)
nx.draw_networkx_nodes(traversal_graph, pos,
nodelist=[end_node],
node_color='lightcoral',
node_size=3000,
ax=ax)
ax.set_title("Graph Traversal Flow")
ax.axis('off')
# Add colorbar for edge weights
sm = plt.cm.ScalarMappable(cmap=plt.cm.Blues,
norm=plt.Normalize(vmin=min(edge_weights), vmax=max(edge_weights)))
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)
cbar.set_label('Edge Weight', rotation=270, labelpad=15)
# Add legend
regular_line = plt.Line2D([0], [0], color='blue', linewidth=2, label='Regular Edge')
traversal_line = plt.Line2D([0], [0], color='red', linewidth=2, linestyle='--', label='Traversal Path')
start_point = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgreen', markersize=15,
label='Start Node')
end_point = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightcoral', markersize=15,
label='End Node')
legend = plt.legend(handles=[regular_line, traversal_line, start_point, end_point], loc='upper left',
bbox_to_anchor=(0, 1), ncol=2)
legend.get_frame().set_alpha(0.8)
plt.tight_layout()
plt.show()
@staticmethod
def print_filtered_content(traversal_path, filtered_content):
"""
Prints the filtered content of visited nodes in the order of traversal.
Args:
- traversal_path (list of int): The list of node indices representing the traversal path.
- filtered_content (dict of int: str): A dictionary mapping node indices to their filtered content.
Returns:
- None
"""
print("\nFiltered content of visited nodes in order of traversal:")
for i, node in enumerate(traversal_path):
print(f"\nStep {i + 1} - Node {node}:")
print(
f"Filtered Content: {filtered_content.get(node, 'No filtered content available')[:200]}...") # Print first 200 characters
print("-" * 50)
# Define the graph RAG class
class GraphRAG:
def __init__(self, documents):
"""
Initializes the GraphRAG system with components for document processing, knowledge graph construction,
querying, and visualization.
Args:
- documents (list of str): A list of documents to be processed.
Attributes:
- llm: An instance of a large language model (LLM) for generating responses.
- embedding_model: An instance of an embedding model for document embeddings.
- document_processor: An instance of the DocumentProcessor class for processing documents.
- knowledge_graph: An instance of the KnowledgeGraph class for building and managing the knowledge graph.
- query_engine: An instance of the QueryEngine class for handling queries (initialized as None).
- visualizer: An instance of the Visualizer class for visualizing the knowledge graph traversal.
"""
self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", max_tokens=4000)
self.embedding_model = OpenAIEmbeddings()
self.document_processor = DocumentProcessor()
self.knowledge_graph = KnowledgeGraph()
self.query_engine = None
self.visualizer = Visualizer()
self.process_documents(documents)
def process_documents(self, documents):
"""
Processes a list of documents by splitting them into chunks, embedding them, and building a knowledge graph.
Args:
- documents (list of str): A list of documents to be processed.
Returns:
- None
"""
splits, vector_store = self.document_processor.process_documents(documents)
self.knowledge_graph.build_graph(splits, self.llm, self.embedding_model)
self.query_engine = QueryEngine(vector_store, self.knowledge_graph, self.llm)
def query(self, query: str):
"""
Handles a query by retrieving relevant information from the knowledge graph and visualizing the traversal path.
Args:
- query (str): The query to be answered.
Returns:
- str: The response to the query.
"""
response, traversal_path, filtered_content = self.query_engine.query(query)
if traversal_path:
self.visualizer.visualize_traversal(self.knowledge_graph.graph, traversal_path)
else:
print("No traversal path to visualize.")
return response
# Argument parsing
def parse_args():
parser = argparse.ArgumentParser(description="GraphRAG system")
parser.add_argument('--path', type=str, default="../data/Understanding_Climate_Change.pdf",
help='Path to the PDF file.')
parser.add_argument('--query', type=str, default='what is the main cause of climate change?',
help='Query to retrieve documents.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
# Load the documents
loader = PyPDFLoader(args.path)
documents = loader.load()
documents = documents[:10]
# Create a graph RAG instance
graph_rag = GraphRAG(documents)
# Process the documents and create the graph
graph_rag.process_documents(documents)
# Input a query and get the retrieved information from the graph RAG
response = graph_rag.query(args.query)