From d355d500381c32da317aff585e08a02ff1a16829 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A5=BF=E8=A5=BF=E5=98=9B=E5=91=A6?= <461600371@qq.com> Date: Fri, 29 Jul 2022 16:07:06 +0800 Subject: [PATCH] Update preprocess.py --- preprocess.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/preprocess.py b/preprocess.py index 6b53645..3b73069 100644 --- a/preprocess.py +++ b/preprocess.py @@ -126,9 +126,11 @@ def convert_bert_example(ex_idx, example: InputExample, tokenizer: BertTokenizer label_ids = label_ids + [0] * pad_length # CLS SEP PAD label都为O assert len(label_ids) == max_seq_len, f'{len(label_ids)}' + # ======================== encode_dict = tokenizer.encode_plus(text=tokens, max_length=max_seq_len, + padding="max_length", truncation='longest_first', return_token_type_ids=True, return_attention_mask=True) @@ -139,8 +141,8 @@ def convert_bert_example(ex_idx, example: InputExample, tokenizer: BertTokenizer if ex_idx < 3: logger.info(f"*** {set_type}_example-{ex_idx} ***") - print(tokenizer.decode(token_ids[:len(raw_text)])) - logger.info(f'text: {" ".join(tokens)}') + print(tokenizer.decode(token_ids[:len(raw_text)+2])) + logger.info(f'text: {str(" ".join(tokens))}') logger.info(f"token_ids: {token_ids}") logger.info(f"attention_masks: {attention_masks}") logger.info(f"token_type_ids: {token_type_ids}") @@ -167,6 +169,9 @@ def convert_examples_to_features(examples, max_seq_len, bert_dir, ent2id, labels logger.info(f'Convert {len(examples)} examples to features') for i, example in enumerate(examples): + # 有可能text为空,过滤掉 + if not example.text: + continue feature, tmp_callback = convert_bert_example( ex_idx=i, example=example, @@ -212,7 +217,7 @@ def save_file(filename, data ,id2ent): if __name__ == '__main__': - dataset = "clue" + dataset = "attr" args = config.Args().get_parser() args.bert_dir = '../model_hub/chinese-bert-wwm-ext/' commonUtils.set_logger(os.path.join(args.log_dir, 'preprocess.log')) @@ -279,7 +284,6 @@ def save_file(filename, data ,id2ent): save_file(os.path.join(mid_data_path,"clue_{}_cut.txt".format(args.max_seq_len)), train_data, id2ent) dev_data = get_data(processor, mid_data_path, "dev.json", "dev", ent2id, labels, args) # test_data = get_data(processor, mid_data_path, "test.json", "test", ent2id, labels, args) - elif dataset == "addr": args.data_dir = './data/addr' args.max_seq_len = 64 @@ -296,6 +300,26 @@ def save_file(filename, data ,id2ent): mid_data_path = os.path.join(args.data_dir, 'mid_data') processor = NerProcessor(cut_sent=True, cut_sent_len=args.max_seq_len) + train_data = get_data(processor, mid_data_path, "train.json", "train", ent2id, labels, args) + save_file(os.path.join(mid_data_path,"clue_{}_cut.txt".format(args.max_seq_len)), train_data, id2ent) + dev_data = get_data(processor, mid_data_path, "dev.json", "dev", ent2id, labels, args) + # test_data = get_data(processor, mid_data_path, "test.json", "test", ent2id, labels, args) + elif dataset == "attr": + args.data_dir = './data/attr' + args.max_seq_len = 128 + + labels_path = os.path.join(args.data_dir, 'mid_data', 'labels.json') + with open(labels_path, 'r') as fp: + labels = json.load(fp) + + ent2id_path = os.path.join(args.data_dir, 'mid_data') + with open(os.path.join(ent2id_path, 'nor_ent2id.json'), encoding='utf-8') as f: + ent2id = json.load(f) + id2ent = {v: k for k, v in ent2id.items()} + + mid_data_path = os.path.join(args.data_dir, 'mid_data') + processor = NerProcessor(cut_sent=True, cut_sent_len=args.max_seq_len) + train_data = get_data(processor, mid_data_path, "train.json", "train", ent2id, labels, args) save_file(os.path.join(mid_data_path,"clue_{}_cut.txt".format(args.max_seq_len)), train_data, id2ent) dev_data = get_data(processor, mid_data_path, "dev.json", "dev", ent2id, labels, args)