diff --git a/curated_transformers/generation/default_generator.py b/curated_transformers/generation/default_generator.py index 16d6c9b6..358161d3 100644 --- a/curated_transformers/generation/default_generator.py +++ b/curated_transformers/generation/default_generator.py @@ -1,10 +1,11 @@ import dataclasses -from typing import List, Optional, Type, TypeVar +from typing import Any, Generic, List, Optional, Type, TypeVar import torch from ..models.auto_model import AutoCausalLM from ..models.module import CausalLMModule +from ..models.output import CacheT from ..quantization.bnb.config import BitsAndBytesConfig from ..tokenizers.auto_tokenizer import AutoTokenizer from ..tokenizers.chunks import InputChunks, TextChunk @@ -19,7 +20,7 @@ Self = TypeVar("Self", bound="DefaultGenerator") -class DefaultGenerator(GeneratorWrapper, FromHFHub): +class DefaultGenerator(Generic[CacheT], GeneratorWrapper, FromHFHub): """ Generator wrapper for models that do not need specific prompting. """ @@ -27,7 +28,7 @@ class DefaultGenerator(GeneratorWrapper, FromHFHub): def __init__( self, tokenizer: TokenizerBase, - causal_lm: CausalLMModule, + causal_lm: CausalLMModule[Any, CacheT], default_config: Optional[GeneratorConfig] = None, ): """ diff --git a/curated_transformers/generation/generator.py b/curated_transformers/generation/generator.py index 3d1b2d7b..c7bbd927 100644 --- a/curated_transformers/generation/generator.py +++ b/curated_transformers/generation/generator.py @@ -1,4 +1,4 @@ -from typing import Generic, Iterator, List, Optional, Tuple +from typing import Any, Generic, Iterator, List, Optional, Tuple import torch from torch import Tensor @@ -17,9 +17,9 @@ class Generator(Generic[CacheT]): Generator base class for causal language models. """ - model: CausalLMModule[CacheT] + model: CausalLMModule[Any, CacheT] - def __init__(self, model: CausalLMModule[CacheT]): + def __init__(self, model: CausalLMModule[Any, CacheT]): """ Construct a generator. diff --git a/curated_transformers/models/albert/encoder.py b/curated_transformers/models/albert/encoder.py index 51c801fa..ffaf6489 100644 --- a/curated_transformers/models/albert/encoder.py +++ b/curated_transformers/models/albert/encoder.py @@ -21,7 +21,7 @@ Self = TypeVar("Self", bound="ALBERTEncoder") -class ALBERTEncoder(EncoderModule, FromHFHub): +class ALBERTEncoder(EncoderModule[ALBERTConfig], FromHFHub): """ ALBERT (`Lan et al., 2022`_) encoder. @@ -39,7 +39,7 @@ def __init__(self, config: ALBERTConfig, *, device: Optional[torch.device] = Non :returns: The encoder. """ - super().__init__() + super().__init__(config) self.max_seq_len = config.model_max_length self.n_hidden_layers = config.layer.n_hidden_layers diff --git a/curated_transformers/models/auto_model.py b/curated_transformers/models/auto_model.py index b57f5181..ef7d3be6 100644 --- a/curated_transformers/models/auto_model.py +++ b/curated_transformers/models/auto_model.py @@ -11,6 +11,7 @@ from .albert import ALBERTEncoder from .bert import BERTEncoder from .camembert import CamemBERTEncoder +from .config import ConfigDataclass from .falcon import FalconCausalLM, FalconDecoder from .gpt_neox import GPTNeoXCausalLM, GPTNeoXDecoder from .hf_hub import FromHFHub @@ -184,7 +185,7 @@ def from_hf_hub_to_cache( module_cls.from_hf_hub_to_cache(name=name, revision=revision) -class AutoEncoder(AutoModel[EncoderModule]): +class AutoEncoder(AutoModel[EncoderModule[ConfigDataclass]]): """ Encoder model loaded from the Hugging Face Model Hub. """ @@ -206,7 +207,7 @@ def from_fsspec( fsspec_args: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None, quantization_config: Optional[BitsAndBytesConfig] = None, - ) -> EncoderModule: + ) -> EncoderModule[ConfigDataclass]: encoder = cls._instantiate_model_from_fsspec( fs, model_path, fsspec_args, device, quantization_config ) @@ -221,7 +222,7 @@ def from_hf_hub( revision: str = "main", device: Optional[torch.device] = None, quantization_config: Optional[BitsAndBytesConfig] = None, - ) -> EncoderModule: + ) -> EncoderModule[ConfigDataclass]: encoder = cls._instantiate_model_from_hf_hub( name, revision, device, quantization_config ) @@ -229,7 +230,7 @@ def from_hf_hub( return encoder -class AutoDecoder(AutoModel[DecoderModule]): +class AutoDecoder(AutoModel[DecoderModule[ConfigDataclass, KeyValueCache]]): """ Decoder module loaded from the Hugging Face Model Hub. """ @@ -252,7 +253,7 @@ def from_fsspec( fsspec_args: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None, quantization_config: Optional[BitsAndBytesConfig] = None, - ) -> DecoderModule: + ) -> DecoderModule[ConfigDataclass, KeyValueCache]: decoder = cls._instantiate_model_from_fsspec( fs, model_path, fsspec_args, device, quantization_config ) @@ -267,7 +268,7 @@ def from_hf_hub( revision: str = "main", device: Optional[torch.device] = None, quantization_config: Optional[BitsAndBytesConfig] = None, - ) -> DecoderModule: + ) -> DecoderModule[ConfigDataclass, KeyValueCache]: decoder = cls._instantiate_model_from_hf_hub( name, revision, device, quantization_config ) @@ -275,7 +276,7 @@ def from_hf_hub( return decoder -class AutoCausalLM(AutoModel[CausalLMModule[KeyValueCache]]): +class AutoCausalLM(AutoModel[CausalLMModule[ConfigDataclass, KeyValueCache]]): """ Causal LM model loaded from the Hugging Face Model Hub. """ @@ -298,7 +299,7 @@ def from_fsspec( fsspec_args: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None, quantization_config: Optional[BitsAndBytesConfig] = None, - ) -> CausalLMModule[KeyValueCache]: + ) -> CausalLMModule[ConfigDataclass, KeyValueCache]: causal_lm = cls._instantiate_model_from_fsspec( fs, model_path, fsspec_args, device, quantization_config ) @@ -313,7 +314,7 @@ def from_hf_hub( revision: str = "main", device: Optional[torch.device] = None, quantization_config: Optional[BitsAndBytesConfig] = None, - ) -> CausalLMModule[KeyValueCache]: + ) -> CausalLMModule[ConfigDataclass, KeyValueCache]: causal_lm = cls._instantiate_model_from_hf_hub( name, revision, device, quantization_config ) diff --git a/curated_transformers/models/bert/encoder.py b/curated_transformers/models/bert/encoder.py index c8048f93..0dc55d1b 100644 --- a/curated_transformers/models/bert/encoder.py +++ b/curated_transformers/models/bert/encoder.py @@ -24,7 +24,7 @@ Self = TypeVar("Self", bound="BERTEncoder") -class BERTEncoder(TransformerEncoder, FromHFHub): +class BERTEncoder(TransformerEncoder[BERTConfig], FromHFHub): """ BERT (`Devlin et al., 2018`_) encoder. @@ -42,7 +42,7 @@ def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None) :returns: The encoder. """ - super().__init__() + super().__init__(config) self.embeddings = TransformerEmbeddings( dropouts=EmbeddingDropouts( diff --git a/curated_transformers/models/config.py b/curated_transformers/models/config.py index 4990e19e..189ab254 100644 --- a/curated_transformers/models/config.py +++ b/curated_transformers/models/config.py @@ -1,9 +1,17 @@ from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional, Protocol from ..layers.activations import Activation +class ConfigDataclass(Protocol): + """ + Protocol that describes a config data class. + """ + + __dataclass_fields__: ClassVar[dict] + + @dataclass class RotaryEmbeddingConfig: """ diff --git a/curated_transformers/models/falcon/causal_lm.py b/curated_transformers/models/falcon/causal_lm.py index 34228f06..ae570d48 100644 --- a/curated_transformers/models/falcon/causal_lm.py +++ b/curated_transformers/models/falcon/causal_lm.py @@ -15,7 +15,7 @@ Self = TypeVar("Self", bound="FalconCausalLM") -class FalconCausalLM(TransformerCausalLM, FromHFHub, Quantizable): +class FalconCausalLM(TransformerCausalLM[FalconConfig], FromHFHub, Quantizable): """ Falcon (`Penedo et al., 2019`_) causal language model. @@ -35,7 +35,7 @@ def __init__( :returns: The causal LM. """ - super().__init__() + super().__init__(config) self.decoder = FalconDecoder(config, device=device) self.output_embeddings = Linear( diff --git a/curated_transformers/models/falcon/decoder.py b/curated_transformers/models/falcon/decoder.py index cec6f5b9..0a16154e 100644 --- a/curated_transformers/models/falcon/decoder.py +++ b/curated_transformers/models/falcon/decoder.py @@ -31,7 +31,7 @@ Self = TypeVar("Self", bound="FalconDecoder") -class FalconDecoder(TransformerDecoder, FromHFHub): +class FalconDecoder(TransformerDecoder[FalconConfig], FromHFHub): """ Falcon (`Penedo et al., 2019`_) decoder. @@ -51,7 +51,7 @@ def __init__( :returns: The decoder. """ - super().__init__() + super().__init__(config) self.embeddings = TransformerEmbeddings( dropouts=EmbeddingDropouts( diff --git a/curated_transformers/models/gpt_neox/causal_lm.py b/curated_transformers/models/gpt_neox/causal_lm.py index 7a1b9c2f..cc6f9bb7 100644 --- a/curated_transformers/models/gpt_neox/causal_lm.py +++ b/curated_transformers/models/gpt_neox/causal_lm.py @@ -15,7 +15,7 @@ Self = TypeVar("Self", bound="GPTNeoXCausalLM") -class GPTNeoXCausalLM(TransformerCausalLM, FromHFHub, Quantizable): +class GPTNeoXCausalLM(TransformerCausalLM[GPTNeoXConfig], FromHFHub, Quantizable): """ GPT-NeoX (`Black et al., 2022`_) causal language model. @@ -35,7 +35,7 @@ def __init__( :returns: The causal LM. """ - super().__init__() + super().__init__(config) self.decoder = GPTNeoXDecoder(config, device=device) self.output_embeddings = Linear( diff --git a/curated_transformers/models/gpt_neox/decoder.py b/curated_transformers/models/gpt_neox/decoder.py index 39c06eea..d04189ca 100644 --- a/curated_transformers/models/gpt_neox/decoder.py +++ b/curated_transformers/models/gpt_neox/decoder.py @@ -25,7 +25,7 @@ Self = TypeVar("Self", bound="GPTNeoXDecoder") -class GPTNeoXDecoder(TransformerDecoder, FromHFHub): +class GPTNeoXDecoder(TransformerDecoder[GPTNeoXConfig], FromHFHub): """ GPT-NeoX (`Black et al., 2022`_) decoder. @@ -45,7 +45,7 @@ def __init__( :returns: The decoder. """ - super().__init__() + super().__init__(config) self.embeddings = TransformerEmbeddings( dropouts=EmbeddingDropouts( diff --git a/curated_transformers/models/llama/causal_lm.py b/curated_transformers/models/llama/causal_lm.py index 58e971e1..135178cc 100644 --- a/curated_transformers/models/llama/causal_lm.py +++ b/curated_transformers/models/llama/causal_lm.py @@ -15,7 +15,7 @@ Self = TypeVar("Self", bound="LlamaCausalLM") -class LlamaCausalLM(TransformerCausalLM, FromHFHub, Quantizable): +class LlamaCausalLM(TransformerCausalLM[LlamaConfig], FromHFHub, Quantizable): """ Llama (`Touvron et al., 2023 [a]`_, `Touvron et al., 2023 [b]`_) causal language model. @@ -36,7 +36,7 @@ def __init__( :returns: The causal LM. """ - super().__init__() + super().__init__(config) self.decoder = LlamaDecoder(config, device=device) self.output_embeddings = Linear( diff --git a/curated_transformers/models/llama/decoder.py b/curated_transformers/models/llama/decoder.py index d26c9950..072a2cc3 100644 --- a/curated_transformers/models/llama/decoder.py +++ b/curated_transformers/models/llama/decoder.py @@ -26,7 +26,7 @@ Self = TypeVar("Self", bound="LlamaDecoder") -class LlamaDecoder(TransformerDecoder, FromHFHub): +class LlamaDecoder(TransformerDecoder[LlamaConfig], FromHFHub): """ Llama (`Touvron et al., 2023 [a]`_, `Touvron et al., 2023 [b]`_) decoder. @@ -47,7 +47,7 @@ def __init__( :returns: The decoder. """ - super().__init__() + super().__init__(config) self.embeddings = TransformerEmbeddings( dropouts=EmbeddingDropouts( diff --git a/curated_transformers/models/module.py b/curated_transformers/models/module.py index eaaa2bb9..b5bcf91d 100644 --- a/curated_transformers/models/module.py +++ b/curated_transformers/models/module.py @@ -1,18 +1,44 @@ from abc import abstractmethod -from typing import Generic, List, Optional +from typing import Generic, List, Optional, TypeVar from torch import Tensor from torch.nn import Module from ..layers.attention import AttentionMask +from .config import ConfigDataclass from .output import CacheT, CausalLMOutputWithCache, ModelOutput, ModelOutputWithCache +ConfigT = TypeVar("ConfigT", bound=ConfigDataclass) -class CausalLMModule(Generic[CacheT], Module): + +class TransformerModule(Generic[ConfigT], Module): + """ + Base class for transformer modules. + """ + + _config: ConfigT + + def __init__(self, config: ConfigT): + super().__init__() + + self._config = config + + @property + def config(self) -> ConfigT: + """ + Returns the model's configuration. + """ + return self._config + + +class CausalLMModule(Generic[ConfigT, CacheT], TransformerModule[ConfigT]): """ Base class for causal language model modules. """ + def __init__(self, config: ConfigT): + super().__init__(config) + @abstractmethod def forward( self, @@ -51,11 +77,14 @@ def forward( raise NotImplementedError -class DecoderModule(Generic[CacheT], Module): +class DecoderModule(Generic[ConfigT, CacheT], TransformerModule[ConfigT]): """ Base class for decoder modules. """ + def __init__(self, config: ConfigT): + super().__init__(config) + @abstractmethod def forward( self, @@ -94,11 +123,14 @@ def forward( raise NotImplementedError -class EncoderModule(Module): +class EncoderModule(Generic[ConfigT], TransformerModule[ConfigT]): """ Base class for encoder modules. """ + def __init__(self, config: ConfigT): + super().__init__(config) + @abstractmethod def forward( self, diff --git a/curated_transformers/models/mpt/causal_lm.py b/curated_transformers/models/mpt/causal_lm.py index 8e850420..0006c064 100644 --- a/curated_transformers/models/mpt/causal_lm.py +++ b/curated_transformers/models/mpt/causal_lm.py @@ -19,7 +19,7 @@ Self = TypeVar("Self", bound="MPTCausalLM") -class MPTCausalLM(TransformerCausalLM, FromHFHub, Quantizable): +class MPTCausalLM(TransformerCausalLM[MPTConfig], FromHFHub, Quantizable): """ `MosaicML MPT`_ causal language model. @@ -39,7 +39,7 @@ def __init__( :returns: The causal LM. """ - super().__init__() + super().__init__(config) self.decoder = MPTDecoder(config, device=device) diff --git a/curated_transformers/models/mpt/decoder.py b/curated_transformers/models/mpt/decoder.py index 5a535065..1207dc5c 100644 --- a/curated_transformers/models/mpt/decoder.py +++ b/curated_transformers/models/mpt/decoder.py @@ -29,7 +29,7 @@ Self = TypeVar("Self", bound="MPTDecoder") -class MPTDecoder(TransformerDecoder, FromHFHub): +class MPTDecoder(TransformerDecoder[MPTConfig], FromHFHub): """ `MosaicML MPT`_ decoder. @@ -49,7 +49,7 @@ def __init__( :returns: The decoder. """ - super().__init__() + super().__init__(config) self.embeddings = TransformerEmbeddings( dropouts=EmbeddingDropouts( diff --git a/curated_transformers/models/roberta/encoder.py b/curated_transformers/models/roberta/encoder.py index dad33a59..11b06c1d 100644 --- a/curated_transformers/models/roberta/encoder.py +++ b/curated_transformers/models/roberta/encoder.py @@ -24,7 +24,7 @@ Self = TypeVar("Self", bound="RoBERTaEncoder") -class RoBERTaEncoder(TransformerEncoder, FromHFHub): +class RoBERTaEncoder(TransformerEncoder[RoBERTaConfig], FromHFHub): """ RoBERTa (`Liu et al., 2019`_) encoder. @@ -42,7 +42,7 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No :returns: The encoder. """ - super().__init__() + super().__init__(config) self.embeddings = RoBERTaEmbeddings( dropouts=EmbeddingDropouts( diff --git a/curated_transformers/models/transformer.py b/curated_transformers/models/transformer.py index 44fcf1ad..54d635fa 100644 --- a/curated_transformers/models/transformer.py +++ b/curated_transformers/models/transformer.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Generic, List, Optional import torch from torch import Tensor @@ -6,11 +6,11 @@ from ..layers.attention import AttentionMask from ..layers.cache import KeyValueCache -from .module import CausalLMModule, DecoderModule, EncoderModule +from .module import CausalLMModule, ConfigT, DecoderModule, EncoderModule from .output import CausalLMOutputWithCache, ModelOutput, ModelOutputWithCache -class TransformerDecoder(DecoderModule): +class TransformerDecoder(Generic[ConfigT], DecoderModule[ConfigT, KeyValueCache]): """ Transformer decoder (`Vaswani et al., 2017`_) base class. @@ -64,7 +64,7 @@ def forward( ) -class TransformerCausalLM(CausalLMModule[KeyValueCache]): +class TransformerCausalLM(Generic[ConfigT], CausalLMModule[ConfigT, KeyValueCache]): """ Transformer causal LM (`Vaswani et al., 2017`_) base class. @@ -100,7 +100,7 @@ def forward( ) -class TransformerEncoder(EncoderModule): +class TransformerEncoder(Generic[ConfigT], EncoderModule[ConfigT]): """ Transformer encoder (`Vaswani et al., 2017`_) base class. diff --git a/curated_transformers/tests/models/util.py b/curated_transformers/tests/models/util.py index aaf4adbd..d9a12926 100644 --- a/curated_transformers/tests/models/util.py +++ b/curated_transformers/tests/models/util.py @@ -13,6 +13,7 @@ CausalLMModule, DecoderModule, EncoderModule, + TransformerModule, ) from curated_transformers.models.output import ModelOutput, ModelOutputWithCache @@ -147,6 +148,7 @@ def assert_decoder_output_equals_hf( with_mask: bool = True, jit_method: JITMethod = JITMethod.Disable, with_torch_sdp=False, + check_config=True, ): orig_model = model_class.from_hf_hub( name=model_name, revision=model_revision, device=torch_device @@ -193,6 +195,9 @@ def assert_decoder_output_equals_hf( orig_model, hf_model, torch_device, atol, rtol, jit_method ) + if check_config and jit_method == JITMethod.Disable: + assert_model_config(model, Y) + def assert_encoder_output_equals_hf( model_class: Type[FromHFHub], @@ -204,6 +209,7 @@ def assert_encoder_output_equals_hf( jit_method: JITMethod = JITMethod.Disable, with_fsspec: bool = False, with_torch_sdp: bool = False, + check_config=True, ): if with_fsspec: orig_model = model_class.from_fsspec( @@ -240,6 +246,9 @@ def assert_encoder_output_equals_hf( orig_model, hf_model, torch_device, atol, rtol, jit_method ) + if check_config and jit_method == JITMethod.Disable: + assert_model_config(model, Y) + def assert_decoder_with_cache_output_equals_hf( orig_model: DecoderModule, @@ -347,3 +356,11 @@ def assert_decoder_with_positions_equals_hf( Y_hf = hf_model(X, position_ids=positions).last_hidden_state torch_assertclose(Y, Y_hf, atol=atol, rtol=rtol) + + +def assert_model_config(model: TransformerModule, model_output: Tensor): + assert isinstance(model, TransformerModule) + config = model.config + + hidden_width = model_output.size(-1) + assert config.layer.feedforward.hidden_width == hidden_width diff --git a/docs/source/causal-lm.rst b/docs/source/causal-lm.rst index def1f1de..f9507669 100644 --- a/docs/source/causal-lm.rst +++ b/docs/source/causal-lm.rst @@ -7,10 +7,12 @@ Base Classes .. autoclass:: curated_transformers.models.CausalLMModule :members: :show-inheritance: + :inherited-members: Module .. autoclass:: curated_transformers.models.TransformerCausalLM :members: :show-inheritance: + :inherited-members: Module Architectures ------------- diff --git a/docs/source/decoders.rst b/docs/source/decoders.rst index 49019545..fdfd2a5e 100644 --- a/docs/source/decoders.rst +++ b/docs/source/decoders.rst @@ -7,10 +7,12 @@ Base Classes .. autoclass:: curated_transformers.models.DecoderModule :members: :show-inheritance: + :inherited-members: Module .. autoclass:: curated_transformers.models.TransformerDecoder :members: :show-inheritance: + :inherited-members: Module Architectures ------------- diff --git a/docs/source/encoders.rst b/docs/source/encoders.rst index ea9860bb..31b305c4 100644 --- a/docs/source/encoders.rst +++ b/docs/source/encoders.rst @@ -7,10 +7,12 @@ Base Classes .. autoclass:: curated_transformers.models.EncoderModule :members: :show-inheritance: + :inherited-members: Module .. autoclass:: curated_transformers.models.TransformerEncoder :members: :show-inheritance: + :inherited-members: Module Architectures -------------