Skip to content

Commit

Permalink
docs: review integrations sagemaker (#544)
Browse files Browse the repository at this point in the history
* refactor: remove reimplementing exceptions

* docs: review docs

* style: reformat

* style: shorten line
  • Loading branch information
wochinge authored Mar 6, 2024
1 parent 6d34079 commit 2c6b218
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 48 deletions.
Original file line number Diff line number Diff line change
@@ -1,46 +1,16 @@
from typing import Optional


class SagemakerError(Exception):
"""
Error generated by the Amazon Sagemaker integration.
Parent class for all exceptions raised by the Sagemaker component
"""

def __init__(
self,
message: Optional[str] = None,
):
super().__init__()
if message:
self.message = message

def __getattr__(self, attr):
# If self.__cause__ is None, it will raise the expected AttributeError
getattr(self.__cause__, attr)

def __str__(self):
return self.message

def __repr__(self):
return str(self)


class AWSConfigurationError(SagemakerError):
"""Exception raised when AWS is not configured correctly"""

def __init__(self, message: Optional[str] = None):
super().__init__(message=message)


class SagemakerNotReadyError(SagemakerError):
"""Exception for issues that occur during Sagemaker inference"""

def __init__(self, message: Optional[str] = None):
super().__init__(message=message)


class SagemakerInferenceError(SagemakerError):
"""Exception for issues that occur during Sagemaker inference"""

def __init__(self, message: Optional[str] = None):
super().__init__(message=message)
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,17 @@ class SagemakerGenerator:
[SageMaker JumpStart foundation models documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-use.html).
Usage example:
Make sure your AWS credentials are set up correctly. You can use environment variables or a shared credentials file.
Then you can use the generator as follows:
```python
# Make sure your AWS credentials are set up correctly. You can use environment variables or a shared credentials
# file. Then you can use the generator as follows:
from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator
generator = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-bf16")
response = generator.run("What's Natural Language Processing? Be brief.")
print(response)
```
```
>> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on
>> the interaction between computers and human language. It involves enabling computers to understand, interpret,
>> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{}]}
>>> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on
>>> the interaction between computers and human language. It involves enabling computers to understand, interpret,
>>> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{}]}
```
"""

Expand Down Expand Up @@ -73,7 +71,6 @@ def __init__(
:param model: The name for SageMaker Model Endpoint.
:param aws_custom_attributes: Custom attributes to be passed to SageMaker, for example `{"accept_eula": True}`
in case of Llama-2 models.
:param generation_kwargs: Additional keyword arguments for text generation. For a list of supported parameters
see your model's documentation page, for example here for HuggingFace models:
https://huggingface.co/blog/sagemaker-huggingface-llm#4-run-inference-and-chat-with-our-model
Expand Down Expand Up @@ -121,15 +118,15 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Returns data that is sent to Posthog for usage analytics.
:returns: a dictionary with following keys:
- model: The name of the model.
:returns: A dictionary with the following keys:
- `model`: The name of the model.
"""
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
Expand All @@ -149,10 +146,11 @@ def to_dict(self) -> Dict[str, Any]:
def from_dict(cls, data) -> "SagemakerGenerator":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
Deserialized component.
"""
deserialize_secrets_inplace(
data["init_parameters"],
Expand All @@ -170,6 +168,7 @@ def _get_aws_session(
):
"""
Creates an AWS Session with the given parameters.
Checks if the provided AWS credentials are valid and can be used to connect to AWS.
:param aws_access_key_id: AWS access key ID.
Expand Down Expand Up @@ -200,8 +199,10 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
:param prompt: The string prompt to use for text generation.
:param generation_kwargs: Additional keyword arguments for text generation. These parameters will
potentially override the parameters passed in the `__init__` method.
potentially override the parameters passed in the `__init__` method.
:raises ValueError: If the model response type is not a list of dictionaries or a single dictionary.
:raises SagemakerNotReadyError: If the SageMaker model is not ready to accept requests.
:raises SagemakerInferenceError: If the SageMaker Inference returns an error.
:returns: A dictionary with the following keys:
- `replies`: A list of strings containing the generated responses
- `meta`: A list of dictionaries containing the metadata for each response.
Expand Down Expand Up @@ -249,5 +250,5 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
msg = f"Sagemaker model not ready: {res.text}"
raise SagemakerNotReadyError(msg) from err

msg = f"SageMaker Inference returned an error. Status code: {res.status_code} Response body: {res.text}"
raise SagemakerInferenceError(msg, status_code=res.status_code) from err
msg = f"SageMaker Inference returned an error. Status code: {res.status_code}. Response body: {res.text}"
raise SagemakerInferenceError(msg) from err

0 comments on commit 2c6b218

Please sign in to comment.