Skip to content

Commit

Permalink
Corrected BanditScheduler ask method
Browse files Browse the repository at this point in the history
  • Loading branch information
Tekexa committed Dec 10, 2024
1 parent 301bd7f commit 4329d4a
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions ribs/schedulers/_bandit_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,14 @@ def ask(self):
reselect = self._active_arr.copy()

# If not enough emitters are active, activate the first num_active.
num_needed = self._num_active - self._active_arr.sum()
# This always happens on the first iteration(s).
nb_activated = self._num_active - self._active_arr.sum()
i = 0
while num_needed > 0:
while nb_activated > 0:
reselect[i] = False
if not self._active_arr[i]:
self._active_arr[i] = True
num_needed -= 1
nb_activated -= 1
i += 1

# Deactivate emitters to be reselected.
Expand All @@ -273,19 +274,14 @@ def ask(self):
self._zeta * np.sqrt(
np.log(self._success.sum()) /
self._selection[update_ucb]))
# Activate top emitters based on UCB1.
activate = np.argsort(ucb1)[-reselect.sum():]
self._active_arr[activate] = True

# Deactivate emitters if there are too many active emitters.
nb_deactivate = self.emitters.sum() - self._num_active
deactivate = np.argsort(ucb1)
for i in deactivate:
if nb_deactivate == 0:
# Activate top emitters based on UCB1, until there are num_active
# active emitters. Activate only inactive emitters.
activate = np.argsort(ucb1)[::-1]
for i in activate:
if self._active_arr.sum() == self._num_active:
break
if self._active_arr[i]:
self._active_arr[i] = False
nb_deactivate -= 1
if not self._active_arr[i]:
self._active_arr[i] = True

self._cur_solutions = []

Expand Down

0 comments on commit 4329d4a

Please sign in to comment.