Skip to content

Commit

Permalink
comments per @3coins
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii-i committed Dec 19, 2023
1 parent 56af7c9 commit 439b21f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 20 deletions.
56 changes: 38 additions & 18 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
Union,
)

import anthropic
import openai
from jsonpath_ng import parse
from langchain.chat_models import (
AzureChatOpenAI,
Expand Down Expand Up @@ -308,6 +306,8 @@ def is_api_key_exc(cls, e: Exception):
"""
Determine if the exception is an Anthropic API key error.
"""
import anthropic

if isinstance(e, anthropic.AuthenticationError):
return e.status_code == 401 and "Invalid API Key" in str(e)
return False
Expand Down Expand Up @@ -502,20 +502,7 @@ async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)


class OpenAIBaseProvider(BaseProvider):
"""
Determine if the exception is an OpenAI API key error.
"""

@classmethod
def is_api_key_exc(cls, e: Exception):
if isinstance(e, openai.error.AuthenticationError):
error_details = e.json_body.get("error", {})
return error_details.get("code") == "invalid_api_key"
return False


class OpenAIProvider(OpenAIBaseProvider, OpenAI):
class OpenAIProvider(BaseProvider, OpenAI):
id = "openai"
name = "OpenAI"
models = [
Expand All @@ -533,8 +520,19 @@ class OpenAIProvider(OpenAIBaseProvider, OpenAI):
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

def is_api_key_exc(cls, e: Exception):
"""
Determine if the exception is an OpenAI API key error.
"""
import openai

if isinstance(e, openai.error.AuthenticationError):
error_details = e.json_body.get("error", {})
return error_details.get("code") == "invalid_api_key"
return False


class ChatOpenAIProvider(OpenAIBaseProvider, OpenAIChat):
class ChatOpenAIProvider(BaseProvider, OpenAIChat):
id = "openai-chat"
name = "OpenAI"
models = [
Expand All @@ -560,10 +558,21 @@ def append_exchange(self, prompt: str, output: str):
self.prefix_messages.append({"role": "user", "content": prompt})
self.prefix_messages.append({"role": "assistant", "content": output})

def is_api_key_exc(cls, e: Exception):
"""
Determine if the exception is an OpenAI API key error.
"""
import openai

if isinstance(e, openai.error.AuthenticationError):
error_details = e.json_body.get("error", {})
return error_details.get("code") == "invalid_api_key"
return False


# uses the new OpenAIChat provider. temporarily living as a separate class until
# conflicts can be resolved
class ChatOpenAINewProvider(OpenAIBaseProvider, ChatOpenAI):
class ChatOpenAINewProvider(BaseProvider, ChatOpenAI):
id = "openai-chat-new"
name = "OpenAI"
models = [
Expand Down Expand Up @@ -593,6 +602,17 @@ class ChatOpenAINewProvider(OpenAIBaseProvider, ChatOpenAI):
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
]

def is_api_key_exc(cls, e: Exception):
"""
Determine if the exception is an OpenAI API key error.
"""
import openai

if isinstance(e, openai.error.AuthenticationError):
error_details = e.json_body.get("error", {})
return error_details.get("code") == "invalid_api_key"
return False


class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI):
id = "azure-chat-openai"
Expand Down
4 changes: 2 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ def reply(
agent_msg = AgentChatMessage(
id=uuid4().hex,
time=time.time(),
body=response,
body=response, # append info about error to the string
reply_to=human_msg.id if human_msg else "",
show_edit_settings=show_edit_settings,
#error_type=APIAuthenticationError
)

for handler in self._root_chat_handlers.values():
Expand Down

0 comments on commit 439b21f

Please sign in to comment.