Skip to content

Commit

Permalink
[HF] Change quick start example to HF, make it run on both CUDA and M…
Browse files Browse the repository at this point in the history
…ac (#74)

This PR uses the `transformers_example.py` as Quick Start example in the
doc. Tried on both CUDA and Mac (with `device="cpu"`). Also change
`pin_memory=torch.cuda.is_available(),` for `allocate_token_bitmask()`
since it is not supported for all devices (e.g. Mac).
  • Loading branch information
CharlieFRuan authored Nov 22, 2024
1 parent 1483b0f commit b9a16de
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 44 deletions.
67 changes: 35 additions & 32 deletions docs/start/quick_start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,47 @@ Quick Start
Example
-------

The easiest way of trying out XGrammar is to use the ``transformers`` library in Python.
After :ref:`installing XGrammar <installation>`, run the following example to see how XGrammar enables
structured generation -- a JSON in this case.

.. code:: python
import xgrammar as xgr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
from transformers import AutoTokenizer, AutoConfig
# Get tokenizer info
model_id = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = AutoConfig.from_pretrained(model_id)
# This can be larger than tokenizer.vocab_size due to paddings
full_vocab_size = config.vocab_size
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=full_vocab_size)
# Compile a JSON grammar
compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
compiled_grammar: xgr.CompiledGrammar = compiler.compile_builtin_json_grammar()
# Instantiate grammar matcher and allocate the bitmask
matcher = xgr.GrammarMatcher(compiled_grammar)
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
# Each loop iteration is a simulated auto-regressive step. Here we use
# simulated logits and sampled tokens. In real application, use XGrammar
# in a LLM generation loop and sample with the masked logits.
sim_sampled_response = '{ "library": "xgrammar" }<|endoftext|>'
sim_sampled_token_ids = tokenizer.encode(sim_sampled_response)
for i, sim_token_id in enumerate(sim_sampled_token_ids):
logits = torch.randn(full_vocab_size).cuda()
matcher.fill_next_token_bitmask(token_bitmask)
xgr.apply_token_bitmask_inplace(logits, token_bitmask.to(logits.device))
assert matcher.accept_token(sim_token_id)
assert matcher.is_terminated()
matcher.reset()
import xgrammar as xgr
device = "cuda" # Or "cpu", etc.
# 0. Instantiate with any HF model you want
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
# model_name = "microsoft/Phi-3.5-mini-instruct"
# model_name = "meta-llama/Llama-3.2-1B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float32, device_map=device
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
# 1. Compile grammar (NOTE: you can substitute this with other grammars like EBNF, JSON Schema)
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=config.vocab_size)
grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
compiled_grammar = grammar_compiler.compile_builtin_json_grammar()
# 2. Prepare inputs
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Introduce yourself in JSON briefly."},
]
texts = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer(texts, return_tensors="pt").to(model.device)
# 3. Instantiate logits_processor per each generate, generate, and print response
xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar)
generated_ids = model.generate(
**model_inputs, max_new_tokens=512, logits_processor=[xgr_logits_processor]
)
generated_ids = generated_ids[0][len(model_inputs.input_ids[0]) :]
print(tokenizer.decode(generated_ids, skip_special_tokens=True))
What to Do Next
Expand Down
12 changes: 6 additions & 6 deletions examples/hf_transformers/transformers_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
a minimal LogitsProcessor.
"""

import xgrammar as xgr
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import xgrammar as xgr

device = "cuda"
# device = "cpu"

# 0. Instantiate model
# Or any HF model you want
# 0. Instantiate with any HF model you want
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
# model_name = "microsoft/Phi-3.5-mini-instruct"
# model_name = "meta-llama/Llama-3.2-1B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float32, device_map="auto"
model_name, torch_dtype=torch.float32, device_map=device
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
Expand Down
2 changes: 1 addition & 1 deletion python/xgrammar/apply_token_bitmask_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def apply_token_bitmask_inplace_cpu(
logits.masked_fill_(~bool_mask, -float("inf"))
else:
if not isinstance(indices, torch.Tensor):
indices = torch.tensor(indices, dtype=torch.long, device=logits.device)
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
len_indices = len(indices)
if len_indices != bitmask.size(0):
raise ValueError("The length of indices and bitmask's batch size must match.")
Expand Down
2 changes: 0 additions & 2 deletions python/xgrammar/contrib/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
"Expect input_ids.shape[0] to be LogitsProcessor.batch_size."
+ f"Got {input_ids.shape[0]} for the former, and {self.batch_size} for the latter."
)
if scores.device.type != "cuda":
raise RuntimeError("logits must be on CUDA")

if not self.prefilled:
# Have not sampled a token yet
Expand Down
9 changes: 6 additions & 3 deletions python/xgrammar/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

bitmask_dtype = torch.int32

is_cuda_available = torch.cuda.is_available()


def get_bitmask_shape(batch_size: int, vocab_size: int) -> Tuple[int, int]:
"""Return the shape of the bitmask (batch_size, ceil(vocab_size / 32))"""
Expand All @@ -45,7 +47,7 @@ def allocate_token_bitmask(batch_size: int, vocab_size: int) -> torch.Tensor:
return torch.empty(
xgr.get_bitmask_shape(batch_size, vocab_size),
dtype=xgr.bitmask_dtype,
pin_memory=True,
pin_memory=torch.cuda.is_available(),
)
Parameters
Expand All @@ -63,12 +65,13 @@ def allocate_token_bitmask(batch_size: int, vocab_size: int) -> torch.Tensor:
Note
----
This is the default way of allocating a bitmask. You can also customize the implementation.
- This is the default way of allocating a bitmask. You can also customize the implementation.
- For CUDA, use `pin_memory` in `torch.empty()` to speed up data transfer from CPU to GPU.
"""
return torch.empty(
get_bitmask_shape(batch_size, vocab_size),
dtype=bitmask_dtype,
pin_memory=True,
pin_memory=is_cuda_available,
)


Expand Down

0 comments on commit b9a16de

Please sign in to comment.