diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 9d843fe4d..8d353638d 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -216,6 +216,13 @@ async def _generate_in_executor( _call_with_args = functools.partial(self._generate, *args, **kwargs) return await loop.run_in_executor(executor, _call_with_args) + @classmethod + def is_api_key_exc(cls, _: Exception): + """ + Determine if the exception is an API key error. Can be implemented by subclasses. + """ + return False + def update_prompt_template(self, format: str, template: str): """ Changes the class-level prompt template for a given format. @@ -263,6 +270,15 @@ class AI21Provider(BaseProvider, AI21): async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs) + @classmethod + def is_api_key_exc(cls, e: Exception): + """ + Determine if the exception is an AI21 API key error. + """ + if isinstance(e, ValueError): + return "status code 401" in str(e) + return False + class AnthropicProvider(BaseProvider, Anthropic): id = "anthropic" @@ -285,6 +301,17 @@ class AnthropicProvider(BaseProvider, Anthropic): def allows_concurrency(self): return False + @classmethod + def is_api_key_exc(cls, e: Exception): + """ + Determine if the exception is an Anthropic API key error. + """ + import anthropic + + if isinstance(e, anthropic.AuthenticationError): + return e.status_code == 401 and "Invalid API Key" in str(e) + return False + class ChatAnthropicProvider(BaseProvider, ChatAnthropic): id = "anthropic-chat" @@ -498,6 +525,18 @@ class OpenAIProvider(BaseProvider, OpenAI): pypi_package_deps = ["openai"] auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") + @classmethod + def is_api_key_exc(cls, e: Exception): + """ + Determine if the exception is an OpenAI API key error. + """ + import openai + + if isinstance(e, openai.error.AuthenticationError): + error_details = e.json_body.get("error", {}) + return error_details.get("code") == "invalid_api_key" + return False + class ChatOpenAIProvider(BaseProvider, OpenAIChat): id = "openai-chat" @@ -525,6 +564,18 @@ def append_exchange(self, prompt: str, output: str): self.prefix_messages.append({"role": "user", "content": prompt}) self.prefix_messages.append({"role": "assistant", "content": output}) + @classmethod + def is_api_key_exc(cls, e: Exception): + """ + Determine if the exception is an OpenAI API key error. + """ + import openai + + if isinstance(e, openai.error.AuthenticationError): + error_details = e.json_body.get("error", {}) + return error_details.get("code") == "invalid_api_key" + return False + # uses the new OpenAIChat provider. temporarily living as a separate class until # conflicts can be resolved @@ -558,6 +609,18 @@ class ChatOpenAINewProvider(BaseProvider, ChatOpenAI): TextField(key="openai_proxy", label="Proxy (optional)", format="text"), ] + @classmethod + def is_api_key_exc(cls, e: Exception): + """ + Determine if the exception is an OpenAI API key error. + """ + import openai + + if isinstance(e, openai.error.AuthenticationError): + error_details = e.json_body.get("error", {}) + return error_details.get("code") == "invalid_api_key" + return False + class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI): id = "azure-chat-openai" diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index fb8a5e3de..9f179de51 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -147,6 +147,13 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage): The default definition of `handle_exc()`. This is the default used when the `handle_exc()` excepts. """ + self.log.error(e) + lm_provider = self.config_manager.lm_provider + if lm_provider and lm_provider.is_api_key_exc(e): + provider_name = getattr(self.config_manager.lm_provider, "name", "") + response = f"Oops! There's a problem connecting to {provider_name}. Please update your {provider_name} API key in the chat settings." + self.reply(response, message) + return formatted_e = traceback.format_exc() response = ( f"Sorry, an error occurred. Details below:\n\n```\n{formatted_e}\n```"