From 3f448102e5dd33d5cace41c8609d6e42eb8a248c Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Sun, 8 Oct 2023 16:06:25 +0200 Subject: [PATCH 01/10] Update fbopt_setpoint_ada.py --- domainlab/algos/trainers/fbopt_setpoint_ada.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/domainlab/algos/trainers/fbopt_setpoint_ada.py b/domainlab/algos/trainers/fbopt_setpoint_ada.py index 1361708d5..bc9e43cb8 100644 --- a/domainlab/algos/trainers/fbopt_setpoint_ada.py +++ b/domainlab/algos/trainers/fbopt_setpoint_ada.py @@ -4,7 +4,12 @@ import torch from domainlab.utils.logger import Logger +def list_add(list1, list2): + return [a + b for a, b in zip(list1, list2)] +def list_multiply(list1, coeff): + return [ele * coeff for ele in list1] + def is_less_list_any(list1, list2): """ judge if one list is less than the other @@ -36,9 +41,10 @@ def __init__(self, state=None, args=None): state = DominateAnyComponent() self.transition_to(state) self.ma_epo_reg_loss = None - self.state_epo_reg_loss = None self.coeff_ma = 0.5 # FIXME - self.state_task_loss = None + self.state_task_loss = 0.0 + self.state_epo_reg_loss = 0.0 + self.coeff_ma_output = 0.9 # FIXME # initial value will be set via trainer self.setpoint4R = None self.setpoint4ell = None @@ -65,7 +71,7 @@ def observe(self, epo_reg_loss, epo_task_loss): read current epo_reg_loss continuously FIXME: setpoint should also be able to be eliviated """ - self.state_epo_reg_loss = epo_reg_loss + self.state_epo_reg_loss = list_add(list_multiply(epo_reg_loss, self.coeff_ma_output), list_multiply(self.state_epo_reg_loss, 1 - self.coeff_ma_output)) self.state_task_loss = epo_task_loss if self.state_updater.update_setpoint(): logger = Logger.get_logger(logger_name='main_out_logger', loglevel="INFO") From 1276cd92cc7c74cebf83e358f3ba871c8f588982 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Sun, 8 Oct 2023 17:15:00 +0200 Subject: [PATCH 02/10] Update fbopt_setpoint_ada.py --- domainlab/algos/trainers/fbopt_setpoint_ada.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/domainlab/algos/trainers/fbopt_setpoint_ada.py b/domainlab/algos/trainers/fbopt_setpoint_ada.py index bc9e43cb8..6660e9a1d 100644 --- a/domainlab/algos/trainers/fbopt_setpoint_ada.py +++ b/domainlab/algos/trainers/fbopt_setpoint_ada.py @@ -43,7 +43,7 @@ def __init__(self, state=None, args=None): self.ma_epo_reg_loss = None self.coeff_ma = 0.5 # FIXME self.state_task_loss = 0.0 - self.state_epo_reg_loss = 0.0 + self.state_epo_reg_loss = [0.0 for _ in range(10)] # FIXME self.coeff_ma_output = 0.9 # FIXME # initial value will be set via trainer self.setpoint4R = None From 3d08ba3b09259e44b92ea73817a8e069fb9e46f4 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Sun, 8 Oct 2023 19:12:41 +0200 Subject: [PATCH 03/10] Update fbopt_setpoint_ada.py --- domainlab/algos/trainers/fbopt_setpoint_ada.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/domainlab/algos/trainers/fbopt_setpoint_ada.py b/domainlab/algos/trainers/fbopt_setpoint_ada.py index 6660e9a1d..e7dede900 100644 --- a/domainlab/algos/trainers/fbopt_setpoint_ada.py +++ b/domainlab/algos/trainers/fbopt_setpoint_ada.py @@ -44,7 +44,7 @@ def __init__(self, state=None, args=None): self.coeff_ma = 0.5 # FIXME self.state_task_loss = 0.0 self.state_epo_reg_loss = [0.0 for _ in range(10)] # FIXME - self.coeff_ma_output = 0.9 # FIXME + self.coeff_ma_output = args.coeff_ma_output_state # initial value will be set via trainer self.setpoint4R = None self.setpoint4ell = None From 6d11c7b16b5d117bea256398d953cf40510c4604 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Sun, 8 Oct 2023 19:14:04 +0200 Subject: [PATCH 04/10] Update args_fbopt.py --- domainlab/algos/trainers/args_fbopt.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/domainlab/algos/trainers/args_fbopt.py b/domainlab/algos/trainers/args_fbopt.py index a84969efd..2d05159a6 100644 --- a/domainlab/algos/trainers/args_fbopt.py +++ b/domainlab/algos/trainers/args_fbopt.py @@ -22,6 +22,10 @@ def add_args2parser_fbopt(parser): parser.add_argument('--coeff_ma', type=float, default=0.5, help='exponential moving average') + + parser.add_argument('--coeff_ma_output_state', type=float, default=0.9, + help='setpoint output as state exponential moving average') + parser.add_argument('--exp_shoulder_clip', type=float, default=10, help='clip before exponential operation') From 26c365884c94b7ca47e5bd591729bee4673f0ffa Mon Sep 17 00:00:00 2001 From: smilesun Date: Mon, 9 Oct 2023 11:31:35 +0200 Subject: [PATCH 05/10] . --- run_fbopt_mnist.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_fbopt_mnist.sh b/run_fbopt_mnist.sh index 212244e1d..86940605f 100644 --- a/run_fbopt_mnist.sh +++ b/run_fbopt_mnist.sh @@ -3,4 +3,4 @@ # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # pytest -s tests/test_fbopt.py -python main_out.py --te_d=1 --tr_d 0 3 --task=mnistcolor10 --bs=16 --aname=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=200 --es=100 --mu_init=0.00001 +python main_out.py --te_d=1 --tr_d 0 3 --task=mnistcolor10 --bs=16 --aname=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=500 --es=50 --mu_init=0.00001 From 643abde7d0463a5f82571d8b9942231c575b444c Mon Sep 17 00:00:00 2001 From: smilesun Date: Mon, 9 Oct 2023 13:56:41 +0200 Subject: [PATCH 06/10] ma 0.5 --- run_fbopt_mnist.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_fbopt_mnist.sh b/run_fbopt_mnist.sh index 86940605f..53325386b 100644 --- a/run_fbopt_mnist.sh +++ b/run_fbopt_mnist.sh @@ -3,4 +3,4 @@ # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # pytest -s tests/test_fbopt.py -python main_out.py --te_d=1 --tr_d 0 3 --task=mnistcolor10 --bs=16 --aname=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=500 --es=50 --mu_init=0.00001 +python main_out.py --te_d=1 --tr_d 0 3 --task=mnistcolor10 --bs=16 --aname=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=500 --es=50 --mu_init=0.00001 --coeff_ma_output_state=0.5 From 688f5de23eae6e90c633487deab7768c2fc59ad3 Mon Sep 17 00:00:00 2001 From: smilesun Date: Mon, 9 Oct 2023 13:59:00 +0200 Subject: [PATCH 07/10] . --- run_fbopt_mnist.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_fbopt_mnist.sh b/run_fbopt_mnist.sh index 53325386b..ea43bf77e 100644 --- a/run_fbopt_mnist.sh +++ b/run_fbopt_mnist.sh @@ -3,4 +3,4 @@ # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # pytest -s tests/test_fbopt.py -python main_out.py --te_d=1 --tr_d 0 3 --task=mnistcolor10 --bs=16 --aname=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=500 --es=50 --mu_init=0.00001 --coeff_ma_output_state=0.5 +python main_out.py --te_d=1 --tr_d 0 3 --task=mnistcolor10 --bs=16 --aname=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=2000 --es=50 --mu_init=0.00001 --coeff_ma_output_state=0.5 From 03a67201e29cf85d26a24353b0d2635be907b8d2 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Mon, 9 Oct 2023 17:55:34 +0200 Subject: [PATCH 08/10] Update args_fbopt.py --- domainlab/algos/trainers/args_fbopt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/domainlab/algos/trainers/args_fbopt.py b/domainlab/algos/trainers/args_fbopt.py index b26123612..bc1d5aac8 100644 --- a/domainlab/algos/trainers/args_fbopt.py +++ b/domainlab/algos/trainers/args_fbopt.py @@ -23,7 +23,7 @@ def add_args2parser_fbopt(parser): parser.add_argument('--coeff_ma', type=float, default=0.5, help='exponential moving average') - parser.add_argument('--coeff_ma_output_state', type=float, default=0.9, + parser.add_argument('--coeff_ma_output_state', type=float, default=0.5, help='setpoint output as state exponential moving average') From 2f6df907311e6f96c2171081d3f304334c751457 Mon Sep 17 00:00:00 2001 From: smilesun Date: Tue, 10 Oct 2023 12:00:14 +0200 Subject: [PATCH 09/10] fix bug 0 drag state down --- domainlab/algos/trainers/fbopt_setpoint_ada.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/domainlab/algos/trainers/fbopt_setpoint_ada.py b/domainlab/algos/trainers/fbopt_setpoint_ada.py index e7dede900..a789a1ccc 100644 --- a/domainlab/algos/trainers/fbopt_setpoint_ada.py +++ b/domainlab/algos/trainers/fbopt_setpoint_ada.py @@ -9,7 +9,7 @@ def list_add(list1, list2): def list_multiply(list1, coeff): return [ele * coeff for ele in list1] - + def is_less_list_any(list1, list2): """ judge if one list is less than the other @@ -71,8 +71,10 @@ def observe(self, epo_reg_loss, epo_task_loss): read current epo_reg_loss continuously FIXME: setpoint should also be able to be eliviated """ - self.state_epo_reg_loss = list_add(list_multiply(epo_reg_loss, self.coeff_ma_output), list_multiply(self.state_epo_reg_loss, 1 - self.coeff_ma_output)) - self.state_task_loss = epo_task_loss + self.state_epo_reg_loss = [self.coeff_ma_output*a + ( 1-self.coeff_ma_output )*b if a != 0.0 else b for a, b in zip(self.state_epo_reg_loss, epo_reg_loss)] + if self.state_task_loss == 0.0: + self.state_task_loss = epo_task_loss + self.state_task_loss = self.coeff_ma_output * self.state_task_loss + (1-self.coeff_ma_output) * epo_task_loss if self.state_updater.update_setpoint(): logger = Logger.get_logger(logger_name='main_out_logger', loglevel="INFO") self.setpoint4R = self.state_epo_reg_loss From 13d32b307363b3cf808e891049f6b8a08df93c33 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Tue, 10 Oct 2023 12:16:49 +0200 Subject: [PATCH 10/10] Update README.md --- README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.md b/README.md index cff6a3692..23ed163d6 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,23 @@ pip install -r requirements_notorch.txt conda install tensorboard ``` +#### Download PACS + +step 1: + +use the following script to download PACS to your local laptop and upload it to your cluster + +https://github.com/marrlab/DomainLab/blob/fbopt/data/script/download_pacs.py + +step 2: +make a symbolic link following the example script in https://github.com/marrlab/DomainLab/blob/master/sh_pacs.sh + +where `mkdir -p data/pacs` is executed under the repository directory, + +`ln -s /dir/to/yourdata/pacs/raw ./data/pacs/PACS` +will create a symbolic link under the repository directory + + #### Windows installation details To install DomainLab on Windows, please remove the `snakemake` dependency from the `requirements.txt` file.