Skip to content

Commit

Permalink
Persist model config to aid serialization (#328)
Browse files Browse the repository at this point in the history
* Persist model config and expose it to the API

* Add test

* Exclude tests with JIT

* Add protocol as a config typevar bound
  • Loading branch information
shadeMe authored Sep 18, 2023
1 parent 92bfaab commit 01e8902
Show file tree
Hide file tree
Showing 21 changed files with 112 additions and 47 deletions.
7 changes: 4 additions & 3 deletions curated_transformers/generation/default_generator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import dataclasses
from typing import List, Optional, Type, TypeVar
from typing import Any, Generic, List, Optional, Type, TypeVar

import torch

from ..models.auto_model import AutoCausalLM
from ..models.module import CausalLMModule
from ..models.output import CacheT
from ..quantization.bnb.config import BitsAndBytesConfig
from ..tokenizers.auto_tokenizer import AutoTokenizer
from ..tokenizers.chunks import InputChunks, TextChunk
Expand All @@ -19,15 +20,15 @@
Self = TypeVar("Self", bound="DefaultGenerator")


class DefaultGenerator(GeneratorWrapper, FromHFHub):
class DefaultGenerator(Generic[CacheT], GeneratorWrapper, FromHFHub):
"""
Generator wrapper for models that do not need specific prompting.
"""

def __init__(
self,
tokenizer: TokenizerBase,
causal_lm: CausalLMModule,
causal_lm: CausalLMModule[Any, CacheT],
default_config: Optional[GeneratorConfig] = None,
):
"""
Expand Down
6 changes: 3 additions & 3 deletions curated_transformers/generation/generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generic, Iterator, List, Optional, Tuple
from typing import Any, Generic, Iterator, List, Optional, Tuple

import torch
from torch import Tensor
Expand All @@ -17,9 +17,9 @@ class Generator(Generic[CacheT]):
Generator base class for causal language models.
"""

model: CausalLMModule[CacheT]
model: CausalLMModule[Any, CacheT]

def __init__(self, model: CausalLMModule[CacheT]):
def __init__(self, model: CausalLMModule[Any, CacheT]):
"""
Construct a generator.
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/albert/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Self = TypeVar("Self", bound="ALBERTEncoder")


class ALBERTEncoder(EncoderModule, FromHFHub):
class ALBERTEncoder(EncoderModule[ALBERTConfig], FromHFHub):
"""
ALBERT (`Lan et al., 2022`_) encoder.
Expand All @@ -39,7 +39,7 @@ def __init__(self, config: ALBERTConfig, *, device: Optional[torch.device] = Non
:returns:
The encoder.
"""
super().__init__()
super().__init__(config)

self.max_seq_len = config.model_max_length
self.n_hidden_layers = config.layer.n_hidden_layers
Expand Down
19 changes: 10 additions & 9 deletions curated_transformers/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .albert import ALBERTEncoder
from .bert import BERTEncoder
from .camembert import CamemBERTEncoder
from .config import ConfigDataclass
from .falcon import FalconCausalLM, FalconDecoder
from .gpt_neox import GPTNeoXCausalLM, GPTNeoXDecoder
from .hf_hub import FromHFHub
Expand Down Expand Up @@ -184,7 +185,7 @@ def from_hf_hub_to_cache(
module_cls.from_hf_hub_to_cache(name=name, revision=revision)


class AutoEncoder(AutoModel[EncoderModule]):
class AutoEncoder(AutoModel[EncoderModule[ConfigDataclass]]):
"""
Encoder model loaded from the Hugging Face Model Hub.
"""
Expand All @@ -206,7 +207,7 @@ def from_fsspec(
fsspec_args: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> EncoderModule:
) -> EncoderModule[ConfigDataclass]:
encoder = cls._instantiate_model_from_fsspec(
fs, model_path, fsspec_args, device, quantization_config
)
Expand All @@ -221,15 +222,15 @@ def from_hf_hub(
revision: str = "main",
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> EncoderModule:
) -> EncoderModule[ConfigDataclass]:
encoder = cls._instantiate_model_from_hf_hub(
name, revision, device, quantization_config
)
assert isinstance(encoder, EncoderModule)
return encoder


class AutoDecoder(AutoModel[DecoderModule]):
class AutoDecoder(AutoModel[DecoderModule[ConfigDataclass, KeyValueCache]]):
"""
Decoder module loaded from the Hugging Face Model Hub.
"""
Expand All @@ -252,7 +253,7 @@ def from_fsspec(
fsspec_args: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> DecoderModule:
) -> DecoderModule[ConfigDataclass, KeyValueCache]:
decoder = cls._instantiate_model_from_fsspec(
fs, model_path, fsspec_args, device, quantization_config
)
Expand All @@ -267,15 +268,15 @@ def from_hf_hub(
revision: str = "main",
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> DecoderModule:
) -> DecoderModule[ConfigDataclass, KeyValueCache]:
decoder = cls._instantiate_model_from_hf_hub(
name, revision, device, quantization_config
)
assert isinstance(decoder, DecoderModule)
return decoder


class AutoCausalLM(AutoModel[CausalLMModule[KeyValueCache]]):
class AutoCausalLM(AutoModel[CausalLMModule[ConfigDataclass, KeyValueCache]]):
"""
Causal LM model loaded from the Hugging Face Model Hub.
"""
Expand All @@ -298,7 +299,7 @@ def from_fsspec(
fsspec_args: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> CausalLMModule[KeyValueCache]:
) -> CausalLMModule[ConfigDataclass, KeyValueCache]:
causal_lm = cls._instantiate_model_from_fsspec(
fs, model_path, fsspec_args, device, quantization_config
)
Expand All @@ -313,7 +314,7 @@ def from_hf_hub(
revision: str = "main",
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> CausalLMModule[KeyValueCache]:
) -> CausalLMModule[ConfigDataclass, KeyValueCache]:
causal_lm = cls._instantiate_model_from_hf_hub(
name, revision, device, quantization_config
)
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/bert/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
Self = TypeVar("Self", bound="BERTEncoder")


class BERTEncoder(TransformerEncoder, FromHFHub):
class BERTEncoder(TransformerEncoder[BERTConfig], FromHFHub):
"""
BERT (`Devlin et al., 2018`_) encoder.
Expand All @@ -42,7 +42,7 @@ def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None)
:returns:
The encoder.
"""
super().__init__()
super().__init__(config)

self.embeddings = TransformerEmbeddings(
dropouts=EmbeddingDropouts(
Expand Down
10 changes: 9 additions & 1 deletion curated_transformers/models/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from dataclasses import dataclass
from typing import Optional
from typing import ClassVar, Optional, Protocol

from ..layers.activations import Activation


class ConfigDataclass(Protocol):
"""
Protocol that describes a config data class.
"""

__dataclass_fields__: ClassVar[dict]


@dataclass
class RotaryEmbeddingConfig:
"""
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/falcon/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Self = TypeVar("Self", bound="FalconCausalLM")


class FalconCausalLM(TransformerCausalLM, FromHFHub, Quantizable):
class FalconCausalLM(TransformerCausalLM[FalconConfig], FromHFHub, Quantizable):
"""
Falcon (`Penedo et al., 2019`_) causal language model.
Expand All @@ -35,7 +35,7 @@ def __init__(
:returns:
The causal LM.
"""
super().__init__()
super().__init__(config)

self.decoder = FalconDecoder(config, device=device)
self.output_embeddings = Linear(
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/falcon/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
Self = TypeVar("Self", bound="FalconDecoder")


class FalconDecoder(TransformerDecoder, FromHFHub):
class FalconDecoder(TransformerDecoder[FalconConfig], FromHFHub):
"""
Falcon (`Penedo et al., 2019`_) decoder.
Expand All @@ -51,7 +51,7 @@ def __init__(
:returns:
The decoder.
"""
super().__init__()
super().__init__(config)

self.embeddings = TransformerEmbeddings(
dropouts=EmbeddingDropouts(
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/gpt_neox/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Self = TypeVar("Self", bound="GPTNeoXCausalLM")


class GPTNeoXCausalLM(TransformerCausalLM, FromHFHub, Quantizable):
class GPTNeoXCausalLM(TransformerCausalLM[GPTNeoXConfig], FromHFHub, Quantizable):
"""
GPT-NeoX (`Black et al., 2022`_) causal language model.
Expand All @@ -35,7 +35,7 @@ def __init__(
:returns:
The causal LM.
"""
super().__init__()
super().__init__(config)

self.decoder = GPTNeoXDecoder(config, device=device)
self.output_embeddings = Linear(
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/gpt_neox/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
Self = TypeVar("Self", bound="GPTNeoXDecoder")


class GPTNeoXDecoder(TransformerDecoder, FromHFHub):
class GPTNeoXDecoder(TransformerDecoder[GPTNeoXConfig], FromHFHub):
"""
GPT-NeoX (`Black et al., 2022`_) decoder.
Expand All @@ -45,7 +45,7 @@ def __init__(
:returns:
The decoder.
"""
super().__init__()
super().__init__(config)

self.embeddings = TransformerEmbeddings(
dropouts=EmbeddingDropouts(
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/llama/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Self = TypeVar("Self", bound="LlamaCausalLM")


class LlamaCausalLM(TransformerCausalLM, FromHFHub, Quantizable):
class LlamaCausalLM(TransformerCausalLM[LlamaConfig], FromHFHub, Quantizable):
"""
Llama (`Touvron et al., 2023 [a]`_, `Touvron et al., 2023 [b]`_) causal language model.
Expand All @@ -36,7 +36,7 @@ def __init__(
:returns:
The causal LM.
"""
super().__init__()
super().__init__(config)

self.decoder = LlamaDecoder(config, device=device)
self.output_embeddings = Linear(
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/llama/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
Self = TypeVar("Self", bound="LlamaDecoder")


class LlamaDecoder(TransformerDecoder, FromHFHub):
class LlamaDecoder(TransformerDecoder[LlamaConfig], FromHFHub):
"""
Llama (`Touvron et al., 2023 [a]`_, `Touvron et al., 2023 [b]`_) decoder.
Expand All @@ -47,7 +47,7 @@ def __init__(
:returns:
The decoder.
"""
super().__init__()
super().__init__(config)

self.embeddings = TransformerEmbeddings(
dropouts=EmbeddingDropouts(
Expand Down
40 changes: 36 additions & 4 deletions curated_transformers/models/module.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
from abc import abstractmethod
from typing import Generic, List, Optional
from typing import Generic, List, Optional, TypeVar

from torch import Tensor
from torch.nn import Module

from ..layers.attention import AttentionMask
from .config import ConfigDataclass
from .output import CacheT, CausalLMOutputWithCache, ModelOutput, ModelOutputWithCache

ConfigT = TypeVar("ConfigT", bound=ConfigDataclass)

class CausalLMModule(Generic[CacheT], Module):

class TransformerModule(Generic[ConfigT], Module):
"""
Base class for transformer modules.
"""

_config: ConfigT

def __init__(self, config: ConfigT):
super().__init__()

self._config = config

@property
def config(self) -> ConfigT:
"""
Returns the model's configuration.
"""
return self._config


class CausalLMModule(Generic[ConfigT, CacheT], TransformerModule[ConfigT]):
"""
Base class for causal language model modules.
"""

def __init__(self, config: ConfigT):
super().__init__(config)

@abstractmethod
def forward(
self,
Expand Down Expand Up @@ -51,11 +77,14 @@ def forward(
raise NotImplementedError


class DecoderModule(Generic[CacheT], Module):
class DecoderModule(Generic[ConfigT, CacheT], TransformerModule[ConfigT]):
"""
Base class for decoder modules.
"""

def __init__(self, config: ConfigT):
super().__init__(config)

@abstractmethod
def forward(
self,
Expand Down Expand Up @@ -94,11 +123,14 @@ def forward(
raise NotImplementedError


class EncoderModule(Module):
class EncoderModule(Generic[ConfigT], TransformerModule[ConfigT]):
"""
Base class for encoder modules.
"""

def __init__(self, config: ConfigT):
super().__init__(config)

@abstractmethod
def forward(
self,
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/mpt/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Self = TypeVar("Self", bound="MPTCausalLM")


class MPTCausalLM(TransformerCausalLM, FromHFHub, Quantizable):
class MPTCausalLM(TransformerCausalLM[MPTConfig], FromHFHub, Quantizable):
"""
`MosaicML MPT`_ causal language model.
Expand All @@ -39,7 +39,7 @@ def __init__(
:returns:
The causal LM.
"""
super().__init__()
super().__init__(config)

self.decoder = MPTDecoder(config, device=device)

Expand Down
Loading

0 comments on commit 01e8902

Please sign in to comment.