From a7306b7d270e1866bad294ccf7ec58db729db744 Mon Sep 17 00:00:00 2001
From: Parveen Kumar <89995648+parveen232@users.noreply.github.com>
Date: Thu, 30 May 2024 10:47:34 +0530
Subject: [PATCH] feat: text generation websocket functionality (#41)
* feat: websocket generate api
* Move text generation websocket functionality to webui.py
---
README.md | 62 ++++++++++++++++++++++-----------------
example/index.html | 34 ++++++++++++++++++++++
main.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++
requirements.txt | 3 +-
webui.py | 25 ++++++++++++++--
5 files changed, 167 insertions(+), 29 deletions(-)
create mode 100644 example/index.html
create mode 100644 main.py
diff --git a/README.md b/README.md
index 85612e4..6ca8be2 100644
--- a/README.md
+++ b/README.md
@@ -17,48 +17,50 @@ An easy-to-use tool made with Gradio, LangChain, and Torch.
- Enable LLM inference with [llama.cpp](https://github.com/ggerganov/llama.cpp) using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)
- Convert models(Safetensors, pt to gguf etc)
- Customize LLM inference parameters(n_gpu_layers, temperature, max_tokens etc)
+- Real-time text generation via websockets, enabling seamless integration with different frontend frameworks.
## 🚀 Installation
To use LLMinator, follow these simple steps:
-#### Clone the LLMinator repository from GitHub & install requirements
+#### Clone the LLMinator repository from GitHub & install requirements
-```bash
-git clone https://github.com/Aesthisia/LLMinator.git
-cd LLMinator
-pip install -r requirements.txt
-```
+ ```
+ git clone https://github.com/Aesthisia/LLMinator.git
+ cd LLMinator
+ pip install -r requirements.txt
+ ```
#### Build LLMinator with [llama.cpp](https://github.com/ggerganov/llama.cpp):
- - Using `make`:
+- Using `make`:
- - On Linux or MacOS:
+ - On Linux or MacOS:
- ```bash
- make
- ```
+ ```bash
+ make
+ ```
- - On Windows:
+ - On Windows:
- 1. Download the latest fortran version of [w64devkit](https://github.com/skeeto/w64devkit/releases).
- 2. Extract `w64devkit` on your pc.
- 3. Run `w64devkit.exe`.
- 4. Use the `cd` command to reach the `LLMinator` folder.
- 5. From here you can run:
- ```bash
- make
- ```
+ 1. Download the latest fortran version of [w64devkit](https://github.com/skeeto/w64devkit/releases).
+ 2. Extract `w64devkit` on your pc.
+ 3. Run `w64devkit.exe`.
+ 4. Use the `cd` command to reach the `LLMinator` folder.
+ 5. From here you can run:
+ ```bash
+ make
+ ```
- - Using `CMake`:
- ```bash
- mkdir build
- cd build
- cmake ..
- ```
+- Using `CMake`:
+ ```bash
+ mkdir build
+ cd build
+ cmake ..
+ ```
#### Launch LLMinator on browser
+
- Run the LLMinator tool using the command `python webui.py`.
- Access the web interface by opening the [http://127.0.0.1:7860](http://127.0.0.1:7860) in your browser.
- Start interacting with the chatbot and experimenting with LLMs!
@@ -75,6 +77,14 @@ Checkout this youtube [video](https://www.youtube.com/watch?v=OL8wRYbdjLE) to fo
| --port | 7860 | Launch gradio with given server port |
| --share | False | This generates a public shareable link that you can send to anybody |
+### Connect to WebSocket for generation
+
+Connect to [ws://localhost:7861/](ws://localhost:7861/) for real-time text generation. Submit prompts and receive responses through the websocket connection.
+
+**Integration with Frontends:**
+
+The provided `example/index.html` demonstrates basic usage of text generation through websocket connection. You can integrate it with any frontend framework like React.js
+
## Installation and Development Tips
#### Python Version
diff --git a/example/index.html b/example/index.html
new file mode 100644
index 0000000..109b583
--- /dev/null
+++ b/example/index.html
@@ -0,0 +1,34 @@
+
+
+
+
+
+ LLMinator
+
+
+ LLMinator
+
+
+
+
+
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..66eb997
--- /dev/null
+++ b/main.py
@@ -0,0 +1,72 @@
+import torch
+from fastapi import FastAPI, WebSocket
+from src import quantize
+from langchain import PromptTemplate
+from langchain_community.llms import LlamaCpp
+from langchain.callbacks.manager import CallbackManager
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+from langchain_core.prompts import PromptTemplate
+from core import default_repo_id
+
+app = FastAPI()
+
+# Callbacks support token-wise streaming
+callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
+
+#check if cuda is available
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+n_gpu_layers = None
+if device == "cuda":
+ n_gpu_layers = -1
+else:
+ n_gpu_layers = 0
+n_ctx = 6000
+n_batch = 30
+n_parts = 1
+temperature = 0.9
+max_tokens = 500
+
+def snapshot_download_and_convert_to_gguf(repo_id):
+ gguf_model_path = quantize.quantize_model(repo_id)
+ return gguf_model_path
+
+def init_llm_chain(model_path):
+ llm = LlamaCpp(
+ model_path=model_path,
+ n_gpu_layers=n_gpu_layers,
+ n_ctx=n_ctx,
+ n_batch=n_batch,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ n_parts=n_parts,
+ callback_manager=callback_manager,
+ verbose=True
+ )
+
+ template = """Question: {question}
+ Answer: Let's work this out in a step by step way to be sure we have the right answer."""
+
+ prompt = PromptTemplate.from_template(template)
+ llm_chain = prompt | llm
+ return llm_chain, llm
+
+model_path = snapshot_download_and_convert_to_gguf(default_repo_id)
+
+llm_chain, llm = init_llm_chain(model_path)
+
+@app.websocket("/generate")
+async def websocket_endpoint(websocket: WebSocket):
+ await websocket.accept()
+ while True:
+ prompt = await websocket.receive_text()
+
+ async def bot(prompt):
+ print("Question: ", prompt)
+ output = llm_chain.stream(prompt)
+ print("stream:", output)
+ for character in output:
+ print(character)
+ await websocket.send_text(character)
+
+ await bot(prompt)
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 6e0c02d..4657cca 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,4 +2,5 @@ gradio==4.27.0
huggingface_hub==0.21.1
langchain==0.1.14
torch==2.1.2
-llama-cpp-python==0.1.9
\ No newline at end of file
+llama-cpp-python==0.1.9
+fastapi==0.110.1
\ No newline at end of file
diff --git a/webui.py b/webui.py
index b9209a4..aee588b 100644
--- a/webui.py
+++ b/webui.py
@@ -1,4 +1,4 @@
-import os, torch, argparse
+import os, torch, argparse, asyncio, websockets, threading
import gradio as gr
from src import quantize
from langchain import PromptTemplate
@@ -75,6 +75,21 @@ def parse_args():
model_path = snapshot_download_and_convert_to_gguf(default_repo_id)
+async def generate(websocket):
+ async for message in websocket:
+ output = llm_chain.stream(message)
+ for character in output:
+ await asyncio.sleep(0)
+ await websocket.send(character)
+
+async def start_websockets():
+ print(f"Starting WebSocket server on port 7861 ...")
+ async with websockets.serve(generate, "localhost", 7861):
+ await asyncio.Future()
+
+async def main():
+ await start_websockets()
+
with gr.Blocks(css='style.css') as demo:
with gr.Tabs(selected="chat") as tabs:
with gr.Tab("Chat", id="chat"):
@@ -280,4 +295,10 @@ def bot(history):
config_reset_btn.click(resetConfigs, None, [n_gpu_layers_input, n_ctx_input, n_batch_input, n_parts_input, temperature_input, max_tokens_input], queue=False, show_progress="full")
demo.queue()
-demo.launch(server_name=args.host, server_port=args.port, share=args.share)
\ No newline at end of file
+
+def launch_demo():
+ demo.launch(server_name=args.host, server_port=args.port, share=args.share)
+
+if __name__ == "__main__":
+ threading.Thread(target=launch_demo).start()
+ asyncio.run(main())
\ No newline at end of file