From 4a6336103e6f1799dc0eced7d98a73501d0d6a3b Mon Sep 17 00:00:00 2001
From: Steven <tartakovsky.developer@gmail.com>
Date: Wed, 27 Mar 2024 08:54:19 -0700
Subject: [PATCH] Use new `langchain-openai` partner package (#653)

* Import ChatOpenAI and AzureChatOpenAI from the updated langchain location:

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update pyproject.toml, add `langchain-openai` as an optional dependency

* remove direct dependency on openai and tiktoken packages

* move openai providers to separate module to keep langchain_openai optional

* update openai provider dependencies in docs

* pre-commit

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: David L. Qiu <david@qiu.dev>
---
 docs/source/users/index.md                    |   4 +-
 .../jupyter_ai_magics/__init__.py             |   4 -
 .../jupyter_ai_magics/embedding_providers.py  |  13 ---
 .../partner_providers/openai.py               | 108 ++++++++++++++++++
 .../jupyter_ai_magics/providers.py            |  93 ---------------
 packages/jupyter-ai-magics/pyproject.toml     |  10 +-
 packages/jupyter-ai/pyproject.toml            |   1 -
 7 files changed, 115 insertions(+), 118 deletions(-)
 create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py

diff --git a/docs/source/users/index.md b/docs/source/users/index.md
index ad15da84f..9cc38ba93 100644
--- a/docs/source/users/index.md
+++ b/docs/source/users/index.md
@@ -144,8 +144,8 @@ Jupyter AI supports the following model providers:
 | GPT4All             | `gpt4all`            | N/A                        | `gpt4all`                       |
 | Hugging Face Hub    | `huggingface_hub`    | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` |
 | NVIDIA              | `nvidia-chat`        | `NVIDIA_API_KEY`           | `langchain_nvidia_ai_endpoints` |
-| OpenAI              | `openai`             | `OPENAI_API_KEY`           | `openai`                        |
-| OpenAI (chat)       | `openai-chat`        | `OPENAI_API_KEY`           | `openai`                        |
+| OpenAI              | `openai`             | `OPENAI_API_KEY`           | `langchain-openai`              |
+| OpenAI (chat)       | `openai-chat`        | `OPENAI_API_KEY`           | `langchain-openai`              |
 | SageMaker           | `sagemaker-endpoint` | N/A                        | `boto3`                         |
 
 The environment variable names shown above are also the names of the settings keys used when setting up the chat interface.
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
index a8f4225de..b6db9f957 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
@@ -7,7 +7,6 @@
     CohereEmbeddingsProvider,
     GPT4AllEmbeddingsProvider,
     HfHubEmbeddingsProvider,
-    OpenAIEmbeddingsProvider,
     QianfanEmbeddingsEndpointProvider,
 )
 from .exception import store_exception
@@ -17,16 +16,13 @@
 from .providers import (
     AI21Provider,
     AnthropicProvider,
-    AzureChatOpenAIProvider,
     BaseProvider,
     BedrockChatProvider,
     BedrockProvider,
     ChatAnthropicProvider,
-    ChatOpenAIProvider,
     CohereProvider,
     GPT4AllProvider,
     HfHubProvider,
-    OpenAIProvider,
     QianfanProvider,
     SmEndpointProvider,
     TogetherAIProvider,
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py
index 55f74bc1c..3b187e1c2 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py
@@ -68,19 +68,6 @@ def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs, **model_kwargs)
 
 
-class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings):
-    id = "openai"
-    name = "OpenAI"
-    models = [
-        "text-embedding-ada-002",
-        "text-embedding-3-small",
-        "text-embedding-3-large",
-    ]
-    model_id_key = "model"
-    pypi_package_deps = ["openai"]
-    auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
-
-
 class CohereEmbeddingsProvider(BaseEmbeddingsProvider, CohereEmbeddings):
     id = "cohere"
     name = "Cohere"
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py
new file mode 100644
index 000000000..382a480e1
--- /dev/null
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py
@@ -0,0 +1,108 @@
+from langchain_openai import AzureChatOpenAI, ChatOpenAI, OpenAI, OpenAIEmbeddings
+
+from ..embedding_providers import BaseEmbeddingsProvider
+from ..providers import BaseProvider, EnvAuthStrategy, TextField
+
+
+class OpenAIProvider(BaseProvider, OpenAI):
+    id = "openai"
+    name = "OpenAI"
+    models = ["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct"]
+    model_id_key = "model_name"
+    pypi_package_deps = ["langchain_openai"]
+    auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
+
+    @classmethod
+    def is_api_key_exc(cls, e: Exception):
+        """
+        Determine if the exception is an OpenAI API key error.
+        """
+        import openai
+
+        if isinstance(e, openai.AuthenticationError):
+            error_details = e.json_body.get("error", {})
+            return error_details.get("code") == "invalid_api_key"
+        return False
+
+
+class ChatOpenAIProvider(BaseProvider, ChatOpenAI):
+    id = "openai-chat"
+    name = "OpenAI"
+    models = [
+        "gpt-3.5-turbo",
+        "gpt-3.5-turbo-0125",
+        "gpt-3.5-turbo-0301",  # Deprecated as of 2024-06-13
+        "gpt-3.5-turbo-0613",  # Deprecated as of 2024-06-13
+        "gpt-3.5-turbo-1106",
+        "gpt-3.5-turbo-16k",
+        "gpt-3.5-turbo-16k-0613",  # Deprecated as of 2024-06-13
+        "gpt-4",
+        "gpt-4-turbo-preview",
+        "gpt-4-0613",
+        "gpt-4-32k",
+        "gpt-4-32k-0613",
+        "gpt-4-0125-preview",
+        "gpt-4-1106-preview",
+    ]
+    model_id_key = "model_name"
+    pypi_package_deps = ["langchain_openai"]
+    auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
+
+    fields = [
+        TextField(
+            key="openai_api_base", label="Base API URL (optional)", format="text"
+        ),
+        TextField(
+            key="openai_organization", label="Organization (optional)", format="text"
+        ),
+        TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
+    ]
+
+    @classmethod
+    def is_api_key_exc(cls, e: Exception):
+        """
+        Determine if the exception is an OpenAI API key error.
+        """
+        import openai
+
+        if isinstance(e, openai.AuthenticationError):
+            error_details = e.json_body.get("error", {})
+            return error_details.get("code") == "invalid_api_key"
+        return False
+
+
+class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI):
+    id = "azure-chat-openai"
+    name = "Azure OpenAI"
+    models = ["*"]
+    model_id_key = "deployment_name"
+    model_id_label = "Deployment name"
+    pypi_package_deps = ["langchain_openai"]
+    auth_strategy = EnvAuthStrategy(name="AZURE_OPENAI_API_KEY")
+    registry = True
+
+    fields = [
+        TextField(
+            key="openai_api_base", label="Base API URL (required)", format="text"
+        ),
+        TextField(
+            key="openai_api_version", label="API version (required)", format="text"
+        ),
+        TextField(
+            key="openai_organization", label="Organization (optional)", format="text"
+        ),
+        TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
+    ]
+
+
+class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings):
+    id = "openai"
+    name = "OpenAI"
+    models = [
+        "text-embedding-ada-002",
+        "text-embedding-3-small",
+        "text-embedding-3-large",
+    ]
+    model_id_key = "model"
+    pypi_package_deps = ["langchain_openai"]
+    auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
index 821a7afa5..85af4579a 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
@@ -22,10 +22,8 @@
 from langchain.schema import LLMResult
 from langchain.utils import get_from_dict_or_env
 from langchain_community.chat_models import (
-    AzureChatOpenAI,
     BedrockChat,
     ChatAnthropic,
-    ChatOpenAI,
     QianfanChatEndpoint,
 )
 from langchain_community.llms import (
@@ -632,97 +630,6 @@ async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
         return await self._call_in_executor(*args, **kwargs)
 
 
-class OpenAIProvider(BaseProvider, OpenAI):
-    id = "openai"
-    name = "OpenAI"
-    models = ["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct"]
-    model_id_key = "model_name"
-    pypi_package_deps = ["openai"]
-    auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
-
-    @classmethod
-    def is_api_key_exc(cls, e: Exception):
-        """
-        Determine if the exception is an OpenAI API key error.
-        """
-        import openai
-
-        if isinstance(e, openai.AuthenticationError):
-            error_details = e.json_body.get("error", {})
-            return error_details.get("code") == "invalid_api_key"
-        return False
-
-
-class ChatOpenAIProvider(BaseProvider, ChatOpenAI):
-    id = "openai-chat"
-    name = "OpenAI"
-    models = [
-        "gpt-3.5-turbo",
-        "gpt-3.5-turbo-0125",
-        "gpt-3.5-turbo-0301",  # Deprecated as of 2024-06-13
-        "gpt-3.5-turbo-0613",  # Deprecated as of 2024-06-13
-        "gpt-3.5-turbo-1106",
-        "gpt-3.5-turbo-16k",
-        "gpt-3.5-turbo-16k-0613",  # Deprecated as of 2024-06-13
-        "gpt-4",
-        "gpt-4-turbo-preview",
-        "gpt-4-0613",
-        "gpt-4-32k",
-        "gpt-4-32k-0613",
-        "gpt-4-0125-preview",
-        "gpt-4-1106-preview",
-    ]
-    model_id_key = "model_name"
-    pypi_package_deps = ["openai"]
-    auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
-
-    fields = [
-        TextField(
-            key="openai_api_base", label="Base API URL (optional)", format="text"
-        ),
-        TextField(
-            key="openai_organization", label="Organization (optional)", format="text"
-        ),
-        TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
-    ]
-
-    @classmethod
-    def is_api_key_exc(cls, e: Exception):
-        """
-        Determine if the exception is an OpenAI API key error.
-        """
-        import openai
-
-        if isinstance(e, openai.AuthenticationError):
-            error_details = e.json_body.get("error", {})
-            return error_details.get("code") == "invalid_api_key"
-        return False
-
-
-class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI):
-    id = "azure-chat-openai"
-    name = "Azure OpenAI"
-    models = ["*"]
-    model_id_key = "deployment_name"
-    model_id_label = "Deployment name"
-    pypi_package_deps = ["openai"]
-    auth_strategy = EnvAuthStrategy(name="AZURE_OPENAI_API_KEY")
-    registry = True
-
-    fields = [
-        TextField(
-            key="openai_api_base", label="Base API URL (required)", format="text"
-        ),
-        TextField(
-            key="openai_api_version", label="API version (required)", format="text"
-        ),
-        TextField(
-            key="openai_organization", label="Organization (optional)", format="text"
-        ),
-        TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
-    ]
-
-
 class JsonContentHandler(LLMContentHandler):
     content_type = "application/json"
     accepts = "application/json"
diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml
index f6637d36c..00f7fe0d5 100644
--- a/packages/jupyter-ai-magics/pyproject.toml
+++ b/packages/jupyter-ai-magics/pyproject.toml
@@ -42,8 +42,8 @@ all = [
     "huggingface_hub",
     "ipywidgets",
     "langchain_nvidia_ai_endpoints",
+    "langchain-openai",
     "pillow",
-    "openai~=1.6.1",
     "boto3",
     "qianfan",
     "together",
@@ -56,9 +56,9 @@ anthropic = "jupyter_ai_magics:AnthropicProvider"
 cohere = "jupyter_ai_magics:CohereProvider"
 gpt4all = "jupyter_ai_magics:GPT4AllProvider"
 huggingface_hub = "jupyter_ai_magics:HfHubProvider"
-openai = "jupyter_ai_magics:OpenAIProvider"
-openai-chat = "jupyter_ai_magics:ChatOpenAIProvider"
-azure-chat-openai = "jupyter_ai_magics:AzureChatOpenAIProvider"
+openai = "jupyter_ai_magics.partner_providers.openai:OpenAIProvider"
+openai-chat = "jupyter_ai_magics.partner_providers.openai:ChatOpenAIProvider"
+azure-chat-openai = "jupyter_ai_magics.partner_providers.openai:AzureChatOpenAIProvider"
 sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider"
 amazon-bedrock = "jupyter_ai_magics:BedrockProvider"
 anthropic-chat = "jupyter_ai_magics:ChatAnthropicProvider"
@@ -73,7 +73,7 @@ bedrock = "jupyter_ai_magics:BedrockEmbeddingsProvider"
 cohere = "jupyter_ai_magics:CohereEmbeddingsProvider"
 gpt4all = "jupyter_ai_magics:GPT4AllEmbeddingsProvider"
 huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider"
-openai = "jupyter_ai_magics:OpenAIEmbeddingsProvider"
+openai = "jupyter_ai_magics.partner_providers.openai:OpenAIEmbeddingsProvider"
 qianfan = "jupyter_ai_magics:QianfanEmbeddingsEndpointProvider"
 
 [tool.hatch.version]
diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml
index 073abb8de..e650188cb 100644
--- a/packages/jupyter-ai/pyproject.toml
+++ b/packages/jupyter-ai/pyproject.toml
@@ -26,7 +26,6 @@ dependencies = [
     "jupyterlab~=4.0",
     "aiosqlite>=0.18",
     "importlib_metadata>=5.2.0",
-    "tiktoken",                  # required for OpenAIEmbeddings
     "jupyter_ai_magics",
     "dask[distributed]",
     "faiss-cpu",                 # Not distributed by official repo