diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index c07061b1d..091b78fea 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -239,6 +239,16 @@ class Config: provider is selected. """ + @classmethod + def chat_models(self): + """Models which are suitable for chat.""" + return self.models + + @classmethod + def completion_models(self): + """Models which are suitable for completions.""" + return self.models + # # instance attrs # diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py index 1371e3cbf..fa16920d2 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py @@ -27,8 +27,8 @@ def __init__(self, *args, **kwargs): self.llm_chain = None def get_llm_chain(self): - lm_provider = self.config_manager.lm_provider - lm_provider_params = self.config_manager.lm_provider_params + lm_provider = self.config_manager.completions_lm_provider + lm_provider_params = self.config_manager.completions_lm_provider_params if not lm_provider or not lm_provider_params: return None diff --git a/packages/jupyter-ai/jupyter_ai/config/config_schema.json b/packages/jupyter-ai/jupyter_ai/config/config_schema.json index ff7c717c4..0dd910ab9 100644 --- a/packages/jupyter-ai/jupyter_ai/config/config_schema.json +++ b/packages/jupyter-ai/jupyter_ai/config/config_schema.json @@ -16,6 +16,18 @@ "default": null, "readOnly": false }, + "completions_model_provider_id": { + "$comment": "Language model global ID for completions.", + "type": ["string", "null"], + "default": null, + "readOnly": false + }, + "completions_embeddings_provider_id": { + "$comment": "Embedding model global ID for completions.", + "type": ["string", "null"], + "default": null, + "readOnly": false + }, "api_keys": { "$comment": "Dictionary of API keys, mapping key names to key values.", "type": "object", @@ -37,6 +49,17 @@ } }, "additionalProperties": false + }, + "completions_fields": { + "$comment": "Dictionary of model-specific fields, mapping LM GIDs to sub-dictionaries of field key-value pairs for completions.", + "type": "object", + "default": {}, + "patternProperties": { + "^.*$": { + "anyOf": [{ "type": "object" }] + } + }, + "additionalProperties": false } } } diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 392e44601..5969c9f90 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -164,15 +164,22 @@ def _process_existing_config(self, default_config): {k: v for k, v in existing_config.items() if v is not None}, ) config = GlobalConfig(**merged_config) - validated_config = self._validate_lm_em_id(config) + validated_config = self._validate_lm_em_id( + config, lm_key="model_provider_id", em_key="embeddings_provider_id" + ) + validated_config = self._validate_lm_em_id( + config, + lm_key="completions_model_provider_id", + em_key="completions_embeddings_provider_id", + ) # re-write to the file to validate the config and apply any # updates to the config file immediately self._write_config(validated_config) - def _validate_lm_em_id(self, config): - lm_id = config.model_provider_id - em_id = config.embeddings_provider_id + def _validate_lm_em_id(self, config, lm_key, em_key): + lm_id = getattr(config, lm_key) + em_id = getattr(config, em_key) # if the currently selected language or embedding model are # forbidden, set them to `None` and log a warning. @@ -180,12 +187,12 @@ def _validate_lm_em_id(self, config): self.log.warning( f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." ) - config.model_provider_id = None + setattr(config, lm_key, None) if em_id is not None and not self._validate_model(em_id, raise_exc=False): self.log.warning( f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." ) - config.embeddings_provider_id = None + setattr(config, em_key, None) # if the currently selected language or embedding model ids are # not associated with models, set them to `None` and log a warning. @@ -193,12 +200,12 @@ def _validate_lm_em_id(self, config): self.log.warning( f"No language model is associated with '{lm_id}'. Setting to None." ) - config.model_provider_id = None + setattr(config, lm_key, None) if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]: self.log.warning( f"No embedding model is associated with '{em_id}'. Setting to None." ) - config.embeddings_provider_id = None + setattr(config, em_key, None) return config @@ -321,6 +328,9 @@ def _write_config(self, new_config: GlobalConfig): complete `GlobalConfig` object, and should not be called publicly.""" # remove any empty field dictionaries new_config.fields = {k: v for k, v in new_config.fields.items() if v} + new_config.completions_fields = { + k: v for k, v in new_config.completions_fields.items() if v + } self._validate_config(new_config) with open(self.config_path, "w") as f: @@ -328,21 +338,19 @@ def _write_config(self, new_config: GlobalConfig): def delete_api_key(self, key_name: str): config_dict = self._read_config().dict() - lm_provider = self.lm_provider - em_provider = self.em_provider required_keys = [] - if ( - lm_provider - and lm_provider.auth_strategy - and lm_provider.auth_strategy.type == "env" - ): - required_keys.append(lm_provider.auth_strategy.name) - if ( - em_provider - and em_provider.auth_strategy - and em_provider.auth_strategy.type == "env" - ): - required_keys.append(self.em_provider.auth_strategy.name) + for provider in [ + self.lm_provider, + self.em_provider, + self.completions_lm_provider, + self.completions_em_provider, + ]: + if ( + provider + and provider.auth_strategy + and provider.auth_strategy.type == "env" + ): + required_keys.append(provider.auth_strategy.name) if key_name in required_keys: raise KeyInUseError( @@ -390,67 +398,69 @@ def em_gid(self): @property def lm_provider(self): - config = self._read_config() - lm_gid = config.model_provider_id - if lm_gid is None: - return None - - _, Provider = get_lm_provider(config.model_provider_id, self._lm_providers) - return Provider + return self._get_provider("model_provider_id", self._lm_providers) @property def em_provider(self): + return self._get_provider("embeddings_provider_id", self._em_providers) + + @property + def completions_lm_provider(self): + return self._get_provider("completions_model_provider_id", self._lm_providers) + + @property + def completions_em_provider(self): + return self._get_provider( + "completions_embeddings_provider_id", self._em_providers + ) + + def _get_provider(self, key, listing): config = self._read_config() - em_gid = config.embeddings_provider_id - if em_gid is None: + gid = getattr(config, key) + if gid is None: return None - _, Provider = get_em_provider(em_gid, self._em_providers) + _, Provider = get_lm_provider(gid, listing) return Provider @property def lm_provider_params(self): - # get generic fields - config = self._read_config() - lm_gid = config.model_provider_id - if not lm_gid: - return None + return self._provider_params("model_provider_id", self._lm_providers) - lm_lid = lm_gid.split(":", 1)[1] - fields = config.fields.get(lm_gid, {}) - - # get authn fields - _, Provider = get_lm_provider(lm_gid, self._lm_providers) - authn_fields = {} - if Provider.auth_strategy and Provider.auth_strategy.type == "env": - key_name = Provider.auth_strategy.name - authn_fields[key_name.lower()] = config.api_keys[key_name] + @property + def em_provider_params(self): + return self._provider_params("embeddings_provider_id", self._em_providers) - return { - "model_id": lm_lid, - **fields, - **authn_fields, - } + @property + def completions_lm_provider_params(self): + return self._provider_params( + "completions_model_provider_id", self._lm_providers + ) @property - def em_provider_params(self): + def completions_em_provider_params(self): + return self._provider_params( + "completions_embeddings_provider_id", self._em_providers + ) + + def _provider_params(self, key, listing): # get generic fields config = self._read_config() - em_gid = config.embeddings_provider_id - if not em_gid: + gid = getattr(config, key) + if not gid: return None - em_lid = em_gid.split(":", 1)[1] + lid = gid.split(":", 1)[1] # get authn fields - _, Provider = get_em_provider(em_gid, self._em_providers) + _, Provider = get_em_provider(gid, listing) authn_fields = {} if Provider.auth_strategy and Provider.auth_strategy.type == "env": key_name = Provider.auth_strategy.name authn_fields[key_name.lower()] = config.api_keys[key_name] return { - "model_id": em_lid, + "model_id": lid, **authn_fields, } diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 6002310ba..976b7df62 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -290,6 +290,10 @@ def filter_predicate(local_model_id: str): # filter out every model w/ model ID according to allow/blocklist for provider in providers: provider.models = list(filter(filter_predicate, provider.models)) + provider.chat_models = list(filter(filter_predicate, provider.chat_models)) + provider.completion_models = list( + filter(filter_predicate, provider.completion_models) + ) # filter out every provider with no models which satisfy the allow/blocklist, then return return filter((lambda p: len(p.models) > 0), providers) @@ -311,6 +315,8 @@ def get(self): id=provider.id, name=provider.name, models=provider.models, + chat_models=provider.chat_models(), + completion_models=provider.completion_models(), help=provider.help, auth_strategy=provider.auth_strategy, registry=provider.registry, diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 32353a694..5675c6366 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -94,6 +94,8 @@ class ListProvidersEntry(BaseModel): auth_strategy: AuthStrategy registry: bool fields: List[Field] + chat_models: Optional[List[str]] + completion_models: Optional[List[str]] class ListProvidersResponse(BaseModel): @@ -121,6 +123,9 @@ class DescribeConfigResponse(BaseModel): # timestamp indicating when the configuration file was last read. should be # passed to the subsequent UpdateConfig request. last_read: int + completions_model_provider_id: Optional[str] + completions_embeddings_provider_id: Optional[str] + completions_fields: Dict[str, Dict[str, Any]] def forbid_none(cls, v): @@ -137,6 +142,9 @@ class UpdateConfigRequest(BaseModel): # if passed, this will raise an Error if the config was written to after the # time specified by `last_read` to prevent write-write conflicts. last_read: Optional[int] + completions_model_provider_id: Optional[str] + completions_embeddings_provider_id: Optional[str] + completions_fields: Optional[Dict[str, Dict[str, Any]]] _validate_send_wse = validator("send_with_shift_enter", allow_reuse=True)( forbid_none @@ -154,3 +162,6 @@ class GlobalConfig(BaseModel): send_with_shift_enter: bool fields: Dict[str, Dict[str, Any]] api_keys: Dict[str, str] + completions_model_provider_id: Optional[str] + completions_embeddings_provider_id: Optional[str] + completions_fields: Dict[str, Dict[str, Any]] diff --git a/packages/jupyter-ai/src/completions/plugin.ts b/packages/jupyter-ai/src/completions/plugin.ts index 835561143..c2519fd9e 100644 --- a/packages/jupyter-ai/src/completions/plugin.ts +++ b/packages/jupyter-ai/src/completions/plugin.ts @@ -4,6 +4,8 @@ import { } from '@jupyterlab/application'; import { ICompletionProviderManager } from '@jupyterlab/completer'; import { ISettingRegistry } from '@jupyterlab/settingregistry'; +import { MainAreaWidget } from '@jupyterlab/apputils'; +import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { IEditorLanguageRegistry, IEditorLanguage @@ -12,6 +14,8 @@ import { getEditor } from '../selection-watcher'; import { IJaiStatusItem } from '../tokens'; import { displayName, JaiInlineProvider } from './provider'; import { CompletionWebsocketHandler } from './handler'; +import { jupyternautIcon } from '../icons'; +import { ModelSettingsWidget } from './settings'; export namespace CommandIDs { /** @@ -23,6 +27,10 @@ export namespace CommandIDs { */ export const toggleLanguageCompletions = 'jupyter-ai:toggle-language-completions'; + /** + * Command to open provider/model configuration. + */ + export const configureModel = 'jupyter-ai:configure-completions'; } const INLINE_COMPLETER_PLUGIN = @@ -52,7 +60,8 @@ export const completionPlugin: JupyterFrontEndPlugin = { requires: [ ICompletionProviderManager, IEditorLanguageRegistry, - ISettingRegistry + ISettingRegistry, + IRenderMimeRegistry ], optional: [IJaiStatusItem], activate: async ( @@ -60,6 +69,7 @@ export const completionPlugin: JupyterFrontEndPlugin = { completionManager: ICompletionProviderManager, languageRegistry: IEditorLanguageRegistry, settingRegistry: ISettingRegistry, + rmRegistry: IRenderMimeRegistry, statusItem: IJaiStatusItem | null ): Promise => { if (typeof completionManager.registerInlineProvider === 'undefined') { @@ -176,6 +186,37 @@ export const completionPlugin: JupyterFrontEndPlugin = { } }); + let settingsWidget: MainAreaWidget | null = null; + const newSettingsWidget = () => { + const content = new ModelSettingsWidget({ + rmRegistry, + isProviderEnabled: () => provider.isEnabled(), + openInlineCompleterSettings: () => { + app.commands.execute('settingeditor:open', { + query: 'Inline Completer' + }); + } + }); + const widget = new MainAreaWidget({ content }); + widget.id = 'jupyterlab-inline-completions-model'; + widget.title.label = 'AI Completions Model Settings'; + widget.title.closable = true; + widget.title.icon = jupyternautIcon; + return widget; + }; + app.commands.addCommand(CommandIDs.configureModel, { + execute: () => { + if (!settingsWidget || settingsWidget.isDisposed) { + settingsWidget = newSettingsWidget(); + } + if (!settingsWidget.isAttached) { + app.shell.add(settingsWidget, 'main'); + } + app.shell.activateById(settingsWidget.id); + }, + label: 'Configure Jupyternaut Completions Model' + }); + if (statusItem) { statusItem.addItem({ command: CommandIDs.toggleCompletions, @@ -185,6 +226,10 @@ export const completionPlugin: JupyterFrontEndPlugin = { command: CommandIDs.toggleLanguageCompletions, rank: 2 }); + statusItem.addItem({ + command: CommandIDs.configureModel, + rank: 3 + }); } } }; diff --git a/packages/jupyter-ai/src/completions/settings.tsx b/packages/jupyter-ai/src/completions/settings.tsx new file mode 100644 index 000000000..23e0efb24 --- /dev/null +++ b/packages/jupyter-ai/src/completions/settings.tsx @@ -0,0 +1,181 @@ +import { ReactWidget } from '@jupyterlab/ui-components'; +import React, { useState } from 'react'; + +import { Box } from '@mui/system'; +import { Alert, Button, CircularProgress } from '@mui/material'; + +import { AiService } from '../handler'; +import { + ServerInfoState, + useServerInfo +} from '../components/settings/use-server-info'; +import { ModelSettings, IModelSettings } from '../components/model-settings'; +import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; +import { minifyUpdate } from '../components/settings/minify'; +import { useStackingAlert } from '../components/mui-extras/stacking-alert'; + +type CompleterSettingsProps = { + rmRegistry: IRenderMimeRegistry; + isProviderEnabled: () => boolean; + openInlineCompleterSettings: () => void; +}; + +/** + * Component that returns the settings view. + */ +export function CompleterSettings(props: CompleterSettingsProps): JSX.Element { + // state fetched on initial render + const server = useServerInfo(); + + // initialize alert helper + const alert = useStackingAlert(); + + // whether the form is currently saving + const [saving, setSaving] = useState(false); + + // provider/model settings + const [modelSettings, setModelSettings] = useState({ + fields: {}, + apiKeys: {}, + emGlobalId: null, + lmGlobalId: null + }); + + const handleSave = async () => { + // compress fields with JSON values + if (server.state !== ServerInfoState.Ready) { + return; + } + + const { fields, lmGlobalId, emGlobalId, apiKeys } = modelSettings; + + for (const fieldKey in fields) { + const fieldVal = fields[fieldKey]; + if (typeof fieldVal !== 'string' || !fieldVal.trim().startsWith('{')) { + continue; + } + + try { + const parsedFieldVal = JSON.parse(fieldVal); + const compressedFieldVal = JSON.stringify(parsedFieldVal); + fields[fieldKey] = compressedFieldVal; + } catch (e) { + continue; + } + } + + let updateRequest: AiService.UpdateConfigRequest = { + completions_model_provider_id: lmGlobalId, + completions_embeddings_provider_id: emGlobalId, + api_keys: apiKeys, + ...(lmGlobalId && { + completions_fields: { + [lmGlobalId]: fields + } + }) + }; + updateRequest = minifyUpdate(server.config, updateRequest); + updateRequest.last_read = server.config.last_read; + + setSaving(true); + try { + await AiService.updateConfig(updateRequest); + } catch (e) { + console.error(e); + const msg = + e instanceof Error || typeof e === 'string' + ? e.toString() + : 'An unknown error occurred. Check the console for more details.'; + alert.show('error', msg); + return; + } finally { + setSaving(false); + } + await server.refetchAll(); + alert.show('success', 'Settings saved successfully.'); + }; + + if (server.state === ServerInfoState.Loading) { + return ( + + + + ); + } + + if (server.state === ServerInfoState.Error) { + return ( + + + {server.error || + 'An unknown error occurred. Check the console for more details.'} + + + ); + } + + return ( + + {props.isProviderEnabled() ? null : ( + + The jupyter-ai inline completion provider is not enabled in the Inline + Completer settings. + + + )} + + + + + + + {alert.jsx} + + ); +} + +export class ModelSettingsWidget extends ReactWidget { + constructor(protected options: CompleterSettingsProps) { + super(); + } + render(): JSX.Element { + return ; + } +} diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index 889342a24..2a2da034e 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -1,4 +1,4 @@ -import React, { useEffect, useState, useMemo } from 'react'; +import React, { useEffect, useState } from 'react'; import { Box } from '@mui/system'; import { @@ -7,23 +7,17 @@ import { FormControl, FormControlLabel, FormLabel, - MenuItem, Radio, RadioGroup, - TextField, CircularProgress } from '@mui/material'; -import { Select } from './select'; import { AiService } from '../handler'; -import { ModelFields } from './settings/model-fields'; import { ServerInfoState, useServerInfo } from './settings/use-server-info'; -import { ExistingApiKeys } from './settings/existing-api-keys'; +import { ModelSettings, IModelSettings } from './model-settings'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { minifyUpdate } from './settings/minify'; import { useStackingAlert } from './mui-extras/stacking-alert'; -import { RendermimeMarkdown } from './rendermime-markdown'; -import { getProviderId, getModelLocalId } from '../utils'; type ChatSettingsProps = { rmRegistry: IRenderMimeRegistry; @@ -38,38 +32,21 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { // initialize alert helper const alert = useStackingAlert(); - const apiKeysAlert = useStackingAlert(); // user inputs - const [lmProvider, setLmProvider] = - useState(null); - const [showLmLocalId, setShowLmLocalId] = useState(false); - const [helpMarkdown, setHelpMarkdown] = useState(null); - const [lmLocalId, setLmLocalId] = useState(''); - const lmGlobalId = useMemo(() => { - if (!lmProvider) { - return null; - } - - return lmProvider.id + ':' + lmLocalId; - }, [lmProvider, lmLocalId]); - - const [emGlobalId, setEmGlobalId] = useState(null); - const emProvider = useMemo(() => { - if (emGlobalId === null || server.state !== ServerInfoState.Ready) { - return null; - } - - return getProvider(emGlobalId, server.emProviders); - }, [emGlobalId, server]); - - const [apiKeys, setApiKeys] = useState>({}); const [sendWse, setSendWse] = useState(false); - const [fields, setFields] = useState>({}); // whether the form is currently saving const [saving, setSaving] = useState(false); + // provider/model settings + const [modelSettings, setModelSettings] = useState({ + fields: {}, + apiKeys: {}, + emGlobalId: null, + lmGlobalId: null + }); + /** * Effect: initialize inputs after fetching server info. */ @@ -77,79 +54,17 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { if (server.state !== ServerInfoState.Ready) { return; } - - setLmLocalId(server.lmLocalId); - setEmGlobalId(server.config.embeddings_provider_id); setSendWse(server.config.send_with_shift_enter); - setHelpMarkdown(server.lmProvider?.help ?? null); - if (server.lmProvider?.registry) { - setShowLmLocalId(true); - } - setLmProvider(server.lmProvider); }, [server]); - /** - * Effect: re-initialize apiKeys object whenever the selected LM/EM changes. - * Properties with a value of '' indicate necessary user input. - */ - useEffect(() => { - if (server.state !== ServerInfoState.Ready) { - return; - } - - const newApiKeys: Record = {}; - const lmAuth = lmProvider?.auth_strategy; - const emAuth = emProvider?.auth_strategy; - if ( - lmAuth?.type === 'env' && - !server.config.api_keys.includes(lmAuth.name) - ) { - newApiKeys[lmAuth.name] = ''; - } - if (lmAuth?.type === 'multienv') { - lmAuth.names.forEach(apiKey => { - if (!server.config.api_keys.includes(apiKey)) { - newApiKeys[apiKey] = ''; - } - }); - } - - if ( - emAuth?.type === 'env' && - !server.config.api_keys.includes(emAuth.name) - ) { - newApiKeys[emAuth.name] = ''; - } - if (emAuth?.type === 'multienv') { - emAuth.names.forEach(apiKey => { - if (!server.config.api_keys.includes(apiKey)) { - newApiKeys[apiKey] = ''; - } - }); - } - - setApiKeys(newApiKeys); - }, [lmProvider, emProvider, server]); - - /** - * Effect: re-initialize fields object whenever the selected LM changes. - */ - useEffect(() => { - if (server.state !== ServerInfoState.Ready || !lmGlobalId) { - return; - } - - const currFields: Record = - server.config.fields?.[lmGlobalId] ?? {}; - setFields(currFields); - }, [server, lmProvider]); - const handleSave = async () => { // compress fields with JSON values if (server.state !== ServerInfoState.Ready) { return; } + const { fields, lmGlobalId, emGlobalId, apiKeys } = modelSettings; + for (const fieldKey in fields) { const fieldVal = fields[fieldKey]; if (typeof fieldVal !== 'string' || !fieldVal.trim().startsWith('{')) { @@ -181,7 +96,6 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { setSaving(true); try { - await apiKeysAlert.clear(); await AiService.updateConfig(updateRequest); } catch (e) { console.error(e); @@ -244,112 +158,11 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { overflowY: 'auto' }} > - {/* Language model section */} -

Language model

- - {showLmLocalId && ( - setLmLocalId(e.target.value)} - fullWidth - /> - )} - {helpMarkdown && ( - - )} - {lmGlobalId && ( - - )} - - {/* Embedding model section */} -

Embedding model

- - - {/* API Keys section */} -

API Keys

- {/* API key inputs for newly-used providers */} - {Object.entries(apiKeys).map(([apiKeyName, apiKeyValue], idx) => ( - - setApiKeys(apiKeys => ({ - ...apiKeys, - [apiKeyName]: e.target.value - })) - } - /> - ))} - {/* Pre-existing API keys */} - {/* Input */} @@ -391,12 +204,3 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { ); } - -function getProvider( - globalModelId: string, - providers: AiService.ListProvidersResponse -): AiService.ListProvidersEntry | null { - const providerId = getProviderId(globalModelId); - const provider = providers.providers.find(p => p.id === providerId); - return provider ?? null; -} diff --git a/packages/jupyter-ai/src/components/model-settings.tsx b/packages/jupyter-ai/src/components/model-settings.tsx new file mode 100644 index 000000000..e99aa7480 --- /dev/null +++ b/packages/jupyter-ai/src/components/model-settings.tsx @@ -0,0 +1,308 @@ +import React, { useEffect, useState, useMemo } from 'react'; + +import { Box } from '@mui/system'; +import { Alert, MenuItem, TextField, CircularProgress } from '@mui/material'; + +import { Select } from './select'; +import { AiService } from '../handler'; +import { ModelFields } from './settings/model-fields'; +import { ServerInfoState, useServerInfo } from './settings/use-server-info'; +import { ExistingApiKeys } from './settings/existing-api-keys'; +import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; +import { useStackingAlert } from './mui-extras/stacking-alert'; +import { RendermimeMarkdown } from './rendermime-markdown'; +import { getProviderId, getModelLocalId } from '../utils'; + +type ModelSettingsProps = { + rmRegistry: IRenderMimeRegistry; + label: string; + onChange: (settings: IModelSettings) => void; + modelKind: 'chat' | 'completions'; +}; + +export interface IModelSettings { + fields: Record; + apiKeys: Record; + emGlobalId: string | null; + lmGlobalId: string | null; +} + +/** + * Component that returns the settings view in the chat panel. + */ +export function ModelSettings(props: ModelSettingsProps): JSX.Element { + // state fetched on initial render + const server = useServerInfo(); + + // initialize alert helper + const apiKeysAlert = useStackingAlert(); + + // user inputs + const [lmProvider, setLmProvider] = + useState(null); + const [showLmLocalId, setShowLmLocalId] = useState(false); + const [helpMarkdown, setHelpMarkdown] = useState(null); + const [lmLocalId, setLmLocalId] = useState(''); + const lmGlobalId = useMemo(() => { + if (!lmProvider) { + return null; + } + + return lmProvider.id + ':' + lmLocalId; + }, [lmProvider, lmLocalId]); + + const [emGlobalId, setEmGlobalId] = + useState(null); + const emProvider = useMemo(() => { + if (emGlobalId === null || server.state !== ServerInfoState.Ready) { + return null; + } + + return getProvider(emGlobalId, server.emProviders); + }, [emGlobalId, server]); + + const [apiKeys, setApiKeys] = useState({}); + const [fields, setFields] = useState({}); + + /** + * Effect: initialize inputs after fetching server info. + */ + useEffect(() => { + if (server.state !== ServerInfoState.Ready) { + return; + } + const kind = props.modelKind; + + setLmLocalId(server[kind].lmLocalId); + setEmGlobalId( + kind === 'chat' + ? server.config.embeddings_provider_id + : server.config.completions_embeddings_provider_id + ); + setHelpMarkdown(server[kind].lmProvider?.help ?? null); + if (server[kind].lmProvider?.registry) { + setShowLmLocalId(true); + } + setLmProvider(server[kind].lmProvider); + }, [server]); + + /** + * Effect: re-initialize apiKeys object whenever the selected LM/EM changes. + * Properties with a value of '' indicate necessary user input. + */ + useEffect(() => { + if (server.state !== ServerInfoState.Ready) { + return; + } + + const newApiKeys: Record = {}; + const lmAuth = lmProvider?.auth_strategy; + const emAuth = emProvider?.auth_strategy; + if ( + lmAuth?.type === 'env' && + !server.config.api_keys.includes(lmAuth.name) + ) { + newApiKeys[lmAuth.name] = ''; + } + if (lmAuth?.type === 'multienv') { + lmAuth.names.forEach(apiKey => { + if (!server.config.api_keys.includes(apiKey)) { + newApiKeys[apiKey] = ''; + } + }); + } + + if ( + emAuth?.type === 'env' && + !server.config.api_keys.includes(emAuth.name) + ) { + newApiKeys[emAuth.name] = ''; + } + if (emAuth?.type === 'multienv') { + emAuth.names.forEach(apiKey => { + if (!server.config.api_keys.includes(apiKey)) { + newApiKeys[apiKey] = ''; + } + }); + } + + setApiKeys(newApiKeys); + }, [lmProvider, emProvider, server]); + + /** + * Effect: re-initialize fields object whenever the selected LM changes. + */ + useEffect(() => { + if (server.state !== ServerInfoState.Ready || !lmGlobalId) { + return; + } + + const currFields: Record = + server.config.fields?.[lmGlobalId] ?? {}; + setFields(currFields); + }, [server, lmProvider]); + + useEffect(() => { + props.onChange({ + fields, + apiKeys, + lmGlobalId, + emGlobalId + }); + }, [lmProvider, emProvider, apiKeys, fields]); + + if (server.state === ServerInfoState.Loading) { + return ( + + + + ); + } + + if (server.state === ServerInfoState.Error) { + return ( + <> + + {server.error || + 'An unknown error occurred. Check the console for more details.'} + + + ); + } + + return ( + <> + {/* Language model section */} +

{props.label}

+ + {showLmLocalId && ( + setLmLocalId(e.target.value)} + fullWidth + /> + )} + {helpMarkdown && ( + + )} + {lmGlobalId && ( + + )} + + {/* Embedding model section */} +

Embedding model

+ + + {/* API Keys section */} +

API Keys

+ {/* API key inputs for newly-used providers */} + {Object.entries(apiKeys).map(([apiKeyName, apiKeyValue], idx) => ( + + setApiKeys(apiKeys => ({ + ...apiKeys, + [apiKeyName]: e.target.value + })) + } + /> + ))} + {/* Pre-existing API keys */} + + + ); +} + +function getProvider( + globalModelId: string, + providers: AiService.ListProvidersResponse +): AiService.ListProvidersEntry | null { + const providerId = getProviderId(globalModelId); + const provider = providers.providers.find(p => p.id === providerId); + return provider ?? null; +} diff --git a/packages/jupyter-ai/src/components/settings/use-server-info.ts b/packages/jupyter-ai/src/components/settings/use-server-info.ts index 80c77fbed..3695bfe41 100644 --- a/packages/jupyter-ai/src/components/settings/use-server-info.ts +++ b/packages/jupyter-ai/src/components/settings/use-server-info.ts @@ -2,15 +2,20 @@ import { useState, useEffect, useMemo, useCallback } from 'react'; import { AiService } from '../../handler'; import { getProviderId, getModelLocalId } from '../../utils'; -type ServerInfoProperties = { - config: AiService.DescribeConfigResponse; - lmProviders: AiService.ListProvidersResponse; - emProviders: AiService.ListProvidersResponse; +type ProvidersInfo = { lmProvider: AiService.ListProvidersEntry | null; emProvider: AiService.ListProvidersEntry | null; lmLocalId: string; }; +type ServerInfoProperties = { + lmProviders: AiService.ListProvidersResponse; + emProviders: AiService.ListProvidersResponse; + config: AiService.DescribeConfigResponse; + chat: ProvidersInfo; + completions: ProvidersInfo; +}; + type ServerInfoMethods = { refetchAll: () => Promise; refetchApiKeys: () => Promise; @@ -65,13 +70,29 @@ export function useServerInfo(): ServerInfo { const emProvider = emGid === null ? null : getProvider(emGid, emProviders); const lmLocalId = (lmGid && getModelLocalId(lmGid)) ?? ''; + + const cLmGid = config.completions_model_provider_id; + const cEmGid = config.completions_embeddings_provider_id; + const cLmProvider = + cLmGid === null ? null : getProvider(cLmGid, lmProviders); + const cEmProvider = + cEmGid === null ? null : getProvider(cEmGid, emProviders); + const cLmLocalId = (cLmGid && getModelLocalId(cLmGid)) ?? ''; + setServerInfoProps({ config, lmProviders, emProviders, - lmProvider, - emProvider, - lmLocalId + chat: { + lmProvider, + emProvider, + lmLocalId + }, + completions: { + lmProvider: cLmProvider, + emProvider: cEmProvider, + lmLocalId: cLmLocalId + } }); setState(ServerInfoState.Ready); diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index 7848dc20e..600a257f6 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -117,6 +117,8 @@ export namespace AiService { send_with_shift_enter: boolean; fields: Record>; last_read: number; + completions_model_provider_id: string | null; + completions_embeddings_provider_id: string | null; }; export type UpdateConfigRequest = { @@ -126,6 +128,9 @@ export namespace AiService { send_with_shift_enter?: boolean; fields?: Record>; last_read?: number; + completions_model_provider_id?: string | null; + completions_embeddings_provider_id?: string | null; + completions_fields?: Record>; }; export async function getConfig(): Promise { @@ -182,6 +187,8 @@ export namespace AiService { help?: string; auth_strategy: AuthStrategy; registry: boolean; + completion_models: string[]; + chat_models: string[]; fields: Field[]; };