Skip to content

Commit

Permalink
[Bugfix] Ensure special tokens are properly filtered out for guided s…
Browse files Browse the repository at this point in the history
…tructured output with MistralTokenizer (vllm-project#10363)

Signed-off-by: Guillaume Calmettes <[email protected]>
  • Loading branch information
gcalmettes authored Nov 15, 2024
1 parent 3a763ba commit 691a3ec
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
4 changes: 2 additions & 2 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pillow # Required for image processing
prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.6
lm-format-enforcer >= 0.10.9, < 0.11
outlines >= 0.0.43, < 0.1
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
Expand All @@ -31,4 +31,4 @@ pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
einops # Required for Qwen2-VL.
compressed-tensors == 0.8.0 # required for compressed-tensors
compressed-tensors == 0.8.0 # required for compressed-tensors
19 changes: 15 additions & 4 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,29 @@ def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
revision=revision)
return tokenizer_file

# the following attributes are set to fit VLLM's design
# the following attributes are set to fit VLLM's design and are used
# by the guided structured output backends.
@property
def all_special_tokens_extended(self) -> List[str]:
return []
# tekken defines its own extended special tokens list
if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
special_tokens = self.tokenizer.SPECIAL_TOKENS
else:
special_tokens = list(SpecialTokens)
return [
s.value if isinstance(s, SpecialTokens) else s
for s in special_tokens
]

@property
def all_special_tokens(self) -> List[str]:
return []
return self.all_special_tokens_extended

@property
def all_special_ids(self) -> List[int]:
return []
return [
self.all_special_tokens.index(t) for t in self.all_special_tokens
]

@property
def bos_token_id(self) -> int:
Expand Down

0 comments on commit 691a3ec

Please sign in to comment.