diff --git a/boptestGymEnv.py b/boptestGymEnv.py index 9ba2850..307183d 100644 --- a/boptestGymEnv.py +++ b/boptestGymEnv.py @@ -1013,8 +1013,8 @@ def __init__(self, env, n_bins_act=10): env: gym.Env Original gym environment n_bins_obs: integer - Number of bins to be used in the transformed observation space - for each observation. + Number of bins to be used in the transformed action space + for each action. ''' @@ -1059,26 +1059,27 @@ def _get_indices(self, action_wrapper): ------- Suppose: self.n_act = 3 (number of actions) - self.n_bins_act = 4 (number of bins per action) + self.n_bins_act = 3 (number of bins per action, this means 4 values possible per action) self.val_bins_act = [[0, 1, 2, 3], [10, 11, 12, 13], [20, 21, 22, 23]] (value bins for each action) Then, `_get_indices` example, for action_wrapper = 37: indices = [] Loop 3 times: - Iteration 1: indices.append((37+1) % 4) -> indices = [2], action_wrapper //= 4 -> action_wrapper = 9 - Iteration 2: indices.append((9+1) % 4) -> indices = [2, 2], action_wrapper //= 4 -> action_wrapper = 2 - Iteration 3: indices.append((2+1) % 4) -> indices = [2, 2, 3], action_wrapper //= 4 -> action_wrapper = 0 - Reverse indices: [3, 2, 2] + Iteration 1: indices.append((37 % (3+1)) -> indices = [1], action_wrapper //= 4 -> action_wrapper = 9 + Iteration 2: indices.append((9 % (3+1)) -> indices = [1, 1], action_wrapper //= 4 -> action_wrapper = 2 + Iteration 3: indices.append((2 % (3+1)) -> indices = [1, 1, 2], action_wrapper //= 4 -> action_wrapper = 0 + Reverse indices: [2, 1, 1] Note ---- - To understand why we need to add 1 in `action_wrapper+1)%self.n_bins_act` think of the corner case + To understand why we need to add 1 in `action_wrapper%(self.n_bins_act+1)` think of the edge case where we only have one bin. If the action_wrapper is 1, then the index should be 1, but if we do not - add 1, the index would be 0 (because 1%1=0). + add 1 to `self.n_bins_act`, the index would be 0 (because 1%1=0). The underlying reason is that + n_bins_act is the number of bins, not the number of possible action values. """ indices=[] for _ in range(self.n_act): - indices.append((action_wrapper+1)%self.n_bins_act) + indices.append(action_wrapper%(self.n_bins_act+1)) action_wrapper //= self.n_bins_act return indices[::-1] diff --git a/testing/references/multiaction_training.csv b/testing/references/multiaction_training.csv index 1aff302..8b4092a 100644 --- a/testing/references/multiaction_training.csv +++ b/testing/references/multiaction_training.csv @@ -1,2 +1,2 @@ keys,value -0,439 +0,841