Skip to content

Commit

Permalink
WIP: Upgrades openai, merges in "new" provider
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonWeill committed Dec 28, 2023
1 parent f564b1c commit be21313
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 50 deletions.
1 change: 0 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
BedrockChatProvider,
BedrockProvider,
ChatAnthropicProvider,
ChatOpenAINewProvider,
ChatOpenAIProvider,
CohereProvider,
GPT4AllProvider,
Expand Down
10 changes: 2 additions & 8 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,6 @@ def __init__(self, shell):
super().__init__(shell)
self.transcript_openai = []

# suppress warning when using old OpenAIChat provider
warnings.filterwarnings(
"ignore",
message="You are trying to use a chat model. This way of initializing it is "
"no longer supported. Instead, please use: "
"`from langchain.chat_models import ChatOpenAI`",
)
# suppress warning when using old Anthropic provider
warnings.filterwarnings(
"ignore",
Expand Down Expand Up @@ -527,7 +520,8 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
# configure and instantiate provider
provider_params = {"model_id": local_model_id}
if provider_id == "openai-chat":
provider_params["prefix_messages"] = self.transcript_openai
# provider_params["messages"] = self.transcript_openai
pass
# for SageMaker, validate that required params are specified
if provider_id == "sagemaker-endpoint":
if (
Expand Down
44 changes: 4 additions & 40 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
AzureChatOpenAI,
BedrockChat,
ChatAnthropic,
ChatOpenAI,
QianfanChatEndpoint,
)
from langchain.chat_models.base import BaseChatModel
Expand All @@ -45,6 +44,7 @@
from langchain.schema import LLMResult
from langchain.utils import get_from_dict_or_env

from langchain_community.chat_models import ChatOpenAI

class EnvAuthStrategy(BaseModel):
"""Require one auth token via an environment variable."""
Expand Down Expand Up @@ -540,7 +540,7 @@ def is_api_key_exc(cls, e: Exception):
return False


class ChatOpenAIProvider(BaseProvider, OpenAIChat):
class ChatOpenAIProvider(BaseProvider, ChatOpenAI):
id = "openai-chat"
name = "OpenAI"
models = [
Expand All @@ -564,44 +564,8 @@ class ChatOpenAIProvider(BaseProvider, OpenAIChat):
def append_exchange(self, prompt: str, output: str):
"""Appends a conversational exchange between user and an OpenAI Chat
model to a transcript that will be included in future exchanges."""
self.prefix_messages.append({"role": "user", "content": prompt})
self.prefix_messages.append({"role": "assistant", "content": output})

@classmethod
def is_api_key_exc(cls, e: Exception):
"""
Determine if the exception is an OpenAI API key error.
"""
import openai

if isinstance(e, openai.AuthenticationError):
error_details = e.json_body.get("error", {})
return error_details.get("code") == "invalid_api_key"
return False


# uses the new OpenAIChat provider. temporarily living as a separate class until
# conflicts can be resolved
class ChatOpenAINewProvider(BaseProvider, ChatOpenAI):
id = "openai-chat-new"
name = "OpenAI"
models = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
"gpt-4-1106-preview",
]
model_id_key = "model_name"
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
self.messages.append({"role": "user", "content": prompt})
self.messages.append({"role": "assistant", "content": output})

fields = [
TextField(
Expand Down
1 change: 0 additions & 1 deletion packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ gpt4all = "jupyter_ai_magics:GPT4AllProvider"
huggingface_hub = "jupyter_ai_magics:HfHubProvider"
openai = "jupyter_ai_magics:OpenAIProvider"
openai-chat = "jupyter_ai_magics:ChatOpenAIProvider"
openai-chat-new = "jupyter_ai_magics:ChatOpenAINewProvider"
azure-chat-openai = "jupyter_ai_magics:AzureChatOpenAIProvider"
sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider"
amazon-bedrock = "jupyter_ai_magics:BedrockProvider"
Expand Down

0 comments on commit be21313

Please sign in to comment.