From d51ecd4192bce24fc4419772cc2d3ebdfb974963 Mon Sep 17 00:00:00 2001 From: killian <63927363+KillianLucas@users.noreply.github.com> Date: Wed, 23 Oct 2024 22:23:40 -0700 Subject: [PATCH] Prep for server, windows fix --- interpreter/computer_use/loop.py | 103 ++++++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 2 deletions(-) diff --git a/interpreter/computer_use/loop.py b/interpreter/computer_use/loop.py index 3f8f0c7f2..e26300392 100755 --- a/interpreter/computer_use/loop.py +++ b/interpreter/computer_use/loop.py @@ -33,10 +33,19 @@ BETA_FLAG = "computer-use-2024-10-22" +from typing import List, Optional + +import uvicorn +from fastapi import FastAPI +from fastapi.responses import StreamingResponse +from pydantic import BaseModel from rich import print as rich_print from rich.markdown import Markdown from rich.rule import Rule +# Add this near the top of the file, with other imports and global variables +messages: List[BetaMessageParam] = [] + def print_markdown(message): """ @@ -87,7 +96,7 @@ class APIProvider(StrEnum): * When using your bash tool with commands that are expected to output very large quantities of text, redirect into a tmp file and use str_replace_editor or `grep -n -B -A ` to confirm output. * When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. * When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. -* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. +* The current date is {datetime.today().strftime('%A, %B %d, %Y')}. @@ -335,6 +344,81 @@ async def main(): provider = APIProvider.ANTHROPIC system_prompt_suffix = "" + # Check if running in server mode + if "--server" in sys.argv: + app = FastAPI() + + # Start the mouse position checking thread when in server mode + mouse_thread = threading.Thread(target=check_mouse_position) + mouse_thread.daemon = True + mouse_thread.start() + + # Get API key from environment variable + api_key = os.environ.get("ANTHROPIC_API_KEY") + if not api_key: + raise ValueError( + "ANTHROPIC_API_KEY environment variable must be set when running in server mode" + ) + + @app.post("/openai/chat/completions") + async def chat_completion(request: ChatCompletionRequest): + # Check exit flag before processing request + if exit_flag: + return {"error": "Server shutting down due to mouse in corner"} + + async def stream_response(): + print("is this even happening") + + # Instead of creating converted_messages, append the last message to global messages + global messages + messages.append( + { + "role": request.messages[-1].role, + "content": [ + {"type": "text", "text": request.messages[-1].content} + ], + } + ) + + response_chunks = [] + + async def output_callback(content_block: BetaContentBlock): + chunk = f"data: {json.dumps({'choices': [{'delta': {'content': content_block.text}}]})}\n\n" + response_chunks.append(chunk) + yield chunk + + async def tool_output_callback(result: ToolResult, tool_id: str): + if result.output or result.error: + content = result.output if result.output else result.error + chunk = f"data: {json.dumps({'choices': [{'delta': {'content': content}}]})}\n\n" + response_chunks.append(chunk) + yield chunk + + try: + result = await sampling_loop( + model=model, + provider=provider, + system_prompt_suffix=system_prompt_suffix, + messages=messages, # Now using global messages + output_callback=output_callback, + tool_output_callback=tool_output_callback, + api_key=api_key, + ) + + # # Yield all stored chunks + # for chunk in response_chunks: + # yield chunk + + except Exception as e: + print(f"Error: {e}") + yield f"data: {json.dumps({'error': str(e)})}\n\n" + + return StreamingResponse(stream_response(), media_type="text/event-stream") + + # Instead of running uvicorn here, we'll return the app + return app + + # Original CLI code continues here... print() print_markdown("Welcome to **Open Interpreter**.\n") print_markdown("---") @@ -428,7 +512,12 @@ def tool_output_callback(result: ToolResult, tool_id: str): def run_async_main(): - asyncio.run(main()) + if "--server" in sys.argv: + # Start uvicorn server directly without asyncio.run() + app = asyncio.run(main()) + uvicorn.run(app, host="0.0.0.0", port=8000) + else: + asyncio.run(main()) if __name__ == "__main__": @@ -464,3 +553,13 @@ def check_mouse_position(): print("\nMouse moved to corner. Exiting...") os._exit(0) threading.Event().wait(0.1) # Check every 100ms + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatCompletionRequest(BaseModel): + messages: List[ChatMessage] + stream: Optional[bool] = False