Skip to content

Commit

Permalink
Protect against invalid request format
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 14, 2024
1 parent 2b22f56 commit d68566b
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def __init__(
self.next_generation_output = None
self.requests_lock = Lock()
self.requests = dict[RequestId, RequestState]()
self.requests_to_be_cancelled_lock = Lock()
# Error message for each request that fails to be added to the engine
self.requests_to_be_cancelled = dict[RequestId, str]()

# TODO(@team): This is a temporary solution to expose model config to higher API layer.
# Follow-up with the proper solution
Expand Down Expand Up @@ -119,13 +122,17 @@ def add(self, requests: list[Request]):
assert isinstance(req.stopping_criteria.stop_sequences, list)

# If the request violates the tokenization, this returns None, so skip.
state = get_new_request_state(
req,
self.conversation_template,
self.tokenizer,
self.model_artifact_config.vocab_size,
)
new_request_states.append(state)
try:
state = get_new_request_state(
req,
self.conversation_template,
self.tokenizer,
self.model_artifact_config.vocab_size,
)
new_request_states.append(state)
except Exception as e:
with self.requests_to_be_cancelled_lock:
self.requests_to_be_cancelled[req.request_id] = str(e)

self.command_queue.put(AddRequestsCommand(request_states=new_request_states))

Expand Down Expand Up @@ -171,11 +178,25 @@ def step(self) -> InferenceStepResult:
has_pending_requests=self.has_pending_requests(),
)

outputs = list[RequestOutput]()

with self.requests_to_be_cancelled_lock:
if len(self.requests_to_be_cancelled) > 0:
for req_id, err_msg in self.requests_to_be_cancelled.items():
outputs.append(
RequestOutput(
req_id,
sequences=[],
error=err_msg,
)
)
self.requests_to_be_cancelled.clear()

if not self._is_ready_to_serve():
raise RuntimeError("GenerationLoopWorker process is not running")

if not self.has_pending_requests():
return InferenceStepResult([])
return InferenceStepResult(outputs)

if self.next_generation_output is None:
generation_output = self.result_queue.get()
Expand All @@ -188,8 +209,6 @@ def step(self) -> InferenceStepResult:
f"Error from GenerationLoopWorker process: {generation_output.error}"
) from generation_output.error

outputs = list[RequestOutput]()

with self.requests_lock:
LOG.debug(
"StagingInferenceEngine.step obtained requests_lock",
Expand Down

0 comments on commit d68566b

Please sign in to comment.