Skip to content

Commit

Permalink
feat: Use non-gated tokenizer as fallback for mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-risch committed Jun 24, 2024
1 parent 75bb792 commit 12eec95
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, ClassVar, Dict, List

Expand Down Expand Up @@ -332,7 +333,16 @@ def __init__(self, generation_kwargs: Dict[str, Any]):
# Use `mistralai/Mistral-7B-v0.1` as tokenizer, all mistral models likely use the same tokenizer
# a) we should get good estimates for the prompt length
# b) we can use apply_chat_template with the template above to delineate ChatMessages
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
# Mistral models are gated on HF Hub. If no HF_TOKEN is found we use a non-gated alternative tokenizer model.
if os.environ.get("HF_TOKEN"):
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
else:
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
logger.warning(f"Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for "
f"estimating the prompt length because no HF_TOKEN was found. Using "
f"NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var "
f"HF_TOKEN containing a Hugging Face token and make sure you have access to the model.")

self.prompt_handler = DefaultPromptHandler(
tokenizer=tokenizer,
model_max_length=model_max_length,
Expand Down
37 changes: 30 additions & 7 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import logging
import os
from typing import Optional, Type
from unittest.mock import patch

import pytest
from _pytest.monkeypatch import MonkeyPatch
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk

Expand Down Expand Up @@ -183,13 +186,6 @@ def test_prepare_body_with_custom_inference_params(self) -> None:
assert body == expected_body


@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason=(
"To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also "
"have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`"
),
)
class TestMistralAdapter:
def test_prepare_body_with_default_params(self) -> None:
layer = MistralChatAdapter(generation_kwargs={})
Expand Down Expand Up @@ -245,6 +241,33 @@ def test_mistral_chat_template_incorrect_order(self):
except Exception as e:
assert "Conversation roles must alternate user/assistant/" in str(e)

def test_use_mistral_adapter_without_hf_token(self, monkeypatch: MonkeyPatch, caplog) -> None:
monkeypatch.delenv("HF_TOKEN", raising=False)
with (
patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained,
patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"),
caplog.at_level(logging.WARNING)
):
MistralChatAdapter(generation_kwargs={})
mock_pretrained.assert_called_with("NousResearch/Llama-2-7b-chat-hf")
assert "no HF_TOKEN was found" in caplog.text

def test_use_mistral_adapter_with_hf_token(self, monkeypatch: MonkeyPatch) -> None:
monkeypatch.setenv("HF_TOKEN", "test")
with (
patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained,
patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler")
):
MistralChatAdapter(generation_kwargs={})
mock_pretrained.assert_called_with("mistralai/Mistral-7B-Instruct-v0.1")

@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason=(
"To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also "
"have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`"
),
)
@pytest.mark.parametrize("model_name", MISTRAL_MODELS)
@pytest.mark.integration
def test_default_inference_params(self, model_name, chat_messages):
Expand Down

0 comments on commit 12eec95

Please sign in to comment.