From 67bc59f746fd328bf618fb95ed89d74aa40c0ec3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Krassowski?= <5832902+krassowski@users.noreply.github.com> Date: Fri, 3 May 2024 19:32:17 +0100 Subject: [PATCH] Distinguish between completion and chat models (#711) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Distinguish between completion and chat models * Fix tests * Shorten the tab name, move settings button Lint * Implement the completion model selection in chat UI * Improve docstring * Call `_validate_lm_em_id` only once, add typing annotations * Remove embeddings provider for completions as the team has no plans to support it :( * Use type alias to reduce changeset/make review easier Without this change prettier reformats the plugin with an extra indentation, which leads to bad changeset display on GitHub. * Rename `_validate_lm_em_id` to `_validate_model_ids` * Rename `LLMHandlerMixin` to `CompletionsModelMixin` and rename the file from `llm_mixin` to `model_mixin` fro consistency. Of note, the file name does not need `completions_` prefix as the file is in `completions/` subdirectory. * Rename "Chat LM" to "LM"; add title attribute; note using the title attribute because getting the icon to show up nicely (getting they nice grey color and positioning as it gets in buttons, compared to just plain black) was not trivial; I think the icon might be the way to go in the future but I would postpone it to another PR. That said, I still think it should say "Chat LM" because it has no effect on magics nor completions. * Rename heading "Completer model" → "Inline completions model" * Move `UseSignal` down to `CompleterSettingsButton` implementation * Rename the label in the select to "Inline completion model" * Disable selection when completer is not enabled * Remove use of `UseSignal`, tweak naming of `useState` from `completerIsEnabled` to `isCompleterEnabled` * Use mui tooltips * Fix use of `jai_config_manager` * Fix tests --- .../jupyter_ai_magics/providers.py | 10 + .../jupyter_ai/completions/handlers/base.py | 4 +- .../handlers/{llm_mixin.py => model_mixin.py} | 8 +- .../jupyter_ai/config/config_schema.json | 17 ++ .../jupyter-ai/jupyter_ai/config_manager.py | 150 ++++++------- packages/jupyter-ai/jupyter_ai/handlers.py | 6 + packages/jupyter-ai/jupyter_ai/models.py | 8 + .../__snapshots__/test_config_manager.ambr | 3 + .../tests/completions/test_handlers.py | 4 +- packages/jupyter-ai/src/completions/plugin.ts | 12 +- .../jupyter-ai/src/completions/provider.ts | 12 +- .../src/components/chat-settings.tsx | 198 ++++++++++++++++-- packages/jupyter-ai/src/components/chat.tsx | 9 +- .../components/settings/use-server-info.ts | 31 ++- packages/jupyter-ai/src/handler.ts | 5 + packages/jupyter-ai/src/index.ts | 21 +- packages/jupyter-ai/src/tokens.ts | 14 ++ .../jupyter-ai/src/widgets/chat-sidebar.tsx | 7 +- 18 files changed, 402 insertions(+), 117 deletions(-) rename packages/jupyter-ai/jupyter_ai/completions/handlers/{llm_mixin.py => model_mixin.py} (89%) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 4daf42bba..e86613038 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -258,6 +258,16 @@ class Config: provider is selected. """ + @classmethod + def chat_models(self): + """Models which are suitable for chat.""" + return self.models + + @classmethod + def completion_models(self): + """Models which are suitable for completions.""" + return self.models + # # instance attrs # diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py index 9eb4f845a..32920dc83 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py @@ -5,7 +5,7 @@ from typing import Union import tornado -from jupyter_ai.completions.handlers.llm_mixin import LLMHandlerMixin +from jupyter_ai.completions.handlers.model_mixin import CompletionsModelMixin from jupyter_ai.completions.models import ( CompletionError, InlineCompletionList, @@ -18,7 +18,7 @@ class BaseInlineCompletionHandler( - LLMHandlerMixin, JupyterHandler, tornado.websocket.WebSocketHandler + CompletionsModelMixin, JupyterHandler, tornado.websocket.WebSocketHandler ): """A Tornado WebSocket handler that receives inline completion requests and fulfills them accordingly. This class is instantiated once per WebSocket diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py similarity index 89% rename from packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py rename to packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py index ed454fff6..fd19498bd 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py @@ -5,8 +5,8 @@ from jupyter_ai_magics.providers import BaseProvider -class LLMHandlerMixin: - """Base class containing shared methods and attributes used by LLM handler classes.""" +class CompletionsModelMixin: + """Mixin class containing methods and attributes used by completions LLM handler.""" handler_kind: str settings: dict @@ -26,8 +26,8 @@ def __init__(self, *args, **kwargs) -> None: self._llm_params = None def get_llm(self) -> Optional[BaseProvider]: - lm_provider = self.jai_config_manager.lm_provider - lm_provider_params = self.jai_config_manager.lm_provider_params + lm_provider = self.jai_config_manager.completions_lm_provider + lm_provider_params = self.jai_config_manager.completions_lm_provider_params if not lm_provider or not lm_provider_params: return None diff --git a/packages/jupyter-ai/jupyter_ai/config/config_schema.json b/packages/jupyter-ai/jupyter_ai/config/config_schema.json index ff7c717c4..d276dd31d 100644 --- a/packages/jupyter-ai/jupyter_ai/config/config_schema.json +++ b/packages/jupyter-ai/jupyter_ai/config/config_schema.json @@ -16,6 +16,12 @@ "default": null, "readOnly": false }, + "completions_model_provider_id": { + "$comment": "Language model global ID for completions.", + "type": ["string", "null"], + "default": null, + "readOnly": false + }, "api_keys": { "$comment": "Dictionary of API keys, mapping key names to key values.", "type": "object", @@ -37,6 +43,17 @@ } }, "additionalProperties": false + }, + "completions_fields": { + "$comment": "Dictionary of model-specific fields, mapping LM GIDs to sub-dictionaries of field key-value pairs for completions.", + "type": "object", + "default": {}, + "patternProperties": { + "^.*$": { + "anyOf": [{ "type": "object" }] + } + }, + "additionalProperties": false } } } diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 392e44601..a05f9b3ca 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -97,6 +97,10 @@ class ConfigManager(Configurable): config=True, ) + model_provider_id: Optional[str] + embeddings_provider_id: Optional[str] + completions_model_provider_id: Optional[str] + def __init__( self, log: Logger, @@ -164,41 +168,49 @@ def _process_existing_config(self, default_config): {k: v for k, v in existing_config.items() if v is not None}, ) config = GlobalConfig(**merged_config) - validated_config = self._validate_lm_em_id(config) + validated_config = self._validate_model_ids(config) # re-write to the file to validate the config and apply any # updates to the config file immediately self._write_config(validated_config) - def _validate_lm_em_id(self, config): - lm_id = config.model_provider_id - em_id = config.embeddings_provider_id + def _validate_model_ids(self, config): + lm_provider_keys = ["model_provider_id", "completions_model_provider_id"] + em_provider_keys = ["embeddings_provider_id"] # if the currently selected language or embedding model are # forbidden, set them to `None` and log a warning. - if lm_id is not None and not self._validate_model(lm_id, raise_exc=False): - self.log.warning( - f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." - ) - config.model_provider_id = None - if em_id is not None and not self._validate_model(em_id, raise_exc=False): - self.log.warning( - f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." - ) - config.embeddings_provider_id = None + for lm_key in lm_provider_keys: + lm_id = getattr(config, lm_key) + if lm_id is not None and not self._validate_model(lm_id, raise_exc=False): + self.log.warning( + f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." + ) + setattr(config, lm_key, None) + for em_key in em_provider_keys: + em_id = getattr(config, em_key) + if em_id is not None and not self._validate_model(em_id, raise_exc=False): + self.log.warning( + f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." + ) + setattr(config, em_key, None) # if the currently selected language or embedding model ids are # not associated with models, set them to `None` and log a warning. - if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]: - self.log.warning( - f"No language model is associated with '{lm_id}'. Setting to None." - ) - config.model_provider_id = None - if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]: - self.log.warning( - f"No embedding model is associated with '{em_id}'. Setting to None." - ) - config.embeddings_provider_id = None + for lm_key in lm_provider_keys: + lm_id = getattr(config, lm_key) + if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]: + self.log.warning( + f"No language model is associated with '{lm_id}'. Setting to None." + ) + setattr(config, lm_key, None) + for em_key in em_provider_keys: + em_id = getattr(config, em_key) + if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]: + self.log.warning( + f"No embedding model is associated with '{em_id}'. Setting to None." + ) + setattr(config, em_key, None) return config @@ -321,6 +333,9 @@ def _write_config(self, new_config: GlobalConfig): complete `GlobalConfig` object, and should not be called publicly.""" # remove any empty field dictionaries new_config.fields = {k: v for k, v in new_config.fields.items() if v} + new_config.completions_fields = { + k: v for k, v in new_config.completions_fields.items() if v + } self._validate_config(new_config) with open(self.config_path, "w") as f: @@ -328,21 +343,18 @@ def _write_config(self, new_config: GlobalConfig): def delete_api_key(self, key_name: str): config_dict = self._read_config().dict() - lm_provider = self.lm_provider - em_provider = self.em_provider required_keys = [] - if ( - lm_provider - and lm_provider.auth_strategy - and lm_provider.auth_strategy.type == "env" - ): - required_keys.append(lm_provider.auth_strategy.name) - if ( - em_provider - and em_provider.auth_strategy - and em_provider.auth_strategy.type == "env" - ): - required_keys.append(self.em_provider.auth_strategy.name) + for provider in [ + self.lm_provider, + self.em_provider, + self.completions_lm_provider, + ]: + if ( + provider + and provider.auth_strategy + and provider.auth_strategy.type == "env" + ): + required_keys.append(provider.auth_strategy.name) if key_name in required_keys: raise KeyInUseError( @@ -390,67 +402,57 @@ def em_gid(self): @property def lm_provider(self): - config = self._read_config() - lm_gid = config.model_provider_id - if lm_gid is None: - return None - - _, Provider = get_lm_provider(config.model_provider_id, self._lm_providers) - return Provider + return self._get_provider("model_provider_id", self._lm_providers) @property def em_provider(self): + return self._get_provider("embeddings_provider_id", self._em_providers) + + @property + def completions_lm_provider(self): + return self._get_provider("completions_model_provider_id", self._lm_providers) + + def _get_provider(self, key, listing): config = self._read_config() - em_gid = config.embeddings_provider_id - if em_gid is None: + gid = getattr(config, key) + if gid is None: return None - _, Provider = get_em_provider(em_gid, self._em_providers) + _, Provider = get_lm_provider(gid, listing) return Provider @property def lm_provider_params(self): - # get generic fields - config = self._read_config() - lm_gid = config.model_provider_id - if not lm_gid: - return None - - lm_lid = lm_gid.split(":", 1)[1] - fields = config.fields.get(lm_gid, {}) - - # get authn fields - _, Provider = get_lm_provider(lm_gid, self._lm_providers) - authn_fields = {} - if Provider.auth_strategy and Provider.auth_strategy.type == "env": - key_name = Provider.auth_strategy.name - authn_fields[key_name.lower()] = config.api_keys[key_name] - - return { - "model_id": lm_lid, - **fields, - **authn_fields, - } + return self._provider_params("model_provider_id", self._lm_providers) @property def em_provider_params(self): + return self._provider_params("embeddings_provider_id", self._em_providers) + + @property + def completions_lm_provider_params(self): + return self._provider_params( + "completions_model_provider_id", self._lm_providers + ) + + def _provider_params(self, key, listing): # get generic fields config = self._read_config() - em_gid = config.embeddings_provider_id - if not em_gid: + gid = getattr(config, key) + if not gid: return None - em_lid = em_gid.split(":", 1)[1] + lid = gid.split(":", 1)[1] # get authn fields - _, Provider = get_em_provider(em_gid, self._em_providers) + _, Provider = get_em_provider(gid, listing) authn_fields = {} if Provider.auth_strategy and Provider.auth_strategy.type == "env": key_name = Provider.auth_strategy.name authn_fields[key_name.lower()] = config.api_keys[key_name] return { - "model_id": em_lid, + "model_id": lid, **authn_fields, } diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 6002310ba..976b7df62 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -290,6 +290,10 @@ def filter_predicate(local_model_id: str): # filter out every model w/ model ID according to allow/blocklist for provider in providers: provider.models = list(filter(filter_predicate, provider.models)) + provider.chat_models = list(filter(filter_predicate, provider.chat_models)) + provider.completion_models = list( + filter(filter_predicate, provider.completion_models) + ) # filter out every provider with no models which satisfy the allow/blocklist, then return return filter((lambda p: len(p.models) > 0), providers) @@ -311,6 +315,8 @@ def get(self): id=provider.id, name=provider.name, models=provider.models, + chat_models=provider.chat_models(), + completion_models=provider.completion_models(), help=provider.help, auth_strategy=provider.auth_strategy, registry=provider.registry, diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 32353a694..6742a922d 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -94,6 +94,8 @@ class ListProvidersEntry(BaseModel): auth_strategy: AuthStrategy registry: bool fields: List[Field] + chat_models: Optional[List[str]] + completion_models: Optional[List[str]] class ListProvidersResponse(BaseModel): @@ -121,6 +123,8 @@ class DescribeConfigResponse(BaseModel): # timestamp indicating when the configuration file was last read. should be # passed to the subsequent UpdateConfig request. last_read: int + completions_model_provider_id: Optional[str] + completions_fields: Dict[str, Dict[str, Any]] def forbid_none(cls, v): @@ -137,6 +141,8 @@ class UpdateConfigRequest(BaseModel): # if passed, this will raise an Error if the config was written to after the # time specified by `last_read` to prevent write-write conflicts. last_read: Optional[int] + completions_model_provider_id: Optional[str] + completions_fields: Optional[Dict[str, Dict[str, Any]]] _validate_send_wse = validator("send_with_shift_enter", allow_reuse=True)( forbid_none @@ -154,3 +160,5 @@ class GlobalConfig(BaseModel): send_with_shift_enter: bool fields: Dict[str, Dict[str, Any]] api_keys: Dict[str, str] + completions_model_provider_id: Optional[str] + completions_fields: Dict[str, Dict[str, Any]] diff --git a/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr b/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr index cd254bedc..111ebe6ea 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr +++ b/packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr @@ -3,6 +3,9 @@ dict({ 'api_keys': list([ ]), + 'completions_fields': dict({ + }), + 'completions_model_provider_id': None, 'embeddings_provider_id': None, 'fields': dict({ }), diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py index 8597463f2..c5b5d1eea 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py @@ -35,8 +35,8 @@ def __init__(self, lm_provider=None, lm_provider_params=None): self.messages = [] self.tasks = [] self.settings["jai_config_manager"] = SimpleNamespace( - lm_provider=lm_provider or MockProvider, - lm_provider_params=lm_provider_params or {"model_id": "model"}, + completions_lm_provider=lm_provider or MockProvider, + completions_lm_provider_params=lm_provider_params or {"model_id": "model"}, ) self.settings["jai_event_loop"] = SimpleNamespace( create_task=lambda x: self.tasks.append(x) diff --git a/packages/jupyter-ai/src/completions/plugin.ts b/packages/jupyter-ai/src/completions/plugin.ts index 835561143..f5963f36f 100644 --- a/packages/jupyter-ai/src/completions/plugin.ts +++ b/packages/jupyter-ai/src/completions/plugin.ts @@ -9,7 +9,7 @@ import { IEditorLanguage } from '@jupyterlab/codemirror'; import { getEditor } from '../selection-watcher'; -import { IJaiStatusItem } from '../tokens'; +import { IJaiStatusItem, IJaiCompletionProvider } from '../tokens'; import { displayName, JaiInlineProvider } from './provider'; import { CompletionWebsocketHandler } from './handler'; @@ -46,7 +46,9 @@ type IcPluginSettings = ISettingRegistry.ISettings & { }; }; -export const completionPlugin: JupyterFrontEndPlugin = { +type JaiCompletionToken = IJaiCompletionProvider | null; + +export const completionPlugin: JupyterFrontEndPlugin = { id: 'jupyter_ai:inline-completions', autoStart: true, requires: [ @@ -55,19 +57,20 @@ export const completionPlugin: JupyterFrontEndPlugin = { ISettingRegistry ], optional: [IJaiStatusItem], + provides: IJaiCompletionProvider, activate: async ( app: JupyterFrontEnd, completionManager: ICompletionProviderManager, languageRegistry: IEditorLanguageRegistry, settingRegistry: ISettingRegistry, statusItem: IJaiStatusItem | null - ): Promise => { + ): Promise => { if (typeof completionManager.registerInlineProvider === 'undefined') { // Gracefully short-circuit on JupyterLab 4.0 and Notebook 7.0 console.warn( 'Inline completions are only supported in JupyterLab 4.1+ and Jupyter Notebook 7.1+' ); - return; + return null; } const completionHandler = new CompletionWebsocketHandler(); @@ -186,5 +189,6 @@ export const completionPlugin: JupyterFrontEndPlugin = { rank: 2 }); } + return provider; } }; diff --git a/packages/jupyter-ai/src/completions/provider.ts b/packages/jupyter-ai/src/completions/provider.ts index b1e0f6a89..786ced851 100644 --- a/packages/jupyter-ai/src/completions/provider.ts +++ b/packages/jupyter-ai/src/completions/provider.ts @@ -9,11 +9,13 @@ import { import { ISettingRegistry } from '@jupyterlab/settingregistry'; import { Notification, showErrorMessage } from '@jupyterlab/apputils'; import { JSONValue, PromiseDelegate } from '@lumino/coreutils'; +import { ISignal, Signal } from '@lumino/signaling'; import { IEditorLanguageRegistry, IEditorLanguage } from '@jupyterlab/codemirror'; import { NotebookPanel } from '@jupyterlab/notebook'; +import { IJaiCompletionProvider } from '../tokens'; import { AiCompleterService as AiService } from './types'; import { DocumentWidget } from '@jupyterlab/docregistry'; import { jupyternautIcon } from '../icons'; @@ -34,7 +36,9 @@ export function displayName(language: IEditorLanguage): string { return language.displayName ?? language.name; } -export class JaiInlineProvider implements IInlineCompletionProvider { +export class JaiInlineProvider + implements IInlineCompletionProvider, IJaiCompletionProvider +{ readonly identifier = JaiInlineProvider.ID; readonly icon = jupyternautIcon.bindprops({ width: 16, top: 1 }); @@ -181,6 +185,7 @@ export class JaiInlineProvider implements IInlineCompletionProvider { async configure(settings: { [property: string]: JSONValue }): Promise { this._settings = settings as unknown as JaiInlineProvider.ISettings; + this._settingsChanged.emit(); } isEnabled(): boolean { @@ -191,6 +196,10 @@ export class JaiInlineProvider implements IInlineCompletionProvider { return !this._settings.disabledLanguages.includes(language); } + get settingsChanged(): ISignal { + return this._settingsChanged; + } + /** * Process the stream chunk to make it available in the awaiting generator. */ @@ -250,6 +259,7 @@ export class JaiInlineProvider implements IInlineCompletionProvider { private _settings: JaiInlineProvider.ISettings = JaiInlineProvider.DEFAULT_SETTINGS; + private _settingsChanged = new Signal(this); private _streamPromises: Map> = new Map(); diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index 889342a24..87009ecf5 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -4,6 +4,7 @@ import { Box } from '@mui/system'; import { Alert, Button, + IconButton, FormControl, FormControlLabel, FormLabel, @@ -11,8 +12,11 @@ import { Radio, RadioGroup, TextField, + Tooltip, CircularProgress } from '@mui/material'; +import SettingsIcon from '@mui/icons-material/Settings'; +import WarningAmberIcon from '@mui/icons-material/WarningAmber'; import { Select } from './select'; import { AiService } from '../handler'; @@ -23,10 +27,13 @@ import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { minifyUpdate } from './settings/minify'; import { useStackingAlert } from './mui-extras/stacking-alert'; import { RendermimeMarkdown } from './rendermime-markdown'; +import { IJaiCompletionProvider } from '../tokens'; import { getProviderId, getModelLocalId } from '../utils'; type ChatSettingsProps = { rmRegistry: IRenderMimeRegistry; + completionProvider: IJaiCompletionProvider | null; + openInlineCompleterSettings: () => void; }; /** @@ -43,9 +50,17 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { // user inputs const [lmProvider, setLmProvider] = useState(null); + const [clmProvider, setClmProvider] = + useState(null); const [showLmLocalId, setShowLmLocalId] = useState(false); - const [helpMarkdown, setHelpMarkdown] = useState(null); + const [showClmLocalId, setShowClmLocalId] = useState(false); + const [chatHelpMarkdown, setChatHelpMarkdown] = useState(null); + const [completionHelpMarkdown, setCompletionHelpMarkdown] = useState< + string | null + >(null); const [lmLocalId, setLmLocalId] = useState(''); + const [clmLocalId, setClmLocalId] = useState(''); + const lmGlobalId = useMemo(() => { if (!lmProvider) { return null; @@ -53,6 +68,13 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { return lmProvider.id + ':' + lmLocalId; }, [lmProvider, lmLocalId]); + const clmGlobalId = useMemo(() => { + if (!clmProvider) { + return null; + } + + return clmProvider.id + ':' + clmLocalId; + }, [clmProvider, clmLocalId]); const [emGlobalId, setEmGlobalId] = useState(null); const emProvider = useMemo(() => { @@ -67,6 +89,25 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { const [sendWse, setSendWse] = useState(false); const [fields, setFields] = useState>({}); + const [isCompleterEnabled, setIsCompleterEnabled] = useState( + props.completionProvider && props.completionProvider.isEnabled() + ); + + const refreshCompleterState = () => { + setIsCompleterEnabled( + props.completionProvider && props.completionProvider.isEnabled() + ); + }; + + useEffect(() => { + props.completionProvider?.settingsChanged.connect(refreshCompleterState); + return () => { + props.completionProvider?.settingsChanged.disconnect( + refreshCompleterState + ); + }; + }, [props.completionProvider]); + // whether the form is currently saving const [saving, setSaving] = useState(false); @@ -78,14 +119,20 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { return; } - setLmLocalId(server.lmLocalId); + setLmLocalId(server.chat.lmLocalId); + setClmLocalId(server.completions.lmLocalId); setEmGlobalId(server.config.embeddings_provider_id); setSendWse(server.config.send_with_shift_enter); - setHelpMarkdown(server.lmProvider?.help ?? null); - if (server.lmProvider?.registry) { + setChatHelpMarkdown(server.chat.lmProvider?.help ?? null); + setCompletionHelpMarkdown(server.completions.lmProvider?.help ?? null); + if (server.chat.lmProvider?.registry) { setShowLmLocalId(true); } - setLmProvider(server.lmProvider); + if (server.completions.lmProvider?.registry) { + setShowClmLocalId(true); + } + setLmProvider(server.chat.lmProvider); + setClmProvider(server.completions.lmProvider); }, [server]); /** @@ -169,11 +216,17 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { model_provider_id: lmGlobalId, embeddings_provider_id: emGlobalId, api_keys: apiKeys, - ...(lmGlobalId && { + ...((lmGlobalId || clmGlobalId) && { fields: { - [lmGlobalId]: fields + ...(lmGlobalId && { + [lmGlobalId]: fields + }), + ...(clmGlobalId && { + [clmGlobalId]: fields + }) } }), + completions_model_provider_id: clmGlobalId, send_with_shift_enter: sendWse }; updateRequest = minifyUpdate(server.config, updateRequest); @@ -244,11 +297,16 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { overflowY: 'auto' }} > - {/* Language model section */} -

Language model

+ {/* Chat language model section */} +

+ Language model +

{showLmLocalId && ( @@ -290,10 +350,10 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { fullWidth /> )} - {helpMarkdown && ( + {chatHelpMarkdown && ( )} {lmGlobalId && ( @@ -327,6 +387,77 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { )} + {/* Completer language model section */} +

+ Inline completions model + +

+ + {showClmLocalId && ( + setClmLocalId(e.target.value)} + fullWidth + /> + )} + {completionHelpMarkdown && ( + + )} + {clmGlobalId && ( + + )} + {/* API Keys section */}

API Keys

{/* API key inputs for newly-used providers */} @@ -392,6 +523,37 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { ); } +function CompleterSettingsButton(props: { + selection: AiService.ListProvidersEntry | null; + provider: IJaiCompletionProvider | null; + isCompleterEnabled: boolean | null; + openSettings: () => void; +}): JSX.Element { + if (props.selection && !props.isCompleterEnabled) { + return ( + + + + + + ); + } + return ( + + + + + + ); +} + function getProvider( globalModelId: string, providers: AiService.ListProvidersResponse diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index c33d36f25..721d157f9 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -18,6 +18,7 @@ import { import { SelectionWatcher } from '../selection-watcher'; import { ChatHandler } from '../chat_handler'; import { CollaboratorsContextProvider } from '../contexts/collaborators-context'; +import { IJaiCompletionProvider } from '../tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { ScrollContainer } from './scroll-container'; @@ -185,6 +186,8 @@ export type ChatProps = { themeManager: IThemeManager | null; rmRegistry: IRenderMimeRegistry; chatView?: ChatView; + completionProvider: IJaiCompletionProvider | null; + openInlineCompleterSettings: () => void; }; enum ChatView { @@ -237,7 +240,11 @@ export function Chat(props: ChatProps): JSX.Element { /> )} {view === ChatView.Settings && ( - + )} diff --git a/packages/jupyter-ai/src/components/settings/use-server-info.ts b/packages/jupyter-ai/src/components/settings/use-server-info.ts index 80c77fbed..d583b9579 100644 --- a/packages/jupyter-ai/src/components/settings/use-server-info.ts +++ b/packages/jupyter-ai/src/components/settings/use-server-info.ts @@ -2,15 +2,20 @@ import { useState, useEffect, useMemo, useCallback } from 'react'; import { AiService } from '../../handler'; import { getProviderId, getModelLocalId } from '../../utils'; -type ServerInfoProperties = { - config: AiService.DescribeConfigResponse; - lmProviders: AiService.ListProvidersResponse; - emProviders: AiService.ListProvidersResponse; +type ProvidersInfo = { lmProvider: AiService.ListProvidersEntry | null; emProvider: AiService.ListProvidersEntry | null; lmLocalId: string; }; +type ServerInfoProperties = { + lmProviders: AiService.ListProvidersResponse; + emProviders: AiService.ListProvidersResponse; + config: AiService.DescribeConfigResponse; + chat: ProvidersInfo; + completions: Omit; +}; + type ServerInfoMethods = { refetchAll: () => Promise; refetchApiKeys: () => Promise; @@ -65,13 +70,25 @@ export function useServerInfo(): ServerInfo { const emProvider = emGid === null ? null : getProvider(emGid, emProviders); const lmLocalId = (lmGid && getModelLocalId(lmGid)) ?? ''; + + const cLmGid = config.completions_model_provider_id; + const cLmProvider = + cLmGid === null ? null : getProvider(cLmGid, lmProviders); + const cLmLocalId = (cLmGid && getModelLocalId(cLmGid)) ?? ''; + setServerInfoProps({ config, lmProviders, emProviders, - lmProvider, - emProvider, - lmLocalId + chat: { + lmProvider, + emProvider, + lmLocalId + }, + completions: { + lmProvider: cLmProvider, + lmLocalId: cLmLocalId + } }); setState(ServerInfoState.Ready); diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index 7848dc20e..512479b7e 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -117,6 +117,7 @@ export namespace AiService { send_with_shift_enter: boolean; fields: Record>; last_read: number; + completions_model_provider_id: string | null; }; export type UpdateConfigRequest = { @@ -126,6 +127,8 @@ export namespace AiService { send_with_shift_enter?: boolean; fields?: Record>; last_read?: number; + completions_model_provider_id?: string | null; + completions_fields?: Record>; }; export async function getConfig(): Promise { @@ -182,6 +185,8 @@ export namespace AiService { help?: string; auth_strategy: AuthStrategy; registry: boolean; + completion_models: string[]; + chat_models: string[]; fields: Field[]; }; diff --git a/packages/jupyter-ai/src/index.ts b/packages/jupyter-ai/src/index.ts index df2ac69d0..d07a0dc34 100644 --- a/packages/jupyter-ai/src/index.ts +++ b/packages/jupyter-ai/src/index.ts @@ -18,6 +18,7 @@ import { ChatHandler } from './chat_handler'; import { buildErrorWidget } from './widgets/chat-error'; import { completionPlugin } from './completions'; import { statusItemPlugin } from './status'; +import { IJaiCompletionProvider } from './tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; export type DocumentTracker = IWidgetTracker; @@ -28,14 +29,20 @@ export type DocumentTracker = IWidgetTracker; const plugin: JupyterFrontEndPlugin = { id: 'jupyter_ai:plugin', autoStart: true, - optional: [IGlobalAwareness, ILayoutRestorer, IThemeManager], + optional: [ + IGlobalAwareness, + ILayoutRestorer, + IThemeManager, + IJaiCompletionProvider + ], requires: [IRenderMimeRegistry], activate: async ( app: JupyterFrontEnd, rmRegistry: IRenderMimeRegistry, globalAwareness: Awareness | null, restorer: ILayoutRestorer | null, - themeManager: IThemeManager | null + themeManager: IThemeManager | null, + completionProvider: IJaiCompletionProvider | null ) => { /** * Initialize selection watcher singleton @@ -47,6 +54,12 @@ const plugin: JupyterFrontEndPlugin = { */ const chatHandler = new ChatHandler(); + const openInlineCompleterSettings = () => { + app.commands.execute('settingeditor:open', { + query: 'Inline Completer' + }); + }; + let chatWidget: ReactWidget | null = null; try { await chatHandler.initialize(); @@ -55,7 +68,9 @@ const plugin: JupyterFrontEndPlugin = { chatHandler, globalAwareness, themeManager, - rmRegistry + rmRegistry, + completionProvider, + openInlineCompleterSettings ); } catch (e) { chatWidget = buildErrorWidget(themeManager); diff --git a/packages/jupyter-ai/src/tokens.ts b/packages/jupyter-ai/src/tokens.ts index f240a8197..f0f301982 100644 --- a/packages/jupyter-ai/src/tokens.ts +++ b/packages/jupyter-ai/src/tokens.ts @@ -1,4 +1,5 @@ import { Token } from '@lumino/coreutils'; +import { ISignal } from '@lumino/signaling'; import type { IRankedMenu } from '@jupyterlab/ui-components'; export interface IJaiStatusItem { @@ -12,3 +13,16 @@ export const IJaiStatusItem = new Token( 'jupyter_ai:IJupyternautStatus', 'Status indicator displayed in the statusbar' ); + +export interface IJaiCompletionProvider { + isEnabled(): boolean; + settingsChanged: ISignal; +} + +/** + * The inline completion provider token. + */ +export const IJaiCompletionProvider = new Token( + 'jupyter_ai:IJaiCompletionProvider', + 'The jupyter-ai inline completion provider API' +); diff --git a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx index 291593b41..bb575febc 100644 --- a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx +++ b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx @@ -7,6 +7,7 @@ import { Chat } from '../components/chat'; import { chatIcon } from '../icons'; import { SelectionWatcher } from '../selection-watcher'; import { ChatHandler } from '../chat_handler'; +import { IJaiCompletionProvider } from '../tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; export function buildChatSidebar( @@ -14,7 +15,9 @@ export function buildChatSidebar( chatHandler: ChatHandler, globalAwareness: Awareness | null, themeManager: IThemeManager | null, - rmRegistry: IRenderMimeRegistry + rmRegistry: IRenderMimeRegistry, + completionProvider: IJaiCompletionProvider | null, + openInlineCompleterSettings: () => void ): ReactWidget { const ChatWidget = ReactWidget.create( ); ChatWidget.id = 'jupyter-ai::chat';