Skip to content

Commit

Permalink
Add support for loading models from an fsspec filesystem (#327)
Browse files Browse the repository at this point in the history
* Add support for loading models from an fsspec filesystem

This change adds a method `from_fsspec` to the `FromHFHub` mixins. This
new method can be used to load a model from an fsspec filesystem.

We should rename `FromHFHub` in the next major semver release, since
this PR extends loading support beyond Hugging Face hub.

* Doc fix

Co-authored-by: Madeesh Kannan <[email protected]>

* Name functions more generically

* Apply suggestions from code review

Add fixes from @shadeMe

Co-authored-by: Madeesh Kannan <[email protected]>

* Add missing type hints

* Fix argument order

* Remove an exception

* black

* Type fix

---------

Co-authored-by: Madeesh Kannan <[email protected]>
  • Loading branch information
danieldk and shadeMe authored Sep 14, 2023
1 parent 9b3db13 commit 92bfaab
Show file tree
Hide file tree
Showing 28 changed files with 1,016 additions and 161 deletions.
127 changes: 124 additions & 3 deletions curated_transformers/models/auto_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from abc import ABC, abstractmethod
from typing import Dict, Generic, Optional, Type, TypeVar
from typing import Any, Dict, Generic, Optional, Type, TypeVar

import torch
from fsspec import AbstractFileSystem

from ..layers.cache import KeyValueCache
from ..quantization.bnb.config import BitsAndBytesConfig
from ..util.hf import get_hf_config_model_type
from ..util.fsspec import get_config_model_type as get_config_model_type_fsspec
from ..util.hf import get_config_model_type
from .albert import ALBERTEncoder
from .bert import BERTEncoder
from .camembert import CamemBERTEncoder
Expand All @@ -30,13 +32,36 @@ class AutoModel(ABC, Generic[ModelT]):

_hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {}

@classmethod
def _resolve_model_cls_fsspec(
cls,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
) -> Type[FromHFHub]:
model_type = get_config_model_type_fsspec(
fs, model_path, fsspec_args=fsspec_args
)
if model_type is None:
raise ValueError(
"The model type is not defined in the model configuration."
)
module_cls = cls._hf_model_type_to_curated.get(model_type)
if module_cls is None:
raise ValueError(
f"Unsupported model type `{model_type}` for {cls.__name__}. "
f"Supported model types: {tuple(cls._hf_model_type_to_curated.keys())}"
)
assert issubclass(module_cls, FromHFHub)
return module_cls

@classmethod
def _resolve_model_cls(
cls,
name: str,
revision: str,
) -> Type[FromHFHub]:
model_type = get_hf_config_model_type(name, revision)
model_type = get_config_model_type(name, revision)
module_cls = cls._hf_model_type_to_curated.get(model_type)
if module_cls is None:
raise ValueError(
Expand All @@ -46,6 +71,25 @@ def _resolve_model_cls(
assert issubclass(module_cls, FromHFHub)
return module_cls

@classmethod
def _instantiate_model_from_fsspec(
cls,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]],
device: Optional[torch.device],
quantization_config: Optional[BitsAndBytesConfig],
) -> FromHFHub:
module_cls = cls._resolve_model_cls_fsspec(fs, model_path)
module = module_cls.from_fsspec(
fs=fs,
model_path=model_path,
fsspec_args=fsspec_args,
device=device,
quantization_config=quantization_config,
)
return module

@classmethod
def _instantiate_model_from_hf_hub(
cls,
Expand All @@ -63,6 +107,35 @@ def _instantiate_model_from_hf_hub(
)
return module

@classmethod
def from_fsspec(
cls,
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> ModelT:
"""
Construct a module and load its parameters from a fsspec filesystem.
:param fs:
The filesystem to load the model from.
:param model_path:
The path of the model on the filesystem.
:param fsspec_args:
Implementation-specific keyword arguments to pass to fsspec
filesystem operations.
:param device:
Device on which the model is initialized.
:param quantization_config:
Configuration for loading quantized weights.
:returns:
Module with the parameters loaded.
"""
raise NotImplementedError

@classmethod
@abstractmethod
def from_hf_hub(
Expand Down Expand Up @@ -124,6 +197,22 @@ class AutoEncoder(AutoModel[EncoderModule]):
"xlm-roberta": XLMREncoder,
}

@classmethod
def from_fsspec(
cls,
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> EncoderModule:
encoder = cls._instantiate_model_from_fsspec(
fs, model_path, fsspec_args, device, quantization_config
)
assert isinstance(encoder, EncoderModule)
return encoder

@classmethod
def from_hf_hub(
cls,
Expand Down Expand Up @@ -154,6 +243,22 @@ class AutoDecoder(AutoModel[DecoderModule]):
"RefinedWebModel": FalconDecoder,
}

@classmethod
def from_fsspec(
cls,
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> DecoderModule:
decoder = cls._instantiate_model_from_fsspec(
fs, model_path, fsspec_args, device, quantization_config
)
assert isinstance(decoder, DecoderModule)
return decoder

@classmethod
def from_hf_hub(
cls,
Expand Down Expand Up @@ -184,6 +289,22 @@ class AutoCausalLM(AutoModel[CausalLMModule[KeyValueCache]]):
"RefinedWebModel": FalconCausalLM,
}

@classmethod
def from_fsspec(
cls,
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> CausalLMModule[KeyValueCache]:
causal_lm = cls._instantiate_model_from_fsspec(
fs, model_path, fsspec_args, device, quantization_config
)
assert isinstance(causal_lm, CausalLMModule)
return causal_lm

@classmethod
def from_hf_hub(
cls,
Expand Down
110 changes: 81 additions & 29 deletions curated_transformers/models/hf_hub.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from abc import ABC, abstractmethod
from typing import (
Any,
Expand All @@ -14,12 +13,17 @@
)

import torch
from fsspec import AbstractFileSystem
from torch import Tensor

from ..quantization import prepare_module_for_quantization
from ..quantization.bnb.config import BitsAndBytesConfig
from ..util.hf import get_model_checkpoint_filepaths, get_model_config_filepath
from ..util.serde import load_model_from_checkpoints
from ..util.fsspec import (
get_model_checkpoint_files as get_model_checkpoint_files_fsspec,
)
from ..util.fsspec import get_model_config as get_model_config_fsspec
from ..util.hf import get_model_checkpoint_files, get_model_config
from ..util.serde import ModelCheckpointType, ModelFile, load_model_from_checkpoints

# Only provided as typing.Self in Python 3.11+.
Self = TypeVar("Self", bound="FromHFHub")
Expand Down Expand Up @@ -89,8 +93,46 @@ def from_hf_hub_to_cache(
:param revision:
Model revision.
"""
_ = get_model_config_filepath(name, revision)
_ = get_model_checkpoint_filepaths(name, revision)
_ = get_model_config(name, revision)
_ = get_model_checkpoint_files(name, revision)

@classmethod
def from_fsspec(
cls: Type[Self],
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> Self:
"""
Construct a module and load its parameters from a fsspec filesystem.
:param fs:
The filesystem to load the model from.
:param model_path:
The path of the model on the filesystem.
:param fsspec_args:
Implementation-specific keyword arguments to pass to fsspec
filesystem operations.
:param device:
Device on which the model is initialized.
:param quantization_config:
Configuration for loading quantized weights.
:returns:
Module with the parameters loaded.
"""
return cls._create_and_load_model(
get_config=lambda: get_model_config_fsspec(
fs, model_path, fsspec_args=fsspec_args
),
get_checkpoint_files=lambda: get_model_checkpoint_files_fsspec(
fs, model_path, fsspec_args=fsspec_args
),
device=device,
quantization_config=quantization_config,
)

@classmethod
def from_hf_hub(
Expand All @@ -115,11 +157,39 @@ def from_hf_hub(
:returns:
Module with the parameters loaded.
"""
# Download configuration and construct model.
config_filename = get_model_config_filepath(name, revision)
with open(config_filename, "r") as f:
config = json.load(f)
# Initialize the model on the torch `meta` device to avoid unnecessary allocations.
return cls._create_and_load_model(
get_config=lambda: get_model_config(name, revision),
get_checkpoint_files=lambda: get_model_checkpoint_files(name, revision),
device=device,
quantization_config=quantization_config,
)

@abstractmethod
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
):
"""
Moves and/or casts the parameters and buffers.
This method is automatically implemented by also deriving from
``torch.nn.Module``. This mixin does not derive from ``Module`` in
order to be an abstract base class.
"""
...

@classmethod
def _create_and_load_model(
cls: Type[Self],
*,
get_config: Callable[[], Dict[Any, str]],
get_checkpoint_files: Callable[[], Tuple[List[ModelFile], ModelCheckpointType]],
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> Self:
config = get_config()
model = cls.from_hf_config(hf_config=config, device=torch.device("meta"))

# Convert the model to the expected dtype.
Expand All @@ -137,9 +207,7 @@ def from_hf_hub(
tensor2param = None

# Download model and convert HF parameter names to ours.
checkpoint_filenames, checkpoint_type = get_model_checkpoint_filepaths(
name, revision
)
checkpoint_filenames, checkpoint_type = get_checkpoint_files()
load_model_from_checkpoints(
model, # type:ignore
filepaths=checkpoint_filenames,
Expand All @@ -156,22 +224,6 @@ def from_hf_hub(

return model

@abstractmethod
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
):
"""
Moves and/or casts the parameters and buffers.
This method is automatically implemented by also deriving from
``torch.nn.Module``. This mixin does not derive from ``Module`` in
order to be an abstract base class.
"""
...


def _process_hf_keys(
model_name: str,
Expand Down
19 changes: 17 additions & 2 deletions curated_transformers/tests/models/test_auto_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from huggingface_hub import HfFileSystem

from curated_transformers.models import (
ALBERTEncoder,
Expand All @@ -22,15 +23,18 @@
from curated_transformers.models.mpt.decoder import MPTDecoder


def test_auto_encoder():
model_encoder_map = {
@pytest.fixture
def model_encoder_map():
return {
"explosion-testing/bert-test": BERTEncoder,
"explosion-testing/albert-test": ALBERTEncoder,
"explosion-testing/roberta-test": RoBERTaEncoder,
"explosion-testing/camembert-test": CamemBERTEncoder,
"explosion-testing/xlm-roberta-test": XLMREncoder,
}


def test_auto_encoder(model_encoder_map):
for name, encoder_cls in model_encoder_map.items():
encoder = AutoEncoder.from_hf_hub(name=name)
assert isinstance(encoder, encoder_cls)
Expand All @@ -39,6 +43,17 @@ def test_auto_encoder():
AutoEncoder.from_hf_hub(name="explosion-testing/falcon-test")


@pytest.mark.slow
def test_auto_encoder_fsspec(model_encoder_map):
for name, encoder_cls in model_encoder_map.items():
# The default revision is 'main', but we pass it anyway to test
# that the function acceps fsspec_args.
encoder = AutoEncoder.from_fsspec(
fs=HfFileSystem(), model_path=name, fsspec_args={"revision": "main"}
)
assert isinstance(encoder, encoder_cls)


def test_auto_decoder():
model_decoder_map = {
"explosion-testing/falcon-test": FalconDecoder,
Expand Down
Loading

0 comments on commit 92bfaab

Please sign in to comment.