From 5743050cc41f0f21986c9f938db5ac1726052629 Mon Sep 17 00:00:00 2001 From: Kaeli <1214490813@qq.com> Date: Thu, 9 May 2024 19:12:55 +0800 Subject: [PATCH] Update (#403) * update * update --------- Co-authored-by: kaeli --- uer/layers/multi_headed_attn.py | 3 ++- uer/layers/transformer.py | 2 +- uer/trainer.py | 2 +- uer/utils/constants.py | 18 ++++++++---------- uer/utils/dataloader.py | 2 +- uer/utils/dataset.py | 23 +++++++++++++++-------- uer/utils/vocab.py | 20 ++++++++++++++------ 7 files changed, 42 insertions(+), 28 deletions(-) diff --git a/uer/layers/multi_headed_attn.py b/uer/layers/multi_headed_attn.py index 3d29e974..4a9408aa 100755 --- a/uer/layers/multi_headed_attn.py +++ b/uer/layers/multi_headed_attn.py @@ -68,7 +68,8 @@ def unshape(x): if prev_attn is not None: scores += prev_attn prev_attn_out = scores - probs = nn.Softmax(dim=-1)(scores) + # probs = nn.Softmax(dim=-1)(scores) + probs = nn.functional.softmax(scores, dim=-1, dtype=torch.float32).to(query.dtype) probs = self.dropout(probs) output = unshape(torch.matmul(probs, value)) output = self.final_linear(output) diff --git a/uer/layers/transformer.py b/uer/layers/transformer.py index fa0342b8..07300396 100755 --- a/uer/layers/transformer.py +++ b/uer/layers/transformer.py @@ -47,7 +47,7 @@ def __init__(self, args): self.layer_norm_1 = LayerNorm(args.hidden_size) self.layer_norm_2 = LayerNorm(args.hidden_size) - def forward(self, hidden, mask, position_bias = None, has_residual_attention=False, prev_attn=None): + def forward(self, hidden, mask, position_bias=None, has_residual_attention=False, prev_attn=None): """ Args: hidden: [batch_size x seq_length x emb_size] diff --git a/uer/trainer.py b/uer/trainer.py index 2aa97b44..179bb58d 100644 --- a/uer/trainer.py +++ b/uer/trainer.py @@ -44,7 +44,7 @@ def init_model(args): def init_optimizer(args, model): # Build optimizer. param_optimizer = list(model.named_parameters()) - no_decay = ["bias", "gamma", "beta"] + no_decay = ["bias", "gamma", "beta", "layer_norm"] optimizer_grouped_parameters = [ {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01}, {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0} diff --git a/uer/utils/constants.py b/uer/utils/constants.py index 5cc2c2e6..c5e7801e 100644 --- a/uer/utils/constants.py +++ b/uer/utils/constants.py @@ -4,13 +4,11 @@ with open("models/special_tokens_map.json", mode="r", encoding="utf-8") as f: special_tokens_map = json.load(f) -UNK_TOKEN = special_tokens_map["unk_token"] -CLS_TOKEN = special_tokens_map["cls_token"] -SEP_TOKEN = special_tokens_map["sep_token"] -MASK_TOKEN = special_tokens_map["mask_token"] -PAD_TOKEN = special_tokens_map["pad_token"] -try: - # e.g. , , ... , should have consecutive IDs. - SENTINEL_TOKEN = special_tokens_map["sentinel_token"] -except KeyError: - pass +UNK_TOKEN = special_tokens_map.get("unk_token") +CLS_TOKEN = special_tokens_map.get("cls_token") +SEP_TOKEN = special_tokens_map.get("sep_token") +MASK_TOKEN = special_tokens_map.get("mask_token") +PAD_TOKEN = special_tokens_map.get("pad_token") + +# e.g. , , ... , should have consecutive IDs. +SENTINEL_TOKEN = special_tokens_map.get("sentinel_token") diff --git a/uer/utils/dataloader.py b/uer/utils/dataloader.py index 3bc02eb6..d1bcf3eb 100644 --- a/uer/utils/dataloader.py +++ b/uer/utils/dataloader.py @@ -502,7 +502,7 @@ def __iter__(self): seg_single += [0] * pad_num seg.append(seg_single) - if len(ins) == 4 : + if len(ins) == 4: src.append(src_single) masked_words_num += len(ins[1]) tgt_mlm.append([0] * len(src_single)) diff --git a/uer/utils/dataset.py b/uer/utils/dataset.py index fc1af0a6..fb2ee8e8 100644 --- a/uer/utils/dataset.py +++ b/uer/utils/dataset.py @@ -16,7 +16,7 @@ def merge_dataset(dataset_path, workers_num): for i in range(workers_num): tmp_dataset_reader = open("dataset-tmp-" + str(i) + ".pt", "rb") while True: - tmp_data = tmp_dataset_reader.read(2**20) + tmp_data = tmp_dataset_reader.read(2 ** 20) if tmp_data: dataset_writer.write(tmp_data) else: @@ -210,7 +210,8 @@ def create_ins_from_doc(self, all_documents, document_index): pad_num = self.seq_length - len(src) if not self.dynamic_masking: - src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) + src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, + self.span_geo_prob, self.span_max_length) src = (src, pad_num) instance = (src, tgt_mlm, is_random_next, seg_pos) else: @@ -244,7 +245,8 @@ def worker(self, proc_id, start, end): line = f.readline() pos += 1 - document = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(line)) + [self.vocab.get(SEP_TOKEN)] + document = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(line)) + [self.vocab.get(SEP_TOKEN)] if self.full_sentences: if len(document) > 0: @@ -292,7 +294,8 @@ def build_instances(self, all_documents): seg_pos = [len(src)] if not self.dynamic_masking: - src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) + src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, + self.span_masking, self.span_geo_prob, self.span_max_length) instance = ((src, 0), tgt, seg_pos) else: instance = ((src, 0), seg_pos) @@ -309,7 +312,8 @@ def build_instances(self, all_documents): pad_num = self.seq_length - len(src) if not self.dynamic_masking: - src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) + src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, + self.span_geo_prob, self.span_max_length) instance = ((src, pad_num), tgt, seg_pos) else: instance = ((src, pad_num), seg_pos) @@ -416,7 +420,8 @@ def create_ins_from_doc(self, document): pad_num = self.seq_length - len(src) if not self.dynamic_masking: - src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) + src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, + self.span_geo_prob, self.span_max_length) src = (src, pad_num) instance = (src, tgt_mlm, is_wrong_order, seg_pos) else: @@ -811,7 +816,8 @@ def worker(self, proc_id, start, end): if len(line) == 2: label = int(line[0]) text = line[1] - src = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) + [self.vocab.get(SEP_TOKEN)] + src = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(text)) + [self.vocab.get(SEP_TOKEN)] tgt_cls = label seg_pos = [len(src)] elif len(line) == 3: # For sentence pair input. @@ -847,7 +853,8 @@ def worker(self, proc_id, start, end): if not self.dynamic_masking: src_single, pad_num = src - src_single, tgt_mlm = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) + src_single, tgt_mlm = mask_seq(src_single, self.tokenizer, self.whole_word_masking, + self.span_masking, self.span_geo_prob, self.span_max_length) src = (src_single, pad_num) instance = (src, tgt_mlm, tgt_cls, seg_pos) else: diff --git a/uer/utils/vocab.py b/uer/utils/vocab.py index 6a62fe86..f9763cec 100644 --- a/uer/utils/vocab.py +++ b/uer/utils/vocab.py @@ -15,12 +15,20 @@ def __init__(self): self.reserved_vocab_path = \ os.path.abspath(os.path.join(os.path.dirname(__file__), "../../models/reserved_vocab.txt")) - def load(self, vocab_path, is_quiet=False): - with open(vocab_path, mode="r", encoding="utf-8") as reader: - for index, line in enumerate(reader): - w = line.strip("\r\n").split()[0] if line.strip() else line.strip("\r\n") - self.w2i[w] = index - self.i2w.append(w) + def load(self, vocab_path, is_quiet=False, is_vocab_json=False): + if is_vocab_json: + with open(vocab_path, 'r') as file: + voc = json.load(file) + sorted_voc = sorted(voc.items(), key=lambda x: x[1]) + for w, i in sorted_voc: + self.w2i[w] = i + self.i2w.append(w) + else: + with open(vocab_path, mode="r", encoding="utf-8") as reader: + for index, line in enumerate(reader): + w = line.strip("\r\n").split()[0] if line.strip() else line.strip("\r\n") + self.w2i[w] = index + self.i2w.append(w) if not is_quiet: print("Vocabulary size: ", len(self))