Skip to content

Commit

Permalink
Update relative_transformer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yhcc authored Jun 11, 2020
1 parent 1d529dc commit a287270
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions modules/relative_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def forward(self, x, mask):

D_ = torch.einsum('nd,ld->nl', self.r_w_bias, pos_embed)[None, :, None] # head x 2max_len, 每个head对位置的bias
B_ = torch.einsum('bnqd,ld->bnql', q, pos_embed) # bsz x head x max_len x 2max_len,每个query对每个shift的偏移
BD = B_ + D_ # bsz x head x max_len x 2max_len, 要转换为bsz x head x max_len x max_len
BD = self._shift(BD)
attn = AC + BD
E_ = torch.einsum('bnqd,ld->bnql', k, pos_embed) # bsz x head x max_len x 2max_len, key对relative的bias
BDE = B_ + D_ # bsz x head x max_len x 2max_len, 要转换为bsz x head x max_len x max_len
BD = self._shift(BD) + self._transpose_shift(E_)
attn = AC + BDE

attn = attn / self.scale

Expand Down Expand Up @@ -159,3 +160,28 @@ def _shift(self, BD):
BD = BD[:, :, :-1].view(bsz, n_head, max_len, -1) # bsz x n_head x 2max_len x max_len
BD = BD[:, :, :, max_len:]
return BD

def _transpose_shift(self, E):
"""
类似
-3 -2 -1 0 1 2
-30 -20 -10 00 10 20
-300 -200 -100 000 100 200
转换为
0 -10 -200
1 00 -100
2 10 000
:param E: batch_size x n_head x max_len x 2max_len
:return: batch_size x n_head x max_len x max_len
"""
bsz, n_head, max_len, _ = E.size()
zero_pad = E.new_zeros(bsz, n_head, max_len, 1)
# bsz x n_head x -1 x (max_len+1)
E = torch.cat([E, zero_pad], dim=-1).view(bsz, n_head, -1, max_len)
indice = (torch.arange(max_len)*2+1).to(E.device)
E = E.index_select(index=indice, dim=-2).transpose(-1,-2) # bsz x n_head x max_len x max_len

return E

0 comments on commit a287270

Please sign in to comment.