Skip to content

Commit

Permalink
Simplified agents
Browse files Browse the repository at this point in the history
  • Loading branch information
vadim0x60 committed Feb 14, 2024
1 parent af2adc7 commit 055724a
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 42 deletions.
2 changes: 1 addition & 1 deletion examples/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

env = gym.make('MountainCarContinuous-v0', max_episode_steps=500, render_mode='human')
program = Program(source=mountain_car_solver, language='Python')
agent = program.spawn().rl(env.action_space, env.observation_space)
agent = program.spawn()

obs, info = env.reset()
print(obs, info)
Expand Down
2 changes: 1 addition & 1 deletion examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
gym[classic_control]~=0.26
gymnasium[classic_control]
GitPython~=3.1
programlib
61 changes: 23 additions & 38 deletions programlib/agent.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import gymnasium as gym
import numpy as np

def decode_action(action_space, action):
def decode_action(action):
try:
a = eval(action)

if not isinstance(action_space, gym.spaces.Discrete):
return [a]
x = np.array(eval(action))
except SyntaxError:
return list(map(eval, action.split(r'[\p\s]+')))
x = np.array(map(eval, action.split(r'[\p\s]+')))

return x.reshape(-1)

def encode_obs(obs_space, obs):
def encode_obs(obs):
try:
obs = obs.tolist()
except AttributeError:
Expand Down Expand Up @@ -39,53 +38,39 @@ def act(self, input_lines):

self.process.expect(self.delimiter)
return self.process.before.decode()

def rl(self, action_space, obs_space):
return RLAgent(self, action_space, obs_space)

def close(self):
self.process.close()
self.program.exitstatus = self.process.exitstatus

def __del__(self):
self.close()

class RLAgent():
"""
Reinforcement Learning Agent: represents a running program for control in
an OpenAI gym environment. Mimics the interface of a stable-baselines model.
"""

def __init__(self, agent, action_space, obs_space) -> None:
self.agent = agent
self.action_space = action_space
self.obs_space = obs_space

def predict(self, obs, deterministic=True):
"""
Predict what the next action should be given the current observation
Predict what the next action should be given the current observation.
Same as act(), but designed to work with reinforcement learning envs.
Mimics the interface of a stable-baselines model.
The observations will be passed to stdin of the program, and the action
will be read from stdout.
Parameters
----------
obs - the current observation
deterministic - whether to return the action or a pseudo-stochastic
vector of action probabilities (one-hot)
deterministic - should always be set to True,
for compatibility with stable-baselines
Returns (action, state) tuple
-------
action - the action to take
state - a reference to the process to examine the execution state
"""

obs_str = encode_obs(self.obs_space, obs)
action_str = self.agent.act(obs_str)
action = decode_action(self.action_space, action_str)
assert deterministic, "Pseudo-stochastic actions not supported"

if not deterministic:
actions_probs = np.zeros(self.action_space.n)
actions_probs[action] = 1.0
obs_str = encode_obs(obs)
action_str = self.act(obs_str)
action = decode_action(action_str)

return action, self.agent.process
return action, self.process

def close(self):
self.process.close()
self.program.exitstatus = self.process.exitstatus

def __del__(self):
self.close()
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "programlib"
version = "9.0.3"
version = "10.0.0"
description = "Programs as Objects"
authors = ["Vadim Liventsev <[email protected]>"]
license = "MIT"
Expand All @@ -14,7 +14,6 @@ classifiers = [
[tool.poetry.dependencies]
python = "^3.8"
pexpect = "^4.8.0"
gymnasium = " >=0.0.0"
numpy = "^1.24.2"
pyte = "^0.8.0"
contextlib-chdir = {version = "^1.0.2", python = "<3.11"}
Expand Down

0 comments on commit 055724a

Please sign in to comment.