Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to langchain-aws for AWS providers #909

Merged
merged 9 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ Jupyter AI supports the following model providers:
| AI21 | `ai21` | `AI21_API_KEY` | `ai21` |
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `langchain-anthropic` |
| Anthropic (chat) | `anthropic-chat` | `ANTHROPIC_API_KEY` | `langchain-anthropic` |
| Bedrock | `bedrock` | N/A | `boto3` |
| Bedrock (chat) | `bedrock-chat` | N/A | `boto3` |
| Bedrock | `bedrock` | N/A | `langchain-aws` |
| Bedrock (chat) | `bedrock-chat` | N/A | `langchain-aws` |
| Cohere | `cohere` | `COHERE_API_KEY` | `langchain_cohere` |
| ERNIE-Bot | `qianfan` | `QIANFAN_AK`, `QIANFAN_SK` | `qianfan` |
| Gemini | `gemini` | `GOOGLE_API_KEY` | `langchain-google-genai` |
Expand All @@ -169,7 +169,7 @@ Jupyter AI supports the following model providers:
| NVIDIA | `nvidia-chat` | `NVIDIA_API_KEY` | `langchain_nvidia_ai_endpoints` |
| OpenAI | `openai` | `OPENAI_API_KEY` | `langchain-openai` |
| OpenAI (chat) | `openai-chat` | `OPENAI_API_KEY` | `langchain-openai` |
| SageMaker | `sagemaker-endpoint` | N/A | `boto3` |
| SageMaker | `sagemaker-endpoint` | N/A | `langchain-aws` |

The environment variable names shown above are also the names of the settings keys used when setting up the chat interface.
If multiple variables are listed for a provider, **all** must be specified.
Expand Down
4 changes: 0 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# expose embedding model providers on the package root
from .embedding_providers import (
BaseEmbeddingsProvider,
BedrockEmbeddingsProvider,
GPT4AllEmbeddingsProvider,
HfHubEmbeddingsProvider,
OllamaEmbeddingsProvider,
Expand All @@ -20,13 +19,10 @@
from .providers import (
AI21Provider,
BaseProvider,
BedrockChatProvider,
BedrockProvider,
GPT4AllProvider,
HfHubProvider,
OllamaProvider,
QianfanProvider,
SmEndpointProvider,
TogetherAIProvider,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

from jupyter_ai_magics.providers import (
AuthStrategy,
AwsAuthStrategy,
EnvAuthStrategy,
Field,
MultiEnvAuthStrategy,
)
from langchain.pydantic_v1 import BaseModel, Extra
from langchain_community.embeddings import (
BedrockEmbeddings,
GPT4AllEmbeddings,
HuggingFaceHubEmbeddings,
OllamaEmbeddings,
Expand Down Expand Up @@ -93,16 +91,6 @@ class HfHubEmbeddingsProvider(BaseEmbeddingsProvider, HuggingFaceHubEmbeddings):
registry = True


# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockEmbeddingsProvider(BaseEmbeddingsProvider, BedrockEmbeddings):
id = "bedrock"
name = "Bedrock"
models = ["amazon.titan-embed-text-v1", "amazon.titan-embed-text-v2:0"]
model_id_key = "model_id"
pypi_package_deps = ["boto3"]
auth_strategy = AwsAuthStrategy()


class GPT4AllEmbeddingsProvider(BaseEmbeddingsProvider, GPT4AllEmbeddings):
def __init__(self, **kwargs):
from gpt4all import GPT4All
Expand Down
183 changes: 183 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import copy
import json
from typing import Any, Coroutine, Dict

from jsonpath_ng import parse
from langchain_aws import BedrockEmbeddings, BedrockLLM, ChatBedrock, SagemakerEndpoint
from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler
from langchain_core.outputs import LLMResult

from ..embedding_providers import BaseEmbeddingsProvider
from ..providers import AwsAuthStrategy, BaseProvider, MultilineTextField, TextField


# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockProvider(BaseProvider, BedrockLLM):
id = "bedrock"
name = "Amazon Bedrock"
models = [
"amazon.titan-text-express-v1",
"amazon.titan-text-lite-v1",
"ai21.j2-ultra-v1",
"ai21.j2-mid-v1",
"cohere.command-light-text-v14",
"cohere.command-text-v14",
"cohere.command-r-v1:0",
"cohere.command-r-plus-v1:0",
"meta.llama2-13b-chat-v1",
"meta.llama2-70b-chat-v1",
"meta.llama3-8b-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"mistral.mistral-7b-instruct-v0:2",
"mistral.mixtral-8x7b-instruct-v0:1",
"mistral.mistral-large-2402-v1:0",
]
model_id_key = "model_id"
pypi_package_deps = ["langchain-aws"]
auth_strategy = AwsAuthStrategy()
fields = [
TextField(
key="credentials_profile_name",
label="AWS profile (optional)",
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
]

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)


# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockChatProvider(BaseProvider, ChatBedrock):
id = "bedrock-chat"
name = "Amazon Bedrock Chat"
models = [
"amazon.titan-text-express-v1",
"amazon.titan-text-lite-v1",
"anthropic.claude-v2",
"anthropic.claude-v2:1",
"anthropic.claude-instant-v1",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
dlqqq marked this conversation as resolved.
Show resolved Hide resolved
"meta.llama2-13b-chat-v1",
"meta.llama2-70b-chat-v1",
"meta.llama3-8b-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"mistral.mistral-7b-instruct-v0:2",
"mistral.mixtral-8x7b-instruct-v0:1",
"mistral.mistral-large-2402-v1:0",
]
model_id_key = "model_id"
pypi_package_deps = ["langchain-aws"]
auth_strategy = AwsAuthStrategy()
fields = [
TextField(
key="credentials_profile_name",
label="AWS profile (optional)",
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
]

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)

async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]:
return await self._generate_in_executor(*args, **kwargs)

@property
def allows_concurrency(self):
return not "anthropic" in self.model_id


# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockEmbeddingsProvider(BaseEmbeddingsProvider, BedrockEmbeddings):
id = "bedrock"
name = "Bedrock"
models = [
"amazon.titan-embed-text-v1",
"amazon.titan-embed-text-v2:0",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3",
]
model_id_key = "model_id"
pypi_package_deps = ["langchain-aws"]
auth_strategy = AwsAuthStrategy()


class JsonContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"

def __init__(self, request_schema, response_path):
self.request_schema = json.loads(request_schema)
self.response_path = response_path
self.response_parser = parse(response_path)

def replace_values(self, old_val, new_val, d: Dict[str, Any]):
"""Replaces values of a dictionary recursively."""
for key, val in d.items():
if val == old_val:
d[key] = new_val
if isinstance(val, dict):
self.replace_values(old_val, new_val, val)

return d

def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
request_obj = copy.deepcopy(self.request_schema)
self.replace_values("<prompt>", prompt, request_obj)
request = json.dumps(request_obj).encode("utf-8")
return request

def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
matches = self.response_parser.find(response_json)
return matches[0].value


class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
id = "sagemaker-endpoint"
name = "SageMaker endpoint"
models = ["*"]
model_id_key = "endpoint_name"
model_id_label = "Endpoint name"
# This all needs to be on one line of markdown, for use in a table
help = (
"Specify an endpoint name as the model ID. "
"In addition, you must specify a region name, request schema, and response path. "
"For more information, see the documentation about [SageMaker endpoints deployment](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-deploy-models.html) "
"and about [using magic commands with SageMaker endpoints](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#using-magic-commands-with-sagemaker-endpoints)."
)

pypi_package_deps = ["langchain-aws"]
auth_strategy = AwsAuthStrategy()
registry = True
fields = [
TextField(key="region_name", label="Region name (required)", format="text"),
MultilineTextField(
key="request_schema", label="Request schema (required)", format="json"
),
TextField(
key="response_path", label="Response path (required)", format="jsonpath"
),
]

def __init__(self, *args, **kwargs):
request_schema = kwargs.pop("request_schema")
response_path = kwargs.pop("response_path")
content_handler = JsonContentHandler(
request_schema=request_schema, response_path=response_path
)

super().__init__(*args, **kwargs, content_handler=content_handler)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)
Loading
Loading