diff --git a/examples/json-mode/src/json_mode.ts b/examples/json-mode/src/json_mode.ts index d1549d4e..c60951eb 100644 --- a/examples/json-mode/src/json_mode.ts +++ b/examples/json-mode/src/json_mode.ts @@ -12,7 +12,10 @@ async function main() { const initProgressCallback = (report: webllm.InitProgressReport) => { setLabel("init-label", report.text); }; - const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC"; + // Pick any one of these models to start trying -- most models in WebLLM support grammar + const selectedModel = "Llama-3.2-3B-Instruct-q4f16_1-MLC"; + // const selectedModel = "Qwen2.5-1.5B-Instruct-q4f16_1-MLC"; + // const selectedModel = "Phi-3.5-mini-instruct-q4f16_1-MLC"; const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine( selectedModel, { initProgressCallback: initProgressCallback }, diff --git a/examples/json-schema/src/json_schema.ts b/examples/json-schema/src/json_schema.ts index d078ec38..dea87582 100644 --- a/examples/json-schema/src/json_schema.ts +++ b/examples/json-schema/src/json_schema.ts @@ -37,9 +37,14 @@ async function simpleStructuredTextExample() { const initProgressCallback = (report: webllm.InitProgressReport) => { setLabel("init-label", report.text); }; + + // Pick any one of these models to start trying -- most models in WebLLM support grammar + // const selectedModel = "Llama-3.2-3B-Instruct-q4f16_1-MLC"; + // const selectedModel = "Qwen2.5-1.5B-Instruct-q4f16_1-MLC"; + const selectedModel = "Phi-3.5-mini-instruct-q4f16_1-MLC"; const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine( - "Llama-3.1-8B-Instruct-q4f16_1-MLC", - { initProgressCallback: initProgressCallback }, + selectedModel, + { initProgressCallback: initProgressCallback, logLevel: "INFO" }, ); // Note that you'd need to prompt the model to answer in JSON either in @@ -106,9 +111,14 @@ async function harryPotterExample() { setLabel("init-label", report.text); }; + // Pick any one of these models to start trying -- most models in WebLLM support grammar + const selectedModel = "Llama-3.2-3B-Instruct-q4f16_1-MLC"; + // const selectedModel = "Qwen2.5-1.5B-Instruct-q4f16_1-MLC"; + // const selectedModel = "Phi-3.5-mini-instruct-q4f16_1-MLC"; + const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine( - "Llama-3.1-8B-Instruct-q4f16_1-MLC", - { initProgressCallback: initProgressCallback }, + selectedModel, + { initProgressCallback: initProgressCallback, logLevel: "INFO" }, ); // Note that you'd need to prompt the model to answer in JSON either in @@ -134,6 +144,7 @@ async function harryPotterExample() { console.log(reply); console.log("Output:\n" + (await engine.getMessage())); console.log(reply.usage); + console.log(reply.usage!.extra); } async function functionCallingExample() { @@ -214,10 +225,64 @@ async function functionCallingExample() { console.log(reply.usage); } +async function ebnfGrammarExample() { + // You can directly define an EBNFGrammar string with ResponseFormat.grammar + const jsonGrammarStr = String.raw` +root ::= basic_array | basic_object +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= (([\"] basic_string_1 [\"])) +basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1 +escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]" +basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" +ws ::= [ \n\t]* +`; + + const initProgressCallback = (report: webllm.InitProgressReport) => { + setLabel("init-label", report.text); + }; + + // Pick any one of these models to start trying -- most models in WebLLM support grammar + const selectedModel = "Llama-3.2-3B-Instruct-q4f16_1-MLC"; + // const selectedModel = "Qwen2.5-1.5B-Instruct-q4f16_1-MLC"; + // const selectedModel = "Phi-3.5-mini-instruct-q4f16_1-MLC"; + const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine( + selectedModel, + { initProgressCallback: initProgressCallback, logLevel: "INFO" }, + ); + + // Note that you'd need to prompt the model to answer in JSON either in + // user's message or the system prompt + const request: webllm.ChatCompletionRequest = { + stream: false, // works with streaming, logprobs, top_logprobs as well + messages: [ + { + role: "user", + content: "Introduce yourself in JSON", + }, + ], + max_tokens: 128, + response_format: { + type: "grammar", + grammar: jsonGrammarStr, + } as webllm.ResponseFormat, + }; + + const reply0 = await engine.chatCompletion(request); + console.log(reply0); + console.log("Output:\n" + (await engine.getMessage())); + console.log(reply0.usage); +} + async function main() { // await simpleStructuredTextExample(); - // await harryPotterExample(); - await functionCallingExample(); + await harryPotterExample(); + // await functionCallingExample(); + // await ebnfGrammarExample(); } main(); diff --git a/package-lock.json b/package-lock.json index 70b60412..142bf7b4 100644 --- a/package-lock.json +++ b/package-lock.json @@ -14,6 +14,7 @@ "devDependencies": { "@mlc-ai/web-runtime": "0.18.0-dev2", "@mlc-ai/web-tokenizers": "^0.1.5", + "@mlc-ai/web-xgrammar": "../xgrammar/web", "@next/eslint-plugin-next": "^14.2.3", "@rollup/plugin-commonjs": "^20.0.0", "@rollup/plugin-node-resolve": "^13.0.4", @@ -39,6 +40,29 @@ "typescript": "^4.9.5" } }, + "../xgrammar/web": { + "name": "@mlc-ai/web-xgrammar", + "version": "0.1.0", + "dev": true, + "license": "Apache-2.0", + "devDependencies": { + "@jest/globals": "^29.7.0", + "@mlc-ai/web-tokenizers": "^0.1.5", + "@rollup/plugin-commonjs": "^20.0.0", + "@rollup/plugin-node-resolve": "^13.0.4", + "@rollup/plugin-wasm": "^5.1.2", + "@types/jest": "^29.5.12", + "@typescript-eslint/eslint-plugin": "^5.59.6", + "@typescript-eslint/parser": "^5.59.6", + "eslint": "^8.41.0", + "jest": "^29.7.0", + "rollup": "^2.56.2", + "rollup-plugin-typescript2": "^0.34.1", + "ts-jest": "^29.2.5", + "tslib": "^2.3.1", + "typescript": "^4.9.5" + } + }, "node_modules/@ampproject/remapping": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/@ampproject/remapping/-/remapping-2.3.0.tgz", @@ -1232,6 +1256,10 @@ "integrity": "sha512-G7vjJzZyOFJvAfx42kPEU7Z2hkAAGWvKJHfMTLdvY8QDLFvvvVOwmEk89Mh+7PBVPpcfh3PW0npTTupFnLMwHw==", "dev": true }, + "node_modules/@mlc-ai/web-xgrammar": { + "resolved": "../xgrammar/web", + "link": true + }, "node_modules/@next/eslint-plugin-next": { "version": "14.2.13", "resolved": "https://registry.npmjs.org/@next/eslint-plugin-next/-/eslint-plugin-next-14.2.13.tgz", diff --git a/package.json b/package.json index 348239f6..8404bbb8 100644 --- a/package.json +++ b/package.json @@ -51,6 +51,7 @@ "ts-jest": "^29.1.2", "tslib": "^2.3.1", "@mlc-ai/web-runtime": "0.18.0-dev2", + "@mlc-ai/web-xgrammar": "../xgrammar/web", "typescript": "^4.9.5" }, "dependencies": { diff --git a/src/config.ts b/src/config.ts index c0352d70..33b335c5 100644 --- a/src/config.ts +++ b/src/config.ts @@ -127,6 +127,7 @@ export interface MLCEngineConfig { export interface GenerationConfig { // Only used in MLC repetition_penalty?: number; + ignore_eos?: boolean; // Shared by MLC and OpenAI APIs top_p?: number | null; temperature?: number | null; diff --git a/src/conversation.ts b/src/conversation.ts index 25b362ec..8813df12 100644 --- a/src/conversation.ts +++ b/src/conversation.ts @@ -257,10 +257,12 @@ export class Conversation { } getStopStr(): string[] { - if (this.config.stop_str.length > 0) { - return this.config.stop_str; - } - return [this.config.seps[this.config.seps.length - 1]]; + // TODO(Charlie): Is this needed? + // if (this.config.stop_str.length > 0) { + // return this.config.stop_str; + // } + // return [this.config.seps[this.config.seps.length - 1]]; + return this.config.stop_str; } getStopTokens() { diff --git a/src/engine.ts b/src/engine.ts index 82df2127..6fe3f577 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -465,6 +465,7 @@ export class MLCEngine implements MLCEngineInterface { pipeline: LLMChatPipeline, chatConfig: ChatConfig, genConfig: GenerationConfig, + timeReceived: number, ): AsyncGenerator; asyncGenerate( request: CompletionCreateParamsStreaming, @@ -472,6 +473,7 @@ export class MLCEngine implements MLCEngineInterface { pipeline: LLMChatPipeline, chatConfig: ChatConfig, genConfig: GenerationConfig, + timeReceived: number, ): AsyncGenerator; async *asyncGenerate( request: ChatCompletionRequestStreaming | CompletionCreateParamsStreaming, @@ -479,6 +481,7 @@ export class MLCEngine implements MLCEngineInterface { pipeline: LLMChatPipeline, chatConfig: ChatConfig, genConfig: GenerationConfig, + timeReceived: number, ): AsyncGenerator { // Since it is an async generator, we need to do fine-grained try-catch to ensure lock is // released only when errors occur. Then release at the very end when no error occurs. @@ -678,18 +681,39 @@ export class MLCEngine implements MLCEngineInterface { // 4. Usage chunk if (request.stream_options?.include_usage) { + const usedGrammar = + "response_format" in request && + (request.response_format?.type === "grammar" || + request.response_format?.type === "json_object"); const completion_tokens = pipeline.getCurRoundDecodingTotalTokens(); const prompt_tokens = pipeline.getCurRoundPrefillTotalTokens(); const prefill_tokens_per_s = pipeline.getCurRoundPrefillTokensPerSec(); const decode_tokens_per_s = pipeline.getCurRoundDecodingTokensPerSec(); + const grammar_init_s = pipeline.getCurRoundGrammarInitTotalTime(); + const prefill_time = pipeline.getCurRoundPrefillTotalTime(); + const decode_time = pipeline.getCurRoundDecodingTotalTime(); + const grammar_per_token_s = + pipeline.getCurRoundGrammarPerTokenTotalTime(); + const defaultExtra = { + e2e_latency_s: (Date.now() - timeReceived) / 1000, + prefill_tokens_per_s: prefill_tokens_per_s, + decode_tokens_per_s: decode_tokens_per_s, + time_to_first_token_s: prefill_time, + time_per_output_token_s: decode_time / completion_tokens, + }; const usage: CompletionUsage = { completion_tokens: completion_tokens, prompt_tokens: prompt_tokens, total_tokens: completion_tokens + prompt_tokens, - extra: { - prefill_tokens_per_s: prefill_tokens_per_s, - decode_tokens_per_s: decode_tokens_per_s, - }, + extra: usedGrammar + ? { + ...defaultExtra, + ...{ + grammar_init_s: grammar_init_s, + grammar_per_token_s: grammar_per_token_s / completion_tokens, + }, + } + : defaultExtra, }; if (isChatCompletion) { const usageChunk: ChatCompletionChunk = { @@ -745,6 +769,7 @@ export class MLCEngine implements MLCEngineInterface { async chatCompletion( request: ChatCompletionRequest, ): Promise | ChatCompletion> { + const timeReceived = Date.now(); // 0. Check model loaded and preprocess inputs const [selectedModelId, selectedPipeline, selectedChatConfig] = this.getLLMStates("ChatCompletionRequest", request.model); @@ -766,6 +791,7 @@ export class MLCEngine implements MLCEngineInterface { logprobs: request.logprobs, top_logprobs: request.top_logprobs, response_format: request.response_format, + ignore_eos: request.ignore_eos, }; // 0.5 Block wait until this pipeline finishes all previous requests @@ -780,6 +806,7 @@ export class MLCEngine implements MLCEngineInterface { selectedPipeline, selectedChatConfig, genConfig, + timeReceived, ); } @@ -796,6 +823,8 @@ export class MLCEngine implements MLCEngineInterface { let prompt_tokens = 0; let prefill_time = 0; let decode_time = 0; + let grammar_init_s = 0; + let grammar_per_token_s = 0; for (let i = 0; i < n; i++) { let outputMessage: string; if (this.interruptSignal) { @@ -852,8 +881,21 @@ export class MLCEngine implements MLCEngineInterface { prompt_tokens += selectedPipeline.getCurRoundPrefillTotalTokens(); prefill_time += selectedPipeline.getCurRoundPrefillTotalTime(); decode_time += selectedPipeline.getCurRoundDecodingTotalTime(); + grammar_init_s += selectedPipeline.getCurRoundGrammarInitTotalTime(); + grammar_per_token_s += + selectedPipeline.getCurRoundGrammarPerTokenTotalTime(); } - + const usedGrammar = + "response_format" in request && + (request.response_format?.type === "grammar" || + request.response_format?.type === "json_object"); + const defaultExtra = { + e2e_latency_s: (Date.now() - timeReceived) / 1000, + prefill_tokens_per_s: prompt_tokens / prefill_time, + decode_tokens_per_s: completion_tokens / decode_time, + time_to_first_token_s: prefill_time, + time_per_output_token_s: decode_time / completion_tokens, + }; const response: ChatCompletion = { id: crypto.randomUUID(), choices: choices, @@ -864,10 +906,15 @@ export class MLCEngine implements MLCEngineInterface { completion_tokens: completion_tokens, prompt_tokens: prompt_tokens, total_tokens: completion_tokens + prompt_tokens, - extra: { - prefill_tokens_per_s: prompt_tokens / prefill_time, - decode_tokens_per_s: completion_tokens / decode_time, - }, + extra: usedGrammar + ? { + ...defaultExtra, + ...{ + grammar_init_s: grammar_init_s, + grammar_per_token_s: grammar_per_token_s / completion_tokens, + }, + } + : defaultExtra, } as CompletionUsage, }; @@ -901,6 +948,8 @@ export class MLCEngine implements MLCEngineInterface { async completion( request: CompletionCreateParams, ): Promise | Completion> { + const timeReceived = Date.now(); + // 0. Check model loaded and preprocess inputs const [selectedModelId, selectedPipeline, selectedChatConfig] = this.getLLMStates("CompletionCreateParams", request.model); @@ -915,6 +964,7 @@ export class MLCEngine implements MLCEngineInterface { logit_bias: request.logit_bias, logprobs: request.logprobs, top_logprobs: request.top_logprobs, + ignore_eos: request.ignore_eos, }; // 0.5 Block wait until this pipeline finishes all previous requests @@ -929,6 +979,7 @@ export class MLCEngine implements MLCEngineInterface { selectedPipeline, selectedChatConfig, genConfig, + timeReceived, ); } @@ -989,8 +1040,11 @@ export class MLCEngine implements MLCEngineInterface { prompt_tokens: prompt_tokens, total_tokens: completion_tokens + prompt_tokens, extra: { + e2e_latency_s: (Date.now() - timeReceived) / 1000, prefill_tokens_per_s: prompt_tokens / prefill_time, decode_tokens_per_s: completion_tokens / decode_time, + time_to_first_token_s: prefill_time, + time_per_output_token_s: decode_time / completion_tokens, }, } as CompletionUsage, }; diff --git a/src/error.ts b/src/error.ts index 9e8a45ab..08d9454b 100644 --- a/src/error.ts +++ b/src/error.ts @@ -411,6 +411,16 @@ export class InvalidResponseFormatError extends Error { } } +export class InvalidResponseFormatGrammarError extends Error { + constructor() { + super( + "When ResponseFormat.type is `grammar`, ResponseFormat.grammar needs to be specified.\n" + + "When ResponseFormat.grammar is specified, ResponseFormat.type needs to be grammar.", + ); + this.name = "InvalidResponseFormatGrammarError"; + } +} + export class CustomResponseFormatError extends Error { constructor(currentFormat: any) { super( diff --git a/src/grammar.ts b/src/grammar.ts deleted file mode 100644 index 0614cb9b..00000000 --- a/src/grammar.ts +++ /dev/null @@ -1,203 +0,0 @@ -import * as tvmjs from "@mlc-ai/web-runtime"; - -export type BNFGrammar = tvmjs.TVMObject; -export type GrammarStateMatcher = tvmjs.TVMObject; - -/** - * A factory class for generating and calling GrammarStateMatcher (GrammarSM) and BNFGrammar related - * methods, essentially a wrapper of related global functions in the tvm instance's wasm. - * - * We implement a factory class rather than having classes of GrammarStateMatcher and BNFGrammar - * because factory class allows us to only get/dispose PackedFunc once -- especially when we need - * multiple instances of BNFGrammar or GrammarStateMatcher. - */ -export class GrammarFactory { - private fBNFGrammarGetGrammarOfJSON: tvmjs.PackedFunc; - private fBNFGrammarFromSchema: tvmjs.PackedFunc; - private fGrammarSMFromTokenTable: tvmjs.PackedFunc; - private fGrammarSMAcceptToken: tvmjs.PackedFunc; - private fGrammarSMFindNextTokenBitmaskAsNDArray: tvmjs.PackedFunc; - private fGrammarSMIsTerminated: tvmjs.PackedFunc; - private fGrammarSMResetState: tvmjs.PackedFunc; - - /** - * Extract TVM global functions from tvm runtime instance. - * - * @param tvm An instantiated tvm runtime instance. - */ - constructor(tvm: tvmjs.Instance) { - tvm.beginScope(); - // Get global functions. - this.fBNFGrammarGetGrammarOfJSON = tvm.detachFromCurrentScope( - tvm.getGlobalFunc("mlc.grammar.BNFGrammarGetGrammarOfJSON"), - ); - this.fBNFGrammarFromSchema = tvm.detachFromCurrentScope( - tvm.getGlobalFunc("mlc.grammar.BNFGrammarFromSchema"), - ); - this.fGrammarSMFromTokenTable = tvm.detachFromCurrentScope( - tvm.getGlobalFunc("mlc.grammar.GrammarStateMatcherFromTokenTable"), - ); - this.fGrammarSMAcceptToken = tvm.detachFromCurrentScope( - tvm.getGlobalFunc("mlc.grammar.GrammarStateMatcherAcceptToken"), - ); - this.fGrammarSMFindNextTokenBitmaskAsNDArray = tvm.detachFromCurrentScope( - tvm.getGlobalFunc( - "mlc.grammar.GrammarStateMatcherFindNextTokenBitmaskAsNDArray", - ), - ); - this.fGrammarSMIsTerminated = tvm.detachFromCurrentScope( - tvm.getGlobalFunc("mlc.grammar.GrammarStateMatcherIsTerminated"), - ); - this.fGrammarSMResetState = tvm.detachFromCurrentScope( - tvm.getGlobalFunc("mlc.grammar.GrammarStateMatcherResetState"), - ); - tvm.endScope(); - } - - /** - * @returns BNFGrammar of JSON. - * @note Caller needs to handle disposal of returned object. - */ - getBNFGrammarOfJSON(): BNFGrammar { - return this.fBNFGrammarGetGrammarOfJSON() as BNFGrammar; - } - - /** - * Construct a BNF grammar from the json schema string. The schema string should be in the format - * of the schema of a JSON file. We will parse the schema and generate a BNF grammar. - * - * @param schema The schema string. - * @param indent The number of spaces for indentation. If undefined, the grammar will enforce the - * output to be in one line. - * @param separators Two separators that will be enforced by the grammar: comma and colon. - * Examples: (",", ":"), (", ", ": "). If undefined, the default separators will be used: - * (",", ": ") when the indent is not undefined, and (", ", ": ") otherwise. This follows the - * convention in Python's json.dumps(). - * @param strictMode Whether to use strict mode. In strict mode, the generated grammar will not - * allow properties and items that is not specified in the schema. This is equivalent to - * setting unevaluatedProperties and unevaluatedItems to false. - * - * @note Caller needs to handle disposal of returned object. - */ - getBNFGrammarFromSchema( - schema_str: string, - indent?: number, - separators?: [string, string], - strictMode = true, - ): BNFGrammar { - // Convert indent to tvmjs.Scalar - let indentInput: tvmjs.Scalar | undefined; - if (indent !== undefined && indent !== null) { - indentInput = new tvmjs.Scalar(indent, "int32"); - } - // Convert strictMode to tvmjs.Scalar - const strictModeInput = strictMode - ? new tvmjs.Scalar(1, "int32") - : new tvmjs.Scalar(0, "int32"); - - return this.fBNFGrammarFromSchema( - schema_str, - indentInput, - separators, - strictModeInput, - ) as BNFGrammar; - } - - /** - * Creates a Grammar State Matcher from a specified BNFGrammar rule and a token table. - * - * @param grammar A BNFGrammar used to specify the rule for the state matcher. - * @param tokenTable A list of all tokens in the tokenizer in the order of their ids, post processed. - * @param maxRollbackSteps Max rollback steps to support. Currently not supported, has to be zero. - * @returns A Grammar state matcher - * @note Caller needs to handle disposal of returned object. - */ - getGrammarStateMatcherFromTokenTable( - grammar: BNFGrammar, - tokenTable: tvmjs.TVMObject, - maxRollbackSteps = 0, - ): GrammarStateMatcher { - if (maxRollbackSteps !== 0) { - throw Error( - "maxRollbackSteps has to be zero as rollback is not supported yet.", - ); - } - return this.fGrammarSMFromTokenTable( - grammar, - tokenTable, - new tvmjs.Scalar(maxRollbackSteps, "int32"), - ) as GrammarStateMatcher; - } - - /** - * Accept a new token to the grammar state matcher, updating its internal state. - * - * @param grammarStateMatcher The grammar state matcher that will accept a new token and update - * its state correspondingly. - * @param tokenID The token to be accepted in its ID. - * @returns Whether the token is accepted. - */ - acceptToken( - grammarStateMatcher: GrammarStateMatcher, - tokenID: number, - ): boolean { - let accepted = false; - try { - accepted = this.fGrammarSMAcceptToken( - grammarStateMatcher, - new tvmjs.Scalar(tokenID, "int32"), - /*verbose=*/ new tvmjs.Scalar(0, "int32"), - ); - } catch (error) { - throw Error( - "Encountered error when accepting token " + tokenID + ": " + error, - ); - } - return accepted; - } - - /** - * Returns a bitmask in the form of an NDArray of shape (max_num_token, ceildiv(vocab_size, 32)) - * based on what tokens can/cannot be accepted by the current state of the grammar state matcher. - * - * @param grammarStateMatcher The grammar state matcher that will produce the bit mask. - * @param fullVocabSize The vocab size read from `config.json`, used to calculate size of bitmask. - * @returns A bitmask in the form of an NDArray. - */ - findNextTokenBitmask( - grammarStateMatcher: GrammarStateMatcher, - fullVocabSize: number, - ): tvmjs.TVMObject { - return this.fGrammarSMFindNextTokenBitmaskAsNDArray( - grammarStateMatcher, - new tvmjs.Scalar(fullVocabSize, "int32"), - ); - } - - /** - * @returns Whether the grammar state matcher has reached the end and hence terminated. - */ - isTerminated(grammarStateMatcher: GrammarStateMatcher): boolean { - return this.fGrammarSMIsTerminated(grammarStateMatcher); - } - - /** - * Reset the state of matcher to the initial state. - */ - resetState(grammarStateMatcher: GrammarStateMatcher): void { - this.fGrammarSMResetState(grammarStateMatcher); - } - - /** - * Dispose all tvmjs.PackedFunc this factory is initialized with. - */ - dispose() { - this.fBNFGrammarGetGrammarOfJSON.dispose(); - this.fBNFGrammarFromSchema.dispose(); - this.fGrammarSMFromTokenTable.dispose(); - this.fGrammarSMAcceptToken.dispose(); - this.fGrammarSMFindNextTokenBitmaskAsNDArray.dispose(); - this.fGrammarSMIsTerminated.dispose(); - this.fGrammarSMResetState.dispose(); - } -} diff --git a/src/llm_chat.ts b/src/llm_chat.ts index 65589189..0451da3a 100644 --- a/src/llm_chat.ts +++ b/src/llm_chat.ts @@ -1,6 +1,7 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ /* eslint-disable no-prototype-builtins */ import * as tvmjs from "@mlc-ai/web-runtime"; +import * as xgr from "@mlc-ai/web-xgrammar"; import log from "loglevel"; import { Tokenizer } from "@mlc-ai/web-tokenizers"; import { ChatConfig, GenerationConfig, Role } from "./config"; @@ -21,7 +22,6 @@ import { ResponseFormat, ChatCompletionContentPartImage, } from "./openai_api_protocols/index"; -import { BNFGrammar, GrammarFactory, GrammarStateMatcher } from "./grammar"; import { AttentionSinkSizeError, ContextWindowSizeExceededError, @@ -50,7 +50,6 @@ export class LLMChatPipeline { private image_embed: tvmjs.PackedFunc | undefined; private embed: tvmjs.PackedFunc; private fapplyBitmask: tvmjs.PackedFunc; - private fpostProcessTokenTable: tvmjs.PackedFunc; // Functions related to PagedKVCache private fclearKVCaches: tvmjs.PackedFunc; private fKVCacheAddSequence: tvmjs.PackedFunc; @@ -103,24 +102,32 @@ export class LLMChatPipeline { private logitProcessor?: LogitProcessor = undefined; // Grammar-related - // A factory to instantiate and maintain the BNF grammars and grammar state matchers. - private grammarFactory: GrammarFactory; - // A grammar state matcher for this current round if response_format is set. Reinitialized upon + // A grammar matcher for this current round if response_format is set. Reinitialized upon // each step regardless of whether the chat is multi-round or not. - private grammarStateMatcher?: GrammarStateMatcher = undefined; - // The current schema used for grammarStateMatcher; if undefined, grammarStateMatcher is simply - // using JSON mode. We use this field to determine whether we re-initiate a GrammarStateMatcher + private grammarMatcher?: xgr.GrammarMatcher = undefined; + // The current schema or grammar string used for grammarMatcher; if undefined, grammarMatcher is + // simply using JSON mode. We use this field to determine whether we re-initiate a GrammarMatcher // or simply reset the state during each round (i.e. during prefillStep). - private schema?: string = undefined; + private schemaOrGrammarStr?: string = undefined; // A string list of tokens ordered by their token id, post-processed. Once initialized, will not // be reinitialized since `this.tokenizer` does not change throughout the lifetime of LLMChatPipeline. - private tokenTable?: tvmjs.TVMObject = undefined; + private xgTokenizerInfo?: xgr.TokenizerInfo = undefined; + // Compiler for grammar. It is persistent since it specializes on xgTokenizerInfo. + private grammarCompiler?: xgr.GrammarCompiler = undefined; + // Size of the bitmask for grammar, determined by fullVocabSize private bitmaskSize: number; // `vocab_size` read from `config.json`. Can be different from the size of the tokenTable for some // models due to dummy padded tokens. private fullVocabSize: number; // Method to post process the token for grammar; either "byte_level" or default "byte_fallback". private token_postproc_method: string; + // Whether to prepend space for grammar + private prepend_space_in_encode: boolean; + // stats for grammar-related overhead + // Time to initialize grammar matcher in seconds + private curRoundGrammarInitTotalTime = 0; + // Total time of getting next bitmask and accepting token in seconds + private curRoundGrammarPerTokenTotalTime = 0; constructor( tvm: tvmjs.Instance, @@ -133,7 +140,6 @@ export class LLMChatPipeline { this.tokenizer = tokenizer; this.config = config; this.logitProcessor = logitProcessor; - this.grammarFactory = new GrammarFactory(tvm); this.fullVocabSize = this.config.vocab_size; this.bitmaskSize = Math.ceil(this.fullVocabSize / 32); @@ -150,18 +156,22 @@ export class LLMChatPipeline { // fallback mechanisms if (config.tokenizer_info !== undefined) { this.token_postproc_method = config.tokenizer_info.token_postproc_method; + this.prepend_space_in_encode = + config.tokenizer_info.prepend_space_in_encode; } else if (config.token_table_postproc_method !== undefined) { this.token_postproc_method = config.token_table_postproc_method; + this.prepend_space_in_encode = false; } else { log.warn( "Cannot find `tokenizer_info` or `token_table_postproc_method` in `mlc-chat-config.json`, " + - "using default token_postproc_method `byte_fallback`.\n" + - "Models that should not use `byte_fallback` include: Llama3, Qwen1.5-1.8B, StableLM-zerphyr-1.6B.\n" + + "using default token_postproc_method `raw`.\n" + "This field is only used for json mode.", ); - this.token_postproc_method = "byte_fallback"; + this.token_postproc_method = "raw"; + this.prepend_space_in_encode = false; } log.info("token_postproc_method: ", this.token_postproc_method); + log.info("prepend_space_in_encode: ", this.prepend_space_in_encode); this.device = this.tvm.webgpu(); @@ -180,9 +190,6 @@ export class LLMChatPipeline { this.fapplyBitmask = this.tvm.detachFromCurrentScope( this.vm.getFunction("apply_bitmask_inplace"), ); - this.fpostProcessTokenTable = this.tvm.detachFromCurrentScope( - tvm.getGlobalFunc("mlc.tokenizers.PostProcessTokenTable"), - ); try { this.image_embed = this.tvm.detachFromCurrentScope( this.vm.getFunction("image_embed"), @@ -285,8 +292,7 @@ export class LLMChatPipeline { dispose() { // TODO: Do we need to dispose all PackedFuncs here? - this.grammarFactory.dispose(); - this.grammarStateMatcher?.dispose(); + this.grammarMatcher?.dispose(); this.params.dispose(); this.decoding.dispose(); this.prefill.dispose(); @@ -298,7 +304,8 @@ export class LLMChatPipeline { this.logitsOnCPU?.dispose(); this.tvm.dispose(); this.tokenizer.dispose(); - this.tokenTable?.dispose(); + this.xgTokenizerInfo?.dispose(); + this.grammarCompiler?.dispose(); } /** @@ -399,6 +406,21 @@ export class LLMChatPipeline { return this.curRoundPrefillTotalTime; } + /** + * @returns the time (seconds) spent on for initializing grammar matcher for a single request. + */ + getCurRoundGrammarInitTotalTime(): number { + return this.curRoundGrammarInitTotalTime; + } + + /** + * @returns the total time (seconds) spent on creating bitmask and accepting token grammar matcher + * for all the generated tokens in a single request. + */ + getCurRoundGrammarPerTokenTotalTime(): number { + return this.curRoundGrammarPerTokenTotalTime; + } + /** * @returns Runtime stats information. */ @@ -453,6 +475,8 @@ export class LLMChatPipeline { */ setConversation(newConv: Conversation) { this.conversation = newConv; + this.stopStr = this.conversation.getStopStr(); + this.stopTokens = this.conversation.getStopTokens(); } async asyncLoadWebGPUPipelines() { @@ -488,9 +512,75 @@ export class LLMChatPipeline { this.curRoundPrefillTotalTokens = 0; this.curRoundPrefillTotalTime = 0; this.curRoundDecodingTotalTime = 0; + this.curRoundGrammarInitTotalTime = 0; + this.curRoundGrammarPerTokenTotalTime = 0; this.stopTriggered = false; const conversation = this.conversation; + // -1. Instantiate grammar matcher according to generation config. This step is overlapped + // with prefilling the prompt to hide overhead by using this promise. + let grammarMatcherInitPromise: Promise | undefined = undefined; + if ( + genConfig?.response_format?.type === "json_object" || + genConfig?.response_format?.type === "grammar" + ) { + const curSchemaOrGrammarStr = + genConfig.response_format.schema || genConfig.response_format.grammar; + if ( + curSchemaOrGrammarStr === this.schemaOrGrammarStr && + this.grammarMatcher + ) { + // If we did not change the schema and have instantiated a GrammarMatcher, we reuse it. + const tGrammarInitStart = performance.now(); + log.info("Reuse grammar matcher."); + this.grammarMatcher.reset(); + this.curRoundGrammarInitTotalTime = + (performance.now() - tGrammarInitStart) / 1e3; + } else { + // Else dispose current grammarMatcher, reinitialize, and update this.schema. + /* eslint-disable no-async-promise-executor */ + grammarMatcherInitPromise = new Promise(async (resolve) => { + const tGrammarInitStart = performance.now(); + log.info("Initialize new grammar matcher."); + if (this.grammarMatcher) { + this.grammarMatcher.dispose(); + } + if (this.xgTokenizerInfo === undefined) { + log.info("Initialize token table."); + // Post process entire table + const rawTokenTable = getTokenTableFromTokenizer(this.tokenizer); + this.xgTokenizerInfo = await xgr.TokenizerInfo.createTokenizerInfo( + rawTokenTable, + this.token_postproc_method, + this.prepend_space_in_encode, + this.fullVocabSize, + ); + this.grammarCompiler = + await xgr.GrammarCompiler.createGrammarCompiler( + this.xgTokenizerInfo, + ); + } + const grammar: xgr.CompiledGrammar = + curSchemaOrGrammarStr === undefined + ? await this.grammarCompiler!.compileBuiltinJSONGrammar() + : genConfig?.response_format?.type === "json_object" + ? await this.grammarCompiler!.compileJSONSchema( + curSchemaOrGrammarStr, + ) + : await this.grammarCompiler!.compileGrammar( + curSchemaOrGrammarStr, + ); + this.grammarMatcher = + await xgr.GrammarMatcher.createGrammarMatcher(grammar); + grammar.dispose(); + this.schemaOrGrammarStr = curSchemaOrGrammarStr; + this.curRoundGrammarInitTotalTime = + (performance.now() - tGrammarInitStart) / 1e3; + resolve(); + }); + } + } + // 0. Get inputData from conversation if (conversation.isTextCompletion) { conversation.prompt = inp; @@ -543,45 +633,11 @@ export class LLMChatPipeline { ); } } - - // 3. Instantiate grammar state matcher according to generation config - if (genConfig?.response_format?.type === "json_object") { - const curSchema = genConfig.response_format.schema; - if (curSchema === this.schema && this.grammarStateMatcher) { - // If we did not change the schema and have instantiated a GrammarStateMatcher, we reuse it. - this.grammarFactory.resetState(this.grammarStateMatcher); - } else { - // Else dispose current grammarStateMatcher, reinitialize, and update this.schema. - if (this.grammarStateMatcher) { - this.grammarStateMatcher.dispose(); - } - if (this.tokenTable === undefined) { - const rawTokenTable = getTokenTableFromTokenizer(this.tokenizer); - // Post process entire table - this.tokenTable = this.tvm.detachFromCurrentScope( - this.fpostProcessTokenTable( - rawTokenTable, - this.token_postproc_method, - ), - ); - } - const grammar: BNFGrammar = - curSchema === undefined - ? this.grammarFactory.getBNFGrammarOfJSON() - : this.grammarFactory.getBNFGrammarFromSchema(curSchema); - this.grammarStateMatcher = this.tvm.detachFromCurrentScope( - this.grammarFactory.getGrammarStateMatcherFromTokenTable( - grammar, - this.tokenTable!, - ), - ); - this.schema = curSchema; - } - } - this.tvm.endScope(); // 4. Sample, stats, post process token sampled. + // We wait for prefill and grammar matcher init to finish + await Promise.all([this.device.sync(), grammarMatcherInitPromise]); const nextToken = await this.sampleTokenFromLogits(logits!, genConfig); logits!.dispose(); const tend = performance.now(); @@ -667,14 +723,31 @@ export class LLMChatPipeline { if (max_tokens <= 0) { throw new MinValueError("max_tokens", 0); } + + // Get ignore_eos from generationConfig (specified by user in completion request) + let ignore_eos = false; + if ( + genConfig !== undefined && + genConfig.ignore_eos !== undefined && + genConfig.ignore_eos !== null + ) { + ignore_eos = genConfig.ignore_eos; + } + // Get stopStrs, possibly overridden by genConfig for this round let stopStrs = this.stopStr; if (genConfig !== undefined && genConfig.stop) { stopStrs = stopStrs.concat(genConfig.stop); } + let stopTokens = this.stopTokens; + if (ignore_eos) { + stopTokens = []; + stopStrs = []; + } + // Stop condition 1: stop token; otherwise, append to `this.outputIds` - if (this.stopTokens.includes(nextToken)) { + if (stopTokens.includes(nextToken)) { this.stopTriggered = true; this.finishReason = "stop"; } @@ -952,16 +1025,27 @@ export class LLMChatPipeline { } // 0. Update logitsOnGPU with on-GPU grammar bitmasking - if (response_format?.type === "json_object") { + if ( + response_format?.type === "json_object" || + response_format?.type === "grammar" + ) { this.tvm.beginScope(); - if (this.grammarStateMatcher === undefined) { - throw Error("Expect grammar state matcher to be initialized."); + if (this.grammarMatcher === undefined) { + throw Error("Expect grammar matcher to be initialized."); + } + + const tBitmaskStart = performance.now(); + const bitMaskOnCPU: Int32Array = + await this.grammarMatcher.getNextTokenBitmask(); + this.curRoundGrammarPerTokenTotalTime += + (performance.now() - tBitmaskStart) / 1e3; + + if (bitMaskOnCPU.length !== this.bitmaskSize) { + throw new Error( + `InternalError: Expect grammar bitmask to be ` + + `size ${this.bitmaskSize}, but got ${bitMaskOnCPU.length}.`, + ); } - // TODO(Charlie): Do we detach from current scope here for bitmask? - const bitMaskOnCPU = this.grammarFactory.findNextTokenBitmask( - this.grammarStateMatcher, - this.fullVocabSize, - ) as unknown as tvmjs.NDArray; const bitMaskOnGPU = this.tvm .empty([1, this.bitmaskSize], "int32", this.device) .copyFrom(bitMaskOnCPU); @@ -1081,20 +1165,21 @@ export class LLMChatPipeline { // 5. Update logit processor this.logitProcessor?.processSampledToken(sampledToken); - // 6. Update grammar state matcher with new token - if (response_format?.type === "json_object") { - this.tvm.beginScope(); - if (this.grammarStateMatcher === undefined) { - throw Error("Expect grammar state matcher to be initialized."); + // 6. Update grammar matcher with new token + if ( + response_format?.type === "json_object" || + response_format?.type === "grammar" + ) { + if (this.grammarMatcher === undefined) { + throw Error("Expect grammar matcher to be initialized."); } - const accepted = this.grammarFactory.acceptToken( - this.grammarStateMatcher, - sampledToken, - ); + const tAcceptStart = performance.now(); + const accepted = this.grammarMatcher.acceptToken(sampledToken); + this.curRoundGrammarPerTokenTotalTime += + (performance.now() - tAcceptStart) / 1e3; if (!accepted) { - throw Error("Grammar state matcher rejected the newly sampled token."); + throw Error("Grammar matcher rejected the newly sampled token."); } - this.tvm.endScope(); } return sampledToken; diff --git a/src/openai_api_protocols/chat_completion.ts b/src/openai_api_protocols/chat_completion.ts index b2c37456..17f472d3 100644 --- a/src/openai_api_protocols/chat_completion.ts +++ b/src/openai_api_protocols/chat_completion.ts @@ -29,6 +29,7 @@ import { CustomResponseFormatError, CustomSystemPromptError, InvalidResponseFormatError, + InvalidResponseFormatGrammarError, InvalidStreamOptionsError, MessageOrderError, MultipleTextContentError, @@ -240,6 +241,12 @@ export interface ChatCompletionRequestBase { */ response_format?: ResponseFormat; + /** + * If true, will ignore stop string and stop token and generate until max_tokens hit. + * If unset, will treat as false. + */ + ignore_eos?: boolean; + /** * ID of the model to use. This equals to `ModelRecord.model_id`, which needs to either be in * `webllm.prebuiltAppConfig` or in `engineConfig.appConfig`. @@ -474,6 +481,26 @@ export function postInitAndCheckFields( } } + // 6.1 When grammar is specified, the type needs to be grammar + if ( + request.response_format?.grammar !== undefined && + request.response_format?.grammar !== null + ) { + if (request.response_format?.type !== "grammar") { + throw new InvalidResponseFormatGrammarError(); + } + } + + // 6.2 When type is grammar, the grammar field needs to be specified. + if (request.response_format?.type === "grammar") { + if ( + request.response_format?.grammar === undefined || + request.response_format?.grammar === null + ) { + throw new InvalidResponseFormatGrammarError(); + } + } + // 7. Function calling hardcoded handlings if (request.tools !== undefined && request.tools !== null) { // 7.1 Check if model supports function calling @@ -902,6 +929,11 @@ export interface CompletionUsage { * Fields specific to WebLLM, not present in OpenAI. */ extra: { + /** + * Total seconds spent on this request, from receiving the request, to generating the response. + */ + e2e_latency_s: number; + /** * Number of tokens per second for prefilling. */ @@ -911,6 +943,30 @@ export interface CompletionUsage { * Number of tokens per second for autoregressive decoding. */ decode_tokens_per_s: number; + + /** + * Seconds spent to generate the first token since receiving the request. Mainly contains + * prefilling overhead. If n > 1, it is the sum over all choices. + */ + time_to_first_token_s: number; + + /** + * Seconds in between generated tokens. Mainly contains decoding overhead. If n > 1, it + * is the average over all choices. + */ + time_per_output_token_s: number; + + /** + * Seconds spent on initializing grammar matcher for structured output. If n > 1, it + * is the sum over all choices. + */ + grammar_init_s?: number; + + /** + * Seconds per-token that grammar matcher spent on creating bitmask and accepting token for + * structured output. If n > 1, it is the average over all choices. + */ + grammar_per_token_s?: number; }; } @@ -1066,6 +1122,9 @@ export namespace ChatCompletionChunk { * Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the * message the model generates is valid JSON. * + * Setting to `{ "type": "grammar" }` requires you to also specify the `grammar` field, which + * is a BNFGrammar string. + * * Setting `schema` specifies the output format of the json object such as properties to include. * * **Important:** when using JSON mode, you **must** also instruct the model to produce JSON @@ -1078,11 +1137,26 @@ export namespace ChatCompletionChunk { */ export interface ResponseFormat { /** - * Must be one of `text` or `json_object`. + * Must be one of `text`, `json_object`, or `grammar`. */ - type?: "text" | "json_object"; + type?: "text" | "json_object" | "grammar"; /** * A schema string in the format of the schema of a JSON file. `type` needs to be `json_object`. */ schema?: string; + /** + * An EBNF-formatted string. Needs to be specified when, and only specified when, + * `type` is `grammar`. The grammar will be normalized (simplified) by default. + * EBNF grammar: see https://www.w3.org/TR/xml/#sec-notation. Note: + 1. Use # as the comment mark + 2. Use C-style unicode escape sequence \u01AB, \U000001AB, \xAB + 3. A-B (match A and not match B) is not supported yet + 4. Lookahead assertion can be added at the end of a rule to speed up matching. E.g. + ``` + main ::= "ab" a [a-z] + a ::= "cd" (=[a-z]) + ``` + The assertion (=[a-z]) means a must be followed by [a-z]. + */ + grammar?: string; } diff --git a/src/openai_api_protocols/completion.ts b/src/openai_api_protocols/completion.ts index 0ed869cd..fb6aa458 100644 --- a/src/openai_api_protocols/completion.ts +++ b/src/openai_api_protocols/completion.ts @@ -182,6 +182,12 @@ export interface CompletionCreateParamsBase { */ top_p?: number | null; + /** + * If true, will ignore stop string and stop token and generate until max_tokens hit. + * If unset, will treat as false. + */ + ignore_eos?: boolean; + /** * ID of the model to use. This equals to `ModelRecord.model_id`, which needs to either be in * `webllm.prebuiltAppConfig` or in `engineConfig.appConfig`. diff --git a/tests/openai_chat_completion.test.ts b/tests/openai_chat_completion.test.ts index b4ea8e3b..0604cda1 100644 --- a/tests/openai_chat_completion.test.ts +++ b/tests/openai_chat_completion.test.ts @@ -124,6 +124,46 @@ describe("Check chat completion unsupported requests", () => { ); }); + test("Grammar string without grammar type", () => { + expect(() => { + const request: ChatCompletionRequest = { + messages: [{ role: "user", content: "Hello! " }], + response_format: { grammar: "some grammar string" }, + }; + postInitAndCheckFields( + request, + "Llama-3.1-8B-Instruct-q4f32_1-MLC", + ModelType.LLM, + ); + }).toThrow("When ResponseFormat.type is `grammar`,"); + }); + + test("Grammar type without grammar string", () => { + expect(() => { + const request: ChatCompletionRequest = { + messages: [{ role: "user", content: "Hello! " }], + response_format: { type: "grammar" }, + }; + postInitAndCheckFields( + request, + "Llama-3.1-8B-Instruct-q4f32_1-MLC", + ModelType.LLM, + ); + }).toThrow("When ResponseFormat.type is `grammar`,"); + }); + + test("Valid: Grammar type with grammar string", () => { + const request: ChatCompletionRequest = { + messages: [{ role: "user", content: "Hello! " }], + response_format: { type: "grammar", grammar: "some grammar string" }, + }; + postInitAndCheckFields( + request, + "Llama-3.1-8B-Instruct-q4f32_1-MLC", + ModelType.LLM, + ); + }); + test("image_url.detail is unsupported", () => { expect(() => { const request: ChatCompletionRequest = {