diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 67120a1b6..db5c70c6a 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -194,13 +194,17 @@ def before_tr(self): self.cal_reg_loss_over_task_loss_ratio() def cal_reg_loss_over_task_loss_ratio(self): + """ + estimate the scale of each loss term, match each loss term to the major + loss via a ratio, this ratio will be multiplied with multiplier + """ list_accum_reg_loss = [] loss_task_agg = 0 for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( self.loader_tr ): if ind_batch >= self.aconf.nb4reg_over_task_ratio: - break + return tensor_x, tensor_y, tensor_d = ( tensor_x.to(self.device), tensor_y.to(self.device), diff --git a/sh_list_error.sh b/sh_list_error.sh deleted file mode 100644 index 5f725e15b..000000000 --- a/sh_list_error.sh +++ /dev/null @@ -1,5 +0,0 @@ -# find $1 -type f -print0 | xargs -0 grep -li error -# B means before, A means after, some erros have long stack exception message so we need at least -# 100 lines before the error, the last line usually indicate the root cause of error -grep -B 100 -wnr "error" --group-separator="=========begin_slurm_error===============" $1 > slurm_errors.txt -cat slurm_errors.txt