forked from jupyterlab/jupyter-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Migrate to
langchain-aws
for AWS providers (jupyterlab#909)
* migrate to langchain-aws * pre-commit * update aws provider dependencies in docs * correct SM endpoints docs URL Co-authored-by: Jason Weill <[email protected]> * add new Cohere model IDs to BedrockEmbeddings Co-authored-by: Piyush Jain <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use BedrockLLM instead of Bedrock class * add Amazon, Meta, Mistral models to BedrockChatProvider * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Jason Weill <[email protected]> Co-authored-by: Piyush Jain <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
Showing
6 changed files
with
193 additions
and
177 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
183 changes: 183 additions & 0 deletions
183
packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import copy | ||
import json | ||
from typing import Any, Coroutine, Dict | ||
|
||
from jsonpath_ng import parse | ||
from langchain_aws import BedrockEmbeddings, BedrockLLM, ChatBedrock, SagemakerEndpoint | ||
from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler | ||
from langchain_core.outputs import LLMResult | ||
|
||
from ..embedding_providers import BaseEmbeddingsProvider | ||
from ..providers import AwsAuthStrategy, BaseProvider, MultilineTextField, TextField | ||
|
||
|
||
# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html | ||
class BedrockProvider(BaseProvider, BedrockLLM): | ||
id = "bedrock" | ||
name = "Amazon Bedrock" | ||
models = [ | ||
"amazon.titan-text-express-v1", | ||
"amazon.titan-text-lite-v1", | ||
"ai21.j2-ultra-v1", | ||
"ai21.j2-mid-v1", | ||
"cohere.command-light-text-v14", | ||
"cohere.command-text-v14", | ||
"cohere.command-r-v1:0", | ||
"cohere.command-r-plus-v1:0", | ||
"meta.llama2-13b-chat-v1", | ||
"meta.llama2-70b-chat-v1", | ||
"meta.llama3-8b-instruct-v1:0", | ||
"meta.llama3-70b-instruct-v1:0", | ||
"meta.llama3-1-8b-instruct-v1:0", | ||
"meta.llama3-1-70b-instruct-v1:0", | ||
"mistral.mistral-7b-instruct-v0:2", | ||
"mistral.mixtral-8x7b-instruct-v0:1", | ||
"mistral.mistral-large-2402-v1:0", | ||
] | ||
model_id_key = "model_id" | ||
pypi_package_deps = ["langchain-aws"] | ||
auth_strategy = AwsAuthStrategy() | ||
fields = [ | ||
TextField( | ||
key="credentials_profile_name", | ||
label="AWS profile (optional)", | ||
format="text", | ||
), | ||
TextField(key="region_name", label="Region name (optional)", format="text"), | ||
] | ||
|
||
async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: | ||
return await self._call_in_executor(*args, **kwargs) | ||
|
||
|
||
# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html | ||
class BedrockChatProvider(BaseProvider, ChatBedrock): | ||
id = "bedrock-chat" | ||
name = "Amazon Bedrock Chat" | ||
models = [ | ||
"amazon.titan-text-express-v1", | ||
"amazon.titan-text-lite-v1", | ||
"anthropic.claude-v2", | ||
"anthropic.claude-v2:1", | ||
"anthropic.claude-instant-v1", | ||
"anthropic.claude-3-sonnet-20240229-v1:0", | ||
"anthropic.claude-3-haiku-20240307-v1:0", | ||
"anthropic.claude-3-opus-20240229-v1:0", | ||
"anthropic.claude-3-5-sonnet-20240620-v1:0", | ||
"meta.llama2-13b-chat-v1", | ||
"meta.llama2-70b-chat-v1", | ||
"meta.llama3-8b-instruct-v1:0", | ||
"meta.llama3-70b-instruct-v1:0", | ||
"meta.llama3-1-8b-instruct-v1:0", | ||
"meta.llama3-1-70b-instruct-v1:0", | ||
"mistral.mistral-7b-instruct-v0:2", | ||
"mistral.mixtral-8x7b-instruct-v0:1", | ||
"mistral.mistral-large-2402-v1:0", | ||
] | ||
model_id_key = "model_id" | ||
pypi_package_deps = ["langchain-aws"] | ||
auth_strategy = AwsAuthStrategy() | ||
fields = [ | ||
TextField( | ||
key="credentials_profile_name", | ||
label="AWS profile (optional)", | ||
format="text", | ||
), | ||
TextField(key="region_name", label="Region name (optional)", format="text"), | ||
] | ||
|
||
async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: | ||
return await self._call_in_executor(*args, **kwargs) | ||
|
||
async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]: | ||
return await self._generate_in_executor(*args, **kwargs) | ||
|
||
@property | ||
def allows_concurrency(self): | ||
return not "anthropic" in self.model_id | ||
|
||
|
||
# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html | ||
class BedrockEmbeddingsProvider(BaseEmbeddingsProvider, BedrockEmbeddings): | ||
id = "bedrock" | ||
name = "Bedrock" | ||
models = [ | ||
"amazon.titan-embed-text-v1", | ||
"amazon.titan-embed-text-v2:0", | ||
"cohere.embed-english-v3", | ||
"cohere.embed-multilingual-v3", | ||
] | ||
model_id_key = "model_id" | ||
pypi_package_deps = ["langchain-aws"] | ||
auth_strategy = AwsAuthStrategy() | ||
|
||
|
||
class JsonContentHandler(LLMContentHandler): | ||
content_type = "application/json" | ||
accepts = "application/json" | ||
|
||
def __init__(self, request_schema, response_path): | ||
self.request_schema = json.loads(request_schema) | ||
self.response_path = response_path | ||
self.response_parser = parse(response_path) | ||
|
||
def replace_values(self, old_val, new_val, d: Dict[str, Any]): | ||
"""Replaces values of a dictionary recursively.""" | ||
for key, val in d.items(): | ||
if val == old_val: | ||
d[key] = new_val | ||
if isinstance(val, dict): | ||
self.replace_values(old_val, new_val, val) | ||
|
||
return d | ||
|
||
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: | ||
request_obj = copy.deepcopy(self.request_schema) | ||
self.replace_values("<prompt>", prompt, request_obj) | ||
request = json.dumps(request_obj).encode("utf-8") | ||
return request | ||
|
||
def transform_output(self, output: bytes) -> str: | ||
response_json = json.loads(output.read().decode("utf-8")) | ||
matches = self.response_parser.find(response_json) | ||
return matches[0].value | ||
|
||
|
||
class SmEndpointProvider(BaseProvider, SagemakerEndpoint): | ||
id = "sagemaker-endpoint" | ||
name = "SageMaker endpoint" | ||
models = ["*"] | ||
model_id_key = "endpoint_name" | ||
model_id_label = "Endpoint name" | ||
# This all needs to be on one line of markdown, for use in a table | ||
help = ( | ||
"Specify an endpoint name as the model ID. " | ||
"In addition, you must specify a region name, request schema, and response path. " | ||
"For more information, see the documentation about [SageMaker endpoints deployment](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-deploy-models.html) " | ||
"and about [using magic commands with SageMaker endpoints](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#using-magic-commands-with-sagemaker-endpoints)." | ||
) | ||
|
||
pypi_package_deps = ["langchain-aws"] | ||
auth_strategy = AwsAuthStrategy() | ||
registry = True | ||
fields = [ | ||
TextField(key="region_name", label="Region name (required)", format="text"), | ||
MultilineTextField( | ||
key="request_schema", label="Request schema (required)", format="json" | ||
), | ||
TextField( | ||
key="response_path", label="Response path (required)", format="jsonpath" | ||
), | ||
] | ||
|
||
def __init__(self, *args, **kwargs): | ||
request_schema = kwargs.pop("request_schema") | ||
response_path = kwargs.pop("response_path") | ||
content_handler = JsonContentHandler( | ||
request_schema=request_schema, response_path=response_path | ||
) | ||
|
||
super().__init__(*args, **kwargs, content_handler=content_handler) | ||
|
||
async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: | ||
return await self._call_in_executor(*args, **kwargs) |
Oops, something went wrong.