diff --git a/docs/how_to/ebnf_guided_generation.rst b/docs/how_to/ebnf_guided_generation.rst new file mode 100644 index 0000000..e830cc6 --- /dev/null +++ b/docs/how_to/ebnf_guided_generation.rst @@ -0,0 +1,184 @@ +.. _how-to-ebnf-generation: + +EBNF-Guided Generation +====================== + +XGrammar enables efficient structured generation. Besides JSON, you can use an EBNF +grammar to guide the generation, providing more flexibility for customization. + +We first go over how to use XGrammar in an LLM engine to achieve this in +:ref:`EBNF-Guided Generation in LLM Engines `, we then provide +an end-to-end JSON generation using XGrammar with HF ``transformers`` in +:ref:`Try out via HF Transformers `. + +Install XGrammar +~~~~~~~~~~~~~~~~ + +:ref:`XGrammar ` is available via pip. +It is always recommended to install it in an isolated conda virtual environment. + + +.. _how-to-ebnf-generation-engine: + +EBNF-Guided Generation in LLM Engines +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In this section, we see how to use XGrammar in an LLM engine to ensure that the output follows +ane EBNF grammar. + +All code snippets below are actual runnable code as we simulate the LLM generation. + +First, import necessary libraries for the tutorial. + +.. code:: python + + import xgrammar as xgr + import torch + import numpy as np + from transformers import AutoTokenizer, AutoConfig + +Then, we extract tokenizer info from the LLM we are using with ``xgr.TokenizerInfo``. With +the ``tokenizer_info``, instantiate ``xgr.GrammarCompiler`` that will compiler a grammar of +your choice. + +.. code:: python + + # Get tokenizer info + model_id = "meta-llama/Llama-3.2-1B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_id) + config = AutoConfig.from_pretrained(model_id) + # This can be larger than tokenizer.vocab_size due to paddings + full_vocab_size = config.vocab_size + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=full_vocab_size) + + compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + +Then specify an EBNF grammar string. We currently use +the GBNF format (GGML BNF), with the specification +`here `__. + + +.. code:: python + + ebnf_grammar_str = """root ::= (expr "=" term)+ + expr ::= term ([-+*/] term)* + term ::= num | "(" expr ")" + num ::= [0-9]+""" + + compiled_grammar = compiler.compile_grammar(ebnf_grammar_str) + +With the compiled grammar, we can instantiate a ``xgr.GrammarMatcher``, the main construct +we interact with that maintains the state of the structured generation. We also allocate a +bitmask that will be used to mask logits. + +.. code:: python + + # Instantiate grammar matcher and allocate the bitmask + matcher = xgr.GrammarMatcher(compiled_grammar) + token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size) + +Now we simulate a single-request auto-regressive generation. See :ref:`how-to-engine-integration` +for batched inference. + +.. code:: python + + # Here we simulate a valid sampled response + sim_sampled_response = '(5+3)*2=16<|endoftext|>' + sim_sampled_token_ids = tokenizer.encode(sim_sampled_response) + + # Each loop iteration is a simulated auto-regressive step + for i, sim_token_id in enumerate(sim_sampled_token_ids): + # LLM inference to get logits, here we use randn to simulate. + # logits is a tensor of shape (full_vocab_size,) on GPU + # logits = LLM.inference() + logits = torch.randn(full_vocab_size).cuda() + + # Apply bitmask to logits to mask invalid tokens + matcher.fill_next_token_bitmask(token_bitmask) + xgr.apply_token_bitmask_inplace(logits, token_bitmask.to(logits.device)) + + # Sample next token + probs = torch.softmax(logits, dim=-1).cpu().numpy() + next_token_id = np.random.choice(list(range(full_vocab_size)), p=probs) + + # Accept token from matcher to update its state, so that the next bitmask + # generated will enforce the next token to be generated. Assert to make + # sure the token is indeed valid. Here we accept the simulated response + # assert matcher.accept_token(next_token_id) + assert matcher.accept_token(sim_token_id) + + # Since we accepted a stop token `<|endoftext|>`, we have terminated + assert matcher.is_terminated() + + # Reset to be ready for the next auto-regressive generation + matcher.reset() + + + +.. _how-to-ebnf-generation-HF: + +Try out via HF Transformers +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +XGrammar can be easily integrate with HF transformers using a ``LogitsProcessor``. Note that +this integration mainly aims for accessibility and may contain extra overhead. + +First, instantiate a model, a tokenizer, and inputs. + +.. code:: python + + import xgrammar as xgr + + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig + + device = "cuda" # Or "cpu", etc. + model_name = "meta-llama/Llama-3.2-1B-Instruct" + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float32, device_map=device + ) + tokenizer = AutoTokenizer.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Introduce yourself in JSON briefly."}, + ] + texts = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = tokenizer(texts, return_tensors="pt").to(model.device) + + +Then construct a ``GrammarCompiler`` and compile the grammar. + +.. code:: python + + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=config.vocab_size) + grammar_compiler = xgr.GrammarCompiler(tokenizer_info) + # Grammar string that represents a JSON schema + json_grammar_ebnf_str = r""" + root ::= basic_array | basic_object + basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object + basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? + basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? + basic_string ::= (([\"] basic_string_1 [\"])) + basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1 + escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] + basic_boolean ::= "true" | "false" + basic_null ::= "null" + basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]" + basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" + ws ::= [ \n\t]* + """ + compiled_grammar = compiler.compile_json_schema(json_grammar_ebnf) + + +Finally, use ``LogitsProcessor`` to generate with grammar. + +.. code:: python + + xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar) + generated_ids = model.generate( + **model_inputs, max_new_tokens=512, logits_processor=[xgr_logits_processor] + ) + generated_ids = generated_ids[0][len(model_inputs.input_ids[0]) :] + print(tokenizer.decode(generated_ids, skip_special_tokens=True)) diff --git a/docs/how_to/engine_integration.rst b/docs/how_to/engine_integration.rst new file mode 100644 index 0000000..adef9ed --- /dev/null +++ b/docs/how_to/engine_integration.rst @@ -0,0 +1,244 @@ +.. _how-to-engine-integration: + +Integration with LLM Engine +=========================== + +XGrammar enables efficient structured generation. In this tutorial, we go over the key components +of XGrammar and how to integrate XGrammar into an LLM engine. + +We first lay out the concepts in :ref:`High-Level Flow `. +We then demonstrate how XGrammar enables +:ref:`Structured Generation for Batched Inference `. + +The code snippets below are actual runnable code as we simulate the LLM generation. + + +Install XGrammar +---------------- + +:ref:`XGrammar ` is available via pip. +It is always recommended to install it in an isolated conda virtual environment. + + +.. _how-to-engine-integration-flow: + +High-Level Flow +--------------- + +In this section, we go over the key components of XGrammar when integrating it into an LLM engine +for structured generation. + +First, import necessary libraries for the tutorial. + +.. code:: python + + import xgrammar as xgr + import torch + import numpy as np + from transformers import AutoTokenizer, AutoConfig + +xgr.TokenizerInfo +^^^^^^^^^^^^^^^^^ + +``xgr.TokenizerInfo`` is a per-model construct that encapsulates tokenizer information, including +all its vocabulary. There are several ways of instantiating it, and the most convenient way +is using an ``AutoTokenizer``. Note that for some models, ``AutoConfig.vocab_size`` can be larger +than ``AutoTokenizer.vocab_size`` due to paddings, with the former being the shape of the model's +logits. To be safe, always pass in the former when instantiating ``xgr.TokenizerInfo``. + +.. code:: python + + # Get tokenizer info + model_id = "meta-llama/Llama-3.2-1B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_id) + config = AutoConfig.from_pretrained(model_id) + # This can be larger than tokenizer.vocab_size due to paddings + full_vocab_size = config.vocab_size + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=full_vocab_size) + + +xgr.GrammarCompiler +^^^^^^^^^^^^^^^^^^^ + +With an ``xgr.TokenizerInfo``, we can instantiate an ``xgr.GrammarCompiler``. This is a construct +that compiles a grammar according to the model's tokenizer info. Therefore, for each model, you +can use the same ``xgr.GrammarCompiler`` persistently, as it can compile different grammars for +the same ``xgr.TokenizerInfo``. Note that the ``compiler`` behavior can be configured with +``max_threads`` for multithreading, and ``enable_cache`` (defaults to true) for caching +compiled grammars. + +.. code:: python + + compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + + +xgr.CompiledGrammar +^^^^^^^^^^^^^^^^^^^ + +Then, using the ``xgr.GrammarCompiler``, we can compile a grammar, with the result being an +``xgr.CompiledGrammar``. Here we use a built-in JSON grammar. For other grammars, see +:ref:`how-to-json-generation` and :ref:`how-to-ebnf-generation`. +Every thing we have seen up to now are per-model (rather than per-generation). + +.. code:: python + + compiled_grammar: xgr.CompiledGrammar = compiler.compile_builtin_json_grammar() + +xgr.GrammarMatcher +^^^^^^^^^^^^^^^^^^ + +With the compiled grammar, we can instantiate a ``xgr.GrammarMatcher``. It is the main construct +an LLM engine interacts with that maintains the state of the structured generation. Note that +each request should have its own ``xgr.GrammarMatcher`` since each has a different generation state, +as we will see in :ref:`how-to-engine-integration-batched`. + +.. code:: python + + # Instantiate grammar matcher with the compiled grammar + matcher = xgr.GrammarMatcher(compiled_grammar) + +Auto-regressive Generation with xgr.GrammarMatcher +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Now we simulate a single-request auto-regressive generation. See later section for +:ref:`how-to-engine-integration-batched`. + +First, we pre-allocate a token bitmask with ``xgr.allocate_token_bitmask()``, +which is essentially a ``torch.Tensor`` of shape ``(batch_size, vocab_size)``. You can also +use your own implementation for allocating a bitmask. + +In each auto-regressive step, we fill the token bitmask according to the current state +of the matcher with ``xgr.GrammarMatcher.fill_next_token_bitmask()``. Then, we apply the bitmask +into the model's logits with ``xgr.apply_token_bitmask_inplace()``, which calls a CUDA kernel +if ``logits`` is on CUDA (recommended), otherwise a CPU implementation. + +After masking, the logits for illegal tokens are set to negative infinity, so that +we will never sample them. After sampling the token, update the ``xgr.GrammarMatcher``'s state with +``xgr.GrammarMatcher.accept_token()``. Finally, use ``xgr.GrammarMatcher.reset()`` to prepare +for the next generation. + +.. code:: python + + # Here we simulate a valid sampled response + sim_sampled_response = '{ "library": "xgrammar" }<|endoftext|>' + sim_sampled_token_ids = tokenizer.encode(sim_sampled_response) + + # Allocate a token bitmask + token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size) + + # Each loop iteration is a simulated auto-regressive step + for i, sim_token_id in enumerate(sim_sampled_token_ids): + # LLM inference to get logits, here we use randn to simulate. + # logits is a tensor of shape (full_vocab_size,) on GPU + # logits = LLM.inference() + logits = torch.randn(full_vocab_size).cuda() + + # Apply bitmask to logits to mask invalid tokens + matcher.fill_next_token_bitmask(token_bitmask) + xgr.apply_token_bitmask_inplace(logits, token_bitmask.to(logits.device)) + + # Sample next token + probs = torch.softmax(logits, dim=-1).cpu().numpy() + next_token_id = np.random.choice(list(range(full_vocab_size)), p=probs) + + # Accept token from matcher to update its state, so that the next bitmask + # generated will enforce the next token to be generated. Assert to make + # sure the token is indeed valid. Here we accept the simulated response + # assert matcher.accept_token(next_token_id) + assert matcher.accept_token(sim_token_id) + + # Since we accepted a stop token `<|endoftext|>`, we have terminated + assert matcher.is_terminated() + + # Reset to be ready for the next auto-regressive generation + matcher.reset() + + +.. _how-to-engine-integration-batched: + +Structured Generation for Batched Inference +------------------------------------------- + +The code snippets above assume a single request generation. +This section demonstrates how the same concept works with batched generation. + +First, follow the exact same steps above for the per-model constructs +``xgr.TokenizerInfo`` and ``xgr.GrammarCompiler``. Say each request needs +to generate a valid JSON. + +.. code:: python + + import xgrammar as xgr + import torch + import numpy as np + from transformers import AutoTokenizer, AutoConfig + + # Get tokenizer info + model_id = "meta-llama/Llama-3.2-1B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_id) + config = AutoConfig.from_pretrained(model_id) + # This can be larger than tokenizer.vocab_size due to paddings + full_vocab_size = config.vocab_size + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=full_vocab_size) + + # Compile a JSON grammar + compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + compiled_grammar: xgr.CompiledGrammar = compiler.compile_builtin_json_grammar() + +Now, we need to maintain an ``xgr.GrammarMatcher`` for each request in the batch, since +each has a different generation state. Note that each request in the batch can follow a different +``xgr.CompiledGrammar``, but here for simplicity, they are all just following the general +JSON grammar. + +.. code:: python + + batch_size = 2 + matchers = [ + xgr.GrammarMatcher(compiled_grammar) + for i in range(batch_size) + ] + token_bitmask = xgr.allocate_token_bitmask(batch_size, tokenizer_info.vocab_size) + +We simulate an auto-regressive generation of batched inference. Note that here we +assume the generation lengths of the two requests are the same for simplicity. But +it should be easy to generalize based on how your engine supports batched inference. +The key difference from single-request generation is that, in batched-request generation, +each request has its own ``xgr.GrammarMatcher`` to maintain. + +.. code:: python + + sim_sampled_responses = ['{"name": "a"}<|endoftext|>', '{"name": "b"}<|endoftext|>'] + sim_sampled_token_ids = [tokenizer.encode(response) for response in sim_sampled_responses] + + # Each loop iteration is a simulated auto-regressive step + for loop_iter in range(len(sim_sampled_token_ids[0])): + # LLM batched inference to get logits, here we use randn to simulate + # Now, logits is a tensor of shape (batch_size, full_vocab_size) on GPU + # logits = LLM.inference() + logits = torch.randn(batch_size, full_vocab_size).cuda() + + # This for loop is parallelizable using threading.Thread. But estimate + # the overhead in your engine. + for i in range(batch_size): + matchers[i].fill_next_token_bitmask(token_bitmask, i) + xgr.apply_token_bitmask_inplace(logits, token_bitmask.to(logits.device)) + + # Sample next token + probs = torch.softmax(logits, dim=-1).cpu().numpy() + next_token_ids = [ + np.random.choice(list(range(full_vocab_size)), p=probs[i]) + for i in range(batch_size) + ] + + # Update the matcher for each request + for i in range(batch_size): + # Here we accept the simulated response + # assert matchers[i].accept_token(next_token_ids[i]) + matchers[i].accept_token(sim_sampled_token_ids[i][loop_iter]) + + # In our simulated case, all requests should have terminated since we accepted + # a stop token `<|endoftext|>` + for i in range(batch_size): + assert matchers[i].is_terminated() + # Reset to be ready for the next generation + matchers[i].reset() diff --git a/docs/how_to/json_generation.rst b/docs/how_to/json_generation.rst new file mode 100644 index 0000000..329fd28 --- /dev/null +++ b/docs/how_to/json_generation.rst @@ -0,0 +1,205 @@ +.. _how-to-json-generation: + +JSON Generation +====================== + +XGrammar enables efficient structured generation. One example structure is JSON and JSON Schema. +In this tutorial, we go over how to use XGrammar to ensure that an LLM's output is a +valid JSON, or adheres to a customized JSON schema. + +We first go over how to use XGrammar in an LLM engine to achieve this in +:ref:`JSON Generation in LLM Engines `, we then provide +an end-to-end JSON generation using XGrammar with HF ``transformers`` in +:ref:`Try out via HF Transformers `. + +Install XGrammar +~~~~~~~~~~~~~~~~ + +:ref:`XGrammar ` is available via pip. +It is always recommended to install it in an isolated conda virtual environment. + + +.. _how-to-json-generation-engine: + +JSON Generation in LLM Engines +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In this section, we see how to use XGrammar in an LLM engine to ensure that the output is +always a valid JSON. + +All code snippets below are actual runnable code as we simulate the LLM generation. + +First, import necessary libraries for the tutorial. + +.. code:: python + + import xgrammar as xgr + import torch + import numpy as np + from transformers import AutoTokenizer, AutoConfig + +Then, we extract tokenizer info from the LLM we are using with ``xgr.TokenizerInfo``. With +the ``tokenizer_info``, instantiate ``xgr.GrammarCompiler`` that will compiler a grammar of +your choice. + +.. code:: python + + # Get tokenizer info + model_id = "meta-llama/Llama-3.2-1B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_id) + config = AutoConfig.from_pretrained(model_id) + # This can be larger than tokenizer.vocab_size due to paddings + full_vocab_size = config.vocab_size + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=full_vocab_size) + + compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + +For JSON generation, there are generally three options for compiling the grammar: using a built-in +JSON grammar, specify JSON schema with a Pydantic model, or from a JSON schema string. Pick one +one of the three below to run. + +.. code:: python + + # Option 1: Compile with a built-in JSON grammar + compiled_grammar: xgr.CompiledGrammar = compiler.compile_builtin_json_grammar() + +.. code:: python + + # Option 2: Compile with JSON schema from a pydantic model + from pydantic import BaseModel + + class Person(BaseModel): + name: str + age: int + + compiled_grammar = compiler.compile_json_schema(Person) + +.. code:: python + + # Option 3: Compile with JSON schema from a JSON schema string + import json + + person_schema = { + "title": "Person", + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer", + } + }, + "required": ["name", "age"] + } + compiled_grammar = compiler.compile_json_schema(json.dumps(person_schema)) + +With the compiled grammar, we can instantiate a ``xgr.GrammarMatcher``, the main construct +we interact with that maintains the state of the structured generation. We also allocate a +bitmask that will be used to mask logits. + +.. code:: python + + # Instantiate grammar matcher and allocate the bitmask + matcher = xgr.GrammarMatcher(compiled_grammar) + token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size) + +Now we simulate a single-request auto-regressive generation. See :ref:`how-to-engine-integration` +for batched inference. + +.. code:: python + + # Here we simulate a valid sampled response + sim_sampled_response = '{ "library": "xgrammar" }<|endoftext|>' + sim_sampled_token_ids = tokenizer.encode(sim_sampled_response) + + # Each loop iteration is a simulated auto-regressive step + for i, sim_token_id in enumerate(sim_sampled_token_ids): + # LLM inference to get logits, here we use randn to simulate. + # logits is a tensor of shape (full_vocab_size,) on GPU + # logits = LLM.inference() + logits = torch.randn(full_vocab_size).cuda() + + # Apply bitmask to logits to mask invalid tokens + matcher.fill_next_token_bitmask(token_bitmask) + xgr.apply_token_bitmask_inplace(logits, token_bitmask.to(logits.device)) + + # Sample next token + probs = torch.softmax(logits, dim=-1).cpu().numpy() + next_token_id = np.random.choice(list(range(full_vocab_size)), p=probs) + + # Accept token from matcher to update its state, so that the next bitmask + # generated will enforce the next token to be generated. Assert to make + # sure the token is indeed valid. Here we accept the simulated response + # assert matcher.accept_token(next_token_id) + assert matcher.accept_token(sim_token_id) + + # Since we accepted a stop token `<|endoftext|>`, we have terminated + assert matcher.is_terminated() + + # Reset to be ready for the next auto-regressive generation + matcher.reset() + + + +.. _how-to-json-generation-HF: + +Try out via HF Transformers +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +XGrammar can be easily integrate with HF transformers using a ``LogitsProcessor``. Note that +this integration mainly aims for accessibility and may contain extra overhead. + +First, instantiate a model, a tokenizer, and inputs. + +.. code:: python + + import xgrammar as xgr + + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig + + device = "cuda" # Or "cpu", etc. + model_name = "meta-llama/Llama-3.2-1B-Instruct" + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float32, device_map=device + ) + tokenizer = AutoTokenizer.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Introduce yourself in JSON with two fields: name and age."}, + ] + texts = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = tokenizer(texts, return_tensors="pt").to(model.device) + + +Then construct a ``GrammarCompiler`` and compile the grammar. + +.. code:: python + + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=config.vocab_size) + grammar_compiler = xgr.GrammarCompiler(tokenizer_info) + # Option 1: Compile with a built-in JSON grammar + # compiled_grammar = grammar_compiler.compile_builtin_json_grammar() + # Option 2: Compile with JSON schema from a pydantic model + from pydantic import BaseModel + + class Person(BaseModel): + name: str + age: int + + compiled_grammar = compiler.compile_json_schema(Person) + + +Finally, use ``LogitsProcessor`` to generate with grammar. + +.. code:: python + + xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar) + generated_ids = model.generate( + **model_inputs, max_new_tokens=512, logits_processor=[xgr_logits_processor] + ) + generated_ids = generated_ids[0][len(model_inputs.input_ids[0]) :] + print(tokenizer.decode(generated_ids, skip_special_tokens=True)) diff --git a/docs/index.rst b/docs/index.rst index dc5abfb..6ea3c72 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -24,11 +24,12 @@ Check out :ref:`quick-start` for quick start examples of using XGrammar. .. toctree:: :maxdepth: 1 - :caption: Tutorials + :caption: How To :hidden: - tutorials/structured_generation.rst - .. tutorials/backend_integration.rst .. TODO + how_to/json_generation.rst + how_to/ebnf_guided_generation.rst + how_to/engine_integration.rst .. tutorials/web_sdk.rst .. TODO diff --git a/docs/start/quick_start.rst b/docs/start/quick_start.rst index fe9fa66..f8245c7 100644 --- a/docs/start/quick_start.rst +++ b/docs/start/quick_start.rst @@ -23,7 +23,7 @@ Instantiate a model, a tokenizer, and inputs. from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig device = "cuda" # Or "cpu", etc. - model_name = "meta-llama/Llama-3.1-8B-Instruct" + model_name = "meta-llama/Llama-3.2-1B-Instruct" model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, device_map=device ) @@ -75,5 +75,5 @@ Use logits_processor to generate with grammar. What to Do Next --------------- -- Check out :ref:`tutorial-structured-generation` for the detailed usage guide of XGrammar. +- Check out :ref:`how-to-json-generation` and other How-To guides for the detailed usage guide of XGrammar. - Report any problem or ask any question: open new issues in our `GitHub repo `_. diff --git a/docs/tutorials/structured_generation.rst b/docs/tutorials/structured_generation.rst deleted file mode 100644 index f798476..0000000 --- a/docs/tutorials/structured_generation.rst +++ /dev/null @@ -1,388 +0,0 @@ -.. _tutorial-structured-generation: - -Structured Generation -====================== - -XGrammar enables efficient structured generation. In this tutorial, we go over how to -use XGrammar to ensure that an LLM's output adheres to the structure of a valid JSON, a -customized JSON schema, and a customized EBNF grammar string. - -We first lay out the concepts by going over :ref:`JSON generation ` -in detail. Then we go over how to generate with :ref:`customized JSON schemas ` -and :ref:`customized EBNF grammar strings `. Finally, we demonstrate -how xgrammar works with :ref:`batched inference `. - -Therefore, we encourage you to start with :ref:`JSON Generation `. - -The code snippets below are actual runnable code as we simulate the LLM generation. - - -Install XGrammar -~~~~~~~~~~~~~~~~ - -:ref:`XGrammar ` is available via pip. -It is always recommended to install it in an isolated conda virtual environment. - - -.. _tutorial-json-generation: - -JSON Generation -~~~~~~~~~~~~~~~ - -In this section, we see how to use XGrammar to ensure that an LLM's output is -always a valid JSON. - -First, import necessary libraries for the tutorial. - -.. code:: python - - import xgrammar as xgr - import torch - import numpy as np - from transformers import AutoTokenizer, AutoConfig - -Then, we extract tokenizer info from the LLM we are using with ``xgr.TokenizerInfo`` - -.. code:: python - - # Get tokenizer info - model_id = "Qwen/Qwen2.5-1.5B-Instruct" - tokenizer = AutoTokenizer.from_pretrained(model_id) - config = AutoConfig.from_pretrained(model_id) - # This can be larger than tokenizer.vocab_size due to paddings - full_vocab_size = config.vocab_size - tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=full_vocab_size) - -With the ``tokenizer_info``, instantiate ``xgr.GrammarCompiler`` that compiles a -grammar of your choice. Here we use a JSON grammar. Note that the ``compiler`` behavior -can be configured with ``max_threads`` for multithreading, and ``enable_cache`` (defaults to -true) for caching compiled grammars. Note that every thing we have seen up to now are per-model (rather -than per-generation). - -.. code:: python - - compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) - compiled_grammar: xgr.CompiledGrammar = compiler.compile_builtin_json_grammar() - -With the compiled grammar, we can instantiate a ``xgr.GrammarMatcher``, the main construct -we interact with that maintains the state of the structured generation. We also allocate a -bitmask that will be used to mask logits. - -.. code:: python - - # Instantiate grammar matcher and allocate the bitmask - matcher = xgr.GrammarMatcher(compiled_grammar) - token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size) - -Now we simulate a single-request auto-regressive generation. See later section for :ref:`batched generation `. - -.. code:: python - - # Here we simulate a valid sampled response - sim_sampled_response = '{ "library": "xgrammar" }<|endoftext|>' - sim_sampled_token_ids = tokenizer.encode(sim_sampled_response) - - # Each loop iteration is a simulated auto-regressive step - for i, sim_token_id in enumerate(sim_sampled_token_ids): - # LLM inference to get logits, here we use randn to simulate. - # logits is a tensor of shape (full_vocab_size,) on GPU - # logits = LLM.inference() - logits = torch.randn(full_vocab_size).cuda() - - # Apply bitmask to logits to mask invalid tokens - matcher.fill_next_token_bitmask(token_bitmask) - xgr.apply_token_bitmask_inplace(logits, token_bitmask.to(logits.device)) - - # Sample next token - probs = torch.softmax(logits, dim=-1).cpu().numpy() - next_token_id = np.random.choice(list(range(full_vocab_size)), p=probs) - - # Accept token from matcher to update its state, so that the next bitmask - # generated will enforce the next token to be generated. Assert to make - # sure the token is indeed valid. Here we accept the simulated response - # assert matcher.accept_token(next_token_id) - assert matcher.accept_token(sim_token_id) - - # Since we accepted a stop token `<|endoftext|>`, we have terminated - assert matcher.is_terminated() - - # Reset to be ready for the next auto-regressive generation - matcher.reset() - - -.. _tutorial-json-schema-generation: - -JSON Schema Guided Generation -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -In this section, we see how to use XGrammar to generate an output that adheres -to a customized JSON schema. - -The flow is almost identical to the one above, except that the ``CompiledGrammar`` -is compiled based on the JSON schema, rather than being compiled with a generic JSON grammar. - -First, set up the tokenizer info and the grammar compiler as above. - -.. code:: python - - import xgrammar as xgr - import torch - import numpy as np - from transformers import AutoTokenizer, AutoConfig - - # Get tokenizer info - model_id = "Qwen/Qwen2.5-1.5B-Instruct" - tokenizer = AutoTokenizer.from_pretrained(model_id) - config = AutoConfig.from_pretrained(model_id) - # This can be larger than tokenizer.vocab_size due to paddings - full_vocab_size = config.vocab_size - tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=full_vocab_size) - - compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) - -Now, to compile a grammar from a JSON schema, there are generically two methods: from a Pydantic model, -or from a JSON schema string. The two code snippets below are functionally identical, pick one to run. - -.. code:: python - - # Method 1. Compile with a pydantic model - from pydantic import BaseModel - - class Person(BaseModel): - name: str - age: int - - compiled_grammar = compiler.compile_json_schema(Person) - -.. code:: python - - # Method 2. Compile with a JSON schema string - import json - - person_schema = { - "title": "Person", - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer", - } - }, - "required": ["name", "age"] - } - compiled_grammar = compiler.compile_json_schema(json.dumps(person_schema)) - - -Then, the remaining steps are identical to before, except that we now use a different -``xgr.CompiledGrammar`` and have a different simulated valid generation. - -.. code:: python - - # Instantiate grammar matcher and allocate the bitmask - matcher = xgr.GrammarMatcher(compiled_grammar) - token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size) - - # Here we simulate a valid sampled response - sim_sampled_response = '{"name": "xgrammar", "age": 0}<|endoftext|>' - sim_sampled_token_ids = tokenizer.encode(sim_sampled_response) - - # Each loop iteration is a simulated auto-regressive step - for i, sim_token_id in enumerate(sim_sampled_token_ids): - # LLM inference to get logits, here we use randn to simulate. - # logits is a tensor of shape (full_vocab_size,) on GPU - # logits = LLM.inference() - logits = torch.randn(full_vocab_size).cuda() - - # Apply bitmask to logits to mask invalid tokens - matcher.fill_next_token_bitmask(token_bitmask) - xgr.apply_token_bitmask_inplace(logits, token_bitmask.to(logits.device)) - - # Sample next token - probs = torch.softmax(logits, dim=-1).cpu().numpy() - next_token_id = np.random.choice(list(range(full_vocab_size)), p=probs) - - # Accept token from matcher to update its state, so that the next bitmask - # generated will enforce the next token to be generated. Assert to make - # sure the token is indeed valid. Here we accept the simulated response - # assert matcher.accept_token(next_token_id) - assert matcher.accept_token(sim_token_id) - - # Since we accepted a stop token `<|endoftext|>`, we have terminated - assert matcher.is_terminated() - - # Reset to be ready for the next auto-regressive generation - matcher.reset() - - -.. _tutorial-ebnf-generation: - -EBNF Guided Generation -~~~~~~~~~~~~~~~~~~~~~~~ - -XGrammar also enables generation that adheres to a customized EBNF grammar string. We currently use -the GBNF format (GGML BNF), with the specification `here `__. - -The code is largely identical to above, except that the ``CompiledGrammar`` is now compiled with -the provided EBNF grammar string. - -First, set up the tokenizer info and the grammar compiler as above. - -.. code:: python - - import xgrammar as xgr - import torch - import numpy as np - from transformers import AutoTokenizer, AutoConfig - - # Get tokenizer info - model_id = "Qwen/Qwen2.5-1.5B-Instruct" - tokenizer = AutoTokenizer.from_pretrained(model_id) - config = AutoConfig.from_pretrained(model_id) - # This can be larger than tokenizer.vocab_size due to paddings - full_vocab_size = config.vocab_size - tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=full_vocab_size) - - compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) - -Now, compile ``CompiledGrammar`` with your EBNF grammar string. - -.. code:: python - - ebnf_grammar_str = """root ::= (expr "=" term)+ - expr ::= term ([-+*/] term)* - term ::= num | "(" expr ")" - num ::= [0-9]+""" - - compiled_grammar = compiler.compile_grammar(ebnf_grammar_str) - -Then, the remaining steps are identical to before, except that we now use a different -``xgr.CompiledGrammar`` and have a different simulated valid generation. - -.. code:: python - - # Instantiate grammar matcher and allocate the bitmask - matcher = xgr.GrammarMatcher(compiled_grammar) - token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size) - - # Here we simulate a valid sampled response - sim_sampled_response = '(5+3)*2=16<|endoftext|>' - sim_sampled_token_ids = tokenizer.encode(sim_sampled_response) - - # Each loop iteration is a simulated auto-regressive step - for i, sim_token_id in enumerate(sim_sampled_token_ids): - # LLM inference to get logits, here we use randn to simulate. - # logits is a tensor of shape (full_vocab_size,) on GPU - # logits = LLM.inference() - logits = torch.randn(full_vocab_size).cuda() - - # Apply bitmask to logits to mask invalid tokens - matcher.fill_next_token_bitmask(token_bitmask) - xgr.apply_token_bitmask_inplace(logits, token_bitmask.to(logits.device)) - - # Sample next token - probs = torch.softmax(logits, dim=-1).cpu().numpy() - next_token_id = np.random.choice(list(range(full_vocab_size)), p=probs) - - # Accept token from matcher to update its state, so that the next bitmask - # generated will enforce the next token to be generated. Assert to make - # sure the token is indeed valid. Here we accept the simulated response - # assert matcher.accept_token(next_token_id) - assert matcher.accept_token(sim_token_id) - - # Since we accepted a stop token `<|endoftext|>`, we have terminated - assert matcher.is_terminated() - - # Reset to be ready for the next auto-regressive generation - matcher.reset() - - -.. _tutorial-batched-inference: - -Structured Generation for Batched Inference -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -All the code snippets above assume a single request generation. -This section demonstrates how the same concept works with batched generation. - -First, follow the exact same steps above for the per-model constructs -``xgr.TokenizerInfo`` and ``xgr.GrammarCompiler``. Say each request needs -to generate a valid JSON. - -.. code:: python - - import xgrammar as xgr - import torch - import numpy as np - from transformers import AutoTokenizer, AutoConfig - - # Get tokenizer info - model_id = "Qwen/Qwen2.5-1.5B-Instruct" - tokenizer = AutoTokenizer.from_pretrained(model_id) - config = AutoConfig.from_pretrained(model_id) - # This can be larger than tokenizer.vocab_size due to paddings - full_vocab_size = config.vocab_size - tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=full_vocab_size) - - # Compile a JSON grammar - compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) - compiled_grammar: xgr.CompiledGrammar = compiler.compile_builtin_json_grammar() - -Now, we need to maintain an ``xgr.GrammarMatcher`` for each request in the batch, since -each has a different generation state. Note that each request in the batch can follow a different -``xgr.CompiledGrammar``, but here for simplicity, they are all just following the general -JSON grammar. - -.. code:: python - - batch_size = 2 - matchers = [ - xgr.GrammarMatcher(compiled_grammar) - for i in range(batch_size) - ] - token_bitmask = xgr.allocate_token_bitmask(batch_size, tokenizer_info.vocab_size) - -We simulate an auto-regressive generation of batched inference. Note that here we -assume the generation lengths of the two requests are the same for simplicity. But -it should be easy to generalize based on how your engine supports batched inference. -The key difference from single-request generation is that, in batched-request generation, -each request has its own ``xgr.GrammarMatcher`` to maintain. - -.. code:: python - - sim_sampled_responses = ['{"name": "a"}<|endoftext|>', '{"name": "b"}<|endoftext|>'] - sim_sampled_token_ids = [tokenizer.encode(response) for response in sim_sampled_responses] - - # Each loop iteration is a simulated auto-regressive step - for loop_iter in range(len(sim_sampled_token_ids[0])): - # LLM batched inference to get logits, here we use randn to simulate - # Now, logits is a tensor of shape (batch_size, full_vocab_size) on GPU - # logits = LLM.inference() - logits = torch.randn(batch_size, full_vocab_size).cuda() - - # This for loop is parallelizable using threading.Thread. But estimate - # the overhead in your engine. - for i in range(batch_size): - matchers[i].fill_next_token_bitmask(token_bitmask, i) - xgr.apply_token_bitmask_inplace(logits, token_bitmask.to(logits.device)) - - # Sample next token - probs = torch.softmax(logits, dim=-1).cpu().numpy() - next_token_ids = [ - np.random.choice(list(range(full_vocab_size)), p=probs[i]) - for i in range(batch_size) - ] - - # Update the matcher for each request - for i in range(batch_size): - # Here we accept the simulated response - # assert matchers[i].accept_token(next_token_ids[i]) - matchers[i].accept_token(sim_sampled_token_ids[i][loop_iter]) - - # In our simulated case, all requests should have terminated since we accepted - # a stop token `<|endoftext|>` - for i in range(batch_size): - assert matchers[i].is_terminated() - # Reset to be ready for the next generation - matchers[i].reset() diff --git a/version.py b/version.py index aacc812..f494468 100644 --- a/version.py +++ b/version.py @@ -20,7 +20,7 @@ # --------------------------------------------------- -__version__ = "0.0.5" +__version__ = "0.1.0" PROJ_ROOT = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))