Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix env checker bounds, expose all invalid indices at once #1638

Merged
merged 5 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 2.1.0a3 (WIP)
Release 2.1.0a4 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -27,7 +27,8 @@ Bug Fixes:
^^^^^^^^^^
- Relaxed check in logger, that was causing issue on Windows with colorama
- Fixed off-policy algorithms with continuous float64 actions (see #1145) (@tobirohrer)

- Fixed env_checker.py warning messages for out of bounds in complex observation spaces (@Gabo-Tor)

Deprecations:
^^^^^^^^^^^^^

Expand Down Expand Up @@ -1398,7 +1399,7 @@ And all the contributors:
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@carlosluis @arjun-kg @tlpss @JonathanKuelz
@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
Expand Down
30 changes: 18 additions & 12 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,24 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac
f"Expected: {observation_space.dtype}, actual dtype: {obs.dtype}"
)
if isinstance(observation_space, spaces.Box):
assert np.all(obs >= observation_space.low), (
f"The observation returned by the `{method_name}()` method does not match the lower bound "
f"of the given observation space {observation_space}."
f"Expected: obs >= {np.min(observation_space.low)}, "
f"actual min value: {np.min(obs)} at index {np.argmin(obs)}"
)
assert np.all(obs <= observation_space.high), (
f"The observation returned by the `{method_name}()` method does not match the upper bound "
f"of the given observation space {observation_space}. "
f"Expected: obs <= {np.max(observation_space.high)}, "
f"actual max value: {np.max(obs)} at index {np.argmax(obs)}"
)
lower_bounds, upper_bounds = observation_space.low, observation_space.high
# Expose all invalid indices at once
invalid_indices = np.where(np.logical_or(obs < lower_bounds, obs > upper_bounds))
if (obs > upper_bounds).any() or (obs < lower_bounds).any():
message = (
f"The observation returned by the `{method_name}()` method does not match the bounds "
f"of the given observation space {observation_space}. \n"
)
message += f"{len(invalid_indices[0])} invalid indices: \n"

for index in zip(*invalid_indices):
index_str = ",".join(map(str, index))
message += (
f"Expected: {lower_bounds[index]} <= obs[{index_str}] <= {upper_bounds[index]}, "
f"actual value: {obs[index]} \n"
)

raise AssertionError(message)

assert observation_space.contains(obs), (
f"The observation returned by the `{method_name}()` method "
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.1.0a3
2.1.0a4
23 changes: 18 additions & 5 deletions tests/test_env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,28 @@ def test_check_env_dict_action():
[
# Above upper bound
(
spaces.Box(low=0.0, high=1.0, shape=(3,), dtype=np.float32),
spaces.Box(low=np.array([0.0, 0.0, 0.0]), high=np.array([2.0, 1.0, 1.0]), shape=(3,), dtype=np.float32),
np.array([1.0, 1.5, 0.5], dtype=np.float32),
r"Expected: obs <= 1\.0, actual max value: 1\.5 at index 1",
r"Expected: 0\.0 <= obs\[1] <= 1\.0, actual value: 1\.5",
),
# Above upper bound (multi-dim)
(
spaces.Box(low=-1.0, high=2.0, shape=(2, 3, 3, 1), dtype=np.float32),
3.0 * np.ones((2, 3, 3, 1), dtype=np.float32),
# Note: this is one of the 18 invalid indices
r"Expected: -1\.0 <= obs\[1,2,1,0\] <= 2\.0, actual value: 3\.0",
),
# Below lower bound
(
spaces.Box(low=0.0, high=2.0, shape=(3,), dtype=np.float32),
spaces.Box(low=np.array([0.0, -10.0, 0.0]), high=np.array([2.0, 1.0, 1.0]), shape=(3,), dtype=np.float32),
np.array([-1.0, 1.5, 0.5], dtype=np.float32),
r"Expected: obs >= 0\.0, actual min value: -1\.0 at index 0",
r"Expected: 0\.0 <= obs\[0] <= 2\.0, actual value: -1\.0",
),
# Below lower bound (multi-dim)
(
spaces.Box(low=-1.0, high=2.0, shape=(2, 3, 3, 1), dtype=np.float32),
-2 * np.ones((2, 3, 3, 1), dtype=np.float32),
r"18 invalid indices:",
),
# Wrong dtype
(
Expand Down Expand Up @@ -111,7 +124,7 @@ def step(self, action):

test_env = TestEnv()
with pytest.raises(AssertionError, match=error_message):
check_env(env=test_env)
check_env(env=test_env, warn=False)


class LimitedStepsTestEnv(gym.Env):
Expand Down