diff --git a/vllm/config.py b/vllm/config.py index dc0efc6ac..dd291ad60 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -87,12 +87,14 @@ def __init__( enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, max_logprobs: int = 5, + local_files_only: bool = False, ) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode self.trust_remote_code = trust_remote_code self.download_dir = download_dir + self.local_files_only = local_files_only self.load_format = load_format self.seed = seed self.revision = revision @@ -110,6 +112,10 @@ def __init__( from modelscope.hub.snapshot_download import snapshot_download if not os.path.exists(model): + if self.local_files_only: + raise ValueError( + f"Unable to find cached ModelScope model for {model} " + f"with local_files_only==True") model_path = snapshot_download(model_id=model, cache_dir=download_dir, revision=revision) @@ -119,8 +125,23 @@ def __init__( self.download_dir = model_path self.tokenizer = model_path - self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision) + elif self.local_files_only: + # TODO: fully support local_files_only propagation through + # each model class's load_weights function + # + # For places where we don't propagate local_files_only, modify + # the env var... + os.environ['HF_HUB_OFFLINE'] = "1" + # and monkey patch... + import huggingface_hub + huggingface_hub.constants.HF_HUB_OFFLINE = True + + self.hf_config = get_config(self.model, + trust_remote_code=trust_remote_code, + local_files_only=local_files_only, + cache_dir=download_dir, + revision=revision, + code_revision=code_revision) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.max_model_len = _get_and_verify_max_len(self.hf_text_config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 834b1eec1..2ae8e0f8f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from typing import Optional, Tuple +from huggingface_hub.constants import HF_HUB_OFFLINE + from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, TokenizerPoolConfig, VisionLanguageConfig) @@ -17,6 +19,7 @@ class EngineArgs: tokenizer_mode: str = 'auto' trust_remote_code: bool = False download_dir: Optional[str] = None + local_files_only: bool = HF_HUB_OFFLINE # checks TRANSFORMERS_OFFLINE too load_format: str = 'auto' dtype: str = 'auto' kv_cache_dtype: str = 'auto' @@ -124,6 +127,11 @@ def add_cli_args( help='directory to download and load the weights, ' 'default to the default cache dir of ' 'huggingface') + parser.add_argument( + '--local-files-only', + action='store_true', + default=EngineArgs.local_files_only, + help='disable downloads and only look at local files') parser.add_argument( '--load-format', type=str, @@ -395,7 +403,7 @@ def create_engine_configs( self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, self.enforce_eager, self.max_context_len_to_capture, - self.max_logprobs) + self.max_logprobs, self.local_files_only) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7928b36d8..bb95852f0 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -83,6 +83,7 @@ def __init__( f"dtype={model_config.dtype}, " f"max_seq_len={model_config.max_model_len}, " f"download_dir={model_config.download_dir!r}, " + f"local_files_only={model_config.local_files_only}, " f"load_format={model_config.load_format}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"disable_custom_all_reduce=" @@ -231,6 +232,8 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): max_input_length=None, tokenizer_mode=self.model_config.tokenizer_mode, trust_remote_code=self.model_config.trust_remote_code, + local_files_only=self.model_config.local_files_only, + cache_dir=self.model_config.download_dir, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) self.tokenizer: BaseTokenizerGroup = get_tokenizer_group( diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index ee047d08f..09fe0f042 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -66,7 +66,10 @@ async def _post_init(self): self.tokenizer = get_tokenizer( engine_model_config.tokenizer, tokenizer_mode=engine_model_config.tokenizer_mode, - trust_remote_code=engine_model_config.trust_remote_code) + trust_remote_code=engine_model_config.trust_remote_code, + local_files_only=engine_model_config.local_files_only, + cache_dir=engine_model_config.download_dir, + ) async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 9181f2988..e156d5b07 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -1,4 +1,5 @@ """Utilities for downloading and initializing model weights.""" +import contextlib import fnmatch import glob import hashlib @@ -9,8 +10,11 @@ import filelock import numpy as np +import requests import torch from huggingface_hub import HfFileSystem, snapshot_download +from huggingface_hub.constants import HF_HUB_OFFLINE +from huggingface_hub.utils import OfflineModeIsEnabled, RevisionNotFoundError from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm @@ -143,14 +147,11 @@ def prepare_hf_model_weights( fall_back_to_pt: bool = True, revision: Optional[str] = None, ) -> Tuple[str, List[str], bool]: - # Download model weights from huggingface. - is_local = os.path.isdir(model_name_or_path) - use_safetensors = False + # Determine the format of weights to load # Some quantized models use .pt files for storing the weights. if load_format == "auto": allow_patterns = ["*.safetensors", "*.bin"] elif load_format == "safetensors": - use_safetensors = True allow_patterns = ["*.safetensors"] elif load_format == "pt": allow_patterns = ["*.pt"] @@ -162,29 +163,68 @@ def prepare_hf_model_weights( if fall_back_to_pt: allow_patterns += ["*.pt"] - if not is_local: - # Before we download we look at that is available: - fs = HfFileSystem() - file_list = fs.ls(model_name_or_path, detail=False, revision=revision) - - # depending on what is available we download different things - for pattern in allow_patterns: - matching = fnmatch.filter(file_list, pattern) - if len(matching) > 0: - allow_patterns = [pattern] - break - - logger.info(f"Using model weights format {allow_patterns}") - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model_name_or_path, cache_dir): + # Find the model weights to load: + # - check if pointing at a local directory + # - download weights from HuggingFace Hub (including a newer revision if it + # exists) + # - discover weights in the local HuggingFace Hub cache (fallback to this if + # download fails) + if os.path.isdir(model_name_or_path): + hf_folder = model_name_or_path + else: + # If there is an error downloading from the HF API, we'll fallback to + # loading from the local cache + local_files_only = False + if HF_HUB_OFFLINE: + local_files_only = True + else: + try: + # Before we download we check the available files + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, + detail=False, + revision=revision) + + # Depending on what is available we download different things + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] + break + + logger.info(f"Using model weights format {allow_patterns}") + except ( + requests.exceptions.SSLError, + requests.exceptions.ProxyError, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + OfflineModeIsEnabled, + RevisionNotFoundError, + FileNotFoundError, + requests.HTTPError, + ) as error: + # If querying the repo fails (eg. Network is down / HF Hub is + # down / HF Hub returns access error / or HF_HUB_OFFLINE=1), see + # if we can fallback to load from locally cached files instead + # of crashing + logger.warning(f"Error in call to HF Hub: {error}. " + f"Attempting to load from local cache instead.") + local_files_only = True + + # Use file lock to prevent multiple processes from downloading the same + # model weights at the same time. If we fallback to local files only, + # we don't need the lock, but we still use snapshot_download to resolve + # the path to the model files in the cache + with get_lock(model_name_or_path, cache_dir + ) if not local_files_only else contextlib.nullcontext(): hf_folder = snapshot_download(model_name_or_path, allow_patterns=allow_patterns, cache_dir=cache_dir, + local_files_only=local_files_only, tqdm_class=Disabledtqdm, revision=revision) - else: - hf_folder = model_name_or_path + + use_safetensors = False hf_weights_files: List[str] = [] for pattern in allow_patterns: hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 8a6ba6c5b..2b3d9b1dd 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -14,16 +14,23 @@ } -def get_config(model: str, - trust_remote_code: bool, - revision: Optional[str] = None, - code_revision: Optional[str] = None) -> PretrainedConfig: +def get_config( + model: str, + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + cache_dir: Optional[str] = None, + local_files_only: bool = False, +) -> PretrainedConfig: try: config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision, - code_revision=code_revision) + code_revision=code_revision, + cache_dir=cache_dir, + local_files_only=local_files_only, + ) except ValueError as e: if (not trust_remote_code and "requires you to execute the configuration file" in str(e)): @@ -37,9 +44,13 @@ def get_config(model: str, raise e if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] - config = config_class.from_pretrained(model, - revision=revision, - code_revision=code_revision) + config = config_class.from_pretrained( + model, + revision=revision, + code_revision=code_revision, + cache_dir=cache_dir, + local_files_only=local_files_only, + ) return config