From 39ccac8ea5927f83e94cb1f647458c15f9acc82f Mon Sep 17 00:00:00 2001 From: Srikant Garnaik Date: Wed, 7 Aug 2024 19:21:47 +0530 Subject: [PATCH 1/2] Add support for Azure Open AI Embeddings to jupyter AI --- .../partner_providers/openai.py | 19 ++++++++++++++++++- packages/jupyter-ai-magics/pyproject.toml | 1 + 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py index afba7c2b6..e804968f8 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py @@ -1,4 +1,4 @@ -from langchain_openai import AzureChatOpenAI, ChatOpenAI, OpenAI, OpenAIEmbeddings +from langchain_openai import AzureChatOpenAI, ChatOpenAI, OpenAI, OpenAIEmbeddings, AzureOpenAIEmbeddings from ..embedding_providers import BaseEmbeddingsProvider from ..providers import BaseProvider, EnvAuthStrategy, TextField @@ -106,3 +106,20 @@ class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings): model_id_key = "model" pypi_package_deps = ["langchain_openai"] auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") + + +class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbeddings): + id = "azure" + name = "Azure OpenAI" + models = [ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", + ] + model_id_key = "azure_deployment" + pypi_package_deps = ["langchain_openai"] + auth_strategy = EnvAuthStrategy( + name="AZURE_OPENAI_API_KEY", keyword_param="openai_api_key" + ) + + registry = True diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index 02eb8b9e5..d42a9bc7b 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -76,6 +76,7 @@ gemini = "jupyter_ai_magics.partner_providers.gemini:GeminiProvider" mistralai = "jupyter_ai_magics.partner_providers.mistralai:MistralAIProvider" [project.entry-points."jupyter_ai.embeddings_model_providers"] +azure = "jupyter_ai_magics.partner_providers.openai:AzureOpenAIEmbeddingsProvider" bedrock = "jupyter_ai_magics.partner_providers.aws:BedrockEmbeddingsProvider" cohere = "jupyter_ai_magics.partner_providers.cohere:CohereEmbeddingsProvider" mistralai = "jupyter_ai_magics.partner_providers.mistralai:MistralAIEmbeddingsProvider" From 3ac5700bf01c76add9aaefd8a8ca646d1449b2a6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Aug 2024 13:59:42 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jupyter_ai_magics/partner_providers/openai.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py index e804968f8..5d8f49073 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py @@ -1,4 +1,10 @@ -from langchain_openai import AzureChatOpenAI, ChatOpenAI, OpenAI, OpenAIEmbeddings, AzureOpenAIEmbeddings +from langchain_openai import ( + AzureChatOpenAI, + AzureOpenAIEmbeddings, + ChatOpenAI, + OpenAI, + OpenAIEmbeddings, +) from ..embedding_providers import BaseEmbeddingsProvider from ..providers import BaseProvider, EnvAuthStrategy, TextField @@ -121,5 +127,5 @@ class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbedding auth_strategy = EnvAuthStrategy( name="AZURE_OPENAI_API_KEY", keyword_param="openai_api_key" ) - + registry = True