Skip to content

Commit

Permalink
Merge branch 'master' into feat/set_options
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Aug 1, 2023
2 parents 30e8107 + d43400b commit 842faef
Show file tree
Hide file tree
Showing 14 changed files with 133 additions and 48 deletions.
2 changes: 1 addition & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ This folder contains documentation for the RL baselines.
#### Install Sphinx and Theme
Execute this command in the project root:
```
pip install -e .[docs]
pip install -e ".[docs]"
```

#### Building the Docs
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ downsampling and "vector" with a single linear layer.
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class CustomCombinedExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space: spaces.Dict):
def __init__(self, observation_space: gym.spaces.Dict):
# We do not know features-dim here before going over all the items,
# so put something dummy for now. PyTorch requires calling
# nn.Module.__init__ before adding modules
Expand Down
4 changes: 2 additions & 2 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ You can control the evaluation frequency with ``eval_freq`` to monitor your agen
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback
from stable-baselines3.common.env_util import make_vec_env
from stable_baselines3.common.env_util import make_vec_env
env_id = "Pendulum-v1"
n_training_envs = 1
Expand All @@ -330,7 +330,7 @@ You can control the evaluation frequency with ``eval_freq`` to monitor your agen
os.makedirs(eval_log_dir, exist_ok=True)
# Initialize a vectorized training environment with default parameters
train_env = make_vec_env(env_id, n_env=n_training_envs, seed=0)
train_env = make_vec_env(env_id, n_envs=n_training_envs, seed=0)
# Separate evaluation env, with different parameters passed via env_kwargs
# Eval environments can be vectorized to speed up evaluation.
Expand Down
11 changes: 9 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 2.1.0a0 (WIP)
Release 2.1.0a3 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -16,6 +16,7 @@ New Features:
^^^^^^^^^^^^^
- Added Python 3.11 support
- Add options argument to pass to `env.reset()`. Same as seeds logic, options are reset at the end of an episode (@ReHoss)
- Added Gymnasium 0.29 support (@pseudo-rnd-thoughts)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand All @@ -25,6 +26,8 @@ New Features:

Bug Fixes:
^^^^^^^^^^
- Relaxed check in logger, that was causing issue on Windows with colorama
- Fixed off-policy algorithms with continuous float64 actions (see #1145) (@tobirohrer)

Deprecations:
^^^^^^^^^^^^^
Expand All @@ -33,9 +36,13 @@ Others:
^^^^^^^
- Updated GitHub issue templates
- Fix typo in gym patch error message (@lukashass)
- Refactor ``test_spaces.py`` tests

Documentation:
^^^^^^^^^^^^^^
- Fixed callback example (@BertrandDecoster)
- Fixed policy network example (@kyle-he)
- Added mobile-env as new community project (@stefanbschneider)


Release 2.0.0 (2023-06-22)
Expand Down Expand Up @@ -1396,4 +1403,4 @@ And all the contributors:
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @ReHoss
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @ReHoss
15 changes: 15 additions & 0 deletions docs/misc/projects.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,18 @@ A simple library for pink noise exploration with deterministic (DDPG / TD3) and
| Authors: Onno Eberhard, Jakob Hollenstein, Cristina Pinneri, Georg Martius
| Github: https://github.com/martius-lab/pink-noise-rl
| Paper: https://openreview.net/forum?id=hQ9V5QN27eS (Oral at ICLR 2023)

mobile-env
----------

An open, minimalist Gymnasium environment for autonomous coordination in wireless mobile networks.
It allows simulating various scenarios with moving users in a cellular network with multiple base stations.

- Written in pure Python, easy to modify and extend, and can be installed directly via PyPI.
- Implements the standard Gymnasium interface such that it can be used with all common frameworks for reinforcement learning.
- There are examples for both single-agent and multi-agent RL using either `stable-baselines3` or Ray RLlib.

| Authors: Stefan Schneider, Stefan Werner
| Github: https://github.com/stefanbschneider/mobile-env
| Paper: https://ris.uni-paderborn.de/download/30236/30237 (2022 IEEE/IFIP Network Operations and Management Symposium (NOMS))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gymnasium==0.28.1",
"gymnasium>=0.28.1,<0.30",
"numpy>=1.20",
"torch>=1.13",
# For saving models
Expand Down
23 changes: 21 additions & 2 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def __init__(
else:
self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)

self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
self.actions = np.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
)

self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
Expand Down Expand Up @@ -311,6 +313,21 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non
)
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))

@staticmethod
def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike:
"""
Cast `np.float64` action datatype to `np.float32`,
keep the others dtype unchanged.
See GH#1572 for more information.
:param dtype: The original action space dtype
:return: ``np.float32`` if the dtype was float64,
the original dtype otherwise.
"""
if dtype == np.float64:
return np.float32
return dtype


class RolloutBuffer(BaseBuffer):
"""
Expand Down Expand Up @@ -543,7 +560,9 @@ def __init__(
for key, _obs_shape in self.obs_shape.items()
}

self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
self.actions = np.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)

Expand Down
6 changes: 4 additions & 2 deletions stable_baselines3/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,10 @@ def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36):
if isinstance(filename_or_file, str):
self.file = open(filename_or_file, "w")
self.own_file = True
elif isinstance(filename_or_file, TextIOBase):
self.file = filename_or_file
elif isinstance(filename_or_file, TextIOBase) or hasattr(filename_or_file, "write"):
# Note: in theory `TextIOBase` check should be sufficient,
# in practice, libraries don't always inherit from it, see GH#1598
self.file = filename_or_file # type: ignore[assignment]
self.own_file = False
else:
raise ValueError(f"Expected file or str, got {filename_or_file}")
Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/common/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def __init__(
mode = "w" if override_existing else "a"
# Prevent newline issue on Windows, see GH issue #692
self.file_handler = open(filename, f"{mode}t", newline="\n")
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t", *extra_keys))
self.logger = csv.DictWriter(
self.file_handler, fieldnames=("r", "l", "t", *extra_keys)
) # pytype: disable=wrong-arg-types
if override_existing:
self.file_handler.write(f"#{json.dumps(header)}\n")
self.logger.writeheader()
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def predict(
if not deterministic and np.random.rand() < self.exploration_rate:
if self.policy.is_vectorized_observation(observation):
if isinstance(observation, dict):
n_batch = observation[list(observation.keys())[0]].shape[0]
n_batch = observation[next(iter(observation.keys()))].shape[0]
else:
n_batch = observation.shape[0]
action = np.array([self.action_space.sample() for _ in range(n_batch)])
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.1.0a0
2.1.0a3
2 changes: 0 additions & 2 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,6 @@ def patched_step(_action):
spaces.Box(low=-1000, high=1000, shape=(3,), dtype=np.float32),
# Too small range
spaces.Box(low=-0.1, high=0.1, shape=(2,), dtype=np.float32),
# Inverted boundaries
spaces.Box(low=1, high=-1, shape=(2,), dtype=np.float32),
# Same boundaries
spaces.Box(low=1, high=1, shape=(2,), dtype=np.float32),
# Unbounded action space
Expand Down
5 changes: 3 additions & 2 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,9 @@ def test_ep_buffers_stats_window_size(algo, stats_window_size):
assert model.ep_success_buffer.maxlen == stats_window_size


def test_human_output_format_custom_test_io():
class DummyTextIO(TextIOBase):
@pytest.mark.parametrize("base_class", [object, TextIOBase])
def test_human_output_format_custom_test_io(base_class):
class DummyTextIO(base_class):
def __init__(self) -> None:
super().__init__()
self.lines = [[]]
Expand Down
101 changes: 71 additions & 30 deletions tests/test_spaces.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,67 @@
from dataclasses import dataclass
from typing import Dict, Optional

import gymnasium as gym
import numpy as np
import pytest
from gymnasium import spaces
from gymnasium.spaces.space import Space

from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy

BOX_SPACE_FLOAT64 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float64)
BOX_SPACE_FLOAT32 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)

class DummyMultiDiscreteSpace(gym.Env):
def __init__(self, nvec):
super().__init__()
self.observation_space = spaces.MultiDiscrete(nvec)
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)

def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
if seed is not None:
super().reset(seed=seed)
return self.observation_space.sample(), {}
@dataclass
class DummyEnv(gym.Env):
observation_space: Space
action_space: Space

def step(self, action):
return self.observation_space.sample(), 0.0, False, False, {}


class DummyMultiBinary(gym.Env):
def __init__(self, n):
super().__init__()
self.observation_space = spaces.MultiBinary(n)
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)

def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
if seed is not None:
super().reset(seed=seed)
return self.observation_space.sample(), {}

def step(self, action):
return self.observation_space.sample(), 0.0, False, False, {}


class DummyMultidimensionalAction(gym.Env):
class DummyMultidimensionalAction(DummyEnv):
def __init__(self):
super().__init__()
self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)
super().__init__(
BOX_SPACE_FLOAT32,
spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32),
)

def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
if seed is not None:
super().reset(seed=seed)
return self.observation_space.sample(), {}

def step(self, action):
return self.observation_space.sample(), 0.0, False, False, {}
class DummyMultiBinary(DummyEnv):
def __init__(self, n):
super().__init__(
spaces.MultiBinary(n),
BOX_SPACE_FLOAT32,
)


class DummyMultiDiscreteSpace(DummyEnv):
def __init__(self, nvec):
super().__init__(
spaces.MultiDiscrete(nvec),
BOX_SPACE_FLOAT32,
)


@pytest.mark.parametrize(
"env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2)), DummyMultidimensionalAction()]
"env",
[
DummyMultiDiscreteSpace([4, 3]),
DummyMultiBinary(8),
DummyMultiBinary((3, 2)),
DummyMultidimensionalAction(),
],
)
def test_env(env):
# Check the env used for testing
Expand Down Expand Up @@ -127,3 +131,40 @@ def test_discrete_obs_space(model_class, env):
else:
kwargs = dict(n_steps=256)
model_class("MlpPolicy", env, **kwargs).learn(256)


@pytest.mark.parametrize("model_class", [SAC, TD3, PPO, DDPG, A2C])
@pytest.mark.parametrize(
"obs_space",
[
BOX_SPACE_FLOAT32,
BOX_SPACE_FLOAT64,
spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT32}),
spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT64}),
],
)
@pytest.mark.parametrize(
"action_space",
[
BOX_SPACE_FLOAT32,
BOX_SPACE_FLOAT64,
],
)
def test_float64_action_space(model_class, obs_space, action_space):
env = DummyEnv(obs_space, action_space)
env = gym.wrappers.TimeLimit(env, max_episode_steps=200)
if isinstance(env.observation_space, spaces.Dict):
policy = "MultiInputPolicy"
else:
policy = "MlpPolicy"

if model_class in [PPO, A2C]:
kwargs = dict(n_steps=64, policy_kwargs=dict(net_arch=[12]))
else:
kwargs = dict(learning_starts=60, policy_kwargs=dict(net_arch=[12]))

model = model_class(policy, env, **kwargs)
model.learn(64)
initial_obs, _ = env.reset()
action, _ = model.predict(initial_obs, deterministic=False)
assert action.dtype == env.action_space.dtype

0 comments on commit 842faef

Please sign in to comment.