From b9a16de54e1e0eff58da14c65750414cceaf1a6f Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Thu, 21 Nov 2024 18:16:55 -0800 Subject: [PATCH] [HF] Change quick start example to HF, make it run on both CUDA and Mac (#74) This PR uses the `transformers_example.py` as Quick Start example in the doc. Tried on both CUDA and Mac (with `device="cpu"`). Also change `pin_memory=torch.cuda.is_available(),` for `allocate_token_bitmask()` since it is not supported for all devices (e.g. Mac). --- docs/start/quick_start.rst | 67 ++++++++++--------- .../hf_transformers/transformers_example.py | 12 ++-- python/xgrammar/apply_token_bitmask_cpu.py | 2 +- python/xgrammar/contrib/hf.py | 2 - python/xgrammar/matcher.py | 9 ++- 5 files changed, 48 insertions(+), 44 deletions(-) diff --git a/docs/start/quick_start.rst b/docs/start/quick_start.rst index 08be4b1..2f371ad 100644 --- a/docs/start/quick_start.rst +++ b/docs/start/quick_start.rst @@ -6,44 +6,47 @@ Quick Start Example ------- +The easiest way of trying out XGrammar is to use the ``transformers`` library in Python. After :ref:`installing XGrammar `, run the following example to see how XGrammar enables structured generation -- a JSON in this case. .. code:: python - import xgrammar as xgr + from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig import torch - from transformers import AutoTokenizer, AutoConfig - - # Get tokenizer info - model_id = "Qwen/Qwen2.5-1.5B-Instruct" - tokenizer = AutoTokenizer.from_pretrained(model_id) - config = AutoConfig.from_pretrained(model_id) - # This can be larger than tokenizer.vocab_size due to paddings - full_vocab_size = config.vocab_size - tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=full_vocab_size) - - # Compile a JSON grammar - compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) - compiled_grammar: xgr.CompiledGrammar = compiler.compile_builtin_json_grammar() - - # Instantiate grammar matcher and allocate the bitmask - matcher = xgr.GrammarMatcher(compiled_grammar) - token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size) - - # Each loop iteration is a simulated auto-regressive step. Here we use - # simulated logits and sampled tokens. In real application, use XGrammar - # in a LLM generation loop and sample with the masked logits. - sim_sampled_response = '{ "library": "xgrammar" }<|endoftext|>' - sim_sampled_token_ids = tokenizer.encode(sim_sampled_response) - for i, sim_token_id in enumerate(sim_sampled_token_ids): - logits = torch.randn(full_vocab_size).cuda() - matcher.fill_next_token_bitmask(token_bitmask) - xgr.apply_token_bitmask_inplace(logits, token_bitmask.to(logits.device)) - assert matcher.accept_token(sim_token_id) - - assert matcher.is_terminated() - matcher.reset() + import xgrammar as xgr + + device = "cuda" # Or "cpu", etc. + # 0. Instantiate with any HF model you want + model_name = "Qwen/Qwen2.5-0.5B-Instruct" + # model_name = "microsoft/Phi-3.5-mini-instruct" + # model_name = "meta-llama/Llama-3.2-1B-Instruct" + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float32, device_map=device + ) + tokenizer = AutoTokenizer.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name) + + # 1. Compile grammar (NOTE: you can substitute this with other grammars like EBNF, JSON Schema) + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=config.vocab_size) + grammar_compiler = xgr.GrammarCompiler(tokenizer_info) + compiled_grammar = grammar_compiler.compile_builtin_json_grammar() + + # 2. Prepare inputs + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Introduce yourself in JSON briefly."}, + ] + texts = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = tokenizer(texts, return_tensors="pt").to(model.device) + + # 3. Instantiate logits_processor per each generate, generate, and print response + xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar) + generated_ids = model.generate( + **model_inputs, max_new_tokens=512, logits_processor=[xgr_logits_processor] + ) + generated_ids = generated_ids[0][len(model_inputs.input_ids[0]) :] + print(tokenizer.decode(generated_ids, skip_special_tokens=True)) What to Do Next diff --git a/examples/hf_transformers/transformers_example.py b/examples/hf_transformers/transformers_example.py index 5e588a7..902d7c3 100644 --- a/examples/hf_transformers/transformers_example.py +++ b/examples/hf_transformers/transformers_example.py @@ -3,20 +3,20 @@ a minimal LogitsProcessor. """ -import xgrammar as xgr -import torch - from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +import torch +import xgrammar as xgr +device = "cuda" +# device = "cpu" -# 0. Instantiate model -# Or any HF model you want +# 0. Instantiate with any HF model you want model_name = "Qwen/Qwen2.5-0.5B-Instruct" # model_name = "microsoft/Phi-3.5-mini-instruct" # model_name = "meta-llama/Llama-3.2-1B-Instruct" model = AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype=torch.float32, device_map="auto" + model_name, torch_dtype=torch.float32, device_map=device ) tokenizer = AutoTokenizer.from_pretrained(model_name) config = AutoConfig.from_pretrained(model_name) diff --git a/python/xgrammar/apply_token_bitmask_cpu.py b/python/xgrammar/apply_token_bitmask_cpu.py index 781b493..06b5d98 100644 --- a/python/xgrammar/apply_token_bitmask_cpu.py +++ b/python/xgrammar/apply_token_bitmask_cpu.py @@ -50,7 +50,7 @@ def apply_token_bitmask_inplace_cpu( logits.masked_fill_(~bool_mask, -float("inf")) else: if not isinstance(indices, torch.Tensor): - indices = torch.tensor(indices, dtype=torch.long, device=logits.device) + indices = torch.tensor(indices, dtype=torch.int32, device=logits.device) len_indices = len(indices) if len_indices != bitmask.size(0): raise ValueError("The length of indices and bitmask's batch size must match.") diff --git a/python/xgrammar/contrib/hf.py b/python/xgrammar/contrib/hf.py index cb71d49..bd23034 100644 --- a/python/xgrammar/contrib/hf.py +++ b/python/xgrammar/contrib/hf.py @@ -74,8 +74,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to "Expect input_ids.shape[0] to be LogitsProcessor.batch_size." + f"Got {input_ids.shape[0]} for the former, and {self.batch_size} for the latter." ) - if scores.device.type != "cuda": - raise RuntimeError("logits must be on CUDA") if not self.prefilled: # Have not sampled a token yet diff --git a/python/xgrammar/matcher.py b/python/xgrammar/matcher.py index c9e2941..dfc604d 100644 --- a/python/xgrammar/matcher.py +++ b/python/xgrammar/matcher.py @@ -30,6 +30,8 @@ bitmask_dtype = torch.int32 +is_cuda_available = torch.cuda.is_available() + def get_bitmask_shape(batch_size: int, vocab_size: int) -> Tuple[int, int]: """Return the shape of the bitmask (batch_size, ceil(vocab_size / 32))""" @@ -45,7 +47,7 @@ def allocate_token_bitmask(batch_size: int, vocab_size: int) -> torch.Tensor: return torch.empty( xgr.get_bitmask_shape(batch_size, vocab_size), dtype=xgr.bitmask_dtype, - pin_memory=True, + pin_memory=torch.cuda.is_available(), ) Parameters @@ -63,12 +65,13 @@ def allocate_token_bitmask(batch_size: int, vocab_size: int) -> torch.Tensor: Note ---- - This is the default way of allocating a bitmask. You can also customize the implementation. + - This is the default way of allocating a bitmask. You can also customize the implementation. + - For CUDA, use `pin_memory` in `torch.empty()` to speed up data transfer from CPU to GPU. """ return torch.empty( get_bitmask_shape(batch_size, vocab_size), dtype=bitmask_dtype, - pin_memory=True, + pin_memory=is_cuda_available, )