diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py index 106580c8a..30bfd628f 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py @@ -83,7 +83,9 @@ class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI): pypi_package_deps = ["langchain_openai"] # Confusingly, langchain uses both OPENAI_API_KEY and AZURE_OPENAI_API_KEY for azure # https://github.com/langchain-ai/langchain/blob/f2579096993ae460516a0aae1d3e09f3eb5c1772/libs/partners/openai/langchain_openai/llms/azure.py#L85 - auth_strategy = EnvAuthStrategy(name="AZURE_OPENAI_API_KEY") + auth_strategy = EnvAuthStrategy( + name="AZURE_OPENAI_API_KEY", keyword_param="openai_api_key" + ) registry = True fields = [ diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index baff893a9..1bbc4ce57 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -33,19 +33,13 @@ from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import Runnable from langchain.utils import get_from_dict_or_env -from langchain_community.chat_models import ( - BedrockChat, - ChatAnthropic, - QianfanChatEndpoint, -) +from langchain_community.chat_models import BedrockChat, QianfanChatEndpoint from langchain_community.llms import ( AI21, - Anthropic, Bedrock, Cohere, GPT4All, HuggingFaceEndpoint, - OpenAI, SagemakerEndpoint, Together, ) @@ -111,10 +105,23 @@ class EnvAuthStrategy(BaseModel): - """Require one auth token via an environment variable.""" + """ + Describes a provider that uses a single authentication token, which is + passed either as an environment variable or as a keyword argument. + """ type: Literal["env"] = "env" + name: str + """The name of the environment variable, e.g. `'ANTHROPIC_API_KEY'`.""" + + keyword_param: Optional[str] + """ + If unset (default), the authentication token is provided as a keyword + argument with the parameter equal to the environment variable name in + lowercase. If set to some string `k`, the authentication token will be + passed using the keyword parameter `k`. + """ class MultiEnvAuthStrategy(BaseModel): diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index a05f9b3ca..d7674c5a2 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -448,8 +448,12 @@ def _provider_params(self, key, listing): _, Provider = get_em_provider(gid, listing) authn_fields = {} if Provider.auth_strategy and Provider.auth_strategy.type == "env": + keyword_param = ( + Provider.auth_strategy.keyword_param + or Provider.auth_strategy.name.lower() + ) key_name = Provider.auth_strategy.name - authn_fields[key_name.lower()] = config.api_keys[key_name] + authn_fields[keyword_param] = config.api_keys[key_name] return { "model_id": lid,