Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
✨ implement health probe
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde committed Jul 30, 2024
1 parent 453939b commit d8372de
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 3 deletions.
3 changes: 3 additions & 0 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,6 @@ async def do_log_stats(
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
pass

async def check_health(self) -> None:
"""Raise if unhealthy"""
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from starlette.routing import Mount
from transformers import AutoTokenizer

from vllm.config import ModelConfig
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import VLLMBackend
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import VLLMBackend
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser
# yapf conflicts with isort for this block
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
VLLM_GENERATE_RPC_PATH = "tcp://localhost:5570"
VLLM_GET_DATA_RPC_PATH = "tcp://localhost:5571"
VLLM_IS_READY_RPC_PATH = "tcp://localhost:5572"
VLLM_HEALTH_CHECK_PATH = "tcp://localhost:5584"


@dataclass
Expand Down
20 changes: 20 additions & 0 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import pickle
from typing import AsyncIterator, Mapping, Optional

Expand All @@ -7,6 +8,7 @@
from vllm.config import DecodingConfig, ModelConfig
from vllm.entrypoints.openai.rpc import (VLLM_GENERATE_RPC_PATH,
VLLM_GET_DATA_RPC_PATH,
VLLM_HEALTH_CHECK_PATH,
VLLM_IS_READY_RPC_PATH,
GenerateRequest, GetDataRequest)
from vllm.inputs import PromptInputs
Expand Down Expand Up @@ -34,6 +36,10 @@ def __init__(self, tokenizer):
self.get_data_socket = self.context.socket(zmq.constants.REQ)
self.get_data_socket.connect(VLLM_GET_DATA_RPC_PATH)

# Health check socket
self.health_check_socket = self.context.socket(zmq.constants.DEALER)
self.health_check_socket.connect(VLLM_HEALTH_CHECK_PATH)

async def wait_for_server(self):
await self.is_ready_socket.recv()

Expand Down Expand Up @@ -104,3 +110,17 @@ async def generate(

yield request_output
socket.close()

async def check_health(self) -> None:
"""Raise if unhealthy"""
await self.health_check_socket.send_multipart([pickle.dumps(0)])

health_future = self.health_check_socket.recv()

# TODO: config timeout for health check?
message = await asyncio.wait_for(health_future, timeout=5)

response = pickle.loads(message)

if isinstance(response, Exception):
raise response
22 changes: 21 additions & 1 deletion vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm import AsyncLLMEngine
from vllm.entrypoints.openai.rpc import (VLLM_GENERATE_RPC_PATH,
VLLM_GET_DATA_RPC_PATH,
VLLM_HEALTH_CHECK_PATH,
VLLM_IS_READY_RPC_PATH,
GetDataRequest)
from vllm.logger import init_logger
Expand Down Expand Up @@ -40,10 +41,15 @@ def __init__(self, async_engine_args):
self.get_data_socket = self.context.socket(zmq.constants.REP)
self.get_data_socket.bind(VLLM_GET_DATA_RPC_PATH)

# Init socket for health checks
self.health_check_socket = self.context.socket(zmq.constants.ROUTER)
self.health_check_socket.bind(VLLM_HEALTH_CHECK_PATH)

# Setup polling so we can listen on both sockets.
self.poller = zmq.asyncio.Poller()
self.poller.register(self.generate_socket, zmq.constants.POLLIN)
self.poller.register(self.get_data_socket, zmq.constants.POLLIN)
self.poller.register(self.health_check_socket, zmq.constants.POLLIN)

def cleanup(self):
"""Shuts down the zmq context and closes all sockets"""
Expand Down Expand Up @@ -82,6 +88,15 @@ async def generate(self, identity, message):
self.generate_socket.send_multipart(
[identity, pickle.dumps(e, pickle.HIGHEST_PROTOCOL)])

async def check_health(self, _, identity):
try:
await self.engine.check_health()
await self.health_check_socket.send_multipart(
[identity, pickle.dumps(0, pickle.HIGHEST_PROTOCOL)])
except Exception as e:
await self.health_check_socket.send_multipart(
[identity, pickle.dumps(e, pickle.HIGHEST_PROTOCOL)])

async def run_loop(self):
# Notify the RPC client that we are ready to receive requests.
await self.is_ready_socket.send_string("Ready!")
Expand All @@ -92,7 +107,6 @@ async def run_loop(self):
while True:
self.poll_future = self.poller.poll()
socks = dict(await self.poll_future)

task = None
if self.generate_socket in socks:
identity, message = await self.generate_socket.recv_multipart()
Expand All @@ -102,6 +116,12 @@ async def run_loop(self):
message = await self.get_data_socket.recv()
task = asyncio.create_task(self.get_data(message))

elif self.health_check_socket in socks:
identity, message = \
await self.health_check_socket.recv_multipart()
task = asyncio.create_task(self.check_health(
message, identity))

# We need to keep around a strong reference to the task,
# to avoid the task disappearing mid-execution as running tasks
# can be GC'ed. Below is a common "fire-and-forget" tasks
Expand Down

0 comments on commit d8372de

Please sign in to comment.