Skip to content

Commit

Permalink
Merge branch 'mhof_dev' into mhof_dev_lr_scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun authored Dec 10, 2024
2 parents b0279dc + 924f4cc commit 7ac297a
Show file tree
Hide file tree
Showing 13 changed files with 74 additions and 17 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: CI

on:
push:
branches: mhof_dev_merge
branches: mhof_dev
pull_request:
branches: mhof_dev_merge
branches: mhof_dev
workflow_dispatch:
jobs:
test:
Expand Down
2 changes: 1 addition & 1 deletion domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
list_reg_loss_trainer_tensor
list_mu = list_mu_model + list_mu_trainer
# ERM return a tensor of all zeros, delete here
if len(list_mu) > 1:
if len(list_mu) > 1 and "ModelERM" == type(self.get_model()).__name__:
list_boolean_zero = [torch.all(torch.eq(list_loss_tensor[i], 0)).item()
for i in range(len(list_mu))]
list_loss_tensor = [list_loss_tensor[i] for (i, flag) in
Expand Down
3 changes: 3 additions & 0 deletions domainlab/algos/trainers/fbopt_mu_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def __init__(self, trainer, **kwargs):
self.mu_min = trainer.aconf.mu_min
self.mu_clip = trainer.aconf.mu_clip

if not kwargs:
raise RuntimeError("feedback scheduler requires **kwargs, the set \
of multipliers non-empty")
self.mmu = kwargs
# force initial value of mu
self.mmu = {key: self.init_mu for key, val in self.mmu.items()}
Expand Down
2 changes: 1 addition & 1 deletion domainlab/algos/trainers/train_causIRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch):
penalty = torch.nan_to_num(self.mmd(first, second))
else:
penalty = torch.tensor(0)
loss = self.cal_loss(tensor_x, tensor_y, tensor_d, others)
loss, *_ = self.cal_loss(tensor_x, tensor_y, tensor_d, others)
loss = loss + penalty
loss.backward()
self.optimizer.step()
Expand Down
4 changes: 2 additions & 2 deletions domainlab/algos/trainers/train_ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def move_average(self, dict_data, epoch):
self._ma_iter += 1
return dict_return_ema_para_curr_iter

def after_epoch(self, epoch):
def after_epoch(self, epoch, flag_info=None):
torch_model = self.get_model()
dict_para = torch_model.state_dict() # only for trainable parameters
new_dict_para = self.move_average(dict_para, epoch)
# without deepcopy, this seems to work
torch_model.load_state_dict(new_dict_para)
super().after_epoch(epoch)
super().after_epoch(epoch, flag_info)
2 changes: 1 addition & 1 deletion domainlab/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def mk_parser_main():
parser.add_argument(
"--lr_scheduler",
type=str,
default="CosineAnnealingLR",
default=None,
help="name of pytorch learning rate scheduler",
)

Expand Down
2 changes: 1 addition & 1 deletion examples/yaml/slurm/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ cluster:
--qos=gpu_normal
--gres=gpu:1
--nice=10000
-t 48:00:00
-t 24:00:00
-c 2
--mem=160G
--job-name=$outputfolder-{rule}-{wildcards}
Expand Down
27 changes: 27 additions & 0 deletions scripts/generate_latex_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
aggregate benchmark csv file to generate latex table
"""
import argparse
import pandas as pd


def gen_latex_table(raw_df, fname="table_perf.tex",
group="method", str_perf="acc"):
"""
aggregate benchmark csv file to generate latex table
"""
df_result = raw_df.groupby(group)[str_perf].agg(["mean", "std", "count"])
latex_table = df_result.to_latex(float_format="%.3f")
str_table = df_result.to_string()
print(str_table)
with open(fname, 'w') as file:
file.write(latex_table)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Read a CSV file")
parser.add_argument("filename", help="Name of the CSV file to read")
args = parser.parse_args()

df = pd.read_csv(args.filename, index_col=False, skipinitialspace=True)
gen_latex_table(df)
5 changes: 3 additions & 2 deletions scripts/sh_genplot.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mkdir $2
python main_out.py --gen_plots $1 --outp_dir $2
# mkdir $2
sh scripts/merge_csvs.sh $1
python main_out.py --gen_plots merged_data.csv --outp_dir partial_agg_plots
22 changes: 17 additions & 5 deletions tests/test_fbopt.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,54 @@
"""
unit and end-end test for deep all, mldg
"""
import pytest
from tests.utils_test import utils_test_algo


def test_dann_fbopt():
"""
dann
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=dann --trainer=fbopt --nname=alexnet --epos=3"
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=dann --trainer=fbopt --nname=alexnet --epos=3 --no_dump"
utils_test_algo(args)


def test_jigen_fbopt():
"""
jigen
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=jigen --trainer=fbopt --nname=alexnet --epos=3"
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=jigen --trainer=fbopt --nname=alexnet --epos=3 --no_dump"
utils_test_algo(args)


def test_diva_fbopt():
"""
diva
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=diva --gamma_y=1.0 --trainer=fbopt --nname=alexnet --epos=3"
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=diva --gamma_y=1.0 --trainer=fbopt --nname=alexnet --epos=3 --no_dump"
utils_test_algo(args)


def test_erm_fbopt():
"""
erm
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt --nname=alexnet --epos=3" # pylint: disable=line-too-long
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt --nname=alexnet --epos=3 --no_dump" # pylint: disable=line-too-long
with pytest.raises(RuntimeError):
utils_test_algo(args)


def test_irm_fbopt():
"""
irm
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt_irm --nname=alexnet --epos=3 --no_dump" # pylint: disable=line-too-long
utils_test_algo(args)


def test_forcesetpoint_fbopt():
"""
diva
"""
args = "--te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --model=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=10 --es=0 --mu_init=0.00001 --coeff_ma_setpoint=0.5 --coeff_ma_output_state=0.99 --force_setpoint_change_once"
args = "--te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --model=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=10 --es=0 --mu_init=0.00001 --coeff_ma_setpoint=0.5 --coeff_ma_output_state=0.99 --force_setpoint_change_once --no_dump"
utils_test_algo(args)
2 changes: 1 addition & 1 deletion tests/test_fbopt_irm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ def test_mhof_irm():
"""
args = "--te_d=0 --task=mnistcolor10 --model=erm \
--trainer=fbopt_irm --nname=conv_bn_pool_2 \
--k_i_gain_ratio=0.5"
--k_i_gain_ratio=0.5 --no_dump"
utils_test_algo(args)
2 changes: 1 addition & 1 deletion tests/test_fbopt_setpoint_rewind.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ def test_jigen_fbopt():
"""
jigen
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=jigen --trainer=fbopt --nname=alexnet --epos=300 --setpoint_rewind=yes"
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=jigen --trainer=fbopt --nname=alexnet --epos=300 --setpoint_rewind --no_dump"
utils_test_algo(args)
14 changes: 14 additions & 0 deletions tests/test_lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

"""
unit and end-end test for lr scheduler
"""
from tests.utils_test import utils_test_algo


def test_lr_scheduler():
"""
train
"""
args = "--te_d=2 --tr_d 0 1 --task=mnistcolor10 --debug --bs=100 --model=erm \
--nname=conv_bn_pool_2 --no_dump --lr_scheduler=CosineAnnealingLR"
utils_test_algo(args)

0 comments on commit 7ac297a

Please sign in to comment.