Skip to content

Commit

Permalink
Update preprocess.py
Browse files Browse the repository at this point in the history
  • Loading branch information
taishan1994 authored Jul 29, 2022
1 parent 5082095 commit d355d50
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ def convert_bert_example(ex_idx, example: InputExample, tokenizer: BertTokenizer
label_ids = label_ids + [0] * pad_length # CLS SEP PAD label都为O

assert len(label_ids) == max_seq_len, f'{len(label_ids)}'

# ========================
encode_dict = tokenizer.encode_plus(text=tokens,
max_length=max_seq_len,
padding="max_length",
truncation='longest_first',
return_token_type_ids=True,
return_attention_mask=True)
Expand All @@ -139,8 +141,8 @@ def convert_bert_example(ex_idx, example: InputExample, tokenizer: BertTokenizer

if ex_idx < 3:
logger.info(f"*** {set_type}_example-{ex_idx} ***")
print(tokenizer.decode(token_ids[:len(raw_text)]))
logger.info(f'text: {" ".join(tokens)}')
print(tokenizer.decode(token_ids[:len(raw_text)+2]))
logger.info(f'text: {str(" ".join(tokens))}')
logger.info(f"token_ids: {token_ids}")
logger.info(f"attention_masks: {attention_masks}")
logger.info(f"token_type_ids: {token_type_ids}")
Expand All @@ -167,6 +169,9 @@ def convert_examples_to_features(examples, max_seq_len, bert_dir, ent2id, labels
logger.info(f'Convert {len(examples)} examples to features')

for i, example in enumerate(examples):
# 有可能text为空,过滤掉
if not example.text:
continue
feature, tmp_callback = convert_bert_example(
ex_idx=i,
example=example,
Expand Down Expand Up @@ -212,7 +217,7 @@ def save_file(filename, data ,id2ent):

if __name__ == '__main__':

dataset = "clue"
dataset = "attr"
args = config.Args().get_parser()
args.bert_dir = '../model_hub/chinese-bert-wwm-ext/'
commonUtils.set_logger(os.path.join(args.log_dir, 'preprocess.log'))
Expand Down Expand Up @@ -279,7 +284,6 @@ def save_file(filename, data ,id2ent):
save_file(os.path.join(mid_data_path,"clue_{}_cut.txt".format(args.max_seq_len)), train_data, id2ent)
dev_data = get_data(processor, mid_data_path, "dev.json", "dev", ent2id, labels, args)
# test_data = get_data(processor, mid_data_path, "test.json", "test", ent2id, labels, args)

elif dataset == "addr":
args.data_dir = './data/addr'
args.max_seq_len = 64
Expand All @@ -296,6 +300,26 @@ def save_file(filename, data ,id2ent):
mid_data_path = os.path.join(args.data_dir, 'mid_data')
processor = NerProcessor(cut_sent=True, cut_sent_len=args.max_seq_len)

train_data = get_data(processor, mid_data_path, "train.json", "train", ent2id, labels, args)
save_file(os.path.join(mid_data_path,"clue_{}_cut.txt".format(args.max_seq_len)), train_data, id2ent)
dev_data = get_data(processor, mid_data_path, "dev.json", "dev", ent2id, labels, args)
# test_data = get_data(processor, mid_data_path, "test.json", "test", ent2id, labels, args)
elif dataset == "attr":
args.data_dir = './data/attr'
args.max_seq_len = 128

labels_path = os.path.join(args.data_dir, 'mid_data', 'labels.json')
with open(labels_path, 'r') as fp:
labels = json.load(fp)

ent2id_path = os.path.join(args.data_dir, 'mid_data')
with open(os.path.join(ent2id_path, 'nor_ent2id.json'), encoding='utf-8') as f:
ent2id = json.load(f)
id2ent = {v: k for k, v in ent2id.items()}

mid_data_path = os.path.join(args.data_dir, 'mid_data')
processor = NerProcessor(cut_sent=True, cut_sent_len=args.max_seq_len)

train_data = get_data(processor, mid_data_path, "train.json", "train", ent2id, labels, args)
save_file(os.path.join(mid_data_path,"clue_{}_cut.txt".format(args.max_seq_len)), train_data, id2ent)
dev_data = get_data(processor, mid_data_path, "dev.json", "dev", ent2id, labels, args)
Expand Down

0 comments on commit d355d50

Please sign in to comment.