Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vLLM model handler efficiency improvements #32687

Merged
merged 5 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/trigger_files/beam_PostCommit_Python.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run.",
"modification": 3
"modification": 4
}

103 changes: 70 additions & 33 deletions sdks/python/apache_beam/ml/inference/vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# pytype: skip-file

import asyncio
import logging
import os
import subprocess
Expand All @@ -35,6 +36,7 @@
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.utils import subprocess_server
from openai import AsyncOpenAI
from openai import OpenAI

try:
Expand Down Expand Up @@ -94,6 +96,15 @@ def getVLLMClient(port) -> OpenAI:
)


def getAsyncVLLMClient(port) -> AsyncOpenAI:
openai_api_key = "EMPTY"
openai_api_base = f"http://localhost:{port}/v1"
return AsyncOpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)


class _VLLMModelServer():
def __init__(self, model_name: str, vllm_server_kwargs: Dict[str, str]):
self._model_name = model_name
Expand Down Expand Up @@ -184,6 +195,34 @@ def __init__(
def load_model(self) -> _VLLMModelServer:
return _VLLMModelServer(self._model_name, self._vllm_server_kwargs)

async def _async_run_inference(
self,
batch: Sequence[str],
model: _VLLMModelServer,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionResult]:
client = getAsyncVLLMClient(model.get_server_port())
inference_args = inference_args or {}
async_predictions = []
for prompt in batch:
try:
completion = client.completions.create(
model=self._model_name, prompt=prompt, **inference_args)
async_predictions.append(completion)
except Exception as e:
model.check_connectivity()
raise e

predictions = []
for p in async_predictions:
try:
predictions.append(await p)
except Exception as e:
model.check_connectivity()
raise e

return [PredictionResult(x, y) for x, y in zip(batch, predictions)]

def run_inference(
self,
batch: Sequence[str],
Expand All @@ -200,22 +239,7 @@ def run_inference(
Returns:
An Iterable of type PredictionResult.
"""
client = getVLLMClient(model.get_server_port())
inference_args = inference_args or {}
predictions = []
# TODO(https://github.com/apache/beam/issues/32528): We should add support
# for taking in batches and doing a bunch of async calls. That will end up
# being more efficient when we can do in bundle batching.
for prompt in batch:
try:
completion = client.completions.create(
model=self._model_name, prompt=prompt, **inference_args)
predictions.append(completion)
except Exception as e:
model.check_connectivity()
raise e

return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
return asyncio.run(self._async_run_inference(batch, model, inference_args))

def share_model_across_processes(self) -> bool:
return True
Expand Down Expand Up @@ -272,41 +296,54 @@ def load_model(self) -> _VLLMModelServer:

return _VLLMModelServer(self._model_name, self._vllm_server_kwargs)

def run_inference(
async def _async_run_inference(
self,
batch: Sequence[Sequence[OpenAIChatMessage]],
model: _VLLMModelServer,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionResult]:
"""Runs inferences on a batch of text strings.

Args:
batch: A sequence of examples as OpenAI messages.
model: A _VLLMModelServer for connecting to the spun up server.
inference_args: Any additional arguments for an inference.

Returns:
An Iterable of type PredictionResult.
"""
client = getVLLMClient(model.get_server_port())
client = getAsyncVLLMClient(model.get_server_port())
inference_args = inference_args or {}
predictions = []
# TODO(https://github.com/apache/beam/issues/32528): We should add support
# for taking in batches and doing a bunch of async calls. That will end up
# being more efficient when we can do in bundle batching.
async_predictions = []
for messages in batch:
formatted = []
for message in messages:
formatted.append({"role": message.role, "content": message.content})
try:
completion = client.chat.completions.create(
model=self._model_name, messages=formatted, **inference_args)
predictions.append(completion)
async_predictions.append(completion)
except Exception as e:
model.check_connectivity()
raise e

predictions = []
for p in async_predictions:
try:
predictions.append(await p)
except Exception as e:
model.check_connectivity()
raise e

return [PredictionResult(x, y) for x, y in zip(batch, predictions)]

def run_inference(
self,
batch: Sequence[Sequence[OpenAIChatMessage]],
model: _VLLMModelServer,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionResult]:
"""Runs inferences on a batch of text strings.

Args:
batch: A sequence of examples as OpenAI messages.
model: A _VLLMModelServer for connecting to the spun up server.
inference_args: Any additional arguments for an inference.

Returns:
An Iterable of type PredictionResult.
"""
return asyncio.run(self._async_run_inference(batch, model, inference_args))

def share_model_across_processes(self) -> bool:
return True
Loading