Skip to content

Commit

Permalink
Merge pull request #806 from marrlab/gamma_reg_collision
Browse files Browse the repository at this point in the history
master: Fixed hyperparameter collisions
  • Loading branch information
smilesun authored Jul 2, 2024
2 parents faf7589 + bdf84c2 commit ca493e5
Show file tree
Hide file tree
Showing 20 changed files with 339 additions and 29 deletions.
29 changes: 28 additions & 1 deletion docs/doc_usage_cmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,34 @@ To run DomainLab, the minimum necessary parameters are:
### Advanced Configuration

- **Learning Rate (`--lr`):** Set the training learning rate.
- **Regularization (`--gamma_reg`):** Weight of regularization loss.
- **Regularization (`--gamma_reg`):** Sets the weight of the regularization
loss. This parameter can be configured either as
a single value applied to individual classes,
or using a dictionary to specify different
weights for different models and trainers.

- **Command Line Usage:**
- For a single value: `python script.py --gamma_reg=0.1`
- For multiple values: `python script.py --gamma_reg='default=0.1,dann=0.05,jigen=0.2'`

- **YAML Configuration:**
- For a single value:

```yaml
gamma_reg: 0.1
```
- For different values:
```yaml
gamma_reg:
dann: 0.05
dial: 0.2
default: 0.1 # value for every other instance
```
Gamma reg is available for the trainers, as well as the
dann and jigen model.
- **Early Stopping (`--es`):** Steps for early stopping.
- **Random Seed (`--seed`):** Seed for reproducibility.
- **CUDA Options (`--nocu`, `--device`):** Configure CUDA usage and device settings.
Expand Down
3 changes: 2 additions & 1 deletion domainlab/algos/builder_dann.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter
from domainlab.models.model_dann import mk_dann
from domainlab.utils.utils_cuda import get_device
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class NodeAlgoBuilderDANN(NodeAlgoBuilder):
Expand Down Expand Up @@ -55,7 +56,7 @@ def init_business(self, exp):
model = mk_dann(list_str_y=task.list_str_y,
net_classifier=net_classifier)(
list_d_tr=task.list_domain_tr,
alpha=args.gamma_reg,
alpha=get_gamma_reg(args, 'dann'),
net_encoder=net_encoder,
net_discriminator=net_discriminator,
builder=self)
Expand Down
3 changes: 2 additions & 1 deletion domainlab/algos/builder_jigen1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from domainlab.dsets.utils_wrapdset_patches import WrapDsetPatches
from domainlab.models.model_jigen import mk_jigen
from domainlab.utils.utils_cuda import get_device
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class NodeAlgoBuilderJiGen(NodeAlgoBuilder):
Expand Down Expand Up @@ -56,7 +57,7 @@ def init_business(self, exp):
model = mk_jigen(
list_str_y=task.list_str_y,
net_classifier=net_classifier)(
coeff_reg=args.gamma_reg,
coeff_reg=get_gamma_reg(args, 'jigen'),
net_encoder=net_encoder,
net_classifier_permutation=net_classifier_perm,
n_perm=args.nperm,
Expand Down
8 changes: 8 additions & 0 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,11 @@ def dset_decoration_args_algo(self, args, ddset):
if self._decoratee is not None:
return self._decoratee.dset_decoration_args_algo(args, ddset)
return ddset

def print_parameters(self):
"""
Function to print all parameters of the object.
Can be used to print the parameters of any child class
"""
params = vars(self)
print(f"Parameters of {type(self).__name__}: {params}")
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_dial.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.autograd import Variable

from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerDIAL(TrainerBasic):
Expand Down Expand Up @@ -49,4 +50,4 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
tensor_x_adv = self.gen_adversarial(self.device, tensor_x, tensor_y)
tensor_x_batch_adv_no_grad = Variable(tensor_x_adv, requires_grad=False)
loss_dial = self.model.cal_task_loss(tensor_x_batch_adv_no_grad, tensor_y)
return [loss_dial], [self.aconf.gamma_reg]
return [loss_dial], [get_gamma_reg(self.aconf, self.name)]
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_fishr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
backpack = None

from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerFishr(TrainerBasic):
Expand All @@ -39,7 +40,7 @@ def tr_epoch(self, epoch):
dict_layerwise_var_var_grads_sum = \
{key: val.sum() for key, val in dict_layerwise_var_var_grads.items()}
loss_fishr = sum(dict_layerwise_var_var_grads_sum.values())
loss = sum(list_loss_erm) + self.aconf.gamma_reg * loss_fishr
loss = sum(list_loss_erm) + get_gamma_reg(self.aconf, self.name) * loss_fishr
loss.backward()
self.optimizer.step()
self.epo_loss_tr += loss.detach().item()
Expand Down
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_matchdg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from domainlab.tasks.utils_task_dset import DsetIndDecorator4XYD
from domainlab.utils.logger import Logger
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerMatchDG(AbstractTrainer):
Expand All @@ -36,7 +37,7 @@ def init_business(
self.base_domain_size = get_base_domain_size4match_dg(self.task)
self.epo_loss_tr = 0
self.flag_erm = flag_erm
self.lambda_ctr = self.aconf.gamma_reg
self.lambda_ctr = get_gamma_reg(aconf, self.name)
self.mk_match_tensor(epoch=0)
self.flag_match_tensor_sweep_over = False
self.tuple_tensor_ref_domain2each_y = None
Expand Down
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_mldg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.tasks.utils_task import mk_loader
from domainlab.tasks.utils_task_dset import DsetZip
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerMLDG(AbstractTrainer):
Expand Down Expand Up @@ -108,7 +109,7 @@ def tr_epoch(self, epoch):
loss = (
loss_source_task.sum()
+ source_reg_tr.sum()
+ self.aconf.gamma_reg * loss_look_forward.sum()
+ get_gamma_reg(self.aconf, self.name) * loss_look_forward.sum()
)
#
loss.backward()
Expand Down
39 changes: 38 additions & 1 deletion domainlab/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,37 @@
from domainlab.models.args_vae import add_args2parser_vae
from domainlab.utils.logger import Logger

class ParseValuesOrKeyValuePairs(argparse.Action):
"""Class used for arg parsing where values are provided in a key value format"""

def __call__(self, parser: argparse.ArgumentParser,
namespace: argparse.Namespace, values: str, option_string: str = None):
"""
Handle parsing of key value pairs, or a single value instead
Args:
parser (argparse.ArgumentParser): The ArgumentParser object.
namespace (argparse.Namespace): The namespace object to store parsed values.
values (str): The string containing key=value pairs or a single float value.
option_string (str, optional): The option string that triggered this action (unused).
Raises:
ValueError: If the values cannot be parsed to float.
"""
if "=" in values:
my_dict = {}
for kv in values.split(","):
k, v = kv.split("=")
try:
my_dict[k.strip()] = float(v.strip())
except ValueError:
raise ValueError(f"Invalid value in key-value pair: '{kv}', must be float")
setattr(namespace, self.dest, my_dict)
else:
try:
setattr(namespace, self.dest, float(values))
except ValueError:
raise ValueError(f"Invalid value for {self.dest}: '{values}', must be float")

def mk_parser_main():
"""
Expand All @@ -31,7 +62,13 @@ def mk_parser_main():
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")

parser.add_argument(
"--gamma_reg", type=float, default=0.1, help="weight of regularization loss"
"--gamma_reg",
default=0.1,
help="weight of regularization loss in the form of $$\ell(\cdot) + \mu \times R(\cdot)$$ \
can specify per model as 'default=3.0, dann=1.0,jigen=2.0', where default refer to gamma for trainer \
note diva is implemented $$\ell(\cdot) + \mu \times R(\cdot)$$ \
so diva does not have gamma_reg",
action=ParseValuesOrKeyValuePairs
)

parser.add_argument("--es", type=int, default=1, help="early stop steps")
Expand Down
32 changes: 32 additions & 0 deletions domainlab/models/a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,35 @@ def dset_decoration_args_algo(self, args, ddset):
if self._decoratee is not None:
return self._decoratee.dset_decoration_args_algo(args, ddset)
return ddset

@property
def p_na_prefix(self):
"""
common prefix for Models
"""
return "Model"

@property
def name(self):
"""
get the name of the algorithm
"""
na_prefix = self.p_na_prefix
len_prefix = len(na_prefix)
na_class = type(self).__name__
if na_class[:len_prefix] != na_prefix:
raise RuntimeError(
"Model builder node class must start with ",
na_prefix,
"the current class is named: ",
na_class,
)
return type(self).__name__[len_prefix:].lower()

def print_parameters(self):
"""
Function to print all parameters of the object.
Can be used to print the parameters of every child class.
"""
params = vars(self)
print(f"Parameters of {type(self).__name__}: {params}")
6 changes: 4 additions & 2 deletions domainlab/models/model_dann.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,15 @@ def hyper_update(self, epoch, fun_scheduler):
dict_rst = fun_scheduler(
epoch
) # the __call__ method of hyperparameter scheduler
self.alpha = dict_rst["alpha"]
self.alpha = dict_rst[self.name + "_alpha"]

def hyper_init(self, functor_scheduler):
"""hyper_init.
:param functor_scheduler:
"""
return functor_scheduler(trainer=None, alpha=self.alpha)
parameters = {}
parameters[self.name + "_alpha"] = self.alpha
return functor_scheduler(trainer=None, **parameters)

def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others):
_ = others
Expand Down
12 changes: 8 additions & 4 deletions domainlab/models/model_diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,22 @@ def hyper_update(self, epoch, fun_scheduler):
:param fun_scheduler:
"""
dict_rst = fun_scheduler(epoch)
self.beta_d = dict_rst["beta_d"]
self.beta_y = dict_rst["beta_y"]
self.beta_x = dict_rst["beta_x"]
self.beta_d = dict_rst[self.name + "_beta_d"]
self.beta_y = dict_rst[self.name + "_beta_x"]
self.beta_x = dict_rst[self.name + "_beta_y"]

def hyper_init(self, functor_scheduler):
"""
initiate a scheduler object via class name and things inside this model
:param functor_scheduler: the class name of the scheduler
"""
parameters = {}
parameters[self.name + "_beta_d"] = self.beta_d
parameters[self.name + "_beta_y"] = self.beta_y
parameters[self.name + "_beta_x"] = self.beta_x
return functor_scheduler(
trainer=None, beta_d=self.beta_d, beta_y=self.beta_y, beta_x=self.beta_x
trainer=None, **parameters
)

def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
Expand Down
19 changes: 10 additions & 9 deletions domainlab/models/model_hduva.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def hyper_update(self, epoch, fun_scheduler):
dict_rst = fun_scheduler(
epoch
) # the __call__ function of hyper-para-scheduler object
self.beta_d = dict_rst["beta_d"]
self.beta_y = dict_rst["beta_y"]
self.beta_x = dict_rst["beta_x"]
self.beta_t = dict_rst["beta_t"]
self.beta_d = dict_rst[self.name + "_beta_d"]
self.beta_y = dict_rst[self.name + "_beta_y"]
self.beta_x = dict_rst[self.name + "_beta_x"]
self.beta_t = dict_rst[self.name + "_beta_t"]

def hyper_init(self, functor_scheduler):
"""hyper_init.
Expand All @@ -78,12 +78,13 @@ def hyper_init(self, functor_scheduler):
# calling the constructor of the hyper-parameter-scheduler class, so that this scheduler
# class build a dictionary {"beta_d":self.beta_d, "beta_y":self.beta_y}
# constructor signature is def __init__(self, **kwargs):
parameters = {}
parameters[self.name + "_beta_d"] = self.beta_d
parameters[self.name + "_beta_y"] = self.beta_y
parameters[self.name + "_beta_x"] = self.beta_x
parameters[self.name + "_beta_t"] = self.beta_t
return functor_scheduler(
trainer=None,
beta_d=self.beta_d,
beta_y=self.beta_y,
beta_x=self.beta_x,
beta_t=self.beta_t,
trainer=None, **parameters
)

@store_args
Expand Down
18 changes: 18 additions & 0 deletions domainlab/utils/hyperparameter_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""
retrieval for hyperparameters
"""

def get_gamma_reg(args, component_name):
"""
Retrieves either a shared gamma regularization, or individual ones for each specified object
"""
gamma_reg = args.gamma_reg
if isinstance(gamma_reg, dict):
if component_name in gamma_reg:
return gamma_reg[component_name]
if 'default' in gamma_reg:
return gamma_reg['default']
raise ValueError("""If a gamma_reg dict is specified,
but no value set for every model and trainer,
a default value must be specified.""")
return gamma_reg # Return the single value if it's not a dictionary
5 changes: 4 additions & 1 deletion examples/conf/vlcs_diva_mldg_dial.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ val_threshold: 0.8 # threashold before which training does not
model: dann_diva # combine model DANN with DIVA
epos: 1 # number of epochs
trainer: mldg_dial # combine trainer MLDG and DIAL
gamma_reg: 1.0 # hyperparameter of DANN
gamma_reg:
default: 1.0
dann: 1.5
# in this case, mldg and dial get the default gamma_reg value 1.0
gamma_y: 700000.0 # hyperparameter of diva
gamma_d: 100000.0 # hyperparameter of diva
npath: examples/nets/resnet.py # neural network for class classification
Expand Down
2 changes: 1 addition & 1 deletion run_benchmark_slurm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ echo "Number of GPUs: $NUMBER_GPUS"
echo "Results will be stored in: $results_dir"

# Helmholtz
snakemake --profile "examples/yaml/slurm" --config yaml_file="$CONFIGFILE" --keep-going --keep-incomplete --notemp --cores 3 -s "domainlab/exp_protocol/benchmark.smk" --configfile "$CONFIGFILE" --config output_dir="$results_dir" 2>&1 | tee "$logfile"
snakemake --profile "examples/yaml/slurm" --config yaml_file="$CONFIGFILE" --keep-going --keep-incomplete --notemp --cores 3 -s "domainlab/exp_protocol/benchmark.smk" --configfile "$CONFIGFILE" --config output_dir="$results_dir" 2>&1 | tee "$logfile"
Loading

0 comments on commit ca493e5

Please sign in to comment.