Skip to content

Commit

Permalink
[TPU] Support tensor parallelism in async llm engine (vllm-project#6891)
Browse files Browse the repository at this point in the history
  • Loading branch information
etwk authored Jul 29, 2024
1 parent 60d1c6e commit 7f8d612
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
3 changes: 3 additions & 0 deletions Dockerfile.tpu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ RUN pip install "numpy<2"
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

# Fix FastAPI dependence
RUN pip install "starlette<0.38.0"

# Build vLLM.
COPY . /workspace/vllm
ENV VLLM_TARGET_DEVICE="tpu"
Expand Down
10 changes: 8 additions & 2 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,14 @@ def _get_executor_cls(
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
executor_class = RayTPUExecutorAsync
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
Expand Down

0 comments on commit 7f8d612

Please sign in to comment.