Skip to content

Commit

Permalink
Update preprocess.py
Browse files Browse the repository at this point in the history
  • Loading branch information
taishan1994 authored Jun 15, 2022
1 parent da32f45 commit 2469b1e
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,9 @@ def save_file(filename, data ,id2ent):

if __name__ == '__main__':

dataset = "c"
dataset = "chip"
args = config.Args().get_parser()
args.bert_dir = '../model_hub/bert-base-chinese/'
args.bert_dir = '../model_hub/chinese-bert-wwm-ext/'
commonUtils.set_logger(os.path.join(args.log_dir, 'preprocess.log'))

if dataset == "c":
Expand All @@ -237,4 +237,25 @@ def save_file(filename, data ,id2ent):
train_data = get_data(processor, mid_data_path, "train.json", "train", ent2id, labels, args)
save_file(os.path.join(mid_data_path,"cner_{}_cut.txt".format(args.max_seq_len)), train_data, id2ent)
dev_data = get_data(processor, mid_data_path, "dev.json", "dev", ent2id, labels, args)
test_data = get_data(processor, mid_data_path, "test.json", "test", ent2id, labels, args)
test_data = get_data(processor, mid_data_path, "test.json", "test", ent2id, labels, args)

elif dataset == "chip":
args.data_dir = './data/CHIP2020'
args.max_seq_len = 150

labels_path = os.path.join(args.data_dir, 'mid_data', 'labels.json')
with open(labels_path, 'r') as fp:
labels = json.load(fp)

ent2id_path = os.path.join(args.data_dir, 'mid_data')
with open(os.path.join(ent2id_path, 'nor_ent2id.json'), encoding='utf-8') as f:
ent2id = json.load(f)
id2ent = {v: k for k, v in ent2id.items()}

mid_data_path = os.path.join(args.data_dir, 'mid_data')
processor = NerProcessor(cut_sent=True, cut_sent_len=args.max_seq_len)

train_data = get_data(processor, mid_data_path, "train.json", "train", ent2id, labels, args)
save_file(os.path.join(mid_data_path,"chip_{}_cut.txt".format(args.max_seq_len)), train_data, id2ent)
dev_data = get_data(processor, mid_data_path, "dev.json", "dev", ent2id, labels, args)
# test_data = get_data(processor, mid_data_path, "test.json", "test", ent2id, labels, args)

0 comments on commit 2469b1e

Please sign in to comment.