From 2469b1eb95d5e4bd92336fd9ae6df47144e3aa99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A5=BF=E8=A5=BF=E5=98=9B=E5=91=A6?= <461600371@qq.com> Date: Wed, 15 Jun 2022 12:20:27 +0800 Subject: [PATCH] Update preprocess.py --- preprocess.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/preprocess.py b/preprocess.py index 5baa489..8182352 100644 --- a/preprocess.py +++ b/preprocess.py @@ -213,9 +213,9 @@ def save_file(filename, data ,id2ent): if __name__ == '__main__': - dataset = "c" + dataset = "chip" args = config.Args().get_parser() - args.bert_dir = '../model_hub/bert-base-chinese/' + args.bert_dir = '../model_hub/chinese-bert-wwm-ext/' commonUtils.set_logger(os.path.join(args.log_dir, 'preprocess.log')) if dataset == "c": @@ -237,4 +237,25 @@ def save_file(filename, data ,id2ent): train_data = get_data(processor, mid_data_path, "train.json", "train", ent2id, labels, args) save_file(os.path.join(mid_data_path,"cner_{}_cut.txt".format(args.max_seq_len)), train_data, id2ent) dev_data = get_data(processor, mid_data_path, "dev.json", "dev", ent2id, labels, args) - test_data = get_data(processor, mid_data_path, "test.json", "test", ent2id, labels, args) \ No newline at end of file + test_data = get_data(processor, mid_data_path, "test.json", "test", ent2id, labels, args) + + elif dataset == "chip": + args.data_dir = './data/CHIP2020' + args.max_seq_len = 150 + + labels_path = os.path.join(args.data_dir, 'mid_data', 'labels.json') + with open(labels_path, 'r') as fp: + labels = json.load(fp) + + ent2id_path = os.path.join(args.data_dir, 'mid_data') + with open(os.path.join(ent2id_path, 'nor_ent2id.json'), encoding='utf-8') as f: + ent2id = json.load(f) + id2ent = {v: k for k, v in ent2id.items()} + + mid_data_path = os.path.join(args.data_dir, 'mid_data') + processor = NerProcessor(cut_sent=True, cut_sent_len=args.max_seq_len) + + train_data = get_data(processor, mid_data_path, "train.json", "train", ent2id, labels, args) + save_file(os.path.join(mid_data_path,"chip_{}_cut.txt".format(args.max_seq_len)), train_data, id2ent) + dev_data = get_data(processor, mid_data_path, "dev.json", "dev", ent2id, labels, args) + # test_data = get_data(processor, mid_data_path, "test.json", "test", ent2id, labels, args)