Skip to content

Commit

Permalink
Merge branch 'master' into lb_fix_yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
lisab00 authored Sep 29, 2023
2 parents fd8a5e2 + 3479847 commit 1d04cde
Show file tree
Hide file tree
Showing 14 changed files with 48 additions and 331 deletions.
2 changes: 0 additions & 2 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(self, successor_node=None):
#
self.loader_tr = None
self.loader_te = None
self.dict_loader_tr = None
self.num_batches = None
self.flag_update_hyper_per_epoch = None
self.flag_update_hyper_per_batch = None
Expand Down Expand Up @@ -67,7 +66,6 @@ def init_business(self, model, task, observer, device, aconf, flag_accept=True):
#
self.loader_tr = task.loader_tr
self.loader_te = task.loader_te
self.dict_loader_tr = task.dict_loader_tr

if flag_accept:
self.observer.accept(self)
Expand Down
172 changes: 0 additions & 172 deletions domainlab/algos/trainers/train_fishr.py

This file was deleted.

3 changes: 0 additions & 3 deletions domainlab/algos/trainers/zoo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
from domainlab.algos.trainers.train_dial import TrainerDIAL
from domainlab.algos.trainers.train_matchdg import TrainerMatchDG
from domainlab.algos.trainers.train_mldg import TrainerMLDG
from domainlab.algos.trainers.train_fishr import TrainerFishr
from domainlab.algos.trainers.train_hyper_scheduler import TrainerHyperScheduler



class TrainerChainNodeGetter(object):
"""
Chain of Responsibility: node is named in pattern Trainer[XXX] where the string
Expand Down Expand Up @@ -40,7 +38,6 @@ def __call__(self, lst_candidates=None, default=None, lst_excludes=None):
chain = TrainerDIAL(chain)
chain = TrainerMatchDG(chain)
chain = TrainerMLDG(chain)
chain = TrainerFishr(chain)
chain = TrainerHyperScheduler(chain) # FIXME: change to warmup
node = chain.handle(self.request)
return node
3 changes: 3 additions & 0 deletions domainlab/compos/exp/exp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def mk_model_na(self, tag=None, dd_cut=19):
model_name = "_".join(list4mname)
if self.host.args.debug:
model_name = "debug_" + model_name
slurm = os.environ.get('SLURM_JOB_ID')
if slurm:
model_name = model_name + '_' + slurm
logger = Logger.get_logger()
logger.info(f"model name: {model_name}")
return model_name
Expand Down
8 changes: 0 additions & 8 deletions domainlab/models/a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,6 @@

from torch import nn

try:
from backpack import extend
except:
backpack = None




class AModel(nn.Module, metaclass=abc.ABCMeta):
"""
Expand Down Expand Up @@ -65,4 +58,3 @@ def forward(self, tensor_x, tensor_y, tensor_d, others=None):
:param d:
"""
return self.cal_loss(tensor_x, tensor_y, tensor_d, others)

13 changes: 1 addition & 12 deletions domainlab/models/a_model_classif.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,13 @@
from torch import nn as nn
from torch.nn import functional as F


try:
from backpack import backpack, extend
from backpack.extensions import BatchGrad, Variance
except:
backpack = None



from domainlab.models.a_model import AModel
from domainlab.utils.utils_class import store_args
from domainlab.utils.utils_classif import get_label_na, logit2preds_vpic
from domainlab.utils.perf import PerfClassif
from domainlab.utils.perf_metrics import PerfMetricClassif
from domainlab.utils.logger import Logger

loss_cross_entropy_extended = extend(nn.CrossEntropyLoss())


class AModelClassif(AModel, metaclass=abc.ABCMeta):
"""
Expand Down Expand Up @@ -124,7 +113,7 @@ def cal_task_loss(self, tensor_x, tensor_y):
y_target = tensor_y
else:
_, y_target = tensor_y.max(dim=1)
lc_y = loss_cross_entropy_extended(logit_y, y_target)
lc_y = F.cross_entropy(logit_y, y_target, reduction="none")
# cross entropy always return a scalar, no need for inside instance reduction
return lc_y

Expand Down
12 changes: 0 additions & 12 deletions domainlab/models/model_deep_all.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@

try:
from backpack import extend
except:
backpack = None


from domainlab.models.a_model_classif import AModelClassif
from domainlab.utils.override_interface import override_interface

Expand Down Expand Up @@ -59,9 +52,4 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others=None):
lc_y = self.cal_task_loss(tensor_x, tensor_y)
return lc_y

def convert4backpack(self):
"""
convert the module to backpack for 2nd order gradients
"""
self.net = extend(self.net, use_converter=True)
return ModelDeepAll
1 change: 0 additions & 1 deletion domainlab/tasks/a_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(self, succ=None):
self._args = None
self.dict_dset_all = {} # persist
self.dict_dset_tr = {} # versatile variable: which domains to use as training
self.dict_loader_tr = {}
self.dict_dset_te = {} # versatile
self.dict_dset_val = {} # versatile
self.dict_domain_class_count = {}
Expand Down
1 change: 0 additions & 1 deletion domainlab/tasks/b_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def init_business(self, args, node_algo_builder=None):
ddset_tr = node_algo_builder.dset_decoration_args_algo(args, ddset_tr)
ddset_val = node_algo_builder.dset_decoration_args_algo(args, ddset_val)
self.dict_dset_tr.update({na_domain: ddset_tr})
self.dict_loader_tr.update({na_domain: mk_loader(ddset_tr, args.bs)})
self.dict_dset_val.update({na_domain: ddset_val})
ddset_mix = ConcatDataset(tuple(self.dict_dset_tr.values()))
self._loader_tr = mk_loader(ddset_mix, args.bs)
Expand Down
Loading

0 comments on commit 1d04cde

Please sign in to comment.