Skip to content

Commit

Permalink
Fix the add_loss() issue that does not work (Fix #122)
Browse files Browse the repository at this point in the history
  • Loading branch information
xpai committed Nov 5, 2024
1 parent 6b6a7ca commit 7b98a49
Show file tree
Hide file tree
Showing 16 changed files with 71 additions and 50 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,4 @@ venv.bak/
.mypy_cache/
_build
checkpoints
.csv
*.csv
28 changes: 25 additions & 3 deletions fuxictr/pytorch/models/multitask_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os, sys
import logging
from fuxictr.pytorch.models import BaseModel
from fuxictr.pytorch.torch_utils import get_device, get_optimizer, get_loss
from fuxictr.pytorch.torch_utils import get_device, get_optimizer, get_loss, get_regularizer
from tqdm import tqdm
from collections import defaultdict

Expand Down Expand Up @@ -79,14 +79,36 @@ def get_labels(self, inputs):
for i in range(len(labels))]
return y

def compute_loss(self, return_dict, y_true):
def regularization_loss(self):
reg_loss = 0
if self._embedding_regularizer or self._net_regularizer:
emb_reg = get_regularizer(self._embedding_regularizer)
net_reg = get_regularizer(self._net_regularizer)
for _, module in self.named_modules():
for p_name, param in module.named_parameters():
if param.requires_grad:
if p_name in ["weight", "bias"]:
if type(module) == nn.Embedding:
if self._embedding_regularizer:
for emb_p, emb_lambda in emb_reg:
reg_loss += (emb_lambda / emb_p) * torch.norm(param, emb_p) ** emb_p
else:
if self._net_regularizer:
for net_p, net_lambda in net_reg:
reg_loss += (net_lambda / net_p) * torch.norm(param, net_p) ** net_p
return reg_loss

def add_loss(self, return_dict, y_true):
labels = self.feature_map.labels
loss = [self.loss_fn[i](return_dict["{}_pred".format(labels[i])], y_true[i], reduction='mean')
for i in range(len(labels))]
if self.loss_weight == 'EQ':
# Default: All losses are weighted equally
loss = torch.sum(torch.stack(loss))
loss += self.regularization_loss()
return loss

def compute_loss(self, return_dict, y_true):
loss = self.add_loss(return_dict, y_true) + self.regularization_loss()
return loss

def evaluate(self, data_generator, metrics=None):
Expand Down
15 changes: 9 additions & 6 deletions fuxictr/pytorch/models/rank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def compile(self, optimizer, loss, lr):
self.loss_fn = get_loss(loss)

def regularization_loss(self):
reg_term = 0
reg_loss = 0
if self._embedding_regularizer or self._net_regularizer:
emb_reg = get_regularizer(self._embedding_regularizer)
net_reg = get_regularizer(self._net_regularizer)
Expand All @@ -76,16 +76,19 @@ def regularization_loss(self):
if type(module) == nn.Embedding:
if self._embedding_regularizer:
for emb_p, emb_lambda in emb_reg:
reg_term += (emb_lambda / emb_p) * torch.norm(param, emb_p) ** emb_p
reg_loss += (emb_lambda / emb_p) * torch.norm(param, emb_p) ** emb_p
else:
if self._net_regularizer:
for net_p, net_lambda in net_reg:
reg_term += (net_lambda / net_p) * torch.norm(param, net_p) ** net_p
return reg_term
reg_loss += (net_lambda / net_p) * torch.norm(param, net_p) ** net_p
return reg_loss

def compute_loss(self, return_dict, y_true):
def add_loss(self, return_dict, y_true):
loss = self.loss_fn(return_dict["y_pred"], y_true, reduction='mean')
loss += self.regularization_loss()
return loss

def compute_loss(self, return_dict, y_true):
loss = self.add_loss(return_dict, y_true) + self.regularization_loss()
return loss

def reset_parameters(self):
Expand Down
4 changes: 2 additions & 2 deletions fuxictr/tensorflow/models/rank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def add_loss(self, inputs):
loss = self.loss_fn(return_dict["y_pred"], y_true)
return loss

def get_total_loss(self, inputs):
def compute_loss(self, inputs):
total_loss = self.add_loss(inputs) + sum(self.losses) # with regularization
return total_loss

Expand Down Expand Up @@ -144,7 +144,7 @@ def train_epoch(self, data_generator):
@tf.function
def train_step(self, batch_data):
with tf.GradientTape() as tape:
loss = self.get_total_loss(batch_data)
loss = self.compute_loss(batch_data)
grads = tape.gradient(loss, self.trainable_variables)
grads, _ = tf.clip_by_global_norm(grads, self._max_gradient_norm)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
Expand Down
2 changes: 1 addition & 1 deletion fuxictr/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__="2.3.3"
__version__="2.3.4"
4 changes: 1 addition & 3 deletions model_zoo/DCNv3/src/DCNv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ def forward(self, inputs):
"y_s": self.output_activation(slogit)}
return return_dict

def add_loss(self, inputs):
return_dict = self.forward(inputs)
y_true = self.get_labels(inputs)
def add_loss(self, return_dict, y_true):
y_pred = return_dict["y_pred"]
y_d = return_dict["y_d"]
y_s = return_dict["y_s"]
Expand Down
4 changes: 1 addition & 3 deletions model_zoo/DIEN/src/DIEN.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@ def get_unmasked_tensor(self, h, non_zero_mask):
out[non_zero_mask] = h
return out

def add_loss(self, inputs):
y_true = self.get_labels(inputs)
return_dict = self.forward(inputs)
def add_loss(self, return_dict, y_true):
loss = self.loss_fn(return_dict["y_pred"], y_true, reduction='mean')
if self.aux_loss_alpha > 0:
# padding post required
Expand Down
4 changes: 1 addition & 3 deletions model_zoo/DMIN/src/DMIN.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,7 @@ def get_mask(self, x):
attn_mask = attn_mask & causal_mask
return padding_mask, attn_mask

def add_loss(self, inputs):
y_true = self.get_labels(inputs)
return_dict = self.forward(inputs)
def add_loss(self, return_dict, y_true):
loss = self.loss_fn(return_dict["y_pred"], y_true, reduction='mean')
if self.aux_loss_lambda > 0:
for i in range(len(self.target_field)):
Expand Down
4 changes: 1 addition & 3 deletions model_zoo/DMR/src/DMR.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,7 @@ def forward(self, inputs):
return_dict = {"y_pred": y_pred, "aux_loss": aux_loss_sum}
return return_dict

def add_loss(self, inputs):
y_true = self.get_labels(inputs)
return_dict = self.forward(inputs)
def add_loss(self, return_dict, y_true):
loss = self.loss_fn(return_dict["y_pred"], y_true, reduction='mean')
if self.aux_loss_beta > 0:
# padding post required
Expand Down
4 changes: 2 additions & 2 deletions model_zoo/EulerNet/config/model_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Base:
feature_specs: null
feature_config: null

EulerNet_test: # This is a config template
EulerNet_test:
model: EulerNet
dataset_id: tiny_npz
loss: 'binary_crossentropy'
Expand All @@ -34,7 +34,7 @@ EulerNet_test: # This is a config template
monitor: {'AUC': 1, 'logloss': -1}
monitor_mode: 'max'

EulerNet_test: # This is a config template
EulerNet_default: # This is a config template
model: EulerNet
dataset_id: TBD
loss: 'binary_crossentropy'
Expand Down
4 changes: 1 addition & 3 deletions model_zoo/FinalNet/src/FinalNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ def forward2(self, X):
y_pred = self.fc2(block2_out)
return y_pred

def add_loss(self, inputs):
return_dict = self.forward(inputs)
y_true = self.get_labels(inputs)
def add_loss(self, return_dict, y_true):
loss = self.loss_fn(return_dict["y_pred"], y_true, reduction='mean')
if self.block_type == "2B":
y1 = self.output_activation(return_dict["y1"])
Expand Down
7 changes: 3 additions & 4 deletions model_zoo/GDCN/config/model_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ GDCNP_test:
monitor: 'AUC'
monitor_mode: 'max'

GDCNS_test:
model: GDCNS
GDCN_test:
model: GDCN
dataset_id: tiny_npz
loss: 'binary_crossentropy'
metrics: ['logloss', 'AUC']
Expand All @@ -59,7 +59,7 @@ GDCNS_test:
monitor_mode: 'max'

GDCN_default: # This is a config template
model: GDCNP
model: GDCN
dataset_id: TBD
loss: 'binary_crossentropy'
metrics: ['logloss', 'AUC']
Expand All @@ -80,4 +80,3 @@ GDCN_default: # This is a config template
seed: 20222023
monitor: {'AUC': 1, 'logloss': -1}
monitor_mode: 'max'

8 changes: 5 additions & 3 deletions model_zoo/GDCN/src/GDCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ def forward(self, inputs):
return_dict = {"y_pred": y_pred}
return return_dict

class GDCNS(BaseModel):

class GDCN(BaseModel):
def __init__(self,
feature_map,
model_id="GDCNS",
model_id="GDCN",
gpu=-1,
learning_rate=1e-3,
embedding_dim=10,
Expand All @@ -84,7 +85,7 @@ def __init__(self,
embedding_regularizer=None,
net_regularizer=None,
**kwargs):
super(GDCNS, self).__init__(feature_map,
super(GDCN, self).__init__(feature_map,
model_id=model_id,
gpu=gpu,
embedding_regularizer=embedding_regularizer,
Expand Down Expand Up @@ -115,6 +116,7 @@ def forward(self, inputs):
return_dict = {"y_pred": y_pred}
return return_dict


class GateCorssLayer(nn.Module):
# The core structure: gated corss layer.
def __init__(self, input_dim, cn_layers=3):
Expand Down
2 changes: 1 addition & 1 deletion model_zoo/GDCN/src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .GDCN import GDCNP, GDCNS
from .GDCN import GDCNP, GDCN



2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="fuxictr",
version="2.3.3",
version="2.3.4",
author="RECZOO",
author_email="[email protected]",
description="A configurable, tunable, and reproducible library for CTR prediction",
Expand Down
27 changes: 16 additions & 11 deletions tests/test_torch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,52 @@ home="$(pwd)/../model_zoo"
echo "=== Testing AFM ===" && cd $home/AFM && python run_expid.py --expid AFM_test && \
echo "=== Testing AFN ===" && cd $home/AFN && python run_expid.py --expid AFN_test && \
echo "=== Testing AOANet ===" && cd $home/AOANet && python run_expid.py --expid AOANet_test && \
echo "=== Testing APG ===" && cd $home/APG && python run_expid.py --expid APG_test && \
echo "=== Testing AutoInt ===" && cd $home/AutoInt && python run_expid.py --expid AutoInt_test && \
echo "=== Testing BST ===" && cd $home/BST && python run_expid.py --expid BST_test && \
echo "=== Testing CCPM ===" && cd $home/CCPM && python run_expid.py --expid CCPM_test && \
echo "=== Testing DCN_torch ===" && cd $home/DCN/DCN_torch && python run_expid.py --expid DCN_test && \
echo "=== Testing DCN ===" && cd $home/DCN/DCN_torch && python run_expid.py --expid DCN_test && \
echo "=== Testing DCNv2 ===" && cd $home/DCNv2 && python run_expid.py --expid DCNv2_test && \
echo "=== Testing DCNv3 ===" && cd $home/DCNv3 && python run_expid.py --expid DCNv3_test && \
echo "=== Testing DeepCrossing ===" && cd $home/DeepCrossing && python run_expid.py --expid DeepCrossing_test && \
echo "=== Testing DeepFM_torch ===" && cd $home/DeepFM/DeepFM_torch && python run_expid.py --expid DeepFM_test && \
echo "=== Testing DeepFM_tf ===" && cd $home/DeepFM/DeepFM_tf && python run_expid.py --expid DeepFM_test && \
echo "=== Testing DeepFM ===" && cd $home/DeepFM/DeepFM_torch && python run_expid.py --expid DeepFM_test && \
echo "=== Testing DeepIM ===" && cd $home/DeepIM && python run_expid.py --expid DeepIM_test && \
echo "=== Testing DESTINE ===" && cd $home/DESTINE && python run_expid.py --expid DESTINE_test && \
echo "=== Testing DIEN ===" && cd $home/DIEN && python run_expid.py --expid DIEN_test && \
echo "=== Testing DIN ===" && cd $home/DIN && python run_expid.py --expid DIN_test && \
echo "=== Testing DLRM ===" && cd $home/DLRM && python run_expid.py --expid DLRM_test && \
echo "=== Testing DMIN ===" && cd $home/DMIN && python run_expid.py --expid DMIN_test && \
echo "=== Testing DMR ===" && cd $home/DMR && python run_expid.py --expid DMR_test && \
echo "=== Testing DLRM ===" && cd $home/DLRM && python run_expid.py --expid DLRM_test && \
echo "=== Testing DNN_torch ===" && cd $home/DNN/DNN_torch && python run_expid.py --expid DNN_test && \
echo "=== Testing DNN ===" && cd $home/DNN/DNN_torch && python run_expid.py --expid DNN_test && \
echo "=== Testing DSSM ===" && cd $home/DSSM && python run_expid.py --expid DSSM_test && \
echo "=== Testing EDCN ===" && cd $home/EDCN && python run_expid.py --expid EDCN_test && \
echo "=== Testing EulerNet ===" && cd $home/EulerNet && python run_expid.py --expid EulerNet_test && \
echo "=== Testing FFM ===" && cd $home/FFM && python run_expid.py --expid FFM_test && \
echo "=== Testing FFMv2 ===" && python run_expid.py --expid FFMv2_test
echo "=== Testing FGCNN ===" && cd $home/FGCNN && python run_expid.py --expid FGCNN_test && \
echo "=== Testing DNN ===" && cd $home/FiBiNET && python run_expid.py --expid FiBiNET_test && \
echo "=== Testing FiBiNET ===" && cd $home/FiBiNET && python run_expid.py --expid FiBiNET_test && \
echo "=== Testing FiGNN ===" && cd $home/FiGNN && python run_expid.py --expid FiGNN_test && \
echo "=== Testing FinalMLP ===" && cd $home/FinalMLP && python run_expid.py --expid FinalMLP_test && \
echo "=== Testing FinalNet ===" && cd $home/FinalNet && python run_expid.py --expid FinalNet_test && \
echo "=== Testing FLEN ===" && cd $home/FLEN && python run_expid.py --expid FLEN_test && \
echo "=== Testing FM ===" && cd $home/FM && python run_expid.py --expid FM_test && \
echo "=== Testing FmFM ===" && cd $home/FmFM && python run_expid.py --expid FmFM_test && \
echo "=== Testing FwFM ===" && cd $home/FwFM && python run_expid.py --expid FwFM_test && \
echo "=== Testing FinalMLP ===" && cd $home/FinalMLP && python run_expid.py --expid FinalMLP_test && \
echo "=== Testing FinalNet ===" && cd $home/FinalNet && python run_expid.py --expid FinalNet_test && \
echo "=== Testing GDCN ===" && cd $home/GDCN && python run_expid.py --expid GDCN_test && \
echo "=== Testing HFM ===" && cd $home/HFM && python run_expid.py --expid HFM_test && \
echo "=== Testing HOFM ===" && cd $home/HOFM && python run_expid.py --expid HOFM_test && \
echo "=== Testing InterHAt ===" && cd $home/InterHAt && python run_expid.py --expid InterHAt_test && \
echo "=== Testing LorentzFM ===" && cd $home/LorentzFM && python run_expid.py --expid LorentzFM_test && \
echo "=== Testing LR ===" && cd $home/LR && python run_expid.py --expid LR_test && \
echo "=== Testing MaskNet ===" && cd $home/MaskNet && python run_expid.py --expid MaskNet_test && \
echo "=== Testing NFM ===" && cd $home/NFM && python run_expid.py --expid NFM_test && \
echo "=== Testing ONN ===" && cd $home/ONN && python run_expid.py --expid ONN_test && \
echo "=== Testing ONNv2 ===" && python run_expid.py --expid ONNv2_test && \
echo "=== Testing ONN ===" && cd $home/ONN/ONN_torch && python run_expid.py --expid ONN_test && \
echo "=== Testing ONNv2 ===" && cd $home/ONN/ONN_torch && python run_expid.py --expid ONNv2_test && \
echo "=== Testing PPNet ===" && cd $home/PEPNet && python run_expid.py --expid PPNet_test && \
echo "=== Testing PNN ===" && cd $home/PNN && python run_expid.py --expid PNN_test && \
echo "=== Testing SAM ===" && cd $home/SAM && python run_expid.py --expid SAM_test && \
echo "=== Testing WideDeep_torch ===" && cd $home/WideDeep/WideDeep_torch && python run_expid.py --expid WideDeep_test && \
echo "=== Testing TransAct ===" && cd $home/TransAct && python run_expid.py --expid TransAct_test && \
echo "=== Testing WideDeep ===" && cd $home/WideDeep/WideDeep_torch && python run_expid.py --expid WideDeep_test && \
echo "=== Testing xDeepFM ===" && cd $home/xDeepFM && python run_expid.py --expid xDeepFM_test && \

echo "All tests done."

0 comments on commit 7b98a49

Please sign in to comment.