diff --git a/llm/web/datasette-plugins/llm_views.py b/llm/web/datasette-plugins/llm_views.py index f8e18685..3d35984e 100644 --- a/llm/web/datasette-plugins/llm_views.py +++ b/llm/web/datasette-plugins/llm_views.py @@ -1,5 +1,31 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor from datasette import hookimpl, Response -import openai +import json +import llm + +END_SIGNAL = object() + + +def async_wrap(generator_func, *args, **kwargs): + def next_item(gen): + try: + return next(gen) + except StopIteration: + return END_SIGNAL + + async def async_generator(): + loop = asyncio.get_running_loop() + generator = iter(generator_func(*args, **kwargs)) + with ThreadPoolExecutor() as executor: + while True: + item = await loop.run_in_executor(executor, next_item, generator) + if item is END_SIGNAL: + break + yield item + + return async_generator + CHAT = """ @@ -9,6 +35,7 @@

WebSocket Client

+


@@ -24,8 +51,9 @@ function sendMessage() { const message = document.getElementById('message').value; - console.log({message, ws}); - ws.send(message); + const model = document.getElementById('model').value; + console.log({message, model, ws}); + ws.send(JSON.stringify({message, model})); } @@ -42,27 +70,31 @@ async def websocket_application(scope, receive, send): await send({"type": "websocket.accept"}) elif event["type"] == "websocket.receive": message = event["text"] - - async for chunk in await openai.ChatCompletion.acreate( - model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": message, - } - ], - stream=True, - ): - content = chunk["choices"][0].get("delta", {}).get("content") - if content is not None: - await send({"type": "websocket.send", "text": content}) + # {"message":"...","model":"gpt-3.5-turbo"} + decoded = json.loads(message) + model_id = decoded["model"] + message = decoded["message"] + model = llm.get_model(model_id) + async for chunk in async_wrap(model.prompt, message)(): + await send({"type": "websocket.send", "text": chunk}) + await send({"type": "websocket.send", "text": "\n\n"}) elif event["type"] == "websocket.disconnect": break def chat(): - return Response.html(CHAT) + return Response.html( + CHAT.replace( + "OPTIONS", + "\n".join( + ''.format( + model=model_with_alias.model.model_id + ) + for model_with_alias in llm.get_models_with_aliases() + ), + ) + ) @hookimpl