From be2131396cbb14b9bf37c707c53d2ef925e1538b Mon Sep 17 00:00:00 2001 From: Jason Weill Date: Thu, 28 Dec 2023 15:13:47 -0800 Subject: [PATCH] WIP: Upgrades openai, merges in "new" provider --- .../jupyter_ai_magics/__init__.py | 1 - .../jupyter_ai_magics/magics.py | 10 +---- .../jupyter_ai_magics/providers.py | 44 ++----------------- packages/jupyter-ai-magics/pyproject.toml | 1 - 4 files changed, 6 insertions(+), 50 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py index 1bfdaeb24..a00d4877c 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py @@ -22,7 +22,6 @@ BedrockChatProvider, BedrockProvider, ChatAnthropicProvider, - ChatOpenAINewProvider, ChatOpenAIProvider, CohereProvider, GPT4AllProvider, diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index f34e28a22..a5f0fa6f6 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -124,13 +124,6 @@ def __init__(self, shell): super().__init__(shell) self.transcript_openai = [] - # suppress warning when using old OpenAIChat provider - warnings.filterwarnings( - "ignore", - message="You are trying to use a chat model. This way of initializing it is " - "no longer supported. Instead, please use: " - "`from langchain.chat_models import ChatOpenAI`", - ) # suppress warning when using old Anthropic provider warnings.filterwarnings( "ignore", @@ -527,7 +520,8 @@ def run_ai_cell(self, args: CellArgs, prompt: str): # configure and instantiate provider provider_params = {"model_id": local_model_id} if provider_id == "openai-chat": - provider_params["prefix_messages"] = self.transcript_openai + # provider_params["messages"] = self.transcript_openai + pass # for SageMaker, validate that required params are specified if provider_id == "sagemaker-endpoint": if ( diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 4df236c7e..0007c711a 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -22,7 +22,6 @@ AzureChatOpenAI, BedrockChat, ChatAnthropic, - ChatOpenAI, QianfanChatEndpoint, ) from langchain.chat_models.base import BaseChatModel @@ -45,6 +44,7 @@ from langchain.schema import LLMResult from langchain.utils import get_from_dict_or_env +from langchain_community.chat_models import ChatOpenAI class EnvAuthStrategy(BaseModel): """Require one auth token via an environment variable.""" @@ -540,7 +540,7 @@ def is_api_key_exc(cls, e: Exception): return False -class ChatOpenAIProvider(BaseProvider, OpenAIChat): +class ChatOpenAIProvider(BaseProvider, ChatOpenAI): id = "openai-chat" name = "OpenAI" models = [ @@ -564,44 +564,8 @@ class ChatOpenAIProvider(BaseProvider, OpenAIChat): def append_exchange(self, prompt: str, output: str): """Appends a conversational exchange between user and an OpenAI Chat model to a transcript that will be included in future exchanges.""" - self.prefix_messages.append({"role": "user", "content": prompt}) - self.prefix_messages.append({"role": "assistant", "content": output}) - - @classmethod - def is_api_key_exc(cls, e: Exception): - """ - Determine if the exception is an OpenAI API key error. - """ - import openai - - if isinstance(e, openai.AuthenticationError): - error_details = e.json_body.get("error", {}) - return error_details.get("code") == "invalid_api_key" - return False - - -# uses the new OpenAIChat provider. temporarily living as a separate class until -# conflicts can be resolved -class ChatOpenAINewProvider(BaseProvider, ChatOpenAI): - id = "openai-chat-new" - name = "OpenAI" - models = [ - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-0301", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k-0613", - "gpt-4", - "gpt-4-0314", - "gpt-4-0613", - "gpt-4-32k", - "gpt-4-32k-0314", - "gpt-4-32k-0613", - "gpt-4-1106-preview", - ] - model_id_key = "model_name" - pypi_package_deps = ["openai"] - auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") + self.messages.append({"role": "user", "content": prompt}) + self.messages.append({"role": "assistant", "content": output}) fields = [ TextField( diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index 05e11a10d..90e988947 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -62,7 +62,6 @@ gpt4all = "jupyter_ai_magics:GPT4AllProvider" huggingface_hub = "jupyter_ai_magics:HfHubProvider" openai = "jupyter_ai_magics:OpenAIProvider" openai-chat = "jupyter_ai_magics:ChatOpenAIProvider" -openai-chat-new = "jupyter_ai_magics:ChatOpenAINewProvider" azure-chat-openai = "jupyter_ai_magics:AzureChatOpenAIProvider" sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider" amazon-bedrock = "jupyter_ai_magics:BedrockProvider"