Skip to content

Commit

Permalink
Migrate to langchain-aws for AWS providers (jupyterlab#909)
Browse files Browse the repository at this point in the history
* migrate to langchain-aws

* pre-commit

* update aws provider dependencies in docs

* correct SM endpoints docs URL

Co-authored-by: Jason Weill <[email protected]>

* add new Cohere model IDs to BedrockEmbeddings

Co-authored-by: Piyush Jain <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use BedrockLLM instead of Bedrock class

* add Amazon, Meta, Mistral models to BedrockChatProvider

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Jason Weill <[email protected]>
Co-authored-by: Piyush Jain <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored and Marchlak committed Oct 28, 2024
1 parent 01e5611 commit f59a9fe
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 177 deletions.
6 changes: 3 additions & 3 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ Jupyter AI supports the following model providers:
| AI21 | `ai21` | `AI21_API_KEY` | `ai21` |
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `langchain-anthropic` |
| Anthropic (chat) | `anthropic-chat` | `ANTHROPIC_API_KEY` | `langchain-anthropic` |
| Bedrock | `bedrock` | N/A | `boto3` |
| Bedrock (chat) | `bedrock-chat` | N/A | `boto3` |
| Bedrock | `bedrock` | N/A | `langchain-aws` |
| Bedrock (chat) | `bedrock-chat` | N/A | `langchain-aws` |
| Cohere | `cohere` | `COHERE_API_KEY` | `langchain_cohere` |
| ERNIE-Bot | `qianfan` | `QIANFAN_AK`, `QIANFAN_SK` | `qianfan` |
| Gemini | `gemini` | `GOOGLE_API_KEY` | `langchain-google-genai` |
Expand All @@ -169,7 +169,7 @@ Jupyter AI supports the following model providers:
| NVIDIA | `nvidia-chat` | `NVIDIA_API_KEY` | `langchain_nvidia_ai_endpoints` |
| OpenAI | `openai` | `OPENAI_API_KEY` | `langchain-openai` |
| OpenAI (chat) | `openai-chat` | `OPENAI_API_KEY` | `langchain-openai` |
| SageMaker | `sagemaker-endpoint` | N/A | `boto3` |
| SageMaker | `sagemaker-endpoint` | N/A | `langchain-aws` |

The environment variable names shown above are also the names of the settings keys used when setting up the chat interface.
If multiple variables are listed for a provider, **all** must be specified.
Expand Down
4 changes: 0 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# expose embedding model providers on the package root
from .embedding_providers import (
BaseEmbeddingsProvider,
BedrockEmbeddingsProvider,
GPT4AllEmbeddingsProvider,
HfHubEmbeddingsProvider,
OllamaEmbeddingsProvider,
Expand All @@ -20,13 +19,10 @@
from .providers import (
AI21Provider,
BaseProvider,
BedrockChatProvider,
BedrockProvider,
GPT4AllProvider,
HfHubProvider,
OllamaProvider,
QianfanProvider,
SmEndpointProvider,
TogetherAIProvider,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

from jupyter_ai_magics.providers import (
AuthStrategy,
AwsAuthStrategy,
EnvAuthStrategy,
Field,
MultiEnvAuthStrategy,
)
from langchain.pydantic_v1 import BaseModel, Extra
from langchain_community.embeddings import (
BedrockEmbeddings,
GPT4AllEmbeddings,
HuggingFaceHubEmbeddings,
OllamaEmbeddings,
Expand Down Expand Up @@ -93,16 +91,6 @@ class HfHubEmbeddingsProvider(BaseEmbeddingsProvider, HuggingFaceHubEmbeddings):
registry = True


# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockEmbeddingsProvider(BaseEmbeddingsProvider, BedrockEmbeddings):
id = "bedrock"
name = "Bedrock"
models = ["amazon.titan-embed-text-v1", "amazon.titan-embed-text-v2:0"]
model_id_key = "model_id"
pypi_package_deps = ["boto3"]
auth_strategy = AwsAuthStrategy()


class GPT4AllEmbeddingsProvider(BaseEmbeddingsProvider, GPT4AllEmbeddings):
def __init__(self, **kwargs):
from gpt4all import GPT4All
Expand Down
183 changes: 183 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import copy
import json
from typing import Any, Coroutine, Dict

from jsonpath_ng import parse
from langchain_aws import BedrockEmbeddings, BedrockLLM, ChatBedrock, SagemakerEndpoint
from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler
from langchain_core.outputs import LLMResult

from ..embedding_providers import BaseEmbeddingsProvider
from ..providers import AwsAuthStrategy, BaseProvider, MultilineTextField, TextField


# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockProvider(BaseProvider, BedrockLLM):
id = "bedrock"
name = "Amazon Bedrock"
models = [
"amazon.titan-text-express-v1",
"amazon.titan-text-lite-v1",
"ai21.j2-ultra-v1",
"ai21.j2-mid-v1",
"cohere.command-light-text-v14",
"cohere.command-text-v14",
"cohere.command-r-v1:0",
"cohere.command-r-plus-v1:0",
"meta.llama2-13b-chat-v1",
"meta.llama2-70b-chat-v1",
"meta.llama3-8b-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"mistral.mistral-7b-instruct-v0:2",
"mistral.mixtral-8x7b-instruct-v0:1",
"mistral.mistral-large-2402-v1:0",
]
model_id_key = "model_id"
pypi_package_deps = ["langchain-aws"]
auth_strategy = AwsAuthStrategy()
fields = [
TextField(
key="credentials_profile_name",
label="AWS profile (optional)",
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
]

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)


# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockChatProvider(BaseProvider, ChatBedrock):
id = "bedrock-chat"
name = "Amazon Bedrock Chat"
models = [
"amazon.titan-text-express-v1",
"amazon.titan-text-lite-v1",
"anthropic.claude-v2",
"anthropic.claude-v2:1",
"anthropic.claude-instant-v1",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"meta.llama2-13b-chat-v1",
"meta.llama2-70b-chat-v1",
"meta.llama3-8b-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"mistral.mistral-7b-instruct-v0:2",
"mistral.mixtral-8x7b-instruct-v0:1",
"mistral.mistral-large-2402-v1:0",
]
model_id_key = "model_id"
pypi_package_deps = ["langchain-aws"]
auth_strategy = AwsAuthStrategy()
fields = [
TextField(
key="credentials_profile_name",
label="AWS profile (optional)",
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
]

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)

async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]:
return await self._generate_in_executor(*args, **kwargs)

@property
def allows_concurrency(self):
return not "anthropic" in self.model_id


# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockEmbeddingsProvider(BaseEmbeddingsProvider, BedrockEmbeddings):
id = "bedrock"
name = "Bedrock"
models = [
"amazon.titan-embed-text-v1",
"amazon.titan-embed-text-v2:0",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3",
]
model_id_key = "model_id"
pypi_package_deps = ["langchain-aws"]
auth_strategy = AwsAuthStrategy()


class JsonContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"

def __init__(self, request_schema, response_path):
self.request_schema = json.loads(request_schema)
self.response_path = response_path
self.response_parser = parse(response_path)

def replace_values(self, old_val, new_val, d: Dict[str, Any]):
"""Replaces values of a dictionary recursively."""
for key, val in d.items():
if val == old_val:
d[key] = new_val
if isinstance(val, dict):
self.replace_values(old_val, new_val, val)

return d

def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
request_obj = copy.deepcopy(self.request_schema)
self.replace_values("<prompt>", prompt, request_obj)
request = json.dumps(request_obj).encode("utf-8")
return request

def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
matches = self.response_parser.find(response_json)
return matches[0].value


class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
id = "sagemaker-endpoint"
name = "SageMaker endpoint"
models = ["*"]
model_id_key = "endpoint_name"
model_id_label = "Endpoint name"
# This all needs to be on one line of markdown, for use in a table
help = (
"Specify an endpoint name as the model ID. "
"In addition, you must specify a region name, request schema, and response path. "
"For more information, see the documentation about [SageMaker endpoints deployment](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-deploy-models.html) "
"and about [using magic commands with SageMaker endpoints](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#using-magic-commands-with-sagemaker-endpoints)."
)

pypi_package_deps = ["langchain-aws"]
auth_strategy = AwsAuthStrategy()
registry = True
fields = [
TextField(key="region_name", label="Region name (required)", format="text"),
MultilineTextField(
key="request_schema", label="Request schema (required)", format="json"
),
TextField(
key="response_path", label="Response path (required)", format="jsonpath"
),
]

def __init__(self, *args, **kwargs):
request_schema = kwargs.pop("request_schema")
response_path = kwargs.pop("response_path")
content_handler = JsonContentHandler(
request_schema=request_schema, response_path=response_path
)

super().__init__(*args, **kwargs, content_handler=content_handler)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)
Loading

0 comments on commit f59a9fe

Please sign in to comment.