Skip to content

Commit

Permalink
Running server and LLM in different processes (opendatahub-io#110)
Browse files Browse the repository at this point in the history
* Putting server and LLM on different processes

* result queues cleanup

* Syncing server start with LLM init

* Renamed model to protocol to match the original endpoint

* Formatting

* Removing tuning parameters that are no longer used

* process spawn method from env

* Cleanup and refactor

* Where did this come from?

* mypy much

* One unhappy linter in the CI, one unhappy linter. Refactor it down, format it around, 3 unhappy linters in the CI
  • Loading branch information
gshtras authored Jul 29, 2024
1 parent 904b5b8 commit a6414b8
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 264 deletions.
123 changes: 58 additions & 65 deletions vllm/entrypoints/fast_sync_llm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import asyncio
from asyncio import Queue as queue
from asyncio import QueueEmpty
from typing import Dict, Union
import multiprocessing as mp
from queue import Empty
from typing import Union

from vllm import envs
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
Expand All @@ -22,38 +20,35 @@ class FastSyncLLM:
def __init__(
self,
engine_args: EngineArgs,
input_queue: queue,
input_queue: mp.Queue,
result_queue: mp.Queue,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
self.llm_engine = LLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)
self.engine_args = engine_args
self.request_counter = Counter()

self.input_queue = input_queue
self.result_queue = result_queue
self.finish = False
self.result_queues: Dict[str, asyncio.Queue] = {}
self.need_restart = False

def _add_request(
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
request_id: str,
result_queue: queue,
) -> None:
self.result_queues[request_id] = result_queue
if isinstance(inputs, list):
inputs = TextTokensPrompt(prompt_token_ids=inputs)
self.llm_engine.add_request(request_id, inputs, params)

async def _poll_requests(self):
def _poll_requests(self):
while True:
if not self.llm_engine.has_unfinished_requests():
logger.info("No unfinished requests. Waiting...")
(request_id, prompt, sampling_params,
result_queue) = await self.input_queue.get()
(request_id, prompt, sampling_params) = self.input_queue.get()
if self.need_restart:
logger.info("Restarting worker loops")
for worker in self.llm_engine.model_executor.workers:
Expand All @@ -62,57 +57,55 @@ async def _poll_requests(self):

else:
try:
(request_id, prompt, sampling_params,
result_queue) = self.input_queue.get_nowait()
except QueueEmpty:
(request_id, prompt,
sampling_params) = self.input_queue.get_nowait()
except Empty:
break
self._add_request(prompt, sampling_params, request_id,
result_queue)
self._add_request(prompt, sampling_params, request_id)

def run_engine(self):
self.llm_engine = LLMEngine.from_engine_args(
self.engine_args, usage_context=UsageContext.LLM_CLASS)

#@async_rpd_trace()
async def run_engine(self):
steps_before_yield = envs.VLLM_ENGINE_STEPS_BEFORE_YIELD
self.result_queue.put(("Ready", None, None))
request_stats = {}
while True:
await self._poll_requests()
step_outputs = self.llm_engine.step()
if not self.llm_engine.has_unfinished_requests():
logger.info("Broadcast stop")
broadcast_tensor_dict({}, src=0)
self.need_restart = True
steps_before_yield = 0
for output in step_outputs:
assert len(output.outputs) == 1
output_len = len(output.outputs[0].text)
result_queue = self.result_queues[output.request_id]
stats = None
if output_len >= 0 and (output.request_id
not in request_stats):
request_stats[output.request_id] = output_len
result = output.outputs[0].text
steps_before_yield = min(
steps_before_yield, envs.VLLM_ENGINE_STEPS_FIRST_TOKEN)
else:
result = output.outputs[0].text[
request_stats[output.request_id]:output_len]
if output.finished:
stats = {
"prompt": len(output.prompt_token_ids),
"tokens": len(output.outputs[0].token_ids),
"finish_reason": output.outputs[0].finish_reason,
"stop_reason": output.outputs[0].stop_reason,
}
del request_stats[output.request_id]
del self.result_queues[output.request_id]
steps_before_yield = min(
steps_before_yield,
envs.VLLM_ENGINE_STEPS_COMPLETED_REQUEST)
else:
request_stats[output.request_id] = output_len
result_queue.put_nowait((output.request_id, result, stats))
steps_before_yield -= 1
if steps_before_yield <= 0:
logger.info("Engine yield. Requests: %d",
self.llm_engine.get_num_unfinished_requests())
steps_before_yield = envs.VLLM_ENGINE_STEPS_BEFORE_YIELD
await asyncio.sleep(0)
log_interval = 100
try:
while True:
self._poll_requests()
step_outputs = self.llm_engine.step()
log_interval -= 1
if log_interval == 0:
log_interval = 100
logger.info("Step finished. Unfinished requests: %d",
self.llm_engine.get_num_unfinished_requests())
if not self.llm_engine.has_unfinished_requests():
logger.info("Broadcast stop")
broadcast_tensor_dict({}, src=0)
self.need_restart = True
for output in step_outputs:
assert len(output.outputs) == 1
output_len = len(output.outputs[0].text)
stats = None
if output_len >= 0 and (output.request_id
not in request_stats):
request_stats[output.request_id] = output_len
result = output.outputs[0].text
else:
result = output.outputs[0].text[
request_stats[output.request_id]:output_len]
if output.finished:
stats = {
"prompt": len(output.prompt_token_ids),
"tokens": len(output.outputs[0].token_ids),
"finish_reason": output.outputs[0].finish_reason,
"stop_reason": output.outputs[0].stop_reason,
}
del request_stats[output.request_id]
else:
request_stats[output.request_id] = output_len
self.result_queue.put_nowait(
(output.request_id, result, stats))
except Exception as e:
logger.error("Error in run_engine: %s", e)
raise e
Loading

0 comments on commit a6414b8

Please sign in to comment.