Skip to content

Commit

Permalink
[Doc] Update Ray Data distributed offline inference example (vllm-pro…
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored May 17, 2024
1 parent 48d5985 commit c5711ef
Showing 1 changed file with 42 additions and 6 deletions.
48 changes: 42 additions & 6 deletions examples/offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,31 @@

import numpy as np
import ray
from packaging.version import Version
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from vllm import LLM, SamplingParams

assert Version(ray.__version__) >= Version(
"2.22.0"), "Ray version must be at least 2.22.0"

# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Set tensor parallelism per instance.
tensor_parallel_size = 1

# Set number of instances. Each instance will use tensor_parallel_size GPUs.
num_instances = 1


# Create a class to do batch inference.
class LLMPredictor:

def __init__(self):
# Create an LLM.
self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf")
self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
tensor_parallel_size=tensor_parallel_size)

def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
# Generate texts from the prompts.
Expand All @@ -43,17 +55,41 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")


# For tensor_parallel_size > 1, we need to create placement groups for vLLM
# to use. Every actor has to have its own placement group.
def scheduling_strategy_fn():
# One bundle per tensor parallel worker
pg = ray.util.placement_group(
[{
"GPU": 1,
"CPU": 1
}] * tensor_parallel_size,
strategy="STRICT_PACK",
)
return dict(scheduling_strategy=PlacementGroupSchedulingStrategy(
pg, placement_group_capture_child_tasks=True))


resources_kwarg = {}
if tensor_parallel_size == 1:
# For tensor_parallel_size == 1, we simply set num_gpus=1.
resources_kwarg["num_gpus"] = 1
else:
# Otherwise, we have to set num_gpus=0 and provide
# a function that will create a placement group for
# each instance.
resources_kwarg["num_gpus"] = 0
resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn

# Apply batch inference for all input data.
ds = ds.map_batches(
LLMPredictor,
# Set the concurrency to the number of LLM instances.
concurrency=10,
# Specify the number of GPUs required per LLM instance.
# NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism
# (i.e., `tensor_parallel_size`).
num_gpus=1,
concurrency=num_instances,
# Specify the batch size for inference.
batch_size=32,
**resources_kwarg,
)

# Peek first 10 results.
Expand Down

0 comments on commit c5711ef

Please sign in to comment.