Skip to content

Commit

Permalink
Update to new XGrammar pre-release APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieFRuan committed Nov 22, 2024
1 parent 6e23f5d commit 2fed3e8
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions src/llm_chat.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */
/* eslint-disable no-prototype-builtins */
import * as tvmjs from "@mlc-ai/web-runtime";
import * as xgrammar from "@mlc-ai/web-xgrammar";
import * as xgr from "@mlc-ai/web-xgrammar";
import log from "loglevel";
import { Tokenizer } from "@mlc-ai/web-tokenizers";
import { ChatConfig, GenerationConfig, Role } from "./config";
Expand Down Expand Up @@ -104,14 +104,17 @@ export class LLMChatPipeline {
// Grammar-related
// A grammar matcher for this current round if response_format is set. Reinitialized upon
// each step regardless of whether the chat is multi-round or not.
private grammarMatcher?: xgrammar.GrammarMatcher = undefined;
private grammarMatcher?: xgr.GrammarMatcher = undefined;
// The current schema or grammar string used for grammarMatcher; if undefined, grammarMatcher is
// simply using JSON mode. We use this field to determine whether we re-initiate a GrammarMatcher
// or simply reset the state during each round (i.e. during prefillStep).
private schemaOrGrammarStr?: string = undefined;
// A string list of tokens ordered by their token id, post-processed. Once initialized, will not
// be reinitialized since `this.tokenizer` does not change throughout the lifetime of LLMChatPipeline.
private xgTokenizerInfo?: xgrammar.TokenizerInfo = undefined;
private xgTokenizerInfo?: xgr.TokenizerInfo = undefined;
// Compiler for grammar. It is persistent since it specializes on xgTokenizerInfo.
private grammarCompiler?: xgr.GrammarCompiler = undefined;
// Size of the bitmask for grammar, determined by fullVocabSize
private bitmaskSize: number;
// `vocab_size` read from `config.json`. Can be different from the size of the tokenTable for some
// models due to dummy padded tokens.
Expand Down Expand Up @@ -302,6 +305,7 @@ export class LLMChatPipeline {
this.tvm.dispose();
this.tokenizer.dispose();
this.xgTokenizerInfo?.dispose();
this.grammarCompiler?.dispose();
}

/**
Expand Down Expand Up @@ -545,31 +549,29 @@ export class LLMChatPipeline {
log.info("Initialize token table.");
// Post process entire table
const rawTokenTable = getTokenTableFromTokenizer(this.tokenizer);
this.xgTokenizerInfo =
await xgrammar.TokenizerInfo.createTokenizerInfo(
rawTokenTable,
this.token_postproc_method,
this.prepend_space_in_encode,
this.xgTokenizerInfo = await xgr.TokenizerInfo.createTokenizerInfo(
rawTokenTable,
this.token_postproc_method,
this.prepend_space_in_encode,
this.fullVocabSize,
);
this.grammarCompiler =
await xgr.GrammarCompiler.createGrammarCompiler(
this.xgTokenizerInfo,
);
}
const grammar: xgrammar.BNFGrammar =
const grammar: xgr.CompiledGrammar =
curSchemaOrGrammarStr === undefined
? await xgrammar.BuiltinGrammar.json()
? await this.grammarCompiler!.compileBuiltinJSONGrammar()
: genConfig?.response_format?.type === "json_object"
? await xgrammar.BuiltinGrammar.jsonSchema(
? await this.grammarCompiler!.compileJSONSchema(
curSchemaOrGrammarStr,
)
: await xgrammar.BNFGrammar.createBNFGrammar(
: await this.grammarCompiler!.compileGrammar(
curSchemaOrGrammarStr,
);
this.grammarMatcher =
await xgrammar.GrammarMatcher.createGrammarMatcher(
grammar,
this.xgTokenizerInfo,
undefined,
undefined,
this.fullVocabSize,
);
await xgr.GrammarMatcher.createGrammarMatcher(grammar);
grammar.dispose();
this.schemaOrGrammarStr = curSchemaOrGrammarStr;
this.curRoundGrammarInitTotalTime =
Expand Down Expand Up @@ -1034,7 +1036,7 @@ export class LLMChatPipeline {

const tBitmaskStart = performance.now();
const bitMaskOnCPU: Int32Array =
await this.grammarMatcher.findNextTokenBitmask();
await this.grammarMatcher.getNextTokenBitmask();
this.curRoundGrammarPerTokenTotalTime +=
(performance.now() - tBitmaskStart) / 1e3;

Expand Down

0 comments on commit 2fed3e8

Please sign in to comment.