Skip to content

Commit

Permalink
Protect against malformed JSON schema (#235)
Browse files Browse the repository at this point in the history
* Protect against malformed JSON schema

* fix typing
  • Loading branch information
masahi authored Mar 19, 2024
1 parent fd4ea46 commit 4727147
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 101 deletions.
239 changes: 144 additions & 95 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,47 @@
"""
Common utilites for engine classes.
"""Common utilites for engine classes.
"""

import torch
import time
import copy
import json
from typing import Tuple, Deque, Dict, Optional, Callable, List
from collections import deque
from threading import Condition, Lock
import time
from collections import defaultdict, deque
from dataclasses import dataclass
from pathlib import Path
from threading import Condition, Lock
from typing import Any, Callable, Deque, Dict, List, Optional, Tuple

import structlog
import torch

from ..errors import JSONModeError
from ..model.base import ModelArtifactConfig
from ..openai_logprob_protocol import LogprobsContent, TopLogprobs
from .base import (
GenerationSequence,
RawLogprobsInfo,
Request,
RequestId,
RequestState,
SequenceId,
StoppingCriteria,
RawLogprobsInfo,
)
from .constrained import build_regex_from_schema
from .constrained.fsm_cache import FSMCache
from .model_module import (
ConversationTemplate,
DecodeRequest,
PrefillRequest,
EvalMultiQueryRequest,
EvictedTokens,
ConversationTemplate,
FailedRequest,
KVCacheManager,
ModelModule,
PrefillRequest,
RequestType,
TextGenerator,
)
from .model_module import (
Tokenizer as TokenizerP,
)
from ..model.base import ModelArtifactConfig
from ..openai_logprob_protocol import LogprobsContent, TopLogprobs
from .constrained.fsm_cache import FSMCache
from .constrained import build_regex_from_schema

LOG = structlog.stdlib.get_logger(__name__)

Expand Down Expand Up @@ -131,6 +136,7 @@ def detokenize_incrementally(
prefix_end_offset = max(len(output_tokens) - 1, 0)
else:
# Put new_token_id in a list so skip_special_tokens is respected
# TODO(@jroesch): guard here for out of bound token ids
new_tokens = tokenizer.convert_ids_to_tokens([new_token_id])
output_tokens = generation_sequence.prev_tokens + new_tokens

Expand Down Expand Up @@ -257,13 +263,27 @@ def prepare_output(
return delta, out_logprob_info


# TODO(@jroesch): fix typing here
def _schema_to_regex_fsm(regex_fsm_cache, json_schema: Any) -> Any:
try:
# Convert schema into json string
json_schema = json.dumps(json_schema)
# Build a regex (grammar) from json string
json_regex = build_regex_from_schema(json_schema, whitespace_pattern=r"[ \n\t]?")
# Query fsm cache for FSM object
return regex_fsm_cache.query(json_regex)
except Exception as exc:
LOG.exception("An error occurred while building JSON mode FSM.", exc=exc)
raise JSONModeError("Failed to construct FSM.") from exc


def get_requests_to_process(
current_states: List[RequestState],
cache_manager: KVCacheManager,
regex_fsm_cache: FSMCache,
tokenizer: TokenizerP,
) -> Tuple[List[RequestType], bool, int]:
) -> Tuple[List[RequestType], List[FailedRequest], bool, int]:
requests: List[RequestType] = []
failed_requests: List[FailedRequest] = []
# TODO: consider having hybrid batch if the underlying attention kernel supports
# mixing prefill and decode.
is_prompt_batch = any(not state.is_prefilled for state in current_states)
Expand All @@ -281,106 +301,135 @@ def get_requests_to_process(

if is_prompt_batch:
for state in current_states:
if is_evicted_parallel_sampling_request(state):
requests.append(
PrefillRequest(
request_id=state.request_id,
token_ids=state.prompt_token_ids,
prompt_mask=state.prompt_mask,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
try:
if is_evicted_parallel_sampling_request(state):
requests.append(
PrefillRequest(
request_id=state.request_id,
token_ids=state.prompt_token_ids,
prompt_mask=state.prompt_mask,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
)
)

token_counts += len(state.prompt_token_ids)
token_counts += len(state.prompt_token_ids)

for gen_seq in state.generation_sequences:
# TODO(vvchernov): This is for repetion penalty
# Not obvious EvalMultiQueryRequest needs this
# Now empty instead of state.prompt_mask
vocab_size = state.sampling_params.vocab_size
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
requests.append(
EvalMultiQueryRequest(
sequence_id=gen_seq.seq_id,
num_past_tokens=state.prompt_len,
prompt_mask=prompt_mask,
queries=EvictedTokens(gen_seq.generated_token_ids),
sampling_params=state.sampling_params,
)
)
cache_manager.extend(
gen_seq.seq_id,
len(gen_seq.generated_token_ids) + 1,
)

# TODO(masahi): How to account for token counts in EvalMultiQueryRequest in
# Prometheus metric?
elif not state.is_prefilled:
# Check if JSON mode is enabled
if state.sampling_params.json_schema is not None:
# Convert schema into json string
json_schema = json.dumps(state.sampling_params.json_schema)
# Build a regex (grammar) from json string
json_regex = build_regex_from_schema(
json_schema, whitespace_pattern=r"[ \n\t]?"
)
# Query fsm cache for FSM object
state.sampling_params.regex_fsm = regex_fsm_cache.query(json_regex)

if (
state.num_sequences == 1
and state.generation_sequences[0].generated_token_ids
):
# generated_token_ids is added for the case where the request is
# recovering from cache eviction.
token_ids = (
state.prompt_token_ids
+ state.generation_sequences[0].generated_token_ids
)
else:
token_ids = state.prompt_token_ids

for gen_seq in state.generation_sequences:
# TODO(vvchernov): This is for repetion penalty
# Not obvious EvalMultiQueryRequest needs this
# Now empty instead of state.prompt_mask
vocab_size = state.sampling_params.vocab_size
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
requests.append(
EvalMultiQueryRequest(
sequence_id=gen_seq.seq_id,
num_past_tokens=state.prompt_len,
prompt_mask=prompt_mask,
queries=EvictedTokens(gen_seq.generated_token_ids),
PrefillRequest(
request_id=state.request_id,
token_ids=token_ids,
prompt_mask=state.prompt_mask,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
)
cache_manager.extend(
gen_seq.seq_id,
len(gen_seq.generated_token_ids) + 1,
)

# TODO(masahi): How to account for token counts in EvalMultiQueryRequest in
# Prometheus metric?
elif not state.is_prefilled:
# Check if JSON mode is enabled
if state.sampling_params.json_schema is not None:
# Convert schema into json string
json_schema = json.dumps(state.sampling_params.json_schema)
# Build a regex (grammar) from json string
json_regex = build_regex_from_schema(json_schema, whitespace_pattern=r"[ \n\t]?")
# Query fsm cache for FSM object
state.sampling_params.regex_fsm = regex_fsm_cache.query(json_regex)

if (
state.num_sequences == 1
and state.generation_sequences[0].generated_token_ids
):
# generated_token_ids is added for the case where the request is
# recovering from cache eviction.
token_ids = (
state.prompt_token_ids
+ state.generation_sequences[0].generated_token_ids
)
else:
token_ids = state.prompt_token_ids
token_counts += len(state.prompt_token_ids)

requests.append(
PrefillRequest(
except Exception as exc:
LOG.exception(
"An exception occurred creating internal request types.",
request_id=state.request_id,
exc=exc,
)
failed_requests.append(
FailedRequest(
request_id=state.request_id,
token_ids=token_ids,
prompt_mask=state.prompt_mask,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
error=exc,
)
)

token_counts += len(state.prompt_token_ids)

LOG.debug(
"Creating prompt batch.",
num_requests=len(requests),
total_tokens=token_counts,
)
LOG.debug(
"Creating prompt batch.",
num_requests=len(requests),
total_tokens=token_counts,
)
else:
for state in current_states:
for gen_seq in state.generation_sequences:
if not gen_seq.is_finished:
prompt_counts = len(state.prompt_token_ids)
requests.append(
DecodeRequest(
sequence_id=gen_seq.seq_id,
prompt_token_counts=prompt_counts,
prompt_mask=state.prompt_mask,
token_ids=gen_seq.generated_token_ids,
sampling_params=state.sampling_params,
try:
for gen_seq in state.generation_sequences:
if not gen_seq.is_finished:
prompt_counts = len(state.prompt_token_ids)
requests.append(
DecodeRequest(
sequence_id=gen_seq.seq_id,
prompt_token_counts=prompt_counts,
prompt_mask=state.prompt_mask,
token_ids=gen_seq.generated_token_ids,
sampling_params=state.sampling_params,
)
)
cache_manager.extend(
gen_seq.seq_id,
prompt_counts
+ len(gen_seq.generated_token_ids)
- gen_seq.next_start_position,
)
except Exception as exc:
LOG.exception(
"An exception occurred creating internal request types.",
request_id=state.request_id,
exc=exc,
)
failed_requests.append(
FailedRequest(
request_id=state.request_id,
error=exc,
)
cache_manager.extend(
gen_seq.seq_id,
prompt_counts
+ len(gen_seq.generated_token_ids)
- gen_seq.next_start_position,
)
)

token_counts = len(requests)
LOG.debug("Creating decode batch with %s requests.", token_counts)

return requests, is_prompt_batch, token_counts
return requests, failed_requests, is_prompt_batch, token_counts


def should_stop_by_length(
Expand Down
6 changes: 6 additions & 0 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ class EvalMultiQueryRequest:
sampling_params: SamplingParams


@dataclass
class FailedRequest:
request_id: RequestId
error: Exception


RequestType = Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]


Expand Down
18 changes: 14 additions & 4 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,17 @@ def step(self) -> GenerationLoopWorkerOutput:
)
return result

requests, is_prompt_batch = self._get_requests_to_process()
requests, failed_requests, is_prompt_batch = self._get_requests_to_process()

if len(failed_requests) > 0:
states = []

for request in failed_requests:
states.append(self.current_batch[request.request_id])
states[-1].validation_err = ValidationError(str(request.error))

outputs += self.create_aborted_outputs(states, finish_reason=FinishReason.Cancelled)

results = self.text_generator.generate(requests, self.cache_manager.get_cache())
LOG.debug("Finished text generation.")

Expand Down Expand Up @@ -334,16 +344,16 @@ def _adjust_batch(self):
num_new_batched_tokens = self.try_grow_batch(num_new_batched_tokens)

def _get_requests_to_process(self):
requests, is_prompt_batch, token_counts = get_requests_to_process(
self.current_batch.values(), self.cache_manager, self.regex_fsm_cache, self.tokenizer
requests, failed_requests, is_prompt_batch, token_counts = get_requests_to_process(
self.current_batch.values(), self.cache_manager, self.regex_fsm_cache,
)

if is_prompt_batch:
self.prom_metrics.histogram(BATCHED_PREFILL_TOKENS).observe(token_counts)
else:
self.prom_metrics.histogram(BATCHED_DECODE_TOKENS).observe(token_counts)

return requests, is_prompt_batch
return requests, failed_requests, is_prompt_batch

def _has_request_to_process(self) -> bool:
return bool(self.queue or self.current_batch)
Expand Down
5 changes: 3 additions & 2 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@ def step(self) -> InferenceStepResult:
if not self.current_batch:
return InferenceStepResult(outputs)

requests, _, _ = get_requests_to_process(
list(self.current_batch.values()), self.cache_manager, self.regex_fsm_cache, self.tokenizer
requests, _, _, _ = get_requests_to_process(
list(self.current_batch.values()), self.cache_manager, self.regex_fsm_cache,
)

results = self.text_generator.generate(requests, self.cache_manager.get_cache())
logger.debug("Finished text generation.")

Expand Down

0 comments on commit 4727147

Please sign in to comment.