diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 4daf42bba..e86613038 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -258,6 +258,16 @@ class Config: provider is selected. """ + @classmethod + def chat_models(self): + """Models which are suitable for chat.""" + return self.models + + @classmethod + def completion_models(self): + """Models which are suitable for completions.""" + return self.models + # # instance attrs # diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py index 9eb4f845a..32920dc83 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py @@ -5,7 +5,7 @@ from typing import Union import tornado -from jupyter_ai.completions.handlers.llm_mixin import LLMHandlerMixin +from jupyter_ai.completions.handlers.model_mixin import CompletionsModelMixin from jupyter_ai.completions.models import ( CompletionError, InlineCompletionList, @@ -18,7 +18,7 @@ class BaseInlineCompletionHandler( - LLMHandlerMixin, JupyterHandler, tornado.websocket.WebSocketHandler + CompletionsModelMixin, JupyterHandler, tornado.websocket.WebSocketHandler ): """A Tornado WebSocket handler that receives inline completion requests and fulfills them accordingly. This class is instantiated once per WebSocket diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py similarity index 89% rename from packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py rename to packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py index ed454fff6..fd19498bd 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py @@ -5,8 +5,8 @@ from jupyter_ai_magics.providers import BaseProvider -class LLMHandlerMixin: - """Base class containing shared methods and attributes used by LLM handler classes.""" +class CompletionsModelMixin: + """Mixin class containing methods and attributes used by completions LLM handler.""" handler_kind: str settings: dict @@ -26,8 +26,8 @@ def __init__(self, *args, **kwargs) -> None: self._llm_params = None def get_llm(self) -> Optional[BaseProvider]: - lm_provider = self.jai_config_manager.lm_provider - lm_provider_params = self.jai_config_manager.lm_provider_params + lm_provider = self.jai_config_manager.completions_lm_provider + lm_provider_params = self.jai_config_manager.completions_lm_provider_params if not lm_provider or not lm_provider_params: return None diff --git a/packages/jupyter-ai/jupyter_ai/config/config_schema.json b/packages/jupyter-ai/jupyter_ai/config/config_schema.json index ff7c717c4..d276dd31d 100644 --- a/packages/jupyter-ai/jupyter_ai/config/config_schema.json +++ b/packages/jupyter-ai/jupyter_ai/config/config_schema.json @@ -16,6 +16,12 @@ "default": null, "readOnly": false }, + "completions_model_provider_id": { + "$comment": "Language model global ID for completions.", + "type": ["string", "null"], + "default": null, + "readOnly": false + }, "api_keys": { "$comment": "Dictionary of API keys, mapping key names to key values.", "type": "object", @@ -37,6 +43,17 @@ } }, "additionalProperties": false + }, + "completions_fields": { + "$comment": "Dictionary of model-specific fields, mapping LM GIDs to sub-dictionaries of field key-value pairs for completions.", + "type": "object", + "default": {}, + "patternProperties": { + "^.*$": { + "anyOf": [{ "type": "object" }] + } + }, + "additionalProperties": false } } } diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 392e44601..a05f9b3ca 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -97,6 +97,10 @@ class ConfigManager(Configurable): config=True, ) + model_provider_id: Optional[str] + embeddings_provider_id: Optional[str] + completions_model_provider_id: Optional[str] + def __init__( self, log: Logger, @@ -164,41 +168,49 @@ def _process_existing_config(self, default_config): {k: v for k, v in existing_config.items() if v is not None}, ) config = GlobalConfig(**merged_config) - validated_config = self._validate_lm_em_id(config) + validated_config = self._validate_model_ids(config) # re-write to the file to validate the config and apply any # updates to the config file immediately self._write_config(validated_config) - def _validate_lm_em_id(self, config): - lm_id = config.model_provider_id - em_id = config.embeddings_provider_id + def _validate_model_ids(self, config): + lm_provider_keys = ["model_provider_id", "completions_model_provider_id"] + em_provider_keys = ["embeddings_provider_id"] # if the currently selected language or embedding model are # forbidden, set them to `None` and log a warning. - if lm_id is not None and not self._validate_model(lm_id, raise_exc=False): - self.log.warning( - f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." - ) - config.model_provider_id = None - if em_id is not None and not self._validate_model(em_id, raise_exc=False): - self.log.warning( - f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." - ) - config.embeddings_provider_id = None + for lm_key in lm_provider_keys: + lm_id = getattr(config, lm_key) + if lm_id is not None and not self._validate_model(lm_id, raise_exc=False): + self.log.warning( + f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." + ) + setattr(config, lm_key, None) + for em_key in em_provider_keys: + em_id = getattr(config, em_key) + if em_id is not None and not self._validate_model(em_id, raise_exc=False): + self.log.warning( + f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." + ) + setattr(config, em_key, None) # if the currently selected language or embedding model ids are # not associated with models, set them to `None` and log a warning. - if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]: - self.log.warning( - f"No language model is associated with '{lm_id}'. Setting to None." - ) - config.model_provider_id = None - if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]: - self.log.warning( - f"No embedding model is associated with '{em_id}'. Setting to None." - ) - config.embeddings_provider_id = None + for lm_key in lm_provider_keys: + lm_id = getattr(config, lm_key) + if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]: + self.log.warning( + f"No language model is associated with '{lm_id}'. Setting to None." + ) + setattr(config, lm_key, None) + for em_key in em_provider_keys: + em_id = getattr(config, em_key) + if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]: + self.log.warning( + f"No embedding model is associated with '{em_id}'. Setting to None." + ) + setattr(config, em_key, None) return config @@ -321,6 +333,9 @@ def _write_config(self, new_config: GlobalConfig): complete `GlobalConfig` object, and should not be called publicly.""" # remove any empty field dictionaries new_config.fields = {k: v for k, v in new_config.fields.items() if v} + new_config.completions_fields = { + k: v for k, v in new_config.completions_fields.items() if v + } self._validate_config(new_config) with open(self.config_path, "w") as f: @@ -328,21 +343,18 @@ def _write_config(self, new_config: GlobalConfig): def delete_api_key(self, key_name: str): config_dict = self._read_config().dict() - lm_provider = self.lm_provider - em_provider = self.em_provider required_keys = [] - if ( - lm_provider - and lm_provider.auth_strategy - and lm_provider.auth_strategy.type == "env" - ): - required_keys.append(lm_provider.auth_strategy.name) - if ( - em_provider - and em_provider.auth_strategy - and em_provider.auth_strategy.type == "env" - ): - required_keys.append(self.em_provider.auth_strategy.name) + for provider in [ + self.lm_provider, + self.em_provider, + self.completions_lm_provider, + ]: + if ( + provider + and provider.auth_strategy + and provider.auth_strategy.type == "env" + ): + required_keys.append(provider.auth_strategy.name) if key_name in required_keys: raise KeyInUseError( @@ -390,67 +402,57 @@ def em_gid(self): @property def lm_provider(self): - config = self._read_config() - lm_gid = config.model_provider_id - if lm_gid is None: - return None - - _, Provider = get_lm_provider(config.model_provider_id, self._lm_providers) - return Provider + return self._get_provider("model_provider_id", self._lm_providers) @property def em_provider(self): + return self._get_provider("embeddings_provider_id", self._em_providers) + + @property + def completions_lm_provider(self): + return self._get_provider("completions_model_provider_id", self._lm_providers) + + def _get_provider(self, key, listing): config = self._read_config() - em_gid = config.embeddings_provider_id - if em_gid is None: + gid = getattr(config, key) + if gid is None: return None - _, Provider = get_em_provider(em_gid, self._em_providers) + _, Provider = get_lm_provider(gid, listing) return Provider @property def lm_provider_params(self): - # get generic fields - config = self._read_config() - lm_gid = config.model_provider_id - if not lm_gid: - return None - - lm_lid = lm_gid.split(":", 1)[1] - fields = config.fields.get(lm_gid, {}) - - # get authn fields - _, Provider = get_lm_provider(lm_gid, self._lm_providers) - authn_fields = {} - if Provider.auth_strategy and Provider.auth_strategy.type == "env": - key_name = Provider.auth_strategy.name - authn_fields[key_name.lower()] = config.api_keys[key_name] - - return { - "model_id": lm_lid, - **fields, - **authn_fields, - } + return self._provider_params("model_provider_id", self._lm_providers) @property def em_provider_params(self): + return self._provider_params("embeddings_provider_id", self._em_providers) + + @property + def completions_lm_provider_params(self): + return self._provider_params( + "completions_model_provider_id", self._lm_providers + ) + + def _provider_params(self, key, listing): # get generic fields config = self._read_config() - em_gid = config.embeddings_provider_id - if not em_gid: + gid = getattr(config, key) + if not gid: return None - em_lid = em_gid.split(":", 1)[1] + lid = gid.split(":", 1)[1] # get authn fields - _, Provider = get_em_provider(em_gid, self._em_providers) + _, Provider = get_em_provider(gid, listing) authn_fields = {} if Provider.auth_strategy and Provider.auth_strategy.type == "env": key_name = Provider.auth_strategy.name authn_fields[key_name.lower()] = config.api_keys[key_name] return { - "model_id": em_lid, + "model_id": lid, **authn_fields, } diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 6002310ba..976b7df62 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -290,6 +290,10 @@ def filter_predicate(local_model_id: str): # filter out every model w/ model ID according to allow/blocklist for provider in providers: provider.models = list(filter(filter_predicate, provider.models)) + provider.chat_models = list(filter(filter_predicate, provider.chat_models)) + provider.completion_models = list( + filter(filter_predicate, provider.completion_models) + ) # filter out every provider with no models which satisfy the allow/blocklist, then return return filter((lambda p: len(p.models) > 0), providers) @@ -311,6 +315,8 @@ def get(self): id=provider.id, name=provider.name, models=provider.models, + chat_models=provider.chat_models(), + completion_models=provider.completion_models(), help=provider.help, auth_strategy=provider.auth_strategy, registry=provider.registry, diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 32353a694..6742a922d 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -94,6 +94,8 @@ class ListProvidersEntry(BaseModel): auth_strategy: AuthStrategy registry: bool fields: List[Field] + chat_models: Optional[List[str]] + completion_models: Optional[List[str]] class ListProvidersResponse(BaseModel): @@ -121,6 +123,8 @@ class DescribeConfigResponse(BaseModel): # timestamp indicating when the configuration file was last read. should be # passed to the subsequent UpdateConfig request. last_read: int + completions_model_provider_id: Optional[str] + completions_fields: Dict[str, Dict[str, Any]] def forbid_none(cls, v): @@ -137,6 +141,8 @@ class UpdateConfigRequest(BaseModel): # if passed, this will raise an Error if the config was written to after the # time specified by `last_read` to prevent write-write conflicts. last_read: Optional[int] + completions_model_provider_id: Optional[str] + completions_fields: Optional[Dict[str, Dict[str, Any]]] _validate_send_wse = validator("send_with_shift_enter", allow_reuse=True)( forbid_none @@ -154,3 +160,5 @@ class GlobalConfig(BaseModel): send_with_shift_enter: bool fields: Dict[str, Dict[str, Any]] api_keys: Dict[str, str] + completions_model_provider_id: Optional[str] + completions_fields: Dict[str, Dict[str, Any]] diff --git a/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr b/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr index cd254bedc..111ebe6ea 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr +++ b/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr @@ -3,6 +3,9 @@ dict({ 'api_keys': list([ ]), + 'completions_fields': dict({ + }), + 'completions_model_provider_id': None, 'embeddings_provider_id': None, 'fields': dict({ }), diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py index 8597463f2..c5b5d1eea 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py @@ -35,8 +35,8 @@ def __init__(self, lm_provider=None, lm_provider_params=None): self.messages = [] self.tasks = [] self.settings["jai_config_manager"] = SimpleNamespace( - lm_provider=lm_provider or MockProvider, - lm_provider_params=lm_provider_params or {"model_id": "model"}, + completions_lm_provider=lm_provider or MockProvider, + completions_lm_provider_params=lm_provider_params or {"model_id": "model"}, ) self.settings["jai_event_loop"] = SimpleNamespace( create_task=lambda x: self.tasks.append(x) diff --git a/packages/jupyter-ai/src/completions/plugin.ts b/packages/jupyter-ai/src/completions/plugin.ts index 835561143..f5963f36f 100644 --- a/packages/jupyter-ai/src/completions/plugin.ts +++ b/packages/jupyter-ai/src/completions/plugin.ts @@ -9,7 +9,7 @@ import { IEditorLanguage } from '@jupyterlab/codemirror'; import { getEditor } from '../selection-watcher'; -import { IJaiStatusItem } from '../tokens'; +import { IJaiStatusItem, IJaiCompletionProvider } from '../tokens'; import { displayName, JaiInlineProvider } from './provider'; import { CompletionWebsocketHandler } from './handler'; @@ -46,7 +46,9 @@ type IcPluginSettings = ISettingRegistry.ISettings & { }; }; -export const completionPlugin: JupyterFrontEndPlugin = { +type JaiCompletionToken = IJaiCompletionProvider | null; + +export const completionPlugin: JupyterFrontEndPlugin = { id: 'jupyter_ai:inline-completions', autoStart: true, requires: [ @@ -55,19 +57,20 @@ export const completionPlugin: JupyterFrontEndPlugin = { ISettingRegistry ], optional: [IJaiStatusItem], + provides: IJaiCompletionProvider, activate: async ( app: JupyterFrontEnd, completionManager: ICompletionProviderManager, languageRegistry: IEditorLanguageRegistry, settingRegistry: ISettingRegistry, statusItem: IJaiStatusItem | null - ): Promise => { + ): Promise => { if (typeof completionManager.registerInlineProvider === 'undefined') { // Gracefully short-circuit on JupyterLab 4.0 and Notebook 7.0 console.warn( 'Inline completions are only supported in JupyterLab 4.1+ and Jupyter Notebook 7.1+' ); - return; + return null; } const completionHandler = new CompletionWebsocketHandler(); @@ -186,5 +189,6 @@ export const completionPlugin: JupyterFrontEndPlugin = { rank: 2 }); } + return provider; } }; diff --git a/packages/jupyter-ai/src/completions/provider.ts b/packages/jupyter-ai/src/completions/provider.ts index b1e0f6a89..786ced851 100644 --- a/packages/jupyter-ai/src/completions/provider.ts +++ b/packages/jupyter-ai/src/completions/provider.ts @@ -9,11 +9,13 @@ import { import { ISettingRegistry } from '@jupyterlab/settingregistry'; import { Notification, showErrorMessage } from '@jupyterlab/apputils'; import { JSONValue, PromiseDelegate } from '@lumino/coreutils'; +import { ISignal, Signal } from '@lumino/signaling'; import { IEditorLanguageRegistry, IEditorLanguage } from '@jupyterlab/codemirror'; import { NotebookPanel } from '@jupyterlab/notebook'; +import { IJaiCompletionProvider } from '../tokens'; import { AiCompleterService as AiService } from './types'; import { DocumentWidget } from '@jupyterlab/docregistry'; import { jupyternautIcon } from '../icons'; @@ -34,7 +36,9 @@ export function displayName(language: IEditorLanguage): string { return language.displayName ?? language.name; } -export class JaiInlineProvider implements IInlineCompletionProvider { +export class JaiInlineProvider + implements IInlineCompletionProvider, IJaiCompletionProvider +{ readonly identifier = JaiInlineProvider.ID; readonly icon = jupyternautIcon.bindprops({ width: 16, top: 1 }); @@ -181,6 +185,7 @@ export class JaiInlineProvider implements IInlineCompletionProvider { async configure(settings: { [property: string]: JSONValue }): Promise { this._settings = settings as unknown as JaiInlineProvider.ISettings; + this._settingsChanged.emit(); } isEnabled(): boolean { @@ -191,6 +196,10 @@ export class JaiInlineProvider implements IInlineCompletionProvider { return !this._settings.disabledLanguages.includes(language); } + get settingsChanged(): ISignal { + return this._settingsChanged; + } + /** * Process the stream chunk to make it available in the awaiting generator. */ @@ -250,6 +259,7 @@ export class JaiInlineProvider implements IInlineCompletionProvider { private _settings: JaiInlineProvider.ISettings = JaiInlineProvider.DEFAULT_SETTINGS; + private _settingsChanged = new Signal(this); private _streamPromises: Map> = new Map(); diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index 889342a24..87009ecf5 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -4,6 +4,7 @@ import { Box } from '@mui/system'; import { Alert, Button, + IconButton, FormControl, FormControlLabel, FormLabel, @@ -11,8 +12,11 @@ import { Radio, RadioGroup, TextField, + Tooltip, CircularProgress } from '@mui/material'; +import SettingsIcon from '@mui/icons-material/Settings'; +import WarningAmberIcon from '@mui/icons-material/WarningAmber'; import { Select } from './select'; import { AiService } from '../handler'; @@ -23,10 +27,13 @@ import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { minifyUpdate } from './settings/minify'; import { useStackingAlert } from './mui-extras/stacking-alert'; import { RendermimeMarkdown } from './rendermime-markdown'; +import { IJaiCompletionProvider } from '../tokens'; import { getProviderId, getModelLocalId } from '../utils'; type ChatSettingsProps = { rmRegistry: IRenderMimeRegistry; + completionProvider: IJaiCompletionProvider | null; + openInlineCompleterSettings: () => void; }; /** @@ -43,9 +50,17 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { // user inputs const [lmProvider, setLmProvider] = useState(null); + const [clmProvider, setClmProvider] = + useState(null); const [showLmLocalId, setShowLmLocalId] = useState(false); - const [helpMarkdown, setHelpMarkdown] = useState(null); + const [showClmLocalId, setShowClmLocalId] = useState(false); + const [chatHelpMarkdown, setChatHelpMarkdown] = useState(null); + const [completionHelpMarkdown, setCompletionHelpMarkdown] = useState< + string | null + >(null); const [lmLocalId, setLmLocalId] = useState(''); + const [clmLocalId, setClmLocalId] = useState(''); + const lmGlobalId = useMemo(() => { if (!lmProvider) { return null; @@ -53,6 +68,13 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { return lmProvider.id + ':' + lmLocalId; }, [lmProvider, lmLocalId]); + const clmGlobalId = useMemo(() => { + if (!clmProvider) { + return null; + } + + return clmProvider.id + ':' + clmLocalId; + }, [clmProvider, clmLocalId]); const [emGlobalId, setEmGlobalId] = useState(null); const emProvider = useMemo(() => { @@ -67,6 +89,25 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { const [sendWse, setSendWse] = useState(false); const [fields, setFields] = useState>({}); + const [isCompleterEnabled, setIsCompleterEnabled] = useState( + props.completionProvider && props.completionProvider.isEnabled() + ); + + const refreshCompleterState = () => { + setIsCompleterEnabled( + props.completionProvider && props.completionProvider.isEnabled() + ); + }; + + useEffect(() => { + props.completionProvider?.settingsChanged.connect(refreshCompleterState); + return () => { + props.completionProvider?.settingsChanged.disconnect( + refreshCompleterState + ); + }; + }, [props.completionProvider]); + // whether the form is currently saving const [saving, setSaving] = useState(false); @@ -78,14 +119,20 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { return; } - setLmLocalId(server.lmLocalId); + setLmLocalId(server.chat.lmLocalId); + setClmLocalId(server.completions.lmLocalId); setEmGlobalId(server.config.embeddings_provider_id); setSendWse(server.config.send_with_shift_enter); - setHelpMarkdown(server.lmProvider?.help ?? null); - if (server.lmProvider?.registry) { + setChatHelpMarkdown(server.chat.lmProvider?.help ?? null); + setCompletionHelpMarkdown(server.completions.lmProvider?.help ?? null); + if (server.chat.lmProvider?.registry) { setShowLmLocalId(true); } - setLmProvider(server.lmProvider); + if (server.completions.lmProvider?.registry) { + setShowClmLocalId(true); + } + setLmProvider(server.chat.lmProvider); + setClmProvider(server.completions.lmProvider); }, [server]); /** @@ -169,11 +216,17 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { model_provider_id: lmGlobalId, embeddings_provider_id: emGlobalId, api_keys: apiKeys, - ...(lmGlobalId && { + ...((lmGlobalId || clmGlobalId) && { fields: { - [lmGlobalId]: fields + ...(lmGlobalId && { + [lmGlobalId]: fields + }), + ...(clmGlobalId && { + [clmGlobalId]: fields + }) } }), + completions_model_provider_id: clmGlobalId, send_with_shift_enter: sendWse }; updateRequest = minifyUpdate(server.config, updateRequest); @@ -244,11 +297,16 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { overflowY: 'auto' }} > - {/* Language model section */} -

Language model

+ {/* Chat language model section */} +

+ Language model +

{showLmLocalId && ( @@ -290,10 +350,10 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { fullWidth /> )} - {helpMarkdown && ( + {chatHelpMarkdown && ( )} {lmGlobalId && ( @@ -327,6 +387,77 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { )} + {/* Completer language model section */} +

+ Inline completions model + +

+ + {showClmLocalId && ( + setClmLocalId(e.target.value)} + fullWidth + /> + )} + {completionHelpMarkdown && ( + + )} + {clmGlobalId && ( + + )} + {/* API Keys section */}

API Keys

{/* API key inputs for newly-used providers */} @@ -392,6 +523,37 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { ); } +function CompleterSettingsButton(props: { + selection: AiService.ListProvidersEntry | null; + provider: IJaiCompletionProvider | null; + isCompleterEnabled: boolean | null; + openSettings: () => void; +}): JSX.Element { + if (props.selection && !props.isCompleterEnabled) { + return ( + + + + + + ); + } + return ( + + + + + + ); +} + function getProvider( globalModelId: string, providers: AiService.ListProvidersResponse diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index c33d36f25..721d157f9 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -18,6 +18,7 @@ import { import { SelectionWatcher } from '../selection-watcher'; import { ChatHandler } from '../chat_handler'; import { CollaboratorsContextProvider } from '../contexts/collaborators-context'; +import { IJaiCompletionProvider } from '../tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { ScrollContainer } from './scroll-container'; @@ -185,6 +186,8 @@ export type ChatProps = { themeManager: IThemeManager | null; rmRegistry: IRenderMimeRegistry; chatView?: ChatView; + completionProvider: IJaiCompletionProvider | null; + openInlineCompleterSettings: () => void; }; enum ChatView { @@ -237,7 +240,11 @@ export function Chat(props: ChatProps): JSX.Element { /> )} {view === ChatView.Settings && ( - + )} diff --git a/packages/jupyter-ai/src/components/settings/use-server-info.ts b/packages/jupyter-ai/src/components/settings/use-server-info.ts index 80c77fbed..d583b9579 100644 --- a/packages/jupyter-ai/src/components/settings/use-server-info.ts +++ b/packages/jupyter-ai/src/components/settings/use-server-info.ts @@ -2,15 +2,20 @@ import { useState, useEffect, useMemo, useCallback } from 'react'; import { AiService } from '../../handler'; import { getProviderId, getModelLocalId } from '../../utils'; -type ServerInfoProperties = { - config: AiService.DescribeConfigResponse; - lmProviders: AiService.ListProvidersResponse; - emProviders: AiService.ListProvidersResponse; +type ProvidersInfo = { lmProvider: AiService.ListProvidersEntry | null; emProvider: AiService.ListProvidersEntry | null; lmLocalId: string; }; +type ServerInfoProperties = { + lmProviders: AiService.ListProvidersResponse; + emProviders: AiService.ListProvidersResponse; + config: AiService.DescribeConfigResponse; + chat: ProvidersInfo; + completions: Omit; +}; + type ServerInfoMethods = { refetchAll: () => Promise; refetchApiKeys: () => Promise; @@ -65,13 +70,25 @@ export function useServerInfo(): ServerInfo { const emProvider = emGid === null ? null : getProvider(emGid, emProviders); const lmLocalId = (lmGid && getModelLocalId(lmGid)) ?? ''; + + const cLmGid = config.completions_model_provider_id; + const cLmProvider = + cLmGid === null ? null : getProvider(cLmGid, lmProviders); + const cLmLocalId = (cLmGid && getModelLocalId(cLmGid)) ?? ''; + setServerInfoProps({ config, lmProviders, emProviders, - lmProvider, - emProvider, - lmLocalId + chat: { + lmProvider, + emProvider, + lmLocalId + }, + completions: { + lmProvider: cLmProvider, + lmLocalId: cLmLocalId + } }); setState(ServerInfoState.Ready); diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index 7848dc20e..512479b7e 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -117,6 +117,7 @@ export namespace AiService { send_with_shift_enter: boolean; fields: Record>; last_read: number; + completions_model_provider_id: string | null; }; export type UpdateConfigRequest = { @@ -126,6 +127,8 @@ export namespace AiService { send_with_shift_enter?: boolean; fields?: Record>; last_read?: number; + completions_model_provider_id?: string | null; + completions_fields?: Record>; }; export async function getConfig(): Promise { @@ -182,6 +185,8 @@ export namespace AiService { help?: string; auth_strategy: AuthStrategy; registry: boolean; + completion_models: string[]; + chat_models: string[]; fields: Field[]; }; diff --git a/packages/jupyter-ai/src/index.ts b/packages/jupyter-ai/src/index.ts index df2ac69d0..d07a0dc34 100644 --- a/packages/jupyter-ai/src/index.ts +++ b/packages/jupyter-ai/src/index.ts @@ -18,6 +18,7 @@ import { ChatHandler } from './chat_handler'; import { buildErrorWidget } from './widgets/chat-error'; import { completionPlugin } from './completions'; import { statusItemPlugin } from './status'; +import { IJaiCompletionProvider } from './tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; export type DocumentTracker = IWidgetTracker; @@ -28,14 +29,20 @@ export type DocumentTracker = IWidgetTracker; const plugin: JupyterFrontEndPlugin = { id: 'jupyter_ai:plugin', autoStart: true, - optional: [IGlobalAwareness, ILayoutRestorer, IThemeManager], + optional: [ + IGlobalAwareness, + ILayoutRestorer, + IThemeManager, + IJaiCompletionProvider + ], requires: [IRenderMimeRegistry], activate: async ( app: JupyterFrontEnd, rmRegistry: IRenderMimeRegistry, globalAwareness: Awareness | null, restorer: ILayoutRestorer | null, - themeManager: IThemeManager | null + themeManager: IThemeManager | null, + completionProvider: IJaiCompletionProvider | null ) => { /** * Initialize selection watcher singleton @@ -47,6 +54,12 @@ const plugin: JupyterFrontEndPlugin = { */ const chatHandler = new ChatHandler(); + const openInlineCompleterSettings = () => { + app.commands.execute('settingeditor:open', { + query: 'Inline Completer' + }); + }; + let chatWidget: ReactWidget | null = null; try { await chatHandler.initialize(); @@ -55,7 +68,9 @@ const plugin: JupyterFrontEndPlugin = { chatHandler, globalAwareness, themeManager, - rmRegistry + rmRegistry, + completionProvider, + openInlineCompleterSettings ); } catch (e) { chatWidget = buildErrorWidget(themeManager); diff --git a/packages/jupyter-ai/src/tokens.ts b/packages/jupyter-ai/src/tokens.ts index f240a8197..f0f301982 100644 --- a/packages/jupyter-ai/src/tokens.ts +++ b/packages/jupyter-ai/src/tokens.ts @@ -1,4 +1,5 @@ import { Token } from '@lumino/coreutils'; +import { ISignal } from '@lumino/signaling'; import type { IRankedMenu } from '@jupyterlab/ui-components'; export interface IJaiStatusItem { @@ -12,3 +13,16 @@ export const IJaiStatusItem = new Token( 'jupyter_ai:IJupyternautStatus', 'Status indicator displayed in the statusbar' ); + +export interface IJaiCompletionProvider { + isEnabled(): boolean; + settingsChanged: ISignal; +} + +/** + * The inline completion provider token. + */ +export const IJaiCompletionProvider = new Token( + 'jupyter_ai:IJaiCompletionProvider', + 'The jupyter-ai inline completion provider API' +); diff --git a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx index 291593b41..bb575febc 100644 --- a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx +++ b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx @@ -7,6 +7,7 @@ import { Chat } from '../components/chat'; import { chatIcon } from '../icons'; import { SelectionWatcher } from '../selection-watcher'; import { ChatHandler } from '../chat_handler'; +import { IJaiCompletionProvider } from '../tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; export function buildChatSidebar( @@ -14,7 +15,9 @@ export function buildChatSidebar( chatHandler: ChatHandler, globalAwareness: Awareness | null, themeManager: IThemeManager | null, - rmRegistry: IRenderMimeRegistry + rmRegistry: IRenderMimeRegistry, + completionProvider: IJaiCompletionProvider | null, + openInlineCompleterSettings: () => void ): ReactWidget { const ChatWidget = ReactWidget.create( ); ChatWidget.id = 'jupyter-ai::chat';