Skip to content

Commit

Permalink
Backport PR #513: Make Jupyternaut reply for API auth errors user-fri…
Browse files Browse the repository at this point in the history
…endly (#532)

Co-authored-by: Andrii Ieroshenko <[email protected]>
  • Loading branch information
meeseeksmachine and andrii-i authored Dec 20, 2023
1 parent e76e352 commit 62b1a2a
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
63 changes: 63 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,13 @@ async def _generate_in_executor(
_call_with_args = functools.partial(self._generate, *args, **kwargs)
return await loop.run_in_executor(executor, _call_with_args)

@classmethod
def is_api_key_exc(cls, _: Exception):
"""
Determine if the exception is an API key error. Can be implemented by subclasses.
"""
return False

def update_prompt_template(self, format: str, template: str):
"""
Changes the class-level prompt template for a given format.
Expand Down Expand Up @@ -263,6 +270,15 @@ class AI21Provider(BaseProvider, AI21):
async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)

@classmethod
def is_api_key_exc(cls, e: Exception):
"""
Determine if the exception is an AI21 API key error.
"""
if isinstance(e, ValueError):
return "status code 401" in str(e)
return False


class AnthropicProvider(BaseProvider, Anthropic):
id = "anthropic"
Expand All @@ -285,6 +301,17 @@ class AnthropicProvider(BaseProvider, Anthropic):
def allows_concurrency(self):
return False

@classmethod
def is_api_key_exc(cls, e: Exception):
"""
Determine if the exception is an Anthropic API key error.
"""
import anthropic

if isinstance(e, anthropic.AuthenticationError):
return e.status_code == 401 and "Invalid API Key" in str(e)
return False


class ChatAnthropicProvider(BaseProvider, ChatAnthropic):
id = "anthropic-chat"
Expand Down Expand Up @@ -498,6 +525,18 @@ class OpenAIProvider(BaseProvider, OpenAI):
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

@classmethod
def is_api_key_exc(cls, e: Exception):
"""
Determine if the exception is an OpenAI API key error.
"""
import openai

if isinstance(e, openai.error.AuthenticationError):
error_details = e.json_body.get("error", {})
return error_details.get("code") == "invalid_api_key"
return False


class ChatOpenAIProvider(BaseProvider, OpenAIChat):
id = "openai-chat"
Expand Down Expand Up @@ -525,6 +564,18 @@ def append_exchange(self, prompt: str, output: str):
self.prefix_messages.append({"role": "user", "content": prompt})
self.prefix_messages.append({"role": "assistant", "content": output})

@classmethod
def is_api_key_exc(cls, e: Exception):
"""
Determine if the exception is an OpenAI API key error.
"""
import openai

if isinstance(e, openai.error.AuthenticationError):
error_details = e.json_body.get("error", {})
return error_details.get("code") == "invalid_api_key"
return False


# uses the new OpenAIChat provider. temporarily living as a separate class until
# conflicts can be resolved
Expand Down Expand Up @@ -558,6 +609,18 @@ class ChatOpenAINewProvider(BaseProvider, ChatOpenAI):
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
]

@classmethod
def is_api_key_exc(cls, e: Exception):
"""
Determine if the exception is an OpenAI API key error.
"""
import openai

if isinstance(e, openai.error.AuthenticationError):
error_details = e.json_body.get("error", {})
return error_details.get("code") == "invalid_api_key"
return False


class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI):
id = "azure-chat-openai"
Expand Down
7 changes: 7 additions & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,13 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage):
The default definition of `handle_exc()`. This is the default used when
the `handle_exc()` excepts.
"""
self.log.error(e)
lm_provider = self.config_manager.lm_provider
if lm_provider and lm_provider.is_api_key_exc(e):
provider_name = getattr(self.config_manager.lm_provider, "name", "")
response = f"Oops! There's a problem connecting to {provider_name}. Please update your {provider_name} API key in the chat settings."
self.reply(response, message)
return
formatted_e = traceback.format_exc()
response = (
f"Sorry, an error occurred. Details below:\n\n```\n{formatted_e}\n```"
Expand Down

0 comments on commit 62b1a2a

Please sign in to comment.