Skip to content

Commit

Permalink
merge master
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun committed Jul 11, 2024
2 parents a40cf96 + 344936b commit 054485c
Show file tree
Hide file tree
Showing 44 changed files with 582 additions and 1,700 deletions.
10 changes: 8 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ jobs:
- name: check if readme yaml works
run: rm -r zoutput && python main_out.py -c ./examples/conf/vlcs_diva_mldg_dial.yaml
- name: test if examples in markdown works
run: bash -x -v ci_run_examples.sh
run: bash -x -v scripts/ci_run_examples.sh
- name: test if benchmark works
run: pip install snakemake==7.32.0 && pip install pulp==2.7.0 && sed -i '1s/^/#!\/bin\/bash -x -v\n/' run_benchmark_standalone.sh && bash -x -v run_benchmark_standalone.sh examples/benchmark/demo_shared_hyper_grid.yaml && cat zoutput/benchmarks/mnist_benchmark_grid/hyperparameters.csv && cat zoutput/benchmarks/mnist_benchmark_grid/results.csv
run: |
pip install snakemake==7.32.0 && pip install pulp==2.7.0
echo "insert a shebang line (#!/bin/bash -x -v) at the beginning of the bash script"
sed -i '1s/^/#!\/bin\/bash -x -v\n/' run_benchmark_standalone.sh
bash -x -v run_benchmark_standalone.sh examples/benchmark/demo_shared_hyper_grid.yaml
cat zoutput/benchmarks/mnist_benchmark_grid*/hyperparameters.csv
cat zoutput/benchmarks/mnist_benchmark_grid*/results.csv
10 changes: 8 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
.ropeproject
./zdpath
./zoutput
/zdpath
/zoutput
tests/__pycache__/
*.pyc
.vscode/
domainlab/zdata/pacs
/data/
/.snakemake/
/dist
/domainlab.egg-info
/runs
/slurm_errors.txt
4 changes: 2 additions & 2 deletions docs/doc_benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ hyperparameter sampling and pytorch.
The following script will help to find out which job has failed and the error message, so that you could direct to the
specific log file
```cluster
bash ./sh_list_error.sh ./zoutput/slurm_logs
bash ./sh_list_error.sh ./zoutput/benchmarks/[output folder of the sepcifed benchmark in the yaml file]/slurm_logs
```
#### Map between slurm job id and sampled hyperparameter index
suppose the slurm job id is 14144163, one could the corresponding log file in `./zoutput/slurm_logs` folder via
suppose the slurm job id is 14144163, one could the corresponding log file in `./zoutput/[output folder of the sepcifed benchmark in the yaml file]/slurm_logs` folder via
`find . | grep -i "14144163"`

the results can be
Expand Down
29 changes: 28 additions & 1 deletion docs/doc_usage_cmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,34 @@ To run DomainLab, the minimum necessary parameters are:
### Advanced Configuration

- **Learning Rate (`--lr`):** Set the training learning rate.
- **Regularization (`--gamma_reg`):** Weight of regularization loss.
- **Regularization (`--gamma_reg`):** Sets the weight of the regularization
loss. This parameter can be configured either as
a single value applied to individual classes,
or using a dictionary to specify different
weights for different models and trainers.

- **Command Line Usage:**
- For a single value: `python script.py --gamma_reg=0.1`
- For multiple values: `python script.py --gamma_reg='default=0.1,dann=0.05,jigen=0.2'`

- **YAML Configuration:**
- For a single value:

```yaml
gamma_reg: 0.1
```
- For different values:
```yaml
gamma_reg:
dann: 0.05
dial: 0.2
default: 0.1 # value for every other instance
```
Gamma reg is available for the trainers, as well as the
dann and jigen model.
- **Early Stopping (`--es`):** Steps for early stopping.
- **Random Seed (`--seed`):** Seed for reproducibility.
- **CUDA Options (`--nocu`, `--device`):** Configure CUDA usage and device settings.
Expand Down
3 changes: 2 additions & 1 deletion domainlab/algos/builder_dann.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter
from domainlab.models.model_dann import mk_dann
from domainlab.utils.utils_cuda import get_device
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class NodeAlgoBuilderDANN(NodeAlgoBuilder):
Expand Down Expand Up @@ -55,7 +56,7 @@ def init_business(self, exp):
model = mk_dann(list_str_y=task.list_str_y,
net_classifier=net_classifier)(
list_d_tr=task.list_domain_tr,
alpha=args.gamma_reg,
alpha=get_gamma_reg(args, 'dann'),
net_encoder=net_encoder,
net_discriminator=net_discriminator,
builder=self)
Expand Down
3 changes: 2 additions & 1 deletion domainlab/algos/builder_jigen1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from domainlab.dsets.utils_wrapdset_patches import WrapDsetPatches
from domainlab.models.model_jigen import mk_jigen
from domainlab.utils.utils_cuda import get_device
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class NodeAlgoBuilderJiGen(NodeAlgoBuilder):
Expand Down Expand Up @@ -56,7 +57,7 @@ def init_business(self, exp):
model = mk_jigen(
list_str_y=task.list_str_y,
net_classifier=net_classifier)(
coeff_reg=args.gamma_reg,
coeff_reg=get_gamma_reg(args, 'jigen'),
net_encoder=net_encoder,
net_classifier_permutation=net_classifier_perm,
n_perm=args.nperm,
Expand Down
8 changes: 8 additions & 0 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,11 @@ def dset_decoration_args_algo(self, args, ddset):
if self._decoratee is not None:
return self._decoratee.dset_decoration_args_algo(args, ddset)
return ddset

def print_parameters(self):
"""
Function to print all parameters of the object.
Can be used to print the parameters of any child class
"""
params = vars(self)
print(f"Parameters of {type(self).__name__}: {params}")
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_dial.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.autograd import Variable

from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerDIAL(TrainerBasic):
Expand Down Expand Up @@ -49,4 +50,4 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
tensor_x_adv = self.gen_adversarial(self.device, tensor_x, tensor_y)
tensor_x_batch_adv_no_grad = Variable(tensor_x_adv, requires_grad=False)
loss_dial = self.model.cal_task_loss(tensor_x_batch_adv_no_grad, tensor_y)
return [loss_dial], [self.aconf.gamma_reg]
return [loss_dial], [get_gamma_reg(self.aconf, self.name)]
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_fishr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
backpack = None

from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerFishr(TrainerBasic):
Expand All @@ -39,7 +40,7 @@ def tr_epoch(self, epoch):
dict_layerwise_var_var_grads_sum = \
{key: val.sum() for key, val in dict_layerwise_var_var_grads.items()}
loss_fishr = sum(dict_layerwise_var_var_grads_sum.values())
loss = sum(list_loss_erm) + self.aconf.gamma_reg * loss_fishr
loss = sum(list_loss_erm) + get_gamma_reg(self.aconf, self.name) * loss_fishr
loss.backward()
self.optimizer.step()
self.epo_loss_tr += loss.detach().item()
Expand Down
70 changes: 70 additions & 0 deletions domainlab/algos/trainers/train_irm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
use random start to generate adversarial images
"""
import torch
from torch import autograd
from torch.nn import functional as F
from domainlab.algos.trainers.train_basic import TrainerBasic


class TrainerIRM(TrainerBasic):
"""
IRMv1 split a minibatch into half, and use an unbiased estimate of the
squared gradient norm via inner product
$$\\delta_{w|w=1} \\ell(w\\dot \\Phi(X^{e, i}), Y^{e, i})$$
of dimension dim(Grad)
with
$$\\delta_{w|w=1} \\ell(w\\dot \\Phi(X^{e, j}), Y^{e, j})$$
of dimension dim(Grad)
For more details, see section 3.2 and Appendix D of :
Arjovsky et al., “Invariant Risk Minimization.”
"""
def tr_epoch(self, epoch):
list_loaders = list(self.dict_loader_tr.values())
loaders_zip = zip(*list_loaders)
self.model.train()
self.epo_loss_tr = 0

for ind_batch, tuple_data_domains_batch in enumerate(loaders_zip):
self.optimizer.zero_grad()
list_domain_loss_erm = []
list_domain_reg = []
for batch_domain_e in tuple_data_domains_batch:
tensor_x, tensor_y, tensor_d, *others = batch_domain_e
tensor_x, tensor_y, tensor_d = \
tensor_x.to(self.device), tensor_y.to(self.device), \
tensor_d.to(self.device)
list_domain_loss_erm.append(
self.model.cal_task_loss(tensor_x, tensor_y))
list_1ele_loss_irm, _ = \
self._cal_reg_loss(tensor_x, tensor_y, tensor_d, others)
list_domain_reg += list_1ele_loss_irm
loss = torch.sum(torch.stack(list_domain_loss_erm)) + \
self.aconf.gamma_reg * torch.sum(torch.stack(list_domain_reg))
loss.backward()
self.optimizer.step()
self.epo_loss_tr += loss.detach().item()
self.after_batch(epoch, ind_batch)

flag_stop = self.observer.update(epoch) # notify observer
return flag_stop

def _cal_phi(self, tensor_x):
logits = self.model.cal_logit_y(tensor_x)
return logits

def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
"""
Let trainer behave like a model, so that other trainer could use it
"""
_ = tensor_d
_ = others
y = tensor_y
phi = self._cal_phi(tensor_x)
dummy_w_scale = torch.tensor(1.).to(tensor_x.device).requires_grad_()
loss_1 = F.cross_entropy(phi[::2] * dummy_w_scale, y[::2])
loss_2 = F.cross_entropy(phi[1::2] * dummy_w_scale, y[1::2])
grad_1 = autograd.grad(loss_1, [dummy_w_scale], create_graph=True)[0]
grad_2 = autograd.grad(loss_2, [dummy_w_scale], create_graph=True)[0]
loss_irm = torch.sum(grad_1 * grad_2)
return [loss_irm], [self.aconf.gamma_reg]
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_matchdg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from domainlab.tasks.utils_task_dset import DsetIndDecorator4XYD
from domainlab.utils.logger import Logger
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerMatchDG(AbstractTrainer):
Expand All @@ -36,7 +37,7 @@ def init_business(
self.base_domain_size = get_base_domain_size4match_dg(self.task)
self.epo_loss_tr = 0
self.flag_erm = flag_erm
self.lambda_ctr = self.aconf.gamma_reg
self.lambda_ctr = get_gamma_reg(aconf, self.name)
self.mk_match_tensor(epoch=0)
self.flag_match_tensor_sweep_over = False
self.tuple_tensor_ref_domain2each_y = None
Expand Down
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/train_mldg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.tasks.utils_task import mk_loader
from domainlab.tasks.utils_task_dset import DsetZip
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerMLDG(AbstractTrainer):
Expand Down Expand Up @@ -108,7 +109,7 @@ def tr_epoch(self, epoch):
loss = (
loss_source_task.sum()
+ source_reg_tr.sum()
+ self.aconf.gamma_reg * loss_look_forward.sum()
+ get_gamma_reg(self.aconf, self.name) * loss_look_forward.sum()
)
#
loss.backward()
Expand Down
2 changes: 2 additions & 0 deletions domainlab/algos/trainers/zoo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from domainlab.algos.trainers.train_matchdg import TrainerMatchDG
from domainlab.algos.trainers.train_mldg import TrainerMLDG
from domainlab.algos.trainers.train_fishr import TrainerFishr
from domainlab.algos.trainers.train_irm import TrainerIRM


class TrainerChainNodeGetter(object):
Expand Down Expand Up @@ -49,6 +50,7 @@ def __call__(self, lst_candidates=None, default=None, lst_excludes=None):
chain = TrainerMatchDG(chain)
chain = TrainerMLDG(chain)
chain = TrainerFishr(chain)
chain = TrainerIRM(chain)
chain = TrainerHyperScheduler(chain)
node = chain.handle(self.request)
head = node
Expand Down
39 changes: 38 additions & 1 deletion domainlab/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,37 @@
from domainlab.models.args_vae import add_args2parser_vae
from domainlab.utils.logger import Logger

class ParseValuesOrKeyValuePairs(argparse.Action):
"""Class used for arg parsing where values are provided in a key value format"""

def __call__(self, parser: argparse.ArgumentParser,
namespace: argparse.Namespace, values: str, option_string: str = None):
"""
Handle parsing of key value pairs, or a single value instead
Args:
parser (argparse.ArgumentParser): The ArgumentParser object.
namespace (argparse.Namespace): The namespace object to store parsed values.
values (str): The string containing key=value pairs or a single float value.
option_string (str, optional): The option string that triggered this action (unused).
Raises:
ValueError: If the values cannot be parsed to float.
"""
if "=" in values:
my_dict = {}
for kv in values.split(","):
k, v = kv.split("=")
try:
my_dict[k.strip()] = float(v.strip())
except ValueError:
raise ValueError(f"Invalid value in key-value pair: '{kv}', must be float")
setattr(namespace, self.dest, my_dict)
else:
try:
setattr(namespace, self.dest, float(values))
except ValueError:
raise ValueError(f"Invalid value for {self.dest}: '{values}', must be float")

def mk_parser_main():
"""
Expand All @@ -31,7 +62,13 @@ def mk_parser_main():
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")

parser.add_argument(
"--gamma_reg", type=float, default=0.1, help="weight of regularization loss"
"--gamma_reg",
default=0.1,
help="weight of regularization loss in the form of $$\ell(\cdot) + \mu \times R(\cdot)$$ \
can specify per model as 'default=3.0, dann=1.0,jigen=2.0', where default refer to gamma for trainer \
note diva is implemented $$\ell(\cdot) + \mu \times R(\cdot)$$ \
so diva does not have gamma_reg",
action=ParseValuesOrKeyValuePairs
)

parser.add_argument("--es", type=int, default=1, help="early stop steps")
Expand Down
1 change: 1 addition & 0 deletions domainlab/exp_protocol/aggregate_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def agg_from_directory(input_dir: str, output_file: str):

def agg_main(bm_dir: str, skip_plotting: bool = False):
"""Aggregates partial results and generate plots."""
bm_dir.rstrip("/")
agg_output = f"{bm_dir}/results.csv"
agg_input = f"{bm_dir}/rule_results"
agg_from_directory(agg_input, agg_output)
Expand Down
4 changes: 4 additions & 0 deletions domainlab/exp_protocol/benchmark.smk
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ rule parameter_sampling:
expand("{path}", path=config_path)
output:
dest=expand("{output_dir}/hyperparameters.csv", output_dir=config["output_dir"])
# resources:
# log_dir="slurm_logs_test"
params:
sampling_seed=os.environ["DOMAINLAB_CUDA_HYPERPARAM_SEED"]
run:
Expand Down Expand Up @@ -159,6 +161,8 @@ rule agg_results:
# put different csv file in a big csv file
input:
exp_results=experiment_result_files
# resources:
# log_dir="slurm_logs_test"
output:
out_file=expand("{output_dir}/results.csv", output_dir=config["output_dir"])
run:
Expand Down
32 changes: 32 additions & 0 deletions domainlab/models/a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,35 @@ def dset_decoration_args_algo(self, args, ddset):
if self._decoratee is not None:
return self._decoratee.dset_decoration_args_algo(args, ddset)
return ddset

@property
def p_na_prefix(self):
"""
common prefix for Models
"""
return "Model"

@property
def name(self):
"""
get the name of the algorithm
"""
na_prefix = self.p_na_prefix
len_prefix = len(na_prefix)
na_class = type(self).__name__
if na_class[:len_prefix] != na_prefix:
raise RuntimeError(
"Model builder node class must start with ",
na_prefix,
"the current class is named: ",
na_class,
)
return type(self).__name__[len_prefix:].lower()

def print_parameters(self):
"""
Function to print all parameters of the object.
Can be used to print the parameters of every child class.
"""
params = vars(self)
print(f"Parameters of {type(self).__name__}: {params}")
Loading

0 comments on commit 054485c

Please sign in to comment.