Skip to content

Commit

Permalink
proper lookahead
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Feb 18, 2024
1 parent e5ee9d2 commit dc68dca
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 21 deletions.
2 changes: 2 additions & 0 deletions wtpsplit/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ class SubwordXLMConfig(XLMRobertaConfig):

def __init__(
self,
lookahead=None,
**kwargs,
):
super().__init__(**kwargs)
self.mixture_name = "xlm-token"
self.lookahead = lookahead


AutoConfig.register("bert-char", BertCharConfig)
Expand Down
16 changes: 8 additions & 8 deletions wtpsplit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,6 @@ def forward(
output_hidden_states: Optional[bool] = None,
hashed_ids: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
lookahead: Optional[int] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -1006,7 +1005,6 @@ def forward(
hashed_ids: Optional[torch.Tensor] = None,
language_ids=None,
return_dict: Optional[bool] = None,
lookahead: Optional[int] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand All @@ -1024,7 +1022,6 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
lookahead=lookahead,
)

sequence_output = outputs[0]
Expand Down Expand Up @@ -1063,6 +1060,9 @@ def __init__(self, config, add_pooling_layer=True):
self.encoder = XLMRobertaEncoder(config)

self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
self.effective_lookahead = (
config.lookahead // config.num_hidden_layers if config.lookahead is not None else None
)

# Initialize weights and apply final processing
self.post_init()
Expand All @@ -1083,7 +1083,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
lookahead: Optional[int] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -1144,7 +1143,9 @@ def forward(

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, lookahead)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, self.effective_lookahead
)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down Expand Up @@ -1289,7 +1290,7 @@ def get_extended_attention_mask(
print(summary(backbone, depth=4))

# some sample input
text = "A sentence.\n And"
text = "A sentence. Now we move on. And on and this is the last sentence. Now, we are starting to move on to the next sentence. This is the last sentence."
tokenizer = AutoTokenizer.from_pretrained(model_str)

tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False, pad_to_multiple_of=512, padding=True)
Expand All @@ -1301,5 +1302,4 @@ def get_extended_attention_mask(
print(tokens)

# forward pass
lookahead = 512
print(backbone(**tokens, lookahead=lookahead))
print(backbone(**tokens))
25 changes: 14 additions & 11 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
logger = logging.getLogger(__name__)


# TODO: double-check checkpointing and saving (also to txt)

# os.environ["PJRT_DEVICE"] = "None"


Expand Down Expand Up @@ -186,7 +184,6 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer):
"language_ids": torch.tensor(all_language_ids, dtype=torch.long),
"label_weights": torch.stack(all_label_weights, 0),
"labels": torch.stack(all_labels, 0),
"lookahead": args.lookahead,
}

return out
Expand Down Expand Up @@ -214,6 +211,7 @@ def main():
args.model_name_or_path,
num_hidden_layers=args.num_hidden_layers,
num_labels=num_labels,
lookahead=args.lookahead,
)
backbone = SubwordXLMForTokenClassification(config)

Expand All @@ -222,6 +220,7 @@ def main():
args.model_name_or_path,
num_hidden_layers=args.num_hidden_layers,
num_labels=num_labels,
lookahead=args.lookahead,
)
backbone = SubwordXLMForTokenClassification.from_pretrained(
args.model_name_or_path,
Expand All @@ -236,6 +235,9 @@ def main():
# used later to filter out special tokens
special_tokens_ids = set(tokenizer.all_special_ids)
special_tokens_ids.discard(custom_token_id)
if args.lookahead:
assert args.lookahead % args.num_hidden_layers == 0


else:
tokenizer = None
Expand Down Expand Up @@ -265,7 +267,8 @@ def main():
backbone = LACanineForTokenClassification.from_pretrained(
args.model_name_or_path, ignore_mismatched_sizes=True, config=config
)



model = Model(
backbone,
loss_margin=args.loss_margin,
Expand Down Expand Up @@ -637,13 +640,13 @@ def compute_metrics(trainer):
training_args.adapter_lr_multiplier = args.adapter_lr_multiplier

# give .map in multiprocessing enough of time to finish, to be safe
time.sleep(10)
if training_args.local_rank == 0:
# since both share the *same* cache_dir, we cannot simply call dataset.cleanup_cache_files()
# because that would remove the cache files of the other dataset!
cleanup_cache_files([train_dataset, valid_dataset])
logger.warning("Cleaned up cache files.")
time.sleep(10)
# time.sleep(10)
# if training_args.local_rank == 0:
# # since both share the *same* cache_dir, we cannot simply call dataset.cleanup_cache_files()
# # because that would remove the cache files of the other dataset!
# cleanup_cache_files([train_dataset, valid_dataset])
# logger.warning("Cleaned up cache files.")
# time.sleep(10)

trainer = Trainer(
model,
Expand Down
2 changes: 0 additions & 2 deletions wtpsplit/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def forward(
position_ids=None,
labels=None,
label_weights=None,
lookahead=None,
**kwargs,
):
if position_ids is not None:
Expand All @@ -57,7 +56,6 @@ def forward(
language_ids=language_ids,
attention_mask=attention_mask,
position_ids=position_ids,
lookahead=lookahead,
**kwargs,
)
)
Expand Down

0 comments on commit dc68dca

Please sign in to comment.