Skip to content

Commit

Permalink
Merge branch 'main' into add-jina-ranker
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored Mar 6, 2024
2 parents 5e9e24e + 67decea commit d6a302f
Show file tree
Hide file tree
Showing 46 changed files with 1,156 additions and 210 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,4 @@ dmypy.json

# Docs generation artifacts
_readme_*.md
.idea
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_aws_session(
:param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen.
See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html.
:raises AWSConfigurationError: If the provided AWS credentials are invalid.
:return: The created AWS session.
:returns: The created AWS session.
"""
try:
return boto3.Session(
Expand All @@ -54,7 +54,7 @@ def aws_configured(**kwargs) -> bool:
"""
Checks whether AWS configuration is provided.
:param kwargs: The kwargs passed down to the generator.
:return: True if AWS configuration is provided, False otherwise.
:returns: True if AWS configuration is provided, False otherwise.
"""
aws_config_provided = any(key in kwargs for key in AWS_CONFIGURATION_KEYS)
return aws_config_provided
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def run(self, documents: List[Document]):
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:return: The serialized component as a dictionary.
:returns: The serialized component as a dictionary.
"""
return default_to_dict(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def run(self, text: str):
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:return: The serialized component as a dictionary.
:returns: The serialized component as a dictionary.
"""
return default_to_dict(
self,
Expand All @@ -172,7 +172,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockTextEmbedder":
"""
Deserialize this component from a dictionary.
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
:returns: The deserialized component instance.
"""
deserialize_secrets_inplace(
data["init_parameters"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
:param prompt: The prompt to be sent to the model.
:param inference_kwargs: Additional keyword arguments passed to the handler.
:return: A dictionary containing the body for the request.
:returns: A dictionary containing the body for the request.
"""

def get_responses(self, response_body: Dict[str, Any]) -> List[str]:
"""
Extracts the responses from the Amazon Bedrock response.
:param response_body: The response body from the Amazon Bedrock request.
:return: A list of responses.
:returns: A list of responses.
"""
completions = self._extract_completions_from_response(response_body)
responses = [completion.lstrip() for completion in completions]
Expand All @@ -45,7 +45,7 @@ def get_stream_responses(self, stream, stream_handler: TokenStreamingHandler) ->
:param stream: The streaming response from the Amazon Bedrock request.
:param stream_handler: The handler for the streaming response.
:return: A list of string responses.
:returns: A list of string responses.
"""
tokens: List[str] = []
for event in stream:
Expand All @@ -64,7 +64,7 @@ def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str
Includes param if it's in kwargs or its default is not None (i.e. it is actually defined).
:param inference_kwargs: The inference kwargs.
:param default_params: The default params.
:return: A dictionary containing the merged params.
:returns: A dictionary containing the merged params.
"""
kwargs = self.model_kwargs.copy()
kwargs.update(inference_kwargs)
Expand All @@ -80,7 +80,7 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
Extracts the responses from the Amazon Bedrock response.
:param response_body: The response body from the Amazon Bedrock request.
:return: A list of string responses.
:returns: A list of string responses.
"""

@abstractmethod
Expand All @@ -89,7 +89,7 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:return: A string token.
:returns: A string token.
"""


Expand Down Expand Up @@ -121,7 +121,7 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
Extracts the responses from the Amazon Bedrock response.
:param response_body: The response body from the Amazon Bedrock request.
:return: A list of string responses.
:returns: A list of string responses.
"""
return [response_body["completion"]]

Expand All @@ -130,7 +130,7 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:return: A string token.
:returns: A string token.
"""
return chunk.get("completion", "")

Expand All @@ -146,7 +146,7 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
:param prompt: The prompt to be sent to the model.
:param inference_kwargs: Additional keyword arguments passed to the handler.
:return: A dictionary containing the body for the request.
:returns: A dictionary containing the body for the request.
"""
default_params = {
"max_tokens": self.max_length,
Expand All @@ -170,7 +170,7 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
Extracts the responses from the Cohere Command model response.
:param response_body: The response body from the Amazon Bedrock request.
:return: A list of string responses.
:returns: A list of string responses.
"""
responses = [generation["text"] for generation in response_body["generations"]]
return responses
Expand All @@ -180,7 +180,7 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:return: A string token.
:returns: A string token.
"""
return chunk.get("text", "")

Expand Down Expand Up @@ -226,7 +226,7 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
:param prompt: The prompt to be sent to the model.
:param inference_kwargs: Additional keyword arguments passed to the handler.
:return: A dictionary containing the body for the request.
:returns: A dictionary containing the body for the request.
"""
default_params = {
"maxTokenCount": self.max_length,
Expand All @@ -244,7 +244,7 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
Extracts the responses from the Titan model response.
:param response_body: The response body for Titan model response.
:return: A list of string responses.
:returns: A list of string responses.
"""
responses = [result["outputText"] for result in response_body["results"]]
return responses
Expand All @@ -254,7 +254,7 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:return: A string token.
:returns: A string token.
"""
return chunk.get("outputText", "")

Expand All @@ -270,7 +270,7 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
:param prompt: The prompt to be sent to the model.
:param inference_kwargs: Additional keyword arguments passed to the handler.
:return: A dictionary containing the body for the request.
:returns: A dictionary containing the body for the request.
"""
default_params = {
"max_gen_len": self.max_length,
Expand All @@ -287,7 +287,7 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
Extracts the responses from the Llama2 model response.
:param response_body: The response body from the Llama2 model request.
:return: A list of string responses.
:returns: A list of string responses.
"""
return [response_body["generation"]]

Expand All @@ -296,6 +296,6 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:return: A string token.
:returns: A string token.
"""
return chunk.get("generation", "")
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[
:param messages: The chat messages to package into the request.
:param inference_kwargs: Additional inference kwargs to use.
:return: The prepared body.
:returns: The prepared body.
"""

def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
"""
Extracts the responses from the Amazon Bedrock response.
:param response_body: The response body.
:return: The extracted responses.
:returns: The extracted responses.
"""
return self._extract_messages_from_response(self.response_body_message_key(), response_body)

Expand Down Expand Up @@ -85,7 +85,7 @@ def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str
:param inference_kwargs: The inference kwargs to merge.
:param default_params: The default params to start with.
:return: The merged params.
:returns: The merged params.
"""
# Start with a copy of default_params
kwargs = default_params.copy()
Expand All @@ -100,7 +100,7 @@ def _ensure_token_limit(self, prompt: str) -> str:
"""
Ensures that the prompt is within the token limit for the model.
:param prompt: The prompt to check.
:return: The resized prompt.
:returns: The resized prompt.
"""
resize_info = self.check_prompt(prompt)
if resize_info["prompt_length"] != resize_info["new_prompt_length"]:
Expand All @@ -121,7 +121,7 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]:
Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated.
:param prompt: The prompt to check.
:return: A dictionary containing the resized prompt and additional information.
:returns: A dictionary containing the resized prompt and additional information.
"""

def _extract_messages_from_response(self, message_tag: str, response_body: Dict[str, Any]) -> List[ChatMessage]:
Expand All @@ -130,7 +130,7 @@ def _extract_messages_from_response(self, message_tag: str, response_body: Dict[
:param message_tag: The key for the message in the response body.
:param response_body: The response body.
:return: The extracted ChatMessage list.
:returns: The extracted ChatMessage list.
"""
metadata = {k: v for (k, v) in response_body.items() if k != message_tag}
return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)]
Expand All @@ -141,7 +141,7 @@ def response_body_message_key(self) -> str:
Returns the key for the message in the response body.
Subclasses should override this method to return the correct message key - where the response is located.
:return: The key for the message in the response body.
:returns: The key for the message in the response body.
"""

@abstractmethod
Expand All @@ -150,7 +150,7 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:return: The extracted token.
:returns: The extracted token.
"""


Expand Down Expand Up @@ -192,7 +192,7 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[
:param messages: The chat messages to package into the request.
:param inference_kwargs: Additional inference kwargs to use.
:return: The prepared body.
:returns: The prepared body.
"""
default_params = {
"max_tokens_to_sample": self.generation_kwargs.get("max_tokens_to_sample") or 512,
Expand All @@ -212,7 +212,7 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
Prepares the chat messages for the Anthropic Claude request.
:param messages: The chat messages to prepare.
:return: The prepared chat messages as a string.
:returns: The prepared chat messages as a string.
"""
conversation = []
for index, message in enumerate(messages):
Expand Down Expand Up @@ -241,15 +241,15 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]:
Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated.
:param prompt: The prompt to check.
:return: A dictionary containing the resized prompt and additional information.
:returns: A dictionary containing the resized prompt and additional information.
"""
return self.prompt_handler(prompt)

def response_body_message_key(self) -> str:
"""
Returns the key for the message in the response body for Anthropic Claude i.e. "completion".
:return: The key for the message in the response body.
:returns: The key for the message in the response body.
"""
return "completion"

Expand All @@ -258,7 +258,7 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:return: The extracted token.
:returns: The extracted token.
"""
return chunk.get("completion", "")

Expand Down Expand Up @@ -340,7 +340,7 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
Prepares the chat messages for the Meta Llama 2 request.
:param messages: The chat messages to prepare.
:return: The prepared chat messages as a string ready for the model.
:returns: The prepared chat messages as a string ready for the model.
"""
prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template(
conversation=messages, tokenize=False, chat_template=self.chat_template
Expand All @@ -352,7 +352,7 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]:
Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated.
:param prompt: The prompt to check.
:return: A dictionary containing the resized prompt and additional information.
:returns: A dictionary containing the resized prompt and additional information.
"""
return self.prompt_handler(prompt)
Expand All @@ -361,7 +361,7 @@ def response_body_message_key(self) -> str:
"""
Returns the key for the message in the response body for Meta Llama 2 i.e. "generation".
:return: The key for the message in the response body.
:returns: The key for the message in the response body.
"""
return "generation"

Expand All @@ -370,6 +370,6 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:return: The extracted token.
:returns: The extracted token.
"""
return chunk.get("generation", "")
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def invoke(self, *args, **kwargs):
:param args: The positional arguments passed to the generator.
:param kwargs: The keyword arguments passed to the generator.
:return: List of `ChatMessage` generated by LLM.
:returns: List of `ChatMessage` generated by LLM.
"""

kwargs = kwargs.copy()
Expand Down Expand Up @@ -183,7 +183,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
:param messages: The messages to generate a response to.
:param generation_kwargs: Additional generation keyword arguments passed to the model.
:return: A dictionary with the following keys:
:returns: A dictionary with the following keys:
- `replies`: The generated List of `ChatMessage` objects.
"""
return {"replies": self.invoke(messages=messages, **(generation_kwargs or {}))}
Expand All @@ -194,7 +194,7 @@ def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter]
Returns the model adapter for the given model.
:param model: The model to get the adapter for.
:return: The model adapter for the given model, or None if the model is not supported.
:returns: The model adapter for the given model, or None if the model is not supported.
"""
for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items():
if re.fullmatch(pattern, model):
Expand Down
Loading

0 comments on commit d6a302f

Please sign in to comment.