Skip to content

Commit

Permalink
Add support for custom/provisioned models in Bedrock (#922)
Browse files Browse the repository at this point in the history
* add BedrockCustomProvider

* pre-commit

* format help text

* show model ID field

* edit BedrockCustomProvider help text

* add mention of BedrockCustom in docs
  • Loading branch information
dlqqq authored Jul 31, 2024
1 parent 73d72e9 commit b953568
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 17 deletions.
36 changes: 19 additions & 17 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,23 +153,24 @@ Jupyter AI supports a wide range of model providers and models. To use Jupyter A

Jupyter AI supports the following model providers:

| Provider | Provider ID | Environment variable(s) | Python package(s) |
|---------------------|----------------------|----------------------------|---------------------------------|
| AI21 | `ai21` | `AI21_API_KEY` | `ai21` |
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `langchain-anthropic` |
| Anthropic (chat) | `anthropic-chat` | `ANTHROPIC_API_KEY` | `langchain-anthropic` |
| Bedrock | `bedrock` | N/A | `langchain-aws` |
| Bedrock (chat) | `bedrock-chat` | N/A | `langchain-aws` |
| Cohere | `cohere` | `COHERE_API_KEY` | `langchain_cohere` |
| ERNIE-Bot | `qianfan` | `QIANFAN_AK`, `QIANFAN_SK` | `qianfan` |
| Gemini | `gemini` | `GOOGLE_API_KEY` | `langchain-google-genai` |
| GPT4All | `gpt4all` | N/A | `gpt4all` |
| Hugging Face Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` |
| MistralAI | `mistralai` | `MISTRAL_API_KEY` | `langchain-mistralai` |
| NVIDIA | `nvidia-chat` | `NVIDIA_API_KEY` | `langchain_nvidia_ai_endpoints` |
| OpenAI | `openai` | `OPENAI_API_KEY` | `langchain-openai` |
| OpenAI (chat) | `openai-chat` | `OPENAI_API_KEY` | `langchain-openai` |
| SageMaker | `sagemaker-endpoint` | N/A | `langchain-aws` |
| Provider | Provider ID | Environment variable(s) | Python package(s) |
|------------------------------|----------------------|----------------------------|-------------------------------------------|
| AI21 | `ai21` | `AI21_API_KEY` | `ai21` |
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `langchain-anthropic` |
| Anthropic (chat) | `anthropic-chat` | `ANTHROPIC_API_KEY` | `langchain-anthropic` |
| Bedrock | `bedrock` | N/A | `langchain-aws` |
| Bedrock (chat) | `bedrock-chat` | N/A | `langchain-aws` |
| Bedrock (custom/provisioned) | `bedrock-custom` | N/A | `langchain-aws` |
| Cohere | `cohere` | `COHERE_API_KEY` | `langchain-cohere` |
| ERNIE-Bot | `qianfan` | `QIANFAN_AK`, `QIANFAN_SK` | `qianfan` |
| Gemini | `gemini` | `GOOGLE_API_KEY` | `langchain-google-genai` |
| GPT4All | `gpt4all` | N/A | `gpt4all` |
| Hugging Face Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` |
| MistralAI | `mistralai` | `MISTRAL_API_KEY` | `langchain-mistralai` |
| NVIDIA | `nvidia-chat` | `NVIDIA_API_KEY` | `langchain_nvidia_ai_endpoints` |
| OpenAI | `openai` | `OPENAI_API_KEY` | `langchain-openai` |
| OpenAI (chat) | `openai-chat` | `OPENAI_API_KEY` | `langchain-openai` |
| SageMaker endpoint | `sagemaker-endpoint` | N/A | `langchain-aws` |

The environment variable names shown above are also the names of the settings keys used when setting up the chat interface.
If multiple variables are listed for a provider, **all** must be specified.
Expand Down Expand Up @@ -615,6 +616,7 @@ We currently support the following language model providers:
- `anthropic-chat`
- `bedrock`
- `bedrock-chat`
- `bedrock-custom`
- `cohere`
- `huggingface_hub`
- `nvidia-chat`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,30 @@ def allows_concurrency(self):
return not "anthropic" in self.model_id


class BedrockCustomProvider(BaseProvider, ChatBedrock):
id = "bedrock-custom"
name = "Amazon Bedrock (custom/provisioned)"
models = ["*"]
model_id_key = "model_id"
model_id_label = "Model ID"
pypi_package_deps = ["langchain-aws"]
auth_strategy = AwsAuthStrategy()
fields = [
TextField(key="provider", label="Provider (required)", format="text"),
TextField(key="region_name", label="Region name (optional)", format="text"),
TextField(
key="credentials_profile_name",
label="AWS profile (optional)",
format="text",
),
]
help = (
"Specify the ARN (Amazon Resource Name) of the custom/provisioned model as the model ID. For more information, see the [Amazon Bedrock model IDs documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html).\n\n"
"The model provider must also be specified below. This is the provider of your foundation model *in lowercase*, e.g. `amazon`, `anthropic`, `meta`, or `mistral`."
)
registry = True


# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockEmbeddingsProvider(BaseEmbeddingsProvider, BedrockEmbeddings):
id = "bedrock"
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ azure-chat-openai = "jupyter_ai_magics.partner_providers.openai:AzureChatOpenAIP
sagemaker-endpoint = "jupyter_ai_magics.partner_providers.aws:SmEndpointProvider"
amazon-bedrock = "jupyter_ai_magics.partner_providers.aws:BedrockProvider"
amazon-bedrock-chat = "jupyter_ai_magics.partner_providers.aws:BedrockChatProvider"
amazon-bedrock-custom = "jupyter_ai_magics.partner_providers.aws:BedrockCustomProvider"
qianfan = "jupyter_ai_magics:QianfanProvider"
nvidia-chat = "jupyter_ai_magics.partner_providers.nvidia:ChatNVIDIAProvider"
together-ai = "jupyter_ai_magics:TogetherAIProvider"
Expand Down

0 comments on commit b953568

Please sign in to comment.