diff --git a/examples/graph_visual_with_html.py b/examples/graph_visual_with_html.py index 1a3ff144..56642185 100644 --- a/examples/graph_visual_with_html.py +++ b/examples/graph_visual_with_html.py @@ -11,6 +11,7 @@ # Convert NetworkX graph to Pyvis network net.from_nx(G) + # Add colors and title to nodes for node in net.nodes: node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) diff --git a/examples/lightrag_api_ollama_demo.py b/examples/lightrag_api_ollama_demo.py new file mode 100644 index 00000000..36df1262 --- /dev/null +++ b/examples/lightrag_api_ollama_demo.py @@ -0,0 +1,164 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile +from pydantic import BaseModel +import os +from lightrag import LightRAG, QueryParam +from lightrag.llm import ollama_embedding, ollama_model_complete +from lightrag.utils import EmbeddingFunc +from typing import Optional +import asyncio +import nest_asyncio +import aiofiles + +# Apply nest_asyncio to solve event loop issues +nest_asyncio.apply() + +DEFAULT_RAG_DIR = "index_default" +app = FastAPI(title="LightRAG API", description="API for RAG operations") + +DEFAULT_INPUT_FILE = "book.txt" +INPUT_FILE = os.environ.get("INPUT_FILE", f"{DEFAULT_INPUT_FILE}") +print(f"INPUT_FILE: {INPUT_FILE}") + +# Configure working directory +WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") +print(f"WORKING_DIR: {WORKING_DIR}") + + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=ollama_model_complete, + llm_model_name="gemma2:9b", + llm_model_max_async=4, + llm_model_max_token_size=8192, + llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 8192}}, + embedding_func=EmbeddingFunc( + embedding_dim=768, + max_token_size=8192, + func=lambda texts: ollama_embedding( + texts, embed_model="nomic-embed-text", host="http://localhost:11434" + ), + ), +) + + +# Data models +class QueryRequest(BaseModel): + query: str + mode: str = "hybrid" + only_need_context: bool = False + + +class InsertRequest(BaseModel): + text: str + + +class Response(BaseModel): + status: str + data: Optional[str] = None + message: Optional[str] = None + + +# API routes +@app.post("/query", response_model=Response) +async def query_endpoint(request: QueryRequest): + try: + loop = asyncio.get_event_loop() + result = await loop.run_in_executor( + None, + lambda: rag.query( + request.query, + param=QueryParam( + mode=request.mode, only_need_context=request.only_need_context + ), + ), + ) + return Response(status="success", data=result) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +# insert by text +@app.post("/insert", response_model=Response) +async def insert_endpoint(request: InsertRequest): + try: + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: rag.insert(request.text)) + return Response(status="success", message="Text inserted successfully") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +# insert by file in payload +@app.post("/insert_file", response_model=Response) +async def insert_file(file: UploadFile = File(...)): + try: + file_content = await file.read() + # Read file content + try: + content = file_content.decode("utf-8") + except UnicodeDecodeError: + # If UTF-8 decoding fails, try other encodings + content = file_content.decode("gbk") + # Insert file content + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: rag.insert(content)) + + return Response( + status="success", + message=f"File content from {file.filename} inserted successfully", + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +# insert by local default file +@app.post("/insert_default_file", response_model=Response) +@app.get("/insert_default_file", response_model=Response) +async def insert_default_file(): + try: + # Read file content from book.txt + async with aiofiles.open(INPUT_FILE, "r", encoding="utf-8") as file: + content = await file.read() + print(f"read input file {INPUT_FILE} successfully") + # Insert file content + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: rag.insert(content)) + + return Response( + status="success", + message=f"File content from {INPUT_FILE} inserted successfully", + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/health") +async def health_check(): + return {"status": "healthy"} + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8020) + +# Usage example +# To run the server, use the following command in your terminal: +# python lightrag_api_openai_compatible_demo.py + +# Example requests: +# 1. Query: +# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}' + +# 2. Insert text: +# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}' + +# 3. Insert file: +# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' + +# 4. Health check: +# curl -X GET "http://127.0.0.1:8020/health" diff --git a/examples/lightrag_openai_compatible_stream_demo.py b/examples/lightrag_openai_compatible_stream_demo.py new file mode 100644 index 00000000..9345ada5 --- /dev/null +++ b/examples/lightrag_openai_compatible_stream_demo.py @@ -0,0 +1,55 @@ +import os +import inspect +from lightrag import LightRAG +from lightrag.llm import openai_complete, openai_embedding +from lightrag.utils import EmbeddingFunc +from lightrag.lightrag import always_get_an_event_loop +from lightrag import QueryParam + +# WorkingDir +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +WORKING_DIR = os.path.join(ROOT_DIR, "dickens") +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) +print(f"WorkingDir: {WORKING_DIR}") + +api_key = "empty" +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=openai_complete, + llm_model_name="qwen2.5-14b-instruct@4bit", + llm_model_max_async=4, + llm_model_max_token_size=32768, + llm_model_kwargs={"base_url": "http://127.0.0.1:1234/v1", "api_key": api_key}, + embedding_func=EmbeddingFunc( + embedding_dim=1024, + max_token_size=8192, + func=lambda texts: openai_embedding( + texts=texts, + model="text-embedding-bge-m3", + base_url="http://127.0.0.1:1234/v1", + api_key=api_key, + ), + ), +) + +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +resp = rag.query( + "What are the top themes in this story?", + param=QueryParam(mode="hybrid", stream=True), +) + + +async def print_stream(stream): + async for chunk in stream: + if chunk: + print(chunk, end="", flush=True) + + +loop = always_get_an_event_loop() +if inspect.isasyncgen(resp): + loop.run_until_complete(print_stream(resp)) +else: + print(resp) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index 1b713773..1c5cd617 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.0.4" +__version__ = "1.0.5" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 0eb1b27e..833926e5 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -40,14 +40,6 @@ NetworkXStorage, ) -from .kg.neo4j_impl import Neo4JStorage - -from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage - -from .kg.milvus_impl import MilvusVectorDBStorge - -from .kg.mongo_impl import MongoKVStorage - # future KG integrations # from .kg.ArangoDB_impl import ( @@ -55,6 +47,30 @@ # ) +def lazy_external_import(module_name: str, class_name: str): + """Lazily import an external module and return a class from it.""" + + def import_class(): + import importlib + + # Import the module using importlib + module = importlib.import_module(module_name) + + # Get the class from the module + return getattr(module, class_name) + + # Return the import_class function itself, not its result + return import_class + + +Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage") +OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage") +OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage") +OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage") +MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge") +MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage") + + def always_get_an_event_loop() -> asyncio.AbstractEventLoop: """ Ensure that there is always an event loop available. @@ -68,7 +84,7 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop: try: # Try to get the current event loop current_loop = asyncio.get_event_loop() - if current_loop._closed: + if current_loop.is_closed(): raise RuntimeError("Event loop is closed.") return current_loop diff --git a/lightrag/llm.py b/lightrag/llm.py index 72af880e..6a64244a 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -76,11 +76,24 @@ async def openai_complete_if_cache( response = await openai_async_client.chat.completions.create( model=model, messages=messages, **kwargs ) - content = response.choices[0].message.content - if r"\u" in content: - content = content.encode("utf-8").decode("unicode_escape") - return content + if hasattr(response, "__aiter__"): + + async def inner(): + async for chunk in response: + content = chunk.choices[0].delta.content + if content is None: + continue + if r"\u" in content: + content = content.encode("utf-8").decode("unicode_escape") + yield content + + return inner() + else: + content = response.choices[0].message.content + if r"\u" in content: + content = content.encode("utf-8").decode("unicode_escape") + return content @retry( @@ -306,7 +319,7 @@ async def ollama_model_if_cache( response = await ollama_client.chat(model=model, messages=messages, **kwargs) if stream: - """ cannot cache stream response """ + """cannot cache stream response""" async def inner(): async for chunk in response: @@ -447,6 +460,22 @@ class GPTKeywordExtractionFormat(BaseModel): low_level_keywords: List[str] +async def openai_complete( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> Union[str, AsyncIterator[str]]: + keyword_extraction = kwargs.pop("keyword_extraction", None) + if keyword_extraction: + kwargs["response_format"] = "json" + model_name = kwargs["hashing_kv"].global_config["llm_model_name"] + return await openai_complete_if_cache( + model_name, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + async def gpt_4o_complete( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: @@ -890,6 +919,8 @@ async def llm_model_func( self, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: kwargs.pop("model", None) # stop from overwriting the custom model name + kwargs.pop("keyword_extraction", None) + kwargs.pop("mode", None) next_model = self._next_model() args = dict( prompt=prompt, diff --git a/lightrag/operate.py b/lightrag/operate.py index 45c9ef16..468f4b2f 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -222,7 +222,7 @@ async def _merge_edges_then_upsert( }, ) description = await _handle_entity_relation_summary( - (src_id, tgt_id), description, global_config + f"({src_id}, {tgt_id})", description, global_config ) await knowledge_graph_inst.upsert_edge( src_id, @@ -572,7 +572,6 @@ async def kg_query( mode=query_param.mode, ), ) - return response diff --git a/lightrag/utils.py b/lightrag/utils.py index 32d5c87f..d79cc1a2 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -488,7 +488,7 @@ class CacheData: async def save_to_cache(hashing_kv, cache_data: CacheData): - if hashing_kv is None: + if hashing_kv is None or hasattr(cache_data.content, "__aiter__"): return mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}