Skip to content

Commit

Permalink
Update (#403)
Browse files Browse the repository at this point in the history
* update

* update

---------

Co-authored-by: kaeli <[email protected]>
  • Loading branch information
wmpscc and kaeli authored May 9, 2024
1 parent 6bcb2eb commit 5743050
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 28 deletions.
3 changes: 2 additions & 1 deletion uer/layers/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def unshape(x):
if prev_attn is not None:
scores += prev_attn
prev_attn_out = scores
probs = nn.Softmax(dim=-1)(scores)
# probs = nn.Softmax(dim=-1)(scores)
probs = nn.functional.softmax(scores, dim=-1, dtype=torch.float32).to(query.dtype)
probs = self.dropout(probs)
output = unshape(torch.matmul(probs, value))
output = self.final_linear(output)
Expand Down
2 changes: 1 addition & 1 deletion uer/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, args):
self.layer_norm_1 = LayerNorm(args.hidden_size)
self.layer_norm_2 = LayerNorm(args.hidden_size)

def forward(self, hidden, mask, position_bias = None, has_residual_attention=False, prev_attn=None):
def forward(self, hidden, mask, position_bias=None, has_residual_attention=False, prev_attn=None):
"""
Args:
hidden: [batch_size x seq_length x emb_size]
Expand Down
2 changes: 1 addition & 1 deletion uer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def init_model(args):
def init_optimizer(args, model):
# Build optimizer.
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "gamma", "beta"]
no_decay = ["bias", "gamma", "beta", "layer_norm"]
optimizer_grouped_parameters = [
{"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
Expand Down
18 changes: 8 additions & 10 deletions uer/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
with open("models/special_tokens_map.json", mode="r", encoding="utf-8") as f:
special_tokens_map = json.load(f)

UNK_TOKEN = special_tokens_map["unk_token"]
CLS_TOKEN = special_tokens_map["cls_token"]
SEP_TOKEN = special_tokens_map["sep_token"]
MASK_TOKEN = special_tokens_map["mask_token"]
PAD_TOKEN = special_tokens_map["pad_token"]
try:
# e.g. <extra_id_0>, <extra_id_1>, ... , should have consecutive IDs.
SENTINEL_TOKEN = special_tokens_map["sentinel_token"]
except KeyError:
pass
UNK_TOKEN = special_tokens_map.get("unk_token")
CLS_TOKEN = special_tokens_map.get("cls_token")
SEP_TOKEN = special_tokens_map.get("sep_token")
MASK_TOKEN = special_tokens_map.get("mask_token")
PAD_TOKEN = special_tokens_map.get("pad_token")

# e.g. <extra_id_0>, <extra_id_1>, ... , should have consecutive IDs.
SENTINEL_TOKEN = special_tokens_map.get("sentinel_token")
2 changes: 1 addition & 1 deletion uer/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def __iter__(self):
seg_single += [0] * pad_num
seg.append(seg_single)

if len(ins) == 4 :
if len(ins) == 4:
src.append(src_single)
masked_words_num += len(ins[1])
tgt_mlm.append([0] * len(src_single))
Expand Down
23 changes: 15 additions & 8 deletions uer/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def merge_dataset(dataset_path, workers_num):
for i in range(workers_num):
tmp_dataset_reader = open("dataset-tmp-" + str(i) + ".pt", "rb")
while True:
tmp_data = tmp_dataset_reader.read(2**20)
tmp_data = tmp_dataset_reader.read(2 ** 20)
if tmp_data:
dataset_writer.write(tmp_data)
else:
Expand Down Expand Up @@ -210,7 +210,8 @@ def create_ins_from_doc(self, all_documents, document_index):
pad_num = self.seq_length - len(src)

if not self.dynamic_masking:
src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking,
self.span_geo_prob, self.span_max_length)
src = (src, pad_num)
instance = (src, tgt_mlm, is_random_next, seg_pos)
else:
Expand Down Expand Up @@ -244,7 +245,8 @@ def worker(self, proc_id, start, end):
line = f.readline()
pos += 1

document = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(line)) + [self.vocab.get(SEP_TOKEN)]
document = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(line)) + [self.vocab.get(SEP_TOKEN)]

if self.full_sentences:
if len(document) > 0:
Expand Down Expand Up @@ -292,7 +294,8 @@ def build_instances(self, all_documents):
seg_pos = [len(src)]

if not self.dynamic_masking:
src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking,
self.span_masking, self.span_geo_prob, self.span_max_length)
instance = ((src, 0), tgt, seg_pos)
else:
instance = ((src, 0), seg_pos)
Expand All @@ -309,7 +312,8 @@ def build_instances(self, all_documents):
pad_num = self.seq_length - len(src)

if not self.dynamic_masking:
src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking,
self.span_geo_prob, self.span_max_length)
instance = ((src, pad_num), tgt, seg_pos)
else:
instance = ((src, pad_num), seg_pos)
Expand Down Expand Up @@ -416,7 +420,8 @@ def create_ins_from_doc(self, document):
pad_num = self.seq_length - len(src)

if not self.dynamic_masking:
src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking,
self.span_geo_prob, self.span_max_length)
src = (src, pad_num)
instance = (src, tgt_mlm, is_wrong_order, seg_pos)
else:
Expand Down Expand Up @@ -811,7 +816,8 @@ def worker(self, proc_id, start, end):
if len(line) == 2:
label = int(line[0])
text = line[1]
src = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) + [self.vocab.get(SEP_TOKEN)]
src = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(text)) + [self.vocab.get(SEP_TOKEN)]
tgt_cls = label
seg_pos = [len(src)]
elif len(line) == 3: # For sentence pair input.
Expand Down Expand Up @@ -847,7 +853,8 @@ def worker(self, proc_id, start, end):

if not self.dynamic_masking:
src_single, pad_num = src
src_single, tgt_mlm = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
src_single, tgt_mlm = mask_seq(src_single, self.tokenizer, self.whole_word_masking,
self.span_masking, self.span_geo_prob, self.span_max_length)
src = (src_single, pad_num)
instance = (src, tgt_mlm, tgt_cls, seg_pos)
else:
Expand Down
20 changes: 14 additions & 6 deletions uer/utils/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@ def __init__(self):
self.reserved_vocab_path = \
os.path.abspath(os.path.join(os.path.dirname(__file__), "../../models/reserved_vocab.txt"))

def load(self, vocab_path, is_quiet=False):
with open(vocab_path, mode="r", encoding="utf-8") as reader:
for index, line in enumerate(reader):
w = line.strip("\r\n").split()[0] if line.strip() else line.strip("\r\n")
self.w2i[w] = index
self.i2w.append(w)
def load(self, vocab_path, is_quiet=False, is_vocab_json=False):
if is_vocab_json:
with open(vocab_path, 'r') as file:
voc = json.load(file)
sorted_voc = sorted(voc.items(), key=lambda x: x[1])
for w, i in sorted_voc:
self.w2i[w] = i
self.i2w.append(w)
else:
with open(vocab_path, mode="r", encoding="utf-8") as reader:
for index, line in enumerate(reader):
w = line.strip("\r\n").split()[0] if line.strip() else line.strip("\r\n")
self.w2i[w] = index
self.i2w.append(w)
if not is_quiet:
print("Vocabulary size: ", len(self))

Expand Down

0 comments on commit 5743050

Please sign in to comment.