Skip to content

Commit

Permalink
ma setpiont update
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun committed Oct 10, 2023
1 parent 682811e commit ee78952
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions domainlab/algos/trainers/fbopt_setpoint_ada.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,12 @@ def transition_to(self, state):
self.state_updater = state
self.state_updater.accept(self)

def update_setpoint_ma(self, target):
def update_setpoint_ma(self, list_target):
"""
using moving average
"""
temp_ma = self.coeff_ma * torch.tensor(target)
temp_ma += (1 - self.coeff_ma) * torch.tensor(self.setpoint4R)
temp_ma = temp_ma.tolist()
self.setpoint4R = temp_ma
target_ma = [self.coeff_ma * a + (1 - self.coeff_ma) *b for a, b in zip(self.setpoint4R, list_target)]
self.setpoint4R = target_ma

def observe(self, epo_reg_loss, epo_task_loss):
"""
Expand All @@ -69,7 +67,7 @@ def observe(self, epo_reg_loss, epo_task_loss):
self.state_task_loss = epo_task_loss
if self.state_updater.update_setpoint():
logger = Logger.get_logger(logger_name='main_out_logger', loglevel="INFO")
self.setpoint4R = self.state_epo_reg_loss
self.update_setpoint_ma(self.state_epo_reg_loss)
logger.info(f"!!!!!set point updated to {self.setpoint4R}!")


Expand Down

0 comments on commit ee78952

Please sign in to comment.