diff --git a/domainlab/algos/trainers/fbopt_setpoint_ada.py b/domainlab/algos/trainers/fbopt_setpoint_ada.py index e7dede900..a789a1ccc 100644 --- a/domainlab/algos/trainers/fbopt_setpoint_ada.py +++ b/domainlab/algos/trainers/fbopt_setpoint_ada.py @@ -9,7 +9,7 @@ def list_add(list1, list2): def list_multiply(list1, coeff): return [ele * coeff for ele in list1] - + def is_less_list_any(list1, list2): """ judge if one list is less than the other @@ -71,8 +71,10 @@ def observe(self, epo_reg_loss, epo_task_loss): read current epo_reg_loss continuously FIXME: setpoint should also be able to be eliviated """ - self.state_epo_reg_loss = list_add(list_multiply(epo_reg_loss, self.coeff_ma_output), list_multiply(self.state_epo_reg_loss, 1 - self.coeff_ma_output)) - self.state_task_loss = epo_task_loss + self.state_epo_reg_loss = [self.coeff_ma_output*a + ( 1-self.coeff_ma_output )*b if a != 0.0 else b for a, b in zip(self.state_epo_reg_loss, epo_reg_loss)] + if self.state_task_loss == 0.0: + self.state_task_loss = epo_task_loss + self.state_task_loss = self.coeff_ma_output * self.state_task_loss + (1-self.coeff_ma_output) * epo_task_loss if self.state_updater.update_setpoint(): logger = Logger.get_logger(logger_name='main_out_logger', loglevel="INFO") self.setpoint4R = self.state_epo_reg_loss