Skip to content

Commit

Permalink
revert blackify
Browse files Browse the repository at this point in the history
  • Loading branch information
agnieszka-m committed Aug 5, 2024
1 parent d1a2ff0 commit e43d3e6
Show file tree
Hide file tree
Showing 219 changed files with 2,139 additions and 7,606 deletions.
8 changes: 2 additions & 6 deletions integrations/amazon_bedrock/examples/chatgenerator_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

from haystack.dataclasses import ChatMessage

from haystack_integrations.components.generators.amazon_bedrock import (
AmazonBedrockChatGenerator,
)
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator

generator = AmazonBedrockChatGenerator(
model="anthropic.claude-3-haiku-20240307-v1:0",
Expand All @@ -31,9 +29,7 @@
# which allows for more portablability of code across generators
messages = [
ChatMessage.from_system(system_prompt),
ChatMessage.from_user(
"Which service should I use to train custom Machine Learning models?"
),
ChatMessage.from_user("Which service should I use to train custom Machine Learning models?"),
]

results = generator.run(messages)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
AmazonBedrockDocumentEmbedder,
AmazonBedrockTextEmbedder,
)
from haystack_integrations.components.generators.amazon_bedrock import (
AmazonBedrockGenerator,
)
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator

generator_model_name = "amazon.titan-text-lite-v1"
embedder_model_name = "amazon.titan-embed-text-v1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def get_aws_session(
profile_name=aws_profile_name,
)
except BotoCoreError as e:
provided_aws_config = {
k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS
}
provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS}
msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}"
raise AWSConfigurationError(msg) from e

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,13 @@ def __init__(
"cohere.embed-multilingual-v3",
"amazon.titan-embed-text-v2:0",
],
aws_access_key_id: Optional[Secret] = Secret.from_env_var(
"AWS_ACCESS_KEY_ID", strict=False
), # noqa: B008
aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
"AWS_SECRET_ACCESS_KEY", strict=False
),
aws_session_token: Optional[Secret] = Secret.from_env_var(
"AWS_SESSION_TOKEN", strict=False
), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var(
"AWS_DEFAULT_REGION", strict=False
), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var(
"AWS_PROFILE", strict=False
), # noqa: B008
aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
batch_size: int = 32,
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
Expand Down Expand Up @@ -113,9 +105,8 @@ def __init__(
"""

if not model or model not in SUPPORTED_EMBEDDING_MODELS:
msg = (
"Please provide a valid model from the list of supported models: "
+ ", ".join(SUPPORTED_EMBEDDING_MODELS)
msg = "Please provide a valid model from the list of supported models: " + ", ".join(
SUPPORTED_EMBEDDING_MODELS
)
raise ValueError(msg)

Expand Down Expand Up @@ -156,15 +147,9 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key])
for key in self.meta_fields_to_embed
if doc.meta.get(key)
]

text_to_embed = self.embedding_separator.join(
[*meta_values_to_embed, doc.content or ""]
)
meta_values_to_embed = [str(doc.meta[key]) for key in self.meta_fields_to_embed if doc.meta.get(key)]

text_to_embed = self.embedding_separator.join([*meta_values_to_embed, doc.content or ""])

texts_to_embed.append(text_to_embed)
return texts_to_embed
Expand All @@ -178,28 +163,21 @@ def _embed_cohere(self, documents: List[Document]) -> List[Document]:
texts_to_embed = self._prepare_texts_to_embed(documents=documents)

cohere_body = {
"input_type": self.kwargs.get(
"input_type", "search_document"
), # mandatory parameter for Cohere models
"input_type": self.kwargs.get("input_type", "search_document"), # mandatory parameter for Cohere models
}
if truncate := self.kwargs.get("truncate"):
cohere_body["truncate"] = truncate # optional parameter for Cohere models

all_embeddings = []
for i in tqdm(
range(0, len(texts_to_embed), self.batch_size),
disable=not self.progress_bar,
desc="Creating embeddings",
range(0, len(texts_to_embed), self.batch_size), disable=not self.progress_bar, desc="Creating embeddings"
):
batch = texts_to_embed[i : i + self.batch_size]
body = {"texts": batch, **cohere_body}

try:
response = self._client.invoke_model(
body=json.dumps(body),
modelId=self.model,
accept="*/*",
contentType="application/json",
body=json.dumps(body), modelId=self.model, accept="*/*", contentType="application/json"
)
except ClientError as exception:
msg = (
Expand All @@ -226,16 +204,11 @@ def _embed_titan(self, documents: List[Document]) -> List[Document]:
texts_to_embed = self._prepare_texts_to_embed(documents=documents)

all_embeddings = []
for text in tqdm(
texts_to_embed, disable=not self.progress_bar, desc="Creating embeddings"
):
for text in tqdm(texts_to_embed, disable=not self.progress_bar, desc="Creating embeddings"):
body = {"inputText": text}
try:
response = self._client.invoke_model(
body=json.dumps(body),
modelId=self.model,
accept="*/*",
contentType="application/json",
body=json.dumps(body), modelId=self.model, accept="*/*", contentType="application/json"
)
except ClientError as exception:
msg = (
Expand Down Expand Up @@ -263,11 +236,7 @@ def run(self, documents: List[Document]):
- `documents`: The `Document`s with the `embedding` field populated.
:raises AmazonBedrockInferenceError: If the inference fails.
"""
if (
not isinstance(documents, list)
or documents
and not isinstance(documents[0], Document)
):
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
msg = (
"AmazonBedrockDocumentEmbedder expects a list of Documents as input."
"In case you want to embed a string, please use the AmazonBedrockTextEmbedder."
Expand All @@ -290,21 +259,11 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
aws_access_key_id=self.aws_access_key_id.to_dict()
if self.aws_access_key_id
else None,
aws_secret_access_key=self.aws_secret_access_key.to_dict()
if self.aws_secret_access_key
else None,
aws_session_token=self.aws_session_token.to_dict()
if self.aws_session_token
else None,
aws_region_name=self.aws_region_name.to_dict()
if self.aws_region_name
else None,
aws_profile_name=self.aws_profile_name.to_dict()
if self.aws_profile_name
else None,
aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None,
aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None,
aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None,
aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None,
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
model=self.model,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
Expand All @@ -325,12 +284,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockDocumentEmbedder":
"""
deserialize_secrets_inplace(
data["init_parameters"],
[
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"aws_region_name",
"aws_profile_name",
],
["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"],
)
return default_from_dict(cls, data)
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,13 @@ def __init__(
"cohere.embed-multilingual-v3",
"amazon.titan-embed-text-v2:0",
],
aws_access_key_id: Optional[Secret] = Secret.from_env_var(
"AWS_ACCESS_KEY_ID", strict=False
), # noqa: B008
aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
"AWS_SECRET_ACCESS_KEY", strict=False
),
aws_session_token: Optional[Secret] = Secret.from_env_var(
"AWS_SESSION_TOKEN", strict=False
), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var(
"AWS_DEFAULT_REGION", strict=False
), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var(
"AWS_PROFILE", strict=False
), # noqa: B008
aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
**kwargs,
):
"""
Expand All @@ -95,9 +87,8 @@ def __init__(
:raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly.
"""
if not model or model not in SUPPORTED_EMBEDDING_MODELS:
msg = (
"Please provide a valid model from the list of supported models: "
+ ", ".join(SUPPORTED_EMBEDDING_MODELS)
msg = "Please provide a valid model from the list of supported models: " + ", ".join(
SUPPORTED_EMBEDDING_MODELS
)
raise ValueError(msg)

Expand Down Expand Up @@ -148,9 +139,7 @@ def run(self, text: str):
if "cohere" in self.model:
body = {
"texts": [text],
"input_type": self.kwargs.get(
"input_type", "search_query"
), # mandatory parameter for Cohere models
"input_type": self.kwargs.get("input_type", "search_query"), # mandatory parameter for Cohere models
}
if truncate := self.kwargs.get("truncate"):
body["truncate"] = truncate # optional parameter for Cohere models
Expand All @@ -162,10 +151,7 @@ def run(self, text: str):

try:
response = self._client.invoke_model(
body=json.dumps(body),
modelId=self.model,
accept="*/*",
contentType="application/json",
body=json.dumps(body), modelId=self.model, accept="*/*", contentType="application/json"
)
except ClientError as exception:
msg = (
Expand Down Expand Up @@ -193,21 +179,11 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
aws_access_key_id=self.aws_access_key_id.to_dict()
if self.aws_access_key_id
else None,
aws_secret_access_key=self.aws_secret_access_key.to_dict()
if self.aws_secret_access_key
else None,
aws_session_token=self.aws_session_token.to_dict()
if self.aws_session_token
else None,
aws_region_name=self.aws_region_name.to_dict()
if self.aws_region_name
else None,
aws_profile_name=self.aws_profile_name.to_dict()
if self.aws_profile_name
else None,
aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None,
aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None,
aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None,
aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None,
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
model=self.model,
**self.kwargs,
)
Expand All @@ -224,12 +200,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockTextEmbedder":
"""
deserialize_secrets_inplace(
data["init_parameters"],
[
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"aws_region_name",
"aws_profile_name",
],
["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"],
)
return default_from_dict(cls, data)
Loading

0 comments on commit e43d3e6

Please sign in to comment.