Skip to content

Commit

Permalink
no drop for loader
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun committed Sep 13, 2023
1 parent a716e71 commit ad0be7c
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
2 changes: 2 additions & 0 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, successor_node=None):
self.aconf = None
#
self.loader_tr = None
self.loader_tr_no_drop = None
self.loader_te = None
self.num_batches = None
self.flag_update_hyper_per_epoch = None
Expand Down Expand Up @@ -67,6 +68,7 @@ def init_business(self, model, task, observer, device, aconf, flag_accept=True):
self.aconf = aconf
#
self.loader_tr = task.loader_tr
self.loader_tr_no_drop = task._loader_tr_no_drop
self.loader_te = task.loader_te

if flag_accept:
Expand Down
7 changes: 2 additions & 5 deletions domainlab/algos/trainers/train_fbopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ def eval_p_loss(self, dict4mu, dict_theta):
temp_model.hyper_update(epoch=None, fun_scheduler=HyperSetter(dict4mu))
temp_model.set_params(dict_theta)
epo_p_loss = 0 # penalized loss
# FIXME: will loader be corupted? if called at different places? if we do not make deep copy
with torch.no_grad():
for _, (tensor_x, vec_y, vec_d, *_) in enumerate(self.loader_tr):
for _, (tensor_x, vec_y, vec_d, *_) in enumerate(self.loader_tr_no_drop):
tensor_x, vec_y, vec_d = \
tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device)
# sum will kill the dimension of the mini batch
Expand All @@ -87,14 +86,12 @@ def eval_r_loss(self):
ERM loss on all available training data
# TODO: normalize loss via batchsize
"""
# FIXME: move this to model instead of having it in trainer here
temp_model = copy.deepcopy(self.model)
# mock the model hyper-parameter to be from dict4mu
epo_reg_loss = 0
epo_task_loss = 0
# FIXME: will loader be corupted? if called at different places? if we do not make deep copy
with torch.no_grad():
for _, (tensor_x, vec_y, vec_d, *_) in enumerate(self.loader_tr):
for _, (tensor_x, vec_y, vec_d, *_) in enumerate(self.loader_tr_no_drop):
tensor_x, vec_y, vec_d = \
tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device)
b_reg_loss = temp_model.cal_reg_loss(tensor_x, vec_y, vec_d).sum()
Expand Down
1 change: 1 addition & 0 deletions domainlab/tasks/a_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class NodeTaskDG(AbstractChainNodeHandler):
def __init__(self, succ=None):
super().__init__(succ)
self._loader_tr = None
self._loader_tr_no_drop = None
self._loader_te = None
self._loader_val = None
self._list_domains = None
Expand Down
1 change: 1 addition & 0 deletions domainlab/tasks/b_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def init_business(self, args, node_algo_builder=None):
self.dict_dset_val.update({na_domain: ddset_val})
ddset_mix = ConcatDataset(tuple(self.dict_dset_tr.values()))
self._loader_tr = mk_loader(ddset_mix, args.bs)
self._loader_tr_no_drop = mk_loader(ddset_mix, args.bs, drop_last=False, shuffle=False)

ddset_mix_val = ConcatDataset(tuple(self.dict_dset_val.values()))
self._loader_val = mk_loader(ddset_mix_val, args.bs,
Expand Down

0 comments on commit ad0be7c

Please sign in to comment.