Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
yhcc committed Dec 10, 2019
0 parents commit c84d7a9
Show file tree
Hide file tree
Showing 14 changed files with 1,525 additions and 0 deletions.
102 changes: 102 additions & 0 deletions READ.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
## TENER: Adapting Transformer Encoder for Named Entity Recognition


This is the code for the paper [TENER](https://arxiv.org/abs/1911.04474).

TENER (Transformer Encoder for Named Entity Recognition) is a Transformer-based model which
aims to tackle the NER task. Compared with the naive Transformer, we
found relative position embedding is quite important in the NER task. Experiments
in the English and Chinese NER datasets prove the effectiveness.

#### Requirements
This project needs the natural language processing python package
[fastNLP](https://github.com/fastnlp/fastNLP). You can install by
the following command

```bash
pip install fastNLP
```

#### Run the code

(1) Prepare the English dataset.

##### Conll2003

Your file should like the following (The first token in a line
is the word, the last token is the NER tag.)

```
LONDON NNP B-NP B-LOC
1996-08-30 CD I-NP O
West NNP B-NP B-MISC
Indian NNP I-NP I-MISC
all-rounder NN I-NP O
Phil NNP I-NP B-PER
```

##### OntoNotes

Suggest to use the following code to prepare your data
[OntoNotes-5.0-NER](https://github.com/yhcc/OntoNotes-5.0-NER).
Or you can prepare data like the Conll2003 style, and then replace the
OntoNotesNERPipe with Conll2003NERPipe in the code.

For English datasets, we use the Glove 100d pretrained embedding. FastNLP will
download it automatically.

You can use the following code to run (make sure you have changed the
data path)

```
python train_tener_en.py --dataset conll2003
```
or
```
python train_tener_en.py --dataset en-ontonotes
```

Although we tried hard to make sure you can reproduce our results,
the results may still disappoint you. This is usually caused by
the best dev performance does not correlate well with the test performance
. Several runs should be helpful.

The ELMo version (FastNLP will download ELMo weights automatically, you just need
to change the data path in train_elmo_en.)

```
python train_elmo_en.py --dataset en-ontonotes
```



##### MSRA, OntoNotes4.0, Weibo, Resume
Your data should only have two columns, the first is the character,
the second is the tag, like the following
```
口 O
腔 O
溃 O
疡 O
加 O
上 O
```

For the Chinese datasets, you can download the pretrained unigram and
bigram embeddings in [Baidu Cloud](https://pan.baidu.com/s/1pLO6T9D#list/path=%2Fsharelink808087924-1080546002081577%2FNeuralSegmentation&parentPath=%2Fsharelink808087924-1080546002081577).
Download the 'gigaword_chn.all.a2b.uni.iter50.vec' and 'gigaword_chn.all.a2b.bi.iter50.vec'.
Then replace the embedding path in train_tener_cn.py

You can run the code by the following command

```
python train_tener_cn.py --dataset ontonotes
```






74 changes: 74 additions & 0 deletions models/TENER.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@


from fastNLP.modules import ConditionalRandomField, allowed_transitions
from modules.transformer import TransformerEncoder

from torch import nn
import torch
import torch.nn.functional as F


class TENER(nn.Module):
def __init__(self, tag_vocab, embed, num_layers, d_model, n_head, feedforward_dim, dropout,
after_norm=True, attn_type='adatrans', bi_embed=None,
fc_dropout=0.3, pos_embed=None, scale=False, dropout_attn=None):
"""
:param tag_vocab: fastNLP Vocabulary
:param embed: fastNLP TokenEmbedding
:param num_layers: number of self-attention layers
:param d_model: input size
:param n_head: number of head
:param feedforward_dim: the dimension of ffn
:param dropout: dropout in self-attention
:param after_norm: normalization place
:param attn_type: adatrans, naive
:param rel_pos_embed: position embedding的类型,支持sin, fix, None. relative时可为None
:param bi_embed: Used in Chinese scenerio
:param fc_dropout: dropout rate before the fc layer
"""
super().__init__()

self.embed = embed
embed_size = self.embed.embed_size
self.bi_embed = None
if bi_embed is not None:
self.bi_embed = bi_embed
embed_size += self.bi_embed.embed_size

self.in_fc = nn.Linear(embed_size, d_model)

self.transformer = TransformerEncoder(num_layers, d_model, n_head, feedforward_dim, dropout,
after_norm=after_norm, attn_type=attn_type,
scale=scale, dropout_attn=dropout_attn,
pos_embed=pos_embed)
self.fc_dropout = nn.Dropout(fc_dropout)
self.out_fc = nn.Linear(d_model, len(tag_vocab))

trans = allowed_transitions(tag_vocab, include_start_end=True)
self.crf = ConditionalRandomField(len(tag_vocab), include_start_end_trans=True, allowed_transitions=trans)

def _forward(self, chars, target, bigrams=None):
mask = chars.ne(0)
chars = self.embed(chars)
if self.bi_embed is not None:
bigrams = self.bi_embed(bigrams)
chars = torch.cat([chars, bigrams], dim=-1)

chars = self.in_fc(chars)
chars = self.transformer(chars, mask)
chars = self.fc_dropout(chars)
chars = self.out_fc(chars)
logits = F.log_softmax(chars, dim=-1)
if target is None:
paths, _ = self.crf.viterbi_decode(logits, mask)
return {'pred': paths}
else:
loss = self.crf(logits, target, mask)
return {'loss': loss}

def forward(self, chars, target, bigrams=None):
return self._forward(chars, target, bigrams)

def predict(self, chars, bigrams=None):
return self._forward(chars, target=None, bigrams=bigrams)
Empty file added models/__init__.py
Empty file.
122 changes: 122 additions & 0 deletions modules/TransformerEmbedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@


from fastNLP.embeddings import TokenEmbedding
import torch
from fastNLP import Vocabulary
import torch.nn.functional as F
from fastNLP import logger
from fastNLP.embeddings.utils import _construct_char_vocab_from_vocab, get_embeddings
from torch import nn
from .transformer import TransformerEncoder


class TransformerCharEmbed(TokenEmbedding):
def __init__(self, vocab: Vocabulary, embed_size: int = 30, char_emb_size: int = 30, word_dropout: float = 0,
dropout: float = 0, pool_method: str = 'max', activation='relu',
min_char_freq: int = 2, requires_grad=True, include_word_start_end=True,
char_attn_type='adatrans', char_n_head=3, char_dim_ffn=60, char_scale=False, char_pos_embed=None,
char_dropout=0.15, char_after_norm=False):
"""
:param vocab: 词表
:param embed_size: TransformerCharEmbed的输出维度。默认值为50.
:param char_emb_size: character的embedding的维度。默认值为50. 同时也是Transformer的d_model大小
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
:param dropout: 以多大概率drop character embedding的输出以及最终的word的输出。
:param pool_method: 支持'max', 'avg'。
:param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数.
:param min_char_freq: character的最小出现次数。默认值为2.
:param requires_grad:
:param include_word_start_end: 是否使用特殊的tag标记word的开始与结束
:param char_attn_type: adatrans or naive.
:param char_n_head: 多少个head
:param char_dim_ffn: transformer中ffn中间层的大小
:param char_scale: 是否使用scale
:param char_pos_embed: None, 'fix', 'sin'. What kind of position embedding. When char_attn_type=relative, None is
ok
:param char_dropout: Dropout in Transformer encoder
:param char_after_norm: the normalization place.
"""
super(TransformerCharEmbed, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

assert char_emb_size%char_n_head == 0, "d_model should divide n_head."

assert pool_method in ('max', 'avg')
self.pool_method = pool_method
# activation function
if isinstance(activation, str):
if activation.lower() == 'relu':
self.activation = F.relu
elif activation.lower() == 'sigmoid':
self.activation = F.sigmoid
elif activation.lower() == 'tanh':
self.activation = F.tanh
elif activation is None:
self.activation = lambda x: x
elif callable(activation):
self.activation = activation
else:
raise Exception(
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]")

logger.info("Start constructing character vocabulary.")
# 建立char的词表
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq,
include_word_start_end=include_word_start_end)
self.char_pad_index = self.char_vocab.padding_idx
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.")
# 对vocab进行index
max_word_len = max(map(lambda x: len(x[0]), vocab))
if include_word_start_end:
max_word_len += 2
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len),
fill_value=self.char_pad_index, dtype=torch.long))
self.register_buffer('word_lengths', torch.zeros(len(vocab)).long())
for word, index in vocab:
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否
if include_word_start_end:
word = ['<bow>'] + list(word) + ['<eow>']
self.words_to_chars_embedding[index, :len(word)] = \
torch.LongTensor([self.char_vocab.to_index(c) for c in word])
self.word_lengths[index] = len(word)

self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size))
self.transformer = TransformerEncoder(1, char_emb_size, char_n_head, char_dim_ffn, dropout=char_dropout, after_norm=char_after_norm,
attn_type=char_attn_type, pos_embed=char_pos_embed, scale=char_scale)
self.fc = nn.Linear(char_emb_size, embed_size)

self._embed_size = embed_size

self.requires_grad = requires_grad

def forward(self, words):
"""
输入words的index后,生成对应的words的表示。
:param words: [batch_size, max_len]
:return: [batch_size, max_len, embed_size]
"""
words = self.drop_word(words)
batch_size, max_len = words.size()
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len
word_lengths = self.word_lengths[words] # batch_size x max_len
max_word_len = word_lengths.max()
chars = chars[:, :, :max_word_len]
# 为mask的地方为1
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
char_embeds = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
char_embeds = self.dropout(char_embeds)
reshaped_chars = char_embeds.reshape(batch_size * max_len, max_word_len, -1)

trans_chars = self.transformer(reshaped_chars, chars_masks.eq(0).reshape(-1, max_word_len))
trans_chars = trans_chars.reshape(batch_size, max_len, max_word_len, -1)
trans_chars = self.activation(trans_chars)
if self.pool_method == 'max':
trans_chars = trans_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf'))
chars, _ = torch.max(trans_chars, dim=-2) # batch_size x max_len x H
else:
trans_chars = trans_chars.masked_fill(chars_masks.unsqueeze(-1), 0)
chars = torch.sum(trans_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float()

chars = self.fc(chars)

return self.dropout(chars)
Empty file added modules/__init__.py
Empty file.
Loading

0 comments on commit c84d7a9

Please sign in to comment.