diff --git a/README.md b/README.md
index 74556d9..f3e368e 100644
--- a/README.md
+++ b/README.md
@@ -90,9 +90,8 @@ Click-through rate (CTR) prediction is a critical task for various industrial ap
|
:open_file_folder: **Multi-Task Modeling** |
|
| 52 | Arxiv'17 | [ShareBottom](./model_zoo/multitask/ShareBottom) | [An Overview of Multi-Task Learning in Deep Neural Networks](https://arxiv.org/abs/1706.05098) | | `torch` |
| 53 | KDD'18 | [MMoE](./model_zoo/multitask/MMOE) | [Modeling Task Relationships in Multi-task Learning with Multi-Gate Mixture-of-Experts](https://dl.acm.org/doi/pdf/10.1145/3219819.3220007) :triangular_flag_on_post:**Google** | | `torch` |
-| 54 | KDD'18 | [PLE](./model_zoo/multitask/PLE) | [Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/10.1145/3383313.3412236) :triangular_flag_on_post:**Tencent** | | `torch` |
-|:open_file_folder: **Multi-Domain Modeling** |
|
-| 55 | KDD'23 | PEPNet | [PEPNet: Parameter and Embedding Personalized Network for Infusing with Personalized Prior Information](https://arxiv.org/abs/2302.01115) :triangular_flag_on_post:**KuaiShou** | | `torch` |
+| 54 | RecSys'20 | [PLE](./model_zoo/multitask/PLE) | [Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/10.1145/3383313.3412236) :triangular_flag_on_post:**Tencent** | | `torch` |
+
## Benchmarking
diff --git a/model_zoo/BST/src/BST.py b/model_zoo/BST/src/BST.py
index 0036f7e..673d8ec 100644
--- a/model_zoo/BST/src/BST.py
+++ b/model_zoo/BST/src/BST.py
@@ -136,19 +136,23 @@ def forward(self, inputs):
return return_dict
def get_mask(self, x):
- """ padding_mask: 1 for masked positions
- attn_mask: 1 for masked positions in nn.MultiheadAttention
+ """ padding_mask: B x L, 1 for masked positions
+ attn_mask: (B*H) x L x L, 1 for masked positions in nn.MultiheadAttention
"""
padding_mask = (x == 0)
- padding_mask = torch.cat([padding_mask, torch.zeros(x.size(0), 1, dtype=torch.bool, device=x.device)], dim=-1)
+ padding_mask = torch.cat([padding_mask, torch.zeros(x.size(0), 1).bool().to(x.device)],
+ dim=-1)
seq_len = padding_mask.size(1)
- attn_mask = padding_mask.unsqueeze(1).repeat(1, seq_len * self.num_heads, 1).view(-1, seq_len, seq_len)
- diag_zeros = (1 - torch.eye(seq_len, device=x.device)).bool().unsqueeze(0).expand_as(attn_mask)
+ attn_mask = padding_mask.unsqueeze(1).repeat(1, seq_len, 1)
+ diag_zeros = ~torch.eye(seq_len, device=x.device).bool().unsqueeze(0).expand_as(attn_mask)
attn_mask = attn_mask & diag_zeros
if self.use_causal_mask:
- causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), 1).bool() \
- .unsqueeze(0).expand_as(attn_mask)
+ causal_mask = (
+ torch.triu(torch.ones(seq_len, seq_len, device=x.device), 1)
+ .bool().unsqueeze(0).expand_as(attn_mask)
+ )
attn_mask = attn_mask | causal_mask
+ attn_mask = attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(end_dim=1)
return padding_mask, attn_mask
def sequence_pooling(self, transformer_out, mask):
diff --git a/model_zoo/DMIN/src/DMIN.py b/model_zoo/DMIN/src/DMIN.py
index 38c648d..59056f2 100644
--- a/model_zoo/DMIN/src/DMIN.py
+++ b/model_zoo/DMIN/src/DMIN.py
@@ -175,17 +175,23 @@ def forward(self, inputs):
return return_dict
def get_mask(self, x):
- """ padding_mask: 0 for masked positions
+ """
+ Returns:
+ padding_mask: 0 for masked positions
attn_mask: 0 for masked positions
"""
- padding_mask = (x > 0)
+ padding_mask = (x == 0) # 1 for masked positions
seq_len = padding_mask.size(1)
- attn_mask = padding_mask.unsqueeze(1).repeat(1, seq_len * self.num_heads, 1).view(-1, seq_len, seq_len)
- diag_ones = torch.eye(seq_len, device=x.device).bool().unsqueeze(0).expand_as(attn_mask)
- attn_mask = attn_mask | diag_ones
- causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)).bool() \
- .unsqueeze(0).expand_as(attn_mask)
- attn_mask = attn_mask & causal_mask
+ attn_mask = padding_mask.unsqueeze(1).repeat(1, seq_len, 1)
+ diag_zeros = ~torch.eye(seq_len, device=x.device).bool().unsqueeze(0).expand_as(attn_mask)
+ attn_mask = attn_mask & diag_zeros
+ causal_mask = (
+ torch.triu(torch.ones(seq_len, seq_len, device=x.device), 1)
+ .bool().unsqueeze(0).expand_as(attn_mask)
+ )
+ attn_mask = attn_mask | causal_mask
+ attn_mask = attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(end_dim=1)
+ padding_mask, attn_mask = ~padding_mask, ~attn_mask
return padding_mask, attn_mask
def add_loss(self, return_dict, y_true):
diff --git a/model_zoo/DMR/src/DMR.py b/model_zoo/DMR/src/DMR.py
index 41c3632..b105280 100644
--- a/model_zoo/DMR/src/DMR.py
+++ b/model_zoo/DMR/src/DMR.py
@@ -249,7 +249,7 @@ def forward(self, target_emb, sequence_emb, context_emb, sequence_emb2, neg_emb=
attn_score = attn_score.view(-1, seq_len) # b x len
attn_mask = self.get_mask(mask) # 0 for masked positions
expand_score = attn_score.unsqueeze(1).repeat(1, seq_len, 1) # b x len x len
- # expand_score = expand_score.masked_fill_(attn_mask.float() == 0, -1.e9) # fill -inf if mask=0
+ expand_score = expand_score.masked_fill_(attn_mask == False, -1.e9) # fill -inf if mask=False
expand_score = expand_score.softmax(dim=-1)
user_embs = torch.bmm(expand_score, sequence_emb) # b x len x d
user_embs = self.W_o(user_embs.reshape(-1, self.model_dim)).reshape(-1, seq_len, self.model_dim)
@@ -264,13 +264,15 @@ def forward(self, target_emb, sequence_emb, context_emb, sequence_emb2, neg_emb=
return rel_u2i, aux_loss
def get_mask(self, mask):
- """ attn_mask: 0 for masked positions
+ """ attn_mask: B x L, 0 for masked positions
"""
seq_len = mask.size(1)
- attn_mask = mask.unsqueeze(1).repeat(1, seq_len, 1).view(-1, seq_len, seq_len)
- causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=mask.device)).bool() \
- .unsqueeze(0).expand_as(attn_mask)
+ attn_mask = mask.unsqueeze(1).repeat(1, seq_len, 1) # B x L x L
+ causal_mask = (torch.tril(torch.ones(seq_len, seq_len, device=mask.device)).bool()
+ .unsqueeze(0).expand_as(attn_mask))
attn_mask = attn_mask & causal_mask
+ diag_ones = torch.eye(seq_len, device=mask.device).bool().unsqueeze(0).expand_as(attn_mask)
+ attn_mask = attn_mask | diag_ones
return attn_mask
diff --git a/model_zoo/InterHAt/src/InterHAt.py b/model_zoo/InterHAt/src/InterHAt.py
index 84e84bd..9b59f56 100644
--- a/model_zoo/InterHAt/src/InterHAt.py
+++ b/model_zoo/InterHAt/src/InterHAt.py
@@ -122,20 +122,23 @@ def __init__(self, input_dim, attention_dim=None, num_heads=1, dropout_rate=0.,
self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else None
def forward(self, query, key, value, mask=None):
+ # mask: B x L x L, 0 for masked positions
+ if mask:
+ # Repeat to (B * heads) x L x L
+ mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(end_dim=1)
residual = query
# linear projection
query = self.W_q(query)
key = self.W_k(key)
value = self.W_v(value)
-
+
# split by heads
batch_size = query.size(0)
query = query.view(batch_size * self.num_heads, -1, self.attention_dim)
key = key.view(batch_size * self.num_heads, -1, self.attention_dim)
value = value.view(batch_size * self.num_heads, -1, self.attention_dim)
- if mask:
- mask = mask.repeat(self.num_heads, 1, 1)
+
# scaled dot product attention
output, attention = self.dot_product_attention(query, key, value, self.scale, mask)
# concat heads