Skip to content

Commit

Permalink
Consistently use VecEnv environments in docs/algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
jas-ho committed Aug 7, 2023
1 parent a4d2cbd commit 441d8df
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 58 deletions.
40 changes: 14 additions & 26 deletions docs/algorithms/airl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,60 +23,48 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl`
:skipif: skip_doctests

import numpy as np
import gym
import seals # needed to load "seals/" environments
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms.adversarial.airl import AIRL
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
import seals # noqa: F401 # needed to load "seals/" environments
from imitation.policies.serialize import load_policy
from imitation.rewards.reward_nets import BasicShapedRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env
from imitation.data.wrappers import RolloutInfoWrapper

rng = np.random.default_rng(0)

env = gym.make("seals/CartPole-v0")
env = make_vec_env(
"seals/CartPole-v0",
rng=rng,
n_envs=8,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], # for computing rollouts
)
expert = load_policy(
"ppo-huggingface",
organization="HumanCompatibleAI",
env_name="seals-CartPole-v0",
venv=env,
)

rollouts = rollout.rollout(
expert,
make_vec_env(
"seals/CartPole-v0",
rng=rng,
n_envs=5,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
),
env,
rollout.make_sample_until(min_timesteps=None, min_episodes=60),
rng=rng,
)

venv = make_vec_env("seals/CartPole-v0", rng=rng, n_envs=8)
learner = PPO(env=venv, policy=MlpPolicy)
learner = PPO(env=env, policy=MlpPolicy)
reward_net = BasicShapedRewardNet(
venv.observation_space,
venv.action_space,
env.observation_space,
env.action_space,
normalize_input_layer=RunningNorm,
)
airl_trainer = AIRL(
demonstrations=rollouts,
demo_batch_size=1024,
gen_replay_buffer_capacity=2048,
n_disc_updates_per_round=4,
venv=venv,
venv=env,
gen_algo=learner,
reward_net=reward_net,
)
airl_trainer.train(20000)
rewards, _ = evaluate_policy(learner, venv, 100, return_episode_rewards=True)
rewards, _ = evaluate_policy(learner, env, 100, return_episode_rewards=True)
print("Rewards:", rewards)

.. testoutput::
Expand Down
20 changes: 10 additions & 10 deletions docs/algorithms/bc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ Detailed example notebook: :doc:`../tutorials/1_train_bc`
.. testcode::
:skipif: skip_doctests

import gym
import seals # needed to load "seals/" environments
import numpy as np
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv

from imitation.algorithms import bc
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
import seals # noqa: F401 # needed to load "seals/" environments
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
from imitation.data.wrappers import RolloutInfoWrapper

rng = np.random.default_rng(0)
env = gym.make("seals/CartPole-v0")
env = make_vec_env(
"seals/CartPole-v0",
rng=rng,
n_envs=1,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], # for computing rollouts
)
expert = load_policy(
"ppo-huggingface",
organization="HumanCompatibleAI",
Expand All @@ -41,7 +41,7 @@ Detailed example notebook: :doc:`../tutorials/1_train_bc`
)
rollouts = rollout.rollout(
expert,
DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
env,
rollout.make_sample_until(min_timesteps=None, min_episodes=50),
rng=rng,
)
Expand Down
13 changes: 7 additions & 6 deletions docs/algorithms/dagger.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,26 @@ Detailed example notebook: :doc:`../tutorials/2_train_dagger`

import tempfile
import numpy as np
import gym
import seals # needed to load "seals/" environments
import seals # noqa: F401 # needed to load "seals/" environments
import numpy as np
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv

from imitation.algorithms import bc
from imitation.algorithms.dagger import SimpleDAggerTrainer
from imitation.policies.serialize import load_policy

rng = np.random.default_rng(0)
env = gym.make("seals/CartPole-v0")
env = make_vec_env(
"seals/CartPole-v0",
rng=rng,
n_envs=1,
)
expert = load_policy(
"ppo-huggingface",
organization="HumanCompatibleAI",
env_name="seals-CartPole-v0",
venv=env,
)
venv = DummyVecEnv([lambda: gym.make("seals/CartPole-v0")])

bc_trainer = bc.BC(
observation_space=env.observation_space,
Expand All @@ -53,7 +54,7 @@ Detailed example notebook: :doc:`../tutorials/2_train_dagger`
with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir:
print(tmpdir)
dagger_trainer = SimpleDAggerTrainer(
venv=venv,
venv=env,
scratch_dir=tmpdir,
expert_policy=expert,
bc_trainer=bc_trainer,
Expand Down
29 changes: 13 additions & 16 deletions docs/algorithms/gail.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail`
.. testcode::
:skipif: skip_doctests

import gym
import seals # needed to load "seals/" environments
import numpy as np
import seals # noqa: F401 # needed to load "seals/" environments
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy
Expand All @@ -36,45 +35,43 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail`

rng = np.random.default_rng(0)

env = gym.make("seals/CartPole-v0")
env = make_vec_env(
"seals/CartPole-v0",
rng=rng,
n_envs=8,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], # for computing rollouts
)
expert = load_policy(
"ppo-huggingface",
organization="HumanCompatibleAI",
env_name="seals-CartPole-v0",
venv=env,
)

rollouts = rollout.rollout(
expert,
make_vec_env(
"seals/CartPole-v0",
n_envs=5,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
rng=rng,
),
env,
rollout.make_sample_until(min_timesteps=None, min_episodes=60),
rng=rng,
)

venv = make_vec_env("seals/CartPole-v0", n_envs=8, rng=rng)
learner = PPO(env=venv, policy=MlpPolicy)
learner = PPO(env=env, policy=MlpPolicy)
reward_net = BasicRewardNet(
venv.observation_space,
venv.action_space,
env.observation_space,
env.action_space,
normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
demonstrations=rollouts,
demo_batch_size=1024,
gen_replay_buffer_capacity=2048,
n_disc_updates_per_round=4,
venv=venv,
venv=env,
gen_algo=learner,
reward_net=reward_net,
)

gail_trainer.train(20000)
rewards, _ = evaluate_policy(learner, venv, 100, return_episode_rewards=True)
rewards, _ = evaluate_policy(learner, env, 100, return_episode_rewards=True)
print("Rewards:", rewards)

.. testoutput::
Expand Down

0 comments on commit 441d8df

Please sign in to comment.