From dc68dca51af74fc430dbb22c4d6bb1bd6603ddde Mon Sep 17 00:00:00 2001 From: markus583 Date: Sun, 18 Feb 2024 11:30:24 +0000 Subject: [PATCH] proper lookahead --- wtpsplit/configs.py | 2 ++ wtpsplit/models.py | 16 ++++++++-------- wtpsplit/train/train.py | 25 ++++++++++++++----------- wtpsplit/train/utils.py | 2 -- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/wtpsplit/configs.py b/wtpsplit/configs.py index f1d570a3..4eb5f25f 100644 --- a/wtpsplit/configs.py +++ b/wtpsplit/configs.py @@ -51,10 +51,12 @@ class SubwordXLMConfig(XLMRobertaConfig): def __init__( self, + lookahead=None, **kwargs, ): super().__init__(**kwargs) self.mixture_name = "xlm-token" + self.lookahead = lookahead AutoConfig.register("bert-char", BertCharConfig) diff --git a/wtpsplit/models.py b/wtpsplit/models.py index 118427cf..29582e08 100644 --- a/wtpsplit/models.py +++ b/wtpsplit/models.py @@ -840,7 +840,6 @@ def forward( output_hidden_states: Optional[bool] = None, hashed_ids: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, - lookahead: Optional[int] = None, ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1006,7 +1005,6 @@ def forward( hashed_ids: Optional[torch.Tensor] = None, language_ids=None, return_dict: Optional[bool] = None, - lookahead: Optional[int] = None, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1024,7 +1022,6 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - lookahead=lookahead, ) sequence_output = outputs[0] @@ -1063,6 +1060,9 @@ def __init__(self, config, add_pooling_layer=True): self.encoder = XLMRobertaEncoder(config) self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None + self.effective_lookahead = ( + config.lookahead // config.num_hidden_layers if config.lookahead is not None else None + ) # Initialize weights and apply final processing self.post_init() @@ -1083,7 +1083,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - lookahead: Optional[int] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): @@ -1144,7 +1143,9 @@ def forward( # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, lookahead) + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, self.effective_lookahead + ) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1289,7 +1290,7 @@ def get_extended_attention_mask( print(summary(backbone, depth=4)) # some sample input - text = "A sentence.\n And" + text = "A sentence. Now we move on. And on and this is the last sentence. Now, we are starting to move on to the next sentence. This is the last sentence." tokenizer = AutoTokenizer.from_pretrained(model_str) tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False, pad_to_multiple_of=512, padding=True) @@ -1301,5 +1302,4 @@ def get_extended_attention_mask( print(tokens) # forward pass - lookahead = 512 - print(backbone(**tokens, lookahead=lookahead)) + print(backbone(**tokens)) diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index b9a24c40..55ea91dc 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -40,8 +40,6 @@ logger = logging.getLogger(__name__) -# TODO: double-check checkpointing and saving (also to txt) - # os.environ["PJRT_DEVICE"] = "None" @@ -186,7 +184,6 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer): "language_ids": torch.tensor(all_language_ids, dtype=torch.long), "label_weights": torch.stack(all_label_weights, 0), "labels": torch.stack(all_labels, 0), - "lookahead": args.lookahead, } return out @@ -214,6 +211,7 @@ def main(): args.model_name_or_path, num_hidden_layers=args.num_hidden_layers, num_labels=num_labels, + lookahead=args.lookahead, ) backbone = SubwordXLMForTokenClassification(config) @@ -222,6 +220,7 @@ def main(): args.model_name_or_path, num_hidden_layers=args.num_hidden_layers, num_labels=num_labels, + lookahead=args.lookahead, ) backbone = SubwordXLMForTokenClassification.from_pretrained( args.model_name_or_path, @@ -236,6 +235,9 @@ def main(): # used later to filter out special tokens special_tokens_ids = set(tokenizer.all_special_ids) special_tokens_ids.discard(custom_token_id) + if args.lookahead: + assert args.lookahead % args.num_hidden_layers == 0 + else: tokenizer = None @@ -265,7 +267,8 @@ def main(): backbone = LACanineForTokenClassification.from_pretrained( args.model_name_or_path, ignore_mismatched_sizes=True, config=config ) - + + model = Model( backbone, loss_margin=args.loss_margin, @@ -637,13 +640,13 @@ def compute_metrics(trainer): training_args.adapter_lr_multiplier = args.adapter_lr_multiplier # give .map in multiprocessing enough of time to finish, to be safe - time.sleep(10) - if training_args.local_rank == 0: - # since both share the *same* cache_dir, we cannot simply call dataset.cleanup_cache_files() - # because that would remove the cache files of the other dataset! - cleanup_cache_files([train_dataset, valid_dataset]) - logger.warning("Cleaned up cache files.") - time.sleep(10) + # time.sleep(10) + # if training_args.local_rank == 0: + # # since both share the *same* cache_dir, we cannot simply call dataset.cleanup_cache_files() + # # because that would remove the cache files of the other dataset! + # cleanup_cache_files([train_dataset, valid_dataset]) + # logger.warning("Cleaned up cache files.") + # time.sleep(10) trainer = Trainer( model, diff --git a/wtpsplit/train/utils.py b/wtpsplit/train/utils.py index 1132e450..92dafd98 100644 --- a/wtpsplit/train/utils.py +++ b/wtpsplit/train/utils.py @@ -41,7 +41,6 @@ def forward( position_ids=None, labels=None, label_weights=None, - lookahead=None, **kwargs, ): if position_ids is not None: @@ -57,7 +56,6 @@ def forward( language_ids=language_ids, attention_mask=attention_mask, position_ids=position_ids, - lookahead=lookahead, **kwargs, ) )