Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protect against malformed JSON schema #235

Merged
merged 2 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 144 additions & 95 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,47 @@
"""
Common utilites for engine classes.
"""Common utilites for engine classes.
"""

import torch
import time
import copy
import json
from typing import Tuple, Deque, Dict, Optional, Callable, List
from collections import deque
from threading import Condition, Lock
import time
from collections import defaultdict, deque
from dataclasses import dataclass
from pathlib import Path
from threading import Condition, Lock
from typing import Any, Callable, Deque, Dict, List, Optional, Tuple

import structlog
import torch

from ..errors import JSONModeError
from ..model.base import ModelArtifactConfig
from ..openai_logprob_protocol import LogprobsContent, TopLogprobs
from .base import (
GenerationSequence,
RawLogprobsInfo,
Request,
RequestId,
RequestState,
SequenceId,
StoppingCriteria,
RawLogprobsInfo,
)
from .constrained import build_regex_from_schema
from .constrained.fsm_cache import FSMCache
from .model_module import (
ConversationTemplate,
DecodeRequest,
PrefillRequest,
EvalMultiQueryRequest,
EvictedTokens,
ConversationTemplate,
FailedRequest,
KVCacheManager,
ModelModule,
PrefillRequest,
RequestType,
TextGenerator,
)
from .model_module import (
Tokenizer as TokenizerP,
)
from ..model.base import ModelArtifactConfig
from ..openai_logprob_protocol import LogprobsContent, TopLogprobs
from .constrained.fsm_cache import FSMCache
from .constrained import build_regex_from_schema

LOG = structlog.stdlib.get_logger(__name__)

Expand Down Expand Up @@ -131,6 +136,7 @@ def detokenize_incrementally(
prefix_end_offset = max(len(output_tokens) - 1, 0)
else:
# Put new_token_id in a list so skip_special_tokens is respected
# TODO(@jroesch): guard here for out of bound token ids
new_tokens = tokenizer.convert_ids_to_tokens([new_token_id])
output_tokens = generation_sequence.prev_tokens + new_tokens

Expand Down Expand Up @@ -257,13 +263,27 @@ def prepare_output(
return delta, out_logprob_info


# TODO(@jroesch): fix typing here
def _schema_to_regex_fsm(regex_fsm_cache, json_schema: Any) -> Any:
try:
# Convert schema into json string
json_schema = json.dumps(json_schema)
# Build a regex (grammar) from json string
json_regex = build_regex_from_schema(json_schema, whitespace_pattern=r"[ \n\t]?")
# Query fsm cache for FSM object
return regex_fsm_cache.query(json_regex)
except Exception as exc:
LOG.exception("An error occurred while building JSON mode FSM.", exc=exc)
raise JSONModeError("Failed to construct FSM.") from exc


def get_requests_to_process(
current_states: List[RequestState],
cache_manager: KVCacheManager,
regex_fsm_cache: FSMCache,
tokenizer: TokenizerP,
) -> Tuple[List[RequestType], bool, int]:
) -> Tuple[List[RequestType], List[FailedRequest], bool, int]:
requests: List[RequestType] = []
failed_requests: List[FailedRequest] = []
# TODO: consider having hybrid batch if the underlying attention kernel supports
# mixing prefill and decode.
is_prompt_batch = any(not state.is_prefilled for state in current_states)
Expand All @@ -281,106 +301,135 @@ def get_requests_to_process(

if is_prompt_batch:
for state in current_states:
if is_evicted_parallel_sampling_request(state):
requests.append(
PrefillRequest(
request_id=state.request_id,
token_ids=state.prompt_token_ids,
prompt_mask=state.prompt_mask,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
try:
if is_evicted_parallel_sampling_request(state):
requests.append(
PrefillRequest(
request_id=state.request_id,
token_ids=state.prompt_token_ids,
prompt_mask=state.prompt_mask,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
)
)

token_counts += len(state.prompt_token_ids)
token_counts += len(state.prompt_token_ids)

for gen_seq in state.generation_sequences:
# TODO(vvchernov): This is for repetion penalty
# Not obvious EvalMultiQueryRequest needs this
# Now empty instead of state.prompt_mask
vocab_size = state.sampling_params.vocab_size
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
requests.append(
EvalMultiQueryRequest(
sequence_id=gen_seq.seq_id,
num_past_tokens=state.prompt_len,
prompt_mask=prompt_mask,
queries=EvictedTokens(gen_seq.generated_token_ids),
sampling_params=state.sampling_params,
)
)
cache_manager.extend(
gen_seq.seq_id,
len(gen_seq.generated_token_ids) + 1,
)

# TODO(masahi): How to account for token counts in EvalMultiQueryRequest in
# Prometheus metric?
elif not state.is_prefilled:
# Check if JSON mode is enabled
if state.sampling_params.json_schema is not None:
# Convert schema into json string
json_schema = json.dumps(state.sampling_params.json_schema)
# Build a regex (grammar) from json string
json_regex = build_regex_from_schema(
json_schema, whitespace_pattern=r"[ \n\t]?"
)
# Query fsm cache for FSM object
state.sampling_params.regex_fsm = regex_fsm_cache.query(json_regex)

if (
state.num_sequences == 1
and state.generation_sequences[0].generated_token_ids
):
# generated_token_ids is added for the case where the request is
# recovering from cache eviction.
token_ids = (
state.prompt_token_ids
+ state.generation_sequences[0].generated_token_ids
)
else:
token_ids = state.prompt_token_ids

for gen_seq in state.generation_sequences:
# TODO(vvchernov): This is for repetion penalty
# Not obvious EvalMultiQueryRequest needs this
# Now empty instead of state.prompt_mask
vocab_size = state.sampling_params.vocab_size
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
requests.append(
EvalMultiQueryRequest(
sequence_id=gen_seq.seq_id,
num_past_tokens=state.prompt_len,
prompt_mask=prompt_mask,
queries=EvictedTokens(gen_seq.generated_token_ids),
PrefillRequest(
request_id=state.request_id,
token_ids=token_ids,
prompt_mask=state.prompt_mask,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
)
cache_manager.extend(
gen_seq.seq_id,
len(gen_seq.generated_token_ids) + 1,
)

# TODO(masahi): How to account for token counts in EvalMultiQueryRequest in
# Prometheus metric?
elif not state.is_prefilled:
# Check if JSON mode is enabled
if state.sampling_params.json_schema is not None:
# Convert schema into json string
json_schema = json.dumps(state.sampling_params.json_schema)
# Build a regex (grammar) from json string
json_regex = build_regex_from_schema(json_schema, whitespace_pattern=r"[ \n\t]?")
# Query fsm cache for FSM object
state.sampling_params.regex_fsm = regex_fsm_cache.query(json_regex)

if (
state.num_sequences == 1
and state.generation_sequences[0].generated_token_ids
):
# generated_token_ids is added for the case where the request is
# recovering from cache eviction.
token_ids = (
state.prompt_token_ids
+ state.generation_sequences[0].generated_token_ids
)
else:
token_ids = state.prompt_token_ids
token_counts += len(state.prompt_token_ids)

requests.append(
PrefillRequest(
except Exception as exc:
LOG.exception(
"An exception occurred creating internal request types.",
request_id=state.request_id,
exc=exc,
)
failed_requests.append(
FailedRequest(
request_id=state.request_id,
token_ids=token_ids,
prompt_mask=state.prompt_mask,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
error=exc,
)
)

token_counts += len(state.prompt_token_ids)

LOG.debug(
"Creating prompt batch.",
num_requests=len(requests),
total_tokens=token_counts,
)
LOG.debug(
"Creating prompt batch.",
num_requests=len(requests),
total_tokens=token_counts,
)
else:
for state in current_states:
for gen_seq in state.generation_sequences:
if not gen_seq.is_finished:
prompt_counts = len(state.prompt_token_ids)
requests.append(
DecodeRequest(
sequence_id=gen_seq.seq_id,
prompt_token_counts=prompt_counts,
prompt_mask=state.prompt_mask,
token_ids=gen_seq.generated_token_ids,
sampling_params=state.sampling_params,
try:
for gen_seq in state.generation_sequences:
if not gen_seq.is_finished:
prompt_counts = len(state.prompt_token_ids)
requests.append(
DecodeRequest(
sequence_id=gen_seq.seq_id,
prompt_token_counts=prompt_counts,
prompt_mask=state.prompt_mask,
token_ids=gen_seq.generated_token_ids,
sampling_params=state.sampling_params,
)
)
cache_manager.extend(
gen_seq.seq_id,
prompt_counts
+ len(gen_seq.generated_token_ids)
- gen_seq.next_start_position,
)
except Exception as exc:
LOG.exception(
"An exception occurred creating internal request types.",
request_id=state.request_id,
exc=exc,
)
failed_requests.append(
FailedRequest(
request_id=state.request_id,
error=exc,
)
cache_manager.extend(
gen_seq.seq_id,
prompt_counts
+ len(gen_seq.generated_token_ids)
- gen_seq.next_start_position,
)
)

token_counts = len(requests)
LOG.debug("Creating decode batch with %s requests.", token_counts)

return requests, is_prompt_batch, token_counts
return requests, failed_requests, is_prompt_batch, token_counts


def should_stop_by_length(
Expand Down
6 changes: 6 additions & 0 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ class EvalMultiQueryRequest:
sampling_params: SamplingParams


@dataclass
class FailedRequest:
request_id: RequestId
error: Exception


RequestType = Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]


Expand Down
18 changes: 14 additions & 4 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,17 @@ def step(self) -> GenerationLoopWorkerOutput:
)
return result

requests, is_prompt_batch = self._get_requests_to_process()
requests, failed_requests, is_prompt_batch = self._get_requests_to_process()

if len(failed_requests) > 0:
states = []

for request in failed_requests:
states.append(self.current_batch[request.request_id])
states[-1].validation_err = ValidationError(str(request.error))

outputs += self.create_aborted_outputs(states, finish_reason=FinishReason.Cancelled)

results = self.text_generator.generate(requests, self.cache_manager.get_cache())
LOG.debug("Finished text generation.")

Expand Down Expand Up @@ -334,16 +344,16 @@ def _adjust_batch(self):
num_new_batched_tokens = self.try_grow_batch(num_new_batched_tokens)

def _get_requests_to_process(self):
requests, is_prompt_batch, token_counts = get_requests_to_process(
self.current_batch.values(), self.cache_manager, self.regex_fsm_cache, self.tokenizer
requests, failed_requests, is_prompt_batch, token_counts = get_requests_to_process(
self.current_batch.values(), self.cache_manager, self.regex_fsm_cache,
)

if is_prompt_batch:
self.prom_metrics.histogram(BATCHED_PREFILL_TOKENS).observe(token_counts)
else:
self.prom_metrics.histogram(BATCHED_DECODE_TOKENS).observe(token_counts)

return requests, is_prompt_batch
return requests, failed_requests, is_prompt_batch

def _has_request_to_process(self) -> bool:
return bool(self.queue or self.current_batch)
Expand Down
5 changes: 3 additions & 2 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@ def step(self) -> InferenceStepResult:
if not self.current_batch:
return InferenceStepResult(outputs)

requests, _, _ = get_requests_to_process(
list(self.current_batch.values()), self.cache_manager, self.regex_fsm_cache, self.tokenizer
requests, _, _, _ = get_requests_to_process(
list(self.current_batch.values()), self.cache_manager, self.regex_fsm_cache,
)

results = self.text_generator.generate(requests, self.cache_manager.get_cache())
logger.debug("Finished text generation.")

Expand Down
Loading