Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

master: Fixed hyperparameter collisions #806

Merged
merged 33 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7baea73
fixed issue where hyperparameters are colliding
MatteoWohlrapp Apr 15, 2024
07086a2
changed id to name for model identification
MatteoWohlrapp Apr 16, 2024
871374a
Solved gamma_reg naming collision by introducing functionality to pas…
MatteoWohlrapp Apr 22, 2024
b680380
Merge branch 'master' into gamma_reg_collision
MatteoWohlrapp Apr 22, 2024
4be4a02
fixed codacy
MatteoWohlrapp Apr 23, 2024
394d23a
fixed codacy
MatteoWohlrapp Apr 23, 2024
f720bf7
Merge branch 'master' into gamma_reg_collision
smilesun Apr 25, 2024
cc650ad
Removed diva from tests and yaml for gamma hyperparam
MatteoWohlrapp Apr 30, 2024
b64fe6b
Merge branch 'master' into gamma_reg_collision
smilesun May 3, 2024
8fef65d
Merge branch 'master' into gamma_reg_collision
smilesun May 6, 2024
febe876
Increased test coverage
MatteoWohlrapp May 7, 2024
fd36056
fixed codacity
MatteoWohlrapp May 7, 2024
6520834
fixed codacity
MatteoWohlrapp May 7, 2024
c5e3e6a
fixed codacity
MatteoWohlrapp May 7, 2024
884682b
Added quotes around variables
MatteoWohlrapp May 7, 2024
57ceef7
Merge branch 'master' into gamma_reg_collision
smilesun May 7, 2024
91ce3c5
Merge branch 'master' into gamma_reg_collision
smilesun May 10, 2024
aac6510
Merge branch 'master' into gamma_reg_collision
smilesun May 10, 2024
0a1b0b9
Merge branch 'master' into gamma_reg_collision
smilesun May 10, 2024
be59cf2
Merge branch 'master' into gamma_reg_collision
smilesun May 13, 2024
32360dd
Merge branch 'master' into gamma_reg_collision
smilesun May 15, 2024
5d2fb82
Fixed codacity
Jun 11, 2024
570e8a8
fixed codacity
Jun 11, 2024
1198eeb
Corrected doc for gamma_reg, added docstring to call method in argpar…
Jul 2, 2024
bc060ba
Merge branch 'master' into gamma_reg_collision
smilesun Jul 2, 2024
97d9320
Update arg_parser.py, update documentation , remove diva gamma_reg
smilesun Jul 2, 2024
05d228b
Update doc_usage_cmd.md, remove diva gamma_reg in doc
smilesun Jul 2, 2024
5c0f6b1
Update doc_usage_cmd.md
smilesun Jul 2, 2024
5ff5fd3
merge commit
MatteoWohlrapp Jul 2, 2024
31ce1ec
fixed argparser syntax, renamed argument
MatteoWohlrapp Jul 2, 2024
6392ac5
Update vlcs_diva_mldg_dial.yaml
smilesun Jul 2, 2024
0d62919
Update test_hyperparameter_retrieval.py, comments in unit test
smilesun Jul 2, 2024
bdf84c2
style
smilesun Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion docs/doc_usage_cmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,34 @@ To run DomainLab, the minimum necessary parameters are:
### Advanced Configuration

- **Learning Rate (`--lr`):** Set the training learning rate.
- **Regularization (`--gamma_reg`):** Weight of regularization loss.
- **Regularization (`--gamma_reg`):** Sets the weight of the regularization
loss. This parameter can be configured either as
a single value applied to individual classes,
or using a dictionary to specify different
weights for different models and trainers.

- **Command Line Usage:**
- For a single value: `python script.py --gamma_reg=0.1`
- For multiple values: `python script.py --gamma_reg='default=0.1,dann=0.05,jigen=0.2'`

- **YAML Configuration:**
- For a single value:

```yaml
gamma_reg: 0.1
```

- For different values:

```yaml
gamma_reg:
dann: 0.05
dial: 0.2
default: 0.1 # value for every other instance
smilesun marked this conversation as resolved.
Show resolved Hide resolved
```
Gamma reg is available for the trainers, as well as the
dann and jigen model.

- **Early Stopping (`--es`):** Steps for early stopping.
- **Random Seed (`--seed`):** Seed for reproducibility.
- **CUDA Options (`--nocu`, `--device`):** Configure CUDA usage and device settings.
Expand Down
3 changes: 2 additions & 1 deletion domainlab/algos/builder_dann.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter
from domainlab.models.model_dann import mk_dann
from domainlab.utils.utils_cuda import get_device
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class NodeAlgoBuilderDANN(NodeAlgoBuilder):
Expand Down Expand Up @@ -55,7 +56,7 @@ def init_business(self, exp):
model = mk_dann(list_str_y=task.list_str_y,
net_classifier=net_classifier)(
list_d_tr=task.list_domain_tr,
alpha=args.gamma_reg,
alpha=get_gamma_reg(args, 'dann'),
net_encoder=net_encoder,
net_discriminator=net_discriminator,
builder=self)
Expand Down
3 changes: 2 additions & 1 deletion domainlab/algos/builder_jigen1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from domainlab.dsets.utils_wrapdset_patches import WrapDsetPatches
from domainlab.models.model_jigen import mk_jigen
from domainlab.utils.utils_cuda import get_device
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class NodeAlgoBuilderJiGen(NodeAlgoBuilder):
Expand Down Expand Up @@ -56,7 +57,7 @@ def init_business(self, exp):
model = mk_jigen(
list_str_y=task.list_str_y,
net_classifier=net_classifier)(
coeff_reg=args.gamma_reg,
coeff_reg=get_gamma_reg(args, 'jigen'),
net_encoder=net_encoder,
net_classifier_permutation=net_classifier_perm,
n_perm=args.nperm,
Expand Down
8 changes: 8 additions & 0 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,11 @@ def dset_decoration_args_algo(self, args, ddset):
if self._decoratee is not None:
return self._decoratee.dset_decoration_args_algo(args, ddset)
return ddset

def print_parameters(self):
MatteoWohlrapp marked this conversation as resolved.
Show resolved Hide resolved
"""
Function to print all parameters of the object.
Can be used to print the parameters of any child class
"""
params = vars(self)
print(f"Parameters of {type(self).__name__}: {params}")
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_dial.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.autograd import Variable

from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerDIAL(TrainerBasic):
Expand Down Expand Up @@ -49,4 +50,4 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
tensor_x_adv = self.gen_adversarial(self.device, tensor_x, tensor_y)
tensor_x_batch_adv_no_grad = Variable(tensor_x_adv, requires_grad=False)
loss_dial = self.model.cal_task_loss(tensor_x_batch_adv_no_grad, tensor_y)
return [loss_dial], [self.aconf.gamma_reg]
return [loss_dial], [get_gamma_reg(self.aconf, self.name)]
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_fishr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
backpack = None

from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerFishr(TrainerBasic):
Expand All @@ -39,7 +40,7 @@ def tr_epoch(self, epoch):
dict_layerwise_var_var_grads_sum = \
{key: val.sum() for key, val in dict_layerwise_var_var_grads.items()}
loss_fishr = sum(dict_layerwise_var_var_grads_sum.values())
loss = sum(list_loss_erm) + self.aconf.gamma_reg * loss_fishr
loss = sum(list_loss_erm) + get_gamma_reg(self.aconf, self.name) * loss_fishr
loss.backward()
self.optimizer.step()
self.epo_loss_tr += loss.detach().item()
Expand Down
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_matchdg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from domainlab.tasks.utils_task_dset import DsetIndDecorator4XYD
from domainlab.utils.logger import Logger
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerMatchDG(AbstractTrainer):
Expand All @@ -36,7 +37,7 @@ def init_business(
self.base_domain_size = get_base_domain_size4match_dg(self.task)
self.epo_loss_tr = 0
self.flag_erm = flag_erm
self.lambda_ctr = self.aconf.gamma_reg
self.lambda_ctr = get_gamma_reg(aconf, self.name)
self.mk_match_tensor(epoch=0)
self.flag_match_tensor_sweep_over = False
self.tuple_tensor_ref_domain2each_y = None
Expand Down
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_mldg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.tasks.utils_task import mk_loader
from domainlab.tasks.utils_task_dset import DsetZip
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerMLDG(AbstractTrainer):
Expand Down Expand Up @@ -108,7 +109,7 @@ def tr_epoch(self, epoch):
loss = (
loss_source_task.sum()
+ source_reg_tr.sum()
+ self.aconf.gamma_reg * loss_look_forward.sum()
+ get_gamma_reg(self.aconf, self.name) * loss_look_forward.sum()
)
#
loss.backward()
Expand Down
39 changes: 38 additions & 1 deletion domainlab/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,37 @@
from domainlab.models.args_vae import add_args2parser_vae
from domainlab.utils.logger import Logger

class ParseValuesOrKeyValuePairs(argparse.Action):
"""Class used for arg parsing where values are provided in a key value format"""

def __call__(self, parser: argparse.ArgumentParser,
namespace: argparse.Namespace, values: str, option_string: str = None):
"""
Handle parsing of key value pairs, or a single value instead

Args:
parser (argparse.ArgumentParser): The ArgumentParser object.
namespace (argparse.Namespace): The namespace object to store parsed values.
values (str): The string containing key=value pairs or a single float value.
option_string (str, optional): The option string that triggered this action (unused).

Raises:
ValueError: If the values cannot be parsed to float.
"""
if "=" in values:
my_dict = {}
for kv in values.split(","):
k, v = kv.split("=")
try:
my_dict[k.strip()] = float(v.strip())
except ValueError:
raise ValueError(f"Invalid value in key-value pair: '{kv}', must be float")
smilesun marked this conversation as resolved.
Show resolved Hide resolved
setattr(namespace, self.dest, my_dict)
else:
try:
setattr(namespace, self.dest, float(values))
except ValueError:
raise ValueError(f"Invalid value for {self.dest}: '{values}', must be float")

def mk_parser_main():
"""
Expand All @@ -31,7 +62,13 @@ def mk_parser_main():
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")

parser.add_argument(
"--gamma_reg", type=float, default=0.1, help="weight of regularization loss"
"--gamma_reg",
default=0.1,
help="weight of regularization loss in the form of $$\ell(\cdot) + \mu \times R(\cdot)$$ \
can specify per model as 'default=3.0, dann=1.0,jigen=2.0', where default refer to gamma for trainer \
note diva is implemented $$\ell(\cdot) + \mu \times R(\cdot)$$ \
so diva does not have gamma_reg",
action=ParseValuesOrKeyValuePairs
)

parser.add_argument("--es", type=int, default=1, help="early stop steps")
Expand Down
32 changes: 32 additions & 0 deletions domainlab/models/a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,35 @@ def dset_decoration_args_algo(self, args, ddset):
if self._decoratee is not None:
return self._decoratee.dset_decoration_args_algo(args, ddset)
return ddset

@property
def p_na_prefix(self):
"""
common prefix for Models
"""
return "Model"

@property
def name(self):
"""
get the name of the algorithm
"""
na_prefix = self.p_na_prefix
len_prefix = len(na_prefix)
na_class = type(self).__name__
if na_class[:len_prefix] != na_prefix:
raise RuntimeError(
"Model builder node class must start with ",
na_prefix,
"the current class is named: ",
na_class,
)
return type(self).__name__[len_prefix:].lower()

def print_parameters(self):
"""
Function to print all parameters of the object.
Can be used to print the parameters of every child class.
"""
params = vars(self)
print(f"Parameters of {type(self).__name__}: {params}")
6 changes: 4 additions & 2 deletions domainlab/models/model_dann.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,15 @@ def hyper_update(self, epoch, fun_scheduler):
dict_rst = fun_scheduler(
epoch
) # the __call__ method of hyperparameter scheduler
self.alpha = dict_rst["alpha"]
self.alpha = dict_rst[self.name + "_alpha"]

def hyper_init(self, functor_scheduler):
"""hyper_init.
:param functor_scheduler:
"""
return functor_scheduler(trainer=None, alpha=self.alpha)
parameters = {}
parameters[self.name + "_alpha"] = self.alpha
return functor_scheduler(trainer=None, **parameters)

def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others):
_ = others
Expand Down
12 changes: 8 additions & 4 deletions domainlab/models/model_diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,22 @@ def hyper_update(self, epoch, fun_scheduler):
:param fun_scheduler:
"""
dict_rst = fun_scheduler(epoch)
self.beta_d = dict_rst["beta_d"]
self.beta_y = dict_rst["beta_y"]
self.beta_x = dict_rst["beta_x"]
self.beta_d = dict_rst[self.name + "_beta_d"]
self.beta_y = dict_rst[self.name + "_beta_x"]
self.beta_x = dict_rst[self.name + "_beta_y"]

def hyper_init(self, functor_scheduler):
"""
initiate a scheduler object via class name and things inside this model

:param functor_scheduler: the class name of the scheduler
"""
parameters = {}
parameters[self.name + "_beta_d"] = self.beta_d
parameters[self.name + "_beta_y"] = self.beta_y
parameters[self.name + "_beta_x"] = self.beta_x
return functor_scheduler(
trainer=None, beta_d=self.beta_d, beta_y=self.beta_y, beta_x=self.beta_x
trainer=None, **parameters
)

def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
Expand Down
19 changes: 10 additions & 9 deletions domainlab/models/model_hduva.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def hyper_update(self, epoch, fun_scheduler):
dict_rst = fun_scheduler(
epoch
) # the __call__ function of hyper-para-scheduler object
self.beta_d = dict_rst["beta_d"]
self.beta_y = dict_rst["beta_y"]
self.beta_x = dict_rst["beta_x"]
self.beta_t = dict_rst["beta_t"]
self.beta_d = dict_rst[self.name + "_beta_d"]
self.beta_y = dict_rst[self.name + "_beta_y"]
self.beta_x = dict_rst[self.name + "_beta_x"]
self.beta_t = dict_rst[self.name + "_beta_t"]

def hyper_init(self, functor_scheduler):
"""hyper_init.
Expand All @@ -78,12 +78,13 @@ def hyper_init(self, functor_scheduler):
# calling the constructor of the hyper-parameter-scheduler class, so that this scheduler
# class build a dictionary {"beta_d":self.beta_d, "beta_y":self.beta_y}
# constructor signature is def __init__(self, **kwargs):
parameters = {}
parameters[self.name + "_beta_d"] = self.beta_d
parameters[self.name + "_beta_y"] = self.beta_y
parameters[self.name + "_beta_x"] = self.beta_x
parameters[self.name + "_beta_t"] = self.beta_t
return functor_scheduler(
trainer=None,
beta_d=self.beta_d,
beta_y=self.beta_y,
beta_x=self.beta_x,
beta_t=self.beta_t,
trainer=None, **parameters
)

@store_args
Expand Down
18 changes: 18 additions & 0 deletions domainlab/utils/hyperparameter_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""
retrieval for hyperparameters
"""

def get_gamma_reg(args, component_name):
"""
Retrieves either a shared gamma regularization, or individual ones for each specified object
"""
gamma_reg = args.gamma_reg
if isinstance(gamma_reg, dict):
if component_name in gamma_reg:
return gamma_reg[component_name]
if 'default' in gamma_reg:
return gamma_reg['default']
raise ValueError("""If a gamma_reg dict is specified,
but no value set for every model and trainer,
a default value must be specified.""")
return gamma_reg # Return the single value if it's not a dictionary
5 changes: 4 additions & 1 deletion examples/conf/vlcs_diva_mldg_dial.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ val_threshold: 0.8 # threashold before which training does not
model: dann_diva # combine model DANN with DIVA
epos: 1 # number of epochs
trainer: mldg_dial # combine trainer MLDG and DIAL
gamma_reg: 1.0 # hyperparameter of DANN
gamma_reg:
default: 1.0
smilesun marked this conversation as resolved.
Show resolved Hide resolved
dann: 1.5
# in this case, mldg and dial get the default gamma_reg value 1.0
gamma_y: 700000.0 # hyperparameter of diva
gamma_d: 100000.0 # hyperparameter of diva
npath: examples/nets/resnet.py # neural network for class classification
Expand Down
2 changes: 1 addition & 1 deletion run_benchmark_slurm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ echo "Number of GPUs: $NUMBER_GPUS"
echo "Results will be stored in: $results_dir"

# Helmholtz
snakemake --profile "examples/yaml/slurm" --config yaml_file="$CONFIGFILE" --keep-going --keep-incomplete --notemp --cores 3 -s "domainlab/exp_protocol/benchmark.smk" --configfile "$CONFIGFILE" --config output_dir="$results_dir" 2>&1 | tee "$logfile"
snakemake --profile "examples/yaml/slurm" --config yaml_file="$CONFIGFILE" --keep-going --keep-incomplete --notemp --cores 3 -s "domainlab/exp_protocol/benchmark.smk" --configfile "$CONFIGFILE" --config output_dir="$results_dir" 2>&1 | tee "$logfile"
Loading
Loading