diff --git a/.dockerignore b/.dockerignore index 58cf1f0f8d..e28863bf60 100644 --- a/.dockerignore +++ b/.dockerignore @@ -7,7 +7,6 @@ node_modules /package .env .env.* -!.env.example vite.config.js.timestamp-* vite.config.ts.timestamp-* __pycache__ diff --git a/.env.example b/.env.example index de763f31c9..3d2aafc09e 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,6 @@ # Ollama URL for the backend to connect -# The path '/ollama/api' will be redirected to the specified backend URL -OLLAMA_API_BASE_URL='http://localhost:11434/api' +# The path '/ollama' will be redirected to the specified backend URL +OLLAMA_BASE_URL='http://localhost:11434' OPENAI_API_BASE_URL='' OPENAI_API_KEY='' diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 5a85d08796..4386661335 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -32,7 +32,7 @@ assignees: '' **Confirmation:** - [ ] I have read and followed all the instructions provided in the README.md. -- [ ] I have reviewed the troubleshooting.md document. +- [ ] I am on the latest version of both Open WebUI and Ollama. - [ ] I have included the browser console logs. - [ ] I have included the Docker container logs. diff --git a/.github/workflows/build-release.yml b/.github/workflows/build-release.yml index fa3fa296d2..259f0c5ffa 100644 --- a/.github/workflows/build-release.yml +++ b/.github/workflows/build-release.yml @@ -19,24 +19,34 @@ jobs: echo "No changes to package.json" exit 1 } - + - name: Get version number from package.json id: get_version run: | VERSION=$(jq -r '.version' package.json) echo "::set-output name=version::$VERSION" + - name: Extract latest CHANGELOG entry + id: changelog + run: | + CHANGELOG_CONTENT=$(awk 'BEGIN {print_section=0;} /^## \[/ {if (print_section == 0) {print_section=1;} else {exit;}} print_section {print;}' CHANGELOG.md) + CHANGELOG_ESCAPED=$(echo "$CHANGELOG_CONTENT" | sed ':a;N;$!ba;s/\n/%0A/g') + echo "Extracted latest release notes from CHANGELOG.md:" + echo -e "$CHANGELOG_CONTENT" + echo "::set-output name=content::$CHANGELOG_ESCAPED" + - name: Create GitHub release uses: actions/github-script@v5 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | + const changelog = `${{ steps.changelog.outputs.content }}`; const release = await github.rest.repos.createRelease({ owner: context.repo.owner, repo: context.repo.repo, tag_name: `v${{ steps.get_version.outputs.version }}`, name: `v${{ steps.get_version.outputs.version }}`, - body: 'Automatically created new release', + body: changelog, }) console.log(`Created release ${release.data.html_url}`) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b150de25d..d57ba400c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,106 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.1.111] - 2024-03-10 + +### Added + +- ๐ก๏ธ **Model Whitelisting**: Admins now have the ability to whitelist models for users with the 'user' role. +- ๐ **Update All Models**: Added a convenient button to update all models at once. +- ๐ **Toggle PDF OCR**: Users can now toggle PDF OCR option for improved parsing performance. +- ๐จ **DALL-E Integration**: Introduced DALL-E integration for image generation alongside automatic1111. +- ๐ ๏ธ **RAG API Refactoring**: Refactored RAG logic and exposed its API, with additional documentation to follow. + +### Fixed + +- ๐ **Max Token Settings**: Added max token settings for anthropic/claude-3-sonnet-20240229 (Issue #1094). +- ๐ง **Misalignment Issue**: Corrected misalignment of Edit and Delete Icons when Chat Title is Empty (Issue #1104). +- ๐ **Context Loss Fix**: Resolved RAG losing context on model response regeneration with Groq models via API key (Issue #1105). +- ๐ **File Handling Bug**: Addressed File Not Found Notification when Dropping a Conversation Element (Issue #1098). +- ๐ฑ๏ธ **Dragged File Styling**: Fixed dragged file layover styling issue. + +## [0.1.110] - 2024-03-06 + +### Added + +- **๐ Multiple OpenAI Servers Support**: Enjoy seamless integration with multiple OpenAI-compatible APIs, now supported natively. + +### Fixed + +- **๐ OCR Issue**: Resolved PDF parsing issue caused by OCR malfunction. +- **๐ซ RAG Issue**: Fixed the RAG functionality, ensuring it operates smoothly. +- **๐ "Add Docs" Model Button**: Addressed the non-functional behavior of the "Add Docs" model button. + +## [0.1.109] - 2024-03-06 + +### Added + +- **๐ Multiple Ollama Servers Support**: Enjoy enhanced scalability and performance with support for multiple Ollama servers in a single WebUI. Load balancing features are now available, providing improved efficiency (#788, #278). +- **๐ง Support for Claude 3 and Gemini**: Responding to user requests, we've expanded our toolset to include Claude 3 and Gemini, offering a wider range of functionalities within our platform (#1064). +- **๐ OCR Functionality for PDF Loader**: We've augmented our PDF loader with Optical Character Recognition (OCR) capabilities. Now, extract text from scanned documents and images within PDFs, broadening the scope of content processing (#1050). + +### Fixed + +- **๐ ๏ธ RAG Collection**: Implemented a dynamic mechanism to recreate RAG collections, ensuring users have up-to-date and accurate data (#1031). +- **๐ User Agent Headers**: Fixed issue of RAG web requests being sent with empty user_agent headers, reducing rejections from certain websites. Realistic headers are now utilized for these requests (#1024). +- **โน๏ธ Playground Cancel Functionality**: Introducing a new "Cancel" option for stopping Ollama generation in the Playground, enhancing user control and usability (#1006). +- **๐ค Typographical Error in 'ASSISTANT' Field**: Corrected a typographical error in the 'ASSISTANT' field within the GGUF model upload template for accuracy and consistency (#1061). + +### Changed + +- **๐ Refactored Message Deletion Logic**: Streamlined message deletion process for improved efficiency and user experience, simplifying interactions within the platform (#1004). +- **โ ๏ธ Deprecation of `OLLAMA_API_BASE_URL`**: Deprecated `OLLAMA_API_BASE_URL` environment variable; recommend using `OLLAMA_BASE_URL` instead. Refer to our documentation for further details. + +## [0.1.108] - 2024-03-02 + +### Added + +- **๐ฎ Playground Feature (Beta)**: Explore the full potential of the raw API through an intuitive UI with our new playground feature, accessible to admins. Simply click on the bottom name area of the sidebar to access it. The playground feature offers two modes text completion (notebook) and chat completion. As it's in beta, please report any issues you encounter. +- **๐ ๏ธ Direct Database Download for Admins**: Admins can now download the database directly from the WebUI via the admin settings. +- **๐จ Additional RAG Settings**: Customize your RAG process with the ability to edit the TOP K value. Navigate to Documents > Settings > General to make changes. +- **๐ฅ๏ธ UI Improvements**: Tooltips now available in the input area and sidebar handle. More tooltips will be added across other parts of the UI. + +### Fixed + +- Resolved input autofocus issue on mobile when the sidebar is open, making it easier to use. +- Corrected numbered list display issue in Safari (#963). +- Restricted user ability to delete chats without proper permissions (#993). + +### Changed + +- **Simplified Ollama Settings**: Ollama settings now don't require the `/api` suffix. You can now utilize the Ollama base URL directly, e.g., `http://localhost:11434`. Also, an `OLLAMA_BASE_URL` environment variable has been added. +- **Database Renaming**: Starting from this release, `ollama.db` will be automatically renamed to `webui.db`. + +## [0.1.107] - 2024-03-01 + +### Added + +- **๐ Makefile and LLM Update Script**: Included Makefile and a script for LLM updates in the repository. + +### Fixed + +- Corrected issue where links in the settings modal didn't appear clickable (#960). +- Fixed problem with web UI port not taking effect due to incorrect environment variable name in run-compose.sh (#996). +- Enhanced user experience by displaying chat in browser title and enabling automatic scrolling to the bottom (#992). + +### Changed + +- Upgraded toast library from `svelte-french-toast` to `svelte-sonner` for a more polished UI. +- Enhanced accessibility with the addition of dark mode on the authentication page. + +## [0.1.106] - 2024-02-27 + +### Added + +- **๐ฏ Auto-focus Feature**: The input area now automatically focuses when initiating or opening a chat conversation. + +### Fixed + +- Corrected typo from "HuggingFace" to "Hugging Face" (Issue #924). +- Resolved bug causing errors in chat completion API calls to OpenAI due to missing "num_ctx" parameter (Issue #927). +- Fixed issues preventing text editing, selection, and cursor retention in the input field (Issue #940). +- Fixed a bug where defining an OpenAI-compatible API server using 'OPENAI_API_BASE_URL' containing 'openai' string resulted in hiding models not containing 'gpt' string from the model menu. (Issue #930) + ## [0.1.105] - 2024-02-25 ### Added diff --git a/Dockerfile b/Dockerfile index 7ea416de38..b9f2961011 100644 --- a/Dockerfile +++ b/Dockerfile @@ -20,7 +20,7 @@ FROM python:3.11-slim-bookworm as base ENV ENV=prod ENV PORT "" -ENV OLLAMA_API_BASE_URL "/ollama/api" +ENV OLLAMA_BASE_URL "/ollama" ENV OPENAI_API_BASE_URL "" ENV OPENAI_API_KEY "" @@ -41,7 +41,7 @@ ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" # for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2" -# device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance +# device type for whisper tts and embbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu" ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models" ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR @@ -53,6 +53,8 @@ WORKDIR /app/backend # install python dependencies COPY ./backend/requirements.txt ./requirements.txt +RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y + RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir RUN pip3 install -r requirements.txt --no-cache-dir @@ -79,4 +81,4 @@ COPY --from=build /app/package.json /app/package.json # copy backend files COPY ./backend . -CMD [ "bash", "start.sh"] \ No newline at end of file +CMD [ "bash", "start.sh"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..cbcc41d92e --- /dev/null +++ b/Makefile @@ -0,0 +1,25 @@ +install: + @docker-compose up -d + +remove: + @chmod +x confirm_remove.sh + @./confirm_remove.sh + + +start: + @docker-compose start + +stop: + @docker-compose stop + +update: + # Calls the LLM update script + chmod +x update_ollama_models.sh + @./update_ollama_models.sh + @git pull + @docker-compose down + # Make sure the ollama-webui container is stopped before rebuilding + @docker stop open-webui || true + @docker-compose up --build -d + @docker-compose start + diff --git a/README.md b/README.md index 7c40239c7d..46777dbf25 100644 --- a/README.md +++ b/README.md @@ -53,8 +53,6 @@ User-friendly WebUI for LLMs, supported LLM runners include Ollama and OpenAI-co - ๐ฌ **Collaborative Chat**: Harness the collective intelligence of multiple models by seamlessly orchestrating group conversations. Use the `@` command to specify the model, enabling dynamic and diverse dialogues within your chat interface. Immerse yourself in the collective intelligence woven into your chat environment. -- ๐ค **OpenAI API Integration**: Effortlessly integrate OpenAI-compatible API for versatile conversations alongside Ollama models. Customize the API Base URL to link with **LMStudio, Mistral, OpenRouter, and more**. - - ๐ **Regeneration History Access**: Easily revisit and explore your entire regeneration history. - ๐ **Chat History**: Effortlessly access and manage your conversation history. @@ -65,8 +63,18 @@ User-friendly WebUI for LLMs, supported LLM runners include Ollama and OpenAI-co - โ๏ธ **Fine-Tuned Control with Advanced Parameters**: Gain a deeper level of control by adjusting parameters such as temperature and defining your system prompts to tailor the conversation to your specific preferences and needs. +- ๐จ๐ค **Image Generation Integration**: Seamlessly incorporate image generation capabilities using AUTOMATIC1111 API (local) and DALL-E, enriching your chat experience with dynamic visual content. + +- ๐ค **OpenAI API Integration**: Effortlessly integrate OpenAI-compatible API for versatile conversations alongside Ollama models. Customize the API Base URL to link with **LMStudio, Mistral, OpenRouter, and more**. + +- โจ **Multiple OpenAI-Compatible API Support**: Seamlessly integrate and customize various OpenAI-compatible APIs, enhancing the versatility of your chat interactions. + - ๐ **External Ollama Server Connection**: Seamlessly link to an external Ollama server hosted on a different address by configuring the environment variable. +- ๐ **Multiple Ollama Instance Load Balancing**: Effortlessly distribute chat requests across multiple Ollama instances for enhanced performance and reliability. + +- ๐ฅ **Multi-User Management**: Easily oversee and administer users via our intuitive admin panel, streamlining user management processes. + - ๐ **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators. - ๐ **Backend Reverse Proxy Support**: Bolster security through direct communication between Open WebUI backend and Ollama. This key feature eliminates the need to expose Ollama over LAN. Requests made to the '/ollama/api' route from the web UI are seamlessly redirected to Ollama from the backend, enhancing overall system security. @@ -95,10 +103,10 @@ Don't forget to explore our sibling project, [Open WebUI Community](https://open - **If Ollama is on a Different Server**, use this command: -- To connect to Ollama on another server, change the `OLLAMA_API_BASE_URL` to the server's URL: +- To connect to Ollama on another server, change the `OLLAMA_BASE_URL` to the server's URL: ```bash - docker run -d -p 3000:8080 -e OLLAMA_API_BASE_URL=https://example.com/api -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main + docker run -d -p 3000:8080 -e OLLAMA_BASE_URL=https://example.com -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main ``` - After installation, you can access Open WebUI at [http://localhost:3000](http://localhost:3000). Enjoy! ๐ @@ -110,7 +118,7 @@ If you're experiencing connection issues, itโs often due to the WebUI docker c **Example Docker Command**: ```bash -docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_API_BASE_URL=http://127.0.0.1:11434/api --name open-webui --restart always ghcr.io/open-webui/open-webui:main +docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main ``` ### Other Installation Methods @@ -160,6 +168,16 @@ This project is licensed under the [MIT License](LICENSE) - see the [LICENSE](LI If you have any questions, suggestions, or need assistance, please open an issue or join our [Open WebUI Discord community](https://discord.gg/5rJgQTnV4s) to connect with us! ๐ค +## Star History + + + + + --- Created by [Timothy J. Baek](https://github.com/tjbck) - Let's make Open Web UI even more amazing together! ๐ช diff --git a/TROUBLESHOOTING.md b/TROUBLESHOOTING.md index d3163501a3..8e8f89da02 100644 --- a/TROUBLESHOOTING.md +++ b/TROUBLESHOOTING.md @@ -4,7 +4,7 @@ The Open WebUI system is designed to streamline interactions between the client (your browser) and the Ollama API. At the heart of this design is a backend reverse proxy, enhancing security and resolving CORS issues. -- **How it Works**: The Open WebUI is designed to interact with the Ollama API through a specific route. When a request is made from the WebUI to Ollama, it is not directly sent to the Ollama API. Initially, the request is sent to the Open WebUI backend via `/ollama/api` route. From there, the backend is responsible for forwarding the request to the Ollama API. This forwarding is accomplished by using the route specified in the `OLLAMA_API_BASE_URL` environment variable. Therefore, a request made to `/ollama/api` in the WebUI is effectively the same as making a request to `OLLAMA_API_BASE_URL` in the backend. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_API_BASE_URL/tags` in the backend. +- **How it Works**: The Open WebUI is designed to interact with the Ollama API through a specific route. When a request is made from the WebUI to Ollama, it is not directly sent to the Ollama API. Initially, the request is sent to the Open WebUI backend via `/ollama` route. From there, the backend is responsible for forwarding the request to the Ollama API. This forwarding is accomplished by using the route specified in the `OLLAMA_BASE_URL` environment variable. Therefore, a request made to `/ollama` in the WebUI is effectively the same as making a request to `OLLAMA_BASE_URL` in the backend. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_BASE_URL/api/tags` in the backend. - **Security Benefits**: This design prevents direct exposure of the Ollama API to the frontend, safeguarding against potential CORS (Cross-Origin Resource Sharing) issues and unauthorized access. Requiring authentication to access the Ollama API further enhances this security layer. @@ -15,7 +15,7 @@ If you're experiencing connection issues, itโs often due to the WebUI docker c **Example Docker Command**: ```bash -docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_API_BASE_URL=http://127.0.0.1:11434/api --name open-webui --restart always ghcr.io/open-webui/open-webui:main +docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main ``` ### General Connection Errors @@ -25,8 +25,8 @@ docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_API_BASE_ **Troubleshooting Steps**: 1. **Verify Ollama URL Format**: - - When running the Web UI container, ensure the `OLLAMA_API_BASE_URL` is correctly set, including the `/api` suffix. (e.g., `http://192.168.1.1:11434/api` for different host setups). + - When running the Web UI container, ensure the `OLLAMA_BASE_URL` is correctly set. (e.g., `http://192.168.1.1:11434` for different host setups). - In the Open WebUI, navigate to "Settings" > "General". - - Confirm that the Ollama Server URL is correctly set to `[OLLAMA URL]/api` (e.g., `http://localhost:11434/api`), including the `/api` suffix. + - Confirm that the Ollama Server URL is correctly set to `[OLLAMA URL]` (e.g., `http://localhost:11434`). By following these enhanced troubleshooting steps, connection issues should be effectively resolved. For further assistance or queries, feel free to reach out to us on our community Discord. diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index dfa1f187a8..31bfc0f5d2 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -21,7 +21,16 @@ from utils.misc import calculate_sha256 from typing import Optional from pydantic import BaseModel -from config import AUTOMATIC1111_BASE_URL +from pathlib import Path +import uuid +import base64 +import json + +from config import CACHE_DIR, AUTOMATIC1111_BASE_URL + + +IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") +IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) app = FastAPI() app.add_middleware( @@ -32,25 +41,34 @@ allow_headers=["*"], ) +app.state.ENGINE = "" +app.state.ENABLED = False + +app.state.OPENAI_API_KEY = "" +app.state.MODEL = "" + + app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL -app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != "" + app.state.IMAGE_SIZE = "512x512" app.state.IMAGE_STEPS = 50 -@app.get("/enabled", response_model=bool) -async def get_enable_status(request: Request, user=Depends(get_admin_user)): - return app.state.ENABLED +@app.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): + return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} -@app.get("/enabled/toggle", response_model=bool) -async def toggle_enabled(request: Request, user=Depends(get_admin_user)): - try: - r = requests.head(app.state.AUTOMATIC1111_BASE_URL) - app.state.ENABLED = not app.state.ENABLED - return app.state.ENABLED - except Exception as e: - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) +class ConfigUpdateForm(BaseModel): + engine: str + enabled: bool + + +@app.post("/config/update") +async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): + app.state.ENGINE = form_data.engine + app.state.ENABLED = form_data.enabled + return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} class UrlUpdateForm(BaseModel): @@ -58,17 +76,24 @@ class UrlUpdateForm(BaseModel): @app.get("/url") -async def get_openai_url(user=Depends(get_admin_user)): +async def get_automatic1111_url(user=Depends(get_admin_user)): return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL} @app.post("/url/update") -async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): +async def update_automatic1111_url( + form_data: UrlUpdateForm, user=Depends(get_admin_user) +): if form_data.url == "": app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL else: - app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/") + url = form_data.url.strip("/") + try: + r = requests.head(url) + app.state.AUTOMATIC1111_BASE_URL = url + except Exception as e: + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) return { "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, @@ -76,6 +101,30 @@ async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_use } +class OpenAIKeyUpdateForm(BaseModel): + key: str + + +@app.get("/key") +async def get_openai_key(user=Depends(get_admin_user)): + return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} + + +@app.post("/key/update") +async def update_openai_key( + form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user) +): + + if form_data.key == "": + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + + app.state.OPENAI_API_KEY = form_data.key + return { + "OPENAI_API_KEY": app.state.OPENAI_API_KEY, + "status": True, + } + + class ImageSizeUpdateForm(BaseModel): size: str @@ -132,9 +181,22 @@ async def update_image_size( @app.get("/models") def get_models(user=Depends(get_current_user)): try: - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models") - models = r.json() - return models + if app.state.ENGINE == "openai": + return [ + {"id": "dall-e-2", "name": "DALLยทE 2"}, + {"id": "dall-e-3", "name": "DALLยทE 3"}, + ] + else: + r = requests.get( + url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" + ) + models = r.json() + return list( + map( + lambda model: {"id": model["title"], "name": model["model_name"]}, + models, + ) + ) except Exception as e: app.state.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -143,10 +205,12 @@ def get_models(user=Depends(get_current_user)): @app.get("/models/default") async def get_default_model(user=Depends(get_admin_user)): try: - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") - options = r.json() - - return {"model": options["sd_model_checkpoint"]} + if app.state.ENGINE == "openai": + return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} + else: + r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + options = r.json() + return {"model": options["sd_model_checkpoint"]} except Exception as e: app.state.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -157,16 +221,21 @@ class UpdateModelForm(BaseModel): def set_model_handler(model: str): - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") - options = r.json() - if model != options["sd_model_checkpoint"]: - options["sd_model_checkpoint"] = model - r = requests.post( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options - ) + if app.state.ENGINE == "openai": + app.state.MODEL = model + return app.state.MODEL + else: + r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + options = r.json() + + if model != options["sd_model_checkpoint"]: + options["sd_model_checkpoint"] = model + r = requests.post( + url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options + ) - return options + return options @app.post("/models/default/update") @@ -181,45 +250,113 @@ class GenerateImageForm(BaseModel): model: Optional[str] = None prompt: str n: int = 1 - size: str = "512x512" + size: Optional[str] = None negative_prompt: Optional[str] = None +def save_b64_image(b64_str): + image_id = str(uuid.uuid4()) + file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png") + + try: + # Split the base64 string to get the actual image data + img_data = base64.b64decode(b64_str) + + # Write the image data to a file + with open(file_path, "wb") as f: + f.write(img_data) + + return image_id + except Exception as e: + print(f"Error saving image: {e}") + return None + + @app.post("/generations") def generate_image( form_data: GenerateImageForm, user=Depends(get_current_user), ): - print(form_data) - + r = None try: - if form_data.model: - set_model_handler(form_data.model) + if app.state.ENGINE == "openai": - width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) + headers = {} + headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" + headers["Content-Type"] = "application/json" - data = { - "prompt": form_data.prompt, - "batch_size": form_data.n, - "width": width, - "height": height, - } + data = { + "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", + "prompt": form_data.prompt, + "n": form_data.n, + "size": form_data.size if form_data.size else app.state.IMAGE_SIZE, + "response_format": "b64_json", + } + r = requests.post( + url=f"https://api.openai.com/v1/images/generations", + json=data, + headers=headers, + ) - if app.state.IMAGE_STEPS != None: - data["steps"] = app.state.IMAGE_STEPS + r.raise_for_status() - if form_data.negative_prompt != None: - data["negative_prompt"] = form_data.negative_prompt + res = r.json() - print(data) + images = [] - r = requests.post( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", - json=data, - ) + for image in res["data"]: + image_id = save_b64_image(image["b64_json"]) + images.append({"url": f"/cache/image/generations/{image_id}.png"}) + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") + + with open(file_body_path, "w") as f: + json.dump(data, f) + + return images + + else: + if form_data.model: + set_model_handler(form_data.model) + + width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) + + data = { + "prompt": form_data.prompt, + "batch_size": form_data.n, + "width": width, + "height": height, + } + + if app.state.IMAGE_STEPS != None: + data["steps"] = app.state.IMAGE_STEPS + + if form_data.negative_prompt != None: + data["negative_prompt"] = form_data.negative_prompt + + r = requests.post( + url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", + json=data, + ) + + res = r.json() + + print(res) + + images = [] + + for image in res["images"]: + image_id = save_b64_image(image) + images.append({"url": f"/cache/image/generations/{image_id}.png"}) + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") + + with open(file_body_path, "w") as f: + json.dump({**data, "info": res["info"]}, f) + + return images - return r.json() except Exception as e: print(e) + if r: + print(r.json()) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py new file mode 100644 index 0000000000..21b9e58a73 --- /dev/null +++ b/backend/apps/litellm/main.py @@ -0,0 +1,41 @@ +from litellm.proxy.proxy_server import ProxyConfig, initialize +from litellm.proxy.proxy_server import app + +from fastapi import FastAPI, Request, Depends, status +from fastapi.responses import JSONResponse +from utils.utils import get_http_authorization_cred, get_current_user +from config import ENV + +proxy_config = ProxyConfig() + + +async def config(): + router, model_list, general_settings = await proxy_config.load_config( + router=None, config_file_path="./data/litellm/config.yaml" + ) + + await initialize(config="./data/litellm/config.yaml", telemetry=False) + + +async def startup(): + await config() + + +@app.on_event("startup") +async def on_startup(): + await startup() + + +@app.middleware("http") +async def auth_middleware(request: Request, call_next): + auth_header = request.headers.get("Authorization", "") + + if ENV != "dev": + try: + user = get_current_user(get_http_authorization_cred(auth_header)) + print(user) + except Exception as e: + return JSONResponse(status_code=400, content={"detail": str(e)}) + + response = await call_next(request) + return response diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index bc797f080e..5ecbaa2971 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -3,15 +3,22 @@ from fastapi.responses import StreamingResponse from fastapi.concurrency import run_in_threadpool +from pydantic import BaseModel, ConfigDict + +import random import requests import json import uuid -from pydantic import BaseModel +import aiohttp +import asyncio from apps.web.models.users import Users from constants import ERROR_MESSAGES from utils.utils import decode_token, get_current_user, get_admin_user -from config import OLLAMA_API_BASE_URL, WEBUI_AUTH +from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST + +from typing import Optional, List, Union + app = FastAPI() app.add_middleware( @@ -22,27 +29,48 @@ allow_headers=["*"], ) -app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL -# TARGET_SERVER_URL = OLLAMA_API_BASE_URL +app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED +app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST + +app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS +app.state.MODELS = {} REQUEST_POOL = [] -@app.get("/url") -async def get_ollama_api_url(user=Depends(get_admin_user)): - return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} +# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. +# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, +# least connections, or least response time for better resource utilization and performance optimization. + + +@app.middleware("http") +async def check_url(request: Request, call_next): + if len(app.state.MODELS) == 0: + await get_all_models() + else: + pass + + response = await call_next(request) + return response + + +@app.get("/urls") +async def get_ollama_api_urls(user=Depends(get_admin_user)): + return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} class UrlUpdateForm(BaseModel): - url: str + urls: List[str] -@app.post("/url/update") +@app.post("/urls/update") async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): - app.state.OLLAMA_API_BASE_URL = form_data.url - return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} + app.state.OLLAMA_BASE_URLS = form_data.urls + + print(app.state.OLLAMA_BASE_URLS) + return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} @app.get("/cancel/{request_id}") @@ -55,9 +83,817 @@ async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)) raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) +async def fetch_url(url): + try: + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + return await response.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + return None + + +def merge_models_lists(model_lists): + merged_models = {} + + for idx, model_list in enumerate(model_lists): + for model in model_list: + digest = model["digest"] + if digest not in merged_models: + model["urls"] = [idx] + merged_models[digest] = model + else: + merged_models[digest]["urls"].append(idx) + + return list(merged_models.values()) + + +# user=Depends(get_current_user) + + +async def get_all_models(): + print("get_all_models") + tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] + responses = await asyncio.gather(*tasks) + responses = list(filter(lambda x: x is not None, responses)) + + models = { + "models": merge_models_lists( + map(lambda response: response["models"], responses) + ) + } + app.state.MODELS = {model["model"]: model for model in models["models"]} + + return models + + +@app.get("/api/tags") +@app.get("/api/tags/{url_idx}") +async def get_ollama_tags( + url_idx: Optional[int] = None, user=Depends(get_current_user) +): + if url_idx == None: + models = await get_all_models() + + if app.state.MODEL_FILTER_ENABLED: + if user.role == "user": + models["models"] = list( + filter( + lambda model: model["name"] in app.state.MODEL_FILTER_LIST, + models["models"], + ) + ) + return models + return models + else: + url = app.state.OLLAMA_BASE_URLS[url_idx] + try: + r = requests.request(method="GET", url=f"{url}/api/tags") + r.raise_for_status() + + return r.json() + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +@app.get("/api/version") +@app.get("/api/version/{url_idx}") +async def get_ollama_versions(url_idx: Optional[int] = None): + + if url_idx == None: + + # returns lowest version + tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS] + responses = await asyncio.gather(*tasks) + responses = list(filter(lambda x: x is not None, responses)) + + lowest_version = min( + responses, key=lambda x: tuple(map(int, x["version"].split("."))) + ) + + return {"version": lowest_version["version"]} + else: + url = app.state.OLLAMA_BASE_URLS[url_idx] + try: + r = requests.request(method="GET", url=f"{url}/api/version") + r.raise_for_status() + + return r.json() + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class ModelNameForm(BaseModel): + name: str + + +@app.post("/api/pull") +@app.post("/api/pull/{url_idx}") +async def pull_model( + form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) +): + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + r = None + + def get_request(): + nonlocal url + nonlocal r + try: + + def stream_content(): + for chunk in r.iter_content(chunk_size=8192): + yield chunk + + r = requests.request( + method="POST", + url=f"{url}/api/pull", + data=form_data.model_dump_json(exclude_none=True).encode(), + stream=True, + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class PushModelForm(BaseModel): + name: str + insecure: Optional[bool] = None + stream: Optional[bool] = None + + +@app.delete("/api/push") +@app.delete("/api/push/{url_idx}") +async def push_model( + form_data: PushModelForm, + url_idx: Optional[int] = None, + user=Depends(get_admin_user), +): + if url_idx == None: + if form_data.name in app.state.MODELS: + url_idx = app.state.MODELS[form_data.name]["urls"][0] + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + r = None + + def get_request(): + nonlocal url + nonlocal r + try: + + def stream_content(): + for chunk in r.iter_content(chunk_size=8192): + yield chunk + + r = requests.request( + method="POST", + url=f"{url}/api/push", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class CreateModelForm(BaseModel): + name: str + modelfile: Optional[str] = None + stream: Optional[bool] = None + path: Optional[str] = None + + +@app.post("/api/create") +@app.post("/api/create/{url_idx}") +async def create_model( + form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) +): + print(form_data) + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + r = None + + def get_request(): + nonlocal url + nonlocal r + try: + + def stream_content(): + for chunk in r.iter_content(chunk_size=8192): + yield chunk + + r = requests.request( + method="POST", + url=f"{url}/api/create", + data=form_data.model_dump_json(exclude_none=True).encode(), + stream=True, + ) + + r.raise_for_status() + + print(r) + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class CopyModelForm(BaseModel): + source: str + destination: str + + +@app.post("/api/copy") +@app.post("/api/copy/{url_idx}") +async def copy_model( + form_data: CopyModelForm, + url_idx: Optional[int] = None, + user=Depends(get_admin_user), +): + if url_idx == None: + if form_data.source in app.state.MODELS: + url_idx = app.state.MODELS[form_data.source]["urls"][0] + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + try: + r = requests.request( + method="POST", + url=f"{url}/api/copy", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + r.raise_for_status() + + print(r.text) + + return True + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +@app.delete("/api/delete") +@app.delete("/api/delete/{url_idx}") +async def delete_model( + form_data: ModelNameForm, + url_idx: Optional[int] = None, + user=Depends(get_admin_user), +): + if url_idx == None: + if form_data.name in app.state.MODELS: + url_idx = app.state.MODELS[form_data.name]["urls"][0] + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + try: + r = requests.request( + method="DELETE", + url=f"{url}/api/delete", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + r.raise_for_status() + + print(r.text) + + return True + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +@app.post("/api/show") +async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_user)): + if form_data.name not in app.state.MODELS: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), + ) + + url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + try: + r = requests.request( + method="POST", + url=f"{url}/api/show", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + r.raise_for_status() + + return r.json() + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class GenerateEmbeddingsForm(BaseModel): + model: str + prompt: str + options: Optional[dict] = None + keep_alive: Optional[Union[int, str]] = None + + +@app.post("/api/embeddings") +@app.post("/api/embeddings/{url_idx}") +async def generate_embeddings( + form_data: GenerateEmbeddingsForm, + url_idx: Optional[int] = None, + user=Depends(get_current_user), +): + if url_idx == None: + if form_data.model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[form_data.model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + try: + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + r.raise_for_status() + + return r.json() + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class GenerateCompletionForm(BaseModel): + model: str + prompt: str + images: Optional[List[str]] = None + format: Optional[str] = None + options: Optional[dict] = None + system: Optional[str] = None + template: Optional[str] = None + context: Optional[str] = None + stream: Optional[bool] = True + raw: Optional[bool] = None + keep_alive: Optional[Union[int, str]] = None + + +@app.post("/api/generate") +@app.post("/api/generate/{url_idx}") +async def generate_completion( + form_data: GenerateCompletionForm, + url_idx: Optional[int] = None, + user=Depends(get_current_user), +): + + if url_idx == None: + if form_data.model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[form_data.model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail="error_detail", + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + r = None + + def get_request(): + nonlocal form_data + nonlocal r + + request_id = str(uuid.uuid4()) + try: + REQUEST_POOL.append(request_id) + + def stream_content(): + try: + if form_data.stream: + yield json.dumps({"id": request_id, "done": False}) + "\n" + + for chunk in r.iter_content(chunk_size=8192): + if request_id in REQUEST_POOL: + yield chunk + else: + print("User: canceled request") + break + finally: + if hasattr(r, "close"): + r.close() + if request_id in REQUEST_POOL: + REQUEST_POOL.remove(request_id) + + r = requests.request( + method="POST", + url=f"{url}/api/generate", + data=form_data.model_dump_json(exclude_none=True).encode(), + stream=True, + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class ChatMessage(BaseModel): + role: str + content: str + images: Optional[List[str]] = None + + +class GenerateChatCompletionForm(BaseModel): + model: str + messages: List[ChatMessage] + format: Optional[str] = None + options: Optional[dict] = None + template: Optional[str] = None + stream: Optional[bool] = True + keep_alive: Optional[Union[int, str]] = None + + +@app.post("/api/chat") +@app.post("/api/chat/{url_idx}") +async def generate_chat_completion( + form_data: GenerateChatCompletionForm, + url_idx: Optional[int] = None, + user=Depends(get_current_user), +): + + if url_idx == None: + if form_data.model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[form_data.model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + r = None + + print(form_data.model_dump_json(exclude_none=True).encode()) + + def get_request(): + nonlocal form_data + nonlocal r + + request_id = str(uuid.uuid4()) + try: + REQUEST_POOL.append(request_id) + + def stream_content(): + try: + if form_data.stream: + yield json.dumps({"id": request_id, "done": False}) + "\n" + + for chunk in r.iter_content(chunk_size=8192): + if request_id in REQUEST_POOL: + yield chunk + else: + print("User: canceled request") + break + finally: + if hasattr(r, "close"): + r.close() + if request_id in REQUEST_POOL: + REQUEST_POOL.remove(request_id) + + r = requests.request( + method="POST", + url=f"{url}/api/chat", + data=form_data.model_dump_json(exclude_none=True).encode(), + stream=True, + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + print(e) + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +# TODO: we should update this part once Ollama supports other types +class OpenAIChatMessage(BaseModel): + role: str + content: str + + model_config = ConfigDict(extra="allow") + + +class OpenAIChatCompletionForm(BaseModel): + model: str + messages: List[OpenAIChatMessage] + + model_config = ConfigDict(extra="allow") + + +@app.post("/v1/chat/completions") +@app.post("/v1/chat/completions/{url_idx}") +async def generate_openai_chat_completion( + form_data: OpenAIChatCompletionForm, + url_idx: Optional[int] = None, + user=Depends(get_current_user), +): + + if url_idx == None: + if form_data.model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[form_data.model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + print(url) + + r = None + + def get_request(): + nonlocal form_data + nonlocal r + + request_id = str(uuid.uuid4()) + try: + REQUEST_POOL.append(request_id) + + def stream_content(): + try: + if form_data.stream: + yield json.dumps( + {"request_id": request_id, "done": False} + ) + "\n" + + for chunk in r.iter_content(chunk_size=8192): + if request_id in REQUEST_POOL: + yield chunk + else: + print("User: canceled request") + break + finally: + if hasattr(r, "close"): + r.close() + if request_id in REQUEST_POOL: + REQUEST_POOL.remove(request_id) + + r = requests.request( + method="POST", + url=f"{url}/v1/chat/completions", + data=form_data.model_dump_json(exclude_none=True).encode(), + stream=True, + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def proxy(path: str, request: Request, user=Depends(get_current_user)): - target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" +async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): + url = app.state.OLLAMA_BASE_URLS[0] + target_url = f"{url}/{path}" body = await request.body() headers = dict(request.headers) @@ -91,7 +927,13 @@ def get_request(): def stream_content(): try: - if path in ["chat"]: + if path == "generate": + data = json.loads(body.decode("utf-8")) + + if not ("stream" in data and data["stream"] == False): + yield json.dumps({"id": request_id, "done": False}) + "\n" + + elif path == "chat": yield json.dumps({"id": request_id, "done": False}) + "\n" for chunk in r.iter_content(chunk_size=8192): @@ -103,7 +945,8 @@ def stream_content(): finally: if hasattr(r, "close"): r.close() - REQUEST_POOL.remove(request_id) + if request_id in REQUEST_POOL: + REQUEST_POOL.remove(request_id) r = requests.request( method=request.method, diff --git a/backend/apps/ollama/old_main.py b/backend/apps/ollama/old_main.py deleted file mode 100644 index 5e5b881111..0000000000 --- a/backend/apps/ollama/old_main.py +++ /dev/null @@ -1,127 +0,0 @@ -from fastapi import FastAPI, Request, Response, HTTPException, Depends -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse - -import requests -import json -from pydantic import BaseModel - -from apps.web.models.users import Users -from constants import ERROR_MESSAGES -from utils.utils import decode_token, get_current_user -from config import OLLAMA_API_BASE_URL, WEBUI_AUTH - -import aiohttp - -app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL - -# TARGET_SERVER_URL = OLLAMA_API_BASE_URL - - -@app.get("/url") -async def get_ollama_api_url(user=Depends(get_current_user)): - if user and user.role == "admin": - return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) - - -class UrlUpdateForm(BaseModel): - url: str - - -@app.post("/url/update") -async def update_ollama_api_url( - form_data: UrlUpdateForm, user=Depends(get_current_user) -): - if user and user.role == "admin": - app.state.OLLAMA_API_BASE_URL = form_data.url - return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) - - -# async def fetch_sse(method, target_url, body, headers): -# async with aiohttp.ClientSession() as session: -# try: -# async with session.request( -# method, target_url, data=body, headers=headers -# ) as response: -# print(response.status) -# async for line in response.content: -# yield line -# except Exception as e: -# print(e) -# error_detail = "Open WebUI: Server Connection Error" -# yield json.dumps({"error": error_detail, "message": str(e)}).encode() - - -@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def proxy(path: str, request: Request, user=Depends(get_current_user)): - target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" - print(target_url) - - body = await request.body() - headers = dict(request.headers) - - if user.role in ["user", "admin"]: - if path in ["pull", "delete", "push", "copy", "create"]: - if user.role != "admin": - raise HTTPException( - status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED - ) - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) - - headers.pop("Host", None) - headers.pop("Authorization", None) - headers.pop("Origin", None) - headers.pop("Referer", None) - - session = aiohttp.ClientSession() - response = None - try: - response = await session.request( - request.method, target_url, data=body, headers=headers - ) - - print(response) - if not response.ok: - data = await response.json() - print(data) - response.raise_for_status() - - async def generate(): - async for line in response.content: - print(line) - yield line - await session.close() - - return StreamingResponse(generate(), response.status) - - except Exception as e: - print(e) - error_detail = "Open WebUI: Server Connection Error" - - if response is not None: - try: - res = await response.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except: - error_detail = f"Ollama: {e}" - - await session.close() - raise HTTPException( - status_code=response.status if response else 500, - detail=error_detail, - ) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 3632643013..375ed3f121 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -3,7 +3,10 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse import requests +import aiohttp +import asyncio import json + from pydantic import BaseModel @@ -15,7 +18,15 @@ get_verified_user, get_admin_user, ) -from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR +from config import ( + OPENAI_API_BASE_URLS, + OPENAI_API_KEYS, + CACHE_DIR, + MODEL_FILTER_ENABLED, + MODEL_FILTER_LIST, +) +from typing import List, Optional + import hashlib from pathlib import Path @@ -29,116 +40,225 @@ allow_headers=["*"], ) -app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL -app.state.OPENAI_API_KEY = OPENAI_API_KEY +app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED +app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST + +app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS +app.state.OPENAI_API_KEYS = OPENAI_API_KEYS + +app.state.MODELS = {} + + +@app.middleware("http") +async def check_url(request: Request, call_next): + if len(app.state.MODELS) == 0: + await get_all_models() + else: + pass + + response = await call_next(request) + return response -class UrlUpdateForm(BaseModel): - url: str +class UrlsUpdateForm(BaseModel): + urls: List[str] -class KeyUpdateForm(BaseModel): - key: str +class KeysUpdateForm(BaseModel): + keys: List[str] -@app.get("/url") -async def get_openai_url(user=Depends(get_admin_user)): - return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} +@app.get("/urls") +async def get_openai_urls(user=Depends(get_admin_user)): + return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} -@app.post("/url/update") -async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): - app.state.OPENAI_API_BASE_URL = form_data.url - return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} +@app.post("/urls/update") +async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): + app.state.OPENAI_API_BASE_URLS = form_data.urls + return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} -@app.get("/key") -async def get_openai_key(user=Depends(get_admin_user)): - return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} +@app.get("/keys") +async def get_openai_keys(user=Depends(get_admin_user)): + return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS} -@app.post("/key/update") -async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_admin_user)): - app.state.OPENAI_API_KEY = form_data.key - return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} +@app.post("/keys/update") +async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)): + app.state.OPENAI_API_KEYS = form_data.keys + return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS} @app.post("/audio/speech") async def speech(request: Request, user=Depends(get_verified_user)): - target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech" + idx = None + try: + idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") + body = await request.body() + name = hashlib.sha256(body).hexdigest() + + SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") + SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) + file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3") + file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json") + + # Check if the file already exists in the cache + if file_path.is_file(): + return FileResponse(file_path) + + headers = {} + headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}" + headers["Content-Type"] = "application/json" + + try: + r = requests.post( + url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech", + data=body, + headers=headers, + stream=True, + ) - if app.state.OPENAI_API_KEY == "": - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + r.raise_for_status() - body = await request.body() + # Save the streaming content to a file + with open(file_path, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) - name = hashlib.sha256(body).hexdigest() + with open(file_body_path, "w") as f: + json.dump(json.loads(body.decode("utf-8")), f) - SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") - SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) - file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3") - file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json") + # Return the saved file + return FileResponse(file_path) - # Check if the file already exists in the cache - if file_path.is_file(): - return FileResponse(file_path) + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']}" + except: + error_detail = f"External: {e}" - headers = {} - headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" - headers["Content-Type"] = "application/json" + raise HTTPException(status_code=r.status_code, detail=error_detail) + + except ValueError: + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) + +async def fetch_url(url, key): try: - print("openai") - r = requests.post( - url=target_url, - data=body, - headers=headers, - stream=True, + headers = {"Authorization": f"Bearer {key}"} + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as response: + return await response.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + return None + + +def merge_models_lists(model_lists): + merged_list = [] + + for idx, models in enumerate(model_lists): + merged_list.extend( + [ + {**model, "urlIdx": idx} + for model in models + if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx] + or "gpt" in model["id"] + ] ) - r.raise_for_status() + return merged_list - # Save the streaming content to a file - with open(file_path, "wb") as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - with open(file_body_path, "w") as f: - json.dump(json.loads(body.decode("utf-8")), f) +async def get_all_models(): + print("get_all_models") - # Return the saved file - return FileResponse(file_path) + if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "": + models = {"data": []} + else: + tasks = [ + fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx]) + for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS) + ] + responses = await asyncio.gather(*tasks) + responses = list( + filter(lambda x: x is not None and "error" not in x, responses) + ) + models = { + "data": merge_models_lists( + list(map(lambda response: response["data"], responses)) + ) + } + app.state.MODELS = {model["id"]: model for model in models["data"]} + + return models + + +@app.get("/models") +@app.get("/models/{url_idx}") +async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): + if url_idx == None: + models = await get_all_models() + if app.state.MODEL_FILTER_ENABLED: + if user.role == "user": + models["data"] = list( + filter( + lambda model: model["id"] in app.state.MODEL_FILTER_LIST, + models["data"], + ) + ) + return models + return models + else: + url = app.state.OPENAI_API_BASE_URLS[url_idx] + try: + r = requests.request(method="GET", url=f"{url}/models") + r.raise_for_status() - except Exception as e: - print(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"External: {res['error']}" - except: - error_detail = f"External: {e}" + response_data = r.json() + if "api.openai.com" in url: + response_data["data"] = list( + filter(lambda model: "gpt" in model["id"], response_data["data"]) + ) - raise HTTPException(status_code=r.status_code, detail=error_detail) + return response_data + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']}" + except: + error_detail = f"External: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy(path: str, request: Request, user=Depends(get_verified_user)): - target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}" - print(target_url, app.state.OPENAI_API_KEY) - - if app.state.OPENAI_API_KEY == "": - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + idx = 0 body = await request.body() - # TODO: Remove below after gpt-4-vision fix from Open AI # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision) try: body = body.decode("utf-8") body = json.loads(body) + idx = app.state.MODELS[body.get("model")]["urlIdx"] + # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 # This is a workaround until OpenAI fixes the issue with this model if body.get("model") == "gpt-4-vision-preview": @@ -146,13 +266,28 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): body["max_tokens"] = 4000 print("Modified body_dict:", body) + # Fix for ChatGPT calls failing because the num_ctx key is in body + if "num_ctx" in body: + # If 'num_ctx' is in the dictionary, delete it + # Leaving it there generates an error with the + # OpenAI API (Feb 2024) + del body["num_ctx"] + # Convert the modified body back to JSON body = json.dumps(body) except json.JSONDecodeError as e: print("Error loading request body into a dictionary:", e) + url = app.state.OPENAI_API_BASE_URLS[idx] + key = app.state.OPENAI_API_KEYS[idx] + + target_url = f"{url}/{path}" + + if key == "": + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + headers = {} - headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" + headers["Authorization"] = f"Bearer {key}" headers["Content-Type"] = "application/json" try: @@ -174,21 +309,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): headers=dict(r.headers), ) else: - # For non-SSE, read the response and return it - # response_data = ( - # r.json() - # if r.headers.get("Content-Type", "") - # == "application/json" - # else r.text - # ) - response_data = r.json() - - if "openai" in app.state.OPENAI_API_BASE_URL and path == "models": - response_data["data"] = list( - filter(lambda model: "gpt" in model["id"], response_data["data"]) - ) - return response_data except Exception as e: print(e) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 83c10233e6..b21724cc9c 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -44,6 +44,8 @@ DocumentResponse, ) +from apps.rag.utils import query_doc, query_collection + from utils.misc import ( calculate_sha256, calculate_sha256_string, @@ -75,10 +77,13 @@ app = FastAPI() +app.state.PDF_EXTRACT_IMAGES = False app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL +app.state.TOP_K = 4 + app.state.sentence_transformer_ef = ( embedding_functions.SentenceTransformerEmbeddingFunction( model_name=app.state.RAG_EMBEDDING_MODEL, @@ -106,7 +111,7 @@ class StoreWebForm(CollectionNameForm): url: str -def store_data_in_vector_db(data, collection_name) -> bool: +def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: text_splitter = RecursiveCharacterTextSplitter( chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP ) @@ -116,6 +121,12 @@ def store_data_in_vector_db(data, collection_name) -> bool: metadatas = [doc.metadata for doc in docs] try: + if overwrite: + for collection in CHROMA_CLIENT.list_collections(): + if collection_name == collection.name: + print(f"deleting existing collection {collection_name}") + CHROMA_CLIENT.delete_collection(name=collection_name) + collection = CHROMA_CLIENT.create_collection( name=collection_name, embedding_function=app.state.sentence_transformer_ef, @@ -174,12 +185,15 @@ async def update_embedding_model( } -@app.get("/chunk") -async def get_chunk_params(user=Depends(get_admin_user)): +@app.get("/config") +async def get_rag_config(user=Depends(get_admin_user)): return { "status": True, - "chunk_size": app.state.CHUNK_SIZE, - "chunk_overlap": app.state.CHUNK_OVERLAP, + "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, + "chunk": { + "chunk_size": app.state.CHUNK_SIZE, + "chunk_overlap": app.state.CHUNK_OVERLAP, + }, } @@ -188,17 +202,24 @@ class ChunkParamUpdateForm(BaseModel): chunk_overlap: int -@app.post("/chunk/update") -async def update_chunk_params( - form_data: ChunkParamUpdateForm, user=Depends(get_admin_user) -): - app.state.CHUNK_SIZE = form_data.chunk_size - app.state.CHUNK_OVERLAP = form_data.chunk_overlap +class ConfigUpdateForm(BaseModel): + pdf_extract_images: bool + chunk: ChunkParamUpdateForm + + +@app.post("/config/update") +async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): + app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images + app.state.CHUNK_SIZE = form_data.chunk.chunk_size + app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap return { "status": True, - "chunk_size": app.state.CHUNK_SIZE, - "chunk_overlap": app.state.CHUNK_OVERLAP, + "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, + "chunk": { + "chunk_size": app.state.CHUNK_SIZE, + "chunk_overlap": app.state.CHUNK_OVERLAP, + }, } @@ -210,38 +231,48 @@ async def get_rag_template(user=Depends(get_current_user)): } -class RAGTemplateForm(BaseModel): - template: str +@app.get("/query/settings") +async def get_query_settings(user=Depends(get_admin_user)): + return { + "status": True, + "template": app.state.RAG_TEMPLATE, + "k": app.state.TOP_K, + } + +class QuerySettingsForm(BaseModel): + k: Optional[int] = None + template: Optional[str] = None -@app.post("/template/update") -async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)): - # TODO: check template requirements - app.state.RAG_TEMPLATE = ( - form_data.template if form_data.template != "" else RAG_TEMPLATE - ) + +@app.post("/query/settings/update") +async def update_query_settings( + form_data: QuerySettingsForm, user=Depends(get_admin_user) +): + app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE + app.state.TOP_K = form_data.k if form_data.k else 4 return {"status": True, "template": app.state.RAG_TEMPLATE} class QueryDocForm(BaseModel): collection_name: str query: str - k: Optional[int] = 4 + k: Optional[int] = None @app.post("/query/doc") -def query_doc( +def query_doc_handler( form_data: QueryDocForm, user=Depends(get_current_user), ): + try: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, + return query_doc( + collection_name=form_data.collection_name, + query=form_data.query, + k=form_data.k if form_data.k else app.state.TOP_K, embedding_function=app.state.sentence_transformer_ef, ) - result = collection.query(query_texts=[form_data.query], n_results=form_data.k) - return result except Exception as e: print(e) raise HTTPException( @@ -253,77 +284,20 @@ def query_doc( class QueryCollectionsForm(BaseModel): collection_names: List[str] query: str - k: Optional[int] = 4 - - -def merge_and_sort_query_results(query_results, k): - # Initialize lists to store combined data - combined_ids = [] - combined_distances = [] - combined_metadatas = [] - combined_documents = [] - - # Combine data from each dictionary - for data in query_results: - combined_ids.extend(data["ids"][0]) - combined_distances.extend(data["distances"][0]) - combined_metadatas.extend(data["metadatas"][0]) - combined_documents.extend(data["documents"][0]) - - # Create a list of tuples (distance, id, metadata, document) - combined = list( - zip(combined_distances, combined_ids, combined_metadatas, combined_documents) - ) - - # Sort the list based on distances - combined.sort(key=lambda x: x[0]) - - # Unzip the sorted list - sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined) - - # Slicing the lists to include only k elements - sorted_distances = list(sorted_distances)[:k] - sorted_ids = list(sorted_ids)[:k] - sorted_metadatas = list(sorted_metadatas)[:k] - sorted_documents = list(sorted_documents)[:k] - - # Create the output dictionary - merged_query_results = { - "ids": [sorted_ids], - "distances": [sorted_distances], - "metadatas": [sorted_metadatas], - "documents": [sorted_documents], - "embeddings": None, - "uris": None, - "data": None, - } - - return merged_query_results + k: Optional[int] = None @app.post("/query/collection") -def query_collection( +def query_collection_handler( form_data: QueryCollectionsForm, user=Depends(get_current_user), ): - results = [] - - for collection_name in form_data.collection_names: - try: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection( - name=collection_name, - embedding_function=app.state.sentence_transformer_ef, - ) - - result = collection.query( - query_texts=[form_data.query], n_results=form_data.k - ) - results.append(result) - except: - pass - - return merge_and_sort_query_results(results, form_data.k) + return query_collection( + collection_names=form_data.collection_names, + query=form_data.query, + k=form_data.k if form_data.k else app.state.TOP_K, + embedding_function=app.state.sentence_transformer_ef, + ) @app.post("/web") @@ -337,7 +311,7 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): if collection_name == "": collection_name = calculate_sha256_string(form_data.url)[:63] - store_data_in_vector_db(data, collection_name) + store_data_in_vector_db(data, collection_name, overwrite=True) return { "status": True, "collection_name": collection_name, @@ -401,7 +375,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ] if file_ext == "pdf": - loader = PyPDFLoader(file_path) + loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES) elif file_ext == "csv": loader = CSVLoader(file_path) elif file_ext == "rst": @@ -423,7 +397,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str): "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", ] or file_ext in ["xls", "xlsx"]: loader = UnstructuredExcelLoader(file_path) - elif file_ext in known_source_ext or (file_content_type and file_content_type.find("text/") >= 0): + elif file_ext in known_source_ext or ( + file_content_type and file_content_type.find("text/") >= 0 + ): loader = TextLoader(file_path) else: loader = TextLoader(file_path) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py new file mode 100644 index 0000000000..b2da7d90c3 --- /dev/null +++ b/backend/apps/rag/utils.py @@ -0,0 +1,183 @@ +import re +from typing import List + +from config import CHROMA_CLIENT + + +def query_doc(collection_name: str, query: str, k: int, embedding_function): + try: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + embedding_function=embedding_function, + ) + result = collection.query( + query_texts=[query], + n_results=k, + ) + return result + except Exception as e: + raise e + + +def merge_and_sort_query_results(query_results, k): + # Initialize lists to store combined data + combined_ids = [] + combined_distances = [] + combined_metadatas = [] + combined_documents = [] + + # Combine data from each dictionary + for data in query_results: + combined_ids.extend(data["ids"][0]) + combined_distances.extend(data["distances"][0]) + combined_metadatas.extend(data["metadatas"][0]) + combined_documents.extend(data["documents"][0]) + + # Create a list of tuples (distance, id, metadata, document) + combined = list( + zip(combined_distances, combined_ids, combined_metadatas, combined_documents) + ) + + # Sort the list based on distances + combined.sort(key=lambda x: x[0]) + + # Unzip the sorted list + sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined) + + # Slicing the lists to include only k elements + sorted_distances = list(sorted_distances)[:k] + sorted_ids = list(sorted_ids)[:k] + sorted_metadatas = list(sorted_metadatas)[:k] + sorted_documents = list(sorted_documents)[:k] + + # Create the output dictionary + merged_query_results = { + "ids": [sorted_ids], + "distances": [sorted_distances], + "metadatas": [sorted_metadatas], + "documents": [sorted_documents], + "embeddings": None, + "uris": None, + "data": None, + } + + return merged_query_results + + +def query_collection( + collection_names: List[str], query: str, k: int, embedding_function +): + + results = [] + + for collection_name in collection_names: + try: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + embedding_function=embedding_function, + ) + + result = collection.query( + query_texts=[query], + n_results=k, + ) + results.append(result) + except: + pass + + return merge_and_sort_query_results(results, k) + + +def rag_template(template: str, context: str, query: str): + template = re.sub(r"\[context\]", context, template) + template = re.sub(r"\[query\]", query, template) + + return template + + +def rag_messages(docs, messages, template, k, embedding_function): + print(docs) + + last_user_message_idx = None + for i in range(len(messages) - 1, -1, -1): + if messages[i]["role"] == "user": + last_user_message_idx = i + break + + user_message = messages[last_user_message_idx] + + if isinstance(user_message["content"], list): + # Handle list content input + content_type = "list" + query = "" + for content_item in user_message["content"]: + if content_item["type"] == "text": + query = content_item["text"] + break + elif isinstance(user_message["content"], str): + # Handle text content input + content_type = "text" + query = user_message["content"] + else: + # Fallback in case the input does not match expected types + content_type = None + query = "" + + relevant_contexts = [] + + for doc in docs: + context = None + + try: + if doc["type"] == "collection": + context = query_collection( + collection_names=doc["collection_names"], + query=query, + k=k, + embedding_function=embedding_function, + ) + else: + context = query_doc( + collection_name=doc["collection_name"], + query=query, + k=k, + embedding_function=embedding_function, + ) + except Exception as e: + print(e) + context = None + + relevant_contexts.append(context) + + context_string = "" + for context in relevant_contexts: + if context: + context_string += " ".join(context["documents"][0]) + "\n" + + ra_content = rag_template( + template=template, + context=context_string, + query=query, + ) + + if content_type == "list": + new_content = [] + for content_item in user_message["content"]: + if content_item["type"] == "text": + # Update the text item's content with ra_content + new_content.append({"type": "text", "text": ra_content}) + else: + # Keep other types of content as they are + new_content.append(content_item) + new_user_message = {**user_message, "content": new_content} + else: + new_user_message = { + **user_message, + "content": ra_content, + } + + messages[last_user_message_idx] = new_user_message + + return messages diff --git a/backend/apps/web/internal/db.py b/backend/apps/web/internal/db.py index 1f8c3bf7d2..d0aa996953 100644 --- a/backend/apps/web/internal/db.py +++ b/backend/apps/web/internal/db.py @@ -1,6 +1,16 @@ from peewee import * from config import DATA_DIR +import os -DB = SqliteDatabase(f"{DATA_DIR}/ollama.db") +# Check if the file exists +if os.path.exists(f"{DATA_DIR}/ollama.db"): + # Rename the file + os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db") + print("File renamed successfully.") +else: + pass + + +DB = SqliteDatabase(f"{DATA_DIR}/webui.db") DB.connect() diff --git a/backend/apps/web/routers/chats.py b/backend/apps/web/routers/chats.py index 1ce537ec61..0c0ac1ce89 100644 --- a/backend/apps/web/routers/chats.py +++ b/backend/apps/web/routers/chats.py @@ -271,6 +271,16 @@ async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)): @router.delete("/", response_model=bool) -async def delete_all_user_chats(user=Depends(get_current_user)): +async def delete_all_user_chats(request: Request, user=Depends(get_current_user)): + + if ( + user.role == "user" + and not request.app.state.USER_PERMISSIONS["chat"]["deletion"] + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + result = Chats.delete_chats_by_user_id(user.id) return result diff --git a/backend/apps/web/routers/utils.py b/backend/apps/web/routers/utils.py index 86e1a9e58c..fbb350cf29 100644 --- a/backend/apps/web/routers/utils.py +++ b/backend/apps/web/routers/utils.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, UploadFile, File, BackgroundTasks from fastapi import Depends, HTTPException, status -from starlette.responses import StreamingResponse +from starlette.responses import StreamingResponse, FileResponse + from pydantic import BaseModel @@ -9,9 +10,11 @@ import aiohttp import json + +from utils.utils import get_admin_user from utils.misc import calculate_sha256, get_gravatar_url -from config import OLLAMA_API_BASE_URL, DATA_DIR, UPLOAD_DIR +from config import OLLAMA_BASE_URLS, DATA_DIR, UPLOAD_DIR from constants import ERROR_MESSAGES @@ -72,7 +75,7 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024 hashed = calculate_sha256(file) file.seek(0) - url = f"{OLLAMA_API_BASE_URL}/blobs/sha256:{hashed}" + url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}" response = requests.post(url, data=file) if response.ok: @@ -144,7 +147,7 @@ def file_process_stream(): hashed = calculate_sha256(f) f.seek(0) - url = f"{OLLAMA_API_BASE_URL}/blobs/sha256:{hashed}" + url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}" response = requests.post(url, data=f) if response.ok: @@ -172,3 +175,13 @@ async def get_gravatar( email: str, ): return get_gravatar_url(email) + + +@router.get("/db/download") +async def download_db(user=Depends(get_admin_user)): + + return FileResponse( + f"{DATA_DIR}/webui.db", + media_type="application/octet-stream", + filename="webui.db", + ) diff --git a/backend/config.py b/backend/config.py index effcd24620..831371bb7e 100644 --- a/backend/config.py +++ b/backend/config.py @@ -200,16 +200,33 @@ def create_config_file(file_path): #################################### -# OLLAMA_API_BASE_URL +# OLLAMA_BASE_URL #################################### OLLAMA_API_BASE_URL = os.environ.get( "OLLAMA_API_BASE_URL", "http://localhost:11434/api" ) +OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") + + +if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "": + OLLAMA_BASE_URL = ( + OLLAMA_API_BASE_URL[:-4] + if OLLAMA_API_BASE_URL.endswith("/api") + else OLLAMA_API_BASE_URL + ) + if ENV == "prod": - if OLLAMA_API_BASE_URL == "/ollama/api": - OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api" + if OLLAMA_BASE_URL == "/ollama": + OLLAMA_BASE_URL = "http://host.docker.internal:11434" + + +OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") +OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL + +OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")] + #################################### # OPENAI_API @@ -218,15 +235,29 @@ def create_config_file(file_path): OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") + if OPENAI_API_BASE_URL == "": OPENAI_API_BASE_URL = "https://api.openai.com/v1" +OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "") +OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY + +OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")] + + +OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "") +OPENAI_API_BASE_URLS = ( + OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL +) + +OPENAI_API_BASE_URLS = [url.strip() for url in OPENAI_API_BASE_URLS.split(";")] + #################################### # WEBUI #################################### -ENABLE_SIGNUP = os.environ.get("ENABLE_SIGNUP", True) +ENABLE_SIGNUP = os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None) @@ -260,6 +291,11 @@ def create_config_file(file_path): USER_PERMISSIONS = {"chat": {"deletion": True}} +MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False) +MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") +MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")] + + #################################### # WEBUI_VERSION #################################### diff --git a/backend/constants.py b/backend/constants.py index 006fa7bbe7..eacf8a20f6 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -41,6 +41,7 @@ def __str__(self) -> str: NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/" API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." + MALICIOUS = "Unusual activities detected, please try again in a few minutes." PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance." @@ -48,3 +49,6 @@ def __str__(self) -> str: lambda err="": f"Invalid format. Please use the correct format{err if err else ''}" ) RATE_LIMIT_EXCEEDED = "API rate limit exceeded" + + MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" + OPENAI_NOT_FOUND = lambda name="": f"OpenAI API was not found" diff --git a/backend/data/config.json b/backend/data/config.json index 1b5971005b..d3ada59c91 100644 --- a/backend/data/config.json +++ b/backend/data/config.json @@ -1,4 +1,5 @@ { + "version": "0.0.1", "ui": { "prompt_suggestions": [ { diff --git a/backend/main.py b/backend/main.py index 94938b2492..2532271824 100644 --- a/backend/main.py +++ b/backend/main.py @@ -9,27 +9,37 @@ from fastapi import FastAPI, Request, Depends, status from fastapi.staticfiles import StaticFiles from fastapi import HTTPException -from fastapi.responses import JSONResponse from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.middleware.base import BaseHTTPMiddleware -from litellm.proxy.proxy_server import ProxyConfig, initialize -from litellm.proxy.proxy_server import app as litellm_app - from apps.ollama.main import app as ollama_app from apps.openai.main import app as openai_app +from apps.litellm.main import app as litellm_app, startup as litellm_app_startup from apps.audio.main import app as audio_app from apps.images.main import app as images_app from apps.rag.main import app as rag_app from apps.web.main import app as webui_app +from pydantic import BaseModel +from typing import List -from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR -from constants import ERROR_MESSAGES -from utils.utils import get_http_authorization_cred, get_current_user +from utils.utils import get_admin_user +from apps.rag.utils import rag_messages + +from config import ( + WEBUI_NAME, + ENV, + VERSION, + CHANGELOG, + FRONTEND_BUILD_DIR, + MODEL_FILTER_ENABLED, + MODEL_FILTER_LIST, +) +from constants import ERROR_MESSAGES class SPAStaticFiles(StaticFiles): @@ -43,24 +53,68 @@ async def get_response(self, path: str, scope): raise ex -proxy_config = ProxyConfig() +app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) +app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED +app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST -async def config(): - router, model_list, general_settings = await proxy_config.load_config( - router=None, config_file_path="./data/litellm/config.yaml" - ) +origins = ["*"] - await initialize(config="./data/litellm/config.yaml", telemetry=False) +class RAGMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + if request.method == "POST" and ( + "/api/chat" in request.url.path or "/chat/completions" in request.url.path + ): + print(request.url.path) -async def startup(): - await config() + # Read the original request body + body = await request.body() + # Decode body to string + body_str = body.decode("utf-8") + # Parse string to JSON + data = json.loads(body_str) if body_str else {} + # Example: Add a new key-value pair or modify existing ones + # data["modified"] = True # Example modification + if "docs" in data: -app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) + data = {**data} + data["messages"] = rag_messages( + data["docs"], + data["messages"], + rag_app.state.RAG_TEMPLATE, + rag_app.state.TOP_K, + rag_app.state.sentence_transformer_ef, + ) + del data["docs"] + + print(data["messages"]) + + modified_body_bytes = json.dumps(data).encode("utf-8") + + # Replace the request body with the modified one + request._body = modified_body_bytes + + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[ + (k, v) + for k, v in request.headers.raw + if k.lower() != b"content-length" + ], + ] + + response = await call_next(request) + return response + + async def _receive(self, body: bytes): + return {"type": "http.request", "body": body, "more_body": False} + + +app.add_middleware(RAGMiddleware) -origins = ["*"] app.add_middleware( CORSMiddleware, @@ -71,11 +125,6 @@ async def startup(): ) -@app.on_event("startup") -async def on_startup(): - await startup() - - @app.middleware("http") async def check_url(request: Request, call_next): start_time = int(time.time()) @@ -86,25 +135,15 @@ async def check_url(request: Request, call_next): return response -@litellm_app.middleware("http") -async def auth_middleware(request: Request, call_next): - auth_header = request.headers.get("Authorization", "") - - if ENV != "dev": - try: - user = get_current_user(get_http_authorization_cred(auth_header)) - print(user) - except Exception as e: - return JSONResponse(status_code=400, content={"detail": str(e)}) - - response = await call_next(request) - return response +@app.on_event("startup") +async def on_startup(): + await litellm_app_startup() app.mount("/api/v1", webui_app) app.mount("/litellm/api", litellm_app) -app.mount("/ollama/api", ollama_app) +app.mount("/ollama", ollama_app) app.mount("/openai/api", openai_app) app.mount("/images/api/v1", images_app) @@ -125,6 +164,47 @@ async def get_app_config(): } +@app.get("/api/config/model/filter") +async def get_model_filter_config(user=Depends(get_admin_user)): + return { + "enabled": app.state.MODEL_FILTER_ENABLED, + "models": app.state.MODEL_FILTER_LIST, + } + + +class ModelFilterConfigForm(BaseModel): + enabled: bool + models: List[str] + + +@app.post("/api/config/model/filter") +async def get_model_filter_config( + form_data: ModelFilterConfigForm, user=Depends(get_admin_user) +): + + app.state.MODEL_FILTER_ENABLED = form_data.enabled + app.state.MODEL_FILTER_LIST = form_data.models + + ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED + ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST + + openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED + openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST + + return { + "enabled": app.state.MODEL_FILTER_ENABLED, + "models": app.state.MODEL_FILTER_LIST, + } + + +@app.get("/api/version") +async def get_app_config(): + + return { + "version": VERSION, + } + + @app.get("/api/changelog") async def get_app_changelog(): return CHANGELOG @@ -148,6 +228,7 @@ async def get_app_latest_release_version(): app.mount("/static", StaticFiles(directory="static"), name="static") +app.mount("/cache", StaticFiles(directory="data/cache"), name="cache") app.mount( diff --git a/backend/requirements.txt b/backend/requirements.txt index 0cacacd800..29fb34925b 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -16,12 +16,14 @@ aiohttp peewee bcrypt -litellm +litellm==1.30.7 +argon2-cffi apscheduler google-generativeai langchain langchain-community +fake_useragent chromadb sentence_transformers pypdf @@ -34,6 +36,9 @@ openpyxl pyxlsb xlrd +opencv-python-headless +rapidocr-onnxruntime + faster-whisper PyJWT diff --git a/confirm_remove.sh b/confirm_remove.sh new file mode 100755 index 0000000000..729c25070d --- /dev/null +++ b/confirm_remove.sh @@ -0,0 +1,8 @@ +#!/bin/bash +echo "Warning: This will remove all containers and volumes, including persistent data. Do you want to continue? [Y/N]" +read ans +if [ "$ans" == "Y" ] || [ "$ans" == "y" ]; then + docker-compose down -v +else + echo "Operation cancelled." +fi diff --git a/docker-compose.yaml b/docker-compose.yaml index c41c56d8ea..f69084b8a5 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -14,7 +14,7 @@ services: build: context: . args: - OLLAMA_API_BASE_URL: '/ollama/api' + OLLAMA_BASE_URL: '/ollama' dockerfile: Dockerfile image: ghcr.io/open-webui/open-webui:main container_name: open-webui @@ -25,7 +25,7 @@ services: ports: - ${OPEN_WEBUI_PORT-3000}:8080 environment: - - 'OLLAMA_API_BASE_URL=http://ollama:11434/api' + - 'OLLAMA_BASE_URL=http://ollama:11434' - 'WEBUI_SECRET_KEY=' extra_hosts: - host.docker.internal:host-gateway diff --git a/kubernetes/helm/templates/webui-deployment.yaml b/kubernetes/helm/templates/webui-deployment.yaml index df13a14b63..bbd5706dea 100644 --- a/kubernetes/helm/templates/webui-deployment.yaml +++ b/kubernetes/helm/templates/webui-deployment.yaml @@ -40,7 +40,7 @@ spec: - name: data mountPath: /app/backend/data env: - - name: OLLAMA_API_BASE_URL + - name: OLLAMA_BASE_URL value: {{ include "ollama.url" . | quote }} tty: true {{- with .Values.webui.nodeSelector }} diff --git a/kubernetes/manifest/base/webui-deployment.yaml b/kubernetes/manifest/base/webui-deployment.yaml index 174025a94a..38efd55493 100644 --- a/kubernetes/manifest/base/webui-deployment.yaml +++ b/kubernetes/manifest/base/webui-deployment.yaml @@ -26,8 +26,8 @@ spec: cpu: "1000m" memory: "1Gi" env: - - name: OLLAMA_API_BASE_URL - value: "http://ollama-service.open-webui.svc.cluster.local:11434/api" + - name: OLLAMA_BASE_URL + value: "http://ollama-service.open-webui.svc.cluster.local:11434" tty: true volumeMounts: - name: webui-volume diff --git a/package-lock.json b/package-lock.json index 9fdfdb8a48..43deeace95 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "v1.0.0-alpha.101", + "version": "0.1.106", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "open-webui", - "version": "v1.0.0-alpha.101", + "version": "0.1.106", "dependencies": { "@sveltejs/adapter-node": "^1.3.1", "async": "^3.2.5", @@ -17,7 +17,7 @@ "js-sha256": "^0.10.1", "katex": "^0.16.9", "marked": "^9.1.0", - "svelte-french-toast": "^1.2.0", + "svelte-sonner": "^0.3.19", "tippy.js": "^6.3.7", "uuid": "^9.0.1" }, @@ -3211,17 +3211,6 @@ } } }, - "node_modules/svelte-french-toast": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/svelte-french-toast/-/svelte-french-toast-1.2.0.tgz", - "integrity": "sha512-5PW+6RFX3xQPbR44CngYAP1Sd9oCq9P2FOox4FZffzJuZI2mHOB7q5gJBVnOiLF5y3moVGZ7u2bYt7+yPAgcEQ==", - "dependencies": { - "svelte-writable-derived": "^3.1.0" - }, - "peerDependencies": { - "svelte": "^3.57.0 || ^4.0.0" - } - }, "node_modules/svelte-hmr": { "version": "0.15.3", "resolved": "https://registry.npmjs.org/svelte-hmr/-/svelte-hmr-0.15.3.tgz", @@ -3307,15 +3296,12 @@ "node": ">=12" } }, - "node_modules/svelte-writable-derived": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/svelte-writable-derived/-/svelte-writable-derived-3.1.0.tgz", - "integrity": "sha512-cTvaVFNIJ036vSDIyPxJYivKC7ZLtcFOPm1Iq6qWBDo1fOHzfk6ZSbwaKrxhjgy52Rbl5IHzRcWgos6Zqn9/rg==", - "funding": { - "url": "https://ko-fi.com/pixievoltno1" - }, + "node_modules/svelte-sonner": { + "version": "0.3.19", + "resolved": "https://registry.npmjs.org/svelte-sonner/-/svelte-sonner-0.3.19.tgz", + "integrity": "sha512-jpPOgLtHwRaB6Vqo2dUQMv15/yUV/BQWTjKpEqQ11uqRSHKjAYUKZyGrHB2cQsGmyjR0JUzBD58btpgNqINQ/Q==", "peerDependencies": { - "svelte": "^3.2.1 || ^4.0.0-next.1" + "svelte": ">=3 <5" } }, "node_modules/tailwindcss": { @@ -5882,14 +5868,6 @@ "postcss-scss": "^4.0.8" } }, - "svelte-french-toast": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/svelte-french-toast/-/svelte-french-toast-1.2.0.tgz", - "integrity": "sha512-5PW+6RFX3xQPbR44CngYAP1Sd9oCq9P2FOox4FZffzJuZI2mHOB7q5gJBVnOiLF5y3moVGZ7u2bYt7+yPAgcEQ==", - "requires": { - "svelte-writable-derived": "^3.1.0" - } - }, "svelte-hmr": { "version": "0.15.3", "resolved": "https://registry.npmjs.org/svelte-hmr/-/svelte-hmr-0.15.3.tgz", @@ -5920,10 +5898,10 @@ } } }, - "svelte-writable-derived": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/svelte-writable-derived/-/svelte-writable-derived-3.1.0.tgz", - "integrity": "sha512-cTvaVFNIJ036vSDIyPxJYivKC7ZLtcFOPm1Iq6qWBDo1fOHzfk6ZSbwaKrxhjgy52Rbl5IHzRcWgos6Zqn9/rg==", + "svelte-sonner": { + "version": "0.3.19", + "resolved": "https://registry.npmjs.org/svelte-sonner/-/svelte-sonner-0.3.19.tgz", + "integrity": "sha512-jpPOgLtHwRaB6Vqo2dUQMv15/yUV/BQWTjKpEqQ11uqRSHKjAYUKZyGrHB2cQsGmyjR0JUzBD58btpgNqINQ/Q==", "requires": {} }, "tailwindcss": { diff --git a/package.json b/package.json index dd212e7dd1..572443a542 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.1.105", + "version": "0.1.111", "private": true, "scripts": { "dev": "vite dev --host", @@ -49,7 +49,7 @@ "js-sha256": "^0.10.1", "katex": "^0.16.9", "marked": "^9.1.0", - "svelte-french-toast": "^1.2.0", + "svelte-sonner": "^0.3.19", "tippy.js": "^6.3.7", "uuid": "^9.0.1" } diff --git a/run-compose.sh b/run-compose.sh index 7b0f8d2baa..08fba272b5 100755 --- a/run-compose.sh +++ b/run-compose.sh @@ -182,7 +182,7 @@ else export OLLAMA_DATA_DIR=$data_dir # Set OLLAMA_DATA_DIR environment variable fi if [[ -n $webui_port ]]; then - export OLLAMA_WEBUI_PORT=$webui_port # Set OLLAMA_WEBUI_PORT environment variable + export OPEN_WEBUI_PORT=$webui_port # Set OPEN_WEBUI_PORT environment variable fi DEFAULT_COMPOSE_COMMAND+=" up -d" DEFAULT_COMPOSE_COMMAND+=" --remove-orphans" diff --git a/src/app.css b/src/app.css index 091db396bc..82b3caa373 100644 --- a/src/app.css +++ b/src/app.css @@ -28,6 +28,25 @@ math { @apply rounded-lg; } +ol > li { + counter-increment: list-number; + display: block; + margin-bottom: 0; + margin-top: 0; + min-height: 28px; +} + +.prose ol > li::before { + content: counters(list-number, '.') '.'; + padding-right: 0.5rem; + color: var(--tw-prose-counters); + font-weight: 400; +} + +li p { + display: inline; +} + ::-webkit-scrollbar-thumb { --tw-border-opacity: 1; background-color: rgba(217, 217, 227, 0.8); diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index aadf3769fa..35b259d561 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -439,7 +439,7 @@ export const deleteAllChats = async (token: string) => { return json; }) .catch((err) => { - error = err; + error = err.detail; console.log(err); return null; diff --git a/src/lib/apis/images/index.ts b/src/lib/apis/images/index.ts index f05ce0b763..1fb004a3c0 100644 --- a/src/lib/apis/images/index.ts +++ b/src/lib/apis/images/index.ts @@ -1,9 +1,9 @@ import { IMAGES_API_BASE_URL } from '$lib/constants'; -export const getImageGenerationEnabledStatus = async (token: string = '') => { +export const getImageGenerationConfig = async (token: string = '') => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/enabled`, { + const res = await fetch(`${IMAGES_API_BASE_URL}/config`, { method: 'GET', headers: { Accept: 'application/json', @@ -32,10 +32,50 @@ export const getImageGenerationEnabledStatus = async (token: string = '') => { return res; }; -export const toggleImageGenerationEnabledStatus = async (token: string = '') => { +export const updateImageGenerationConfig = async ( + token: string = '', + engine: string, + enabled: boolean +) => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/enabled/toggle`, { + const res = await fetch(`${IMAGES_API_BASE_URL}/config/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + engine, + enabled + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getOpenAIKey = async (token: string = '') => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/key`, { method: 'GET', headers: { Accept: 'application/json', @@ -61,7 +101,42 @@ export const toggleImageGenerationEnabledStatus = async (token: string = '') => throw error; } - return res; + return res.OPENAI_API_KEY; +}; + +export const updateOpenAIKey = async (token: string = '', key: string) => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/key/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + key: key + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.OPENAI_API_KEY; }; export const getAUTOMATIC1111Url = async (token: string = '') => { @@ -263,7 +338,7 @@ export const updateImageSteps = async (token: string = '', steps: number) => { return res.IMAGE_STEPS; }; -export const getDiffusionModels = async (token: string = '') => { +export const getImageGenerationModels = async (token: string = '') => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/models`, { @@ -295,7 +370,7 @@ export const getDiffusionModels = async (token: string = '') => { return res; }; -export const getDefaultDiffusionModel = async (token: string = '') => { +export const getDefaultImageGenerationModel = async (token: string = '') => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/models/default`, { @@ -327,7 +402,7 @@ export const getDefaultDiffusionModel = async (token: string = '') => { return res.model; }; -export const updateDefaultDiffusionModel = async (token: string = '', model: string) => { +export const updateDefaultImageGenerationModel = async (token: string = '', model: string) => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/models/default/update`, { diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index b7b346c0d5..b33fb571b5 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -77,3 +77,65 @@ export const getVersionUpdates = async () => { return res; }; + +export const getModelFilterConfig = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateModelFilterConfig = async ( + token: string, + enabled: boolean, + models: string[] +) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + enabled: enabled, + models: models + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/litellm/index.ts b/src/lib/apis/litellm/index.ts index 6466ee35be..302e9c4a3f 100644 --- a/src/lib/apis/litellm/index.ts +++ b/src/lib/apis/litellm/index.ts @@ -77,6 +77,7 @@ type AddLiteLLMModelForm = { api_base: string; api_key: string; rpm: string; + max_tokens: string; }; export const addLiteLLMModel = async (token: string = '', payload: AddLiteLLMModelForm) => { @@ -95,7 +96,8 @@ export const addLiteLLMModel = async (token: string = '', payload: AddLiteLLMMod model: payload.model, ...(payload.api_base === '' ? {} : { api_base: payload.api_base }), ...(payload.api_key === '' ? {} : { api_key: payload.api_key }), - ...(isNaN(parseInt(payload.rpm)) ? {} : { rpm: parseInt(payload.rpm) }) + ...(isNaN(parseInt(payload.rpm)) ? {} : { rpm: parseInt(payload.rpm) }), + ...(payload.max_tokens === '' ? {} : { max_tokens: payload.max_tokens }) } }) }) diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index 5fc8a5fef4..2047fedef0 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -1,9 +1,9 @@ import { OLLAMA_API_BASE_URL } from '$lib/constants'; -export const getOllamaAPIUrl = async (token: string = '') => { +export const getOllamaUrls = async (token: string = '') => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/url`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/urls`, { method: 'GET', headers: { Accept: 'application/json', @@ -29,13 +29,13 @@ export const getOllamaAPIUrl = async (token: string = '') => { throw error; } - return res.OLLAMA_API_BASE_URL; + return res.OLLAMA_BASE_URLS; }; -export const updateOllamaAPIUrl = async (token: string = '', url: string) => { +export const updateOllamaUrls = async (token: string = '', urls: string[]) => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/url/update`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/urls/update`, { method: 'POST', headers: { Accept: 'application/json', @@ -43,7 +43,7 @@ export const updateOllamaAPIUrl = async (token: string = '', url: string) => { ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - url: url + urls: urls }) }) .then(async (res) => { @@ -64,13 +64,13 @@ export const updateOllamaAPIUrl = async (token: string = '', url: string) => { throw error; } - return res.OLLAMA_API_BASE_URL; + return res.OLLAMA_BASE_URLS; }; export const getOllamaVersion = async (token: string = '') => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/version`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/version`, { method: 'GET', headers: { Accept: 'application/json', @@ -102,7 +102,7 @@ export const getOllamaVersion = async (token: string = '') => { export const getOllamaModels = async (token: string = '') => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/tags`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/tags`, { method: 'GET', headers: { Accept: 'application/json', @@ -148,10 +148,11 @@ export const generateTitle = async ( console.log(template); - const res = await fetch(`${OLLAMA_API_BASE_URL}/generate`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, { method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ @@ -186,10 +187,11 @@ export const generatePrompt = async (token: string = '', model: string, conversa conversation = '[no existing conversation]'; } - const res = await fetch(`${OLLAMA_API_BASE_URL}/generate`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, { method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ @@ -217,15 +219,43 @@ export const generatePrompt = async (token: string = '', model: string, conversa return res; }; +export const generateTextCompletion = async (token: string = '', model: string, text: string) => { + let error = null; + + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: text, + stream: true + }) + }).catch((err) => { + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const generateChatCompletion = async (token: string = '', body: object) => { let controller = new AbortController(); let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/chat`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/chat`, { signal: controller.signal, method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify(body) @@ -265,10 +295,11 @@ export const cancelChatCompletion = async (token: string = '', requestId: string export const createModel = async (token: string, tagName: string, content: string) => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/create`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/create`, { method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ @@ -287,19 +318,23 @@ export const createModel = async (token: string, tagName: string, content: strin return res; }; -export const deleteModel = async (token: string, tagName: string) => { +export const deleteModel = async (token: string, tagName: string, urlIdx: string | null = null) => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/delete`, { - method: 'DELETE', - headers: { - 'Content-Type': 'text/event-stream', - Authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - name: tagName - }) - }) + const res = await fetch( + `${OLLAMA_API_BASE_URL}/api/delete${urlIdx !== null ? `/${urlIdx}` : ''}`, + { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + name: tagName + }) + } + ) .then(async (res) => { if (!res.ok) throw await res.json(); return res.json(); @@ -310,7 +345,12 @@ export const deleteModel = async (token: string, tagName: string) => { }) .catch((err) => { console.log(err); - error = err.error; + error = err; + + if ('detail' in err) { + error = err.detail; + } + return null; }); @@ -321,13 +361,14 @@ export const deleteModel = async (token: string, tagName: string) => { return res; }; -export const pullModel = async (token: string, tagName: string) => { +export const pullModel = async (token: string, tagName: string, urlIdx: string | null = null) => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/pull`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull${urlIdx !== null ? `/${urlIdx}` : ''}`, { method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 3a629eb31c..e38314a550 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -1,9 +1,9 @@ import { OPENAI_API_BASE_URL } from '$lib/constants'; -export const getOpenAIUrl = async (token: string = '') => { +export const getOpenAIUrls = async (token: string = '') => { let error = null; - const res = await fetch(`${OPENAI_API_BASE_URL}/url`, { + const res = await fetch(`${OPENAI_API_BASE_URL}/urls`, { method: 'GET', headers: { Accept: 'application/json', @@ -29,13 +29,13 @@ export const getOpenAIUrl = async (token: string = '') => { throw error; } - return res.OPENAI_API_BASE_URL; + return res.OPENAI_API_BASE_URLS; }; -export const updateOpenAIUrl = async (token: string = '', url: string) => { +export const updateOpenAIUrls = async (token: string = '', urls: string[]) => { let error = null; - const res = await fetch(`${OPENAI_API_BASE_URL}/url/update`, { + const res = await fetch(`${OPENAI_API_BASE_URL}/urls/update`, { method: 'POST', headers: { Accept: 'application/json', @@ -43,7 +43,7 @@ export const updateOpenAIUrl = async (token: string = '', url: string) => { ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - url: url + urls: urls }) }) .then(async (res) => { @@ -64,13 +64,13 @@ export const updateOpenAIUrl = async (token: string = '', url: string) => { throw error; } - return res.OPENAI_API_BASE_URL; + return res.OPENAI_API_BASE_URLS; }; -export const getOpenAIKey = async (token: string = '') => { +export const getOpenAIKeys = async (token: string = '') => { let error = null; - const res = await fetch(`${OPENAI_API_BASE_URL}/key`, { + const res = await fetch(`${OPENAI_API_BASE_URL}/keys`, { method: 'GET', headers: { Accept: 'application/json', @@ -96,13 +96,13 @@ export const getOpenAIKey = async (token: string = '') => { throw error; } - return res.OPENAI_API_KEY; + return res.OPENAI_API_KEYS; }; -export const updateOpenAIKey = async (token: string = '', key: string) => { +export const updateOpenAIKeys = async (token: string = '', keys: string[]) => { let error = null; - const res = await fetch(`${OPENAI_API_BASE_URL}/key/update`, { + const res = await fetch(`${OPENAI_API_BASE_URL}/keys/update`, { method: 'POST', headers: { Accept: 'application/json', @@ -110,7 +110,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => { ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - key: key + keys: keys }) }) .then(async (res) => { @@ -131,7 +131,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => { throw error; } - return res.OPENAI_API_KEY; + return res.OPENAI_API_KEYS; }; export const getOpenAIModels = async (token: string = '') => { diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index ed36f0143c..668fe227be 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -1,9 +1,9 @@ import { RAG_API_BASE_URL } from '$lib/constants'; -export const getChunkParams = async (token: string) => { +export const getRAGConfig = async (token: string) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/chunk`, { + const res = await fetch(`${RAG_API_BASE_URL}/config`, { method: 'GET', headers: { 'Content-Type': 'application/json', @@ -27,18 +27,27 @@ export const getChunkParams = async (token: string) => { return res; }; -export const updateChunkParams = async (token: string, size: number, overlap: number) => { +type ChunkConfigForm = { + chunk_size: number; + chunk_overlap: number; +}; + +type RAGConfigForm = { + pdf_extract_images: boolean; + chunk: ChunkConfigForm; +}; + +export const updateRAGConfig = async (token: string, payload: RAGConfigForm) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/chunk/update`, { + const res = await fetch(`${RAG_API_BASE_URL}/config/update`, { method: 'POST', headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ - chunk_size: size, - chunk_overlap: overlap + ...payload }) }) .then(async (res) => { @@ -85,17 +94,49 @@ export const getRAGTemplate = async (token: string) => { return res?.template ?? ''; }; -export const updateRAGTemplate = async (token: string, template: string) => { +export const getQuerySettings = async (token: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/query/settings`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +type QuerySettings = { + k: number | null; + template: string | null; +}; + +export const updateQuerySettings = async (token: string, settings: QuerySettings) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/template/update`, { + const res = await fetch(`${RAG_API_BASE_URL}/query/settings/update`, { method: 'POST', headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ - template: template + ...settings }) }) .then(async (res) => { @@ -183,7 +224,7 @@ export const queryDoc = async ( token: string, collection_name: string, query: string, - k: number + k: number | null = null ) => { let error = null; @@ -220,7 +261,7 @@ export const queryCollection = async ( token: string, collection_names: string, query: string, - k: number + k: number | null = null ) => { let error = null; diff --git a/src/lib/apis/utils/index.ts b/src/lib/apis/utils/index.ts index ed4d4e0290..bcb554077c 100644 --- a/src/lib/apis/utils/index.ts +++ b/src/lib/apis/utils/index.ts @@ -21,3 +21,35 @@ export const getGravatarUrl = async (email: string) => { return res; }; + +export const downloadDatabase = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/utils/db/download`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then((response) => { + if (!response.ok) { + throw new Error('Network response was not ok'); + } + return response.blob(); + }) + .then((blob) => { + const url = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'webui.db'; + document.body.appendChild(a); + a.click(); + window.URL.revokeObjectURL(url); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); +}; diff --git a/src/lib/components/admin/EditUserModal.svelte b/src/lib/components/admin/EditUserModal.svelte index 09005b30af..d8ceb1457e 100644 --- a/src/lib/components/admin/EditUserModal.svelte +++ b/src/lib/components/admin/EditUserModal.svelte @@ -1,5 +1,5 @@ + +
diff --git a/src/lib/components/admin/Settings/Users.svelte b/src/lib/components/admin/Settings/Users.svelte index 8a442c5199..9f2b5c40cc 100644 --- a/src/lib/components/admin/Settings/Users.svelte +++ b/src/lib/components/admin/Settings/Users.svelte @@ -1,10 +1,14 @@ @@ -21,6 +32,8 @@ on:submit|preventDefault={async () => { // console.log('submit'); await updateUserPermissions(localStorage.token, permissions); + + await updateModelFilterConfig(localStorage.token, whitelistEnabled, whitelistModels); saveHandler(); }} > @@ -69,6 +82,106 @@ + +