-
Notifications
You must be signed in to change notification settings - Fork 0
/
masklm.py
72 lines (58 loc) · 3.56 KB
/
masklm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from transformers import AutoTokenizer
class MaskLM(object):
def __init__(self, tokenizer_path='bert-base-chinese', mlm_probability=0.15):
self.mlm_probability = 0.15
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
labels = inputs.clone()
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
if special_tokens_mask is None:
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = special_tokens_mask.bool()
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
class MaskVideo(object):
def __init__(self, mlm_probability=0.15):
self.mlm_probability = 0.15
def torch_mask_frames(self, video_feature, video_mask):
probability_matrix = torch.full(video_mask.shape, 0.9 * self.mlm_probability)
probability_matrix = probability_matrix * video_mask
masked_indices = torch.bernoulli(probability_matrix).bool()
video_labels_index = torch.arange(video_feature.size(0) * video_feature.size(1)).view(-1, video_feature.size(1))
video_labels_index = -100 * ~masked_indices + video_labels_index * masked_indices
# 90% mask video fill all 0.0
masked_indices_unsqueeze = masked_indices.unsqueeze(-1).expand_as(video_feature)
inputs = video_feature.data.masked_fill(masked_indices_unsqueeze, 0.0)
labels = video_feature[masked_indices_unsqueeze].contiguous().view(-1, video_feature.size(2))
return inputs, video_labels_index
class ShuffleVideo(object):
def __init__(self):
pass
def torch_shuf_video(self, video_feature):
bs = video_feature.size()[0]
# batch 内前一半 video 保持原顺序,后一半 video 逆序
shuf_index = torch.tensor(list(range(bs // 2)) + list(range(bs //2, bs))[::-1])
# shuf 后的 label
label = (torch.tensor(list(range(bs))) == shuf_index).float()
video_feature = video_feature[shuf_index]
return video_feature, label