Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

master: Fixed hyperparameter collisions #806

Merged
merged 33 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7baea73
fixed issue where hyperparameters are colliding
MatteoWohlrapp Apr 15, 2024
07086a2
changed id to name for model identification
MatteoWohlrapp Apr 16, 2024
871374a
Solved gamma_reg naming collision by introducing functionality to pas…
MatteoWohlrapp Apr 22, 2024
b680380
Merge branch 'master' into gamma_reg_collision
MatteoWohlrapp Apr 22, 2024
4be4a02
fixed codacy
MatteoWohlrapp Apr 23, 2024
394d23a
fixed codacy
MatteoWohlrapp Apr 23, 2024
f720bf7
Merge branch 'master' into gamma_reg_collision
smilesun Apr 25, 2024
cc650ad
Removed diva from tests and yaml for gamma hyperparam
MatteoWohlrapp Apr 30, 2024
b64fe6b
Merge branch 'master' into gamma_reg_collision
smilesun May 3, 2024
8fef65d
Merge branch 'master' into gamma_reg_collision
smilesun May 6, 2024
febe876
Increased test coverage
MatteoWohlrapp May 7, 2024
fd36056
fixed codacity
MatteoWohlrapp May 7, 2024
6520834
fixed codacity
MatteoWohlrapp May 7, 2024
c5e3e6a
fixed codacity
MatteoWohlrapp May 7, 2024
884682b
Added quotes around variables
MatteoWohlrapp May 7, 2024
57ceef7
Merge branch 'master' into gamma_reg_collision
smilesun May 7, 2024
91ce3c5
Merge branch 'master' into gamma_reg_collision
smilesun May 10, 2024
aac6510
Merge branch 'master' into gamma_reg_collision
smilesun May 10, 2024
0a1b0b9
Merge branch 'master' into gamma_reg_collision
smilesun May 10, 2024
be59cf2
Merge branch 'master' into gamma_reg_collision
smilesun May 13, 2024
32360dd
Merge branch 'master' into gamma_reg_collision
smilesun May 15, 2024
5d2fb82
Fixed codacity
Jun 11, 2024
570e8a8
fixed codacity
Jun 11, 2024
1198eeb
Corrected doc for gamma_reg, added docstring to call method in argpar…
Jul 2, 2024
bc060ba
Merge branch 'master' into gamma_reg_collision
smilesun Jul 2, 2024
97d9320
Update arg_parser.py, update documentation , remove diva gamma_reg
smilesun Jul 2, 2024
05d228b
Update doc_usage_cmd.md, remove diva gamma_reg in doc
smilesun Jul 2, 2024
5c0f6b1
Update doc_usage_cmd.md
smilesun Jul 2, 2024
5ff5fd3
merge commit
MatteoWohlrapp Jul 2, 2024
31ce1ec
fixed argparser syntax, renamed argument
MatteoWohlrapp Jul 2, 2024
6392ac5
Update vlcs_diva_mldg_dial.yaml
smilesun Jul 2, 2024
0d62919
Update test_hyperparameter_retrieval.py, comments in unit test
smilesun Jul 2, 2024
bdf84c2
style
smilesun Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions docs/doc_usage_cmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,24 @@ To run DomainLab, the minimum necessary parameters are:
### Advanced Configuration

- **Learning Rate (`--lr`):** Set the training learning rate.
- **Regularization (`--gamma_reg`):** Weight of regularization loss.
- **Early Stopping (`--es`):** Steps for early stopping.
- **Regularization (`--gamma_reg`):** Sets the weight of the regularization loss. This parameter can be configured either as a single value applied to individual classes, or using a dictionary to specify different weights for different models and trainers.

- **Command Line Usage:**
- For a single value: `python script.py --gamma_reg=0.1`
- For multiple values: `python script.py --gamma_reg='default=0.1,dann=0.05,diva=0.2'`
smilesun marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, due to the specific implementation of DIVA, instead of $$\ell(\cdot) + \mu R(\cdot)$$ type of loss, for DIVA, it is $$\mu \ell(\cdot) + R(\cdot)$$, here $\mu$ is gamma_reg in the code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But where is this specified for model diva. I can only find this initialization of diva which does not have a regularization :
model = mk_diva(list_str_y=task.list_str_y)(
node,
zd_dim=args.zd_dim,
zy_dim=args.zy_dim,
zx_dim=args.zx_dim,
list_d_tr=task.list_domain_tr,
gamma_d=args.gamma_d,
gamma_y=args.gamma_y,
beta_x=args.beta_x,
beta_y=args.beta_y,
beta_d=args.beta_d,
)
It's here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is specified here:

and

def multiplier4task_loss(self):


- **YAML Configuration:**
- For a single value:
```yaml
gamma_reg: 0.1
```
- For different values:
```yaml
gamma_reg:
default: 0.1 # every other instance that is not listed below will get this value assigned
dann: 0.05
diva: 0.2
```- **Early Stopping (`--es`):** Steps for early stopping.
- **Random Seed (`--seed`):** Seed for reproducibility.
- **CUDA Options (`--nocu`, `--device`):** Configure CUDA usage and device settings.
- **Generated Images (`--gen`):** Option to save generated images.
Expand Down
3 changes: 2 additions & 1 deletion domainlab/algos/builder_dann.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter
from domainlab.models.model_dann import mk_dann
from domainlab.utils.utils_cuda import get_device
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class NodeAlgoBuilderDANN(NodeAlgoBuilder):
Expand Down Expand Up @@ -55,7 +56,7 @@ def init_business(self, exp):
model = mk_dann(list_str_y=task.list_str_y,
net_classifier=net_classifier)(
list_d_tr=task.list_domain_tr,
alpha=args.gamma_reg,
alpha=get_gamma_reg(args, 'dann'),
net_encoder=net_encoder,
net_discriminator=net_discriminator,
builder=self)
Expand Down
3 changes: 2 additions & 1 deletion domainlab/algos/builder_jigen1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from domainlab.dsets.utils_wrapdset_patches import WrapDsetPatches
from domainlab.models.model_jigen import mk_jigen
from domainlab.utils.utils_cuda import get_device
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class NodeAlgoBuilderJiGen(NodeAlgoBuilder):
Expand Down Expand Up @@ -56,7 +57,7 @@ def init_business(self, exp):
model = mk_jigen(
list_str_y=task.list_str_y,
net_classifier=net_classifier)(
coeff_reg=args.gamma_reg,
coeff_reg=get_gamma_reg(args, 'jigen'),
net_encoder=net_encoder,
net_classifier_permutation=net_classifier_perm,
n_perm=args.nperm,
Expand Down
8 changes: 8 additions & 0 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,11 @@ def dset_decoration_args_algo(self, args, ddset):
if self._decoratee is not None:
return self._decoratee.dset_decoration_args_algo(args, ddset)
return ddset

def print_parameters(self):
MatteoWohlrapp marked this conversation as resolved.
Show resolved Hide resolved
"""
Function to print all parameters of the object.
Can be used to print the parameters of any child class
"""
params = vars(self)
print(f"Parameters of {type(self).__name__}: {params}")
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_dial.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.autograd import Variable

from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerDIAL(TrainerBasic):
Expand Down Expand Up @@ -49,4 +50,4 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
tensor_x_adv = self.gen_adversarial(self.device, tensor_x, tensor_y)
tensor_x_batch_adv_no_grad = Variable(tensor_x_adv, requires_grad=False)
loss_dial = self.model.cal_task_loss(tensor_x_batch_adv_no_grad, tensor_y)
return [loss_dial], [self.aconf.gamma_reg]
return [loss_dial], [get_gamma_reg(self.aconf, self.name)]
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_fishr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
backpack = None

from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerFishr(TrainerBasic):
Expand All @@ -39,7 +40,7 @@ def tr_epoch(self, epoch):
dict_layerwise_var_var_grads_sum = \
{key: val.sum() for key, val in dict_layerwise_var_var_grads.items()}
loss_fishr = sum(dict_layerwise_var_var_grads_sum.values())
loss = sum(list_loss_erm) + self.aconf.gamma_reg * loss_fishr
loss = sum(list_loss_erm) + get_gamma_reg(self.aconf, self.name) * loss_fishr
loss.backward()
self.optimizer.step()
self.epo_loss_tr += loss.detach().item()
Expand Down
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_matchdg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from domainlab.tasks.utils_task_dset import DsetIndDecorator4XYD
from domainlab.utils.logger import Logger
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerMatchDG(AbstractTrainer):
Expand All @@ -36,7 +37,7 @@ def init_business(
self.base_domain_size = get_base_domain_size4match_dg(self.task)
self.epo_loss_tr = 0
self.flag_erm = flag_erm
self.lambda_ctr = self.aconf.gamma_reg
self.lambda_ctr = get_gamma_reg(aconf, self.name)
self.mk_match_tensor(epoch=0)
self.flag_match_tensor_sweep_over = False
self.tuple_tensor_ref_domain2each_y = None
Expand Down
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_mldg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.tasks.utils_task import mk_loader
from domainlab.tasks.utils_task_dset import DsetZip
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerMLDG(AbstractTrainer):
Expand Down Expand Up @@ -108,7 +109,7 @@ def tr_epoch(self, epoch):
loss = (
loss_source_task.sum()
+ source_reg_tr.sum()
+ self.aconf.gamma_reg * loss_look_forward.sum()
+ get_gamma_reg(self.aconf, self.name) * loss_look_forward.sum()
)
#
loss.backward()
Expand Down
20 changes: 19 additions & 1 deletion domainlab/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,21 @@
from domainlab.models.args_vae import add_args2parser_vae
from domainlab.utils.logger import Logger

class StoreDictKeyPair(argparse.Action):
"""Class used for arg parsing where values are provided in a key value format"""

def __call__(self, parser, namespace, values, option_string=None):
smilesun marked this conversation as resolved.
Show resolved Hide resolved
MatteoWohlrapp marked this conversation as resolved.
Show resolved Hide resolved
try:
if "=" in values:
my_dict = {}
for kv in values.split(","):
k, v = kv.split("=")
my_dict[k.strip()] = float(v.strip()) # Assuming values are floats
setattr(namespace, self.dest, my_dict)
else:
setattr(namespace, self.dest, float(values)) # Single float value
except ValueError:
raise argparse.ArgumentError(self, f"Invalid value for {self.dest}: {values}")

def mk_parser_main():
"""
Expand All @@ -31,7 +46,10 @@ def mk_parser_main():
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")

parser.add_argument(
"--gamma_reg", type=float, default=0.1, help="weight of regularization loss"
"--gamma_reg",
default=0.1,
help="weight of regularization loss, can specify per model as 'dann=1.0,diva=2.0'",
action=StoreDictKeyPair
)

parser.add_argument("--es", type=int, default=1, help="early stop steps")
Expand Down
32 changes: 32 additions & 0 deletions domainlab/models/a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,35 @@ def dset_decoration_args_algo(self, args, ddset):
if self._decoratee is not None:
return self._decoratee.dset_decoration_args_algo(args, ddset)
return ddset

@property
def p_na_prefix(self):
"""
common prefix for Models
"""
return "Model"

@property
def name(self):
"""
get the name of the algorithm
"""
na_prefix = self.p_na_prefix
len_prefix = len(na_prefix)
na_class = type(self).__name__
if na_class[:len_prefix] != na_prefix:
raise RuntimeError(
"Model builder node class must start with ",
na_prefix,
"the current class is named: ",
na_class,
)
return type(self).__name__[len_prefix:].lower()

def print_parameters(self):
"""
Function to print all parameters of the object.
Can be used to print the parameters of every child class.
"""
params = vars(self)
print(f"Parameters of {type(self).__name__}: {params}")
6 changes: 4 additions & 2 deletions domainlab/models/model_dann.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,15 @@ def hyper_update(self, epoch, fun_scheduler):
dict_rst = fun_scheduler(
epoch
) # the __call__ method of hyperparameter scheduler
self.alpha = dict_rst["alpha"]
self.alpha = dict_rst[self.name + "_alpha"]

def hyper_init(self, functor_scheduler):
"""hyper_init.
:param functor_scheduler:
"""
return functor_scheduler(trainer=None, alpha=self.alpha)
parameters = {}
parameters[self.name + "_alpha"] = self.alpha
return functor_scheduler(trainer=None, **parameters)

def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others):
_ = others
Expand Down
12 changes: 8 additions & 4 deletions domainlab/models/model_diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,22 @@ def hyper_update(self, epoch, fun_scheduler):
:param fun_scheduler:
"""
dict_rst = fun_scheduler(epoch)
self.beta_d = dict_rst["beta_d"]
self.beta_y = dict_rst["beta_y"]
self.beta_x = dict_rst["beta_x"]
self.beta_d = dict_rst[self.name + "_beta_d"]
self.beta_y = dict_rst[self.name + "_beta_x"]
self.beta_x = dict_rst[self.name + "_beta_y"]

def hyper_init(self, functor_scheduler):
"""
initiate a scheduler object via class name and things inside this model

:param functor_scheduler: the class name of the scheduler
"""
parameters = {}
parameters[self.name + "_beta_d"] = self.beta_d
parameters[self.name + "_beta_y"] = self.beta_y
parameters[self.name + "_beta_x"] = self.beta_x
return functor_scheduler(
trainer=None, beta_d=self.beta_d, beta_y=self.beta_y, beta_x=self.beta_x
trainer=None, **parameters
)

def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
Expand Down
19 changes: 10 additions & 9 deletions domainlab/models/model_hduva.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def hyper_update(self, epoch, fun_scheduler):
dict_rst = fun_scheduler(
epoch
) # the __call__ function of hyper-para-scheduler object
self.beta_d = dict_rst["beta_d"]
self.beta_y = dict_rst["beta_y"]
self.beta_x = dict_rst["beta_x"]
self.beta_t = dict_rst["beta_t"]
self.beta_d = dict_rst[self.name + "_beta_d"]
self.beta_y = dict_rst[self.name + "_beta_y"]
self.beta_x = dict_rst[self.name + "_beta_x"]
self.beta_t = dict_rst[self.name + "_beta_t"]

def hyper_init(self, functor_scheduler):
"""hyper_init.
Expand All @@ -78,12 +78,13 @@ def hyper_init(self, functor_scheduler):
# calling the constructor of the hyper-parameter-scheduler class, so that this scheduler
# class build a dictionary {"beta_d":self.beta_d, "beta_y":self.beta_y}
# constructor signature is def __init__(self, **kwargs):
parameters = {}
parameters[self.name + "_beta_d"] = self.beta_d
parameters[self.name + "_beta_y"] = self.beta_y
parameters[self.name + "_beta_x"] = self.beta_x
parameters[self.name + "_beta_t"] = self.beta_t
return functor_scheduler(
trainer=None,
beta_d=self.beta_d,
beta_y=self.beta_y,
beta_x=self.beta_x,
beta_t=self.beta_t,
trainer=None, **parameters
)

@store_args
Expand Down
18 changes: 18 additions & 0 deletions domainlab/utils/hyperparameter_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""
retrieval for hyperparameters
"""

def get_gamma_reg(args, model_name):
smilesun marked this conversation as resolved.
Show resolved Hide resolved
"""
Retrieves either a shared gamma regularization, or individual ones for each specified object
"""
gamma_reg = args.gamma_reg
if isinstance(gamma_reg, dict):
if model_name in gamma_reg:
return gamma_reg[model_name]
if 'default' in gamma_reg:
return gamma_reg['default']
raise ValueError("""If a gamma_reg dict is specified,
but no value set for every model and trainer,
a default value must be specified.""")
return gamma_reg # Return the single value if it's not a dictionary
4 changes: 3 additions & 1 deletion examples/conf/vlcs_diva_mldg_dial.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ val_threshold: 0.8
model: dann_diva # combine model DANN with DIVA
epos: 1 # number of epochs
trainer: mldg_dial # combine trainer MLDG and DIAL
gamma_reg: 1.0 # hyperparameter of DANN
gamma_reg:
default: 1.0
smilesun marked this conversation as resolved.
Show resolved Hide resolved
dann: 1.5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MatteoWohlrapp , how about mldg?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will get the default values assigned

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this behavior tested somewhere? in which test, for instance, now the gamma_reg value for mldg should be 1.0? @MatteoWohlrapp

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i made a pull request to test his, still under work @MatteoWohlrapp

#847

gamma_y: 700000.0 # hyperparameter of diva
gamma_d: 100000.0 # hyperparameter of diva
npath: examples/nets/resnet.py # neural network for class classification
Expand Down
32 changes: 32 additions & 0 deletions tests/test_hyperparameter_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
unit test for hyperparameter parsing
"""
from domainlab.arg_parser import mk_parser_main
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg

def test_store_dict_key_pair_single_value():
"""Test to parse a single gamma_reg parameter"""
parser = mk_parser_main()
args = parser.parse_args(['--gamma_reg', '0.5'])
assert args.gamma_reg == 0.5

def test_store_dict_key_pair_dict_value():
"""Test to parse a dict for the gamma_reg"""
parser = mk_parser_main()
args = parser.parse_args(['--gamma_reg', 'dann=1.0,jigen=2.0'])
assert args.gamma_reg == {'dann': 1.0, 'jigen': 2.0}

def test_get_gamma_reg_single_value():
"""Test to retrieve a single gamma_reg parameter which is applied to all objects"""
parser = mk_parser_main()
args = parser.parse_args(['--gamma_reg', '0.5'])
assert get_gamma_reg(args, 'dann') == 0.5

def test_get_gamma_reg_dict_value():
"""Test to retrieve a dict of gamma_reg parameters for different objects"""
parser = mk_parser_main()
args = parser.parse_args(['--gamma_reg', 'default=5.0,dann=1.0,jigen=2.0'])
print(args)
assert get_gamma_reg(args, 'dann') == 1.0
assert get_gamma_reg(args, 'jigen') == 2.0
assert get_gamma_reg(args, 'nonexistent') == 5.0
smilesun marked this conversation as resolved.
Show resolved Hide resolved
Loading