Skip to content

Commit

Permalink
feat: text generation websocket functionality (#41)
Browse files Browse the repository at this point in the history
* feat: websocket generate api

* Move text generation websocket functionality to webui.py
  • Loading branch information
parveen232 authored May 30, 2024
1 parent a08a6f1 commit a7306b7
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 29 deletions.
62 changes: 36 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,50 @@ An easy-to-use tool made with Gradio, LangChain, and Torch.
- Enable LLM inference with [llama.cpp](https://github.com/ggerganov/llama.cpp) using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)
- Convert models(Safetensors, pt to gguf etc)
- Customize LLM inference parameters(n_gpu_layers, temperature, max_tokens etc)
- Real-time text generation via websockets, enabling seamless integration with different frontend frameworks.

## 🚀 Installation

To use LLMinator, follow these simple steps:

#### Clone the LLMinator repository from GitHub & install requirements
#### Clone the LLMinator repository from GitHub & install requirements

```bash
git clone https://github.com/Aesthisia/LLMinator.git
cd LLMinator
pip install -r requirements.txt
```
```
git clone https://github.com/Aesthisia/LLMinator.git
cd LLMinator
pip install -r requirements.txt
```

#### Build LLMinator with [llama.cpp](https://github.com/ggerganov/llama.cpp):

- Using `make`:
- Using `make`:

- On Linux or MacOS:
- On Linux or MacOS:

```bash
make
```
```bash
make
```

- On Windows:
- On Windows:

1. Download the latest fortran version of [w64devkit](https://github.com/skeeto/w64devkit/releases).
2. Extract `w64devkit` on your pc.
3. Run `w64devkit.exe`.
4. Use the `cd` command to reach the `LLMinator` folder.
5. From here you can run:
```bash
make
```
1. Download the latest fortran version of [w64devkit](https://github.com/skeeto/w64devkit/releases).
2. Extract `w64devkit` on your pc.
3. Run `w64devkit.exe`.
4. Use the `cd` command to reach the `LLMinator` folder.
5. From here you can run:
```bash
make
```

- Using `CMake`:
```bash
mkdir build
cd build
cmake ..
```
- Using `CMake`:
```bash
mkdir build
cd build
cmake ..
```

#### Launch LLMinator on browser

- Run the LLMinator tool using the command `python webui.py`.
- Access the web interface by opening the [http://127.0.0.1:7860](http://127.0.0.1:7860) in your browser.
- Start interacting with the chatbot and experimenting with LLMs!
Expand All @@ -75,6 +77,14 @@ Checkout this youtube [video](https://www.youtube.com/watch?v=OL8wRYbdjLE) to fo
| --port | 7860 | Launch gradio with given server port |
| --share | False | This generates a public shareable link that you can send to anybody |

### Connect to WebSocket for generation

Connect to [ws://localhost:7861/](ws://localhost:7861/) for real-time text generation. Submit prompts and receive responses through the websocket connection.

**Integration with Frontends:**

The provided `example/index.html` demonstrates basic usage of text generation through websocket connection. You can integrate it with any frontend framework like React.js

## Installation and Development Tips

#### Python Version
Expand Down
34 changes: 34 additions & 0 deletions example/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>LLMinator</title>
</head>
<body>
<h1>LLMinator</h1>
<form action="" onsubmit="sendMessage(event)">
<input type="text" id="promptText" autocomplete="off" />
<button>Send</button>
</form>
<textarea id="response"> </textarea>
<script>
var input = document.getElementById("promptText");
var response = document.getElementById("response");
input.style.minWidth = "600px";
input.style.padding = "10px";
response.style.width = "calc(100vw - 100px)";
response.style.minHeight = "400px";
var ws = new WebSocket("ws://localhost:7861/");
ws.onmessage = function (event) {
var content = document.createTextNode(event.data);
response.appendChild(content);
};
function sendMessage(event) {
ws.send(input.value);
input.value = "";
event.preventDefault();
}
</script>
</body>
</html>
72 changes: 72 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
from fastapi import FastAPI, WebSocket
from src import quantize
from langchain import PromptTemplate
from langchain_community.llms import LlamaCpp
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_core.prompts import PromptTemplate
from core import default_repo_id

app = FastAPI()

# Callbacks support token-wise streaming
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])

#check if cuda is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'

n_gpu_layers = None
if device == "cuda":
n_gpu_layers = -1
else:
n_gpu_layers = 0
n_ctx = 6000
n_batch = 30
n_parts = 1
temperature = 0.9
max_tokens = 500

def snapshot_download_and_convert_to_gguf(repo_id):
gguf_model_path = quantize.quantize_model(repo_id)
return gguf_model_path

def init_llm_chain(model_path):
llm = LlamaCpp(
model_path=model_path,
n_gpu_layers=n_gpu_layers,
n_ctx=n_ctx,
n_batch=n_batch,
temperature=temperature,
max_tokens=max_tokens,
n_parts=n_parts,
callback_manager=callback_manager,
verbose=True
)

template = """Question: {question}
Answer: Let's work this out in a step by step way to be sure we have the right answer."""

prompt = PromptTemplate.from_template(template)
llm_chain = prompt | llm
return llm_chain, llm

model_path = snapshot_download_and_convert_to_gguf(default_repo_id)

llm_chain, llm = init_llm_chain(model_path)

@app.websocket("/generate")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
prompt = await websocket.receive_text()

async def bot(prompt):
print("Question: ", prompt)
output = llm_chain.stream(prompt)
print("stream:", output)
for character in output:
print(character)
await websocket.send_text(character)

await bot(prompt)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ gradio==4.27.0
huggingface_hub==0.21.1
langchain==0.1.14
torch==2.1.2
llama-cpp-python==0.1.9
llama-cpp-python==0.1.9
fastapi==0.110.1
25 changes: 23 additions & 2 deletions webui.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os, torch, argparse
import os, torch, argparse, asyncio, websockets, threading
import gradio as gr
from src import quantize
from langchain import PromptTemplate
Expand Down Expand Up @@ -75,6 +75,21 @@ def parse_args():

model_path = snapshot_download_and_convert_to_gguf(default_repo_id)

async def generate(websocket):
async for message in websocket:
output = llm_chain.stream(message)
for character in output:
await asyncio.sleep(0)
await websocket.send(character)

async def start_websockets():
print(f"Starting WebSocket server on port 7861 ...")
async with websockets.serve(generate, "localhost", 7861):
await asyncio.Future()

async def main():
await start_websockets()

with gr.Blocks(css='style.css') as demo:
with gr.Tabs(selected="chat") as tabs:
with gr.Tab("Chat", id="chat"):
Expand Down Expand Up @@ -280,4 +295,10 @@ def bot(history):
config_reset_btn.click(resetConfigs, None, [n_gpu_layers_input, n_ctx_input, n_batch_input, n_parts_input, temperature_input, max_tokens_input], queue=False, show_progress="full")

demo.queue()
demo.launch(server_name=args.host, server_port=args.port, share=args.share)

def launch_demo():
demo.launch(server_name=args.host, server_port=args.port, share=args.share)

if __name__ == "__main__":
threading.Thread(target=launch_demo).start()
asyncio.run(main())

0 comments on commit a7306b7

Please sign in to comment.