From eb264aba06ec2ba03f832782405cd65a43faf016 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 19 Mar 2024 20:17:07 +0000 Subject: [PATCH 1/2] Protect against malformed JSON schema --- serve/mlc_serve/engine/engine_common.py | 239 +++++++++++------- serve/mlc_serve/engine/model_module.py | 6 + .../mlc_serve/engine/staging_engine_worker.py | 18 +- serve/mlc_serve/engine/sync_engine.py | 5 +- 4 files changed, 167 insertions(+), 101 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 918330f032..f6ae51bf9c 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -1,42 +1,47 @@ -""" -Common utilites for engine classes. +"""Common utilites for engine classes. """ -import torch -import time +import copy import json -from typing import Tuple, Deque, Dict, Optional, Callable, List -from collections import deque -from threading import Condition, Lock +import time +from collections import defaultdict, deque +from dataclasses import dataclass from pathlib import Path +from threading import Condition, Lock +from typing import Any, Callable, Deque, Dict, List, Optional, Tuple import structlog +import torch +from ..errors import JSONModeError +from ..model.base import ModelArtifactConfig +from ..openai_logprob_protocol import LogprobsContent, TopLogprobs from .base import ( GenerationSequence, + RawLogprobsInfo, Request, RequestId, RequestState, SequenceId, StoppingCriteria, - RawLogprobsInfo, ) +from .constrained import build_regex_from_schema +from .constrained.fsm_cache import FSMCache from .model_module import ( + ConversationTemplate, DecodeRequest, - PrefillRequest, EvalMultiQueryRequest, EvictedTokens, - ConversationTemplate, + FailedRequest, KVCacheManager, ModelModule, + PrefillRequest, RequestType, TextGenerator, +) +from .model_module import ( Tokenizer as TokenizerP, ) -from ..model.base import ModelArtifactConfig -from ..openai_logprob_protocol import LogprobsContent, TopLogprobs -from .constrained.fsm_cache import FSMCache -from .constrained import build_regex_from_schema LOG = structlog.stdlib.get_logger(__name__) @@ -131,6 +136,7 @@ def detokenize_incrementally( prefix_end_offset = max(len(output_tokens) - 1, 0) else: # Put new_token_id in a list so skip_special_tokens is respected + # TODO(@jroesch): guard here for out of bound token ids new_tokens = tokenizer.convert_ids_to_tokens([new_token_id]) output_tokens = generation_sequence.prev_tokens + new_tokens @@ -257,13 +263,27 @@ def prepare_output( return delta, out_logprob_info +# TODO(@jroesch): fix typing here +def _schema_to_regex_fsm(regex_fsm_cache, json_schema: Any) -> Any: + try: + # Convert schema into json string + json_schema = json.dumps(json_schema) + # Build a regex (grammar) from json string + json_regex = build_regex_from_schema(json_schema, whitespace_pattern=r"[ \n\t]?") + # Query fsm cache for FSM object + return regex_fsm_cache.query(json_regex) + except Exception as exc: + LOG.exception("An error occurred while building JSON mode FSM.", exc=exc) + raise JSONModeError("Failed to construct FSM.") from exc + + def get_requests_to_process( current_states: List[RequestState], cache_manager: KVCacheManager, regex_fsm_cache: FSMCache, - tokenizer: TokenizerP, -) -> Tuple[List[RequestType], bool, int]: +) -> Tuple[List[RequestType], List[FailedRequest], bool, int]: requests: List[RequestType] = [] + failed_requests: List[FailedRequest] = [] # TODO: consider having hybrid batch if the underlying attention kernel supports # mixing prefill and decode. is_prompt_batch = any(not state.is_prefilled for state in current_states) @@ -281,106 +301,135 @@ def get_requests_to_process( if is_prompt_batch: for state in current_states: - if is_evicted_parallel_sampling_request(state): - requests.append( - PrefillRequest( - request_id=state.request_id, - token_ids=state.prompt_token_ids, - prompt_mask=state.prompt_mask, - num_sequence=state.num_sequences, - sampling_params=state.sampling_params, + try: + if is_evicted_parallel_sampling_request(state): + requests.append( + PrefillRequest( + request_id=state.request_id, + token_ids=state.prompt_token_ids, + prompt_mask=state.prompt_mask, + num_sequence=state.num_sequences, + sampling_params=state.sampling_params, + ) ) - ) - token_counts += len(state.prompt_token_ids) + token_counts += len(state.prompt_token_ids) + + for gen_seq in state.generation_sequences: + # TODO(vvchernov): This is for repetion penalty + # Not obvious EvalMultiQueryRequest needs this + # Now empty instead of state.prompt_mask + vocab_size = state.sampling_params.vocab_size + prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool) + requests.append( + EvalMultiQueryRequest( + sequence_id=gen_seq.seq_id, + num_past_tokens=state.prompt_len, + prompt_mask=prompt_mask, + queries=EvictedTokens(gen_seq.generated_token_ids), + sampling_params=state.sampling_params, + ) + ) + cache_manager.extend( + gen_seq.seq_id, + len(gen_seq.generated_token_ids) + 1, + ) + + # TODO(masahi): How to account for token counts in EvalMultiQueryRequest in + # Prometheus metric? + elif not state.is_prefilled: + # Check if JSON mode is enabled + if state.sampling_params.json_schema is not None: + # Convert schema into json string + json_schema = json.dumps(state.sampling_params.json_schema) + # Build a regex (grammar) from json string + json_regex = build_regex_from_schema( + json_schema, whitespace_pattern=r"[ \n\t]?" + ) + # Query fsm cache for FSM object + state.sampling_params.regex_fsm = regex_fsm_cache.query(json_regex) + + if ( + state.num_sequences == 1 + and state.generation_sequences[0].generated_token_ids + ): + # generated_token_ids is added for the case where the request is + # recovering from cache eviction. + token_ids = ( + state.prompt_token_ids + + state.generation_sequences[0].generated_token_ids + ) + else: + token_ids = state.prompt_token_ids - for gen_seq in state.generation_sequences: - # TODO(vvchernov): This is for repetion penalty - # Not obvious EvalMultiQueryRequest needs this - # Now empty instead of state.prompt_mask - vocab_size = state.sampling_params.vocab_size - prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool) requests.append( - EvalMultiQueryRequest( - sequence_id=gen_seq.seq_id, - num_past_tokens=state.prompt_len, - prompt_mask=prompt_mask, - queries=EvictedTokens(gen_seq.generated_token_ids), + PrefillRequest( + request_id=state.request_id, + token_ids=token_ids, + prompt_mask=state.prompt_mask, + num_sequence=state.num_sequences, sampling_params=state.sampling_params, ) ) - cache_manager.extend( - gen_seq.seq_id, - len(gen_seq.generated_token_ids) + 1, - ) - # TODO(masahi): How to account for token counts in EvalMultiQueryRequest in - # Prometheus metric? - elif not state.is_prefilled: - # Check if JSON mode is enabled - if state.sampling_params.json_schema is not None: - # Convert schema into json string - json_schema = json.dumps(state.sampling_params.json_schema) - # Build a regex (grammar) from json string - json_regex = build_regex_from_schema(json_schema, whitespace_pattern=r"[ \n\t]?") - # Query fsm cache for FSM object - state.sampling_params.regex_fsm = regex_fsm_cache.query(json_regex) - - if ( - state.num_sequences == 1 - and state.generation_sequences[0].generated_token_ids - ): - # generated_token_ids is added for the case where the request is - # recovering from cache eviction. - token_ids = ( - state.prompt_token_ids - + state.generation_sequences[0].generated_token_ids - ) - else: - token_ids = state.prompt_token_ids + token_counts += len(state.prompt_token_ids) - requests.append( - PrefillRequest( + except Exception as exc: + LOG.exception( + "An exception occurred creating internal request types.", + request_id=state.request_id, + exc=exc, + ) + failed_requests.append( + FailedRequest( request_id=state.request_id, - token_ids=token_ids, - prompt_mask=state.prompt_mask, - num_sequence=state.num_sequences, - sampling_params=state.sampling_params, + error=exc, ) ) - token_counts += len(state.prompt_token_ids) - - LOG.debug( - "Creating prompt batch.", - num_requests=len(requests), - total_tokens=token_counts, - ) + LOG.debug( + "Creating prompt batch.", + num_requests=len(requests), + total_tokens=token_counts, + ) else: for state in current_states: - for gen_seq in state.generation_sequences: - if not gen_seq.is_finished: - prompt_counts = len(state.prompt_token_ids) - requests.append( - DecodeRequest( - sequence_id=gen_seq.seq_id, - prompt_token_counts=prompt_counts, - prompt_mask=state.prompt_mask, - token_ids=gen_seq.generated_token_ids, - sampling_params=state.sampling_params, + try: + for gen_seq in state.generation_sequences: + if not gen_seq.is_finished: + prompt_counts = len(state.prompt_token_ids) + requests.append( + DecodeRequest( + sequence_id=gen_seq.seq_id, + prompt_token_counts=prompt_counts, + prompt_mask=state.prompt_mask, + token_ids=gen_seq.generated_token_ids, + sampling_params=state.sampling_params, + ) ) + cache_manager.extend( + gen_seq.seq_id, + prompt_counts + + len(gen_seq.generated_token_ids) + - gen_seq.next_start_position, + ) + except Exception as exc: + LOG.exception( + "An exception occurred creating internal request types.", + request_id=state.request_id, + exc=exc, + ) + failed_requests.append( + FailedRequest( + request_id=state.request_id, + error=exc, ) - cache_manager.extend( - gen_seq.seq_id, - prompt_counts - + len(gen_seq.generated_token_ids) - - gen_seq.next_start_position, - ) + ) token_counts = len(requests) LOG.debug("Creating decode batch with %s requests.", token_counts) - return requests, is_prompt_batch, token_counts + return requests, failed_requests, is_prompt_batch, token_counts def should_stop_by_length( diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 302cd623ff..a97322ff3e 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -71,6 +71,12 @@ class EvalMultiQueryRequest: sampling_params: SamplingParams +@dataclass +class FailedRequest: + request_id: RequestId + error: Exception + + RequestType = Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest] diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index e33834bac3..312661d09b 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -233,7 +233,17 @@ def step(self) -> GenerationLoopWorkerOutput: ) return result - requests, is_prompt_batch = self._get_requests_to_process() + requests, failed_requests, is_prompt_batch = self._get_requests_to_process() + + if len(failed_requests) > 0: + states = [] + + for request in failed_requests: + states.append(self.current_batch[request.request_id]) + states[-1].validation_err = str(request.error) + + outputs += self.create_aborted_outputs(states, finish_reason=FinishReason.Cancelled) + results = self.text_generator.generate(requests, self.cache_manager.get_cache()) LOG.debug("Finished text generation.") @@ -334,8 +344,8 @@ def _adjust_batch(self): num_new_batched_tokens = self.try_grow_batch(num_new_batched_tokens) def _get_requests_to_process(self): - requests, is_prompt_batch, token_counts = get_requests_to_process( - self.current_batch.values(), self.cache_manager, self.regex_fsm_cache, self.tokenizer + requests, failed_requests, is_prompt_batch, token_counts = get_requests_to_process( + self.current_batch.values(), self.cache_manager, self.regex_fsm_cache, ) if is_prompt_batch: @@ -343,7 +353,7 @@ def _get_requests_to_process(self): else: self.prom_metrics.histogram(BATCHED_DECODE_TOKENS).observe(token_counts) - return requests, is_prompt_batch + return requests, failed_requests, is_prompt_batch def _has_request_to_process(self) -> bool: return bool(self.queue or self.current_batch) diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index c164f4ecf8..a5eb53ae02 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -145,9 +145,10 @@ def step(self) -> InferenceStepResult: if not self.current_batch: return InferenceStepResult(outputs) - requests, _, _ = get_requests_to_process( - list(self.current_batch.values()), self.cache_manager, self.regex_fsm_cache, self.tokenizer + requests, _, _, _ = get_requests_to_process( + list(self.current_batch.values()), self.cache_manager, self.regex_fsm_cache, ) + results = self.text_generator.generate(requests, self.cache_manager.get_cache()) logger.debug("Finished text generation.") From 4da115a26b659aa8d593744cd259ef14c62f1f4b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 19 Mar 2024 20:22:55 +0000 Subject: [PATCH 2/2] fix typing --- serve/mlc_serve/engine/staging_engine_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index 312661d09b..af286150fb 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -240,7 +240,7 @@ def step(self) -> GenerationLoopWorkerOutput: for request in failed_requests: states.append(self.current_batch[request.request_id]) - states[-1].validation_err = str(request.error) + states[-1].validation_err = ValidationError(str(request.error)) outputs += self.create_aborted_outputs(states, finish_reason=FinishReason.Cancelled)