Skip to content

Commit

Permalink
migrate to langchain-aws
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Jul 24, 2024
1 parent 28d5009 commit c98ff5e
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 174 deletions.
4 changes: 0 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# expose embedding model providers on the package root
from .embedding_providers import (
BaseEmbeddingsProvider,
BedrockEmbeddingsProvider,
GPT4AllEmbeddingsProvider,
HfHubEmbeddingsProvider,
OllamaEmbeddingsProvider,
Expand All @@ -20,13 +19,10 @@
from .providers import (
AI21Provider,
BaseProvider,
BedrockChatProvider,
BedrockProvider,
GPT4AllProvider,
HfHubProvider,
OllamaProvider,
QianfanProvider,
SmEndpointProvider,
TogetherAIProvider,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

from jupyter_ai_magics.providers import (
AuthStrategy,
AwsAuthStrategy,
EnvAuthStrategy,
Field,
MultiEnvAuthStrategy,
)
from langchain.pydantic_v1 import BaseModel, Extra
from langchain_community.embeddings import (
BedrockEmbeddings,
GPT4AllEmbeddings,
HuggingFaceHubEmbeddings,
OllamaEmbeddings,
Expand Down Expand Up @@ -93,16 +91,6 @@ class HfHubEmbeddingsProvider(BaseEmbeddingsProvider, HuggingFaceHubEmbeddings):
registry = True


# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockEmbeddingsProvider(BaseEmbeddingsProvider, BedrockEmbeddings):
id = "bedrock"
name = "Bedrock"
models = ["amazon.titan-embed-text-v1", "amazon.titan-embed-text-v2:0"]
model_id_key = "model_id"
pypi_package_deps = ["boto3"]
auth_strategy = AwsAuthStrategy()


class GPT4AllEmbeddingsProvider(BaseEmbeddingsProvider, GPT4AllEmbeddings):
def __init__(self, **kwargs):
from gpt4all import GPT4All
Expand Down
167 changes: 167 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from typing import Any, Coroutine, Dict
import copy
import json
from jsonpath_ng import parse

from langchain_aws import Bedrock, ChatBedrock, SagemakerEndpoint, BedrockEmbeddings
from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler
from langchain_core.outputs import LLMResult

from ..providers import BaseProvider, AwsAuthStrategy, TextField, MultilineTextField
from ..embedding_providers import BaseEmbeddingsProvider

# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockProvider(BaseProvider, Bedrock):
id = "bedrock"
name = "Amazon Bedrock"
models = [
"amazon.titan-text-express-v1",
"amazon.titan-text-lite-v1",
"ai21.j2-ultra-v1",
"ai21.j2-mid-v1",
"cohere.command-light-text-v14",
"cohere.command-text-v14",
"cohere.command-r-v1:0",
"cohere.command-r-plus-v1:0",
"meta.llama2-13b-chat-v1",
"meta.llama2-70b-chat-v1",
"meta.llama3-8b-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"mistral.mistral-7b-instruct-v0:2",
"mistral.mixtral-8x7b-instruct-v0:1",
"mistral.mistral-large-2402-v1:0",
]
model_id_key = "model_id"
pypi_package_deps = ["langchain-aws"]
auth_strategy = AwsAuthStrategy()
fields = [
TextField(
key="credentials_profile_name",
label="AWS profile (optional)",
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
]

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)


# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockChatProvider(BaseProvider, ChatBedrock):
id = "bedrock-chat"
name = "Amazon Bedrock Chat"
models = [
"anthropic.claude-v2",
"anthropic.claude-v2:1",
"anthropic.claude-instant-v1",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
]
model_id_key = "model_id"
pypi_package_deps = ["langchain-aws"]
auth_strategy = AwsAuthStrategy()
fields = [
TextField(
key="credentials_profile_name",
label="AWS profile (optional)",
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
]

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)

async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]:
return await self._generate_in_executor(*args, **kwargs)

@property
def allows_concurrency(self):
return not "anthropic" in self.model_id


# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockEmbeddingsProvider(BaseEmbeddingsProvider, BedrockEmbeddings):
id = "bedrock"
name = "Bedrock"
models = ["amazon.titan-embed-text-v1", "amazon.titan-embed-text-v2:0"]
model_id_key = "model_id"
pypi_package_deps = ["langchain-aws"]
auth_strategy = AwsAuthStrategy()


class JsonContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"

def __init__(self, request_schema, response_path):
self.request_schema = json.loads(request_schema)
self.response_path = response_path
self.response_parser = parse(response_path)

def replace_values(self, old_val, new_val, d: Dict[str, Any]):
"""Replaces values of a dictionary recursively."""
for key, val in d.items():
if val == old_val:
d[key] = new_val
if isinstance(val, dict):
self.replace_values(old_val, new_val, val)

return d

def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
request_obj = copy.deepcopy(self.request_schema)
self.replace_values("<prompt>", prompt, request_obj)
request = json.dumps(request_obj).encode("utf-8")
return request

def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
matches = self.response_parser.find(response_json)
return matches[0].value


class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
id = "sagemaker-endpoint"
name = "SageMaker endpoint"
models = ["*"]
model_id_key = "endpoint_name"
model_id_label = "Endpoint name"
# This all needs to be on one line of markdown, for use in a table
help = (
"Specify an endpoint name as the model ID. "
"In addition, you must specify a region name, request schema, and response path. "
"For more information, see the documentation about [SageMaker endpoints deployment](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-deployment.html) "
"and about [using magic commands with SageMaker endpoints](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#using-magic-commands-with-sagemaker-endpoints)."
)

pypi_package_deps = ["langchain-aws"]
auth_strategy = AwsAuthStrategy()
registry = True
fields = [
TextField(key="region_name", label="Region name (required)", format="text"),
MultilineTextField(
key="request_schema", label="Request schema (required)", format="json"
),
TextField(
key="response_path", label="Response path (required)", format="jsonpath"
),
]

def __init__(self, *args, **kwargs):
request_schema = kwargs.pop("request_schema")
response_path = kwargs.pop("response_path")
content_handler = JsonContentHandler(
request_schema=request_schema, response_path=response_path
)

super().__init__(*args, **kwargs, content_handler=content_handler)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)

Loading

0 comments on commit c98ff5e

Please sign in to comment.