From 1cbd85308b8daf0ec54cefeb15761394bfd93de0 Mon Sep 17 00:00:00 2001 From: Sanjiv Das Date: Mon, 9 Dec 2024 15:48:40 -0800 Subject: [PATCH] Add base API URL field for Ollama and OpenAI embedding models (#1136) * Base API URL added for embedding models Jupyter AI currently allows the user to call a model at a URL (location) different from the default one by specifying a selected Base API URL. This can be done for Ollama, OpenAI provider models. However, for these providers, there is no way to change the API URL for embedding models when using the `/learn` command in RAG mode. This PR adds an extra field to make this feasible. Tested as follows for Ollama: [1] Start the Ollama system from port 11435 instead 11434 (the default): `OLLAMA_HOST=127.0.0.1:11435 ollama serve` [2] Set the Base API URL: [3] Check that the new API URL works: * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * allow embedding model fields to be saved * exclude empty str fields from config manager * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: David L. Qiu --- .../partner_providers/ollama.py | 8 ++- .../partner_providers/openai.py | 10 ++- .../jupyter-ai/jupyter_ai/config_manager.py | 7 +++ .../src/components/chat-settings.tsx | 63 +++++++++++++------ 4 files changed, 65 insertions(+), 23 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py index 5babc5adb..bf7d8474a 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py @@ -1,7 +1,7 @@ from langchain_ollama import ChatOllama, OllamaEmbeddings from ..embedding_providers import BaseEmbeddingsProvider -from ..providers import BaseProvider, EnvAuthStrategy, TextField +from ..providers import BaseProvider, TextField class OllamaProvider(BaseProvider, ChatOllama): @@ -23,10 +23,14 @@ class OllamaEmbeddingsProvider(BaseEmbeddingsProvider, OllamaEmbeddings): id = "ollama" name = "Ollama" # source: https://ollama.com/library + model_id_key = "model" models = [ "nomic-embed-text", "mxbai-embed-large", "all-minilm", "snowflake-arctic-embed", ] - model_id_key = "model" + registry = True + fields = [ + TextField(key="base_url", label="Base API URL (optional)", format="text"), + ] diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py index 7e97995d4..34ca76a8e 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py @@ -107,6 +107,12 @@ class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings): model_id_key = "model" pypi_package_deps = ["langchain_openai"] auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") + registry = True + fields = [ + TextField( + key="openai_api_base", label="Base API URL (optional)", format="text" + ), + ] class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbeddings): @@ -122,5 +128,7 @@ class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbedding auth_strategy = EnvAuthStrategy( name="AZURE_OPENAI_API_KEY", keyword_param="openai_api_key" ) - registry = True + fields = [ + TextField(key="azure_endpoint", label="Base API URL (optional)", format="text"), + ] diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 7b309faae..4732e152c 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -462,6 +462,13 @@ def _provider_params(self, key, listing, completions: bool = False): else: fields = config.fields.get(model_uid, {}) + # exclude empty fields + # TODO: modify the config manager to never save empty fields in the + # first place. + for field_key in fields: + if isinstance(fields[field_key], str) and not len(fields[field_key]): + fields[field_key] = None + # get authn fields _, Provider = get_em_provider(model_uid, listing) authn_fields = {} diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index a1ad0a9b6..c32eb46fd 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -88,6 +88,9 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { const [apiKeys, setApiKeys] = useState>({}); const [sendWse, setSendWse] = useState(false); const [fields, setFields] = useState>({}); + const [embeddingModelFields, setEmbeddingModelFields] = useState< + Record + >({}); const [isCompleterEnabled, setIsCompleterEnabled] = useState( props.completionProvider && props.completionProvider.isEnabled() @@ -188,7 +191,15 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { const currFields: Record = server.config.fields?.[lmGlobalId] ?? {}; setFields(currFields); - }, [server, lmProvider]); + + if (!emGlobalId) { + return; + } + + const initEmbeddingModelFields: Record = + server.config.fields?.[emGlobalId] ?? {}; + setEmbeddingModelFields(initEmbeddingModelFields); + }, [server, lmGlobalId, emGlobalId]); const handleSave = async () => { // compress fields with JSON values @@ -222,6 +233,9 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { }), ...(clmGlobalId && { [clmGlobalId]: fields + }), + ...(emGlobalId && { + [emGlobalId]: embeddingModelFields }) } }), @@ -376,26 +390,35 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { {/* Embedding model section */}

Embedding model

{server.emProviders.providers.length > 0 ? ( - { + const emGid = e.target.value === 'null' ? null : e.target.value; + setEmGlobalId(emGid); + }} + MenuProps={{ sx: { maxHeight: '50%', minHeight: 400 } }} + > + None + {server.emProviders.providers.map(emp => + emp.models + .filter(em => em !== '*') // TODO: support registry providers + .map(em => ( + + {emp.name} :: {em} + + )) + )} + + {emGlobalId && ( + )} - + ) : (

No embedding models available.

)}