From 92bfaab00567a9f275ef00f77ff2934779e1e56f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 14 Sep 2023 17:40:50 +0200 Subject: [PATCH] Add support for loading models from an fsspec filesystem (#327) * Add support for loading models from an fsspec filesystem This change adds a method `from_fsspec` to the `FromHFHub` mixins. This new method can be used to load a model from an fsspec filesystem. We should rename `FromHFHub` in the next major semver release, since this PR extends loading support beyond Hugging Face hub. * Doc fix Co-authored-by: Madeesh Kannan * Name functions more generically * Apply suggestions from code review Add fixes from @shadeMe Co-authored-by: Madeesh Kannan * Add missing type hints * Fix argument order * Remove an exception * black * Type fix --------- Co-authored-by: Madeesh Kannan --- curated_transformers/models/auto_model.py | 127 +++++++- curated_transformers/models/hf_hub.py | 110 +++++-- .../tests/models/test_auto_models.py | 19 +- .../tests/models/test_hf_hub.py | 54 +++- curated_transformers/tests/models/util.py | 15 +- .../tokenizers/legacy/test_bert_tokenizer.py | 2 +- .../legacy/test_camembert_tokenizer.py | 3 +- .../tokenizers/legacy/test_llama_tokenizer.py | 2 +- .../legacy/test_roberta_tokenizer.py | 2 +- .../tokenizers/legacy/test_xlmr_tokenizer.py | 3 +- .../tests/tokenizers/test_auto_tokenizer.py | 11 + .../tests/tokenizers/test_hf_hub.py | 17 + curated_transformers/tests/tokenizers/util.py | 12 +- .../tokenizers/auto_tokenizer.py | 94 +++++- curated_transformers/tokenizers/hf_hub.py | 56 +++- .../tokenizers/legacy/bert_tokenizer.py | 16 +- .../tokenizers/legacy/camembert_tokenizer.py | 17 +- .../tokenizers/legacy/llama_tokenizer.py | 19 +- .../tokenizers/legacy/roberta_tokenizer.py | 26 +- .../tokenizers/legacy/xlmr_tokenizer.py | 17 +- curated_transformers/tokenizers/tokenizer.py | 29 +- curated_transformers/util/fsspec.py | 291 ++++++++++++++++++ curated_transformers/util/hf.py | 58 ++-- curated_transformers/util/serde.py | 137 ++++++++- docs/source/api-compat.rst | 1 + docs/source/usage.rst | 35 +++ requirements.txt | 2 +- setup.cfg | 2 +- 28 files changed, 1016 insertions(+), 161 deletions(-) create mode 100644 curated_transformers/util/fsspec.py diff --git a/curated_transformers/models/auto_model.py b/curated_transformers/models/auto_model.py index 4ba24234..b57f5181 100644 --- a/curated_transformers/models/auto_model.py +++ b/curated_transformers/models/auto_model.py @@ -1,11 +1,13 @@ from abc import ABC, abstractmethod -from typing import Dict, Generic, Optional, Type, TypeVar +from typing import Any, Dict, Generic, Optional, Type, TypeVar import torch +from fsspec import AbstractFileSystem from ..layers.cache import KeyValueCache from ..quantization.bnb.config import BitsAndBytesConfig -from ..util.hf import get_hf_config_model_type +from ..util.fsspec import get_config_model_type as get_config_model_type_fsspec +from ..util.hf import get_config_model_type from .albert import ALBERTEncoder from .bert import BERTEncoder from .camembert import CamemBERTEncoder @@ -30,13 +32,36 @@ class AutoModel(ABC, Generic[ModelT]): _hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {} + @classmethod + def _resolve_model_cls_fsspec( + cls, + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, + ) -> Type[FromHFHub]: + model_type = get_config_model_type_fsspec( + fs, model_path, fsspec_args=fsspec_args + ) + if model_type is None: + raise ValueError( + "The model type is not defined in the model configuration." + ) + module_cls = cls._hf_model_type_to_curated.get(model_type) + if module_cls is None: + raise ValueError( + f"Unsupported model type `{model_type}` for {cls.__name__}. " + f"Supported model types: {tuple(cls._hf_model_type_to_curated.keys())}" + ) + assert issubclass(module_cls, FromHFHub) + return module_cls + @classmethod def _resolve_model_cls( cls, name: str, revision: str, ) -> Type[FromHFHub]: - model_type = get_hf_config_model_type(name, revision) + model_type = get_config_model_type(name, revision) module_cls = cls._hf_model_type_to_curated.get(model_type) if module_cls is None: raise ValueError( @@ -46,6 +71,25 @@ def _resolve_model_cls( assert issubclass(module_cls, FromHFHub) return module_cls + @classmethod + def _instantiate_model_from_fsspec( + cls, + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]], + device: Optional[torch.device], + quantization_config: Optional[BitsAndBytesConfig], + ) -> FromHFHub: + module_cls = cls._resolve_model_cls_fsspec(fs, model_path) + module = module_cls.from_fsspec( + fs=fs, + model_path=model_path, + fsspec_args=fsspec_args, + device=device, + quantization_config=quantization_config, + ) + return module + @classmethod def _instantiate_model_from_hf_hub( cls, @@ -63,6 +107,35 @@ def _instantiate_model_from_hf_hub( ) return module + @classmethod + def from_fsspec( + cls, + *, + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + quantization_config: Optional[BitsAndBytesConfig] = None, + ) -> ModelT: + """ + Construct a module and load its parameters from a fsspec filesystem. + + :param fs: + The filesystem to load the model from. + :param model_path: + The path of the model on the filesystem. + :param fsspec_args: + Implementation-specific keyword arguments to pass to fsspec + filesystem operations. + :param device: + Device on which the model is initialized. + :param quantization_config: + Configuration for loading quantized weights. + :returns: + Module with the parameters loaded. + """ + raise NotImplementedError + @classmethod @abstractmethod def from_hf_hub( @@ -124,6 +197,22 @@ class AutoEncoder(AutoModel[EncoderModule]): "xlm-roberta": XLMREncoder, } + @classmethod + def from_fsspec( + cls, + *, + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + quantization_config: Optional[BitsAndBytesConfig] = None, + ) -> EncoderModule: + encoder = cls._instantiate_model_from_fsspec( + fs, model_path, fsspec_args, device, quantization_config + ) + assert isinstance(encoder, EncoderModule) + return encoder + @classmethod def from_hf_hub( cls, @@ -154,6 +243,22 @@ class AutoDecoder(AutoModel[DecoderModule]): "RefinedWebModel": FalconDecoder, } + @classmethod + def from_fsspec( + cls, + *, + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + quantization_config: Optional[BitsAndBytesConfig] = None, + ) -> DecoderModule: + decoder = cls._instantiate_model_from_fsspec( + fs, model_path, fsspec_args, device, quantization_config + ) + assert isinstance(decoder, DecoderModule) + return decoder + @classmethod def from_hf_hub( cls, @@ -184,6 +289,22 @@ class AutoCausalLM(AutoModel[CausalLMModule[KeyValueCache]]): "RefinedWebModel": FalconCausalLM, } + @classmethod + def from_fsspec( + cls, + *, + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + quantization_config: Optional[BitsAndBytesConfig] = None, + ) -> CausalLMModule[KeyValueCache]: + causal_lm = cls._instantiate_model_from_fsspec( + fs, model_path, fsspec_args, device, quantization_config + ) + assert isinstance(causal_lm, CausalLMModule) + return causal_lm + @classmethod def from_hf_hub( cls, diff --git a/curated_transformers/models/hf_hub.py b/curated_transformers/models/hf_hub.py index b632db78..a9b4a552 100644 --- a/curated_transformers/models/hf_hub.py +++ b/curated_transformers/models/hf_hub.py @@ -1,4 +1,3 @@ -import json from abc import ABC, abstractmethod from typing import ( Any, @@ -14,12 +13,17 @@ ) import torch +from fsspec import AbstractFileSystem from torch import Tensor from ..quantization import prepare_module_for_quantization from ..quantization.bnb.config import BitsAndBytesConfig -from ..util.hf import get_model_checkpoint_filepaths, get_model_config_filepath -from ..util.serde import load_model_from_checkpoints +from ..util.fsspec import ( + get_model_checkpoint_files as get_model_checkpoint_files_fsspec, +) +from ..util.fsspec import get_model_config as get_model_config_fsspec +from ..util.hf import get_model_checkpoint_files, get_model_config +from ..util.serde import ModelCheckpointType, ModelFile, load_model_from_checkpoints # Only provided as typing.Self in Python 3.11+. Self = TypeVar("Self", bound="FromHFHub") @@ -89,8 +93,46 @@ def from_hf_hub_to_cache( :param revision: Model revision. """ - _ = get_model_config_filepath(name, revision) - _ = get_model_checkpoint_filepaths(name, revision) + _ = get_model_config(name, revision) + _ = get_model_checkpoint_files(name, revision) + + @classmethod + def from_fsspec( + cls: Type[Self], + *, + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + quantization_config: Optional[BitsAndBytesConfig] = None, + ) -> Self: + """ + Construct a module and load its parameters from a fsspec filesystem. + + :param fs: + The filesystem to load the model from. + :param model_path: + The path of the model on the filesystem. + :param fsspec_args: + Implementation-specific keyword arguments to pass to fsspec + filesystem operations. + :param device: + Device on which the model is initialized. + :param quantization_config: + Configuration for loading quantized weights. + :returns: + Module with the parameters loaded. + """ + return cls._create_and_load_model( + get_config=lambda: get_model_config_fsspec( + fs, model_path, fsspec_args=fsspec_args + ), + get_checkpoint_files=lambda: get_model_checkpoint_files_fsspec( + fs, model_path, fsspec_args=fsspec_args + ), + device=device, + quantization_config=quantization_config, + ) @classmethod def from_hf_hub( @@ -115,11 +157,39 @@ def from_hf_hub( :returns: Module with the parameters loaded. """ - # Download configuration and construct model. - config_filename = get_model_config_filepath(name, revision) - with open(config_filename, "r") as f: - config = json.load(f) - # Initialize the model on the torch `meta` device to avoid unnecessary allocations. + return cls._create_and_load_model( + get_config=lambda: get_model_config(name, revision), + get_checkpoint_files=lambda: get_model_checkpoint_files(name, revision), + device=device, + quantization_config=quantization_config, + ) + + @abstractmethod + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, + ): + """ + Moves and/or casts the parameters and buffers. + + This method is automatically implemented by also deriving from + ``torch.nn.Module``. This mixin does not derive from ``Module`` in + order to be an abstract base class. + """ + ... + + @classmethod + def _create_and_load_model( + cls: Type[Self], + *, + get_config: Callable[[], Dict[Any, str]], + get_checkpoint_files: Callable[[], Tuple[List[ModelFile], ModelCheckpointType]], + device: Optional[torch.device] = None, + quantization_config: Optional[BitsAndBytesConfig] = None, + ) -> Self: + config = get_config() model = cls.from_hf_config(hf_config=config, device=torch.device("meta")) # Convert the model to the expected dtype. @@ -137,9 +207,7 @@ def from_hf_hub( tensor2param = None # Download model and convert HF parameter names to ours. - checkpoint_filenames, checkpoint_type = get_model_checkpoint_filepaths( - name, revision - ) + checkpoint_filenames, checkpoint_type = get_checkpoint_files() load_model_from_checkpoints( model, # type:ignore filepaths=checkpoint_filenames, @@ -156,22 +224,6 @@ def from_hf_hub( return model - @abstractmethod - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - non_blocking: bool = False, - ): - """ - Moves and/or casts the parameters and buffers. - - This method is automatically implemented by also deriving from - ``torch.nn.Module``. This mixin does not derive from ``Module`` in - order to be an abstract base class. - """ - ... - def _process_hf_keys( model_name: str, diff --git a/curated_transformers/tests/models/test_auto_models.py b/curated_transformers/tests/models/test_auto_models.py index 17b677cb..0a6c3fe8 100644 --- a/curated_transformers/tests/models/test_auto_models.py +++ b/curated_transformers/tests/models/test_auto_models.py @@ -1,4 +1,5 @@ import pytest +from huggingface_hub import HfFileSystem from curated_transformers.models import ( ALBERTEncoder, @@ -22,8 +23,9 @@ from curated_transformers.models.mpt.decoder import MPTDecoder -def test_auto_encoder(): - model_encoder_map = { +@pytest.fixture +def model_encoder_map(): + return { "explosion-testing/bert-test": BERTEncoder, "explosion-testing/albert-test": ALBERTEncoder, "explosion-testing/roberta-test": RoBERTaEncoder, @@ -31,6 +33,8 @@ def test_auto_encoder(): "explosion-testing/xlm-roberta-test": XLMREncoder, } + +def test_auto_encoder(model_encoder_map): for name, encoder_cls in model_encoder_map.items(): encoder = AutoEncoder.from_hf_hub(name=name) assert isinstance(encoder, encoder_cls) @@ -39,6 +43,17 @@ def test_auto_encoder(): AutoEncoder.from_hf_hub(name="explosion-testing/falcon-test") +@pytest.mark.slow +def test_auto_encoder_fsspec(model_encoder_map): + for name, encoder_cls in model_encoder_map.items(): + # The default revision is 'main', but we pass it anyway to test + # that the function acceps fsspec_args. + encoder = AutoEncoder.from_fsspec( + fs=HfFileSystem(), model_path=name, fsspec_args={"revision": "main"} + ) + assert isinstance(encoder, encoder_cls) + + def test_auto_decoder(): model_decoder_map = { "explosion-testing/falcon-test": FalconDecoder, diff --git a/curated_transformers/tests/models/test_hf_hub.py b/curated_transformers/tests/models/test_hf_hub.py index b5633d93..0a33067a 100644 --- a/curated_transformers/tests/models/test_hf_hub.py +++ b/curated_transformers/tests/models/test_hf_hub.py @@ -4,7 +4,7 @@ from huggingface_hub import _CACHED_NO_EXIST, try_to_load_from_cache from curated_transformers.models.bert.encoder import BERTEncoder -from curated_transformers.util.hf import get_model_checkpoint_filepaths +from curated_transformers.util.hf import get_model_checkpoint_files from curated_transformers.util.serde import ( ModelCheckpointType, _use_model_checkpoint_type, @@ -54,11 +54,11 @@ def test_checkpoint_type_without_safetensors(): # By default, we expect the torch checkpoint to be loaded # even if the safetensor checkpoints are present # (as long as the library is not installed). - ckp_paths, ckp_type = get_model_checkpoint_filepaths( + ckp_paths, ckp_type = get_model_checkpoint_files( "explosion-testing/safetensors-test", revision="main" ) assert len(ckp_paths) == 1 - assert Path(ckp_paths[0]).suffix == ".bin" + assert Path(ckp_paths[0].path).suffix == ".bin" assert ckp_type == ModelCheckpointType.PYTORCH_STATE_DICT with pytest.raises(ValueError, match="`safetensors` library is required"): @@ -70,11 +70,11 @@ def test_checkpoint_type_without_safetensors(): def test_checkpoint_type_with_safetensors(): # Since the safetensors library is installed, we should be # loading from those checkpoints. - ckp_paths, ckp_type = get_model_checkpoint_filepaths( + ckp_paths, ckp_type = get_model_checkpoint_files( "explosion-testing/safetensors-test", revision="main" ) assert len(ckp_paths) == 1 - assert Path(ckp_paths[0]).suffix == ".safetensors" + assert Path(ckp_paths[0].path).suffix == ".safetensors" assert ckp_type == ModelCheckpointType.SAFE_TENSORS encoder = BERTEncoder.from_hf_hub(name="explosion-testing/safetensors-test") @@ -83,21 +83,57 @@ def test_checkpoint_type_with_safetensors(): @pytest.mark.skipif(not has_safetensors, reason="requires huggingface safetensors") def test_forced_checkpoint_type(): with _use_model_checkpoint_type(ModelCheckpointType.PYTORCH_STATE_DICT): - ckp_paths, ckp_type = get_model_checkpoint_filepaths( + ckp_paths, ckp_type = get_model_checkpoint_files( "explosion-testing/safetensors-sharded-test", revision="main" ) assert len(ckp_paths) == 3 - assert all(Path(p).suffix == ".bin" for p in ckp_paths) + assert all(Path(p.path).suffix == ".bin" for p in ckp_paths) assert ckp_type == ModelCheckpointType.PYTORCH_STATE_DICT encoder = BERTEncoder.from_hf_hub(name="explosion-testing/safetensors-test") with _use_model_checkpoint_type(ModelCheckpointType.SAFE_TENSORS): - ckp_paths, ckp_type = get_model_checkpoint_filepaths( + ckp_paths, ckp_type = get_model_checkpoint_files( "explosion-testing/safetensors-sharded-test", revision="main" ) assert len(ckp_paths) == 3 - assert all(Path(p).suffix == ".safetensors" for p in ckp_paths) + assert all(Path(p.path).suffix == ".safetensors" for p in ckp_paths) assert ckp_type == ModelCheckpointType.SAFE_TENSORS encoder = BERTEncoder.from_hf_hub(name="explosion-testing/safetensors-test") + + +@pytest.mark.slow +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +@pytest.mark.parametrize("torch_device", TORCH_DEVICES) +def test_fsspec(torch_device): + assert_encoder_output_equals_hf( + BERTEncoder, + "explosion-testing/bert-test", + torch_device, + with_fsspec=True, + ) + + +@pytest.mark.slow +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +@pytest.mark.parametrize("torch_device", TORCH_DEVICES) +def test_fsspec_sharded(torch_device): + assert_encoder_output_equals_hf( + BERTEncoder, + "explosion-testing/bert-test-sharded", + torch_device, + with_fsspec=True, + ) + + +@pytest.mark.slow +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +@pytest.mark.parametrize("torch_device", TORCH_DEVICES) +def test_fsspec_safetensors(torch_device): + assert_encoder_output_equals_hf( + BERTEncoder, + "explosion-testing/safetensors-test", + torch_device, + with_fsspec=True, + ) diff --git a/curated_transformers/tests/models/util.py b/curated_transformers/tests/models/util.py index 0230001a..aaf4adbd 100644 --- a/curated_transformers/tests/models/util.py +++ b/curated_transformers/tests/models/util.py @@ -2,6 +2,7 @@ from typing import Callable, Dict, List, Tuple, Type, Union import torch +from huggingface_hub import HfFileSystem from torch import Tensor from torch.nn import Module @@ -198,12 +199,18 @@ def assert_encoder_output_equals_hf( model_name: str, torch_device: torch.device, *, - atol=1e-5, - rtol=1e-5, + atol: float = 1e-5, + rtol: float = 1e-5, jit_method: JITMethod = JITMethod.Disable, - with_torch_sdp=False, + with_fsspec: bool = False, + with_torch_sdp: bool = False, ): - orig_model = model_class.from_hf_hub(name=model_name, device=torch_device) + if with_fsspec: + orig_model = model_class.from_fsspec( + fs=HfFileSystem(), model_path=model_name, device=torch_device + ) + else: + orig_model = model_class.from_hf_hub(name=model_name, device=torch_device) orig_model.eval() for _, param in orig_model.state_dict().items(): diff --git a/curated_transformers/tests/tokenizers/legacy/test_bert_tokenizer.py b/curated_transformers/tests/tokenizers/legacy/test_bert_tokenizer.py index 1ebe6656..f8418a5e 100644 --- a/curated_transformers/tests/tokenizers/legacy/test_bert_tokenizer.py +++ b/curated_transformers/tests/tokenizers/legacy/test_bert_tokenizer.py @@ -21,7 +21,7 @@ @pytest.fixture def toy_tokenizer_from_files(test_dir): return BERTTokenizer.from_files( - vocab_path=test_dir / "toy.wordpieces", + vocab_file=test_dir / "toy.wordpieces", ) diff --git a/curated_transformers/tests/tokenizers/legacy/test_camembert_tokenizer.py b/curated_transformers/tests/tokenizers/legacy/test_camembert_tokenizer.py index b897bde2..4181a899 100644 --- a/curated_transformers/tests/tokenizers/legacy/test_camembert_tokenizer.py +++ b/curated_transformers/tests/tokenizers/legacy/test_camembert_tokenizer.py @@ -5,6 +5,7 @@ from curated_transformers.tokenizers.legacy.camembert_tokenizer import ( CamemBERTTokenizer, ) +from curated_transformers.util.serde import LocalModelFile from ...compat import has_hf_transformers from ...utils import torch_assertclose @@ -14,7 +15,7 @@ @pytest.fixture def toy_tokenizer(test_dir): return CamemBERTTokenizer.from_files( - model_path=test_dir / "toy.model", + model_file=LocalModelFile(path=test_dir / "toy.model"), ) diff --git a/curated_transformers/tests/tokenizers/legacy/test_llama_tokenizer.py b/curated_transformers/tests/tokenizers/legacy/test_llama_tokenizer.py index 0d73a869..cd9ef27a 100644 --- a/curated_transformers/tests/tokenizers/legacy/test_llama_tokenizer.py +++ b/curated_transformers/tests/tokenizers/legacy/test_llama_tokenizer.py @@ -15,6 +15,6 @@ def test_from_hf_hub_equals_hf_tokenizer(sample_texts): sample_texts, "openlm-research/open_llama_3b", LlamaTokenizer, - hf_use_fast=False, + with_hf_fast=False, pad_token="", ) diff --git a/curated_transformers/tests/tokenizers/legacy/test_roberta_tokenizer.py b/curated_transformers/tests/tokenizers/legacy/test_roberta_tokenizer.py index e7f8573d..da19c64e 100644 --- a/curated_transformers/tests/tokenizers/legacy/test_roberta_tokenizer.py +++ b/curated_transformers/tests/tokenizers/legacy/test_roberta_tokenizer.py @@ -12,7 +12,7 @@ @pytest.fixture def toy_tokenizer_from_files(test_dir): return RoBERTaTokenizer.from_files( - vocab_path=test_dir / "toy-vocab.json", merges_path=test_dir / "toy-merges.txt" + vocab_file=test_dir / "toy-vocab.json", merges_file=test_dir / "toy-merges.txt" ) diff --git a/curated_transformers/tests/tokenizers/legacy/test_xlmr_tokenizer.py b/curated_transformers/tests/tokenizers/legacy/test_xlmr_tokenizer.py index a3034e49..6a3de92e 100644 --- a/curated_transformers/tests/tokenizers/legacy/test_xlmr_tokenizer.py +++ b/curated_transformers/tests/tokenizers/legacy/test_xlmr_tokenizer.py @@ -3,6 +3,7 @@ from curated_transformers.tokenizers import PiecesWithIds from curated_transformers.tokenizers.legacy.xlmr_tokenizer import XLMRTokenizer +from curated_transformers.util.serde import LocalModelFile from ...compat import has_hf_transformers from ...utils import torch_assertclose @@ -12,7 +13,7 @@ @pytest.fixture def toy_tokenizer(test_dir): return XLMRTokenizer.from_files( - model_path=test_dir / "toy.model", + model_file=LocalModelFile(path=test_dir / "toy.model"), ) diff --git a/curated_transformers/tests/tokenizers/test_auto_tokenizer.py b/curated_transformers/tests/tokenizers/test_auto_tokenizer.py index bc708463..77a70152 100644 --- a/curated_transformers/tests/tokenizers/test_auto_tokenizer.py +++ b/curated_transformers/tests/tokenizers/test_auto_tokenizer.py @@ -1,4 +1,5 @@ import pytest +from huggingface_hub import HfFileSystem from curated_transformers.tokenizers import AutoTokenizer @@ -18,6 +19,16 @@ def test_auto_tokenizer(model_revision): AutoTokenizer.from_hf_hub(name=name, revision=revision) +@pytest.mark.slow +@pytest.mark.parametrize("model_revision", _MODELS) +def test_auto_tokenizer_fsspec(model_revision): + name, revision = model_revision + AutoTokenizer.from_fsspec( + fs=HfFileSystem(), model_path=name, fsspec_args={"revision": revision} + ) + AutoTokenizer.from_hf_hub(name=name, revision=revision) + + def test_cannot_infer(): # This repo/revision does not have a tokenizer and doesn't match a # legacy tokenizer. diff --git a/curated_transformers/tests/tokenizers/test_hf_hub.py b/curated_transformers/tests/tokenizers/test_hf_hub.py index e1fcf8c1..183145e4 100644 --- a/curated_transformers/tests/tokenizers/test_hf_hub.py +++ b/curated_transformers/tests/tokenizers/test_hf_hub.py @@ -1,8 +1,12 @@ +import pytest from huggingface_hub import _CACHED_NO_EXIST, try_to_load_from_cache from curated_transformers.tokenizers import Tokenizer from curated_transformers.tokenizers.legacy import BERTTokenizer +from ..compat import has_hf_transformers +from .util import compare_tokenizer_outputs_with_hf_tokenizer + def test_from_hf_hub_to_cache(): Tokenizer.from_hf_hub_to_cache( @@ -45,3 +49,16 @@ def test_from_hf_hub_to_cache_legacy(): ) != _CACHED_NO_EXIST ) + + +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +def test_fsspec(sample_texts): + # We only test one model, since using fsspec downloads the model + # each time. + compare_tokenizer_outputs_with_hf_tokenizer( + sample_texts, + "EleutherAI/gpt-neox-20b", + Tokenizer, + pad_token="", + with_fsspec=True, + ) diff --git a/curated_transformers/tests/tokenizers/util.py b/curated_transformers/tests/tokenizers/util.py index e6f9d630..91ea2104 100644 --- a/curated_transformers/tests/tokenizers/util.py +++ b/curated_transformers/tests/tokenizers/util.py @@ -1,5 +1,7 @@ from typing import Optional +from huggingface_hub import HfFileSystem + from ..compat import transformers from ..utils import torch_assertclose @@ -9,14 +11,18 @@ def compare_tokenizer_outputs_with_hf_tokenizer( hf_name, tokenizer_cls, pad_token: Optional[str] = None, - hf_use_fast: bool = True, + with_hf_fast: bool = True, + with_fsspec: bool = False, revision: str = "main", ): - tokenizer = tokenizer_cls.from_hf_hub(name=hf_name, revision=revision) + if with_fsspec: + tokenizer = tokenizer_cls.from_fsspec(fs=HfFileSystem(), model_path=hf_name) + else: + tokenizer = tokenizer_cls.from_hf_hub(name=hf_name, revision=revision) pieces = tokenizer(sample_texts) hf_tokenizer = transformers.AutoTokenizer.from_pretrained( - hf_name, revision=revision, use_fast=hf_use_fast + hf_name, revision=revision, use_fast=with_hf_fast ) hf_tokenizer.padding_side = "right" if pad_token is not None: diff --git a/curated_transformers/tokenizers/auto_tokenizer.py b/curated_transformers/tokenizers/auto_tokenizer.py index e4496055..b8361996 100644 --- a/curated_transformers/tokenizers/auto_tokenizer.py +++ b/curated_transformers/tokenizers/auto_tokenizer.py @@ -1,8 +1,11 @@ -from typing import Dict, Optional, Type, cast +from typing import Any, Dict, Optional, Type, cast +from fsspec import AbstractFileSystem from huggingface_hub.utils import EntryNotFoundError -from ..util.hf import TOKENIZER_JSON, get_file_metadata, get_hf_config_model_type +from ..util.fsspec import get_config_model_type as get_model_type_fsspec +from ..util.fsspec import get_tokenizer_config as get_tokenizer_config_fsspec +from ..util.hf import TOKENIZER_JSON, get_config_model_type, get_file_metadata from .hf_hub import FromHFHub, get_tokenizer_config from .legacy.bert_tokenizer import BERTTokenizer from .legacy.camembert_tokenizer import CamemBERTTokenizer @@ -59,9 +62,41 @@ def from_hf_hub_to_cache( :param revision: Model revision. """ - tokenizer_cls = _resolve_tokenizer_class(name, revision) + tokenizer_cls = _resolve_tokenizer_class_hf_hub(name, revision) tokenizer_cls.from_hf_hub_to_cache(name=name, revision=revision) + @classmethod + def from_fsspec( + cls, + *, + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, + ) -> TokenizerBase: + """ + Construct a tokenizer and load its parameters from an fsspec filesystem. + + :param fs: + Filesystem. + :param model_path: + The model path. + :param fsspec_args: + Implementation-specific keyword arguments to pass to fsspec + filesystem operations. + :returns: + The tokenizer. + """ + tokenizer_cls = _resolve_tokenizer_class_fsspec( + fs=fs, model_path=model_path, fsspec_args=fsspec_args + ) + # This cast is safe, because we only return tokenizers. + return cast( + TokenizerBase, + tokenizer_cls.from_fsspec( + fs=fs, model_path=model_path, fsspec_args=fsspec_args + ), + ) + @classmethod def from_hf_hub(cls, *, name: str, revision: str = "main") -> TokenizerBase: """ @@ -75,7 +110,7 @@ def from_hf_hub(cls, *, name: str, revision: str = "main") -> TokenizerBase: The tokenizer. """ - tokenizer_cls = _resolve_tokenizer_class(name, revision) + tokenizer_cls = _resolve_tokenizer_class_hf_hub(name, revision) # This cast is safe, because we only return tokenizers. return cast( TokenizerBase, tokenizer_cls.from_hf_hub(name=name, revision=revision) @@ -83,28 +118,52 @@ def from_hf_hub(cls, *, name: str, revision: str = "main") -> TokenizerBase: def _get_tokenizer_class_from_config( - *, name: str, revision: str + tokenizer_config: Dict[str, Any] ) -> Optional[Type[FromHFHub]]: """ Infer the tokenizer class from the tokenizer configuration. - :param name: - Model name. + :param tokenizer_config: + The tokenizer configuration. :param revision: Model revision. :returns: Inferred class. """ + return HF_TOKENIZER_MAPPING.get(tokenizer_config.get("tokenizer_class", None), None) - try: - tokenizer_config = get_tokenizer_config(name=name, revision=revision) - except EntryNotFoundError: - return None - return HF_TOKENIZER_MAPPING.get(tokenizer_config.get("tokenizer_class", None), None) +def _resolve_tokenizer_class_fsspec( + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, +) -> Type[FromHFHub]: + fsspec_args = {} if fsspec_args is None else fsspec_args + tokenizer_cls: Optional[Type[FromHFHub]] = None + if fs.exists(f"{model_path}/{TOKENIZER_JSON}", **fsspec_args): + return Tokenizer + + if tokenizer_cls is None: + tokenizer_config = get_tokenizer_config_fsspec( + fs=fs, model_path=model_path, fsspec_args=fsspec_args + ) + if tokenizer_config is not None: + tokenizer_cls = _get_tokenizer_class_from_config(tokenizer_config) + + if tokenizer_cls is None: + model_type = get_model_type_fsspec( + fs=fs, model_path=model_path, fsspec_args=fsspec_args + ) + if model_type is not None: + tokenizer_cls = HF_MODEL_MAPPING.get(model_type, None) + + if tokenizer_cls is None: + raise ValueError(f"Cannot infer tokenizer for model at path: {model_path}") + + return tokenizer_cls -def _resolve_tokenizer_class(name: str, revision: str) -> Type[FromHFHub]: +def _resolve_tokenizer_class_hf_hub(name: str, revision: str) -> Type[FromHFHub]: tokenizer_cls: Optional[Type[FromHFHub]] = None try: # We will try to fetch metadata to avoid potentially downloading @@ -116,11 +175,16 @@ def _resolve_tokenizer_class(name: str, revision: str) -> Type[FromHFHub]: tokenizer_cls = Tokenizer if tokenizer_cls is None: - tokenizer_cls = _get_tokenizer_class_from_config(name=name, revision=revision) + try: + tokenizer_config = get_tokenizer_config(name=name, revision=revision) + except EntryNotFoundError: + pass + else: + tokenizer_cls = _get_tokenizer_class_from_config(tokenizer_config) if tokenizer_cls is None: try: - model_type = get_hf_config_model_type(name=name, revision=revision) + model_type = get_config_model_type(name=name, revision=revision) except EntryNotFoundError: pass else: diff --git a/curated_transformers/tokenizers/hf_hub.py b/curated_transformers/tokenizers/hf_hub.py index a628062b..2e719792 100644 --- a/curated_transformers/tokenizers/hf_hub.py +++ b/curated_transformers/tokenizers/hf_hub.py @@ -1,10 +1,12 @@ from abc import ABC, abstractmethod -from pathlib import Path -from typing import Any, Dict, List, Optional, Type, TypeVar +from typing import Any, Dict, Mapping, Optional, Type, TypeVar +from fsspec import AbstractFileSystem from huggingface_hub.utils import EntryNotFoundError +from ..util.fsspec import get_tokenizer_config as get_tokenizer_config_fsspec from ..util.hf import get_tokenizer_config, hf_hub_download +from ..util.serde import FsspecModelFile, LocalModelFile, ModelFile SelfFromHFHub = TypeVar("SelfFromHFHub", bound="FromHFHub") @@ -38,6 +40,30 @@ def from_hf_hub_to_cache( """ raise NotImplementedError + @classmethod + @abstractmethod + def from_fsspec( + cls: Type[SelfFromHFHub], + *, + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, + ) -> SelfFromHFHub: + """ + Construct a tokenizer and load its parameters from an fsspec filesystem. + + :param fs: + Filesystem. + :param model_path: + The model path. + :param fsspec_args: + Implementation-specific keyword arguments to pass to fsspec + filesystem operations. + :returns: + The tokenizer. + """ + raise NotImplementedError + @classmethod @abstractmethod def from_hf_hub( @@ -77,7 +103,7 @@ class LegacyFromHFHub(FromHFHub): def _load_from_vocab_files( cls: Type[SelfLegacyFromHFHub], *, - vocab_files: Dict[str, Path], + vocab_files: Mapping[str, ModelFile], tokenizer_config: Optional[Dict[str, Any]], ) -> SelfLegacyFromHFHub: """ @@ -108,13 +134,35 @@ def from_hf_hub_to_cache( except EntryNotFoundError: pass + @classmethod + def from_fsspec( + cls: Type[SelfLegacyFromHFHub], + *, + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, + ) -> SelfLegacyFromHFHub: + vocab_files = {} + for vocab_file, filename in cls.vocab_files.items(): + vocab_files[vocab_file] = FsspecModelFile( + fs, f"{model_path}/{filename}", fsspec_args + ) + + tokenizer_config = get_tokenizer_config_fsspec( + fs=fs, model_path=model_path, fsspec_args=fsspec_args + ) + + return cls._load_from_vocab_files( + vocab_files=vocab_files, tokenizer_config=tokenizer_config + ) + @classmethod def from_hf_hub( cls: Type[SelfLegacyFromHFHub], *, name: str, revision: str = "main" ) -> SelfLegacyFromHFHub: vocab_files = {} for vocab_file, filename in cls.vocab_files.items(): - vocab_files[vocab_file] = Path( + vocab_files[vocab_file] = LocalModelFile( hf_hub_download(repo_id=name, filename=filename, revision=revision) ) diff --git a/curated_transformers/tokenizers/legacy/bert_tokenizer.py b/curated_transformers/tokenizers/legacy/bert_tokenizer.py index 6649a4b2..86c7accb 100644 --- a/curated_transformers/tokenizers/legacy/bert_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/bert_tokenizer.py @@ -1,9 +1,9 @@ import unicodedata -from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar +from typing import Any, Dict, Iterable, List, Mapping, Optional, Type, TypeVar from curated_tokenizers import WordPieceProcessor +from ...util.serde import ModelFile from .._hf_compat import clean_up_decoded_string_like_hf, tokenize_chinese_chars_bert from ..chunks import ( InputChunks, @@ -260,7 +260,7 @@ def __init__( def from_files( cls: Type[Self], *, - vocab_path: Path, + vocab_file: ModelFile, bos_piece: str = "[CLS]", eos_piece: str = "[SEP]", unk_piece: str = "[UNK]", @@ -270,8 +270,8 @@ def from_files( """ Construct a tokenizer from the vocabulary file. - :param vocab_path: - Path to the vocabulary file. + :param vocab_file: + The vocabulary file. :param bos_piece: The piece to use to mark the beginning of a sequence. :param eos_piece: @@ -284,7 +284,7 @@ def from_files( Strip accents from text. """ vocab: Dict[str, int] = {} - with open(vocab_path, encoding="utf8") as f: + with vocab_file.open(mode="r", encoding="utf8") as f: for line in f: vocab[line.strip()] = len(vocab) @@ -311,7 +311,7 @@ def eos_piece(self) -> Optional[str]: def _load_from_vocab_files( cls: Type[Self], *, - vocab_files: Dict[str, Path], + vocab_files: Mapping[str, ModelFile], tokenizer_config: Optional[Dict[str, Any]], ) -> Self: extra_kwargs = {} @@ -329,7 +329,7 @@ def _load_from_vocab_files( strip_accents is not False and lowercase ) - return cls.from_files(vocab_path=vocab_files["vocab"], **extra_kwargs) + return cls.from_files(vocab_file=vocab_files["vocab"], **extra_kwargs) def _encode(self, input: Iterable[MergedInputChunks]) -> PiecesWithIds: ids = [] diff --git a/curated_transformers/tokenizers/legacy/camembert_tokenizer.py b/curated_transformers/tokenizers/legacy/camembert_tokenizer.py index 1b0594fa..d76f0033 100644 --- a/curated_transformers/tokenizers/legacy/camembert_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/camembert_tokenizer.py @@ -1,8 +1,8 @@ -from pathlib import Path -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Any, Dict, Mapping, Optional, Type, TypeVar from curated_tokenizers import SentencePieceProcessor +from ...util.serde import ModelFile from ..hf_hub import LegacyFromHFHub from ._fairseq import FAIRSEQ_PIECE_IDS, FairSeqPostEncoder, FairSeqPreDecoder from .legacy_tokenizer import AddBosEosPreEncoder @@ -111,21 +111,22 @@ def __init__( def from_files( cls: Type[Self], *, - model_path: Path, + model_file: ModelFile, bos_piece: str = "", eos_piece: str = "", ) -> Self: """ Construct a tokenizer from vocabulary and merge files. - :param model_path: - Path to the SentencePiece model file. + :param model_file: + The SentencePiece model file. :param bos_piece: The piece to use to mark the beginning of a sequence. :param eos_piece: The piece to use to mark the end of a sequence. """ - processor = SentencePieceProcessor.from_file(str(model_path)) + with model_file.open() as f: + processor = SentencePieceProcessor.from_file(f) return cls( processor=processor, bos_piece=bos_piece, @@ -136,10 +137,10 @@ def from_files( def _load_from_vocab_files( cls: Type[Self], *, - vocab_files: Dict[str, Path], + vocab_files: Mapping[str, ModelFile], tokenizer_config: Optional[Dict[str, Any]], ) -> Self: - return cls.from_files(model_path=vocab_files["model"]) + return cls.from_files(model_file=vocab_files["model"]) def _get_piece_id_or_fail(processor: SentencePieceProcessor, piece: str) -> int: diff --git a/curated_transformers/tokenizers/legacy/llama_tokenizer.py b/curated_transformers/tokenizers/legacy/llama_tokenizer.py index 4650a0bd..d8bffe37 100644 --- a/curated_transformers/tokenizers/legacy/llama_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/llama_tokenizer.py @@ -1,8 +1,8 @@ -from pathlib import Path -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Any, Dict, Mapping, Optional, Type, TypeVar from curated_tokenizers import SentencePieceProcessor +from ...util.serde import ModelFile from ..hf_hub import LegacyFromHFHub from .legacy_tokenizer import AddBosEosPreEncoder from .sentencepiece_tokenizer import SentencePieceTokenizer @@ -53,21 +53,22 @@ def __init__( def from_files( cls: Type[Self], *, - model_path: Path, + model_file: ModelFile, add_bos_piece: bool = True, add_eos_piece: bool = False, ) -> Self: """ Construct a Llama tokenizer from a SentencePiece model. - :param model_path: - Path to the SentencePiece model file. + :param model_file: + The SentencePiece model file. :param add_bos_piece: Add a begin-of-sequence piece. :param add_eos_piece: Add an end-of-sequence piece. """ - processor = SentencePieceProcessor.from_file(str(model_path)) + with model_file.open() as f: + processor = SentencePieceProcessor.from_file(f) return cls( processor=processor, add_bos_piece=add_bos_piece, @@ -78,17 +79,17 @@ def from_files( def _load_from_vocab_files( cls: Type[Self], *, - vocab_files: Dict[str, Path], + vocab_files: Mapping[str, ModelFile], tokenizer_config: Optional[Dict[str, Any]], ) -> Self: if tokenizer_config is None: - return cls.from_files(model_path=vocab_files["model"]) + return cls.from_files(model_file=vocab_files["model"]) add_bos_piece = tokenizer_config.get("add_bos_token", True) add_eos_piece = tokenizer_config.get("add_eos_token", False) return cls.from_files( - model_path=vocab_files["model"], + model_file=vocab_files["model"], add_bos_piece=add_bos_piece, add_eos_piece=add_eos_piece, ) diff --git a/curated_transformers/tokenizers/legacy/roberta_tokenizer.py b/curated_transformers/tokenizers/legacy/roberta_tokenizer.py index d92fbb74..a3213dad 100644 --- a/curated_transformers/tokenizers/legacy/roberta_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/roberta_tokenizer.py @@ -1,8 +1,8 @@ -from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, TypeVar from curated_tokenizers import ByteBPEProcessor +from ...util.serde import ModelFile from ..hf_hub import LegacyFromHFHub from ..util import remove_pieces_from_sequence from .bbpe_tokenizer import ByteBPETokenizer @@ -87,26 +87,26 @@ def __init__( def from_files( cls: Type[Self], *, - vocab_path: Path, - merges_path: Path, + vocab_file: ModelFile, + merges_file: ModelFile, bos_piece: str = "", eos_piece: str = "", ) -> Self: """ Construct a tokenizer from vocabulary and merge files. - :param vocab_path: - Path to the vocabulary file. - :param merges_path: - Path to the merges file. + :param vocab_file: + The vocabulary file. + :param merges_file: + The merges file. :param bos_piece: The piece to use to mark the beginning of a sequence. :param eos_piece: The piece to use to mark the end of a sequence. """ - processor = ByteBPEProcessor.load_from_files( - vocab=vocab_path, merges=merges_path - ) + with vocab_file.open(mode="r", encoding="utf-8") as vocab: + with merges_file.open(mode="r", encoding="utf-8") as merges: + processor = ByteBPEProcessor.load_from_files(vocab=vocab, merges=merges) return cls( # This is a bit annoying, but we want to avoid these extremely # overloaded constructors. @@ -124,11 +124,11 @@ def eos_piece(self) -> Optional[str]: def _load_from_vocab_files( cls: Type[Self], *, - vocab_files: Dict[str, Path], + vocab_files: Mapping[str, ModelFile], tokenizer_config: Optional[Dict[str, Any]], ) -> Self: return cls.from_files( - vocab_path=vocab_files["vocab"], merges_path=vocab_files["merges"] + vocab_file=vocab_files["vocab"], merges_file=vocab_files["merges"] ) diff --git a/curated_transformers/tokenizers/legacy/xlmr_tokenizer.py b/curated_transformers/tokenizers/legacy/xlmr_tokenizer.py index 9289e610..75a11091 100644 --- a/curated_transformers/tokenizers/legacy/xlmr_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/xlmr_tokenizer.py @@ -1,8 +1,8 @@ -from pathlib import Path -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Any, Dict, Mapping, Optional, Type, TypeVar from curated_tokenizers import SentencePieceProcessor +from ...util.serde import ModelFile from ..hf_hub import LegacyFromHFHub from ._fairseq import FAIRSEQ_PIECE_IDS, FairSeqPostEncoder, FairSeqPreDecoder from .legacy_tokenizer import AddBosEosPreEncoder @@ -112,22 +112,23 @@ def __init__( def from_files( cls: Type[Self], *, - model_path: Path, + model_file: ModelFile, ) -> Self: """ Construct a XLM-R tokenizer from a SentencePiece model. - :param model_path: - Path to the SentencePiece model file. + :param model_file: + The SentencePiece model file. """ - processor = SentencePieceProcessor.from_file(str(model_path)) + with model_file.open() as f: + processor = SentencePieceProcessor.from_file(f) return cls(processor=processor) @classmethod def _load_from_vocab_files( cls: Type[Self], *, - vocab_files: Dict[str, Path], + vocab_files: Mapping[str, ModelFile], tokenizer_config: Optional[Dict[str, Any]], ) -> Self: - return cls.from_files(model_path=vocab_files["model"]) + return cls.from_files(model_file=vocab_files["model"]) diff --git a/curated_transformers/tokenizers/tokenizer.py b/curated_transformers/tokenizers/tokenizer.py index 85dc6395..a69e46e5 100644 --- a/curated_transformers/tokenizers/tokenizer.py +++ b/curated_transformers/tokenizers/tokenizer.py @@ -5,11 +5,14 @@ from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union, cast import torch +from fsspec import AbstractFileSystem from huggingface_hub.utils import EntryNotFoundError from tokenizers import Tokenizer as HFTokenizer from torch import Tensor from ..layers.attention import AttentionMask +from ..util.fsspec import get_special_tokens_map as get_special_tokens_map_fsspec +from ..util.fsspec import get_tokenizer_config as get_tokenizer_config_fsspec from ..util.hf import ( HF_TOKENIZER_CONFIG, SPECIAL_TOKENS_MAP, @@ -317,7 +320,7 @@ def from_dir(cls: Type[Self], path: Path) -> Self: Load the tokenizer from a directory with a ``tokenizer.json`` file. :param path: - Path to the tokenizer file. + Path to the tokenizer directory. """ tokenizer_path = path / TOKENIZER_JSON config_path = path / HF_TOKENIZER_CONFIG @@ -355,6 +358,30 @@ def from_hf_hub_to_cache( except EntryNotFoundError: pass + @classmethod + def from_fsspec( + cls: Type[Self], + *, + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Self: + tokenizer_path = f"{model_path}/tokenizer.json" + if not fs.exists(tokenizer_path, **kwargs): + raise ValueError(f"Path cannot be found: {tokenizer_path}") + with fs.open(tokenizer_path) as f: + hf_tokenizer = HFTokenizer.from_buffer(f.read()) + + config = get_tokenizer_config_fsspec(fs, model_path, fsspec_args) + special_tokens_map = get_special_tokens_map_fsspec(fs, model_path, fsspec_args) + + return cls( + tokenizer=hf_tokenizer, + config=config, + special_tokens_map=special_tokens_map, + ) + @classmethod def from_hf_hub(cls: Type[Self], *, name: str, revision: str = "main") -> Self: # We cannot directly use `HFTokenizer.from_pretrained`` to instantiate the HF diff --git a/curated_transformers/util/fsspec.py b/curated_transformers/util/fsspec.py new file mode 100644 index 00000000..3a27c79a --- /dev/null +++ b/curated_transformers/util/fsspec.py @@ -0,0 +1,291 @@ +import json +import os +from typing import Any, Dict, List, Optional, Tuple + +from fsspec import AbstractFileSystem + +from .._compat import has_safetensors +from .hf import ( + HF_MODEL_CONFIG, + HF_TOKENIZER_CONFIG, + PRIMARY_CHECKPOINT_FILENAMES, + SHARDED_CHECKPOINT_INDEX_FILENAMES, + SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY, + SPECIAL_TOKENS_MAP, +) +from .serde import ( + _MODEL_CHECKPOINT_TYPE, + FsspecModelFile, + ModelCheckpointType, + ModelFile, +) + + +def get_file_metadata( + *, + fs: AbstractFileSystem, + model_path: str, + filename: str, + fsspec_args: Optional[Dict[str, Any]] = None, +) -> Optional[Dict[str, Any]]: + """ + Get a file from a model on an fsspec filesystem. + + :param fs: + The filesystem on which the model is stored. + :param model_path: + The path of the model on the filesystem. + :param filename: + The file to get metadata for. + :param fsspec_args: + Implementation-specific keyword arguments to pass to fsspec + filesystem operations. + :returns: + File metadata as a dictionary or ``None`` if the file does not + exist. + + """ + index = get_path_index(fs, model_path, fsspec_args=fsspec_args) + return index.get(filename) + + +def get_model_config( + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """ + Get the configuation of a model on an fsspec filesystem. + + :param fs: + The filesystem on which the model is stored. + :param model_path: + The path of the model on the filesystem. + :param fsspec_args: + Implementation-specific keyword arguments to pass to fsspec + filesystem operations. + :returns: + The model configuration. + """ + config = _get_and_parse_json_file( + fs, + path=f"{model_path}/{HF_MODEL_CONFIG}", + fsspec_args=fsspec_args, + ) + if config is None: + raise ValueError( + f"Cannot open model config path: {model_path}/{HF_MODEL_CONFIG}" + ) + return config + + +def get_config_model_type( + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, +) -> Optional[str]: + """ + Get the type of a model on an fsspec filesystem. + + :param fs: + The filesystem on which the model is stored. + :param model_path: + The path of the model on the filesystem. + :param fsspec_args: + Implementation-specific keyword arguments to pass to fsspec + filesystem operations. + :returns: + The model type. + """ + config = get_model_config(fs, model_path, fsspec_args=fsspec_args) + return config.get("model_type") + + +def get_tokenizer_config( + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, +) -> Optional[Dict[str, Any]]: + """ + Get the configuration of a tokenizer on an fsspec filesystem. + + :param fs: + The filesystem on which the model is stored. + :param model_path: + The path of the model on the filesystem. + :param fsspec_args: + Implementation-specific keyword arguments to pass to fsspec + filesystem operations. + :returns: + Deserialized tokenizer configuration. + """ + return _get_and_parse_json_file( + fs, + path=f"{model_path}/{HF_TOKENIZER_CONFIG}", + fsspec_args=fsspec_args, + ) + + +def get_special_tokens_map( + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, +) -> Optional[Dict[str, Any]]: + """ + Get the special token mapping of a tokenizer on an fsspec filesystem. + + :param fs: + The filesystem on which the model is stored. + :param model_path: + The path of the model on the filesystem. + :param fsspec_args: + Implementation-specific keyword arguments to pass to fsspec + filesystem operations. + :returns: + Deserialized special token_map. + """ + return _get_and_parse_json_file( + fs, path=f"{model_path}/{SPECIAL_TOKENS_MAP}", fsspec_args=fsspec_args + ) + + +def _get_and_parse_json_file( + fs: AbstractFileSystem, + *, + path: str, + fsspec_args: Optional[Dict[str, Any]] = None, +) -> Optional[Dict[str, Any]]: + """ + Get a JSON file from an fsspec filesystem and parse it. + + :param fs: + The filesystem on which the model is stored. + :param path: + The path of the JSON file. + :param fsspec_args: + Implementation-specific keyword arguments to pass to fsspec + filesystem operations. + :returns: + List of absolute paths to the checkpoints + and the checkpoint type. + """ + fsspec_args = {} if fsspec_args is None else fsspec_args + + if not fs.exists(path, **fsspec_args): + return None + + with fs.open(path, "r", encoding="utf-8", **fsspec_args) as f: + config = json.load(f) + return config + + +def get_model_checkpoint_files( + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[Dict[str, Any]] = None, +) -> Tuple[List[ModelFile], ModelCheckpointType]: + """ + Return a list of local file paths to checkpoints that belong to the model + on an fsspec filesystem. In case of non-sharded models, a single file path + is returned. In case of sharded models, multiple file paths are returned. + + :param fs: + The filesystem on which the model is stored. + :param model_path: + The path of the model on the filesystem. + :param fsspec_args: + Implementation-specific keyword arguments to pass to fsspec + filesystem operations. + :returns: + List of absolute paths to the checkpoints + and the checkpoint type. + """ + fsspec_args = {} if fsspec_args is None else fsspec_args + + def get_checkpoint_paths( + checkpoint_type: ModelCheckpointType, + ) -> List[ModelFile]: + index = get_path_index(fs, model_path, fsspec_args=fsspec_args) + + # Attempt to download a non-sharded checkpoint first. + entry = index.get(PRIMARY_CHECKPOINT_FILENAMES[checkpoint_type]) + if entry is not None: + return [FsspecModelFile(fs, entry["name"], fsspec_args)] + + # Try sharded checkpoint. + index_filename = SHARDED_CHECKPOINT_INDEX_FILENAMES[checkpoint_type] + entry = index.get(index_filename) + if entry is None: + raise ValueError( + f"Couldn't find a valid {checkpoint_type.pretty_name} checkpoint for " + f"model with path `{model_path}`. Could not open {index_filename}" + ) + + with fs.open(entry["name"], "rb", **fsspec_args) as f: + index = json.load(f) + + weight_map = index.get(SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY) + if not isinstance(weight_map, dict): + raise ValueError( + f"Invalid index file in sharded {checkpoint_type.pretty_name} " + f"checkpoint for model with path `{model_path}`" + ) + + filepaths = [] + # We shouldn't need to hold on to the weights map in the index as each checkpoint + # should contain its constituent parameter names. + for filename in set(weight_map.values()): + filepaths.append(f"{model_path}/{filename}") + + return [FsspecModelFile(fs, path, fsspec_args) for path in sorted(filepaths)] + + checkpoint_type = _MODEL_CHECKPOINT_TYPE.get() + checkpoint_paths: Optional[List[ModelFile]] = None + + if checkpoint_type is None: + # Precedence: Safetensors > PyTorch + if has_safetensors: + try: + checkpoint_type = ModelCheckpointType.SAFE_TENSORS + checkpoint_paths = get_checkpoint_paths(checkpoint_type) + except ValueError: + pass + if checkpoint_paths is None: + checkpoint_type = ModelCheckpointType.PYTORCH_STATE_DICT + checkpoint_paths = get_checkpoint_paths(checkpoint_type) + else: + checkpoint_paths = get_checkpoint_paths(checkpoint_type) + + assert checkpoint_paths is not None + assert checkpoint_type is not None + return checkpoint_paths, checkpoint_type + + +def get_path_index( + fs: AbstractFileSystem, + path: str, + fsspec_args: Optional[Dict[str, Any]] = None, +) -> Dict[str, Dict[str, Any]]: + """ + Get the files and their metadata of a model on an fsspec filesystem. + + :param fs: + The filesystem on which the model is stored. + :param path: + The path to return the index for. + :param fsspec_args: + Implementation-specific keyword arguments to pass to fsspec + filesystem operations. + :returns: + List of absolute paths to the checkpoints + and the checkpoint type. + """ + fsspec_args = {} if fsspec_args is None else fsspec_args + + try: + return { + os.path.basename(entry["name"]): entry + for entry in fs.ls(path, **fsspec_args) + } + except FileNotFoundError: + raise ValueError(f"Path cannot be found: {path}") diff --git a/curated_transformers/util/hf.py b/curated_transformers/util/hf.py index 49ce21c9..3fafef09 100644 --- a/curated_transformers/util/hf.py +++ b/curated_transformers/util/hf.py @@ -6,7 +6,12 @@ from requests import HTTPError, ReadTimeout # type: ignore from .._compat import has_safetensors -from .serde import _MODEL_CHECKPOINT_TYPE, ModelCheckpointType +from .serde import ( + _MODEL_CHECKPOINT_TYPE, + LocalModelFile, + ModelCheckpointType, + ModelFile, +) HF_MODEL_CONFIG = "config.json" HF_MODEL_CHECKPOINT = "pytorch_model.bin" @@ -47,39 +52,38 @@ def get_file_metadata( return huggingface_hub.get_hf_file_metadata(url) -def get_hf_config_model_type(name: str, revision: str) -> str: +def get_config_model_type(name: str, revision: str) -> str: """ Get the type of a model on Hugging Face Hub. - :param filename: - The file to get the type of. :param name: - Model name. + The model to get the type of. + :param revision: + The revision of the model. """ - config_filename = get_model_config_filepath(name, revision) - with open(config_filename, "r") as f: - config = json.load(f) - model_type = config.get("model_type") - if model_type is None: - raise ValueError("Model type not found in Hugging Face model config") - return model_type + config = get_model_config(name, revision) + model_type = config.get("model_type") + if model_type is None: + raise ValueError( + f"Model type not found in Hugging Face model config for model '{name}' ({revision})" + ) + return model_type -def get_model_config_filepath(name: str, revision: str) -> str: +def get_model_config(name: str, revision: str) -> Dict[str, Any]: """ - Return the local file path of the Hugging Face model's config. - If the config is not found in the cache, it is downloaded from - Hugging Face Hub. + Return the model's configuration. If the config is not found in the + cache, it is downloaded from Hugging Face Hub. :param name: Model name. :param revision: Model revision. :returns: - Absolute path to the configuration file. + Model configuration. """ try: - return hf_hub_download( + path = hf_hub_download( repo_id=name, filename=HF_MODEL_CONFIG, revision=revision ) except: @@ -88,10 +92,14 @@ def get_model_config_filepath(name: str, revision: str) -> str: f"(revision `{revision}`) on HuggingFace Model Hub" ) + with open(path, "r") as f: + config = json.load(f) + return config + -def get_model_checkpoint_filepaths( +def get_model_checkpoint_files( name: str, revision: str -) -> Tuple[List[str], ModelCheckpointType]: +) -> Tuple[List[ModelFile], ModelCheckpointType]: """ Return a list of local file paths to checkpoints that belong to the Hugging Face model. In case of non-sharded models, a single file path is returned. In @@ -111,7 +119,7 @@ def get_model_checkpoint_filepaths( def get_checkpoint_paths( checkpoint_type: ModelCheckpointType, - ) -> List[str]: + ) -> List[ModelFile]: # Attempt to download a non-sharded checkpoint first. try: model_filename = hf_hub_download( @@ -124,7 +132,7 @@ def get_checkpoint_paths( model_filename = None if model_filename is not None: - return [model_filename] + return [LocalModelFile(model_filename)] try: model_index_filename = hf_hub_download( @@ -142,7 +150,7 @@ def get_checkpoint_paths( index = json.load(f) weight_map = index.get(SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY) - if weight_map is None or not isinstance(weight_map, dict): + if not isinstance(weight_map, dict): raise ValueError( f"Invalid index file in sharded {checkpoint_type.pretty_name} " f"checkpoint for model `{name}`" @@ -157,10 +165,10 @@ def get_checkpoint_paths( ) filepaths.append(resolved_filename) - return sorted(filepaths) + return [LocalModelFile(path) for path in sorted(filepaths)] checkpoint_type = _MODEL_CHECKPOINT_TYPE.get() - checkpoint_paths: Optional[List[str]] = None + checkpoint_paths: Optional[List[ModelFile]] = None if checkpoint_type is None: # Precedence: Safetensors > PyTorch diff --git a/curated_transformers/util/serde.py b/curated_transformers/util/serde.py index 8a579992..6794e5e0 100644 --- a/curated_transformers/util/serde.py +++ b/curated_transformers/util/serde.py @@ -1,8 +1,11 @@ +from abc import ABC, abstractmethod from contextlib import contextmanager from contextvars import ContextVar from enum import Enum from typing import ( + IO, TYPE_CHECKING, + Any, Callable, Dict, Iterable, @@ -13,6 +16,7 @@ ) import torch +from fsspec import AbstractFileSystem from torch.nn import Module, Parameter from .._compat import has_safetensors @@ -33,6 +37,100 @@ [Mapping[str, torch.Tensor]], Mapping[str, torch.Tensor] ] +PathOrFileDescriptor = Union[str, IO] + + +class ModelFile(ABC): + """ + Model files can be a local path or a remote path exposed as e.g. an I/O + stream. This is a common base class for such different types of model + files. + """ + + @abstractmethod + def open(self, mode: str = "rb", encoding: Optional[str] = None) -> IO: + """ + Get the model file as an I/O stream. + + :param mode: + Mode to open the file with (see Python ``open``). + :param encoding: + Encoding to use when the file is opened as text. + :returns: + An I/O stream. + """ + ... + + @property + @abstractmethod + def path(self) -> Optional[str]: + """ + Get the model file as a local path. If the model file is not + available as a local path, the value of this property is + ``None``. + """ + ... + + +class FsspecModelFile(ModelFile): + """ + Model file on an fsspec filesystem. + """ + + def __init__( + self, + fs: AbstractFileSystem, + path: str, + fsspec_args: Optional[Dict[str, Any]] = None, + ): + """ + Construct an fsspec model file representation. + + :param fs: + The filesystem. + :param path: + The path of the model file on the filesystem. + :param fsspec_args: + Implementation-specific keyword arguments to pass to fsspec + filesystem operations. + """ + super().__init__() + self._fs = fs + self._path = path + self._fsspec_args = fsspec_args + + def open(self, mode: str = "rb", encoding: Optional[str] = None) -> IO: + return self._fs.open( + self._path, mode=mode, encoding=encoding, **self._fsspec_args + ) + + @property + def path(self) -> Optional[str]: + return None + + +class LocalModelFile(ModelFile): + """ + Model file on the local host machine. + """ + + def __init__(self, path: str): + """ + Construct a local model file representation. + + :param path: + The path of the model file on the local filesystem. + """ + super().__init__() + self._path = path + + def open(self, mode: str = "rb", encoding: Optional[str] = None) -> IO: + return open(self._path, mode=mode, encoding=encoding) + + @property + def path(self) -> Optional[str]: + return self._path + class ModelCheckpointType(Enum): """ @@ -46,7 +144,9 @@ class ModelCheckpointType(Enum): SAFE_TENSORS = 1 @property - def loader(self) -> Callable[[Iterable[str]], Iterable[Mapping[str, torch.Tensor]]]: + def loader( + self, + ) -> Callable[[Iterable[ModelFile]], Iterable[Mapping[str, torch.Tensor]]]: checkpoint_type_to_loader = { ModelCheckpointType.PYTORCH_STATE_DICT: _load_pytorch_state_dicts_from_checkpoints, ModelCheckpointType.SAFE_TENSORS: _load_safetensor_state_dicts_from_checkpoints, @@ -95,7 +195,7 @@ def _use_model_checkpoint_type( def load_model_from_checkpoints( model: Module, *, - filepaths: Iterable[str], + filepaths: Iterable[ModelFile], checkpoint_type: ModelCheckpointType, state_dict_converter: HFStateDictConverterT, tensor_to_param_converter: Optional[TensorToParameterConverterT] = None, @@ -278,27 +378,38 @@ def _validate_replacement( def _load_safetensor_state_dicts_from_checkpoints( - filepaths: Iterable[str], + checkpoints: Iterable[ModelFile], ) -> Iterable[Mapping[str, torch.Tensor]]: if not has_safetensors: raise ValueError( "The `safetensors` library is required to load models from Safetensors checkpoints" ) - import safetensors + import safetensors.torch - for path in filepaths: - # Map to CPU first to support all devices. - state_dict = safetensors.torch.load_file(path, device="cpu") + for checkpoint in checkpoints: + # Prefer to load from a path when possible. Since loading from a file + # temporarily puts the checkpoint in memory twice. + if checkpoint.path is not None: + # Map to CPU first to support all devices. + state_dict = safetensors.torch.load_file(checkpoint.path, device="cpu") + else: + with checkpoint.open() as f: + # This has memory overhead, since Safetensors does not have + # support for loading from a file object and cannot use + # the bytes in-place. + checkpoint_bytes = f.read() + state_dict = safetensors.torch.load(checkpoint_bytes) yield state_dict def _load_pytorch_state_dicts_from_checkpoints( - filepaths: Iterable[str], + checkpoints: Iterable[ModelFile], ) -> Iterable[Mapping[str, torch.Tensor]]: - for path in filepaths: - # Map to CPU first to support all devices. - state_dict = torch.load( - path, map_location=torch.device("cpu"), weights_only=True - ) + for checkpoint in checkpoints: + with checkpoint.open() as f: + # Map to CPU first to support all devices. + state_dict = torch.load( + f, map_location=torch.device("cpu"), weights_only=True + ) yield state_dict diff --git a/docs/source/api-compat.rst b/docs/source/api-compat.rst index a38cd4dd..f1dee1fb 100644 --- a/docs/source/api-compat.rst +++ b/docs/source/api-compat.rst @@ -109,3 +109,4 @@ Version 1 to 2 * The factory methods of :py:class:`~curated_transformers.layers.AttentionHeads` add a new ``qkv_split`` argument which is mandatory in future versions. +* The ``FromHFHub`` mixins will be renamed to ``FromHF``. diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 13efc508..f13a5182 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -85,6 +85,9 @@ For more information about the different configs and generators supported by Cur Loading a Model --------------- +Hugging Face Hub +^^^^^^^^^^^^^^^^ + Curated Transformers allows users to easily load model weights from the `Hugging Face Model Hub`_. All models provide a ``from_hf_hub`` method that allows directly loading pre-trained model parameters from Hugging Face Model Hub. @@ -126,6 +129,38 @@ and :py:class:`~curated_transformers.models.AutoCausalLM` classes can be used to lm = AutoCausalLM.from_hf_hub(name="databricks/dolly-v2-3b", revision="main") +fsspec filesystem +^^^^^^^^^^^^^^^^^ + +Curated Transformers also supports loading models from `fsspec`_ filesystems. This +makes it possible to load local models or loading models from cloud services +without using any local storage. A model can be downloaded from an fsspec filesystem +using the ``from_fsspec`` method. + +.. _fsspec: https://filesystem-spec.readthedocs.io + +.. code-block:: python + + import torch + from curated_transformers.models import BERTEncoder + from fsspec.implementations.local import LocalFileSystem + from huggingface_hub import HfFileSystem + + encoder = BERTEncoder.from_fsspec( + fs=LocalFileSystem(), + model_path="/srv/models/bert-base-uncased", + device=torch.device("cuda", index=0), + ) + + # Pass additional arguments to the specific fsspec implementation. + encoder = BERTEncoder.from_fsspec( + fs=HfFileSystem(), + model_path="bert-base-uncased", + fsspec_args={"revision": "a265f773a47193eed794233aa2a0f0bb6d3eaa63"}, + device=torch.device("cuda", index=0), + ) + + Quantization ------------ diff --git a/requirements.txt b/requirements.txt index 458cbf3e..8d0472b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -curated-tokenizers>=0.9.0.dev0,<1.0.0 +curated-tokenizers>=0.9.1,<1.0.0 huggingface-hub>=0.14 tokenizers>=0.13.3 torch>=1.12.0 diff --git a/setup.cfg b/setup.cfg index 0632ffe6..70b92871 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,7 +14,7 @@ zip_safe = true include_package_data = true python_requires = >=3.8 install_requires = - curated-tokenizers>=0.9.0.dev0,<1.0.0 + curated-tokenizers>=0.9.1,<1.0.0 huggingface-hub>=0.14 tokenizers>=0.13.3 torch>=1.12.0