diff --git a/llm/web/datasette-plugins/llm_views.py b/llm/web/datasette-plugins/llm_views.py
index f8e18685..3d35984e 100644
--- a/llm/web/datasette-plugins/llm_views.py
+++ b/llm/web/datasette-plugins/llm_views.py
@@ -1,5 +1,31 @@
+import asyncio
+from concurrent.futures import ThreadPoolExecutor
from datasette import hookimpl, Response
-import openai
+import json
+import llm
+
+END_SIGNAL = object()
+
+
+def async_wrap(generator_func, *args, **kwargs):
+ def next_item(gen):
+ try:
+ return next(gen)
+ except StopIteration:
+ return END_SIGNAL
+
+ async def async_generator():
+ loop = asyncio.get_running_loop()
+ generator = iter(generator_func(*args, **kwargs))
+ with ThreadPoolExecutor() as executor:
+ while True:
+ item = await loop.run_in_executor(executor, next_item, generator)
+ if item is END_SIGNAL:
+ break
+ yield item
+
+ return async_generator
+
CHAT = """
@@ -9,6 +35,7 @@
WebSocket Client
+
@@ -24,8 +51,9 @@
function sendMessage() {
const message = document.getElementById('message').value;
- console.log({message, ws});
- ws.send(message);
+ const model = document.getElementById('model').value;
+ console.log({message, model, ws});
+ ws.send(JSON.stringify({message, model}));
}
@@ -42,27 +70,31 @@ async def websocket_application(scope, receive, send):
await send({"type": "websocket.accept"})
elif event["type"] == "websocket.receive":
message = event["text"]
-
- async for chunk in await openai.ChatCompletion.acreate(
- model="gpt-3.5-turbo",
- messages=[
- {
- "role": "user",
- "content": message,
- }
- ],
- stream=True,
- ):
- content = chunk["choices"][0].get("delta", {}).get("content")
- if content is not None:
- await send({"type": "websocket.send", "text": content})
+ # {"message":"...","model":"gpt-3.5-turbo"}
+ decoded = json.loads(message)
+ model_id = decoded["model"]
+ message = decoded["message"]
+ model = llm.get_model(model_id)
+ async for chunk in async_wrap(model.prompt, message)():
+ await send({"type": "websocket.send", "text": chunk})
+ await send({"type": "websocket.send", "text": "\n\n"})
elif event["type"] == "websocket.disconnect":
break
def chat():
- return Response.html(CHAT)
+ return Response.html(
+ CHAT.replace(
+ "OPTIONS",
+ "\n".join(
+ ''.format(
+ model=model_with_alias.model.model_id
+ )
+ for model_with_alias in llm.get_models_with_aliases()
+ ),
+ )
+ )
@hookimpl