Skip to content

Commit

Permalink
Fix double reset bug (hill-a#968)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Aug 4, 2020
1 parent 6fbc9a9 commit b21e4dc
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ Bug Fixes:
- Fixed a bug in the ``close()`` method of ``SubprocVecEnv``, causing wrappers further down in the wrapper stack to not be closed. (@NeoExtended)
- Fixed a bug in the ``generate_expert_traj()`` method in ``record_expert.py`` when using a non-image vectorized environment (@jbarsce)
- Fixed a bug in CloudPickleWrapper's (used by VecEnvs) ``__setstate___`` where loading was incorrectly using ``pickle.loads`` (@shwang).
- Fixed a bug in ``SAC`` and ``TD3`` where the log timesteps was not correct(@YangRui2015)
- Fixed a bug in ``SAC`` and ``TD3`` where the log timesteps was not correct(@YangRui2015)
- Fixed a bug where the environment was reset twice when using ``evaluate_policy``


Deprecations:
Expand Down
39 changes: 26 additions & 13 deletions stable_baselines/common/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
import typing
from typing import Callable, List, Optional, Tuple, Union

import gym
import numpy as np

from stable_baselines.common.vec_env import VecEnv

if typing.TYPE_CHECKING:
from stable_baselines.common.base_class import BaseRLModel


def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True,
render=False, callback=None, reward_threshold=None,
return_episode_rewards=False):
def evaluate_policy(
model: "BaseRLModel",
env: Union[gym.Env, VecEnv],
n_eval_episodes: int = 10,
deterministic: bool = True,
render: bool = False,
callback: Optional[Callable] = None,
reward_threshold: Optional[float] = None,
return_episode_rewards: bool = False,
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
"""
Runs policy for `n_eval_episodes` episodes and returns average reward.
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
This is made to work only with one env.
:param model: (BaseRLModel) The RL agent you want to evaluate.
:param env: (gym.Env or VecEnv) The gym environment. In the case of a `VecEnv`
:param env: (gym.Env or VecEnv) The gym environment. In the case of a ``VecEnv``
this must contain only one environment.
:param n_eval_episodes: (int) Number of episode to evaluate the agent
:param deterministic: (bool) Whether to use deterministic or stochastic actions
Expand All @@ -20,17 +34,19 @@ def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True,
called after each step.
:param reward_threshold: (float) Minimum expected reward per episode,
this will raise an error if the performance is not met
:param return_episode_rewards: (bool) If True, a list of reward per episode
:param return_episode_rewards: (Optional[float]) If True, a list of reward per episode
will be returned instead of the mean.
:return: (float, float) Mean reward per episode, std of reward per episode
returns ([float], [int]) when `return_episode_rewards` is True
returns ([float], [int]) when ``return_episode_rewards`` is True
"""
if isinstance(env, VecEnv):
assert env.num_envs == 1, "You must pass only one environment when using this function"

episode_rewards, episode_lengths = [], []
for _ in range(n_eval_episodes):
obs = env.reset()
for i in range(n_eval_episodes):
# Avoid double reset, as VecEnv are reset automatically
if not isinstance(env, VecEnv) or i == 0:
obs = env.reset()
done, state = False, None
episode_reward = 0.0
episode_length = 0
Expand All @@ -45,13 +61,10 @@ def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True,
env.render()
episode_rewards.append(episode_reward)
episode_lengths.append(episode_length)

mean_reward = np.mean(episode_rewards)
std_reward = np.std(episode_rewards)

if reward_threshold is not None:
assert mean_reward > reward_threshold, 'Mean reward below threshold: '\
'{:.2f} < {:.2f}'.format(mean_reward, reward_threshold)
assert mean_reward > reward_threshold, "Mean reward below threshold: {:.2f} < {:.2f}".format(mean_reward, reward_threshold)
if return_episode_rewards:
return episode_rewards, episode_lengths
return mean_reward, std_reward

0 comments on commit b21e4dc

Please sign in to comment.