diff --git a/docs/doc_usage_cmd.md b/docs/doc_usage_cmd.md index 7230fec25..6ee916f06 100644 --- a/docs/doc_usage_cmd.md +++ b/docs/doc_usage_cmd.md @@ -21,7 +21,34 @@ To run DomainLab, the minimum necessary parameters are: ### Advanced Configuration - **Learning Rate (`--lr`):** Set the training learning rate. -- **Regularization (`--gamma_reg`):** Weight of regularization loss. +- **Regularization (`--gamma_reg`):** Sets the weight of the regularization + loss. This parameter can be configured either as + a single value applied to individual classes, + or using a dictionary to specify different + weights for different models and trainers. + + - **Command Line Usage:** + - For a single value: `python script.py --gamma_reg=0.1` + - For multiple values: `python script.py --gamma_reg='default=0.1,dann=0.05,jigen=0.2'` + + - **YAML Configuration:** + - For a single value: + + ```yaml + gamma_reg: 0.1 + ``` + + - For different values: + + ```yaml + gamma_reg: + dann: 0.05 + dial: 0.2 + default: 0.1 # value for every other instance + ``` +Gamma reg is available for the trainers, as well as the +dann and jigen model. + - **Early Stopping (`--es`):** Steps for early stopping. - **Random Seed (`--seed`):** Seed for reproducibility. - **CUDA Options (`--nocu`, `--device`):** Configure CUDA usage and device settings. diff --git a/domainlab/algos/builder_dann.py b/domainlab/algos/builder_dann.py index 65b26a62a..73b373c15 100644 --- a/domainlab/algos/builder_dann.py +++ b/domainlab/algos/builder_dann.py @@ -13,6 +13,7 @@ from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter from domainlab.models.model_dann import mk_dann from domainlab.utils.utils_cuda import get_device +from domainlab.utils.hyperparameter_retrieval import get_gamma_reg class NodeAlgoBuilderDANN(NodeAlgoBuilder): @@ -55,7 +56,7 @@ def init_business(self, exp): model = mk_dann(list_str_y=task.list_str_y, net_classifier=net_classifier)( list_d_tr=task.list_domain_tr, - alpha=args.gamma_reg, + alpha=get_gamma_reg(args, 'dann'), net_encoder=net_encoder, net_discriminator=net_discriminator, builder=self) diff --git a/domainlab/algos/builder_jigen1.py b/domainlab/algos/builder_jigen1.py index e899e32f2..de671affe 100644 --- a/domainlab/algos/builder_jigen1.py +++ b/domainlab/algos/builder_jigen1.py @@ -15,6 +15,7 @@ from domainlab.dsets.utils_wrapdset_patches import WrapDsetPatches from domainlab.models.model_jigen import mk_jigen from domainlab.utils.utils_cuda import get_device +from domainlab.utils.hyperparameter_retrieval import get_gamma_reg class NodeAlgoBuilderJiGen(NodeAlgoBuilder): @@ -56,7 +57,7 @@ def init_business(self, exp): model = mk_jigen( list_str_y=task.list_str_y, net_classifier=net_classifier)( - coeff_reg=args.gamma_reg, + coeff_reg=get_gamma_reg(args, 'jigen'), net_encoder=net_encoder, net_classifier_permutation=net_classifier_perm, n_perm=args.nperm, diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index d9bda1513..f5fce7c16 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -258,3 +258,11 @@ def dset_decoration_args_algo(self, args, ddset): if self._decoratee is not None: return self._decoratee.dset_decoration_args_algo(args, ddset) return ddset + + def print_parameters(self): + """ + Function to print all parameters of the object. + Can be used to print the parameters of any child class + """ + params = vars(self) + print(f"Parameters of {type(self).__name__}: {params}") diff --git a/domainlab/algos/trainers/train_dial.py b/domainlab/algos/trainers/train_dial.py index 75a5e34f0..4fe700f45 100644 --- a/domainlab/algos/trainers/train_dial.py +++ b/domainlab/algos/trainers/train_dial.py @@ -5,6 +5,7 @@ from torch.autograd import Variable from domainlab.algos.trainers.train_basic import TrainerBasic +from domainlab.utils.hyperparameter_retrieval import get_gamma_reg class TrainerDIAL(TrainerBasic): @@ -49,4 +50,4 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): tensor_x_adv = self.gen_adversarial(self.device, tensor_x, tensor_y) tensor_x_batch_adv_no_grad = Variable(tensor_x_adv, requires_grad=False) loss_dial = self.model.cal_task_loss(tensor_x_batch_adv_no_grad, tensor_y) - return [loss_dial], [self.aconf.gamma_reg] + return [loss_dial], [get_gamma_reg(self.aconf, self.name)] diff --git a/domainlab/algos/trainers/train_fishr.py b/domainlab/algos/trainers/train_fishr.py index 1a11e3780..3580a0721 100644 --- a/domainlab/algos/trainers/train_fishr.py +++ b/domainlab/algos/trainers/train_fishr.py @@ -13,6 +13,7 @@ backpack = None from domainlab.algos.trainers.train_basic import TrainerBasic +from domainlab.utils.hyperparameter_retrieval import get_gamma_reg class TrainerFishr(TrainerBasic): @@ -39,7 +40,7 @@ def tr_epoch(self, epoch): dict_layerwise_var_var_grads_sum = \ {key: val.sum() for key, val in dict_layerwise_var_var_grads.items()} loss_fishr = sum(dict_layerwise_var_var_grads_sum.values()) - loss = sum(list_loss_erm) + self.aconf.gamma_reg * loss_fishr + loss = sum(list_loss_erm) + get_gamma_reg(self.aconf, self.name) * loss_fishr loss.backward() self.optimizer.step() self.epo_loss_tr += loss.detach().item() diff --git a/domainlab/algos/trainers/train_matchdg.py b/domainlab/algos/trainers/train_matchdg.py index 6a3edd996..72c14ab83 100644 --- a/domainlab/algos/trainers/train_matchdg.py +++ b/domainlab/algos/trainers/train_matchdg.py @@ -13,6 +13,7 @@ ) from domainlab.tasks.utils_task_dset import DsetIndDecorator4XYD from domainlab.utils.logger import Logger +from domainlab.utils.hyperparameter_retrieval import get_gamma_reg class TrainerMatchDG(AbstractTrainer): @@ -36,7 +37,7 @@ def init_business( self.base_domain_size = get_base_domain_size4match_dg(self.task) self.epo_loss_tr = 0 self.flag_erm = flag_erm - self.lambda_ctr = self.aconf.gamma_reg + self.lambda_ctr = get_gamma_reg(aconf, self.name) self.mk_match_tensor(epoch=0) self.flag_match_tensor_sweep_over = False self.tuple_tensor_ref_domain2each_y = None diff --git a/domainlab/algos/trainers/train_mldg.py b/domainlab/algos/trainers/train_mldg.py index 90318286c..e91310adf 100644 --- a/domainlab/algos/trainers/train_mldg.py +++ b/domainlab/algos/trainers/train_mldg.py @@ -10,6 +10,7 @@ from domainlab.algos.trainers.train_basic import TrainerBasic from domainlab.tasks.utils_task import mk_loader from domainlab.tasks.utils_task_dset import DsetZip +from domainlab.utils.hyperparameter_retrieval import get_gamma_reg class TrainerMLDG(AbstractTrainer): @@ -108,7 +109,7 @@ def tr_epoch(self, epoch): loss = ( loss_source_task.sum() + source_reg_tr.sum() - + self.aconf.gamma_reg * loss_look_forward.sum() + + get_gamma_reg(self.aconf, self.name) * loss_look_forward.sum() ) # loss.backward() diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 7c7e004e7..1b8593e8f 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -12,6 +12,37 @@ from domainlab.models.args_vae import add_args2parser_vae from domainlab.utils.logger import Logger +class ParseValuesOrKeyValuePairs(argparse.Action): + """Class used for arg parsing where values are provided in a key value format""" + + def __call__(self, parser: argparse.ArgumentParser, + namespace: argparse.Namespace, values: str, option_string: str = None): + """ + Handle parsing of key value pairs, or a single value instead + + Args: + parser (argparse.ArgumentParser): The ArgumentParser object. + namespace (argparse.Namespace): The namespace object to store parsed values. + values (str): The string containing key=value pairs or a single float value. + option_string (str, optional): The option string that triggered this action (unused). + + Raises: + ValueError: If the values cannot be parsed to float. + """ + if "=" in values: + my_dict = {} + for kv in values.split(","): + k, v = kv.split("=") + try: + my_dict[k.strip()] = float(v.strip()) + except ValueError: + raise ValueError(f"Invalid value in key-value pair: '{kv}', must be float") + setattr(namespace, self.dest, my_dict) + else: + try: + setattr(namespace, self.dest, float(values)) + except ValueError: + raise ValueError(f"Invalid value for {self.dest}: '{values}', must be float") def mk_parser_main(): """ @@ -31,7 +62,13 @@ def mk_parser_main(): parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") parser.add_argument( - "--gamma_reg", type=float, default=0.1, help="weight of regularization loss" + "--gamma_reg", + default=0.1, + help="weight of regularization loss in the form of $$\ell(\cdot) + \mu \times R(\cdot)$$ \ + can specify per model as 'default=3.0, dann=1.0,jigen=2.0', where default refer to gamma for trainer \ + note diva is implemented $$\ell(\cdot) + \mu \times R(\cdot)$$ \ + so diva does not have gamma_reg", + action=ParseValuesOrKeyValuePairs ) parser.add_argument("--es", type=int, default=1, help="early stop steps") diff --git a/domainlab/models/a_model.py b/domainlab/models/a_model.py index beb867167..71b0db334 100644 --- a/domainlab/models/a_model.py +++ b/domainlab/models/a_model.py @@ -178,3 +178,35 @@ def dset_decoration_args_algo(self, args, ddset): if self._decoratee is not None: return self._decoratee.dset_decoration_args_algo(args, ddset) return ddset + + @property + def p_na_prefix(self): + """ + common prefix for Models + """ + return "Model" + + @property + def name(self): + """ + get the name of the algorithm + """ + na_prefix = self.p_na_prefix + len_prefix = len(na_prefix) + na_class = type(self).__name__ + if na_class[:len_prefix] != na_prefix: + raise RuntimeError( + "Model builder node class must start with ", + na_prefix, + "the current class is named: ", + na_class, + ) + return type(self).__name__[len_prefix:].lower() + + def print_parameters(self): + """ + Function to print all parameters of the object. + Can be used to print the parameters of every child class. + """ + params = vars(self) + print(f"Parameters of {type(self).__name__}: {params}") diff --git a/domainlab/models/model_dann.py b/domainlab/models/model_dann.py index 2abd3feda..59a619e7e 100644 --- a/domainlab/models/model_dann.py +++ b/domainlab/models/model_dann.py @@ -85,13 +85,15 @@ def hyper_update(self, epoch, fun_scheduler): dict_rst = fun_scheduler( epoch ) # the __call__ method of hyperparameter scheduler - self.alpha = dict_rst["alpha"] + self.alpha = dict_rst[self.name + "_alpha"] def hyper_init(self, functor_scheduler): """hyper_init. :param functor_scheduler: """ - return functor_scheduler(trainer=None, alpha=self.alpha) + parameters = {} + parameters[self.name + "_alpha"] = self.alpha + return functor_scheduler(trainer=None, **parameters) def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others): _ = others diff --git a/domainlab/models/model_diva.py b/domainlab/models/model_diva.py index b67272784..362c14518 100644 --- a/domainlab/models/model_diva.py +++ b/domainlab/models/model_diva.py @@ -98,9 +98,9 @@ def hyper_update(self, epoch, fun_scheduler): :param fun_scheduler: """ dict_rst = fun_scheduler(epoch) - self.beta_d = dict_rst["beta_d"] - self.beta_y = dict_rst["beta_y"] - self.beta_x = dict_rst["beta_x"] + self.beta_d = dict_rst[self.name + "_beta_d"] + self.beta_y = dict_rst[self.name + "_beta_x"] + self.beta_x = dict_rst[self.name + "_beta_y"] def hyper_init(self, functor_scheduler): """ @@ -108,8 +108,12 @@ def hyper_init(self, functor_scheduler): :param functor_scheduler: the class name of the scheduler """ + parameters = {} + parameters[self.name + "_beta_d"] = self.beta_d + parameters[self.name + "_beta_y"] = self.beta_y + parameters[self.name + "_beta_x"] = self.beta_x return functor_scheduler( - trainer=None, beta_d=self.beta_d, beta_y=self.beta_y, beta_x=self.beta_x + trainer=None, **parameters ) def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): diff --git a/domainlab/models/model_hduva.py b/domainlab/models/model_hduva.py index 61411982e..5c7bb290d 100644 --- a/domainlab/models/model_hduva.py +++ b/domainlab/models/model_hduva.py @@ -66,10 +66,10 @@ def hyper_update(self, epoch, fun_scheduler): dict_rst = fun_scheduler( epoch ) # the __call__ function of hyper-para-scheduler object - self.beta_d = dict_rst["beta_d"] - self.beta_y = dict_rst["beta_y"] - self.beta_x = dict_rst["beta_x"] - self.beta_t = dict_rst["beta_t"] + self.beta_d = dict_rst[self.name + "_beta_d"] + self.beta_y = dict_rst[self.name + "_beta_y"] + self.beta_x = dict_rst[self.name + "_beta_x"] + self.beta_t = dict_rst[self.name + "_beta_t"] def hyper_init(self, functor_scheduler): """hyper_init. @@ -78,12 +78,13 @@ def hyper_init(self, functor_scheduler): # calling the constructor of the hyper-parameter-scheduler class, so that this scheduler # class build a dictionary {"beta_d":self.beta_d, "beta_y":self.beta_y} # constructor signature is def __init__(self, **kwargs): + parameters = {} + parameters[self.name + "_beta_d"] = self.beta_d + parameters[self.name + "_beta_y"] = self.beta_y + parameters[self.name + "_beta_x"] = self.beta_x + parameters[self.name + "_beta_t"] = self.beta_t return functor_scheduler( - trainer=None, - beta_d=self.beta_d, - beta_y=self.beta_y, - beta_x=self.beta_x, - beta_t=self.beta_t, + trainer=None, **parameters ) @store_args diff --git a/domainlab/utils/hyperparameter_retrieval.py b/domainlab/utils/hyperparameter_retrieval.py new file mode 100644 index 000000000..5a1960e2e --- /dev/null +++ b/domainlab/utils/hyperparameter_retrieval.py @@ -0,0 +1,18 @@ +""" +retrieval for hyperparameters +""" + +def get_gamma_reg(args, component_name): + """ + Retrieves either a shared gamma regularization, or individual ones for each specified object + """ + gamma_reg = args.gamma_reg + if isinstance(gamma_reg, dict): + if component_name in gamma_reg: + return gamma_reg[component_name] + if 'default' in gamma_reg: + return gamma_reg['default'] + raise ValueError("""If a gamma_reg dict is specified, + but no value set for every model and trainer, + a default value must be specified.""") + return gamma_reg # Return the single value if it's not a dictionary diff --git a/examples/conf/vlcs_diva_mldg_dial.yaml b/examples/conf/vlcs_diva_mldg_dial.yaml index 5ed10d0c5..e24cfd22e 100644 --- a/examples/conf/vlcs_diva_mldg_dial.yaml +++ b/examples/conf/vlcs_diva_mldg_dial.yaml @@ -5,7 +5,10 @@ val_threshold: 0.8 # threashold before which training does not model: dann_diva # combine model DANN with DIVA epos: 1 # number of epochs trainer: mldg_dial # combine trainer MLDG and DIAL -gamma_reg: 1.0 # hyperparameter of DANN +gamma_reg: + default: 1.0 + dann: 1.5 +# in this case, mldg and dial get the default gamma_reg value 1.0 gamma_y: 700000.0 # hyperparameter of diva gamma_d: 100000.0 # hyperparameter of diva npath: examples/nets/resnet.py # neural network for class classification diff --git a/run_benchmark_slurm.sh b/run_benchmark_slurm.sh index 91f316f45..72b8bbd83 100755 --- a/run_benchmark_slurm.sh +++ b/run_benchmark_slurm.sh @@ -32,4 +32,4 @@ echo "Number of GPUs: $NUMBER_GPUS" echo "Results will be stored in: $results_dir" # Helmholtz -snakemake --profile "examples/yaml/slurm" --config yaml_file="$CONFIGFILE" --keep-going --keep-incomplete --notemp --cores 3 -s "domainlab/exp_protocol/benchmark.smk" --configfile "$CONFIGFILE" --config output_dir="$results_dir" 2>&1 | tee "$logfile" +snakemake --profile "examples/yaml/slurm" --config yaml_file="$CONFIGFILE" --keep-going --keep-incomplete --notemp --cores 3 -s "domainlab/exp_protocol/benchmark.smk" --configfile "$CONFIGFILE" --config output_dir="$results_dir" 2>&1 | tee "$logfile" \ No newline at end of file diff --git a/tests/test_a_model.py b/tests/test_a_model.py new file mode 100644 index 000000000..fac060f06 --- /dev/null +++ b/tests/test_a_model.py @@ -0,0 +1,69 @@ +""" +Test a model functionality +""" +import pytest +from domainlab.models.a_model import AModel + +class ModelTest(AModel): + """ + A test model class conforming to model naming + """ + + def __init__(self): + super().__init__() + self.test_param = 42 + + def cal_task_loss(self, tensor_x, tensor_y): + return 0 + + def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): + return 0 + + @property + def metric4msel(self): + return "" + + +class InvalidTest(AModel): + """ + A test model class that does not conform to the "Model" prefix naming convention + """ + + def cal_task_loss(self, tensor_x, tensor_y): + return 0 + + def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): + return 0 + + @property + def metric4msel(self): + return "" + + +def test_model_name_valid(): + """ + Test a valid model name + """ + model = ModelTest() + assert model.name == "test", f"Expected 'test' but got '{model.name}'" + + +def test_model_name_invalid(): + """ + Test an invalid model name + """ + model = InvalidTest() + with pytest.raises(RuntimeError, match="Model builder node class must start with"): + _ = model.name + + +def test_print_parameters(capsys): + """ + Test the printing of parameters + """ + model = ModelTest() + model.print_parameters() + captured = capsys.readouterr() + assert "Parameters of ModelTest:" in captured.out + assert "'test_param': 42" in captured.out + \ No newline at end of file diff --git a/tests/test_a_trainer.py b/tests/test_a_trainer.py new file mode 100644 index 000000000..d4e073ae3 --- /dev/null +++ b/tests/test_a_trainer.py @@ -0,0 +1,34 @@ +""" +Test a trainer functionality +""" + +from domainlab.algos.trainers.a_trainer import AbstractTrainer + +class TrainerTest(AbstractTrainer): + """ + A test trainer class conforming to model naming + """ + + def __init__(self): + super().__init__() + self.test_param = 42 + + def tr_epoch(self, epoch): + """ + :param epoch: + """ + + def before_tr(self): + """ + before training, probe model performance + """ + +def test_print_parameters(capsys): + """ + Test the printing of parameters + """ + trainer = TrainerTest() + trainer.print_parameters() + captured = capsys.readouterr() + assert "Parameters of TrainerTest:" in captured.out + assert "'test_param': 42" in captured.out diff --git a/tests/test_argparse.py b/tests/test_argparse.py index 683f66911..2eaaae10b 100644 --- a/tests/test_argparse.py +++ b/tests/test_argparse.py @@ -1,14 +1,11 @@ """ Test argparser functionality """ - import os import sys - import pytest - from domainlab.arg_parser import apply_dict_to_args, mk_parser_main, parse_cmd_args - +from domainlab.arg_parser import ParseValuesOrKeyValuePairs def test_parse_cmd_args_warning(): """Call argparser for command line""" @@ -54,3 +51,31 @@ def test_apply_dict_to_args(): apply_dict_to_args(args, data, extend=True) assert args.a == 1 assert args.model == "diva" + +def test_store_dict_key_value_valid(): + """Testing to parse valid gamma_reg value""" + parser = mk_parser_main() + parser.add_argument("--keypair", action=ParseValuesOrKeyValuePairs) + namespace = parser.parse_args(["--keypair", "1"]) + assert namespace.keypair == 1.0 + +def test_store_dict_key_value_pair_valid(): + """Testing to parse valid gamma_reg key value paris""" + parser = mk_parser_main() + parser.add_argument("--keypair", action=ParseValuesOrKeyValuePairs) + namespace = parser.parse_args(["--keypair", "value1=1,value2=2"]) + assert namespace.keypair == {"value1": 1.0, "value2": 2.0} + +def test_store_dict_key_value_invalid(): + """Testing to parse invalid gamma_reg value""" + parser = mk_parser_main() + parser.add_argument("--keypair", action=ParseValuesOrKeyValuePairs) + with pytest.raises(ValueError): + parser.parse_args(["--keypair", "invalid"]) + +def test_store_dict_key_value_pair_invalid(): + """Testing to parse invalid gamma_reg key value pairs""" + parser = mk_parser_main() + parser.add_argument("--keypair", action=ParseValuesOrKeyValuePairs) + with pytest.raises(ValueError): + parser.parse_args(["--keypair", "value1=1,value2=invalid"]) diff --git a/tests/test_hyperparameter_retrieval.py b/tests/test_hyperparameter_retrieval.py new file mode 100644 index 000000000..a9a141329 --- /dev/null +++ b/tests/test_hyperparameter_retrieval.py @@ -0,0 +1,44 @@ +""" +unit test for hyperparameter parsing +""" +import pytest +from domainlab.arg_parser import mk_parser_main +from domainlab.utils.hyperparameter_retrieval import get_gamma_reg + +def test_store_dict_key_pair_single_value(): + """Test to parse a single gamma_reg parameter""" + parser = mk_parser_main() + args = parser.parse_args(['--gamma_reg', '0.5']) + assert args.gamma_reg == 0.5 + +def test_store_dict_key_pair_dict_value(): + """Test to parse a dict for the gamma_reg""" + parser = mk_parser_main() + args = parser.parse_args(['--gamma_reg', 'dann=1.0,jigen=2.0']) + assert args.gamma_reg == {'dann': 1.0, 'jigen': 2.0} + +def test_get_gamma_reg_single_value(): + """Test to retrieve a single gamma_reg parameter which is applied to all objects""" + parser = mk_parser_main() + args = parser.parse_args(['--gamma_reg', '0.5']) + assert get_gamma_reg(args, 'dann') == 0.5 + +def test_get_gamma_reg_dict_value(): + """Test to retrieve a dict of gamma_reg parameters for different objects""" + parser = mk_parser_main() + args = parser.parse_args(['--gamma_reg', 'default=5.0,dann=1.0,jigen=2.0']) + assert get_gamma_reg(args, 'dann') == 1.0 + assert get_gamma_reg(args, 'jigen') == 2.0 + assert get_gamma_reg(args, 'nonexistent') == 5.0 # if we implement other + # model/trainers, + # since not specified in command line arguments, the new model/trainer + # called "nonexistent" should + # get the default value 5.0. + +def test_exception(): + """Test to not specify a default value""" + parser = mk_parser_main() + args = parser.parse_args(['--gamma_reg', 'dann=1.0']) + + with pytest.raises(ValueError, match="If a gamma_reg dict is specified"): + get_gamma_reg(args, 'jigen')