-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit c84d7a9
Showing
14 changed files
with
1,525 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
## TENER: Adapting Transformer Encoder for Named Entity Recognition | ||
|
||
|
||
This is the code for the paper [TENER](https://arxiv.org/abs/1911.04474). | ||
|
||
TENER (Transformer Encoder for Named Entity Recognition) is a Transformer-based model which | ||
aims to tackle the NER task. Compared with the naive Transformer, we | ||
found relative position embedding is quite important in the NER task. Experiments | ||
in the English and Chinese NER datasets prove the effectiveness. | ||
|
||
#### Requirements | ||
This project needs the natural language processing python package | ||
[fastNLP](https://github.com/fastnlp/fastNLP). You can install by | ||
the following command | ||
|
||
```bash | ||
pip install fastNLP | ||
``` | ||
|
||
#### Run the code | ||
|
||
(1) Prepare the English dataset. | ||
|
||
##### Conll2003 | ||
|
||
Your file should like the following (The first token in a line | ||
is the word, the last token is the NER tag.) | ||
|
||
``` | ||
LONDON NNP B-NP B-LOC | ||
1996-08-30 CD I-NP O | ||
West NNP B-NP B-MISC | ||
Indian NNP I-NP I-MISC | ||
all-rounder NN I-NP O | ||
Phil NNP I-NP B-PER | ||
``` | ||
|
||
##### OntoNotes | ||
|
||
Suggest to use the following code to prepare your data | ||
[OntoNotes-5.0-NER](https://github.com/yhcc/OntoNotes-5.0-NER). | ||
Or you can prepare data like the Conll2003 style, and then replace the | ||
OntoNotesNERPipe with Conll2003NERPipe in the code. | ||
|
||
For English datasets, we use the Glove 100d pretrained embedding. FastNLP will | ||
download it automatically. | ||
|
||
You can use the following code to run (make sure you have changed the | ||
data path) | ||
|
||
``` | ||
python train_tener_en.py --dataset conll2003 | ||
``` | ||
or | ||
``` | ||
python train_tener_en.py --dataset en-ontonotes | ||
``` | ||
|
||
Although we tried hard to make sure you can reproduce our results, | ||
the results may still disappoint you. This is usually caused by | ||
the best dev performance does not correlate well with the test performance | ||
. Several runs should be helpful. | ||
|
||
The ELMo version (FastNLP will download ELMo weights automatically, you just need | ||
to change the data path in train_elmo_en.) | ||
|
||
``` | ||
python train_elmo_en.py --dataset en-ontonotes | ||
``` | ||
|
||
|
||
|
||
##### MSRA, OntoNotes4.0, Weibo, Resume | ||
Your data should only have two columns, the first is the character, | ||
the second is the tag, like the following | ||
``` | ||
口 O | ||
腔 O | ||
溃 O | ||
疡 O | ||
加 O | ||
上 O | ||
``` | ||
|
||
For the Chinese datasets, you can download the pretrained unigram and | ||
bigram embeddings in [Baidu Cloud](https://pan.baidu.com/s/1pLO6T9D#list/path=%2Fsharelink808087924-1080546002081577%2FNeuralSegmentation&parentPath=%2Fsharelink808087924-1080546002081577). | ||
Download the 'gigaword_chn.all.a2b.uni.iter50.vec' and 'gigaword_chn.all.a2b.bi.iter50.vec'. | ||
Then replace the embedding path in train_tener_cn.py | ||
|
||
You can run the code by the following command | ||
|
||
``` | ||
python train_tener_cn.py --dataset ontonotes | ||
``` | ||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
|
||
|
||
from fastNLP.modules import ConditionalRandomField, allowed_transitions | ||
from modules.transformer import TransformerEncoder | ||
|
||
from torch import nn | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
|
||
class TENER(nn.Module): | ||
def __init__(self, tag_vocab, embed, num_layers, d_model, n_head, feedforward_dim, dropout, | ||
after_norm=True, attn_type='adatrans', bi_embed=None, | ||
fc_dropout=0.3, pos_embed=None, scale=False, dropout_attn=None): | ||
""" | ||
:param tag_vocab: fastNLP Vocabulary | ||
:param embed: fastNLP TokenEmbedding | ||
:param num_layers: number of self-attention layers | ||
:param d_model: input size | ||
:param n_head: number of head | ||
:param feedforward_dim: the dimension of ffn | ||
:param dropout: dropout in self-attention | ||
:param after_norm: normalization place | ||
:param attn_type: adatrans, naive | ||
:param rel_pos_embed: position embedding的类型,支持sin, fix, None. relative时可为None | ||
:param bi_embed: Used in Chinese scenerio | ||
:param fc_dropout: dropout rate before the fc layer | ||
""" | ||
super().__init__() | ||
|
||
self.embed = embed | ||
embed_size = self.embed.embed_size | ||
self.bi_embed = None | ||
if bi_embed is not None: | ||
self.bi_embed = bi_embed | ||
embed_size += self.bi_embed.embed_size | ||
|
||
self.in_fc = nn.Linear(embed_size, d_model) | ||
|
||
self.transformer = TransformerEncoder(num_layers, d_model, n_head, feedforward_dim, dropout, | ||
after_norm=after_norm, attn_type=attn_type, | ||
scale=scale, dropout_attn=dropout_attn, | ||
pos_embed=pos_embed) | ||
self.fc_dropout = nn.Dropout(fc_dropout) | ||
self.out_fc = nn.Linear(d_model, len(tag_vocab)) | ||
|
||
trans = allowed_transitions(tag_vocab, include_start_end=True) | ||
self.crf = ConditionalRandomField(len(tag_vocab), include_start_end_trans=True, allowed_transitions=trans) | ||
|
||
def _forward(self, chars, target, bigrams=None): | ||
mask = chars.ne(0) | ||
chars = self.embed(chars) | ||
if self.bi_embed is not None: | ||
bigrams = self.bi_embed(bigrams) | ||
chars = torch.cat([chars, bigrams], dim=-1) | ||
|
||
chars = self.in_fc(chars) | ||
chars = self.transformer(chars, mask) | ||
chars = self.fc_dropout(chars) | ||
chars = self.out_fc(chars) | ||
logits = F.log_softmax(chars, dim=-1) | ||
if target is None: | ||
paths, _ = self.crf.viterbi_decode(logits, mask) | ||
return {'pred': paths} | ||
else: | ||
loss = self.crf(logits, target, mask) | ||
return {'loss': loss} | ||
|
||
def forward(self, chars, target, bigrams=None): | ||
return self._forward(chars, target, bigrams) | ||
|
||
def predict(self, chars, bigrams=None): | ||
return self._forward(chars, target=None, bigrams=bigrams) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
|
||
|
||
from fastNLP.embeddings import TokenEmbedding | ||
import torch | ||
from fastNLP import Vocabulary | ||
import torch.nn.functional as F | ||
from fastNLP import logger | ||
from fastNLP.embeddings.utils import _construct_char_vocab_from_vocab, get_embeddings | ||
from torch import nn | ||
from .transformer import TransformerEncoder | ||
|
||
|
||
class TransformerCharEmbed(TokenEmbedding): | ||
def __init__(self, vocab: Vocabulary, embed_size: int = 30, char_emb_size: int = 30, word_dropout: float = 0, | ||
dropout: float = 0, pool_method: str = 'max', activation='relu', | ||
min_char_freq: int = 2, requires_grad=True, include_word_start_end=True, | ||
char_attn_type='adatrans', char_n_head=3, char_dim_ffn=60, char_scale=False, char_pos_embed=None, | ||
char_dropout=0.15, char_after_norm=False): | ||
""" | ||
:param vocab: 词表 | ||
:param embed_size: TransformerCharEmbed的输出维度。默认值为50. | ||
:param char_emb_size: character的embedding的维度。默认值为50. 同时也是Transformer的d_model大小 | ||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | ||
:param dropout: 以多大概率drop character embedding的输出以及最终的word的输出。 | ||
:param pool_method: 支持'max', 'avg'。 | ||
:param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数. | ||
:param min_char_freq: character的最小出现次数。默认值为2. | ||
:param requires_grad: | ||
:param include_word_start_end: 是否使用特殊的tag标记word的开始与结束 | ||
:param char_attn_type: adatrans or naive. | ||
:param char_n_head: 多少个head | ||
:param char_dim_ffn: transformer中ffn中间层的大小 | ||
:param char_scale: 是否使用scale | ||
:param char_pos_embed: None, 'fix', 'sin'. What kind of position embedding. When char_attn_type=relative, None is | ||
ok | ||
:param char_dropout: Dropout in Transformer encoder | ||
:param char_after_norm: the normalization place. | ||
""" | ||
super(TransformerCharEmbed, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | ||
|
||
assert char_emb_size%char_n_head == 0, "d_model should divide n_head." | ||
|
||
assert pool_method in ('max', 'avg') | ||
self.pool_method = pool_method | ||
# activation function | ||
if isinstance(activation, str): | ||
if activation.lower() == 'relu': | ||
self.activation = F.relu | ||
elif activation.lower() == 'sigmoid': | ||
self.activation = F.sigmoid | ||
elif activation.lower() == 'tanh': | ||
self.activation = F.tanh | ||
elif activation is None: | ||
self.activation = lambda x: x | ||
elif callable(activation): | ||
self.activation = activation | ||
else: | ||
raise Exception( | ||
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") | ||
|
||
logger.info("Start constructing character vocabulary.") | ||
# 建立char的词表 | ||
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq, | ||
include_word_start_end=include_word_start_end) | ||
self.char_pad_index = self.char_vocab.padding_idx | ||
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.") | ||
# 对vocab进行index | ||
max_word_len = max(map(lambda x: len(x[0]), vocab)) | ||
if include_word_start_end: | ||
max_word_len += 2 | ||
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len), | ||
fill_value=self.char_pad_index, dtype=torch.long)) | ||
self.register_buffer('word_lengths', torch.zeros(len(vocab)).long()) | ||
for word, index in vocab: | ||
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否 | ||
if include_word_start_end: | ||
word = ['<bow>'] + list(word) + ['<eow>'] | ||
self.words_to_chars_embedding[index, :len(word)] = \ | ||
torch.LongTensor([self.char_vocab.to_index(c) for c in word]) | ||
self.word_lengths[index] = len(word) | ||
|
||
self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size)) | ||
self.transformer = TransformerEncoder(1, char_emb_size, char_n_head, char_dim_ffn, dropout=char_dropout, after_norm=char_after_norm, | ||
attn_type=char_attn_type, pos_embed=char_pos_embed, scale=char_scale) | ||
self.fc = nn.Linear(char_emb_size, embed_size) | ||
|
||
self._embed_size = embed_size | ||
|
||
self.requires_grad = requires_grad | ||
|
||
def forward(self, words): | ||
""" | ||
输入words的index后,生成对应的words的表示。 | ||
:param words: [batch_size, max_len] | ||
:return: [batch_size, max_len, embed_size] | ||
""" | ||
words = self.drop_word(words) | ||
batch_size, max_len = words.size() | ||
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len | ||
word_lengths = self.word_lengths[words] # batch_size x max_len | ||
max_word_len = word_lengths.max() | ||
chars = chars[:, :, :max_word_len] | ||
# 为mask的地方为1 | ||
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 | ||
char_embeds = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size | ||
char_embeds = self.dropout(char_embeds) | ||
reshaped_chars = char_embeds.reshape(batch_size * max_len, max_word_len, -1) | ||
|
||
trans_chars = self.transformer(reshaped_chars, chars_masks.eq(0).reshape(-1, max_word_len)) | ||
trans_chars = trans_chars.reshape(batch_size, max_len, max_word_len, -1) | ||
trans_chars = self.activation(trans_chars) | ||
if self.pool_method == 'max': | ||
trans_chars = trans_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf')) | ||
chars, _ = torch.max(trans_chars, dim=-2) # batch_size x max_len x H | ||
else: | ||
trans_chars = trans_chars.masked_fill(chars_masks.unsqueeze(-1), 0) | ||
chars = torch.sum(trans_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float() | ||
|
||
chars = self.fc(chars) | ||
|
||
return self.dropout(chars) |
Empty file.
Oops, something went wrong.