diff --git a/src/transformers/adapters/__init__.py b/src/transformers/adapters/__init__.py index fea1f63968..dd96b8bb32 100644 --- a/src/transformers/adapters/__init__.py +++ b/src/transformers/adapters/__init__.py @@ -114,6 +114,10 @@ "GPT2ModelWithHeads", ], "models.gptj": ["GPTJAdapterModel"], + "models.hubert": [ + "HubertAdapterModel", + "HubertModelWithHeads", + ], "models.mbart": [ "MBartAdapterModel", "MBartModelWithHeads", @@ -218,6 +222,7 @@ from .models.distilbert import DistilBertAdapterModel, DistilBertModelWithHeads from .models.gpt2 import GPT2AdapterModel, GPT2ModelWithHeads from .models.gptj import GPTJAdapterModel + from .models.hubert import HubertAdapterModel, HubertModelWithHeads from .models.mbart import MBartAdapterModel, MBartModelWithHeads from .models.roberta import RobertaAdapterModel, RobertaModelWithHeads from .models.t5 import T5AdapterModel, T5ModelWithHeads diff --git a/src/transformers/adapters/mixins/hubert.py b/src/transformers/adapters/mixins/hubert.py new file mode 100644 index 0000000000..b294b7cbf6 --- /dev/null +++ b/src/transformers/adapters/mixins/hubert.py @@ -0,0 +1,38 @@ +from typing import Iterable, Tuple + +import torch.nn as nn + +from ..layer import AdapterLayer +from ..model_mixin import ModelAdaptersMixin, ModelWithHeadsAdaptersMixin + + +class HubertEncoderLayerAdaptersMixin: + """Adds adapters to the Encoder Layer module of Hubert.""" + + def _init_adapter_modules(self): + self.attention_adapters = AdapterLayer("mh_adapter", self.config) + self.output_adapters = AdapterLayer("output_adapter", self.config) + self.attention_adapters._init_adapter_modules() + self.output_adapters._init_adapter_modules() + + +class HubertEncoderLayerStableLayerNormAdaptersMixin: + """Adds adapters to the Encoder Layer Stable Layer Norm module of Hubert.""" + + def _init_adapter_modules(self): + self.attention_adapters = AdapterLayer("mh_adapter", self.config) + self.output_adapters = AdapterLayer("output_adapter", self.config) + self.attention_adapters._init_adapter_modules() + self.output_adapters._init_adapter_modules() + + +class HubertModelAdaptersMixin(ModelAdaptersMixin): + """Adds adapters to the Hubert module.""" + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + for i, layer in enumerate(self.transformer.layer): + yield i, layer + + +class HubertModelWithHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin): + pass diff --git a/src/transformers/adapters/models/auto/adapter_model.py b/src/transformers/adapters/models/auto/adapter_model.py index cfd159bad6..e822598eb6 100644 --- a/src/transformers/adapters/models/auto/adapter_model.py +++ b/src/transformers/adapters/models/auto/adapter_model.py @@ -21,6 +21,7 @@ ("mbart", "MBartAdapterModel"), ("gpt2", "GPT2AdapterModel"), ("gptj", "GPTJAdapterModel"), + ("hubert", "HubertAdapterModel"), ("t5", "T5AdapterModel"), ("vit", "ViTAdapterModel"), ] @@ -34,6 +35,7 @@ ("bart", "BartModelWithHeads"), ("mbart", "MBartModelWithHeads"), ("gpt2", "GPT2ModelWithHeads"), + ("hubert", "HubertModelWithHeads"), ("t5", "T5ModelWithHeads"), ] ) diff --git a/src/transformers/adapters/models/hubert/__init__.py b/src/transformers/adapters/models/hubert/__init__.py new file mode 100644 index 0000000000..75f12b7f01 --- /dev/null +++ b/src/transformers/adapters/models/hubert/__init__.py @@ -0,0 +1,42 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The Adapter-Hub Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ....utils import _LazyModule + + +_import_structure = { + "adapter_model": [ + "HubertAdapterModel", + "HubertModelWithHeads", + ], +} + + +if TYPE_CHECKING: + from .adapter_model import HubertAdapterModel, HubertModelWithHeads + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + ) diff --git a/src/transformers/adapters/models/hubert/adapter_model.py b/src/transformers/adapters/models/hubert/adapter_model.py new file mode 100644 index 0000000000..32d0605cff --- /dev/null +++ b/src/transformers/adapters/models/hubert/adapter_model.py @@ -0,0 +1,208 @@ +import warnings + +import torch.nn as nn + +from ....models.hubert.modeling_hubert import ( + HUBERT_INPUTS_DOCSTRING, + HUBERT_START_DOCSTRING, + HubertModel, + HubertPreTrainedModel, +) +from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward +from ...heads import ( + ClassificationHead, + ModelWithFlexibleHeadsAdaptersMixin, + MultiLabelClassificationHead, + MultipleChoiceHead, +) + + +@add_start_docstrings( + """Hubert Model with the option to add multiple flexible heads on top.""", + HUBERT_START_DOCSTRING, +) +class HubertAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, HubertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.hubert = HubertModel(config) + + self._init_head_modules() + + self.init_weights() + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings + """ + return self.hubert.get_position_embeddings() + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings of the model if :obj:`new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (:obj:`int`): + The number of new position embedding matrix. If position embeddings are learned, increasing the size + will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the + end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the + size will add correct vectors at the end following the position encoding algorithm, whereas reducing + the size will remove vectors from the end. + """ + self.hubert.resize_position_embeddings(new_num_position_embeddings) + + @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING.format("batch_size, num_choices")) + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, + **kwargs + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + hubert_output = self.hubert( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + ) + + outputs = self.forward_head( + hubert_output, head_name=head, attention_mask=attention_mask, return_dict=return_dict, **kwargs + ) + + return outputs + + # Copied from RobertaForCausalLM + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False), + } + + head_types = { + "classification": ClassificationHead, + "multilabel_classification": MultiLabelClassificationHead, + "multiple_choice": MultipleChoiceHead, + } + + def add_classification_head( + self, + head_name, + num_labels=2, + layers=2, + activation_function="tanh", + overwrite_ok=False, + multilabel=False, + id2label=None, + use_pooler=False, + ): + """ + Adds a sequence classification head on top of the model. + + Args: + head_name (str): The name of the head. + num_labels (int, optional): Number of classification labels. Defaults to 2. + layers (int, optional): Number of layers. Defaults to 2. + activation_function (str, optional): Activation function. Defaults to 'tanh'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + multilabel (bool, optional): Enable multilabel classification setup. Defaults to False. + """ + + if multilabel: + head = MultiLabelClassificationHead( + self, head_name, num_labels, layers, activation_function, id2label, use_pooler + ) + else: + head = ClassificationHead(self, head_name, num_labels, layers, activation_function, id2label, use_pooler) + self.add_prediction_head(head, overwrite_ok) + + def add_multiple_choice_head( + self, + head_name, + num_choices=2, + layers=2, + activation_function="tanh", + overwrite_ok=False, + id2label=None, + use_pooler=False, + ): + """ + Adds a multiple choice head on top of the model. + + Args: + head_name (str): The name of the head. + num_choices (int, optional): Number of choices. Defaults to 2. + layers (int, optional): Number of layers. Defaults to 2. + activation_function (str, optional): Activation function. Defaults to 'tanh'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = MultipleChoiceHead(self, head_name, num_choices, layers, activation_function, id2label, use_pooler) + self.add_prediction_head(head, overwrite_ok) + + +class HubertModelWithHeads(HubertAdapterModel): + def __init__(self, *args, **kwargs): + warnings.warn( + "This class has been renamed to `{}` in v3. " + "Please use the new class instead as this class might be removed in a future version.".format( + self.__class__.__bases__[0].__name__ + ), + FutureWarning, + ) + super().__init__(*args, **kwargs) + + @classmethod + def from_config(cls, config): + warnings.warn( + "This class has been renamed to `{}` in v3. " + "Please use the new class instead as this class might be removed in a future version.".format( + cls.__bases__[0].__name__ + ), + FutureWarning, + ) + return super().from_config(config) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + warnings.warn( + "This class has been renamed to `{}` in v3. " + "Please use the new class instead as this class might be removed in a future version.".format( + cls.__bases__[0].__name__ + ), + FutureWarning, + ) + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) diff --git a/src/transformers/adapters/wrappers/configuration.py b/src/transformers/adapters/wrappers/configuration.py index 3506d93f70..3702b2b18d 100644 --- a/src/transformers/adapters/wrappers/configuration.py +++ b/src/transformers/adapters/wrappers/configuration.py @@ -39,6 +39,7 @@ "hidden_dropout_prob": "resid_pdrop", "attention_probs_dropout_prob": "attn_pdrop", }, + "hubert": {}, "mbart": { "num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model", diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index a96ef5cf5d..276b8b0c65 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -26,6 +26,16 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled from ...activations import ACT2FN +from ...adapters.composition import adjust_tensors_for_parallel +from ...adapters.context import ForwardContext +from ...adapters.lora import Linear as LoRALinear +from ...adapters.mixins.hubert import ( + HubertEncoderLayerAdaptersMixin, + HubertEncoderLayerStableLayerNormAdaptersMixin, + HubertModelAdaptersMixin, + HubertModelWithHeadsAdaptersMixin, +) +from ...adapters.prefix_tuning import PrefixTuningShim from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import torch_int_div @@ -396,6 +406,7 @@ class HubertAttention(nn.Module): def __init__( self, + config: HubertConfig, embed_dim: int, num_heads: int, dropout: float = 0.0, @@ -416,11 +427,13 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.k_proj = LoRALinear(embed_dim, embed_dim, "selfattn", config, attn_key="k", bias=bias) + self.v_proj = LoRALinear(embed_dim, embed_dim, "selfattn", config, attn_key="v", bias=bias) + self.q_proj = LoRALinear(embed_dim, embed_dim, "selfattn", config, attn_key="q", bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.prefix_tuning = PrefixTuningShim(None, config) + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -480,6 +493,12 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + bsz = query_states.size(0) + proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) @@ -547,17 +566,17 @@ def forward( # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Hubert class HubertFeedForward(nn.Module): - def __init__(self, config): + def __init__(self, config: HubertConfig): super().__init__() self.intermediate_dropout = nn.Dropout(config.activation_dropout) - self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_dense = LoRALinear(config.hidden_size, config.intermediate_size, "intermediate", config) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act - self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.output_dense = LoRALinear(config.intermediate_size, config.hidden_size, "output", config) self.output_dropout = nn.Dropout(config.hidden_dropout) def forward(self, hidden_states): @@ -571,45 +590,51 @@ def forward(self, hidden_states): # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->Hubert -class HubertEncoderLayer(nn.Module): - def __init__(self, config): +class HubertEncoderLayer(HubertEncoderLayerAdaptersMixin, nn.Module): + def __init__(self, config: HubertConfig): super().__init__() self.attention = HubertAttention( + config, embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, ) self.dropout = nn.Dropout(config.hidden_dropout) - self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.sa_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = HubertFeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self._init_adapter_modules() + def forward(self, hidden_states, attention_mask=None, output_attentions=False): attn_residual = hidden_states - hidden_states, attn_weights, _ = self.attention( + sa_output, sa_weights, _ = self.attention( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions ) - hidden_states = self.dropout(hidden_states) - hidden_states = attn_residual + hidden_states + sa_output = self.dropout(sa_output) + sa_output = self.attention_adapters(sa_output, attn_residual, self.sa_layer_norm) # (bs, seq_length, dim) - hidden_states = self.layer_norm(hidden_states) - hidden_states = hidden_states + self.feed_forward(hidden_states) - hidden_states = self.final_layer_norm(hidden_states) + # Feed Forward Network + ffn_output = self.feed_forward(sa_output) # (bs, seq_length, dim) + ffn_output: torch.Tensor = self.output_adapters( + ffn_output, sa_output, self.final_layer_norm + ) # (bs, seq_length, dim) - outputs = (hidden_states,) + outputs = (ffn_output,) if output_attentions: - outputs += (attn_weights,) + outputs += (sa_weights,) return outputs # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->Hubert -class HubertEncoderLayerStableLayerNorm(nn.Module): - def __init__(self, config): +class HubertEncoderLayerStableLayerNorm(HubertEncoderLayerStableLayerNormAdaptersMixin, nn.Module): + def __init__(self, config: HubertConfig): super().__init__() self.attention = HubertAttention( + config, embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, @@ -620,6 +645,8 @@ def __init__(self, config): self.feed_forward = HubertFeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self._init_adapter_modules() + def forward( self, hidden_states: torch.Tensor, @@ -628,24 +655,29 @@ def forward( ): attn_residual = hidden_states hidden_states = self.layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.attention( + sa_output, sa_weights, _ = self.attention( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions ) - hidden_states = self.dropout(hidden_states) - hidden_states = attn_residual + hidden_states - hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + sa_output = self.dropout(sa_output) + sa_output = self.attention_adapters(sa_output, attn_residual) # (bs, seq_length, dim) + + # Feed Forward Network + ffn_output = self.feed_forward(sa_output) # (bs, seq_length, dim) + ffn_output: torch.Tensor = self.output_adapters( + ffn_output, sa_output, self.final_layer_norm + ) # (bs, seq_length, dim) - outputs = (hidden_states,) + outputs = (ffn_output,) if output_attentions: - outputs += (attn_weights,) + outputs += (sa_weights,) return outputs # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->Hubert class HubertEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config: HubertConfig): super().__init__() self.config = config self.pos_conv_embed = HubertPositionalConvEmbedding(config) @@ -733,7 +765,7 @@ def custom_forward(*inputs): # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderStableLayerNorm with Wav2Vec2->Hubert class HubertEncoderStableLayerNorm(nn.Module): - def __init__(self, config): + def __init__(self, config: HubertConfig): super().__init__() self.config = config self.pos_conv_embed = HubertPositionalConvEmbedding(config) @@ -953,7 +985,7 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti "The bare Hubert Model transformer outputting raw hidden-states without any specific head on top.", HUBERT_START_DOCSTRING, ) -class HubertModel(HubertPreTrainedModel): +class HubertModel(HubertModelAdaptersMixin, HubertPreTrainedModel): def __init__(self, config: HubertConfig): super().__init__(config) self.config = config @@ -968,6 +1000,8 @@ def __init__(self, config: HubertConfig): else: self.encoder = HubertEncoder(config) + self._init_adapter_modules() + # Initialize weights and apply final processing self.post_init() @@ -1020,6 +1054,7 @@ def _mask_hidden_states( @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + @ForwardContext.wrap def forward( self, input_values: Optional[torch.Tensor], @@ -1097,7 +1132,7 @@ def forward( HUBERT_START_DOCSTRING, ) # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT -class HubertForCTC(HubertPreTrainedModel): +class HubertForCTC(HubertModelWithHeadsAdaptersMixin, HubertPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1227,7 +1262,7 @@ def forward( HUBERT_START_DOCSTRING, ) # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT -class HubertForSequenceClassification(HubertPreTrainedModel): +class HubertForSequenceClassification(HubertModelWithHeadsAdaptersMixin, HubertPreTrainedModel): def __init__(self, config): super().__init__(config) diff --git a/utils/check_adapters.py b/utils/check_adapters.py index fe8c902c45..92eb8ff315 100644 --- a/utils/check_adapters.py +++ b/utils/check_adapters.py @@ -20,6 +20,7 @@ "deberta_v2", "vit", "clip", + "hubert", ] IGNORE_NOT_IMPLEMENTING_MIXIN = [