-
Notifications
You must be signed in to change notification settings - Fork 0
/
Utils.py
196 lines (174 loc) · 6.94 KB
/
Utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
class Utils:
def __init__():
pass
def remove_html(line):
"""
正则去掉HTML标签元素
"""
reg = re.compile('<[^>]*>')
content = reg.sub('', line)
return content.strip()
def logger():
"""
日志记录管理
"""
import logging
import datetime
#第一步 创建一个logger
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
now_date = datetime.datetime.now()
now_date = now_date.strftime("%Y-%m-%d_%H-%M-%S")
#第2步,创建一个handler,用于写入日志文件
file_handler = logging.FileHandler("./log/"+str(now_date)+".log",mode='w')
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(
logging.Formatter(
fmt='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
)
# 添加handler到logger中
logger.addHandler(file_handler)
# 第三步,创建一个handler,用于输出到控制台
# console_handler = logging.StreamHandler()
# console_handler.setLevel(logging.INFO)
# console_handler.setFormatter(
# logging.Formatter(
# fmt='%(asctime)s - %(levelname)s: %(message)s',
# datefmt='%Y-%m-%d %H:%M:%S')
# )
# logger.addHandler(console_handler)
# logger.critical('this is a logger critical message')
logger.info("XXXX模型")
def extract_entity(chars,tags):
"""
根据标签和原始句子返回,对应的实体
chars:一句话 ["CLS","张", "三","是","我","们","班","主","任","SEP"]
tags:标签列表["O","B-PER","I-PER","O","O","O","B-PER","I-PER","I-PER","O"]
返回一段话中的实体[['张三', 'PER'], ['班主任', 'PER']]
"""
result = []
pre = ''
w = []
for idx,tag in enumerate(tags):
if not pre:
if tag.startswith('B'):
pre = tag.split('-')[1] #pre LOC
w.append(chars[idx])#w 张
else:
if tag == f'I-{pre}': #I-LOC True
w.append(chars[idx]) #w 张三
else:
result.append([w,pre])
w = []
pre = ''
if tag.startswith('B'):
pre = tag.split('-')[1]
w.append(chars[idx])
return [[''.join(x[0]),x[1]] for x in result]
def extract_entity(chars, tags):
"""
根据标签和原始句子返回,对应的实体
chars:一句话 ["CLS","张", "三","是","我","们","班","主","任","SEP"]
tags:标签列表["O","B-PER","I-PER","O","O","O","B-PER","I-PER","I-PER","O"]
新的返回:[{'tokens': '张三', 'type': 'PER', 'start': 1, 'end': 2}, {'tokens': '班主任', 'type': 'PER', 'start': 6, 'end': 8}]
"""
result = []
pre = ''
w = []
s_index = []
e_index = []
for idx, tag in enumerate(tags):
if not pre:
if tag.startswith('B'):
pre = tag.split('-')[1] # pre PER
w.append(chars[idx]) # w 张
s_index.append(idx) # 1
else:
if tag == f'I-{pre}': # I-PER True
w.append(chars[idx]) # w 张--->张三
e_index.append(idx) # 2
else:
result.append([w, pre,s_index,e_index])
w = []
pre = ''
s_index = []
e_index = []
if tag.startswith('B'):
pre = tag.split('-')[1]
w.append(chars[idx])
res = []
for x in result:
item = {"tokens":''.join(x[0]),"type":x[1],"start":x[2][0],"end":x[3][-1]}
res.append(item)
return res
def read_CoNLL(filename):
"""
读取CoNLL文件,返回句子和标签集合
"""
X, y = [], []
labels = []
with open(filename, 'r', encoding='utf-8') as f:
x0, y0 = [], []
for line in f:
data = line.strip()
if data:
x0.append(data.split()[0])
y0.append(data.split()[1])
else:
if len(x0)!=0:
X.append(x0)
y.append(y0)
x0, y0 = [], []
if len(x0)!=0:
X.append(x0)
y.append(y0)
return X, y
def encode_plus(tokenizer, sequence):
"""
位一句话编码
"""
# sequence: ["中", "国", "的", "首", "都", "是", "北", "京"]
input_ids = []
pred_mask = []
# wordpiece 只取第一个sub token预测
for word in sequence:
sub_tokens_ids = tokenizer.encode(word, add_special_tokens=False)
input_ids = input_ids + sub_tokens_ids
pred_mask = pred_mask + [1] + [0 for i in range(len(sub_tokens_ids)-1)]
assert len(input_ids) == len(pred_mask)
return input_ids, pred_mask
def sequence_padding_for_bert(X, y, tokenizer, labels, max_len):
"""
为一个批量的句子编码
"""
input_ids_list = []
attention_mask_list = []
pred_mask_list = []
input_labels_list = []
cls_id = tokenizer.convert_tokens_to_ids("[CLS]")
sep_id = tokenizer.convert_tokens_to_ids("[SEP]")
pad_id = tokenizer.convert_tokens_to_ids("[PAD]")
for i, sequence in tqdm(enumerate(X)):
# get input_ids, pred_mask
input_ids, pred_mask = encode_plus(tokenizer, sequence)
attention_mask = [1] * len(input_ids)
# padding
input_ids = [cls_id] + input_ids[:max_len-2] + [sep_id] + [pad_id]* (max_len - len(input_ids) - 2)
pred_mask = [0] + pred_mask[:max_len-2] + [0] + [0]* (max_len - len(pred_mask) - 2)
# get attention_mask
attention_mask = [1] + attention_mask[:max_len-2] + [1] + [0]* (max_len - len(attention_mask) - 2)
# get input_labels
sequence_labels = [labels.index(l) for l in y[i][:sum(pred_mask)]]
sequence_labels = sequence_labels[::-1]
input_labels = [sequence_labels.pop() if pred_mask[i]==1 else labels.index("O") for i in range(len(pred_mask))]
input_ids_list.append(input_ids)
attention_mask_list.append(attention_mask)
pred_mask_list.append(pred_mask)
input_labels_list.append(input_labels)
return torch.LongTensor(input_ids_list), \
torch.ByteTensor(attention_mask_list), \
torch.ByteTensor(pred_mask_list), \
torch.LongTensor(input_labels_list)
def q():
pass