Skip to content

Commit

Permalink
Revert 9_compare_baselines.ipynb to (almost) the version on master
Browse files Browse the repository at this point in the history
  • Loading branch information
jas-ho committed Aug 8, 2023
1 parent 50e5970 commit b9e07e9
Showing 1 changed file with 24 additions and 26 deletions.
50 changes: 24 additions & 26 deletions docs/tutorials/9_compare_baselines.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"As in the first tutorial, we will start by downloading an expert from the HuggingFace model hub."
"We will start by training a good (but not perfect) expert."
]
},
{
Expand All @@ -25,26 +25,22 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import seals # noqa: F401 # needed to load \"seals/\" environments\n",
"from imitation.policies.serialize import load_policy\n",
"from imitation.util.util import make_vec_env\n",
"from imitation.data.wrappers import RolloutInfoWrapper\n",
"import gym\n",
"from stable_baselines3 import PPO\n",
"from stable_baselines3.ppo import MlpPolicy\n",
"\n",
"env = make_vec_env(\n",
" \"seals/CartPole-v0\",\n",
" rng=np.random.default_rng(),\n",
" n_envs=1,\n",
" post_wrappers=[\n",
" lambda env, _: RolloutInfoWrapper(env)\n",
" ], # needed for computing rollouts later\n",
"env = gym.make(\"CartPole-v1\")\n",
"expert = PPO(\n",
" policy=MlpPolicy,\n",
" env=env,\n",
" seed=0,\n",
" batch_size=64,\n",
" ent_coef=0.0,\n",
" learning_rate=0.0003,\n",
" n_epochs=10,\n",
" n_steps=64,\n",
")\n",
"expert = load_policy(\n",
" \"ppo-huggingface\",\n",
" organization=\"HumanCompatibleAI\",\n",
" env_name=\"seals-CartPole-v0\",\n",
" venv=env,\n",
")"
"expert.learn(10_000) # set to 100_000 for better performance"
]
},
{
Expand All @@ -60,9 +56,6 @@
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3 import PPO\n",
"from stable_baselines3.ppo import MlpPolicy\n",
"\n",
"not_expert = PPO(\n",
" policy=MlpPolicy,\n",
" env=env,\n",
Expand Down Expand Up @@ -166,6 +159,8 @@
"metadata": {},
"outputs": [],
"source": [
"from imitation.testing.reward_improvement import is_significant_reward_improvement\n",
"\n",
"expert_rewards, _ = evaluate_policy(expert, env, 100, return_episode_rewards=True)\n",
"not_expert_rewards, _ = evaluate_policy(\n",
" not_expert, env, 100, return_episode_rewards=True\n",
Expand Down Expand Up @@ -208,11 +203,14 @@
"outputs": [],
"source": [
"from imitation.data import rollout\n",
"from imitation.data.wrappers import RolloutInfoWrapper\n",
"from stable_baselines3.common.vec_env import DummyVecEnv\n",
"import numpy as np\n",
"\n",
"rng = np.random.default_rng()\n",
"expert_rollouts = rollout.rollout(\n",
" expert,\n",
" env,\n",
" DummyVecEnv([lambda: RolloutInfoWrapper(env)]),\n",
" rollout.make_sample_until(min_timesteps=None, min_episodes=50),\n",
" rng=rng,\n",
")\n",
Expand All @@ -221,7 +219,7 @@
"\n",
"not_expert_rollouts = rollout.rollout(\n",
" not_expert,\n",
" env,\n",
" DummyVecEnv([lambda: RolloutInfoWrapper(env)]),\n",
" rollout.make_sample_until(min_timesteps=None, min_episodes=50),\n",
" rng=rng,\n",
")\n",
Expand Down Expand Up @@ -348,7 +346,7 @@
"source": [
"rollouts = rollout.rollout(\n",
" expert,\n",
" env,\n",
" DummyVecEnv([lambda: RolloutInfoWrapper(env)]),\n",
" rollout.make_sample_until(min_timesteps=None, min_episodes=1),\n",
" rng=rng,\n",
")\n",
Expand Down Expand Up @@ -394,7 +392,7 @@
" rng=np.random.default_rng(),\n",
" )\n",
" dagger_trainer = SimpleDAggerTrainer(\n",
" venv=env,\n",
" venv=DummyVecEnv([lambda: RolloutInfoWrapper(env)]),\n",
" scratch_dir=tmpdir,\n",
" expert_policy=expert,\n",
" bc_trainer=dagger_bc_trainer,\n",
Expand Down

0 comments on commit b9e07e9

Please sign in to comment.