Skip to content

Commit

Permalink
fix bug 0 drag state down
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun committed Oct 10, 2023
1 parent 1f857b1 commit 2f6df90
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions domainlab/algos/trainers/fbopt_setpoint_ada.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def list_add(list1, list2):

def list_multiply(list1, coeff):
return [ele * coeff for ele in list1]

def is_less_list_any(list1, list2):
"""
judge if one list is less than the other
Expand Down Expand Up @@ -71,8 +71,10 @@ def observe(self, epo_reg_loss, epo_task_loss):
read current epo_reg_loss continuously
FIXME: setpoint should also be able to be eliviated
"""
self.state_epo_reg_loss = list_add(list_multiply(epo_reg_loss, self.coeff_ma_output), list_multiply(self.state_epo_reg_loss, 1 - self.coeff_ma_output))
self.state_task_loss = epo_task_loss
self.state_epo_reg_loss = [self.coeff_ma_output*a + ( 1-self.coeff_ma_output )*b if a != 0.0 else b for a, b in zip(self.state_epo_reg_loss, epo_reg_loss)]
if self.state_task_loss == 0.0:
self.state_task_loss = epo_task_loss
self.state_task_loss = self.coeff_ma_output * self.state_task_loss + (1-self.coeff_ma_output) * epo_task_loss
if self.state_updater.update_setpoint():
logger = Logger.get_logger(logger_name='main_out_logger', loglevel="INFO")
self.setpoint4R = self.state_epo_reg_loss
Expand Down

0 comments on commit 2f6df90

Please sign in to comment.