Skip to content

Commit

Permalink
Fix get_mask issue when num_heads > 1 (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
xpai committed Dec 11, 2024
1 parent 169376d commit 9d89a67
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 26 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,8 @@ Click-through rate (CTR) prediction is a critical task for various industrial ap
|<tr><th colspan=6 align="center">:open_file_folder: **Multi-Task Modeling**</th></tr>|
| 52 | Arxiv'17 | [ShareBottom](./model_zoo/multitask/ShareBottom) | [An Overview of Multi-Task Learning in Deep Neural Networks](https://arxiv.org/abs/1706.05098) | | `torch` |
| 53 | KDD'18 | [MMoE](./model_zoo/multitask/MMOE) | [Modeling Task Relationships in Multi-task Learning with Multi-Gate Mixture-of-Experts](https://dl.acm.org/doi/pdf/10.1145/3219819.3220007) :triangular_flag_on_post:**Google** | | `torch` |
| 54 | KDD'18 | [PLE](./model_zoo/multitask/PLE) | [Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/10.1145/3383313.3412236) :triangular_flag_on_post:**Tencent** | | `torch` |
|<tr><th colspan=6 align="center">:open_file_folder: **Multi-Domain Modeling**</th></tr>|
| 55 | KDD'23 | PEPNet | [PEPNet: Parameter and Embedding Personalized Network for Infusing with Personalized Prior Information](https://arxiv.org/abs/2302.01115) :triangular_flag_on_post:**KuaiShou** | | `torch` |
| 54 | RecSys'20 | [PLE](./model_zoo/multitask/PLE) | [Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/10.1145/3383313.3412236) :triangular_flag_on_post:**Tencent** | | `torch` |


## Benchmarking

Expand Down
18 changes: 11 additions & 7 deletions model_zoo/BST/src/BST.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,23 @@ def forward(self, inputs):
return return_dict

def get_mask(self, x):
""" padding_mask: 1 for masked positions
attn_mask: 1 for masked positions in nn.MultiheadAttention
""" padding_mask: B x L, 1 for masked positions
attn_mask: (B*H) x L x L, 1 for masked positions in nn.MultiheadAttention
"""
padding_mask = (x == 0)
padding_mask = torch.cat([padding_mask, torch.zeros(x.size(0), 1, dtype=torch.bool, device=x.device)], dim=-1)
padding_mask = torch.cat([padding_mask, torch.zeros(x.size(0), 1).bool().to(x.device)],
dim=-1)
seq_len = padding_mask.size(1)
attn_mask = padding_mask.unsqueeze(1).repeat(1, seq_len * self.num_heads, 1).view(-1, seq_len, seq_len)
diag_zeros = (1 - torch.eye(seq_len, device=x.device)).bool().unsqueeze(0).expand_as(attn_mask)
attn_mask = padding_mask.unsqueeze(1).repeat(1, seq_len, 1)
diag_zeros = ~torch.eye(seq_len, device=x.device).bool().unsqueeze(0).expand_as(attn_mask)
attn_mask = attn_mask & diag_zeros
if self.use_causal_mask:
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), 1).bool() \
.unsqueeze(0).expand_as(attn_mask)
causal_mask = (
torch.triu(torch.ones(seq_len, seq_len, device=x.device), 1)
.bool().unsqueeze(0).expand_as(attn_mask)
)
attn_mask = attn_mask | causal_mask
attn_mask = attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(end_dim=1)
return padding_mask, attn_mask

def sequence_pooling(self, transformer_out, mask):
Expand Down
22 changes: 14 additions & 8 deletions model_zoo/DMIN/src/DMIN.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,23 @@ def forward(self, inputs):
return return_dict

def get_mask(self, x):
""" padding_mask: 0 for masked positions
"""
Returns:
padding_mask: 0 for masked positions
attn_mask: 0 for masked positions
"""
padding_mask = (x > 0)
padding_mask = (x == 0) # 1 for masked positions
seq_len = padding_mask.size(1)
attn_mask = padding_mask.unsqueeze(1).repeat(1, seq_len * self.num_heads, 1).view(-1, seq_len, seq_len)
diag_ones = torch.eye(seq_len, device=x.device).bool().unsqueeze(0).expand_as(attn_mask)
attn_mask = attn_mask | diag_ones
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)).bool() \
.unsqueeze(0).expand_as(attn_mask)
attn_mask = attn_mask & causal_mask
attn_mask = padding_mask.unsqueeze(1).repeat(1, seq_len, 1)
diag_zeros = ~torch.eye(seq_len, device=x.device).bool().unsqueeze(0).expand_as(attn_mask)
attn_mask = attn_mask & diag_zeros
causal_mask = (
torch.triu(torch.ones(seq_len, seq_len, device=x.device), 1)
.bool().unsqueeze(0).expand_as(attn_mask)
)
attn_mask = attn_mask | causal_mask
attn_mask = attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(end_dim=1)
padding_mask, attn_mask = ~padding_mask, ~attn_mask
return padding_mask, attn_mask

def add_loss(self, return_dict, y_true):
Expand Down
12 changes: 7 additions & 5 deletions model_zoo/DMR/src/DMR.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def forward(self, target_emb, sequence_emb, context_emb, sequence_emb2, neg_emb=
attn_score = attn_score.view(-1, seq_len) # b x len
attn_mask = self.get_mask(mask) # 0 for masked positions
expand_score = attn_score.unsqueeze(1).repeat(1, seq_len, 1) # b x len x len
# expand_score = expand_score.masked_fill_(attn_mask.float() == 0, -1.e9) # fill -inf if mask=0
expand_score = expand_score.masked_fill_(attn_mask == False, -1.e9) # fill -inf if mask=False
expand_score = expand_score.softmax(dim=-1)
user_embs = torch.bmm(expand_score, sequence_emb) # b x len x d
user_embs = self.W_o(user_embs.reshape(-1, self.model_dim)).reshape(-1, seq_len, self.model_dim)
Expand All @@ -264,13 +264,15 @@ def forward(self, target_emb, sequence_emb, context_emb, sequence_emb2, neg_emb=
return rel_u2i, aux_loss

def get_mask(self, mask):
""" attn_mask: 0 for masked positions
""" attn_mask: B x L, 0 for masked positions
"""
seq_len = mask.size(1)
attn_mask = mask.unsqueeze(1).repeat(1, seq_len, 1).view(-1, seq_len, seq_len)
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=mask.device)).bool() \
.unsqueeze(0).expand_as(attn_mask)
attn_mask = mask.unsqueeze(1).repeat(1, seq_len, 1) # B x L x L
causal_mask = (torch.tril(torch.ones(seq_len, seq_len, device=mask.device)).bool()
.unsqueeze(0).expand_as(attn_mask))
attn_mask = attn_mask & causal_mask
diag_ones = torch.eye(seq_len, device=mask.device).bool().unsqueeze(0).expand_as(attn_mask)
attn_mask = attn_mask | diag_ones
return attn_mask


Expand Down
9 changes: 6 additions & 3 deletions model_zoo/InterHAt/src/InterHAt.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,20 +122,23 @@ def __init__(self, input_dim, attention_dim=None, num_heads=1, dropout_rate=0.,
self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else None

def forward(self, query, key, value, mask=None):
# mask: B x L x L, 0 for masked positions
if mask:
# Repeat to (B * heads) x L x L
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(end_dim=1)
residual = query

# linear projection
query = self.W_q(query)
key = self.W_k(key)
value = self.W_v(value)

# split by heads
batch_size = query.size(0)
query = query.view(batch_size * self.num_heads, -1, self.attention_dim)
key = key.view(batch_size * self.num_heads, -1, self.attention_dim)
value = value.view(batch_size * self.num_heads, -1, self.attention_dim)
if mask:
mask = mask.repeat(self.num_heads, 1, 1)

# scaled dot product attention
output, attention = self.dot_product_attention(query, key, value, self.scale, mask)
# concat heads
Expand Down

0 comments on commit 9d89a67

Please sign in to comment.