From 575f3f48cc854cd247459eeea5357d266171577a Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Wed, 17 Apr 2024 12:54:11 +0100 Subject: [PATCH] Implement the completion model selection in chat UI --- packages/jupyter-ai/src/completions/plugin.ts | 296 ++++++-------- .../jupyter-ai/src/completions/provider.ts | 12 +- .../jupyter-ai/src/completions/settings.tsx | 178 --------- .../src/components/chat-settings.tsx | 378 +++++++++++++++++- packages/jupyter-ai/src/components/chat.tsx | 9 +- .../src/components/model-settings.tsx | 308 -------------- packages/jupyter-ai/src/index.ts | 21 +- packages/jupyter-ai/src/tokens.ts | 14 + .../jupyter-ai/src/widgets/chat-sidebar.tsx | 7 +- 9 files changed, 543 insertions(+), 680 deletions(-) delete mode 100644 packages/jupyter-ai/src/completions/settings.tsx delete mode 100644 packages/jupyter-ai/src/components/model-settings.tsx diff --git a/packages/jupyter-ai/src/completions/plugin.ts b/packages/jupyter-ai/src/completions/plugin.ts index 3b2b6f2cf..adf3db556 100644 --- a/packages/jupyter-ai/src/completions/plugin.ts +++ b/packages/jupyter-ai/src/completions/plugin.ts @@ -4,18 +4,14 @@ import { } from '@jupyterlab/application'; import { ICompletionProviderManager } from '@jupyterlab/completer'; import { ISettingRegistry } from '@jupyterlab/settingregistry'; -import { MainAreaWidget } from '@jupyterlab/apputils'; -import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { IEditorLanguageRegistry, IEditorLanguage } from '@jupyterlab/codemirror'; import { getEditor } from '../selection-watcher'; -import { IJaiStatusItem } from '../tokens'; +import { IJaiStatusItem, IJaiCompletionProvider } from '../tokens'; import { displayName, JaiInlineProvider } from './provider'; import { CompletionWebsocketHandler } from './handler'; -import { jupyternautIcon } from '../icons'; -import { ModelSettingsWidget } from './settings'; export namespace CommandIDs { /** @@ -27,10 +23,6 @@ export namespace CommandIDs { */ export const toggleLanguageCompletions = 'jupyter-ai:toggle-language-completions'; - /** - * Command to open provider/model configuration. - */ - export const configureModel = 'jupyter-ai:configure-completions'; } const INLINE_COMPLETER_PLUGIN = @@ -54,182 +46,148 @@ type IcPluginSettings = ISettingRegistry.ISettings & { }; }; -export const completionPlugin: JupyterFrontEndPlugin = { - id: 'jupyter_ai:inline-completions', - autoStart: true, - requires: [ - ICompletionProviderManager, - IEditorLanguageRegistry, - ISettingRegistry, - IRenderMimeRegistry - ], - optional: [IJaiStatusItem], - activate: async ( - app: JupyterFrontEnd, - completionManager: ICompletionProviderManager, - languageRegistry: IEditorLanguageRegistry, - settingRegistry: ISettingRegistry, - rmRegistry: IRenderMimeRegistry, - statusItem: IJaiStatusItem | null - ): Promise => { - if (typeof completionManager.registerInlineProvider === 'undefined') { - // Gracefully short-circuit on JupyterLab 4.0 and Notebook 7.0 - console.warn( - 'Inline completions are only supported in JupyterLab 4.1+ and Jupyter Notebook 7.1+' - ); - return; - } - - const completionHandler = new CompletionWebsocketHandler(); - const provider = new JaiInlineProvider({ - completionHandler, - languageRegistry - }); - - await completionHandler.initialize(); - completionManager.registerInlineProvider(provider); - - const findCurrentLanguage = (): IEditorLanguage | null => { - const widget = app.shell.currentWidget; - const editor = getEditor(widget); - if (!editor) { +export const completionPlugin: JupyterFrontEndPlugin = + { + id: 'jupyter_ai:inline-completions', + autoStart: true, + requires: [ + ICompletionProviderManager, + IEditorLanguageRegistry, + ISettingRegistry + ], + optional: [IJaiStatusItem], + provides: IJaiCompletionProvider, + activate: async ( + app: JupyterFrontEnd, + completionManager: ICompletionProviderManager, + languageRegistry: IEditorLanguageRegistry, + settingRegistry: ISettingRegistry, + statusItem: IJaiStatusItem | null + ): Promise => { + if (typeof completionManager.registerInlineProvider === 'undefined') { + // Gracefully short-circuit on JupyterLab 4.0 and Notebook 7.0 + console.warn( + 'Inline completions are only supported in JupyterLab 4.1+ and Jupyter Notebook 7.1+' + ); return null; } - return languageRegistry.findByMIME(editor.model.mimeType); - }; - // ic := inline completion - async function getIcSettings() { - return (await settingRegistry.load( - INLINE_COMPLETER_PLUGIN - )) as IcPluginSettings; - } + const completionHandler = new CompletionWebsocketHandler(); + const provider = new JaiInlineProvider({ + completionHandler, + languageRegistry + }); - /** - * Gets the composite settings for the Jupyter AI inline completion provider - * (JaiIcp). - * - * This reads from the `ISettings.composite` property, which merges the user - * settings with the provider defaults, defined in - * `JaiInlineProvider.DEFAULT_SETTINGS`. - */ - async function getJaiIcpSettings() { - const icSettings = await getIcSettings(); - return icSettings.composite.providers[JaiInlineProvider.ID]; - } + await completionHandler.initialize(); + completionManager.registerInlineProvider(provider); - /** - * Updates the JaiIcp user settings. - */ - async function updateJaiIcpSettings( - newJaiIcpSettings: Partial - ) { - const icSettings = await getIcSettings(); - const oldUserIcpSettings = icSettings.user.providers; - const newUserIcpSettings = { - ...oldUserIcpSettings, - [JaiInlineProvider.ID]: { - ...oldUserIcpSettings?.[JaiInlineProvider.ID], - ...newJaiIcpSettings + const findCurrentLanguage = (): IEditorLanguage | null => { + const widget = app.shell.currentWidget; + const editor = getEditor(widget); + if (!editor) { + return null; } + return languageRegistry.findByMIME(editor.model.mimeType); }; - icSettings.set('providers', newUserIcpSettings); - } - app.commands.addCommand(CommandIDs.toggleCompletions, { - execute: async () => { - const jaiIcpSettings = await getJaiIcpSettings(); - updateJaiIcpSettings({ - enabled: !jaiIcpSettings.enabled - }); - }, - label: 'Enable completions by Jupyternaut', - isToggled: () => { - return provider.isEnabled(); + // ic := inline completion + async function getIcSettings() { + return (await settingRegistry.load( + INLINE_COMPLETER_PLUGIN + )) as IcPluginSettings; } - }); - app.commands.addCommand(CommandIDs.toggleLanguageCompletions, { - execute: async () => { - const jaiIcpSettings = await getJaiIcpSettings(); - const language = findCurrentLanguage(); - if (!language) { - return; - } - - const disabledLanguages = [...jaiIcpSettings.disabledLanguages]; - const newDisabledLanguages = disabledLanguages.includes(language.name) - ? disabledLanguages.filter(l => l !== language.name) - : disabledLanguages.concat(language.name); + /** + * Gets the composite settings for the Jupyter AI inline completion provider + * (JaiIcp). + * + * This reads from the `ISettings.composite` property, which merges the user + * settings with the provider defaults, defined in + * `JaiInlineProvider.DEFAULT_SETTINGS`. + */ + async function getJaiIcpSettings() { + const icSettings = await getIcSettings(); + return icSettings.composite.providers[JaiInlineProvider.ID]; + } - updateJaiIcpSettings({ - disabledLanguages: newDisabledLanguages - }); - }, - label: () => { - const language = findCurrentLanguage(); - return language - ? `Disable completions in ${displayName(language)}` - : 'Disable completions in files'; - }, - isToggled: () => { - const language = findCurrentLanguage(); - return !!language && !provider.isLanguageEnabled(language.name); - }, - isVisible: () => { - const language = findCurrentLanguage(); - return !!language; - }, - isEnabled: () => { - const language = findCurrentLanguage(); - return !!language && provider.isEnabled(); + /** + * Updates the JaiIcp user settings. + */ + async function updateJaiIcpSettings( + newJaiIcpSettings: Partial + ) { + const icSettings = await getIcSettings(); + const oldUserIcpSettings = icSettings.user.providers; + const newUserIcpSettings = { + ...oldUserIcpSettings, + [JaiInlineProvider.ID]: { + ...oldUserIcpSettings?.[JaiInlineProvider.ID], + ...newJaiIcpSettings + } + }; + icSettings.set('providers', newUserIcpSettings); } - }); - let settingsWidget: MainAreaWidget | null = null; - const newSettingsWidget = () => { - const content = new ModelSettingsWidget({ - rmRegistry, - isProviderEnabled: () => provider.isEnabled(), - openInlineCompleterSettings: () => { - app.commands.execute('settingeditor:open', { - query: 'Inline Completer' + app.commands.addCommand(CommandIDs.toggleCompletions, { + execute: async () => { + const jaiIcpSettings = await getJaiIcpSettings(); + updateJaiIcpSettings({ + enabled: !jaiIcpSettings.enabled }); + }, + label: 'Enable completions by Jupyternaut', + isToggled: () => { + return provider.isEnabled(); } }); - const widget = new MainAreaWidget({ content }); - widget.id = 'jupyterlab-inline-completions-model'; - widget.title.label = 'Completer Model Settings'; - widget.title.closable = true; - widget.title.icon = jupyternautIcon; - return widget; - }; - app.commands.addCommand(CommandIDs.configureModel, { - execute: () => { - if (!settingsWidget || settingsWidget.isDisposed) { - settingsWidget = newSettingsWidget(); - } - if (!settingsWidget.isAttached) { - app.shell.add(settingsWidget, 'main'); - } - app.shell.activateById(settingsWidget.id); - }, - label: 'Configure Jupyternaut Completions Model' - }); - if (statusItem) { - statusItem.addItem({ - command: CommandIDs.toggleCompletions, - rank: 1 - }); - statusItem.addItem({ - command: CommandIDs.toggleLanguageCompletions, - rank: 2 - }); - statusItem.addItem({ - command: CommandIDs.configureModel, - rank: 3 + app.commands.addCommand(CommandIDs.toggleLanguageCompletions, { + execute: async () => { + const jaiIcpSettings = await getJaiIcpSettings(); + const language = findCurrentLanguage(); + if (!language) { + return; + } + + const disabledLanguages = [...jaiIcpSettings.disabledLanguages]; + const newDisabledLanguages = disabledLanguages.includes(language.name) + ? disabledLanguages.filter(l => l !== language.name) + : disabledLanguages.concat(language.name); + + updateJaiIcpSettings({ + disabledLanguages: newDisabledLanguages + }); + }, + label: () => { + const language = findCurrentLanguage(); + return language + ? `Disable completions in ${displayName(language)}` + : 'Disable completions in files'; + }, + isToggled: () => { + const language = findCurrentLanguage(); + return !!language && !provider.isLanguageEnabled(language.name); + }, + isVisible: () => { + const language = findCurrentLanguage(); + return !!language; + }, + isEnabled: () => { + const language = findCurrentLanguage(); + return !!language && provider.isEnabled(); + } }); + + if (statusItem) { + statusItem.addItem({ + command: CommandIDs.toggleCompletions, + rank: 1 + }); + statusItem.addItem({ + command: CommandIDs.toggleLanguageCompletions, + rank: 2 + }); + } + return provider; } - } -}; + }; diff --git a/packages/jupyter-ai/src/completions/provider.ts b/packages/jupyter-ai/src/completions/provider.ts index b1e0f6a89..786ced851 100644 --- a/packages/jupyter-ai/src/completions/provider.ts +++ b/packages/jupyter-ai/src/completions/provider.ts @@ -9,11 +9,13 @@ import { import { ISettingRegistry } from '@jupyterlab/settingregistry'; import { Notification, showErrorMessage } from '@jupyterlab/apputils'; import { JSONValue, PromiseDelegate } from '@lumino/coreutils'; +import { ISignal, Signal } from '@lumino/signaling'; import { IEditorLanguageRegistry, IEditorLanguage } from '@jupyterlab/codemirror'; import { NotebookPanel } from '@jupyterlab/notebook'; +import { IJaiCompletionProvider } from '../tokens'; import { AiCompleterService as AiService } from './types'; import { DocumentWidget } from '@jupyterlab/docregistry'; import { jupyternautIcon } from '../icons'; @@ -34,7 +36,9 @@ export function displayName(language: IEditorLanguage): string { return language.displayName ?? language.name; } -export class JaiInlineProvider implements IInlineCompletionProvider { +export class JaiInlineProvider + implements IInlineCompletionProvider, IJaiCompletionProvider +{ readonly identifier = JaiInlineProvider.ID; readonly icon = jupyternautIcon.bindprops({ width: 16, top: 1 }); @@ -181,6 +185,7 @@ export class JaiInlineProvider implements IInlineCompletionProvider { async configure(settings: { [property: string]: JSONValue }): Promise { this._settings = settings as unknown as JaiInlineProvider.ISettings; + this._settingsChanged.emit(); } isEnabled(): boolean { @@ -191,6 +196,10 @@ export class JaiInlineProvider implements IInlineCompletionProvider { return !this._settings.disabledLanguages.includes(language); } + get settingsChanged(): ISignal { + return this._settingsChanged; + } + /** * Process the stream chunk to make it available in the awaiting generator. */ @@ -250,6 +259,7 @@ export class JaiInlineProvider implements IInlineCompletionProvider { private _settings: JaiInlineProvider.ISettings = JaiInlineProvider.DEFAULT_SETTINGS; + private _settingsChanged = new Signal(this); private _streamPromises: Map> = new Map(); diff --git a/packages/jupyter-ai/src/completions/settings.tsx b/packages/jupyter-ai/src/completions/settings.tsx deleted file mode 100644 index fea8d7c89..000000000 --- a/packages/jupyter-ai/src/completions/settings.tsx +++ /dev/null @@ -1,178 +0,0 @@ -import { ReactWidget } from '@jupyterlab/ui-components'; -import React, { useState } from 'react'; - -import { Box } from '@mui/system'; -import { Alert, Button, CircularProgress } from '@mui/material'; - -import { AiService } from '../handler'; -import { - ServerInfoState, - useServerInfo -} from '../components/settings/use-server-info'; -import { ModelSettings, IModelSettings } from '../components/model-settings'; -import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; -import { minifyUpdate } from '../components/settings/minify'; -import { useStackingAlert } from '../components/mui-extras/stacking-alert'; - -type CompleterSettingsProps = { - rmRegistry: IRenderMimeRegistry; - isProviderEnabled: () => boolean; - openInlineCompleterSettings: () => void; -}; - -/** - * Component that returns the settings view. - */ -export function CompleterSettings(props: CompleterSettingsProps): JSX.Element { - // state fetched on initial render - const server = useServerInfo(); - - // initialize alert helper - const alert = useStackingAlert(); - - // whether the form is currently saving - const [saving, setSaving] = useState(false); - - // provider/model settings - const [modelSettings, setModelSettings] = useState({ - fields: {}, - apiKeys: {}, - emGlobalId: null, - lmGlobalId: null - }); - - const handleSave = async () => { - // compress fields with JSON values - if (server.state !== ServerInfoState.Ready) { - return; - } - - const { fields, lmGlobalId, emGlobalId, apiKeys } = modelSettings; - - for (const fieldKey in fields) { - const fieldVal = fields[fieldKey]; - if (typeof fieldVal !== 'string' || !fieldVal.trim().startsWith('{')) { - continue; - } - - try { - const parsedFieldVal = JSON.parse(fieldVal); - const compressedFieldVal = JSON.stringify(parsedFieldVal); - fields[fieldKey] = compressedFieldVal; - } catch (e) { - continue; - } - } - - let updateRequest: AiService.UpdateConfigRequest = { - completions_model_provider_id: lmGlobalId, - completions_embeddings_provider_id: emGlobalId, - api_keys: apiKeys, - ...(lmGlobalId && { - completions_fields: { - [lmGlobalId]: fields - } - }) - }; - updateRequest = minifyUpdate(server.config, updateRequest); - updateRequest.last_read = server.config.last_read; - - setSaving(true); - try { - await AiService.updateConfig(updateRequest); - } catch (e) { - console.error(e); - const msg = - e instanceof Error || typeof e === 'string' - ? e.toString() - : 'An unknown error occurred. Check the console for more details.'; - alert.show('error', msg); - return; - } finally { - setSaving(false); - } - await server.refetchAll(); - alert.show('success', 'Settings saved successfully.'); - }; - - if (server.state === ServerInfoState.Loading) { - return ( - - - - ); - } - - if (server.state === ServerInfoState.Error) { - return ( - - - {server.error || - 'An unknown error occurred. Check the console for more details.'} - - - ); - } - - return ( - - {props.isProviderEnabled() ? null : ( - - The jupyter-ai inline completion provider is not enabled in the Inline - Completer settings. - - )} - - - - - - - - {alert.jsx} - - ); -} - -export class ModelSettingsWidget extends ReactWidget { - constructor(protected options: CompleterSettingsProps) { - super(); - } - render(): JSX.Element { - return ; - } -} diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index 2a2da034e..0d811ca7e 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -1,26 +1,39 @@ -import React, { useEffect, useState } from 'react'; +import React, { useEffect, useState, useMemo } from 'react'; import { Box } from '@mui/system'; import { Alert, Button, + IconButton, FormControl, FormControlLabel, FormLabel, + MenuItem, Radio, RadioGroup, + TextField, CircularProgress } from '@mui/material'; +import SettingsIcon from '@mui/icons-material/Settings'; +import WarningAmberIcon from '@mui/icons-material/WarningAmber'; +import { UseSignal } from '@jupyterlab/ui-components'; +import { Select } from './select'; import { AiService } from '../handler'; +import { ModelFields } from './settings/model-fields'; import { ServerInfoState, useServerInfo } from './settings/use-server-info'; -import { ModelSettings, IModelSettings } from './model-settings'; +import { ExistingApiKeys } from './settings/existing-api-keys'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { minifyUpdate } from './settings/minify'; import { useStackingAlert } from './mui-extras/stacking-alert'; +import { RendermimeMarkdown } from './rendermime-markdown'; +import { IJaiCompletionProvider } from '../tokens'; +import { getProviderId, getModelLocalId } from '../utils'; type ChatSettingsProps = { rmRegistry: IRenderMimeRegistry; + completionProvider: IJaiCompletionProvider | null; + openInlineCompleterSettings: () => void; }; /** @@ -32,21 +45,53 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { // initialize alert helper const alert = useStackingAlert(); + const apiKeysAlert = useStackingAlert(); // user inputs + const [lmProvider, setLmProvider] = + useState(null); + const [clmProvider, setClmProvider] = + useState(null); + const [showLmLocalId, setShowLmLocalId] = useState(false); + const [showClmLocalId, setShowClmLocalId] = useState(false); + const [chatHelpMarkdown, setChatHelpMarkdown] = useState(null); + const [completionHelpMarkdown, setCompletionHelpMarkdown] = useState< + string | null + >(null); + const [lmLocalId, setLmLocalId] = useState(''); + const [clmLocalId, setClmLocalId] = useState(''); + + const lmGlobalId = useMemo(() => { + if (!lmProvider) { + return null; + } + + return lmProvider.id + ':' + lmLocalId; + }, [lmProvider, lmLocalId]); + const clmGlobalId = useMemo(() => { + if (!clmProvider) { + return null; + } + + return clmProvider.id + ':' + clmLocalId; + }, [clmProvider, clmLocalId]); + + const [emGlobalId, setEmGlobalId] = useState(null); + const emProvider = useMemo(() => { + if (emGlobalId === null || server.state !== ServerInfoState.Ready) { + return null; + } + + return getProvider(emGlobalId, server.emProviders); + }, [emGlobalId, server]); + + const [apiKeys, setApiKeys] = useState>({}); const [sendWse, setSendWse] = useState(false); + const [fields, setFields] = useState>({}); // whether the form is currently saving const [saving, setSaving] = useState(false); - // provider/model settings - const [modelSettings, setModelSettings] = useState({ - fields: {}, - apiKeys: {}, - emGlobalId: null, - lmGlobalId: null - }); - /** * Effect: initialize inputs after fetching server info. */ @@ -54,17 +99,85 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { if (server.state !== ServerInfoState.Ready) { return; } + + setLmLocalId(server.chat.lmLocalId); + setClmLocalId(server.completions.lmLocalId); + setEmGlobalId(server.config.embeddings_provider_id); setSendWse(server.config.send_with_shift_enter); + setChatHelpMarkdown(server.chat.lmProvider?.help ?? null); + setCompletionHelpMarkdown(server.completions.lmProvider?.help ?? null); + if (server.chat.lmProvider?.registry) { + setShowLmLocalId(true); + } + if (server.completions.lmProvider?.registry) { + setShowClmLocalId(true); + } + setLmProvider(server.chat.lmProvider); + setClmProvider(server.completions.lmProvider); }, [server]); + /** + * Effect: re-initialize apiKeys object whenever the selected LM/EM changes. + * Properties with a value of '' indicate necessary user input. + */ + useEffect(() => { + if (server.state !== ServerInfoState.Ready) { + return; + } + + const newApiKeys: Record = {}; + const lmAuth = lmProvider?.auth_strategy; + const emAuth = emProvider?.auth_strategy; + if ( + lmAuth?.type === 'env' && + !server.config.api_keys.includes(lmAuth.name) + ) { + newApiKeys[lmAuth.name] = ''; + } + if (lmAuth?.type === 'multienv') { + lmAuth.names.forEach(apiKey => { + if (!server.config.api_keys.includes(apiKey)) { + newApiKeys[apiKey] = ''; + } + }); + } + + if ( + emAuth?.type === 'env' && + !server.config.api_keys.includes(emAuth.name) + ) { + newApiKeys[emAuth.name] = ''; + } + if (emAuth?.type === 'multienv') { + emAuth.names.forEach(apiKey => { + if (!server.config.api_keys.includes(apiKey)) { + newApiKeys[apiKey] = ''; + } + }); + } + + setApiKeys(newApiKeys); + }, [lmProvider, emProvider, server]); + + /** + * Effect: re-initialize fields object whenever the selected LM changes. + */ + useEffect(() => { + if (server.state !== ServerInfoState.Ready || !lmGlobalId) { + return; + } + + const currFields: Record = + server.config.fields?.[lmGlobalId] ?? {}; + setFields(currFields); + }, [server, lmProvider]); + const handleSave = async () => { // compress fields with JSON values if (server.state !== ServerInfoState.Ready) { return; } - const { fields, lmGlobalId, emGlobalId, apiKeys } = modelSettings; - for (const fieldKey in fields) { const fieldVal = fields[fieldKey]; if (typeof fieldVal !== 'string' || !fieldVal.trim().startsWith('{')) { @@ -84,11 +197,17 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { model_provider_id: lmGlobalId, embeddings_provider_id: emGlobalId, api_keys: apiKeys, - ...(lmGlobalId && { + ...((lmGlobalId || clmGlobalId) && { fields: { - [lmGlobalId]: fields + ...(lmGlobalId && { + [lmGlobalId]: fields + }), + ...(clmGlobalId && { + [clmGlobalId]: fields + }) } }), + completions_model_provider_id: clmGlobalId, send_with_shift_enter: sendWse }; updateRequest = minifyUpdate(server.config, updateRequest); @@ -96,6 +215,7 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { setSaving(true); try { + await apiKeysAlert.clear(); await AiService.updateConfig(updateRequest); } catch (e) { console.error(e); @@ -158,11 +278,195 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { overflowY: 'auto' }} > - Chat language model + + {showLmLocalId && ( + setLmLocalId(e.target.value)} + fullWidth + /> + )} + {chatHelpMarkdown && ( + + )} + {lmGlobalId && ( + + )} + + {/* Embedding model section */} +

Embedding model

+ + + {/* Completer language model section */} +

+ Completer model + {props.completionProvider ? ( + + {(): JSX.Element => ( + + )} + + ) : ( + + )} +

+ + {showClmLocalId && ( + setClmLocalId(e.target.value)} + fullWidth + /> + )} + {completionHelpMarkdown && ( + + )} + {clmGlobalId && ( + + )} + + {/* API Keys section */} +

API Keys

+ {/* API key inputs for newly-used providers */} + {Object.entries(apiKeys).map(([apiKeyName, apiKeyValue], idx) => ( + + setApiKeys(apiKeys => ({ + ...apiKeys, + [apiKeyName]: e.target.value + })) + } + /> + ))} + {/* Pre-existing API keys */} + {/* Input */} @@ -204,3 +508,39 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { ); } + +function CompleterSettingsButton(props: { + selection: AiService.ListProvidersEntry | null; + provider: IJaiCompletionProvider | null; + openSettings: () => void; +}): JSX.Element { + if (props.selection && !props.provider?.isEnabled()) { + return ( + + + + ); + } + return ( + + + + ); +} + +function getProvider( + globalModelId: string, + providers: AiService.ListProvidersResponse +): AiService.ListProvidersEntry | null { + const providerId = getProviderId(globalModelId); + const provider = providers.providers.find(p => p.id === providerId); + return provider ?? null; +} diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index c33d36f25..721d157f9 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -18,6 +18,7 @@ import { import { SelectionWatcher } from '../selection-watcher'; import { ChatHandler } from '../chat_handler'; import { CollaboratorsContextProvider } from '../contexts/collaborators-context'; +import { IJaiCompletionProvider } from '../tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { ScrollContainer } from './scroll-container'; @@ -185,6 +186,8 @@ export type ChatProps = { themeManager: IThemeManager | null; rmRegistry: IRenderMimeRegistry; chatView?: ChatView; + completionProvider: IJaiCompletionProvider | null; + openInlineCompleterSettings: () => void; }; enum ChatView { @@ -237,7 +240,11 @@ export function Chat(props: ChatProps): JSX.Element { /> )} {view === ChatView.Settings && ( - + )} diff --git a/packages/jupyter-ai/src/components/model-settings.tsx b/packages/jupyter-ai/src/components/model-settings.tsx deleted file mode 100644 index e99aa7480..000000000 --- a/packages/jupyter-ai/src/components/model-settings.tsx +++ /dev/null @@ -1,308 +0,0 @@ -import React, { useEffect, useState, useMemo } from 'react'; - -import { Box } from '@mui/system'; -import { Alert, MenuItem, TextField, CircularProgress } from '@mui/material'; - -import { Select } from './select'; -import { AiService } from '../handler'; -import { ModelFields } from './settings/model-fields'; -import { ServerInfoState, useServerInfo } from './settings/use-server-info'; -import { ExistingApiKeys } from './settings/existing-api-keys'; -import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; -import { useStackingAlert } from './mui-extras/stacking-alert'; -import { RendermimeMarkdown } from './rendermime-markdown'; -import { getProviderId, getModelLocalId } from '../utils'; - -type ModelSettingsProps = { - rmRegistry: IRenderMimeRegistry; - label: string; - onChange: (settings: IModelSettings) => void; - modelKind: 'chat' | 'completions'; -}; - -export interface IModelSettings { - fields: Record; - apiKeys: Record; - emGlobalId: string | null; - lmGlobalId: string | null; -} - -/** - * Component that returns the settings view in the chat panel. - */ -export function ModelSettings(props: ModelSettingsProps): JSX.Element { - // state fetched on initial render - const server = useServerInfo(); - - // initialize alert helper - const apiKeysAlert = useStackingAlert(); - - // user inputs - const [lmProvider, setLmProvider] = - useState(null); - const [showLmLocalId, setShowLmLocalId] = useState(false); - const [helpMarkdown, setHelpMarkdown] = useState(null); - const [lmLocalId, setLmLocalId] = useState(''); - const lmGlobalId = useMemo(() => { - if (!lmProvider) { - return null; - } - - return lmProvider.id + ':' + lmLocalId; - }, [lmProvider, lmLocalId]); - - const [emGlobalId, setEmGlobalId] = - useState(null); - const emProvider = useMemo(() => { - if (emGlobalId === null || server.state !== ServerInfoState.Ready) { - return null; - } - - return getProvider(emGlobalId, server.emProviders); - }, [emGlobalId, server]); - - const [apiKeys, setApiKeys] = useState({}); - const [fields, setFields] = useState({}); - - /** - * Effect: initialize inputs after fetching server info. - */ - useEffect(() => { - if (server.state !== ServerInfoState.Ready) { - return; - } - const kind = props.modelKind; - - setLmLocalId(server[kind].lmLocalId); - setEmGlobalId( - kind === 'chat' - ? server.config.embeddings_provider_id - : server.config.completions_embeddings_provider_id - ); - setHelpMarkdown(server[kind].lmProvider?.help ?? null); - if (server[kind].lmProvider?.registry) { - setShowLmLocalId(true); - } - setLmProvider(server[kind].lmProvider); - }, [server]); - - /** - * Effect: re-initialize apiKeys object whenever the selected LM/EM changes. - * Properties with a value of '' indicate necessary user input. - */ - useEffect(() => { - if (server.state !== ServerInfoState.Ready) { - return; - } - - const newApiKeys: Record = {}; - const lmAuth = lmProvider?.auth_strategy; - const emAuth = emProvider?.auth_strategy; - if ( - lmAuth?.type === 'env' && - !server.config.api_keys.includes(lmAuth.name) - ) { - newApiKeys[lmAuth.name] = ''; - } - if (lmAuth?.type === 'multienv') { - lmAuth.names.forEach(apiKey => { - if (!server.config.api_keys.includes(apiKey)) { - newApiKeys[apiKey] = ''; - } - }); - } - - if ( - emAuth?.type === 'env' && - !server.config.api_keys.includes(emAuth.name) - ) { - newApiKeys[emAuth.name] = ''; - } - if (emAuth?.type === 'multienv') { - emAuth.names.forEach(apiKey => { - if (!server.config.api_keys.includes(apiKey)) { - newApiKeys[apiKey] = ''; - } - }); - } - - setApiKeys(newApiKeys); - }, [lmProvider, emProvider, server]); - - /** - * Effect: re-initialize fields object whenever the selected LM changes. - */ - useEffect(() => { - if (server.state !== ServerInfoState.Ready || !lmGlobalId) { - return; - } - - const currFields: Record = - server.config.fields?.[lmGlobalId] ?? {}; - setFields(currFields); - }, [server, lmProvider]); - - useEffect(() => { - props.onChange({ - fields, - apiKeys, - lmGlobalId, - emGlobalId - }); - }, [lmProvider, emProvider, apiKeys, fields]); - - if (server.state === ServerInfoState.Loading) { - return ( - - - - ); - } - - if (server.state === ServerInfoState.Error) { - return ( - <> - - {server.error || - 'An unknown error occurred. Check the console for more details.'} - - - ); - } - - return ( - <> - {/* Language model section */} -

{props.label}

- - {showLmLocalId && ( - setLmLocalId(e.target.value)} - fullWidth - /> - )} - {helpMarkdown && ( - - )} - {lmGlobalId && ( - - )} - - {/* Embedding model section */} -

Embedding model

- - - {/* API Keys section */} -

API Keys

- {/* API key inputs for newly-used providers */} - {Object.entries(apiKeys).map(([apiKeyName, apiKeyValue], idx) => ( - - setApiKeys(apiKeys => ({ - ...apiKeys, - [apiKeyName]: e.target.value - })) - } - /> - ))} - {/* Pre-existing API keys */} - - - ); -} - -function getProvider( - globalModelId: string, - providers: AiService.ListProvidersResponse -): AiService.ListProvidersEntry | null { - const providerId = getProviderId(globalModelId); - const provider = providers.providers.find(p => p.id === providerId); - return provider ?? null; -} diff --git a/packages/jupyter-ai/src/index.ts b/packages/jupyter-ai/src/index.ts index df2ac69d0..d07a0dc34 100644 --- a/packages/jupyter-ai/src/index.ts +++ b/packages/jupyter-ai/src/index.ts @@ -18,6 +18,7 @@ import { ChatHandler } from './chat_handler'; import { buildErrorWidget } from './widgets/chat-error'; import { completionPlugin } from './completions'; import { statusItemPlugin } from './status'; +import { IJaiCompletionProvider } from './tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; export type DocumentTracker = IWidgetTracker; @@ -28,14 +29,20 @@ export type DocumentTracker = IWidgetTracker; const plugin: JupyterFrontEndPlugin = { id: 'jupyter_ai:plugin', autoStart: true, - optional: [IGlobalAwareness, ILayoutRestorer, IThemeManager], + optional: [ + IGlobalAwareness, + ILayoutRestorer, + IThemeManager, + IJaiCompletionProvider + ], requires: [IRenderMimeRegistry], activate: async ( app: JupyterFrontEnd, rmRegistry: IRenderMimeRegistry, globalAwareness: Awareness | null, restorer: ILayoutRestorer | null, - themeManager: IThemeManager | null + themeManager: IThemeManager | null, + completionProvider: IJaiCompletionProvider | null ) => { /** * Initialize selection watcher singleton @@ -47,6 +54,12 @@ const plugin: JupyterFrontEndPlugin = { */ const chatHandler = new ChatHandler(); + const openInlineCompleterSettings = () => { + app.commands.execute('settingeditor:open', { + query: 'Inline Completer' + }); + }; + let chatWidget: ReactWidget | null = null; try { await chatHandler.initialize(); @@ -55,7 +68,9 @@ const plugin: JupyterFrontEndPlugin = { chatHandler, globalAwareness, themeManager, - rmRegistry + rmRegistry, + completionProvider, + openInlineCompleterSettings ); } catch (e) { chatWidget = buildErrorWidget(themeManager); diff --git a/packages/jupyter-ai/src/tokens.ts b/packages/jupyter-ai/src/tokens.ts index f240a8197..924745ee1 100644 --- a/packages/jupyter-ai/src/tokens.ts +++ b/packages/jupyter-ai/src/tokens.ts @@ -1,4 +1,5 @@ import { Token } from '@lumino/coreutils'; +import { ISignal } from '@lumino/signaling'; import type { IRankedMenu } from '@jupyterlab/ui-components'; export interface IJaiStatusItem { @@ -12,3 +13,16 @@ export const IJaiStatusItem = new Token( 'jupyter_ai:IJupyternautStatus', 'Status indicator displayed in the statusbar' ); + +export interface IJaiCompletionProvider { + isEnabled(): boolean; + settingsChanged: ISignal; +} + +/** + * The incline completion provider token. + */ +export const IJaiCompletionProvider = new Token( + 'jupyter_ai:IJaiCompletionProvider', + 'Status the incline completion provider' +); diff --git a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx index 291593b41..bb575febc 100644 --- a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx +++ b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx @@ -7,6 +7,7 @@ import { Chat } from '../components/chat'; import { chatIcon } from '../icons'; import { SelectionWatcher } from '../selection-watcher'; import { ChatHandler } from '../chat_handler'; +import { IJaiCompletionProvider } from '../tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; export function buildChatSidebar( @@ -14,7 +15,9 @@ export function buildChatSidebar( chatHandler: ChatHandler, globalAwareness: Awareness | null, themeManager: IThemeManager | null, - rmRegistry: IRenderMimeRegistry + rmRegistry: IRenderMimeRegistry, + completionProvider: IJaiCompletionProvider | null, + openInlineCompleterSettings: () => void ): ReactWidget { const ChatWidget = ReactWidget.create( ); ChatWidget.id = 'jupyter-ai::chat';