-
-
Notifications
You must be signed in to change notification settings - Fork 341
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
174 additions
and
174 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
167 changes: 167 additions & 0 deletions
167
packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
from typing import Any, Coroutine, Dict | ||
import copy | ||
import json | ||
from jsonpath_ng import parse | ||
|
||
from langchain_aws import Bedrock, ChatBedrock, SagemakerEndpoint, BedrockEmbeddings | ||
from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler | ||
from langchain_core.outputs import LLMResult | ||
|
||
from ..providers import BaseProvider, AwsAuthStrategy, TextField, MultilineTextField | ||
from ..embedding_providers import BaseEmbeddingsProvider | ||
|
||
# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html | ||
class BedrockProvider(BaseProvider, Bedrock): | ||
id = "bedrock" | ||
name = "Amazon Bedrock" | ||
models = [ | ||
"amazon.titan-text-express-v1", | ||
"amazon.titan-text-lite-v1", | ||
"ai21.j2-ultra-v1", | ||
"ai21.j2-mid-v1", | ||
"cohere.command-light-text-v14", | ||
"cohere.command-text-v14", | ||
"cohere.command-r-v1:0", | ||
"cohere.command-r-plus-v1:0", | ||
"meta.llama2-13b-chat-v1", | ||
"meta.llama2-70b-chat-v1", | ||
"meta.llama3-8b-instruct-v1:0", | ||
"meta.llama3-70b-instruct-v1:0", | ||
"meta.llama3-1-8b-instruct-v1:0", | ||
"meta.llama3-1-70b-instruct-v1:0", | ||
"mistral.mistral-7b-instruct-v0:2", | ||
"mistral.mixtral-8x7b-instruct-v0:1", | ||
"mistral.mistral-large-2402-v1:0", | ||
] | ||
model_id_key = "model_id" | ||
pypi_package_deps = ["langchain-aws"] | ||
auth_strategy = AwsAuthStrategy() | ||
fields = [ | ||
TextField( | ||
key="credentials_profile_name", | ||
label="AWS profile (optional)", | ||
format="text", | ||
), | ||
TextField(key="region_name", label="Region name (optional)", format="text"), | ||
] | ||
|
||
async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: | ||
return await self._call_in_executor(*args, **kwargs) | ||
|
||
|
||
# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html | ||
class BedrockChatProvider(BaseProvider, ChatBedrock): | ||
id = "bedrock-chat" | ||
name = "Amazon Bedrock Chat" | ||
models = [ | ||
"anthropic.claude-v2", | ||
"anthropic.claude-v2:1", | ||
"anthropic.claude-instant-v1", | ||
"anthropic.claude-3-sonnet-20240229-v1:0", | ||
"anthropic.claude-3-haiku-20240307-v1:0", | ||
"anthropic.claude-3-opus-20240229-v1:0", | ||
"anthropic.claude-3-5-sonnet-20240620-v1:0", | ||
] | ||
model_id_key = "model_id" | ||
pypi_package_deps = ["langchain-aws"] | ||
auth_strategy = AwsAuthStrategy() | ||
fields = [ | ||
TextField( | ||
key="credentials_profile_name", | ||
label="AWS profile (optional)", | ||
format="text", | ||
), | ||
TextField(key="region_name", label="Region name (optional)", format="text"), | ||
] | ||
|
||
async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: | ||
return await self._call_in_executor(*args, **kwargs) | ||
|
||
async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]: | ||
return await self._generate_in_executor(*args, **kwargs) | ||
|
||
@property | ||
def allows_concurrency(self): | ||
return not "anthropic" in self.model_id | ||
|
||
|
||
# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html | ||
class BedrockEmbeddingsProvider(BaseEmbeddingsProvider, BedrockEmbeddings): | ||
id = "bedrock" | ||
name = "Bedrock" | ||
models = ["amazon.titan-embed-text-v1", "amazon.titan-embed-text-v2:0"] | ||
model_id_key = "model_id" | ||
pypi_package_deps = ["langchain-aws"] | ||
auth_strategy = AwsAuthStrategy() | ||
|
||
|
||
class JsonContentHandler(LLMContentHandler): | ||
content_type = "application/json" | ||
accepts = "application/json" | ||
|
||
def __init__(self, request_schema, response_path): | ||
self.request_schema = json.loads(request_schema) | ||
self.response_path = response_path | ||
self.response_parser = parse(response_path) | ||
|
||
def replace_values(self, old_val, new_val, d: Dict[str, Any]): | ||
"""Replaces values of a dictionary recursively.""" | ||
for key, val in d.items(): | ||
if val == old_val: | ||
d[key] = new_val | ||
if isinstance(val, dict): | ||
self.replace_values(old_val, new_val, val) | ||
|
||
return d | ||
|
||
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: | ||
request_obj = copy.deepcopy(self.request_schema) | ||
self.replace_values("<prompt>", prompt, request_obj) | ||
request = json.dumps(request_obj).encode("utf-8") | ||
return request | ||
|
||
def transform_output(self, output: bytes) -> str: | ||
response_json = json.loads(output.read().decode("utf-8")) | ||
matches = self.response_parser.find(response_json) | ||
return matches[0].value | ||
|
||
|
||
class SmEndpointProvider(BaseProvider, SagemakerEndpoint): | ||
id = "sagemaker-endpoint" | ||
name = "SageMaker endpoint" | ||
models = ["*"] | ||
model_id_key = "endpoint_name" | ||
model_id_label = "Endpoint name" | ||
# This all needs to be on one line of markdown, for use in a table | ||
help = ( | ||
"Specify an endpoint name as the model ID. " | ||
"In addition, you must specify a region name, request schema, and response path. " | ||
"For more information, see the documentation about [SageMaker endpoints deployment](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-deployment.html) " | ||
"and about [using magic commands with SageMaker endpoints](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#using-magic-commands-with-sagemaker-endpoints)." | ||
) | ||
|
||
pypi_package_deps = ["langchain-aws"] | ||
auth_strategy = AwsAuthStrategy() | ||
registry = True | ||
fields = [ | ||
TextField(key="region_name", label="Region name (required)", format="text"), | ||
MultilineTextField( | ||
key="request_schema", label="Request schema (required)", format="json" | ||
), | ||
TextField( | ||
key="response_path", label="Response path (required)", format="jsonpath" | ||
), | ||
] | ||
|
||
def __init__(self, *args, **kwargs): | ||
request_schema = kwargs.pop("request_schema") | ||
response_path = kwargs.pop("response_path") | ||
content_handler = JsonContentHandler( | ||
request_schema=request_schema, response_path=response_path | ||
) | ||
|
||
super().__init__(*args, **kwargs, content_handler=content_handler) | ||
|
||
async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: | ||
return await self._call_in_executor(*args, **kwargs) | ||
|
Oops, something went wrong.