diff --git a/modules/relative_transformer.py b/modules/relative_transformer.py index 93cb9ca..e076db9 100644 --- a/modules/relative_transformer.py +++ b/modules/relative_transformer.py @@ -124,9 +124,10 @@ def forward(self, x, mask): D_ = torch.einsum('nd,ld->nl', self.r_w_bias, pos_embed)[None, :, None] # head x 2max_len, 每个head对位置的bias B_ = torch.einsum('bnqd,ld->bnql', q, pos_embed) # bsz x head x max_len x 2max_len,每个query对每个shift的偏移 - BD = B_ + D_ # bsz x head x max_len x 2max_len, 要转换为bsz x head x max_len x max_len - BD = self._shift(BD) - attn = AC + BD + E_ = torch.einsum('bnqd,ld->bnql', k, pos_embed) # bsz x head x max_len x 2max_len, key对relative的bias + BDE = B_ + D_ # bsz x head x max_len x 2max_len, 要转换为bsz x head x max_len x max_len + BD = self._shift(BD) + self._transpose_shift(E_) + attn = AC + BDE attn = attn / self.scale @@ -159,3 +160,28 @@ def _shift(self, BD): BD = BD[:, :, :-1].view(bsz, n_head, max_len, -1) # bsz x n_head x 2max_len x max_len BD = BD[:, :, :, max_len:] return BD + + def _transpose_shift(self, E): + """ + 类似 + -3 -2 -1 0 1 2 + -30 -20 -10 00 10 20 + -300 -200 -100 000 100 200 + + 转换为 + 0 -10 -200 + 1 00 -100 + 2 10 000 + + + :param E: batch_size x n_head x max_len x 2max_len + :return: batch_size x n_head x max_len x max_len + """ + bsz, n_head, max_len, _ = E.size() + zero_pad = E.new_zeros(bsz, n_head, max_len, 1) + # bsz x n_head x -1 x (max_len+1) + E = torch.cat([E, zero_pad], dim=-1).view(bsz, n_head, -1, max_len) + indice = (torch.arange(max_len)*2+1).to(E.device) + E = E.index_select(index=indice, dim=-2).transpose(-1,-2) # bsz x n_head x max_len x max_len + + return E