diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index b6ee525a6..170e17478 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -17,7 +17,7 @@ from dask.distributed import Client as DaskClient from jupyter_ai.config_manager import ConfigManager, Logger from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage -from jupyter_ai.utils import OpenAIErrorUtil +from jupyter_ai.utils import AI21ErrorUtility, OpenAIErrorUtil from jupyter_ai_magics.providers import BaseProvider from openai.error import AuthenticationError as OpenAIAuthenticationError @@ -125,7 +125,7 @@ def is_api_key_exc(self, e: Exception): """ Checks if the exception is an API key exception. """ - return OpenAIErrorUtil.is_api_key_exc(e) + return OpenAIErrorUtil.is_api_key_exc(e) or AI21ErrorUtility.is_api_key_exc(e) def handle_api_key_exc(self, e: Exception, message: HumanChatMessage): provider_name = "" diff --git a/packages/jupyter-ai/jupyter_ai/utils.py b/packages/jupyter-ai/jupyter_ai/utils.py index d22d64899..dd0f770ad 100644 --- a/packages/jupyter-ai/jupyter_ai/utils.py +++ b/packages/jupyter-ai/jupyter_ai/utils.py @@ -14,3 +14,12 @@ def is_api_key_exc(e: Exception): error_details = e.json_body.get("error", {}) return error_details.get("code") == "invalid_api_key" return False + + +class AI21ErrorUtility: + @staticmethod + def is_api_key_exc(e: Exception): + if isinstance(e, ValueError): + # Check if the exception message contains "status code 401" + return "status code 401" in str(e) + return False