Skip to content

Commit

Permalink
Restructure util.serde (#334)
Browse files Browse the repository at this point in the history
* Split `util.serde` into submodules
Move `ModelCheckpointType` and co to `util.serde.checkpoint`

* `isort`
  • Loading branch information
shadeMe authored Sep 28, 2023
1 parent 93cf07f commit 6773703
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 98 deletions.
2 changes: 1 addition & 1 deletion curated_transformers/models/hf_hub/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ...repository.fsspec import FsspecArgs, FsspecRepository
from ...repository.hf_hub import HfHubRepository
from ...repository.repository import ModelRepository, Repository
from ...util.serde import load_model_from_checkpoints
from ...util.serde.load import load_model_from_checkpoints
from ..module import TransformerModule

# Only provided as typing.Self in Python 3.11+.
Expand Down
2 changes: 1 addition & 1 deletion curated_transformers/quantization/bnb/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..._compat import has_bitsandbytes
from ...util.pytorch import ModuleIterator, apply_to_module
from ...util.serde import TensorToParameterConverterT
from ...util.serde.load import TensorToParameterConverterT
from .config import BitsAndBytesConfig, _4BitConfig, _8BitConfig

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion curated_transformers/quantization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch.nn import Module

from ..util.serde import TensorToParameterConverterT
from ..util.serde.load import TensorToParameterConverterT
from .bnb import prepare_for_quantization as bnb_prepare_for_quantization
from .bnb.config import BitsAndBytesConfig
from .quantizable import Quantizable
Expand Down
88 changes: 2 additions & 86 deletions curated_transformers/repository/_hf.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
from contextvars import ContextVar
from enum import Enum
from typing import TYPE_CHECKING, Callable, Iterable, Mapping, Optional

import torch

from .._compat import has_safetensors
from ..repository.file import RepositoryFile

if TYPE_CHECKING:
import safetensors
from typing import TYPE_CHECKING

from ..util.serde.checkpoint import ModelCheckpointType

HF_MODEL_CONFIG = "config.json"
HF_MODEL_CHECKPOINT = "pytorch_model.bin"
Expand All @@ -22,37 +13,6 @@
TOKENIZER_JSON = "tokenizer.json"


class ModelCheckpointType(Enum):
"""
Types of model checkpoints supported by Curated Transformers.
"""

#: PyTorch `checkpoint<https://pytorch.org/docs/stable/generated/torch.save.html>`_.
PYTORCH_STATE_DICT = 0

#: Hugging Face `Safetensors <https://github.com/huggingface/safetensors>`_ checkpoint.
SAFE_TENSORS = 1

@property
def loader(
self,
) -> Callable[[Iterable[RepositoryFile]], Iterable[Mapping[str, torch.Tensor]]]:
checkpoint_type_to_loader = {
ModelCheckpointType.PYTORCH_STATE_DICT: _load_pytorch_state_dicts_from_checkpoints,
ModelCheckpointType.SAFE_TENSORS: _load_safetensor_state_dicts_from_checkpoints,
}
return checkpoint_type_to_loader[self]

@property
def pretty_name(self) -> str:
if self == ModelCheckpointType.PYTORCH_STATE_DICT:
return "PyTorch StateDict"
elif self == ModelCheckpointType.SAFE_TENSORS:
return "SafeTensors"
else:
return ""


PRIMARY_CHECKPOINT_FILENAMES = {
ModelCheckpointType.PYTORCH_STATE_DICT: HF_MODEL_CHECKPOINT,
ModelCheckpointType.SAFE_TENSORS: HF_MODEL_CHECKPOINT_SAFETENSORS,
Expand All @@ -63,47 +23,3 @@ def pretty_name(self) -> str:
}
# Same for both checkpoint types.
SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY = HF_MODEL_SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY


# When `None`, behaviour is implementation-specific.
_MODEL_CHECKPOINT_TYPE: ContextVar[Optional[ModelCheckpointType]] = ContextVar(
"model_checkpoint_type", default=None
)


def _load_safetensor_state_dicts_from_checkpoints(
checkpoints: Iterable[RepositoryFile],
) -> Iterable[Mapping[str, torch.Tensor]]:
if not has_safetensors:
raise ValueError(
"The `safetensors` library is required to load models from Safetensors checkpoints"
)

import safetensors.torch

for checkpoint in checkpoints:
# Prefer to load from a path when possible. Since loading from a file
# temporarily puts the checkpoint in memory twice.
if checkpoint.path is not None:
# Map to CPU first to support all devices.
state_dict = safetensors.torch.load_file(checkpoint.path, device="cpu")
else:
with checkpoint.open() as f:
# This has memory overhead, since Safetensors does not have
# support for loading from a file object and cannot use
# the bytes in-place.
checkpoint_bytes = f.read()
state_dict = safetensors.torch.load(checkpoint_bytes)
yield state_dict


def _load_pytorch_state_dicts_from_checkpoints(
checkpoints: Iterable[RepositoryFile],
) -> Iterable[Mapping[str, torch.Tensor]]:
for checkpoint in checkpoints:
with checkpoint.open() as f:
# Map to CPU first to support all devices.
state_dict = torch.load(
f, map_location=torch.device("cpu"), weights_only=True
)
yield state_dict
3 changes: 1 addition & 2 deletions curated_transformers/repository/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
from typing import Any, Dict, List, Optional, Tuple

from .._compat import has_safetensors
from ..util.serde.checkpoint import _MODEL_CHECKPOINT_TYPE, ModelCheckpointType
from ._hf import (
_MODEL_CHECKPOINT_TYPE,
HF_MODEL_CONFIG,
HF_TOKENIZER_CONFIG,
PRIMARY_CHECKPOINT_FILENAMES,
SHARDED_CHECKPOINT_INDEX_FILENAMES,
SHARDED_CHECKPOINT_INDEX_WEIGHTS_KEY,
SPECIAL_TOKENS_MAP,
TOKENIZER_JSON,
ModelCheckpointType,
)
from .file import RepositoryFile

Expand Down
6 changes: 2 additions & 4 deletions curated_transformers/tests/models/test_hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from curated_transformers.models.bert.encoder import BERTEncoder
from curated_transformers.repository.hf_hub import HfHubRepository
from curated_transformers.repository.repository import ModelRepository
from curated_transformers.util.serde import (
ModelCheckpointType,
_use_model_checkpoint_type,
)
from curated_transformers.util.serde.checkpoint import ModelCheckpointType
from curated_transformers.util.serde.load import _use_model_checkpoint_type

from ..compat import has_hf_transformers, has_safetensors
from ..conftest import TORCH_DEVICES
Expand Down
Empty file.
86 changes: 86 additions & 0 deletions curated_transformers/util/serde/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from contextvars import ContextVar
from enum import Enum
from typing import TYPE_CHECKING, Callable, Iterable, Mapping, Optional

import torch

from ..._compat import has_safetensors
from ...repository.file import RepositoryFile

if TYPE_CHECKING:
import safetensors


class ModelCheckpointType(Enum):
"""
Types of model checkpoints supported by Curated Transformers.
"""

#: PyTorch `checkpoint<https://pytorch.org/docs/stable/generated/torch.save.html>`_.
PYTORCH_STATE_DICT = 0

#: Hugging Face `Safetensors <https://github.com/huggingface/safetensors>`_ checkpoint.
SAFE_TENSORS = 1

@property
def loader(
self,
) -> Callable[[Iterable[RepositoryFile]], Iterable[Mapping[str, torch.Tensor]]]:
checkpoint_type_to_loader = {
ModelCheckpointType.PYTORCH_STATE_DICT: _load_pytorch_state_dicts_from_checkpoints,
ModelCheckpointType.SAFE_TENSORS: _load_safetensor_state_dicts_from_checkpoints,
}
return checkpoint_type_to_loader[self]

@property
def pretty_name(self) -> str:
if self == ModelCheckpointType.PYTORCH_STATE_DICT:
return "PyTorch StateDict"
elif self == ModelCheckpointType.SAFE_TENSORS:
return "SafeTensors"
else:
return ""


# When `None`, behaviour is implementation-specific.
_MODEL_CHECKPOINT_TYPE: ContextVar[Optional[ModelCheckpointType]] = ContextVar(
"model_checkpoint_type", default=None
)


def _load_safetensor_state_dicts_from_checkpoints(
checkpoints: Iterable[RepositoryFile],
) -> Iterable[Mapping[str, torch.Tensor]]:
if not has_safetensors:
raise ValueError(
"The `safetensors` library is required to load models from Safetensors checkpoints"
)

import safetensors.torch

for checkpoint in checkpoints:
# Prefer to load from a path when possible. Since loading from a file
# temporarily puts the checkpoint in memory twice.
if checkpoint.path is not None:
# Map to CPU first to support all devices.
state_dict = safetensors.torch.load_file(checkpoint.path, device="cpu")
else:
with checkpoint.open() as f:
# This has memory overhead, since Safetensors does not have
# support for loading from a file object and cannot use
# the bytes in-place.
checkpoint_bytes = f.read()
state_dict = safetensors.torch.load(checkpoint_bytes)
yield state_dict


def _load_pytorch_state_dicts_from_checkpoints(
checkpoints: Iterable[RepositoryFile],
) -> Iterable[Mapping[str, torch.Tensor]]:
for checkpoint in checkpoints:
with checkpoint.open() as f:
# Map to CPU first to support all devices.
state_dict = torch.load(
f, map_location=torch.device("cpu"), weights_only=True
)
yield state_dict
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch
from torch.nn import Module, Parameter

from ..repository._hf import _MODEL_CHECKPOINT_TYPE, ModelCheckpointType
from ..repository.file import RepositoryFile
from .pytorch import ModuleIterator, apply_to_module
from ...repository.file import RepositoryFile
from ..pytorch import ModuleIterator, apply_to_module
from .checkpoint import _MODEL_CHECKPOINT_TYPE, ModelCheckpointType

# Args: Parent module, module prefix, parameter name, tensor to convert, device.
# Returns the new paramater.
Expand Down

0 comments on commit 6773703

Please sign in to comment.