From 742f20bc9c6a10f84ccb9479f73a2a560b4140ae Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Fri, 8 Dec 2023 10:43:40 -0800 Subject: [PATCH] initialize config manager with errors --- .../jupyter-ai/jupyter_ai/config_manager.py | 104 +++++++++++------- packages/jupyter-ai/jupyter_ai/extension.py | 38 +++++-- packages/jupyter-ai/src/components/chat.tsx | 34 ++++-- packages/jupyter-ai/src/index.ts | 24 +++- .../jupyter-ai/src/widgets/chat-sidebar.tsx | 2 +- 5 files changed, 135 insertions(+), 67 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index c937ff614..325be3dd5 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -164,19 +164,29 @@ def _init_validator(self) -> Validator: self.validator = Validator(schema) def _init_config(self): - try: - if os.path.exists(self.config_path): - self._process_existing_config() - else: - self._create_default_config() - except ValidationError as e: - self._handle_validation_error(e) + # try: + if os.path.exists(self.config_path): + self._process_existing_config() + else: + self._create_default_config() + # except ValidationError as e: + # self._handle_validation_error(e) + # self._config = GlobalConfig( + # send_with_shift_enter=False, fields={}, api_keys={} + # ) def _process_existing_config(self): with open(self.config_path, encoding="utf-8") as f: - config = GlobalConfig(**json.loads(f.read())) - self._validate_lm_em_id(config) + raw_config = json.loads(f.read()) + + validated_raw_config = self._validate_lm_em_id(raw_config) + + try: + config = GlobalConfig(**validated_raw_config) self._write_config(config) + except ValidationError as e: + corrected_config = self._handle_validation_error(e, validated_raw_config) + self._write_config(corrected_config) def _create_default_config(self): properties = self.validator.schema.get("properties", {}) @@ -187,16 +197,16 @@ def _create_default_config(self): default_config = GlobalConfig(**field_dict) self._write_config(default_config) - def _validate_lm_em_id(self, config): - lm_id = config.model_provider_id - em_id = config.embeddings_provider_id + def _validate_lm_em_id(self, raw_config): + lm_id = raw_config.get("model_provider_id") + em_id = raw_config.get("embeddings_provider_id") # if the currently selected language or embedding model are # forbidden, set them to `None` and log a warning. if lm_id is not None and not self._validate_model(lm_id, raise_exc=False): warning_message = f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." self.log.warning(warning_message) - config.model_provider_id = None + raw_config["model_provider_id"] = None self._config_errors.append( ConfigErrorModel( error_type=ConfigErrorType.WARNING, message=warning_message @@ -206,7 +216,7 @@ def _validate_lm_em_id(self, config): if em_id is not None and not self._validate_model(em_id, raise_exc=False): warning_message = f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." self.log.warning(warning_message) - config.embeddings_provider_id = None + raw_config["embeddings_provider_id"] = None self._config_errors.append( ConfigErrorModel( error_type=ConfigErrorType.WARNING, message=warning_message @@ -220,7 +230,7 @@ def _validate_lm_em_id(self, config): f"No language model is associated with '{lm_id}'. Setting to None." ) self.log.warning(warning_message) - config.model_provider_id = None + raw_config["model_provider_id"] = None self._config_errors.append( ConfigErrorModel( error_type=ConfigErrorType.WARNING, message=warning_message @@ -232,28 +242,43 @@ def _validate_lm_em_id(self, config): f"No embedding model is associated with '{em_id}'. Setting to None." ) self.log.warning(warning_message) - config.embeddings_provider_id = None + raw_config["embeddings_provider_id"] = None self._config_errors.append( ConfigErrorModel( error_type=ConfigErrorType.WARNING, message=warning_message ) ) - # re-write to the file to validate the config and apply any - # updates to the config file immediately - self._write_config(config) - - def _handle_validation_error(self, e: ValidationError): - formatted_error = _format_validation_errors(e) - error_message = "Configuration validation failed" - self._config_errors.append( - ConfigErrorModel( - error_type=ConfigErrorType.CRITICAL, - message=error_message, - details=formatted_error, - ) - ) - self.log.error(f"{error_message}: {formatted_error}") + return raw_config + + def _handle_validation_error(self, e: ValidationError, raw_config): + # Extract default values from schema + properties = self.validator.schema.get("properties", {}) + field_list = GlobalConfig.__fields__.keys() + default_values = { + field: properties.get(field).get("default") for field in field_list + } + + # Apply default values to erroneous fields + for error in e.errors(): + field = error["loc"][0] + if field in default_values: + raw_config[field] = default_values[field] + warning_message = f"Error in '{field}': {error['msg']}. Resetting to default value ('{default_values[field]}')." + self.log.warning(warning_message) + self._config_errors.append( + ConfigErrorModel( + error_type=ConfigErrorType.WARNING, message=warning_message + ) + ) + + # Create a config with default values for erroneous fields + config = GlobalConfig(**raw_config) + self.log.warning("\n\n\n Config \n\n\n") + + self.log.warning(config) + self._validate_config(config) + return config def _read_config(self) -> GlobalConfig: """Returns the user's current configuration as a GlobalConfig object. @@ -264,12 +289,15 @@ def _read_config(self) -> GlobalConfig: if last_write <= self._last_read: return self._config - with open(self.config_path, encoding="utf-8") as f: - self._last_read = time.time_ns() - raw_config = json.loads(f.read()) - config = GlobalConfig(**raw_config) - self._validate_config(config) - return config + with open(self.config_path, encoding="utf-8") as f: + self._last_read = time.time_ns() + raw_config = json.loads(f.read()) + try: + config = GlobalConfig(**raw_config) + except ValidationError as e: + config = self._handle_validation_error(e, raw_config) + self._validate_config(config) + return config def _validate_config(self, config: GlobalConfig): """Method used to validate the configuration. This is called after every @@ -414,7 +442,7 @@ def update_config(self, config_update: UpdateConfigRequest): def get_config(self): config = self._read_config() config_dict = config.dict(exclude_unset=True) - api_key_names = list(config_dict.pop("api_keys").keys()) + api_key_names = list(config_dict.pop("api_keys", {}).keys()) return DescribeConfigResponse( **config_dict, api_keys=api_key_names, diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 3dc8bbb40..50a91d1e1 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -28,14 +28,7 @@ class AiExtension(ExtensionApp): name = "jupyter_ai" - handlers = [ - (r"api/ai/api_keys/(?P\w+)", ApiKeysHandler), - (r"api/ai/config/?", GlobalConfigHandler), - (r"api/ai/chats/?", RootChatHandler), - (r"api/ai/chats/history?", ChatHistoryHandler), - (r"api/ai/providers?", ModelProviderHandler), - (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), - ] + handlers = [(r"api/ai/config/?", GlobalConfigHandler)] allowed_providers = List( Unicode(), @@ -139,15 +132,24 @@ def initialize_settings(self): else: # Log the error and proceed with limited functionality self.log.error(f"Configuration errors detected: {config_errors}") - # TODO: self._initialize_limited_functionality() + self._initialize_limited_functionality(config_errors) - self.log.info("Registered providers.") self.log.info(f"Registered {self.name} server extension") latency_ms = round((time.time() - start) * 1000) self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.") def _initialize_full_functionality(self): + self.handlers.extend( + [ + (r"api/ai/api_keys/(?P\w+)", ApiKeysHandler), + (r"api/ai/chats/?", RootChatHandler), + (r"api/ai/chats/history?", ChatHistoryHandler), + (r"api/ai/providers?", ModelProviderHandler), + (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), + ] + ) + # Store chat clients in a dictionary self.settings["chat_clients"] = {} self.settings["jai_root_chat_handlers"] = {} @@ -204,5 +206,21 @@ def _initialize_full_functionality(self): "/help": help_chat_handler, } + self.log.info("Registered providers.") + + def _initialize_limited_functionality(self, config_errors): + """ + Initialize the extension with limited functionality due to configuration errors. + """ + self.log.warning( + "Initializing Jupyter AI extension with limited functionality due to configuration errors." + ) + + # Capture configuration error details + config_errors = self.settings["jai_config_manager"].get_config_errors() + self.settings["config_errors"] = config_errors + + self.settings["jai_chat_handlers"] = [] + async def _get_dask_client(self): return DaskClient(processes=False, asynchronous=True) diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index 83ba69c42..e87e0eb13 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -20,7 +20,7 @@ import { CollaboratorsContextProvider } from '../contexts/collaborators-context' import { ScrollContainer } from './scroll-container'; type ChatBodyProps = { - chatHandler: ChatHandler; + chatHandler: ChatHandler | null; setChatView: (view: ChatView) => void; }; @@ -43,12 +43,22 @@ function ChatBody({ useEffect(() => { async function fetchHistory() { try { - const [history, config] = await Promise.all([ - chatHandler.getHistory(), - AiService.getConfig() - ]); + const config = await AiService.getConfig(); setSendWithShiftEnter(config.send_with_shift_enter ?? false); - setMessages(history.messages); + + // Check if there are critical errors + const hasCriticalErrors = config.config_errors?.some( + error => error.error_type === AiService.ConfigErrorType.CRITICAL + ); + console.log('\n\n\n *** \n\n\n'); + console.log(hasCriticalErrors); + if (!hasCriticalErrors && chatHandler) { + const history = await chatHandler.getHistory(); + setMessages(history.messages); + } else { + setMessages([]); + } + if (!config.model_provider_id) { setShowWelcomeMessage(true); } @@ -78,9 +88,9 @@ function ChatBody({ setMessages(messageGroups => [...messageGroups, message]); } - chatHandler.addListener(handleChatEvents); + chatHandler?.addListener(handleChatEvents); return function cleanup() { - chatHandler.removeListener(handleChatEvents); + chatHandler?.removeListener(handleChatEvents); }; }, [chatHandler]); @@ -96,18 +106,18 @@ function ChatBody({ : ''); // send message to backend - const messageId = await chatHandler.sendMessage({ prompt }); + const messageId = await chatHandler?.sendMessage({ prompt }); // await reply from agent // no need to append to messageGroups state variable, since that's already // handled in the effect hooks. - const reply = await chatHandler.replyFor(messageId); + const reply = await chatHandler?.replyFor(messageId ?? ''); if (replaceSelection && selection) { const { cellId, ...selectionProps } = selection; replaceSelectionFn({ ...selectionProps, ...(cellId && { cellId }), - text: reply.body + text: reply?.body ?? '' }); } }; @@ -187,7 +197,7 @@ function ChatBody({ export type ChatProps = { selectionWatcher: SelectionWatcher; - chatHandler: ChatHandler; + chatHandler: ChatHandler | null; globalAwareness: Awareness | null; chatView?: ChatView; }; diff --git a/packages/jupyter-ai/src/index.ts b/packages/jupyter-ai/src/index.ts index e48e2b211..c8e57bd93 100644 --- a/packages/jupyter-ai/src/index.ts +++ b/packages/jupyter-ai/src/index.ts @@ -12,6 +12,7 @@ import { buildChatSidebar } from './widgets/chat-sidebar'; import { SelectionWatcher } from './selection-watcher'; import { ChatHandler } from './chat_handler'; import { buildErrorWidget } from './widgets/chat-error'; +import { AiService } from './handler'; export type DocumentTracker = IWidgetTracker; @@ -32,14 +33,25 @@ const plugin: JupyterFrontEndPlugin = { */ const selectionWatcher = new SelectionWatcher(app.shell); - /** - * Initialize chat handler, open WS connection - */ - const chatHandler = new ChatHandler(); - let chatWidget: ReactWidget | null = null; + let chatHandler: ChatHandler | null = null; + try { - await chatHandler.initialize(); + // Fetch configuration to check for critical errors + const config = await AiService.getConfig(); + console.log('\n\n\n *** \n\n\n'); + console.log(config.config_errors); + const hasCriticalErrors = config.config_errors?.some( + error => error.error_type === AiService.ConfigErrorType.CRITICAL + ); + + if (!hasCriticalErrors) { + /** + * Initialize chat handler, open WS connection + */ + chatHandler = new ChatHandler(); + await chatHandler.initialize(); + } chatWidget = buildChatSidebar( selectionWatcher, chatHandler, diff --git a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx index 8bc7df12c..abcd81ba4 100644 --- a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx +++ b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx @@ -9,7 +9,7 @@ import { ChatHandler } from '../chat_handler'; export function buildChatSidebar( selectionWatcher: SelectionWatcher, - chatHandler: ChatHandler, + chatHandler: ChatHandler | null, globalAwareness: Awareness | null ): ReactWidget { const ChatWidget = ReactWidget.create(