Skip to content

Commit

Permalink
Merge pull request #1 from foreverska/feature/multibuffalo
Browse files Browse the repository at this point in the history
multibuffalo
  • Loading branch information
foreverska authored Apr 14, 2024
2 parents 12f3851 + b370bed commit 5012dca
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 7 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ distribution (0, 1) and added to the chosen arm center
value. This is not intended to be challenging for an agent but
easy for the debugger to reason about.

## Multi-Buffalo ("MultiBuffalo-v0")

This serves as a contextual bandit implementation. It is a
k-armed bandit with n states. These states are indicated to
the agent in the observation and the two states have different
reward offsets for each arm. The goal of the agent is to
learn and contextualize best action for a given state. This is
a good stepping stone to Markov Decision Processes.

This module had an extra parameter, pace. By default (None), a
new state is chosen for every step of the environment. It can
be set to any integer to determine how many steps between randomly
choosing a new state. Of course, transitioning to a new state is
not guaranteed as the next state is random.

## Using

Install via pip and import buffalo_gym along with gymnasium.
Expand Down
6 changes: 6 additions & 0 deletions buffalo_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,9 @@
entry_point='buffalo_gym.envs:BuffaloEnv',
max_episode_steps=1000
)

register(
id='MultiBuffalo-v0',
entry_point='buffalo_gym.envs:MultiBuffaloEnv',
max_episode_steps=1000
)
3 changes: 2 additions & 1 deletion buffalo_gym/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from buffalo_gym.envs.buffalo_gym import BuffaloEnv
from buffalo_gym.envs.buffalo_gym import BuffaloEnv
from buffalo_gym.envs.multibuffalo_gym import MultiBuffaloEnv
9 changes: 6 additions & 3 deletions buffalo_gym/envs/buffalo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@


class BuffaloEnv(gym.Env):
"""
Standard multi-armed bandit environment with static reward distributions.
"""
metadata = {'render_modes': []}

def __init__(self, arms: int = 10):
Expand All @@ -16,7 +19,7 @@ def __init__(self, arms: int = 10):
:param arms: number of arms
"""
self.action_space = gym.spaces.Discrete(arms)
self.observation_space = gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.int64)
self.observation_space = gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)

self.offsets = np.random.normal(0, arms, (arms,))

Expand All @@ -31,7 +34,7 @@ def reset(self,
:return: observation, info
"""

return np.zeros((1,), dtype=np.int64), {}
return np.zeros((1,), dtype=np.float32), {}

def step(self, action: int) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""
Expand All @@ -41,4 +44,4 @@ def step(self, action: int) -> tuple[ObsType, SupportsFloat, bool, bool, dict[st
"""
reward = np.random.normal(0, 1, 1)[0] + self.offsets[action]

return np.zeros((1,), dtype=np.int64), reward, False, False, {}
return np.zeros((1,), dtype=np.float32), reward, False, False, {}
60 changes: 60 additions & 0 deletions buffalo_gym/envs/multibuffalo_gym.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Any, TypeVar, SupportsFloat
import random

import gymnasium as gym
import numpy as np

ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")


class MultiBuffaloEnv(gym.Env):
"""
Multi-armed bandit environment with multiple states that have (probably) distinct rewards
"""
metadata = {'render_modes': []}

def __init__(self, arms: int = 10, states: int = 2, pace: int | None = None):
"""
Multi-armed bandit environment with k arms and n states
:param arms: number of arms
:param states: number of states
:param pace: number of steps between state changes, None for every step
"""
self.action_space = gym.spaces.Discrete(arms)
self.observation_space = gym.spaces.Box(low=0, high=states, shape=(1,), dtype=np.float32)

self.offsets = np.random.normal(0, arms, (arms, states))
self.pace = pace
self.states = states
self.state = 0
self.ssr = 0

def reset(self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None) -> tuple[ObsType, dict[str, Any]]:
"""
Resets the environment
:param seed: WARN unused, defaults to None
:param options: WARN unused, defaults to None
:return: observation, info
"""

self.state = 0
self.ssr = 0

return np.zeros((1,), dtype=np.float32), {}

def step(self, action: int) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""
Steps the environment
:param action: arm to pull
:return: observation, reward, done, term, info
"""
reward = np.random.normal(0, 1, 1)[0] + self.offsets[action, self.state]
self.ssr += 1
if self.pace is None or self.ssr % self.pace == 0:
self.state = random.randint(0, self.states - 1)

return np.ones((1,), dtype=np.float32)*self.state, reward, False, False, {}
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
description="Buffalo Gym environment",
long_description=long_description,
long_description_content_type="text/markdown",
version="0.0.1",
version="0.0.2",
author="foreverska",
install_requires=["gymnasium>=0.26.0", "numpy"],
keywords="gymnasium, gym",
Expand Down
17 changes: 15 additions & 2 deletions tests/test_buffalo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@

import numpy as np
import gymnasium as gym
import buffalo_gym.envs.buffalo_gym

def test_buffalo():
env = gym.make('Buffalo-v0')

env.reset()
env.step(env.action_space.sample())
obs, info = env.reset()

assert obs.shape == (1,)
assert obs.dtype == np.float32
assert obs[0] == 0

obs, reward, done, term, info = env.step(env.action_space.sample())

assert obs.shape == (1,)
assert obs.dtype == np.float32
assert obs[0] == 0
assert done is False
assert term is False

assert 1
48 changes: 48 additions & 0 deletions tests/test_multibuffalo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@

import numpy as np
import gymnasium as gym
import buffalo_gym

def test_multibuffalo():
env = gym.make('MultiBuffalo-v0')

obs, info = env.reset()

assert obs.shape == (1,)
assert obs.dtype == np.float32
assert obs[0] in (0, 1)

states = []
for _ in range(10):
obs, reward, done, term, info = env.step(env.action_space.sample())

assert obs.shape == (1,)
assert obs.dtype == np.float32
assert obs[0] in (0, 1)
states.append(obs[0])
assert done is False
assert term is False

assert set(states) == {0, 1}

def test_multibuffalo_threestates():
env = gym.make('MultiBuffalo-v0', states=3)

obs, info = env.reset()

assert obs.shape == (1,)
assert obs.dtype == np.float32
assert obs[0] in (0, 1, 2)

states = []
for _ in range(10):
obs, reward, done, term, info = env.step(env.action_space.sample())

assert obs.shape == (1,)
assert obs.dtype == np.float32
assert obs[0] in (0, 1, 2)
states.append(obs[0])
assert done is False
assert term is False

assert set(states) == {0, 1, 2}

0 comments on commit 5012dca

Please sign in to comment.