Skip to content

Commit

Permalink
Update preprocess.py
Browse files Browse the repository at this point in the history
  • Loading branch information
taishan1994 authored Jul 29, 2022
1 parent 1c02b52 commit 562bf73
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,24 @@ def save_file(filename, data ,id2ent):
save_file(os.path.join(mid_data_path,"clue_{}_cut.txt".format(args.max_seq_len)), train_data, id2ent)
dev_data = get_data(processor, mid_data_path, "dev.json", "dev", ent2id, labels, args)
# test_data = get_data(processor, mid_data_path, "test.json", "test", ent2id, labels, args)

elif dataset == "addr":
args.data_dir = './data/addr'
args.max_seq_len = 64

labels_path = os.path.join(args.data_dir, 'mid_data', 'labels.json')
with open(labels_path, 'r') as fp:
labels = json.load(fp)

ent2id_path = os.path.join(args.data_dir, 'mid_data')
with open(os.path.join(ent2id_path, 'nor_ent2id.json'), encoding='utf-8') as f:
ent2id = json.load(f)
id2ent = {v: k for k, v in ent2id.items()}

mid_data_path = os.path.join(args.data_dir, 'mid_data')
processor = NerProcessor(cut_sent=True, cut_sent_len=args.max_seq_len)

train_data = get_data(processor, mid_data_path, "train.json", "train", ent2id, labels, args)
save_file(os.path.join(mid_data_path,"clue_{}_cut.txt".format(args.max_seq_len)), train_data, id2ent)
dev_data = get_data(processor, mid_data_path, "dev.json", "dev", ent2id, labels, args)
# test_data = get_data(processor, mid_data_path, "test.json", "test", ent2id, labels, args)

0 comments on commit 562bf73

Please sign in to comment.