diff --git a/README.md b/README.md index 74556d9..f3e368e 100644 --- a/README.md +++ b/README.md @@ -90,9 +90,8 @@ Click-through rate (CTR) prediction is a critical task for various industrial ap |:open_file_folder: **Multi-Task Modeling**| | 52 | Arxiv'17 | [ShareBottom](./model_zoo/multitask/ShareBottom) | [An Overview of Multi-Task Learning in Deep Neural Networks](https://arxiv.org/abs/1706.05098) | | `torch` | | 53 | KDD'18 | [MMoE](./model_zoo/multitask/MMOE) | [Modeling Task Relationships in Multi-task Learning with Multi-Gate Mixture-of-Experts](https://dl.acm.org/doi/pdf/10.1145/3219819.3220007) :triangular_flag_on_post:**Google** | | `torch` | -| 54 | KDD'18 | [PLE](./model_zoo/multitask/PLE) | [Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/10.1145/3383313.3412236) :triangular_flag_on_post:**Tencent** | | `torch` | -|:open_file_folder: **Multi-Domain Modeling**| -| 55 | KDD'23 | PEPNet | [PEPNet: Parameter and Embedding Personalized Network for Infusing with Personalized Prior Information](https://arxiv.org/abs/2302.01115) :triangular_flag_on_post:**KuaiShou** | | `torch` | +| 54 | RecSys'20 | [PLE](./model_zoo/multitask/PLE) | [Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/10.1145/3383313.3412236) :triangular_flag_on_post:**Tencent** | | `torch` | + ## Benchmarking diff --git a/model_zoo/BST/src/BST.py b/model_zoo/BST/src/BST.py index 0036f7e..673d8ec 100644 --- a/model_zoo/BST/src/BST.py +++ b/model_zoo/BST/src/BST.py @@ -136,19 +136,23 @@ def forward(self, inputs): return return_dict def get_mask(self, x): - """ padding_mask: 1 for masked positions - attn_mask: 1 for masked positions in nn.MultiheadAttention + """ padding_mask: B x L, 1 for masked positions + attn_mask: (B*H) x L x L, 1 for masked positions in nn.MultiheadAttention """ padding_mask = (x == 0) - padding_mask = torch.cat([padding_mask, torch.zeros(x.size(0), 1, dtype=torch.bool, device=x.device)], dim=-1) + padding_mask = torch.cat([padding_mask, torch.zeros(x.size(0), 1).bool().to(x.device)], + dim=-1) seq_len = padding_mask.size(1) - attn_mask = padding_mask.unsqueeze(1).repeat(1, seq_len * self.num_heads, 1).view(-1, seq_len, seq_len) - diag_zeros = (1 - torch.eye(seq_len, device=x.device)).bool().unsqueeze(0).expand_as(attn_mask) + attn_mask = padding_mask.unsqueeze(1).repeat(1, seq_len, 1) + diag_zeros = ~torch.eye(seq_len, device=x.device).bool().unsqueeze(0).expand_as(attn_mask) attn_mask = attn_mask & diag_zeros if self.use_causal_mask: - causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), 1).bool() \ - .unsqueeze(0).expand_as(attn_mask) + causal_mask = ( + torch.triu(torch.ones(seq_len, seq_len, device=x.device), 1) + .bool().unsqueeze(0).expand_as(attn_mask) + ) attn_mask = attn_mask | causal_mask + attn_mask = attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(end_dim=1) return padding_mask, attn_mask def sequence_pooling(self, transformer_out, mask): diff --git a/model_zoo/DMIN/src/DMIN.py b/model_zoo/DMIN/src/DMIN.py index 38c648d..59056f2 100644 --- a/model_zoo/DMIN/src/DMIN.py +++ b/model_zoo/DMIN/src/DMIN.py @@ -175,17 +175,23 @@ def forward(self, inputs): return return_dict def get_mask(self, x): - """ padding_mask: 0 for masked positions + """ + Returns: + padding_mask: 0 for masked positions attn_mask: 0 for masked positions """ - padding_mask = (x > 0) + padding_mask = (x == 0) # 1 for masked positions seq_len = padding_mask.size(1) - attn_mask = padding_mask.unsqueeze(1).repeat(1, seq_len * self.num_heads, 1).view(-1, seq_len, seq_len) - diag_ones = torch.eye(seq_len, device=x.device).bool().unsqueeze(0).expand_as(attn_mask) - attn_mask = attn_mask | diag_ones - causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)).bool() \ - .unsqueeze(0).expand_as(attn_mask) - attn_mask = attn_mask & causal_mask + attn_mask = padding_mask.unsqueeze(1).repeat(1, seq_len, 1) + diag_zeros = ~torch.eye(seq_len, device=x.device).bool().unsqueeze(0).expand_as(attn_mask) + attn_mask = attn_mask & diag_zeros + causal_mask = ( + torch.triu(torch.ones(seq_len, seq_len, device=x.device), 1) + .bool().unsqueeze(0).expand_as(attn_mask) + ) + attn_mask = attn_mask | causal_mask + attn_mask = attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(end_dim=1) + padding_mask, attn_mask = ~padding_mask, ~attn_mask return padding_mask, attn_mask def add_loss(self, return_dict, y_true): diff --git a/model_zoo/DMR/src/DMR.py b/model_zoo/DMR/src/DMR.py index 41c3632..b105280 100644 --- a/model_zoo/DMR/src/DMR.py +++ b/model_zoo/DMR/src/DMR.py @@ -249,7 +249,7 @@ def forward(self, target_emb, sequence_emb, context_emb, sequence_emb2, neg_emb= attn_score = attn_score.view(-1, seq_len) # b x len attn_mask = self.get_mask(mask) # 0 for masked positions expand_score = attn_score.unsqueeze(1).repeat(1, seq_len, 1) # b x len x len - # expand_score = expand_score.masked_fill_(attn_mask.float() == 0, -1.e9) # fill -inf if mask=0 + expand_score = expand_score.masked_fill_(attn_mask == False, -1.e9) # fill -inf if mask=False expand_score = expand_score.softmax(dim=-1) user_embs = torch.bmm(expand_score, sequence_emb) # b x len x d user_embs = self.W_o(user_embs.reshape(-1, self.model_dim)).reshape(-1, seq_len, self.model_dim) @@ -264,13 +264,15 @@ def forward(self, target_emb, sequence_emb, context_emb, sequence_emb2, neg_emb= return rel_u2i, aux_loss def get_mask(self, mask): - """ attn_mask: 0 for masked positions + """ attn_mask: B x L, 0 for masked positions """ seq_len = mask.size(1) - attn_mask = mask.unsqueeze(1).repeat(1, seq_len, 1).view(-1, seq_len, seq_len) - causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=mask.device)).bool() \ - .unsqueeze(0).expand_as(attn_mask) + attn_mask = mask.unsqueeze(1).repeat(1, seq_len, 1) # B x L x L + causal_mask = (torch.tril(torch.ones(seq_len, seq_len, device=mask.device)).bool() + .unsqueeze(0).expand_as(attn_mask)) attn_mask = attn_mask & causal_mask + diag_ones = torch.eye(seq_len, device=mask.device).bool().unsqueeze(0).expand_as(attn_mask) + attn_mask = attn_mask | diag_ones return attn_mask diff --git a/model_zoo/InterHAt/src/InterHAt.py b/model_zoo/InterHAt/src/InterHAt.py index 84e84bd..9b59f56 100644 --- a/model_zoo/InterHAt/src/InterHAt.py +++ b/model_zoo/InterHAt/src/InterHAt.py @@ -122,20 +122,23 @@ def __init__(self, input_dim, attention_dim=None, num_heads=1, dropout_rate=0., self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else None def forward(self, query, key, value, mask=None): + # mask: B x L x L, 0 for masked positions + if mask: + # Repeat to (B * heads) x L x L + mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(end_dim=1) residual = query # linear projection query = self.W_q(query) key = self.W_k(key) value = self.W_v(value) - + # split by heads batch_size = query.size(0) query = query.view(batch_size * self.num_heads, -1, self.attention_dim) key = key.view(batch_size * self.num_heads, -1, self.attention_dim) value = value.view(batch_size * self.num_heads, -1, self.attention_dim) - if mask: - mask = mask.repeat(self.num_heads, 1, 1) + # scaled dot product attention output, attention = self.dot_product_attention(query, key, value, self.scale, mask) # concat heads