From fa7a3168f31d9331a72b099468763ed4abec34b6 Mon Sep 17 00:00:00 2001 From: BertrandDecoster <70576987+BertrandDecoster@users.noreply.github.com> Date: Tue, 18 Jul 2023 13:02:47 +0200 Subject: [PATCH 1/7] Update the Callbacks: Evaluate Agent Performance section of the Examples (#1604) * Update examples.rst section "Callbacks: Evaluate Agent Performance" Two typos fixed * Update changelog --------- Co-authored-by: Antonin Raffin --- docs/guide/examples.rst | 4 ++-- docs/misc/changelog.rst | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index e5b9db0fc..7a0586c33 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -319,7 +319,7 @@ You can control the evaluation frequency with ``eval_freq`` to monitor your agen from stable_baselines3 import SAC from stable_baselines3.common.callbacks import EvalCallback - from stable-baselines3.common.env_util import make_vec_env + from stable_baselines3.common.env_util import make_vec_env env_id = "Pendulum-v1" n_training_envs = 1 @@ -330,7 +330,7 @@ You can control the evaluation frequency with ``eval_freq`` to monitor your agen os.makedirs(eval_log_dir, exist_ok=True) # Initialize a vectorized training environment with default parameters - train_env = make_vec_env(env_id, n_env=n_training_envs, seed=0) + train_env = make_vec_env(env_id, n_envs=n_training_envs, seed=0) # Separate evaluation env, with different parameters passed via env_kwargs # Eval environments can be vectorized to speed up evaluation. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 7bb1f16a1..38255d0d1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -35,6 +35,7 @@ Others: Documentation: ^^^^^^^^^^^^^^ +- Fixed callback example (@BertrandDecoster) Release 2.0.0 (2023-06-22) @@ -1395,4 +1396,4 @@ And all the contributors: @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto -@lutogniew @lbergmann1 @lukashass +@lutogniew @lbergmann1 @lukashass @BertrandDecoster From 61e106052555f8f049bb5a0c981b3dc5899244e1 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Tue, 18 Jul 2023 13:22:22 +0100 Subject: [PATCH 2/7] Update Gymnasium to v0.29.0 (#1610) * Update setup.py to v0.29.0 * Remove invalid test * Loosen version and update changelog --------- Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 5 +++-- setup.py | 2 +- stable_baselines3/version.txt | 2 +- tests/test_envs.py | 2 -- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 38255d0d1..eaa761da4 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 2.1.0a0 (WIP) +Release 2.1.0a1 (WIP) -------------------------- Breaking Changes: @@ -15,6 +15,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Added Python 3.11 support +- Added Gymnasium 0.29 support (@pseudo-rnd-thoughts) `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -1396,4 +1397,4 @@ And all the contributors: @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto -@lutogniew @lbergmann1 @lukashass @BertrandDecoster +@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts diff --git a/setup.py b/setup.py index 6ddb9b65d..deb9f5498 100644 --- a/setup.py +++ b/setup.py @@ -100,7 +100,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gymnasium==0.28.1", + "gymnasium>=0.28.1,<0.30", "numpy>=1.20", "torch>=1.13", # For saving models diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index ecaf4eea7..0e4065ce8 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.1.0a0 +2.1.0a1 diff --git a/tests/test_envs.py b/tests/test_envs.py index e6c973852..e82ef5768 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -156,8 +156,6 @@ def patched_step(_action): spaces.Box(low=-1000, high=1000, shape=(3,), dtype=np.float32), # Too small range spaces.Box(low=-0.1, high=0.1, shape=(2,), dtype=np.float32), - # Inverted boundaries - spaces.Box(low=1, high=-1, shape=(2,), dtype=np.float32), # Same boundaries spaces.Box(low=1, high=1, shape=(2,), dtype=np.float32), # Unbounded action space From a730b9b66a666d7263c4dd43706eb566851e5726 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 21 Jul 2023 07:02:38 +0200 Subject: [PATCH 3/7] Relax logger check for Windows (#1615) * Relax logger check for Windows * Update tests --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/logger.py | 6 ++++-- stable_baselines3/version.txt | 2 +- tests/test_logger.py | 5 +++-- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index eaa761da4..4ecf61b73 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 2.1.0a1 (WIP) +Release 2.1.0a2 (WIP) -------------------------- Breaking Changes: @@ -25,6 +25,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Relaxed check in logger, that was causing issue on Windows with colorama Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 4d0d3461e..3955131c5 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -164,8 +164,10 @@ def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36): if isinstance(filename_or_file, str): self.file = open(filename_or_file, "w") self.own_file = True - elif isinstance(filename_or_file, TextIOBase): - self.file = filename_or_file + elif isinstance(filename_or_file, TextIOBase) or hasattr(filename_or_file, "write"): + # Note: in theory `TextIOBase` check should be sufficient, + # in practice, libraries don't always inherit from it, see GH#1598 + self.file = filename_or_file # type: ignore[assignment] self.own_file = False else: raise ValueError(f"Expected file or str, got {filename_or_file}") diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 0e4065ce8..55c98c93c 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.1.0a1 +2.1.0a2 diff --git a/tests/test_logger.py b/tests/test_logger.py index 9d275a2ec..05bf196a3 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -437,8 +437,9 @@ def test_ep_buffers_stats_window_size(algo, stats_window_size): assert model.ep_success_buffer.maxlen == stats_window_size -def test_human_output_format_custom_test_io(): - class DummyTextIO(TextIOBase): +@pytest.mark.parametrize("base_class", [object, TextIOBase]) +def test_human_output_format_custom_test_io(base_class): + class DummyTextIO(base_class): def __init__(self) -> None: super().__init__() self.lines = [[]] From 5abd50a853e0f667f48ae10769f478c4972eda35 Mon Sep 17 00:00:00 2001 From: Stefan Schneider <28340802+stefanbschneider@users.noreply.github.com> Date: Fri, 21 Jul 2023 16:33:01 +0200 Subject: [PATCH 4/7] Docs: Add mobile-env to community projects (#1617) * Docs: Add mobile-env to community projects * Update docs Readme with correct install command Without the quotes, I get `no matches found: .[docs]` * Add changelog entry for adding mobile-env * Fix format in projects.rst Co-authored-by: Antonin RAFFIN --------- Co-authored-by: Antonin RAFFIN --- docs/README.md | 2 +- docs/misc/changelog.rst | 3 ++- docs/misc/projects.rst | 15 +++++++++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/docs/README.md b/docs/README.md index 169a5e3db..1fc4d762e 100644 --- a/docs/README.md +++ b/docs/README.md @@ -8,7 +8,7 @@ This folder contains documentation for the RL baselines. #### Install Sphinx and Theme Execute this command in the project root: ``` -pip install -e .[docs] +pip install -e ".[docs]" ``` #### Building the Docs diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 4ecf61b73..e16e4b279 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -38,6 +38,7 @@ Others: Documentation: ^^^^^^^^^^^^^^ - Fixed callback example (@BertrandDecoster) +- Added mobile-env as new community project (@stefanbschneider) Release 2.0.0 (2023-06-22) @@ -1398,4 +1399,4 @@ And all the contributors: @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto -@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts +@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider diff --git a/docs/misc/projects.rst b/docs/misc/projects.rst index 9d05d54c1..53ad8f749 100644 --- a/docs/misc/projects.rst +++ b/docs/misc/projects.rst @@ -197,3 +197,18 @@ A simple library for pink noise exploration with deterministic (DDPG / TD3) and | Authors: Onno Eberhard, Jakob Hollenstein, Cristina Pinneri, Georg Martius | Github: https://github.com/martius-lab/pink-noise-rl | Paper: https://openreview.net/forum?id=hQ9V5QN27eS (Oral at ICLR 2023) + + +mobile-env +---------- + +An open, minimalist Gymnasium environment for autonomous coordination in wireless mobile networks. +It allows simulating various scenarios with moving users in a cellular network with multiple base stations. + +- Written in pure Python, easy to modify and extend, and can be installed directly via PyPI. +- Implements the standard Gymnasium interface such that it can be used with all common frameworks for reinforcement learning. +- There are examples for both single-agent and multi-agent RL using either `stable-baselines3` or Ray RLlib. + +| Authors: Stefan Schneider, Stefan Werner +| Github: https://github.com/stefanbschneider/mobile-env +| Paper: https://ris.uni-paderborn.de/download/30236/30237 (2022 IEEE/IFIP Network Operations and Management Symposium (NOMS)) From 72c124d90727170dca2be9ca9b82ab154d9bfe4c Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 24 Jul 2023 14:38:22 +0200 Subject: [PATCH 5/7] Ignore pytype error (#1623) --- stable_baselines3/common/monitor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 5253954e8..f421a4df2 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -193,7 +193,9 @@ def __init__( mode = "w" if override_existing else "a" # Prevent newline issue on Windows, see GH issue #692 self.file_handler = open(filename, f"{mode}t", newline="\n") - self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t", *extra_keys)) + self.logger = csv.DictWriter( + self.file_handler, fieldnames=("r", "l", "t", *extra_keys) + ) # pytype: disable=wrong-arg-types if override_existing: self.file_handler.write(f"#{json.dumps(header)}\n") self.logger.writeheader() From ba77dd7c6180c0ec9a47dfa98291c2103e6750df Mon Sep 17 00:00:00 2001 From: Tobias Rohrer Date: Mon, 24 Jul 2023 16:38:03 +0200 Subject: [PATCH 6/7] Fix to use float64 actions for off policy algorithms (#1572) * Added test cases where off policy algorithms fail with float64 actionspace * casting observations and actions to `np.float32` to unify behaviour between `ReplayBuffer` and `RolloutBuffer`. Fixing issue #1145 * reformatted using black * making test more restrictive by checking models action is float64 * added changelog entry * undo cast of observations as `preprocessing.preprocess_obs()` casts them to float32 anyways. * - Casting to float32 only, if action.dtype is float64 - Added cast to `DictReplayBuffer` as well * Added tests for multiple variations of continuous action types and observation spaces * applied reformatting by `make commit-checks` * Added typing and comment referring to description in merge request * Apply linter for single element slice * Rename helper and refactor tests * Update changelog and docstring --------- Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 4 +- stable_baselines3/common/buffers.py | 23 ++++++- stable_baselines3/dqn/dqn.py | 2 +- stable_baselines3/version.txt | 2 +- tests/test_spaces.py | 101 +++++++++++++++++++--------- 5 files changed, 97 insertions(+), 35 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e16e4b279..039a094ab 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 2.1.0a2 (WIP) +Release 2.1.0a3 (WIP) -------------------------- Breaking Changes: @@ -26,6 +26,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Relaxed check in logger, that was causing issue on Windows with colorama +- Fixed off-policy algorithms with continuous float64 actions (see #1145) (@tobirohrer) Deprecations: ^^^^^^^^^^^^^ @@ -34,6 +35,7 @@ Others: ^^^^^^^ - Updated GitHub issue templates - Fix typo in gym patch error message (@lukashass) +- Refactor ``test_spaces.py`` tests Documentation: ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index fe633e1af..576e10a8b 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -207,7 +207,9 @@ def __init__( else: self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype) - self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype) + self.actions = np.zeros( + (self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype) + ) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) @@ -311,6 +313,21 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non ) return ReplayBufferSamples(*tuple(map(self.to_torch, data))) + @staticmethod + def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike: + """ + Cast `np.float64` action datatype to `np.float32`, + keep the others dtype unchanged. + See GH#1572 for more information. + + :param dtype: The original action space dtype + :return: ``np.float32`` if the dtype was float64, + the original dtype otherwise. + """ + if dtype == np.float64: + return np.float32 + return dtype + class RolloutBuffer(BaseBuffer): """ @@ -543,7 +560,9 @@ def __init__( for key, _obs_shape in self.obs_shape.items() } - self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype) + self.actions = np.zeros( + (self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype) + ) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 6b44254cc..42e3d0df0 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -245,7 +245,7 @@ def predict( if not deterministic and np.random.rand() < self.exploration_rate: if self.policy.is_vectorized_observation(observation): if isinstance(observation, dict): - n_batch = observation[list(observation.keys())[0]].shape[0] + n_batch = observation[next(iter(observation.keys()))].shape[0] else: n_batch = observation.shape[0] action = np.array([self.action_space.sample() for _ in range(n_batch)]) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 55c98c93c..a4a6a877e 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.1.0a2 +2.1.0a3 diff --git a/tests/test_spaces.py b/tests/test_spaces.py index fb70d0a33..e4a933976 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,63 +1,67 @@ +from dataclasses import dataclass from typing import Dict, Optional import gymnasium as gym import numpy as np import pytest from gymnasium import spaces +from gymnasium.spaces.space import Space from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.evaluation import evaluate_policy +BOX_SPACE_FLOAT64 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float64) +BOX_SPACE_FLOAT32 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) -class DummyMultiDiscreteSpace(gym.Env): - def __init__(self, nvec): - super().__init__() - self.observation_space = spaces.MultiDiscrete(nvec) - self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): - if seed is not None: - super().reset(seed=seed) - return self.observation_space.sample(), {} +@dataclass +class DummyEnv(gym.Env): + observation_space: Space + action_space: Space def step(self, action): return self.observation_space.sample(), 0.0, False, False, {} - -class DummyMultiBinary(gym.Env): - def __init__(self, n): - super().__init__() - self.observation_space = spaces.MultiBinary(n) - self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} - def step(self, action): - return self.observation_space.sample(), 0.0, False, False, {} - -class DummyMultidimensionalAction(gym.Env): +class DummyMultidimensionalAction(DummyEnv): def __init__(self): - super().__init__() - self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32) + super().__init__( + BOX_SPACE_FLOAT32, + spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32), + ) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): - if seed is not None: - super().reset(seed=seed) - return self.observation_space.sample(), {} - def step(self, action): - return self.observation_space.sample(), 0.0, False, False, {} +class DummyMultiBinary(DummyEnv): + def __init__(self, n): + super().__init__( + spaces.MultiBinary(n), + BOX_SPACE_FLOAT32, + ) + + +class DummyMultiDiscreteSpace(DummyEnv): + def __init__(self, nvec): + super().__init__( + spaces.MultiDiscrete(nvec), + BOX_SPACE_FLOAT32, + ) @pytest.mark.parametrize( - "env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2)), DummyMultidimensionalAction()] + "env", + [ + DummyMultiDiscreteSpace([4, 3]), + DummyMultiBinary(8), + DummyMultiBinary((3, 2)), + DummyMultidimensionalAction(), + ], ) def test_env(env): # Check the env used for testing @@ -127,3 +131,40 @@ def test_discrete_obs_space(model_class, env): else: kwargs = dict(n_steps=256) model_class("MlpPolicy", env, **kwargs).learn(256) + + +@pytest.mark.parametrize("model_class", [SAC, TD3, PPO, DDPG, A2C]) +@pytest.mark.parametrize( + "obs_space", + [ + BOX_SPACE_FLOAT32, + BOX_SPACE_FLOAT64, + spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT32}), + spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT64}), + ], +) +@pytest.mark.parametrize( + "action_space", + [ + BOX_SPACE_FLOAT32, + BOX_SPACE_FLOAT64, + ], +) +def test_float64_action_space(model_class, obs_space, action_space): + env = DummyEnv(obs_space, action_space) + env = gym.wrappers.TimeLimit(env, max_episode_steps=200) + if isinstance(env.observation_space, spaces.Dict): + policy = "MultiInputPolicy" + else: + policy = "MlpPolicy" + + if model_class in [PPO, A2C]: + kwargs = dict(n_steps=64, policy_kwargs=dict(net_arch=[12])) + else: + kwargs = dict(learning_starts=60, policy_kwargs=dict(net_arch=[12])) + + model = model_class(policy, env, **kwargs) + model.learn(64) + initial_obs, _ = env.reset() + action, _ = model.predict(initial_obs, deterministic=False) + assert action.dtype == env.action_space.dtype From d43400b46460522bfd5b26518357721af3e087ce Mon Sep 17 00:00:00 2001 From: Kyle He Date: Tue, 1 Aug 2023 04:20:29 -0700 Subject: [PATCH 7/7] Fix typo in the documentation for Custom Policy Networks (#1620) * Update custom_policy.rst * Update changelog.rst --------- Co-authored-by: Antonin RAFFIN --- docs/guide/custom_policy.rst | 2 +- docs/misc/changelog.rst | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index a136bfa59..0807498e4 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -215,7 +215,7 @@ downsampling and "vector" with a single linear layer. from stable_baselines3.common.torch_layers import BaseFeaturesExtractor class CustomCombinedExtractor(BaseFeaturesExtractor): - def __init__(self, observation_space: spaces.Dict): + def __init__(self, observation_space: gym.spaces.Dict): # We do not know features-dim here before going over all the items, # so put something dummy for now. PyTorch requires calling # nn.Module.__init__ before adding modules diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 039a094ab..4cff7277f 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -40,6 +40,7 @@ Others: Documentation: ^^^^^^^^^^^^^^ - Fixed callback example (@BertrandDecoster) +- Fixed policy network example (@kyle-he) - Added mobile-env as new community project (@stefanbschneider) @@ -1401,4 +1402,4 @@ And all the contributors: @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto -@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider +@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he