From c29877e0107bda57f09121e459102c65a5eaf5e1 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Wed, 4 Dec 2024 18:50:58 +0100 Subject: [PATCH 01/17] Update ci.yml --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2d65938a4..b6d9e7cde 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,9 +2,9 @@ name: CI on: push: - branches: mhof_dev_merge + branches: mhof_dev pull_request: - branches: mhof_dev_merge + branches: mhof_dev workflow_dispatch: jobs: test: From a715cfe60b14b7d0bb5e074cb73094e22ce98418 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Wed, 4 Dec 2024 19:15:32 +0100 Subject: [PATCH 02/17] Update config.yaml --- examples/yaml/slurm/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/yaml/slurm/config.yaml b/examples/yaml/slurm/config.yaml index cc0f6901a..df55039e7 100644 --- a/examples/yaml/slurm/config.yaml +++ b/examples/yaml/slurm/config.yaml @@ -8,7 +8,7 @@ cluster: --qos=gpu_normal --gres=gpu:1 --nice=10000 - -t 48:00:00 + -t 24:00:00 -c 2 --mem=160G --job-name=$outputfolder-{rule}-{wildcards} From 6597acbb2907df5f49eaad07e6577e9f593dbf41 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 4 Dec 2024 19:24:30 +0100 Subject: [PATCH 03/17] train_causalIRL.py self.cal_loss return a tuple --- domainlab/algos/trainers/train_causIRL.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/domainlab/algos/trainers/train_causIRL.py b/domainlab/algos/trainers/train_causIRL.py index 085abcbfa..194386b43 100644 --- a/domainlab/algos/trainers/train_causIRL.py +++ b/domainlab/algos/trainers/train_causIRL.py @@ -69,7 +69,7 @@ def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch): penalty = torch.nan_to_num(self.mmd(first, second)) else: penalty = torch.tensor(0) - loss = self.cal_loss(tensor_x, tensor_y, tensor_d, others) + loss, *_ = self.cal_loss(tensor_x, tensor_y, tensor_d, others) loss = loss + penalty loss.backward() self.optimizer.step() From de004b5c1334b3060013c66991179e7c415a5711 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 4 Dec 2024 19:35:01 +0100 Subject: [PATCH 04/17] unit test --- tests/test_lr_scheduler.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 tests/test_lr_scheduler.py diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py new file mode 100644 index 000000000..7bb7ab92f --- /dev/null +++ b/tests/test_lr_scheduler.py @@ -0,0 +1,14 @@ + +""" +unit and end-end test for lr scheduler +""" +from tests.utils_test import utils_test_algo + + +def test_lr_scheduler(): + """ + train + """ + args = "--te_d=2 --tr_d 0 1 --task=mnistcolor10 --debug --bs=100 --model=erm \ + --nname=conv_bn_pool_2 --no_dump --lr_scheduler=CosineAnnealingLR" + utils_test_algo(args) From 06257726b48178d84188c690b99ab466321c2d4b Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Wed, 4 Dec 2024 19:39:40 +0100 Subject: [PATCH 05/17] Update arg_parser.py --- domainlab/arg_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 47272d7fb..bb7bda2b4 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -267,7 +267,7 @@ def mk_parser_main(): parser.add_argument( "--lr_scheduler", type=str, - default="CosineAnnealingLR", + default=None, help="name of pytorch learning rate scheduler", ) From 3eed76644bd69b01f6788c7b097a6015285e0394 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 5 Dec 2024 10:17:22 +0100 Subject: [PATCH 06/17] fix bug of deleting multipliers, now only do for erm --- domainlab/algos/trainers/a_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 7e95ebb00..a05d64b86 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -319,7 +319,7 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): list_reg_loss_trainer_tensor list_mu = list_mu_model + list_mu_trainer # ERM return a tensor of all zeros, delete here - if len(list_mu) > 1: + if len(list_mu) > 1 and "ModelERM" == type(self.get_model()).__name__: list_boolean_zero = [torch.all(torch.eq(list_loss_tensor[i], 0)).item() for i in range(len(list_mu))] list_loss_tensor = [list_loss_tensor[i] for (i, flag) in From 87cbc6888ca685eb818687d3c35d467bd05da5c9 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 5 Dec 2024 12:55:45 +0100 Subject: [PATCH 07/17] not allowing naked erm be combined with fbopt --- domainlab/algos/trainers/fbopt_mu_controller.py | 3 +++ tests/test_fbopt.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index ce53fc0df..272d34908 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -45,6 +45,9 @@ def __init__(self, trainer, **kwargs): self.mu_min = trainer.aconf.mu_min self.mu_clip = trainer.aconf.mu_clip + if not kwargs: + raise RuntimeError("feedback scheduler requires **kwargs, the set \ + of multipliers non-empty") self.mmu = kwargs # force initial value of mu self.mmu = {key: self.init_mu for key, val in self.mmu.items()} diff --git a/tests/test_fbopt.py b/tests/test_fbopt.py index c442bf090..84b97f0f9 100644 --- a/tests/test_fbopt.py +++ b/tests/test_fbopt.py @@ -1,6 +1,7 @@ """ unit and end-end test for deep all, mldg """ +import pytest from tests.utils_test import utils_test_algo @@ -27,13 +28,24 @@ def test_diva_fbopt(): args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=diva --gamma_y=1.0 --trainer=fbopt --nname=alexnet --epos=3" utils_test_algo(args) + def test_erm_fbopt(): """ erm """ args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt --nname=alexnet --epos=3" # pylint: disable=line-too-long + with pytest.raises(RuntimeError): + utils_test_algo(args) + + +def test_irm_fbopt(): + """ + irm + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt_irm --nname=alexnet --epos=3" # pylint: disable=line-too-long utils_test_algo(args) + def test_forcesetpoint_fbopt(): """ diva From 51c26d1e724ff9305aab216c33c2a9a26a5c56f5 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 5 Dec 2024 13:37:20 +0100 Subject: [PATCH 08/17] latex table in scirpt --- scripts/generate_latex_table.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 scripts/generate_latex_table.py diff --git a/scripts/generate_latex_table.py b/scripts/generate_latex_table.py new file mode 100644 index 000000000..4b35f4209 --- /dev/null +++ b/scripts/generate_latex_table.py @@ -0,0 +1,25 @@ +""" +aggregate benchmark csv file to generate latex table +""" +import argparse +import pandas as pd + + +def gen_latex_table(raw_df, fname="table_perf.tex", + group="method", str_perf="acc"): + """ + aggregate benchmark csv file to generate latex table + """ + df_result = raw_df.groupby(group)[str_perf].agg(["mean", "std"]) + latex_table = df_result.to_latex(float_format="%.3f") + with open(fname, 'w') as file: + file.write(latex_table) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Read a CSV file") + parser.add_argument("filename", help="Name of the CSV file to read") + args = parser.parse_args() + + df = pd.read_csv(args.filename, index_col=False, skipinitialspace=True) + gen_latex_table(df) From 1b98e5831a75d3b1c160f42a4666c3cee01abf05 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 5 Dec 2024 13:43:24 +0100 Subject: [PATCH 09/17] csv agg mean, std to text table --- scripts/generate_latex_table.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/generate_latex_table.py b/scripts/generate_latex_table.py index 4b35f4209..caa333aab 100644 --- a/scripts/generate_latex_table.py +++ b/scripts/generate_latex_table.py @@ -12,6 +12,8 @@ def gen_latex_table(raw_df, fname="table_perf.tex", """ df_result = raw_df.groupby(group)[str_perf].agg(["mean", "std"]) latex_table = df_result.to_latex(float_format="%.3f") + str_table = df_result.to_string() + print(str_table) with open(fname, 'w') as file: file.write(latex_table) From 7d8d081dc563c4d9cf68059369ae62dbbf209a90 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 5 Dec 2024 13:47:44 +0100 Subject: [PATCH 10/17] fix typo new argument type --- tests/test_fbopt_setpoint_rewind.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_fbopt_setpoint_rewind.py b/tests/test_fbopt_setpoint_rewind.py index 3c1011bab..3ce13079b 100644 --- a/tests/test_fbopt_setpoint_rewind.py +++ b/tests/test_fbopt_setpoint_rewind.py @@ -8,5 +8,5 @@ def test_jigen_fbopt(): """ jigen """ - args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=jigen --trainer=fbopt --nname=alexnet --epos=300 --setpoint_rewind=yes" + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=jigen --trainer=fbopt --nname=alexnet --epos=300 --setpoint_rewind" utils_test_algo(args) From b2962297f7a8b0e87693b7738d2412a08cf14d3a Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Thu, 5 Dec 2024 15:17:27 +0100 Subject: [PATCH 11/17] Update train_ema.py --- domainlab/algos/trainers/train_ema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/domainlab/algos/trainers/train_ema.py b/domainlab/algos/trainers/train_ema.py index 8b57368b0..a935b6570 100644 --- a/domainlab/algos/trainers/train_ema.py +++ b/domainlab/algos/trainers/train_ema.py @@ -57,7 +57,7 @@ def move_average(self, dict_data, epoch): self._ma_iter += 1 return dict_return_ema_para_curr_iter - def after_epoch(self, epoch): + def after_epoch(self, epoch, flag_info=None): torch_model = self.get_model() dict_para = torch_model.state_dict() # only for trainable parameters new_dict_para = self.move_average(dict_para, epoch) From 3cef1a2bdc3e6b92f46e89550cc1d5eef7ca8048 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 5 Dec 2024 16:10:27 +0100 Subject: [PATCH 12/17] fix test_ma --- domainlab/algos/trainers/train_ema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/domainlab/algos/trainers/train_ema.py b/domainlab/algos/trainers/train_ema.py index a935b6570..b2ccede40 100644 --- a/domainlab/algos/trainers/train_ema.py +++ b/domainlab/algos/trainers/train_ema.py @@ -63,4 +63,4 @@ def after_epoch(self, epoch, flag_info=None): new_dict_para = self.move_average(dict_para, epoch) # without deepcopy, this seems to work torch_model.load_state_dict(new_dict_para) - super().after_epoch(epoch) + super().after_epoch(epoch, flag_info) From 6f08403bbb92cab8398ce1154f9f28e573e0562c Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 6 Dec 2024 13:08:18 +0100 Subject: [PATCH 13/17] no_dump in test_fbopt.py --- tests/test_fbopt.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_fbopt.py b/tests/test_fbopt.py index 84b97f0f9..03dad9e16 100644 --- a/tests/test_fbopt.py +++ b/tests/test_fbopt.py @@ -9,7 +9,7 @@ def test_dann_fbopt(): """ dann """ - args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=dann --trainer=fbopt --nname=alexnet --epos=3" + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=dann --trainer=fbopt --nname=alexnet --epos=3 --no_dump" utils_test_algo(args) @@ -17,7 +17,7 @@ def test_jigen_fbopt(): """ jigen """ - args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=jigen --trainer=fbopt --nname=alexnet --epos=3" + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=jigen --trainer=fbopt --nname=alexnet --epos=3 --no_dump" utils_test_algo(args) @@ -25,7 +25,7 @@ def test_diva_fbopt(): """ diva """ - args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=diva --gamma_y=1.0 --trainer=fbopt --nname=alexnet --epos=3" + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=diva --gamma_y=1.0 --trainer=fbopt --nname=alexnet --epos=3 --no_dump" utils_test_algo(args) @@ -33,7 +33,7 @@ def test_erm_fbopt(): """ erm """ - args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt --nname=alexnet --epos=3" # pylint: disable=line-too-long + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt --nname=alexnet --epos=3 --no_dump" # pylint: disable=line-too-long with pytest.raises(RuntimeError): utils_test_algo(args) @@ -42,7 +42,7 @@ def test_irm_fbopt(): """ irm """ - args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt_irm --nname=alexnet --epos=3" # pylint: disable=line-too-long + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt_irm --nname=alexnet --epos=3 --no_dump" # pylint: disable=line-too-long utils_test_algo(args) @@ -50,5 +50,5 @@ def test_forcesetpoint_fbopt(): """ diva """ - args = "--te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --model=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=10 --es=0 --mu_init=0.00001 --coeff_ma_setpoint=0.5 --coeff_ma_output_state=0.99 --force_setpoint_change_once" + args = "--te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --model=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=10 --es=0 --mu_init=0.00001 --coeff_ma_setpoint=0.5 --coeff_ma_output_state=0.99 --force_setpoint_change_once --no_dump" utils_test_algo(args) From e6ac695ebfc38c32c0b9da0e12f1e1209c206b73 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Mon, 9 Dec 2024 16:55:19 +0100 Subject: [PATCH 14/17] Update generate_latex_table.py --- scripts/generate_latex_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/generate_latex_table.py b/scripts/generate_latex_table.py index caa333aab..ebade0583 100644 --- a/scripts/generate_latex_table.py +++ b/scripts/generate_latex_table.py @@ -10,7 +10,7 @@ def gen_latex_table(raw_df, fname="table_perf.tex", """ aggregate benchmark csv file to generate latex table """ - df_result = raw_df.groupby(group)[str_perf].agg(["mean", "std"]) + df_result = raw_df.groupby(group)[str_perf].agg(["mean", "std", "count"]) latex_table = df_result.to_latex(float_format="%.3f") str_table = df_result.to_string() print(str_table) From b5cc60a76d68c70d7d75e973a44c8c31981aa788 Mon Sep 17 00:00:00 2001 From: smilesun Date: Mon, 9 Dec 2024 17:06:08 +0100 Subject: [PATCH 15/17] . --- scripts/sh_genplot.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/sh_genplot.sh b/scripts/sh_genplot.sh index a1b1ecfad..9906e5f2e 100755 --- a/scripts/sh_genplot.sh +++ b/scripts/sh_genplot.sh @@ -1,2 +1,3 @@ mkdir $2 -python main_out.py --gen_plots $1 --outp_dir $2 +merge_csvs.sh +python main_out.py --gen_plots merged_data.csv --outp_dir $2 From 163ca911f4b7803f63404ccf376ed18e77108de7 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Mon, 9 Dec 2024 17:11:18 +0100 Subject: [PATCH 16/17] Update sh_genplot.sh --- scripts/sh_genplot.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/sh_genplot.sh b/scripts/sh_genplot.sh index 9906e5f2e..39f71f16d 100755 --- a/scripts/sh_genplot.sh +++ b/scripts/sh_genplot.sh @@ -1,3 +1,3 @@ -mkdir $2 -merge_csvs.sh -python main_out.py --gen_plots merged_data.csv --outp_dir $2 +# mkdir $2 +sh scripts/merge_csvs.sh $1 +python main_out.py --gen_plots merged_data.csv --outp_dir partial_agg_plots \ No newline at end of file From ae907b29dc6b9f8fc2695ddf3edf601b10ff079b Mon Sep 17 00:00:00 2001 From: smilesun Date: Tue, 10 Dec 2024 10:32:06 +0100 Subject: [PATCH 17/17] unit test no_dump --- tests/test_fbopt_irm.py | 2 +- tests/test_fbopt_setpoint_rewind.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_fbopt_irm.py b/tests/test_fbopt_irm.py index fb1f109f6..10cfbcb53 100644 --- a/tests/test_fbopt_irm.py +++ b/tests/test_fbopt_irm.py @@ -10,5 +10,5 @@ def test_mhof_irm(): """ args = "--te_d=0 --task=mnistcolor10 --model=erm \ --trainer=fbopt_irm --nname=conv_bn_pool_2 \ - --k_i_gain_ratio=0.5" + --k_i_gain_ratio=0.5 --no_dump" utils_test_algo(args) diff --git a/tests/test_fbopt_setpoint_rewind.py b/tests/test_fbopt_setpoint_rewind.py index 3ce13079b..3fcc8660e 100644 --- a/tests/test_fbopt_setpoint_rewind.py +++ b/tests/test_fbopt_setpoint_rewind.py @@ -8,5 +8,5 @@ def test_jigen_fbopt(): """ jigen """ - args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=jigen --trainer=fbopt --nname=alexnet --epos=300 --setpoint_rewind" + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=jigen --trainer=fbopt --nname=alexnet --epos=300 --setpoint_rewind --no_dump" utils_test_algo(args)