From 5e66f1d370cc33d6e4f29019bfdc368fc63182c8 Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Fri, 14 Jun 2024 17:06:51 +0200 Subject: [PATCH] feat: support Claude v3, Llama3 and Command R models on Amazon Bedrock (#809) * feat: support Claude v3 and Cohere Command R models on Amazon Bedrock * revert chat pattern change * rename llama adapter * fix tests after llama adapter rename --- .../generators/amazon_bedrock/adapters.py | 106 +++++- .../generators/amazon_bedrock/generator.py | 8 +- .../amazon_bedrock/tests/test_generator.py | 331 +++++++++++++++++- 3 files changed, 413 insertions(+), 32 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py index f5bd4aa07..7c7fdd7ce 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py @@ -98,6 +98,10 @@ class AnthropicClaudeAdapter(BedrockModelAdapter): Adapter for the Anthropic Claude models. """ + def __init__(self, model_kwargs: Dict[str, Any], max_length: Optional[int]) -> None: + self.use_messages_api = model_kwargs.get("use_messages_api", True) + super().__init__(model_kwargs, max_length) + def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: """ Prepares the body for the Claude model @@ -108,16 +112,30 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: - `prompt`: The prompt to be sent to the model. - specified inference parameters. """ - default_params = { - "max_tokens_to_sample": self.max_length, - "stop_sequences": ["\n\nHuman:"], - "temperature": None, - "top_p": None, - "top_k": None, - } - params = self._get_params(inference_kwargs, default_params) - - body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params} + if self.use_messages_api: + default_params: Dict[str, Any] = { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": self.max_length, + "system": None, + "stop_sequences": None, + "temperature": None, + "top_p": None, + "top_k": None, + } + params = self._get_params(inference_kwargs, default_params) + + body = {"messages": [{"role": "user", "content": prompt}], **params} + else: + default_params = { + "max_tokens_to_sample": self.max_length, + "stop_sequences": ["\n\nHuman:"], + "temperature": None, + "top_p": None, + "top_k": None, + } + params = self._get_params(inference_kwargs, default_params) + + body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params} return body def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: @@ -127,6 +145,9 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L :param response_body: The response body from the Amazon Bedrock request. :returns: A list of string responses. """ + if self.use_messages_api: + return [content["text"] for content in response_body["content"]] + return [response_body["completion"]] def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: @@ -136,6 +157,9 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: :param chunk: The streaming chunk. :returns: A string token. """ + if self.use_messages_api: + return chunk.get("delta", {}).get("text", "") + return chunk.get("completion", "") @@ -240,6 +264,66 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: return chunk.get("text", "") +class CohereCommandRAdapter(BedrockModelAdapter): + """ + Adapter for the Cohere Command R models. + """ + + def prepare_body(self, prompt: str, **inference_kwargs: Any) -> Dict[str, Any]: + """ + Prepares the body for the Command model + + :param prompt: The prompt to be sent to the model. + :param inference_kwargs: Additional keyword arguments passed to the handler. + :returns: A dictionary with the following keys: + - `prompt`: The prompt to be sent to the model. + - specified inference parameters. + """ + default_params = { + "chat_history": None, + "documents": None, + "search_query_only": None, + "preamble": None, + "max_tokens": self.max_length, + "temperature": None, + "p": None, + "k": None, + "prompt_truncation": None, + "frequency_penalty": None, + "presence_penalty": None, + "seed": None, + "return_prompt": None, + "tools": None, + "tool_results": None, + "stop_sequences": None, + "raw_prompting": None, + } + params = self._get_params(inference_kwargs, default_params) + + body = {"message": prompt, **params} + return body + + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + """ + Extracts the responses from the Cohere Command model response. + + :param response_body: The response body from the Amazon Bedrock request. + :returns: A list of string responses. + """ + responses = [response_body["text"]] + return responses + + def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + """ + Extracts the token from a streaming chunk. + + :param chunk: The streaming chunk. + :returns: A string token. + """ + token: str = chunk.get("text", "") + return token + + class AI21LabsJurassic2Adapter(BedrockModelAdapter): """ Model adapter for AI21 Labs' Jurassic 2 models. @@ -324,7 +408,7 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: return chunk.get("outputText", "") -class MetaLlama2ChatAdapter(BedrockModelAdapter): +class MetaLlamaAdapter(BedrockModelAdapter): """ Adapter for Meta's Llama2 models. """ diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 81a02b749..b93ba1d3f 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -19,7 +19,8 @@ AnthropicClaudeAdapter, BedrockModelAdapter, CohereCommandAdapter, - MetaLlama2ChatAdapter, + CohereCommandRAdapter, + MetaLlamaAdapter, MistralAdapter, ) from .handlers import ( @@ -56,9 +57,10 @@ class AmazonBedrockGenerator: SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelAdapter]]] = { r"amazon.titan-text.*": AmazonTitanAdapter, r"ai21.j2.*": AI21LabsJurassic2Adapter, - r"cohere.command.*": CohereCommandAdapter, + r"cohere.command-[^r].*": CohereCommandAdapter, + r"cohere.command-r.*": CohereCommandRAdapter, r"anthropic.claude.*": AnthropicClaudeAdapter, - r"meta.llama2.*": MetaLlama2ChatAdapter, + r"meta.llama.*": MetaLlamaAdapter, r"mistral.*": MistralAdapter, } diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 10fc1eca8..e603c8853 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -10,7 +10,8 @@ AnthropicClaudeAdapter, BedrockModelAdapter, CohereCommandAdapter, - MetaLlama2ChatAdapter, + CohereCommandRAdapter, + MetaLlamaAdapter, MistralAdapter, ) @@ -203,6 +204,9 @@ def test_long_prompt_is_truncated(mock_boto3_session): ("cohere.command-text-v14", CohereCommandAdapter), ("cohere.command-light-text-v14", CohereCommandAdapter), ("cohere.command-text-v21", CohereCommandAdapter), # artificial + ("cohere.command-r-v1:0", CohereCommandRAdapter), + ("cohere.command-r-plus-v1:0", CohereCommandRAdapter), + ("cohere.command-r-v8:9", CohereCommandRAdapter), # artificial ("ai21.j2-mid-v1", AI21LabsJurassic2Adapter), ("ai21.j2-ultra-v1", AI21LabsJurassic2Adapter), ("ai21.j2-mega-v5", AI21LabsJurassic2Adapter), # artificial @@ -210,9 +214,16 @@ def test_long_prompt_is_truncated(mock_boto3_session): ("amazon.titan-text-express-v1", AmazonTitanAdapter), ("amazon.titan-text-agile-v1", AmazonTitanAdapter), ("amazon.titan-text-lightning-v8", AmazonTitanAdapter), # artificial - ("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), - ("meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), - ("meta.llama2-130b-v5", MetaLlama2ChatAdapter), # artificial + ("meta.llama2-13b-chat-v1", MetaLlamaAdapter), + ("meta.llama2-70b-chat-v1", MetaLlamaAdapter), + ("meta.llama2-130b-v5", MetaLlamaAdapter), # artificial + ("meta.llama3-8b-instruct-v1:0", MetaLlamaAdapter), + ("meta.llama3-70b-instruct-v1:0", MetaLlamaAdapter), + ("meta.llama3-130b-instruct-v5:9", MetaLlamaAdapter), # artificial + ("mistral.mistral-7b-instruct-v0:2", MistralAdapter), + ("mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), + ("mistral.mistral-large-2402-v1:0", MistralAdapter), + ("mistral.mistral-medium-v8:0", MistralAdapter), # artificial ("unknown_model", None), ], ) @@ -225,9 +236,183 @@ def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[Bed class TestAnthropicClaudeAdapter: + def test_default_init(self) -> None: + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=100) + assert adapter.use_messages_api is True + + def test_use_messages_api_false(self) -> None: + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=100) + assert adapter.use_messages_api is False + + +class TestAnthropicClaudeAdapterMessagesAPI: def test_prepare_body_with_default_params(self) -> None: layer = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) prompt = "Hello, how are you?" + expected_body = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 99, + "anthropic_version": "bedrock-2023-05-31", + } + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_custom_inference_params(self) -> None: + layer = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + prompt = "Hello, how are you?" + expected_body = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + "system": "system prompt", + "anthropic_version": "custom_version", + } + + body = layer.prepare_body( + prompt, + temperature=0.7, + top_p=0.8, + top_k=5, + max_tokens=50, + stop_sequences=["CUSTOM_STOP"], + system="system prompt", + anthropic_version="custom_version", + unknown_arg="unknown_value", + ) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs(self) -> None: + layer = AnthropicClaudeAdapter( + model_kwargs={ + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP"], + "system": "system prompt", + "anthropic_version": "custom_version", + "unknown_arg": "unknown_value", + }, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + "system": "system prompt", + "anthropic_version": "custom_version", + } + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: + layer = AnthropicClaudeAdapter( + model_kwargs={ + "temperature": 0.6, + "top_p": 0.7, + "top_k": 4, + "max_tokens": 49, + "stop_sequences": ["CUSTOM_STOP_MODEL_KWARGS"], + "system": "system prompt", + "anthropic_version": "custom_version", + }, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP_MODEL_KWARGS"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + "system": "new system prompt", + "anthropic_version": "new_custom_version", + } + + body = layer.prepare_body( + prompt, + temperature=0.7, + top_p=0.8, + top_k=5, + max_tokens=50, + system="new system prompt", + anthropic_version="new_custom_version", + ) + + assert body == expected_body + + def test_get_responses(self) -> None: + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + response_body = {"content": [{"text": "This is a single response."}]} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_responses_leading_whitespace(self) -> None: + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + response_body = {"content": [{"text": "\n\t This is a single response."}]} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_stream_responses(self) -> None: + stream_mock = MagicMock() + stream_handler_mock = MagicMock() + + stream_mock.__iter__.return_value = [ + {"chunk": {"bytes": b'{"delta": {"text": " This"}}'}}, + {"chunk": {"bytes": b'{"delta": {"text": " is"}}'}}, + {"chunk": {"bytes": b'{"delta": {"text": " a"}}'}}, + {"chunk": {"bytes": b'{"delta": {"text": " single"}}'}}, + {"chunk": {"bytes": b'{"delta": {"text": " response."}}'}}, + ] + + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + expected_responses = ["This is a single response."] + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + + stream_handler_mock.assert_has_calls( + [ + call(" This", event_data={"delta": {"text": " This"}}), + call(" is", event_data={"delta": {"text": " is"}}), + call(" a", event_data={"delta": {"text": " a"}}), + call(" single", event_data={"delta": {"text": " single"}}), + call(" response.", event_data={"delta": {"text": " response."}}), + ] + ) + + def test_get_stream_responses_empty(self) -> None: + stream_mock = MagicMock() + stream_handler_mock = MagicMock() + + stream_mock.__iter__.return_value = [] + + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + expected_responses = [""] + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + + stream_handler_mock.assert_not_called() + + +class TestAnthropicClaudeAdapterNoMessagesAPI: + def test_prepare_body_with_default_params(self) -> None: + layer = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) + prompt = "Hello, how are you?" expected_body = { "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant:", "max_tokens_to_sample": 99, @@ -239,7 +424,7 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + layer = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) prompt = "Hello, how are you?" expected_body = { "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant:", @@ -265,6 +450,7 @@ def test_prepare_body_with_custom_inference_params(self) -> None: def test_prepare_body_with_model_kwargs(self) -> None: layer = AnthropicClaudeAdapter( model_kwargs={ + "use_messages_api": False, "temperature": 0.7, "top_p": 0.8, "top_k": 5, @@ -291,6 +477,7 @@ def test_prepare_body_with_model_kwargs(self) -> None: def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: layer = AnthropicClaudeAdapter( model_kwargs={ + "use_messages_api": False, "temperature": 0.6, "top_p": 0.7, "top_k": 4, @@ -314,13 +501,13 @@ def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> Non assert body == expected_body def test_get_responses(self) -> None: - adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) response_body = {"completion": "This is a single response."} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses def test_get_responses_leading_whitespace(self) -> None: - adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) response_body = {"completion": "\n\t This is a single response."} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses @@ -339,7 +526,7 @@ def test_get_stream_responses(self) -> None: stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) expected_responses = ["This is a single response."] assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses @@ -361,7 +548,7 @@ def test_get_stream_responses_empty(self) -> None: stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) expected_responses = [""] assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses @@ -698,6 +885,114 @@ def test_get_stream_responses_empty(self) -> None: stream_handler_mock.assert_not_called() +class TestCohereCommandRAdapter: + def test_prepare_body(self) -> None: + adapter = CohereCommandRAdapter( + model_kwargs={ + "chat_history": [ + {"role": "CHATBOT", "content": "How can I help you today?"}, + ], + "documents": [ + {"title": "France", "snippet": "Paris is the capital of France."}, + {"title": "Germany", "snippet": "Berlin is the capital of Germany."}, + ], + "search_query_only": False, + "preamble": "preamble", + "temperature": 0, + "p": 0.9, + "k": 50, + "prompt_truncation": "AUTO_PRESERVE_ORDER", + "frequency_penalty": 0.3, + "presence_penalty": 0.4, + "seed": 42, + "return_prompt": True, + "tools": [ + { + "name": "query_daily_sales_report", + "description": "Connects to a database to retrieve overall sales volumes and sales " + "information for a given day.", + "parameter_definitions": { + "day": { + "description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.", + "type": "str", + "required": True, + } + }, + } + ], + "tool_results": [ + { + "call": {"name": "query_daily_sales_report", "parameters": {"day": "2023-09-29"}}, + "outputs": [ + {"date": "2023-09-29", "summary": "Total Sales Amount: 10000, Total Units Sold: 250"} + ], + } + ], + "stop_sequences": ["\n\n"], + "raw_prompting": True, + "stream": True, + "unknown_arg": "unknown_arg", + }, + max_length=100, + ) + body = adapter.prepare_body(prompt="test") + assert body == { + "message": "test", + "chat_history": [ + {"role": "CHATBOT", "content": "How can I help you today?"}, + ], + "documents": [ + {"title": "France", "snippet": "Paris is the capital of France."}, + {"title": "Germany", "snippet": "Berlin is the capital of Germany."}, + ], + "search_query_only": False, + "preamble": "preamble", + "max_tokens": 100, + "temperature": 0, + "p": 0.9, + "k": 50, + "prompt_truncation": "AUTO_PRESERVE_ORDER", + "frequency_penalty": 0.3, + "presence_penalty": 0.4, + "seed": 42, + "return_prompt": True, + "tools": [ + { + "name": "query_daily_sales_report", + "description": "Connects to a database to retrieve overall sales volumes and sales " + "information for a given day.", + "parameter_definitions": { + "day": { + "description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.", + "type": "str", + "required": True, + } + }, + } + ], + "tool_results": [ + { + "call": {"name": "query_daily_sales_report", "parameters": {"day": "2023-09-29"}}, + "outputs": [{"date": "2023-09-29", "summary": "Total Sales Amount: 10000, Total Units Sold: 250"}], + } + ], + "stop_sequences": ["\n\n"], + "raw_prompting": True, + } + + def test_extract_completions_from_response(self) -> None: + adapter = CohereCommandRAdapter(model_kwargs={}, max_length=100) + response_body = {"text": "response"} + completions = adapter._extract_completions_from_response(response_body=response_body) + assert completions == ["response"] + + def test_extract_token_from_stream(self) -> None: + adapter = CohereCommandRAdapter(model_kwargs={}, max_length=100) + chunk = {"text": "response_token"} + token = adapter._extract_token_from_stream(chunk=chunk) + assert token == "response_token" + + class TestAI21LabsJurassic2Adapter: def test_prepare_body_with_default_params(self) -> None: layer = AI21LabsJurassic2Adapter(model_kwargs={}, max_length=99) @@ -995,9 +1290,9 @@ def test_get_stream_responses_empty(self) -> None: stream_handler_mock.assert_not_called() -class TestMetaLlama2ChatAdapter: +class TestMetaLlamaAdapter: def test_prepare_body_with_default_params(self) -> None: - layer = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + layer = MetaLlamaAdapter(model_kwargs={}, max_length=99) prompt = "Hello, how are you?" expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 99} @@ -1006,7 +1301,7 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + layer = MetaLlamaAdapter(model_kwargs={}, max_length=99) prompt = "Hello, how are you?" expected_body = { "prompt": "Hello, how are you?", @@ -1026,7 +1321,7 @@ def test_prepare_body_with_custom_inference_params(self) -> None: assert body == expected_body def test_prepare_body_with_model_kwargs(self) -> None: - layer = MetaLlama2ChatAdapter( + layer = MetaLlamaAdapter( model_kwargs={ "temperature": 0.7, "top_p": 0.8, @@ -1048,7 +1343,7 @@ def test_prepare_body_with_model_kwargs(self) -> None: assert body == expected_body def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: - layer = MetaLlama2ChatAdapter( + layer = MetaLlamaAdapter( model_kwargs={ "temperature": 0.6, "top_p": 0.7, @@ -1070,13 +1365,13 @@ def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> Non assert body == expected_body def test_get_responses(self) -> None: - adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99) response_body = {"generation": "This is a single response."} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses def test_get_responses_leading_whitespace(self) -> None: - adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99) response_body = {"generation": "\n\t This is a single response."} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses @@ -1095,7 +1390,7 @@ def test_get_stream_responses(self) -> None: stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses @@ -1117,7 +1412,7 @@ def test_get_stream_responses_empty(self) -> None: stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99) expected_responses = [""] assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses