Skip to content

Commit

Permalink
Update Gymnasium to v1.0.0 (#1837)
Browse files Browse the repository at this point in the history
* Update Gymnasium to v1.0.0a1

* Comment out `gymnasium.wrappers.monitor` (todo update to VideoRecord)

* Fix ruff warnings

* Register Atari envs

* Update `getattr` to `Env.get_wrapper_attr`

* Reorder imports

* Fix `seed` order

* Fix collecting `max_steps`

* Copy and paste video recorder to prevent the need to rewrite the vec vide recorder wrapper

* Use `typing.List` rather than list

* Fix env attribute forwarding

* Separate out env attribute collection from its utilisation

* Update for Gymnasium alpha 2

* Remove assert for OrderedDict

* Update setup.py

* Add type: ignore

* Test with Gymnasium main

* Remove `gymnasium.logger.debug/info`

* Fix github CI yaml

* Run gym 0.29.1 on python 3.10

* Update lower bounds

* Integrate video recorder

* Remove ordered dict

* Update changelog

---------

Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
pseudo-rnd-thoughts and araffin authored Nov 4, 2024
1 parent dd3d0ac commit 8f0b488
Show file tree
Hide file tree
Showing 16 changed files with 148 additions and 120 deletions.
20 changes: 12 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ jobs:
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]

include:
# Default version
- gymnasium-version: "1.0.0"
# Add a new config to test gym<1.0
- python-version: "3.10"
gymnasium-version: "0.29.1"
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -37,15 +42,14 @@ jobs:
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
# Install Atari Roms
uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
uv pip install --system .[extra_no_roms,tests,docs]
uv pip install --system .[extra,tests,docs]
# Use headless version
uv pip install --system opencv-python-headless
- name: Install specific version of gym
run: |
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
# Only run for python 3.10, downgrade gym to 0.29.1
if: matrix.gymnasium-version != '1.0.0'
- name: Lint with ruff
run: |
make lint
Expand Down
2 changes: 1 addition & 1 deletion docs/conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dependencies:
- python=3.11
- pytorch=2.5.0=py3.11_cpu_0
- pip:
- gymnasium>=0.28.1,<0.30
- gymnasium>=0.29.1,<1.1.0
- cloudpickle
- opencv-python-headless
- pandas
Expand Down
7 changes: 5 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
Changelog
==========

Release 2.4.0a10 (WIP)
Release 2.4.0a11 (WIP)
--------------------------

**New algorithm: CrossQ in SB3 Contrib**
**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support**

.. note::

Expand All @@ -24,12 +24,14 @@ Release 2.4.0a10 (WIP)

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Increase minimum required version of Gymnasium to 0.29.1

New Features:
^^^^^^^^^^^^^
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)
- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces
- Added support for Gymnasium v1.0

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -69,6 +71,7 @@ Others:
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
- Switched to uv to download packages faster on GitHub CI
- Updated dependencies for read the doc
- Removed unnecessary ``copy_obs_dict`` method for ``SubprocVecEnv``, remove the use of ordered dict and rename ``flatten_obs`` to ``stack_obs``

Bug Fixes:
^^^^^^^^^^
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ ignore = ["B028", "RUF013"]
# ClassVar, implicit optional check not needed for tests
"./tests/*.py" = ["RUF012", "RUF013"]


[tool.ruff.lint.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 15
Expand Down
43 changes: 16 additions & 27 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,37 +70,13 @@
""" # noqa:E501

# Atari Games download is sometimes problematic:
# https://github.com/Farama-Foundation/AutoROM/issues/39
# That's why we define extra packages without it.
extra_no_roms = [
# For render
"opencv-python",
"pygame",
# Tensorboard support
"tensorboard>=2.9.1",
# Checking memory taken by replay buffer
"psutil",
# For progress bar callback
"tqdm",
"rich",
# For atari games,
"shimmy[atari]~=1.3.0",
"pillow",
]

extra_packages = extra_no_roms + [ # noqa: RUF005
# For atari roms,
"autorom[accept-rom-license]~=0.6.1",
]


setup(
name="stable_baselines3",
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gymnasium>=0.28.1,<0.30",
"gymnasium>=0.29.1,<1.1.0",
"numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
"torch>=1.13",
# For saving models
Expand Down Expand Up @@ -133,8 +109,21 @@
# Copy button for code snippets
"sphinx_copybutton",
],
"extra": extra_packages,
"extra_no_roms": extra_no_roms,
"extra": [
# For render
"opencv-python",
"pygame",
# Tensorboard support
"tensorboard>=2.9.1",
# Checking memory taken by replay buffer
"psutil",
# For progress bar callback
"tqdm",
"rich",
# For atari games,
"ale-py>=0.9.0",
"pillow",
],
},
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
author="Antonin Raffin",
Expand Down
8 changes: 4 additions & 4 deletions stable_baselines3/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
from stable_baselines3.common.vec_env.patch_gym import _patch_env
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info


class DummyVecEnv(VecEnv):
Expand Down Expand Up @@ -110,12 +110,12 @@ def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload]

def _obs_from_buf(self) -> VecEnvObs:
return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
return dict_to_obs(self.observation_space, deepcopy(self.buf_obs))

def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, attr_name) for env_i in target_envs]
return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs]

def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
"""Set attribute inside vectorized environments (see base class)."""
Expand All @@ -126,7 +126,7 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) ->
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
"""Call instance methods of vectorized environments."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs]

def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/vec_env/patch_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma
"Missing shimmy installation. You provided an OpenAI Gym environment. "
"Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. "
"In order to use OpenAI Gym environments with SB3, you need to "
"install shimmy (`pip install 'shimmy>=0.2.1'`)."
"install shimmy (`pip install 'shimmy>=2.0'`)."
) from e

warnings.warn(
Expand Down
34 changes: 17 additions & 17 deletions stable_baselines3/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import multiprocessing as mp
import warnings
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union

import gymnasium as gym
Expand Down Expand Up @@ -54,10 +53,10 @@ def _worker(
elif cmd == "get_spaces":
remote.send((env.observation_space, env.action_space))
elif cmd == "env_method":
method = getattr(env, data[0])
method = env.get_wrapper_attr(data[0])
remote.send(method(*data[1], **data[2]))
elif cmd == "get_attr":
remote.send(getattr(env, data))
remote.send(env.get_wrapper_attr(data))
elif cmd == "set_attr":
remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value]
elif cmd == "is_wrapped":
Expand Down Expand Up @@ -129,7 +128,7 @@ def step_wait(self) -> VecEnvStepReturn:
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos, self.reset_infos = zip(*results) # type: ignore[assignment]
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]
return _stack_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]

def reset(self) -> VecEnvObs:
for env_idx, remote in enumerate(self.remotes):
Expand All @@ -139,7 +138,7 @@ def reset(self) -> VecEnvObs:
# Seeds and options are only used once
self._reset_seeds()
self._reset_options()
return _flatten_obs(obs, self.observation_space)
return _stack_obs(obs, self.observation_space)

def close(self) -> None:
if self.closed:
Expand Down Expand Up @@ -206,27 +205,28 @@ def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]:
return [self.remotes[i] for i in indices]


def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
def _stack_obs(obs_list: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
"""
Flatten observations, depending on the observation space.
Stack observations (convert from a list of single env obs to a stack of obs),
depending on the observation space.
:param obs: observations.
A list or tuple of observations, one per environment.
Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays.
:return: flattened observations.
A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays.
:return: Concatenated observations.
A NumPy array or a dict or tuple of stacked numpy arrays.
Each NumPy array has the environment index as its first axis.
"""
assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment"
assert len(obs) > 0, "need observations from at least one environment"
assert isinstance(obs_list, (list, tuple)), "expected list or tuple of observations per environment"
assert len(obs_list) > 0, "need observations from at least one environment"

if isinstance(space, spaces.Dict):
assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces"
assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space"
return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()])
assert isinstance(space.spaces, dict), "Dict space must have ordered subspaces"
assert isinstance(obs_list[0], dict), "non-dict observation for environment with Dict observation space"
return {key: np.stack([single_obs[key] for single_obs in obs_list]) for key in space.spaces.keys()} # type: ignore[call-overload]
elif isinstance(space, spaces.Tuple):
assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
assert isinstance(obs_list[0], tuple), "non-tuple observation for environment with Tuple observation space"
obs_len = len(space.spaces)
return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) # type: ignore[index]
return tuple(np.stack([single_obs[i] for single_obs in obs_list]) for i in range(obs_len)) # type: ignore[index]
else:
return np.stack(obs) # type: ignore[arg-type]
return np.stack(obs_list) # type: ignore[arg-type]
20 changes: 4 additions & 16 deletions stable_baselines3/common/vec_env/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Helpers for dealing with vectorized environments.
"""

from collections import OrderedDict
from typing import Any, Dict, List, Tuple

import numpy as np
Expand All @@ -12,17 +11,6 @@
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs


def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""
Deep-copy a dict of numpy arrays.
:param obs: a dict of numpy arrays.
:return: a dict of copied numpy arrays.
"""
assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'"
return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])


def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
"""
Convert an internal representation raw_obs into the appropriate type
Expand Down Expand Up @@ -60,18 +48,18 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[
"""
check_for_nested_spaces(obs_space)
if isinstance(obs_space, spaces.Dict):
assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
assert isinstance(obs_space.spaces, dict), "Dict space must have ordered subspaces"
subspaces = obs_space.spaces
elif isinstance(obs_space, spaces.Tuple):
subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment]
subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment,misc]
else:
assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
subspaces = {None: obs_space} # type: ignore[assignment]
subspaces = {None: obs_space} # type: ignore[assignment,dict-item]
keys = []
shapes = {}
dtypes = {}
for key, box in subspaces.items():
keys.append(key)
shapes[key] = box.shape
dtypes[key] = box.dtype
return keys, shapes, dtypes
return keys, shapes, dtypes # type: ignore[return-value]
Loading

0 comments on commit 8f0b488

Please sign in to comment.