diff --git a/configs/xlmr_stratify_0.1_3layers_p_v2_look10.json b/configs/xlmr_stratify_0.1_3layers_p_v2_look10.json new file mode 100644 index 00000000..72ae58cb --- /dev/null +++ b/configs/xlmr_stratify_0.1_3layers_p_v2_look10.json @@ -0,0 +1,43 @@ +{ + "model_name_or_path": "xlm-roberta-base", + "output_dir": "xlmr-normal-v2-look10", + "train_text_path": "data/sentence/train.parquet", + "valid_text_path": "data/sentence/valid.parquet", + "block_size": 512, + "use_bert": true, + "do_train": true, + "do_eval": true, + "evaluation_strategy": "steps", + "per_device_train_batch_size": 64, + "per_device_eval_batch_size": 32, + "gradient_accumulation_steps": 1, + "eval_accumulation_steps": 8, + "dataloader_num_workers": 4, + "preprocessing_num_workers": 32, + "learning_rate": 1e-4, + "save_strategy": "steps", + "fp16": false, + "max_steps": 200000, + "save_steps": 200000, + "eval_steps": 5000, + "logging_steps": 50, + "report_to": "wandb", + "is_decoder": false, + "remove_unused_columns": false, + "lookahead": 10, + "one_sample_per_line": false, + "do_sentence_training": true, + "do_auxiliary_training": true, + "warmup_steps": 5000, + "adapter_warmup_steps": 0, + "adapter_lr_multiplier": 1, + "ngram_order": 1, + "non_punctuation_sample_ratio": 0.1, + "prediction_loss_only": true, + "use_auxiliary": true, + "ddp_timeout": 3600, + "use_subwords": true, + "num_hidden_layers": 3, + "custom_punctuation_file": "punctuation_xlmr_unk.txt", + "log_level": "warning", +} \ No newline at end of file diff --git a/wtpsplit/models.py b/wtpsplit/models.py index c0659593..cfbd9eca 100644 --- a/wtpsplit/models.py +++ b/wtpsplit/models.py @@ -1,12 +1,16 @@ import copy import math -from typing import Optional, Tuple, Union +import warnings +from typing import List, Optional, Tuple, Union import torch -from torch import nn +from torch import Tensor, nn +from torch.nn import CrossEntropyLoss +from torchinfo import summary from transformers import AutoModel, AutoModelForTokenClassification +from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions +from transformers.modeling_utils import ModuleUtilsMixin from transformers.models.bert.modeling_bert import BertEncoder, BertForTokenClassification, BertModel, BertPooler -from transformers.models.xlm_roberta import XLMRobertaModel, XLMRobertaForTokenClassification from transformers.models.canine.modeling_canine import ( _PRIMES, ACT2FN, @@ -27,7 +31,15 @@ ConvProjection, TokenClassifierOutput, ) -from torchinfo import summary +from transformers.models.xlm_roberta import ( + XLMRobertaForTokenClassification, + XLMRobertaModel, +) +from transformers.models.xlm_roberta.modeling_xlm_roberta import ( + XLMRobertaEmbeddings, + XLMRobertaEncoder, + XLMRobertaPooler, +) from wtpsplit.configs import BertCharConfig, LACanineConfig, SubwordXLMConfig from wtpsplit.utils import Constants @@ -969,7 +981,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.roberta = XLMRobertaModel(config, add_pooling_layer=False) + self.roberta = SubwordXLMRobertaModel(config, add_pooling_layer=False) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) @@ -993,20 +1005,267 @@ def forward( hashed_ids: Optional[torch.Tensor] = None, language_ids=None, return_dict: Optional[bool] = None, + lookahead: Optional[int] = None, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: - return super().forward( + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( input_ids, - attention_mask, - token_type_ids, - position_ids, - head_mask, - inputs_embeds, - labels, - output_attentions, - output_hidden_states, - return_dict, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + lookahead=lookahead, ) + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class SubwordXLMRobertaModel(XLMRobertaModel): + _keys_to_ignore_on_load_missing = [r"position_ids"] + + # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->XLMRoberta + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = XLMRobertaEmbeddings(config) + self.encoder = XLMRobertaEncoder(config) + + self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.bert.modeling_bert.BertModel.forward + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + lookahead: Optional[int] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, lookahead) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + lookahead: Optional[int] = None, + device: torch.device = None, + dtype: torch.float = None, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + if dtype is None: + dtype = self.dtype + + if not (attention_mask.dim() == 2 and self.config.is_decoder): + # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` + if device is not None: + warnings.warn( + "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning + ) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( + input_shape, attention_mask, device + ) + if lookahead: + # lookahead mask of shape [batch_size, 1, seq_length, seq_length] + # the current token should attend to the next `lookahead` tokens + # the current token should not attend to the previous `lookahead` tokens + _, seq_length = attention_mask.shape + # Create a lookahead mask + lookahead_mask = torch.tril(torch.ones(seq_length, seq_length), diagonal=lookahead, out=None).to( + attention_mask.device + ) + # Combine the attention mask with the lookahead mask + extended_attention_mask = attention_mask[:, None, None, :] * lookahead_mask + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min + return extended_attention_mask + AutoModel.register(LACanineConfig, LACanineModel) AutoModelForTokenClassification.register(LACanineConfig, LACanineForTokenClassification) @@ -1024,7 +1283,7 @@ def forward( model_str = "xlm-roberta-base" config = AutoConfig.from_pretrained(model_str) config.num_labels = 4 - config.num_hidden_layers = 9 + config.num_hidden_layers = 1 backbone = SubwordXLMForTokenClassification.from_pretrained(model_str, config=config) print(summary(backbone, depth=4)) @@ -1032,12 +1291,14 @@ def forward( text = "This is a test\n sentence \n\n" tokenizer = AutoTokenizer.from_pretrained(model_str) - tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False) + tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False, pad_to_multiple_of=8, padding=True) from tokenizers import AddedToken tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]}) print(tokenizer.tokenize(text)) print(tokenizer.encode(text)) print(tokens) + # forward pass - print(backbone(**tokens)) + lookahead = 1 + print(backbone(**tokens, lookahead=lookahead)) diff --git a/wtpsplit/summary_plot.py b/wtpsplit/summary_plot.py index f8f3438a..3d512e11 100644 --- a/wtpsplit/summary_plot.py +++ b/wtpsplit/summary_plot.py @@ -3,9 +3,9 @@ import json FILES = [ - ".cache/xlmr-normal-v2_b512+s_64_intrinsic_results_u0.01.json", + ".cache/xlmr-normal-v2_b128+s64_intrinsic_results_u0.01.json", + ".cache/xlmr-normal-v2_b256+s64_intrinsic_results_u0.01.json", ".cache/xlmr-normal-v2_b512+s64_intrinsic_results_u0.01.json", - ".cache/xlm-tokenv2_intrinsic_results_u001.json", ] NAME = "test" @@ -21,9 +21,9 @@ def darken_color(color, factor): def plot_violin_from_json(files, name): # Prepare data - data = {"score": [], "metric": [], "file": [], "x": []} + data = {"score": [], "metric": [], "file": [], "x": [], "lang": []} spacing = 1.0 # Space between groups of metrics - violin_width = 0.3 # Width of each violin + violin_width = 0.5 / len(files) # Width of each violin # Base colors for each metric base_colors = {"u": (0, 123, 255, 0.6), "t": (40, 167, 69, 0.6), "punct": (255, 193, 7, 0.6)} @@ -52,6 +52,7 @@ def plot_violin_from_json(files, name): file.split("/")[-1].split(".")[0] ) # Use file base name without extension for legend data["x"].append(x_positions[metric][file]) # Use computed x position + data["lang"].append(lang) # Convert to DataFrame df = pd.DataFrame(data) @@ -74,6 +75,8 @@ def plot_violin_from_json(files, name): box_visible=True, meanline_visible=True, width=violin_width, + points="all", + hovertext=file_df["lang"], ) ) diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index 928599e7..4349420b 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -185,6 +185,7 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer): "language_ids": torch.tensor(all_language_ids, dtype=torch.long), "label_weights": torch.stack(all_label_weights, 0), "labels": torch.stack(all_labels, 0), + "lookahead": args.lookahead, } return out diff --git a/wtpsplit/train/utils.py b/wtpsplit/train/utils.py index 1d76cb90..150d90a2 100644 --- a/wtpsplit/train/utils.py +++ b/wtpsplit/train/utils.py @@ -41,6 +41,7 @@ def forward( position_ids=None, labels=None, label_weights=None, + lookahead=None, **kwargs, ): if position_ids is not None: @@ -55,6 +56,7 @@ def forward( language_ids=language_ids, attention_mask=attention_mask, position_ids=position_ids, + lookahead=lookahead, **kwargs, ) )