diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py index 07a29a706..66a6838c5 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py @@ -26,13 +26,6 @@ class BedrockProvider(BaseProvider, BedrockLLM): "cohere.command-text-v14", "cohere.command-r-v1:0", "cohere.command-r-plus-v1:0", - "meta.llama2-13b-chat-v1", - "meta.llama2-70b-chat-v1", - "meta.llama3-8b-instruct-v1:0", - "meta.llama3-70b-instruct-v1:0", - "meta.llama3-1-8b-instruct-v1:0", - "meta.llama3-1-70b-instruct-v1:0", - "meta.llama3-1-405b-instruct-v1:0", "mistral.mistral-7b-instruct-v0:2", "mistral.mixtral-8x7b-instruct-v0:1", "mistral.mistral-large-2402-v1:0", @@ -58,30 +51,53 @@ async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: class BedrockChatProvider(BaseProvider, ChatBedrock): id = "bedrock-chat" name = "Amazon Bedrock Chat" + + cri_models = [ + # Anthropic models + "anthropic.claude-3-haiku-20240307-v1:0", + "anthropic.claude-3-5-haiku-20241022-v1:0", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "anthropic.claude-3-5-sonnet-20241022-v2:0", + "anthropic.claude-3-opus-20240229-v1:0", + # Meta Llama 3.1 models + "meta.llama3-1-8b-instruct-v1:0", + "meta.llama3-1-70b-instruct-v1:0", + # Meta Llama 3.2 models + "meta.llama3-2-1b-instruct-v1:0", + "meta.llama3-2-3b-instruct-v1:0", + "meta.llama3-2-11b-instruct-v1:0", + "meta.llama3-2-90b-instruct-v1:0", + ] + """ + List of model IDs that support cross-region inference (CRI) in the "us" region area. + + - If a model supports CRI, we default to invoking CRI as it allows the model to be called regardless of the user's region. + - To invoke CRI, we prepend "us." to the model ID internally, which transforms it into an inference profile ID, indicating that CRI should be used. + - Currently, we only support CRI to the "us" region area. + + Source: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html + """ + models = [ + *cri_models, + # Amazon models (no CRI) "amazon.titan-text-express-v1", "amazon.titan-text-lite-v1", "amazon.titan-text-premier-v1:0", + # Anthropic v1 & v2 models (no CRI) "anthropic.claude-v2", "anthropic.claude-v2:1", "anthropic.claude-instant-v1", - "anthropic.claude-3-sonnet-20240229-v1:0", - "anthropic.claude-3-haiku-20240307-v1:0", - "anthropic.claude-3-opus-20240229-v1:0", - "anthropic.claude-3-5-haiku-20241022-v1:0", - "anthropic.claude-3-5-sonnet-20240620-v1:0", - "anthropic.claude-3-5-sonnet-20241022-v2:0", + # Meta Llama 2 models (no CRI) "meta.llama2-13b-chat-v1", "meta.llama2-70b-chat-v1", + # Meta Llama 3 models (no CRI) "meta.llama3-8b-instruct-v1:0", "meta.llama3-70b-instruct-v1:0", - "meta.llama3-1-8b-instruct-v1:0", - "meta.llama3-1-70b-instruct-v1:0", + # Meta Llama 3.1 models (only one not supporting CRI) "meta.llama3-1-405b-instruct-v1:0", - "meta.llama3-2-1b-instruct-v1:0", - "meta.llama3-2-3b-instruct-v1:0", - "meta.llama3-2-11b-instruct-v1:0", - "meta.llama3-2-90b-instruct-v1:0", + # Mistral models "mistral.mistral-7b-instruct-v0:2", "mistral.mixtral-8x7b-instruct-v0:1", "mistral.mistral-large-2402-v1:0", @@ -91,11 +107,6 @@ class BedrockChatProvider(BaseProvider, ChatBedrock): pypi_package_deps = ["langchain-aws"] auth_strategy = AwsAuthStrategy() fields = [ - TextField( - key="region_area", - label="Cross-region inference area (possibly required)", - format="text", - ), TextField( key="credentials_profile_name", label="AWS profile (optional)", @@ -103,14 +114,13 @@ class BedrockChatProvider(BaseProvider, ChatBedrock): ), TextField(key="region_name", label="Region name (optional)", format="text"), ] - help = "Specify the Cross Region Inference (CRI) Area Name. \ - Look this up [here](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html#inference-profiles-support-system)." def __init__(self, *args, **kwargs): - region_area = kwargs.pop("region_area", None) - if region_area: - kwargs["model_id"] = region_area + "." + kwargs["model_id"] - super().__init__(*args, **kwargs) + model_id = kwargs.pop("model_id") + if model_id in self.cri_models: + model_id = "us." + model_id + + super().__init__(*args, **kwargs, model_id=model_id) async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs)