Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-risch committed Jun 24, 2024
1 parent 12eec95 commit 510192a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,12 @@ def __init__(self, generation_kwargs: Dict[str, Any]):
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
else:
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
logger.warning(f"Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for "
f"estimating the prompt length because no HF_TOKEN was found. Using "
f"NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var "
f"HF_TOKEN containing a Hugging Face token and make sure you have access to the model.")
logger.warning(
"Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for "
"estimating the prompt length because no HF_TOKEN was found. Using "
"NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var "
"HF_TOKEN containing a Hugging Face token and make sure you have access to the model."
)

self.prompt_handler = DefaultPromptHandler(
tokenizer=tokenizer,
Expand Down
8 changes: 4 additions & 4 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def test_use_mistral_adapter_without_hf_token(self, monkeypatch: MonkeyPatch, ca
with (
patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained,
patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"),
caplog.at_level(logging.WARNING)
caplog.at_level(logging.WARNING),
):
MistralChatAdapter(generation_kwargs={})
mock_pretrained.assert_called_with("NousResearch/Llama-2-7b-chat-hf")
Expand All @@ -256,16 +256,16 @@ def test_use_mistral_adapter_with_hf_token(self, monkeypatch: MonkeyPatch) -> No
monkeypatch.setenv("HF_TOKEN", "test")
with (
patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained,
patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler")
patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"),
):
MistralChatAdapter(generation_kwargs={})
mock_pretrained.assert_called_with("mistralai/Mistral-7B-Instruct-v0.1")

@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason=(
"To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also "
"have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`"
"To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also "
"have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`"
),
)
@pytest.mark.parametrize("model_name", MISTRAL_MODELS)
Expand Down

0 comments on commit 510192a

Please sign in to comment.