Skip to content

Commit

Permalink
Squash 10647
Browse files Browse the repository at this point in the history
Signed-off-by: Jefferson Fialho <[email protected]>
  • Loading branch information
fialhocoelho committed Dec 12, 2024
1 parent 9c7773c commit 97a6423
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 4 deletions.
3 changes: 3 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ Documentation
serving/metrics
serving/integrations
serving/tensorizer
serving/compatibility_matrix
serving/weights_loading_with_fastsafetensor
serving/faq

.. toctree::
:maxdepth: 1
Expand Down
5 changes: 5 additions & 0 deletions docs/source/serving/weights_loading_with_fastsafetensor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Loading Model weights with fastsafetensors
===================================================================

Using fastsafetensor library enables loading model weights to GPU memory by leveraging GPU direct storage. See https://github.com/foundation-model-stack/fastsafetensors for more details.
For enabling this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true``
1 change: 1 addition & 0 deletions requirements-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ torch == 2.5.1
# These must be updated alongside torch
torchvision == 0.20.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
xformers == 0.0.28.post3; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.5.1
fastsafetensors # Required for model loading via gpu direct storage
17 changes: 13 additions & 4 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference, get_gguf_extra_tensor_names,
gguf_quant_weights_iterator, initialize_dummy_weights,
np_cache_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
Expand Down Expand Up @@ -305,7 +306,15 @@ def _get_weights_iterator(
hf_weights_files,
)
elif use_safetensors:
weights_iterator = safetensors_weights_iterator(hf_weights_files)
use_fastsafe_tensor = os.getenv('USE_FASTSAFETENSOR',
'False').lower() == 'true'
if use_fastsafe_tensor:
logger.info("Using fastsafetensor for loading weights")
weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files)
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files)
else:
weights_iterator = pt_weights_iterator(hf_weights_files)

Expand Down
29 changes: 29 additions & 0 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import huggingface_hub.constants
import numpy as np
import torch
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
Expand Down Expand Up @@ -410,6 +411,34 @@ def safetensors_weights_iterator(
yield name, param


def fastsafetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files
using fastsafetensor library."""
pg = SingleGroup()
if torch.distributed.is_initialized():
pg = torch.distributed.group.WORLD

device = torch.device(f'cuda:{pg.rank()}')
weight_files_sub_lists = [
hf_weights_files[i:i + pg.size()]
for i in range(0, len(hf_weights_files), pg.size())
]

for f_list in weight_files_sub_lists:
loader = SafeTensorsFileLoader(pg, device)
rank_file_map = {i: [f] for i, f in enumerate(f_list)}
loader.add_filenames(rank_file_map)
fb = loader.copy_files_to_device()
keys = list(fb.key_to_rank_lidx.keys())
for k in keys:
t = fb.get_tensor(k)
yield k, t
fb.close()
loader.close()


def pt_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
Expand Down

0 comments on commit 97a6423

Please sign in to comment.