-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: text generation websocket functionality (#41)
* feat: websocket generate api * Move text generation websocket functionality to webui.py
- Loading branch information
1 parent
a08a6f1
commit a7306b7
Showing
5 changed files
with
167 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
<!DOCTYPE html> | ||
<html lang="en"> | ||
<head> | ||
<meta charset="UTF-8" /> | ||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> | ||
<title>LLMinator</title> | ||
</head> | ||
<body> | ||
<h1>LLMinator</h1> | ||
<form action="" onsubmit="sendMessage(event)"> | ||
<input type="text" id="promptText" autocomplete="off" /> | ||
<button>Send</button> | ||
</form> | ||
<textarea id="response"> </textarea> | ||
<script> | ||
var input = document.getElementById("promptText"); | ||
var response = document.getElementById("response"); | ||
input.style.minWidth = "600px"; | ||
input.style.padding = "10px"; | ||
response.style.width = "calc(100vw - 100px)"; | ||
response.style.minHeight = "400px"; | ||
var ws = new WebSocket("ws://localhost:7861/"); | ||
ws.onmessage = function (event) { | ||
var content = document.createTextNode(event.data); | ||
response.appendChild(content); | ||
}; | ||
function sendMessage(event) { | ||
ws.send(input.value); | ||
input.value = ""; | ||
event.preventDefault(); | ||
} | ||
</script> | ||
</body> | ||
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch | ||
from fastapi import FastAPI, WebSocket | ||
from src import quantize | ||
from langchain import PromptTemplate | ||
from langchain_community.llms import LlamaCpp | ||
from langchain.callbacks.manager import CallbackManager | ||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | ||
from langchain_core.prompts import PromptTemplate | ||
from core import default_repo_id | ||
|
||
app = FastAPI() | ||
|
||
# Callbacks support token-wise streaming | ||
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) | ||
|
||
#check if cuda is available | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
|
||
n_gpu_layers = None | ||
if device == "cuda": | ||
n_gpu_layers = -1 | ||
else: | ||
n_gpu_layers = 0 | ||
n_ctx = 6000 | ||
n_batch = 30 | ||
n_parts = 1 | ||
temperature = 0.9 | ||
max_tokens = 500 | ||
|
||
def snapshot_download_and_convert_to_gguf(repo_id): | ||
gguf_model_path = quantize.quantize_model(repo_id) | ||
return gguf_model_path | ||
|
||
def init_llm_chain(model_path): | ||
llm = LlamaCpp( | ||
model_path=model_path, | ||
n_gpu_layers=n_gpu_layers, | ||
n_ctx=n_ctx, | ||
n_batch=n_batch, | ||
temperature=temperature, | ||
max_tokens=max_tokens, | ||
n_parts=n_parts, | ||
callback_manager=callback_manager, | ||
verbose=True | ||
) | ||
|
||
template = """Question: {question} | ||
Answer: Let's work this out in a step by step way to be sure we have the right answer.""" | ||
|
||
prompt = PromptTemplate.from_template(template) | ||
llm_chain = prompt | llm | ||
return llm_chain, llm | ||
|
||
model_path = snapshot_download_and_convert_to_gguf(default_repo_id) | ||
|
||
llm_chain, llm = init_llm_chain(model_path) | ||
|
||
@app.websocket("/generate") | ||
async def websocket_endpoint(websocket: WebSocket): | ||
await websocket.accept() | ||
while True: | ||
prompt = await websocket.receive_text() | ||
|
||
async def bot(prompt): | ||
print("Question: ", prompt) | ||
output = llm_chain.stream(prompt) | ||
print("stream:", output) | ||
for character in output: | ||
print(character) | ||
await websocket.send_text(character) | ||
|
||
await bot(prompt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters