Skip to content

Commit

Permalink
implement automatic CRI to us region area
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Nov 25, 2024
1 parent 37ade04 commit 7b9c7cf
Showing 1 changed file with 40 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@ class BedrockProvider(BaseProvider, BedrockLLM):
"cohere.command-text-v14",
"cohere.command-r-v1:0",
"cohere.command-r-plus-v1:0",
"meta.llama2-13b-chat-v1",
"meta.llama2-70b-chat-v1",
"meta.llama3-8b-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-405b-instruct-v1:0",
"mistral.mistral-7b-instruct-v0:2",
"mistral.mixtral-8x7b-instruct-v0:1",
"mistral.mistral-large-2402-v1:0",
Expand All @@ -58,30 +51,53 @@ async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
class BedrockChatProvider(BaseProvider, ChatBedrock):
id = "bedrock-chat"
name = "Amazon Bedrock Chat"

cri_models = [
# Anthropic models
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-5-haiku-20241022-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-5-sonnet-20241022-v2:0",
"anthropic.claude-3-opus-20240229-v1:0",
# Meta Llama 3.1 models
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
# Meta Llama 3.2 models
"meta.llama3-2-1b-instruct-v1:0",
"meta.llama3-2-3b-instruct-v1:0",
"meta.llama3-2-11b-instruct-v1:0",
"meta.llama3-2-90b-instruct-v1:0",
]
"""
List of model IDs that support cross-region inference (CRI) in the "us" region area.
- If a model supports CRI, we default to invoking CRI as it allows the model to be called regardless of the user's region.
- To invoke CRI, we prepend "us." to the model ID internally, which transforms it into an inference profile ID, indicating that CRI should be used.
- Currently, we only support CRI to the "us" region area.
Source: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
"""

models = [
*cri_models,
# Amazon models (no CRI)
"amazon.titan-text-express-v1",
"amazon.titan-text-lite-v1",
"amazon.titan-text-premier-v1:0",
# Anthropic v1 & v2 models (no CRI)
"anthropic.claude-v2",
"anthropic.claude-v2:1",
"anthropic.claude-instant-v1",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
"anthropic.claude-3-5-haiku-20241022-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-5-sonnet-20241022-v2:0",
# Meta Llama 2 models (no CRI)
"meta.llama2-13b-chat-v1",
"meta.llama2-70b-chat-v1",
# Meta Llama 3 models (no CRI)
"meta.llama3-8b-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
# Meta Llama 3.1 models (only one not supporting CRI)
"meta.llama3-1-405b-instruct-v1:0",
"meta.llama3-2-1b-instruct-v1:0",
"meta.llama3-2-3b-instruct-v1:0",
"meta.llama3-2-11b-instruct-v1:0",
"meta.llama3-2-90b-instruct-v1:0",
# Mistral models
"mistral.mistral-7b-instruct-v0:2",
"mistral.mixtral-8x7b-instruct-v0:1",
"mistral.mistral-large-2402-v1:0",
Expand All @@ -91,26 +107,20 @@ class BedrockChatProvider(BaseProvider, ChatBedrock):
pypi_package_deps = ["langchain-aws"]
auth_strategy = AwsAuthStrategy()
fields = [
TextField(
key="region_area",
label="Cross-region inference area (possibly required)",
format="text",
),
TextField(
key="credentials_profile_name",
label="AWS profile (optional)",
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
]
help = "Specify the Cross Region Inference (CRI) Area Name. \
Look this up [here](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html#inference-profiles-support-system)."

def __init__(self, *args, **kwargs):
region_area = kwargs.pop("region_area", None)
if region_area:
kwargs["model_id"] = region_area + "." + kwargs["model_id"]
super().__init__(*args, **kwargs)
model_id = kwargs.pop("model_id")
if model_id in self.cri_models:
model_id = "us." + model_id

super().__init__(*args, **kwargs, model_id=model_id)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)
Expand Down

0 comments on commit 7b9c7cf

Please sign in to comment.