diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c2160fee348e5..b1602dd9496ba 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -182,7 +182,11 @@ steps: - pip install -r requirements-docs.txt - SPHINXOPTS=\"-W\" make html -- label: A100 status +- label: Distributed Tests (A100) gpu: a100 commands: - - nvidia-smi + # NOTE: don't test llama model here, it seems hf implementation is buggy + # see https://github.com/vllm-project/vllm/pull/5689 for details + - pytest -v -s distributed/test_custom_all_reduce.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 3776c1f91a3f2..9a39160b8a462 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -11,7 +11,8 @@ from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, get_tp_group, graph_capture) -from ..utils import (init_test_distributed_environment, +from ..utils import (ensure_model_parallel_initialized, + init_test_distributed_environment, multi_process_tensor_parallel) random.seed(42) @@ -27,8 +28,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): torch.cuda.set_device(device) init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) - - group = get_tensor_model_parallel_group() + ensure_model_parallel_initialized(tp_size, pp_size) + group = get_tensor_model_parallel_group().device_group # A small all_reduce for warmup. # this is needed because device communicators might be created lazily