From c607585ecc3c813f69b730bf41c5993c622a218f Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sat, 22 Jul 2023 21:31:07 +0200 Subject: [PATCH 01/17] Started PER - Created SumTree (to be ultimated) - Started PrioritizedReplayBuffer - constructor and 'sample' method - to be tested --- .../dqn/prioritized_replay_buffer.py | 160 ++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 stable_baselines3/dqn/prioritized_replay_buffer.py diff --git a/stable_baselines3/dqn/prioritized_replay_buffer.py b/stable_baselines3/dqn/prioritized_replay_buffer.py new file mode 100644 index 000000000..e85f6da6e --- /dev/null +++ b/stable_baselines3/dqn/prioritized_replay_buffer.py @@ -0,0 +1,160 @@ +import random +from typing import Optional, Union +from gymnasium import spaces +import torch as th +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.type_aliases import ReplayBufferSamples +from stable_baselines3.common.vec_env.vec_normalize import VecNormalize + + +class SumTree: + """ + SumTree data structure for Prioritized Replay Buffer. + This code is inspired by: https://github.com/Howuhh/prioritized_experience_replay + + :param size: Max number of element in the buffer. + """ + def __init__(self, size: int): + self.nodes = [0] * (2 * size - 1) + self.data = [None] * size + self.size = size + self.count = 0 + self.real_size = 0 + + @property + def p_total(self): + return self.nodes[0] + + def update(self, data_idx, value): + idx = data_idx + self.size - 1 # child index in tree array + change = value - self.nodes[idx] + + self.nodes[idx] = value + + parent = (idx - 1) // 2 + while parent >= 0: + self.nodes[parent] += change + parent = (parent - 1) // 2 + + def add(self, value, data): + self.data[self.count] = data + self.update(self.count, value) + + self.count = (self.count + 1) % self.size + self.real_size = min(self.size, self.real_size + 1) + + def get(self, cumsum): + assert cumsum <= self.p_total + + idx = 0 + while 2 * idx + 1 < len(self.nodes): + left, right = 2*idx + 1, 2*idx + 2 + + if cumsum <= self.nodes[left]: + idx = left + else: + idx = right + cumsum = cumsum - self.nodes[left] + + data_idx = idx - self.size + 1 + + return data_idx, self.nodes[idx], self.data[data_idx] + + def __repr__(self): + return f"SumTree(nodes={self.nodes.__repr__()}, data={self.data.__repr__()})" + + +class PrioritizedReplayBuffer(ReplayBuffer): + """ + Prioritized Replay Buffer. + Paper: https://arxiv.org/abs/1511.05952 + This code is inspired by: https://github.com/Howuhh/prioritized_experience_replay + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: PyTorch device + :param n_envs: Number of parallel environments + :param alpha: How much prioritization is used (0 - no prioritization, 1 - full prioritization) + :param beta: To what degree to use importance weights (0 - no corrections, 1 - full correction) + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "auto", + n_envs: int = 1, + alpha: float = 0.6, + beta: float = 0.4, + ): + super().__init__(buffer_size, observation_space, action_space, device, n_envs) + + # PER params + self.eps = 1e-8 # minimal priority, prevents zero probabilities + self.alpha = alpha # determines how much prioritization is used, alpha = 0 corresponding to the uniform case + self.beta = beta # determines the amount of importance-sampling correction, beta = 1 fully compensate for the non-uniform probabilities + self.max_priority = self.eps # priority for new samples, init as eps + + # SumTree: data structure to store priorities + self.tree = SumTree(size=buffer_size) + + def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: + """ + Sample elements from the prioritized replay buffer. + + :param batch_size: Number of element to sample + :param env:associated gym VecEnv + to normalize the observations/rewards when sampling + :return: + """ + assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." + + sample_idxs, tree_idxs = [], [] + priorities = th.empty(batch_size, 1, dtype=th.float) + + # To sample a minibatch of size k, the range [0, p_total] is divided equally into k ranges. + # Next, a value is uniformly sampled from each range. Finally the transitions that correspond + # to each of these sampled values are retrieved from the tree. + segment = self.tree.p_total / batch_size + for i in range(batch_size): + # extremes of the current segment + a, b = segment * i, segment * (i + 1) + + # uniformely sample a value from the current segment + cumsum = random.uniform(a, b) + + # tree_idx is a index of a sample in the tree, needed further to update priorities + # sample_idx is a sample index in buffer, needed further to sample actual transitions + tree_idx, priority, sample_idx = self.tree.get(cumsum) + + priorities[i] = priority + tree_idxs.append(tree_idx) + sample_idxs.append(sample_idx) + + # probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha + # where p_i > 0 is the priority of transition i. + probs = priorities / self.tree.p_total + + # The estimation of the expected value with stochastic updates relies on those updates corresponding + # to the same distribution as its expectation. Prioritized replay introduces bias because it changes this + # distribution in an uncontrolled fashion, and therefore changes the solution that the estimates will + # converge to (even if the policy and state distribution are fixed). We can correct this bias by using + # importance-sampling (IS) weights w_i = (1/N * 1/P(i))^β that fully compensates for the non-uniform + # probabilities P(i) if β = 1. These weights can be folded into the Q-learning update by using w_i * δ_i + # instead of δ_i (this is thus weighted IS, not ordinary IS, see e.g. Mahmood et al., 2014). + + # Importance sampling weights. + # All weights w_i were scaled so that max_i w_i = 1. + weights = (self.real_size * probs) ** -self.beta + weights = weights / weights.max() + + batch = ReplayBufferSamples( + self.observations[sample_idxs], + self.actions[sample_idxs], + self.next_observations[sample_idxs], + self.dones[sample_idxs], + self.rewards[sample_idxs], + ) + return batch, weights, tree_idxs From 57f11922509204c56df91e2ca8d55142655e89dc Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 23 Jul 2023 14:58:57 +0200 Subject: [PATCH 02/17] Added "add" method + other improvements --- .../dqn/prioritized_replay_buffer.py | 34 +++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/dqn/prioritized_replay_buffer.py b/stable_baselines3/dqn/prioritized_replay_buffer.py index e85f6da6e..cbf6be9f6 100644 --- a/stable_baselines3/dqn/prioritized_replay_buffer.py +++ b/stable_baselines3/dqn/prioritized_replay_buffer.py @@ -1,6 +1,8 @@ import random -from typing import Optional, Union +from typing import Any, Dict, List, Optional, Union +import warnings from gymnasium import spaces +import numpy as np import torch as th from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.type_aliases import ReplayBufferSamples @@ -25,7 +27,7 @@ def __init__(self, size: int): def p_total(self): return self.nodes[0] - def update(self, data_idx, value): + def update(self, data_idx: int, value: float): idx = data_idx + self.size - 1 # child index in tree array change = value - self.nodes[idx] @@ -36,7 +38,7 @@ def update(self, data_idx, value): self.nodes[parent] += change parent = (parent - 1) // 2 - def add(self, value, data): + def add(self, value: float, data): self.data[self.count] = data self.update(self.count, value) @@ -88,9 +90,14 @@ def __init__( n_envs: int = 1, alpha: float = 0.6, beta: float = 0.4, + optimize_memory_usage: bool = False, ): super().__init__(buffer_size, observation_space, action_space, device, n_envs) + # TODO: check this + if optimize_memory_usage: + warnings.warn("PrioritizedReplayBuffer does not support optimize_memory_usage=True during sampling") + # PER params self.eps = 1e-8 # minimal priority, prevents zero probabilities self.alpha = alpha # determines how much prioritization is used, alpha = 0 corresponding to the uniform case @@ -99,6 +106,27 @@ def __init__( # SumTree: data structure to store priorities self.tree = SumTree(size=buffer_size) + + self.real_size = 0 + self.count = 0 + + def add(self, + obs: np.ndarray, + next_obs: np.ndarray, + action: np.ndarray, + reward: np.ndarray, + done: np.ndarray, + infos: List[Dict[str, Any]], + ) -> None: + # store transition index with maximum priority in sum tree + self.tree.add(self.max_priority, self.count) + + # update counters + self.count = (self.count + 1) % self.buffer_size + self.real_size = min(self.buffer_size, self.real_size + 1) + + # store transition in the buffer + super().add(obs, next_obs, action, reward, done, infos) def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: """ From 2b9df33935220c209cc87decc9a4c599122b5010 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 23 Jul 2023 19:34:18 +0200 Subject: [PATCH 03/17] Docstrings, type hints, doc --- docs/misc/changelog.rst | 1 + docs/modules/dqn.rst | 3 +- stable_baselines3/common/buffers.py | 2 +- .../dqn/prioritized_replay_buffer.py | 76 ++++++++++++------- 4 files changed, 54 insertions(+), 28 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 7bb1f16a1..96041127f 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -15,6 +15,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Added Python 3.11 support +- Prioritized Experience Replay for DQN (@AlexPasqua) `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst index 85d486661..894e201fa 100644 --- a/docs/modules/dqn.rst +++ b/docs/modules/dqn.rst @@ -27,7 +27,8 @@ Notes - Further reference: https://www.nature.com/articles/nature14236 .. note:: - This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay. + This implementation does **not** provide Rainbow DQN, but only vanilla Deep Q-Learning. + There are no extensions such as Double-DQN or Dueling-DQN, with the only exception being Prioritized Experience Replay. Can I use? diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index fe633e1af..c70d0a664 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -279,7 +279,7 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB :param batch_size: Number of element to sample :param env: associated gym VecEnv to normalize the observations/rewards when sampling - :return: + :return: a batch of sampled experiences from the buffer. """ if not self.optimize_memory_usage: return super().sample(batch_size=batch_size, env=env) diff --git a/stable_baselines3/dqn/prioritized_replay_buffer.py b/stable_baselines3/dqn/prioritized_replay_buffer.py index cbf6be9f6..6854db44f 100644 --- a/stable_baselines3/dqn/prioritized_replay_buffer.py +++ b/stable_baselines3/dqn/prioritized_replay_buffer.py @@ -1,9 +1,11 @@ import random -from typing import Any, Dict, List, Optional, Union import warnings -from gymnasium import spaces +from typing import Any, Dict, List, Optional, Union + import numpy as np import torch as th +from gymnasium import spaces + from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.type_aliases import ReplayBufferSamples from stable_baselines3.common.vec_env.vec_normalize import VecNormalize @@ -16,6 +18,7 @@ class SumTree: :param size: Max number of element in the buffer. """ + def __init__(self, size: int): self.nodes = [0] * (2 * size - 1) self.data = [None] * size @@ -24,34 +27,53 @@ def __init__(self, size: int): self.real_size = 0 @property - def p_total(self): + def p_total(self) -> float: + """ + Returns the root node value, which represents the total sum of all priorities in the tree. + + :return: Total sum of all priorities in the tree. + """ return self.nodes[0] def update(self, data_idx: int, value: float): + """ + Update the priority of a leaf node. + + :param data_idx: Index of the leaf node to update. + :param value: New priority value. + """ idx = data_idx + self.size - 1 # child index in tree array change = value - self.nodes[idx] - self.nodes[idx] = value - parent = (idx - 1) // 2 while parent >= 0: self.nodes[parent] += change parent = (parent - 1) // 2 - def add(self, value: float, data): + def add(self, value: float, data: int): + """ + Add a new transition with priority value. + + :param value: Priority value. + :param data: Transition data. + """ self.data[self.count] = data self.update(self.count, value) - self.count = (self.count + 1) % self.size self.real_size = min(self.size, self.real_size + 1) - def get(self, cumsum): + def get(self, cumsum) -> tuple[int, float, Any]: + """ + Get a leaf node index, its priority value and transition data by cumsum value. + + :param cumsum: Cumulative sum value. + :return: Leaf node index, its priority value and transition data. + """ assert cumsum <= self.p_total idx = 0 while 2 * idx + 1 < len(self.nodes): - left, right = 2*idx + 1, 2*idx + 2 - + left, right = 2 * idx + 1, 2 * idx + 2 if cumsum <= self.nodes[left]: idx = left else: @@ -59,7 +81,6 @@ def get(self, cumsum): cumsum = cumsum - self.nodes[left] data_idx = idx - self.size + 1 - return data_idx, self.nodes[idx], self.data[data_idx] def __repr__(self): @@ -101,7 +122,7 @@ def __init__( # PER params self.eps = 1e-8 # minimal priority, prevents zero probabilities self.alpha = alpha # determines how much prioritization is used, alpha = 0 corresponding to the uniform case - self.beta = beta # determines the amount of importance-sampling correction, beta = 1 fully compensate for the non-uniform probabilities + self.beta = beta # determines the amount of importance-sampling correction self.max_priority = self.eps # priority for new samples, init as eps # SumTree: data structure to store priorities @@ -109,8 +130,9 @@ def __init__( self.real_size = 0 self.count = 0 - - def add(self, + + def add( + self, obs: np.ndarray, next_obs: np.ndarray, action: np.ndarray, @@ -118,16 +140,26 @@ def add(self, done: np.ndarray, infos: List[Dict[str, Any]], ) -> None: + """ + Add a new transition to the buffer. + + :param obs: Starting observation of the transition to be stored. + :param next_obs: Destination observation of the transition to be stored. + :param action: Action performed in the transition to be stored. + :param reward: Reward received in the transition to be stored. + :param done: Whether the episode was finished after the transition to be stored. + :param infos: Eventual information given by the environment. + """ # store transition index with maximum priority in sum tree self.tree.add(self.max_priority, self.count) # update counters self.count = (self.count + 1) % self.buffer_size self.real_size = min(self.buffer_size, self.real_size + 1) - + # store transition in the buffer super().add(obs, next_obs, action, reward, done, infos) - + def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: """ Sample elements from the prioritized replay buffer. @@ -135,7 +167,7 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB :param batch_size: Number of element to sample :param env:associated gym VecEnv to normalize the observations/rewards when sampling - :return: + :return: a batch of sampled experiences from the buffer. """ assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." @@ -165,14 +197,6 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB # where p_i > 0 is the priority of transition i. probs = priorities / self.tree.p_total - # The estimation of the expected value with stochastic updates relies on those updates corresponding - # to the same distribution as its expectation. Prioritized replay introduces bias because it changes this - # distribution in an uncontrolled fashion, and therefore changes the solution that the estimates will - # converge to (even if the policy and state distribution are fixed). We can correct this bias by using - # importance-sampling (IS) weights w_i = (1/N * 1/P(i))^β that fully compensates for the non-uniform - # probabilities P(i) if β = 1. These weights can be folded into the Q-learning update by using w_i * δ_i - # instead of δ_i (this is thus weighted IS, not ordinary IS, see e.g. Mahmood et al., 2014). - # Importance sampling weights. # All weights w_i were scaled so that max_i w_i = 1. weights = (self.real_size * probs) ** -self.beta @@ -185,4 +209,4 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB self.dones[sample_idxs], self.rewards[sample_idxs], ) - return batch, weights, tree_idxs + return batch From aee1d30c2aa8c3c669fe3dbe400c6c4317609b13 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 6 Aug 2023 16:51:18 +0200 Subject: [PATCH 04/17] FIxed for pytype checks (partially) --- .../dqn/prioritized_replay_buffer.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/stable_baselines3/dqn/prioritized_replay_buffer.py b/stable_baselines3/dqn/prioritized_replay_buffer.py index 6854db44f..457c24060 100644 --- a/stable_baselines3/dqn/prioritized_replay_buffer.py +++ b/stable_baselines3/dqn/prioritized_replay_buffer.py @@ -20,8 +20,8 @@ class SumTree: """ def __init__(self, size: int): - self.nodes = [0] * (2 * size - 1) - self.data = [None] * size + self.nodes = th.zeros(2 * size - 1) + self.data = th.empty(size) self.size = size self.count = 0 self.real_size = 0 @@ -33,7 +33,7 @@ def p_total(self) -> float: :return: Total sum of all priorities in the tree. """ - return self.nodes[0] + return self.nodes[0].item() def update(self, data_idx: int, value: float): """ @@ -62,7 +62,7 @@ def add(self, value: float, data: int): self.count = (self.count + 1) % self.size self.real_size = min(self.size, self.real_size + 1) - def get(self, cumsum) -> tuple[int, float, Any]: + def get(self, cumsum) -> tuple[int, float, th.Tensor]: """ Get a leaf node index, its priority value and transition data by cumsum value. @@ -81,7 +81,7 @@ def get(self, cumsum) -> tuple[int, float, Any]: cumsum = cumsum - self.nodes[left] data_idx = idx - self.size + 1 - return data_idx, self.nodes[idx], self.data[data_idx] + return data_idx, self.nodes[idx].item(), self.data[data_idx] def __repr__(self): return f"SumTree(nodes={self.nodes.__repr__()}, data={self.data.__repr__()})" @@ -191,7 +191,7 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB priorities[i] = priority tree_idxs.append(tree_idx) - sample_idxs.append(sample_idx) + sample_idxs.append(int(sample_idx.item())) # probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha # where p_i > 0 is the priority of transition i. @@ -202,11 +202,18 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB weights = (self.real_size * probs) ** -self.beta weights = weights / weights.max() - batch = ReplayBufferSamples( - self.observations[sample_idxs], - self.actions[sample_idxs], - self.next_observations[sample_idxs], + env_indices = np.random.randint(0, high=self.n_envs, size=(len(sample_idxs),)) + + if self.optimize_memory_usage: + next_obs = self._normalize_obs(self.observations[(np.array(sample_idxs) + 1) % self.buffer_size, env_indices, :], env) + else: + next_obs = self._normalize_obs(self.next_observations[sample_idxs, env_indices, :], env) + + batch = ( + self._normalize_obs(self.observations[sample_idxs, env_indices, :], env), + self.actions[sample_idxs, env_indices, :], + next_obs, self.dones[sample_idxs], self.rewards[sample_idxs], ) - return batch + return ReplayBufferSamples(*tuple(map(self.to_torch, batch))) From c51b173118990f61763b3e7ea51ead997ad3b7b7 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 6 Aug 2023 16:59:17 +0200 Subject: [PATCH 05/17] make format --- stable_baselines3/dqn/prioritized_replay_buffer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/dqn/prioritized_replay_buffer.py b/stable_baselines3/dqn/prioritized_replay_buffer.py index 457c24060..76d120068 100644 --- a/stable_baselines3/dqn/prioritized_replay_buffer.py +++ b/stable_baselines3/dqn/prioritized_replay_buffer.py @@ -205,7 +205,9 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB env_indices = np.random.randint(0, high=self.n_envs, size=(len(sample_idxs),)) if self.optimize_memory_usage: - next_obs = self._normalize_obs(self.observations[(np.array(sample_idxs) + 1) % self.buffer_size, env_indices, :], env) + next_obs = self._normalize_obs( + self.observations[(np.array(sample_idxs) + 1) % self.buffer_size, env_indices, :], env + ) else: next_obs = self._normalize_obs(self.next_observations[sample_idxs, env_indices, :], env) From 18c9d28c4b60e46c0aaeae1729ec1d1f57c3e3c5 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 6 Aug 2023 19:53:07 +0200 Subject: [PATCH 06/17] Made pytype ignore type on PER's sample method --- stable_baselines3/dqn/prioritized_replay_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/dqn/prioritized_replay_buffer.py b/stable_baselines3/dqn/prioritized_replay_buffer.py index 76d120068..0b197351d 100644 --- a/stable_baselines3/dqn/prioritized_replay_buffer.py +++ b/stable_baselines3/dqn/prioritized_replay_buffer.py @@ -218,4 +218,4 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB self.dones[sample_idxs], self.rewards[sample_idxs], ) - return ReplayBufferSamples(*tuple(map(self.to_torch, batch))) + return ReplayBufferSamples(*tuple(map(self.to_torch, batch))) # type: ignore From fb33732d7ae6de6104d28bb740f9553eed5665bc Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 29 Sep 2023 10:31:48 +0200 Subject: [PATCH 07/17] Switch to numpy for the backend --- .../dqn/prioritized_replay_buffer.py | 34 ++++++++----------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/stable_baselines3/dqn/prioritized_replay_buffer.py b/stable_baselines3/dqn/prioritized_replay_buffer.py index 0b197351d..d3f69f16a 100644 --- a/stable_baselines3/dqn/prioritized_replay_buffer.py +++ b/stable_baselines3/dqn/prioritized_replay_buffer.py @@ -1,6 +1,5 @@ -import random import warnings -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch as th @@ -16,13 +15,13 @@ class SumTree: SumTree data structure for Prioritized Replay Buffer. This code is inspired by: https://github.com/Howuhh/prioritized_experience_replay - :param size: Max number of element in the buffer. + :param buffer_size: Max number of element in the buffer. """ - def __init__(self, size: int): - self.nodes = th.zeros(2 * size - 1) - self.data = th.empty(size) - self.size = size + def __init__(self, buffer_size: int) -> None: + self.nodes = np.zeros((2 * buffer_size - 1)) + self.data = np.zeros(buffer_size) + self.size = buffer_size self.count = 0 self.real_size = 0 @@ -35,7 +34,7 @@ def p_total(self) -> float: """ return self.nodes[0].item() - def update(self, data_idx: int, value: float): + def update(self, data_idx: int, value: float) -> None: """ Update the priority of a leaf node. @@ -50,7 +49,7 @@ def update(self, data_idx: int, value: float): self.nodes[parent] += change parent = (parent - 1) // 2 - def add(self, value: float, data: int): + def add(self, value: float, data: int) -> None: """ Add a new transition with priority value. @@ -62,7 +61,7 @@ def add(self, value: float, data: int): self.count = (self.count + 1) % self.size self.real_size = min(self.size, self.real_size + 1) - def get(self, cumsum) -> tuple[int, float, th.Tensor]: + def get(self, cumsum) -> Tuple[int, float, th.Tensor]: """ Get a leaf node index, its priority value and transition data by cumsum value. @@ -83,7 +82,7 @@ def get(self, cumsum) -> tuple[int, float, th.Tensor]: data_idx = idx - self.size + 1 return data_idx, self.nodes[idx].item(), self.data[data_idx] - def __repr__(self): + def __repr__(self) -> str: return f"SumTree(nodes={self.nodes.__repr__()}, data={self.data.__repr__()})" @@ -115,18 +114,15 @@ def __init__( ): super().__init__(buffer_size, observation_space, action_space, device, n_envs) - # TODO: check this - if optimize_memory_usage: - warnings.warn("PrioritizedReplayBuffer does not support optimize_memory_usage=True during sampling") + assert optimize_memory_usage is False, "PrioritizedReplayBuffer doesn't support optimize_memory_usage=True" - # PER params self.eps = 1e-8 # minimal priority, prevents zero probabilities self.alpha = alpha # determines how much prioritization is used, alpha = 0 corresponding to the uniform case self.beta = beta # determines the amount of importance-sampling correction self.max_priority = self.eps # priority for new samples, init as eps # SumTree: data structure to store priorities - self.tree = SumTree(size=buffer_size) + self.tree = SumTree(buffer_size=buffer_size) self.real_size = 0 self.count = 0 @@ -172,7 +168,7 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." sample_idxs, tree_idxs = [], [] - priorities = th.empty(batch_size, 1, dtype=th.float) + priorities = np.zeros((batch_size, 1)) # To sample a minibatch of size k, the range [0, p_total] is divided equally into k ranges. # Next, a value is uniformly sampled from each range. Finally the transitions that correspond @@ -183,7 +179,7 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB a, b = segment * i, segment * (i + 1) # uniformely sample a value from the current segment - cumsum = random.uniform(a, b) + cumsum = np.random.uniform(a, b) # tree_idx is a index of a sample in the tree, needed further to update priorities # sample_idx is a sample index in buffer, needed further to sample actual transitions @@ -218,4 +214,4 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB self.dones[sample_idxs], self.rewards[sample_idxs], ) - return ReplayBufferSamples(*tuple(map(self.to_torch, batch))) # type: ignore + return ReplayBufferSamples(*tuple(map(self.to_torch, batch))) # type: ignore[arg-type] From f984e5c7a2e6e0c517b2e402fd90e66fa9c28be3 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 29 Sep 2023 10:37:44 +0200 Subject: [PATCH 08/17] Move to common and add tests --- .../{dqn => common}/prioritized_replay_buffer.py | 3 +-- tests/test_buffers.py | 8 ++++++-- tests/test_run.py | 5 ++++- 3 files changed, 11 insertions(+), 5 deletions(-) rename stable_baselines3/{dqn => common}/prioritized_replay_buffer.py (99%) diff --git a/stable_baselines3/dqn/prioritized_replay_buffer.py b/stable_baselines3/common/prioritized_replay_buffer.py similarity index 99% rename from stable_baselines3/dqn/prioritized_replay_buffer.py rename to stable_baselines3/common/prioritized_replay_buffer.py index d3f69f16a..852c77a49 100644 --- a/stable_baselines3/dqn/prioritized_replay_buffer.py +++ b/stable_baselines3/common/prioritized_replay_buffer.py @@ -1,4 +1,3 @@ -import warnings from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -19,7 +18,7 @@ class SumTree: """ def __init__(self, buffer_size: int) -> None: - self.nodes = np.zeros((2 * buffer_size - 1)) + self.nodes = np.zeros(2 * buffer_size - 1) self.data = np.zeros(buffer_size) self.size = buffer_size self.count = 0 diff --git a/tests/test_buffers.py b/tests/test_buffers.py index e7d4a1c57..84aa474b3 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -7,6 +7,7 @@ from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize @@ -108,7 +109,9 @@ def test_replay_buffer_normalization(replay_buffer_cls): assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1) -@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer]) +@pytest.mark.parametrize( + "replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer, PrioritizedReplayBuffer] +) @pytest.mark.parametrize("device", ["cpu", "cuda", "auto"]) def test_device_buffer(replay_buffer_cls, device): if device == "cuda" and not th.cuda.is_available(): @@ -119,6 +122,7 @@ def test_device_buffer(replay_buffer_cls, device): DictRolloutBuffer: DummyDictEnv, ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv, + PrioritizedReplayBuffer: DummyEnv, }[replay_buffer_cls] env = make_vec_env(env) @@ -139,7 +143,7 @@ def test_device_buffer(replay_buffer_cls, device): # Get data from the buffer if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: data = buffer.get(50) - elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]: + elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer, PrioritizedReplayBuffer]: data = buffer.sample(50) # Check that all data are on the desired device diff --git a/tests/test_run.py b/tests/test_run.py index 31c7b956e..a7a30eed5 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -5,6 +5,7 @@ from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise +from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)) @@ -100,7 +101,8 @@ def test_n_critics(n_critics): model.learn(total_timesteps=200) -def test_dqn(): +@pytest.mark.parametrize("replay_buffer_class", [None, PrioritizedReplayBuffer]) +def test_dqn(replay_buffer_class): model = DQN( "MlpPolicy", "CartPole-v1", @@ -109,6 +111,7 @@ def test_dqn(): buffer_size=500, learning_rate=3e-4, verbose=1, + replay_buffer_class=replay_buffer_class, ) model.learn(total_timesteps=200) From 5edf8bf42a6d4ccae902180e121d0e9a4bad84ad Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sat, 30 Sep 2023 19:40:46 +0200 Subject: [PATCH 09/17] Updated DQN docs Added list of rainbow extensions, specifying which ones are currently implemented in the library --- docs/modules/dqn.rst | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst index 894e201fa..2a06c91d5 100644 --- a/docs/modules/dqn.rst +++ b/docs/modules/dqn.rst @@ -28,7 +28,7 @@ Notes .. note:: This implementation does **not** provide Rainbow DQN, but only vanilla Deep Q-Learning. - There are no extensions such as Double-DQN or Dueling-DQN, with the only exception being Prioritized Experience Replay. + Currently, there are no extensions such as Double-DQN or Dueling-DQN, with the only exception being Prioritized Experience Replay. Can I use? @@ -49,6 +49,15 @@ MultiBinary ❌ ✔️ Dict ❌ ✔️️ ============= ====== =========== +- Rainbow DQN extensions: + + - Double Q-Learning: ❌ + - Prioritized Experience Replay: ✔️ + - Dueling Networks: ❌ + - Multi-step Learning: ❌ + - Distributional RL: ❌ + - Noisy Nets: ❌ + Example ------- From 2f76038c9eb398f28ebb7d346b0ef34747192bd5 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 2 Oct 2023 17:04:24 +0200 Subject: [PATCH 10/17] Update doc --- docs/misc/changelog.rst | 2 +- docs/modules/dqn.rst | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index d89f5da1d..8d857efba 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -14,6 +14,7 @@ New Features: ^^^^^^^^^^^^^ - Improved error message of the ``env_checker`` for env wrongly detected as GoalEnv (``compute_reward()`` is defined) - Improved error message when mixing Gym API with VecEnv API (see GH#1694) +- Added Prioritized Experience Replay for DQN (@AlexPasqua) Bug Fixes: ^^^^^^^^^^ @@ -71,7 +72,6 @@ New Features: ^^^^^^^^^^^^^ - Added Python 3.11 support - Added Gymnasium 0.29 support (@pseudo-rnd-thoughts) -- Prioritized Experience Replay for DQN (@AlexPasqua) `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst index 2a06c91d5..401fc1a47 100644 --- a/docs/modules/dqn.rst +++ b/docs/modules/dqn.rst @@ -27,9 +27,9 @@ Notes - Further reference: https://www.nature.com/articles/nature14236 .. note:: - This implementation does **not** provide Rainbow DQN, but only vanilla Deep Q-Learning. - Currently, there are no extensions such as Double-DQN or Dueling-DQN, with the only exception being Prioritized Experience Replay. + This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN or Dueling-DQN. + To Prioritized Experience Replay, you need to pass it via the ``replay_buffer_class`` argument Can I use? ---------- @@ -52,10 +52,10 @@ Dict ❌ ✔️️ - Rainbow DQN extensions: - Double Q-Learning: ❌ - - Prioritized Experience Replay: ✔️ + - Prioritized Experience Replay: ✔️ (``from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer``) - Dueling Networks: ❌ - Multi-step Learning: ❌ - - Distributional RL: ❌ + - Distributional RL: ✔️ (``QR-DQN`` is implemented in the SB3 contrib repo) - Noisy Nets: ❌ From 42f2f4ad40337667636bc4cc0987bf827d970151 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 2 Oct 2023 17:04:31 +0200 Subject: [PATCH 11/17] Rename things to be consistent with buffers.py --- .../common/prioritized_replay_buffer.py | 104 +++++++++--------- 1 file changed, 54 insertions(+), 50 deletions(-) diff --git a/stable_baselines3/common/prioritized_replay_buffer.py b/stable_baselines3/common/prioritized_replay_buffer.py index 852c77a49..93990144b 100644 --- a/stable_baselines3/common/prioritized_replay_buffer.py +++ b/stable_baselines3/common/prioritized_replay_buffer.py @@ -20,12 +20,20 @@ class SumTree: def __init__(self, buffer_size: int) -> None: self.nodes = np.zeros(2 * buffer_size - 1) self.data = np.zeros(buffer_size) - self.size = buffer_size - self.count = 0 - self.real_size = 0 + self.buffer_size = buffer_size + self.pos = 0 + self.full = False + + def size(self) -> int: + """ + :return: The current size of the SumTree + """ + if self.full: + return self.buffer_size + return self.pos @property - def p_total(self) -> float: + def total_sum(self) -> float: """ Returns the root node value, which represents the total sum of all priorities in the tree. @@ -40,7 +48,7 @@ def update(self, data_idx: int, value: float) -> None: :param data_idx: Index of the leaf node to update. :param value: New priority value. """ - idx = data_idx + self.size - 1 # child index in tree array + idx = data_idx + self.buffer_size - 1 # child index in tree array change = value - self.nodes[idx] self.nodes[idx] = value parent = (idx - 1) // 2 @@ -55,34 +63,36 @@ def add(self, value: float, data: int) -> None: :param value: Priority value. :param data: Transition data. """ - self.data[self.count] = data - self.update(self.count, value) - self.count = (self.count + 1) % self.size - self.real_size = min(self.size, self.real_size + 1) - - def get(self, cumsum) -> Tuple[int, float, th.Tensor]: + self.data[self.pos] = data + self.update(self.pos, value) + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + self.pos = 0 + + def get(self, cumulative_sum: float) -> Tuple[int, float, th.Tensor]: """ - Get a leaf node index, its priority value and transition data by cumsum value. + Get a leaf node index, its priority value and transition data by cumulative_sum value. - :param cumsum: Cumulative sum value. + :param cumulative_sum: Cumulative sum value. :return: Leaf node index, its priority value and transition data. """ - assert cumsum <= self.p_total + assert cumulative_sum <= self.total_sum idx = 0 while 2 * idx + 1 < len(self.nodes): left, right = 2 * idx + 1, 2 * idx + 2 - if cumsum <= self.nodes[left]: + if cumulative_sum <= self.nodes[left]: idx = left else: idx = right - cumsum = cumsum - self.nodes[left] + cumulative_sum = cumulative_sum - self.nodes[left] - data_idx = idx - self.size + 1 + data_idx = idx - self.buffer_size + 1 return data_idx, self.nodes[idx].item(), self.data[data_idx] def __repr__(self) -> str: - return f"SumTree(nodes={self.nodes.__repr__()}, data={self.data.__repr__()})" + return f"SumTree(nodes={self.nodes!r}, data={self.data!r})" class PrioritizedReplayBuffer(ReplayBuffer): @@ -96,7 +106,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): :param action_space: Action space :param device: PyTorch device :param n_envs: Number of parallel environments - :param alpha: How much prioritization is used (0 - no prioritization, 1 - full prioritization) + :param alpha: How much prioritization is used (0 - no prioritization aka uniform case, 1 - full prioritization) :param beta: To what degree to use importance weights (0 - no corrections, 1 - full correction) """ @@ -115,17 +125,14 @@ def __init__( assert optimize_memory_usage is False, "PrioritizedReplayBuffer doesn't support optimize_memory_usage=True" - self.eps = 1e-8 # minimal priority, prevents zero probabilities - self.alpha = alpha # determines how much prioritization is used, alpha = 0 corresponding to the uniform case - self.beta = beta # determines the amount of importance-sampling correction - self.max_priority = self.eps # priority for new samples, init as eps + self.min_priority = 1e-8 # minimal priority, prevents zero probabilities + self.alpha = alpha + self.beta = beta + self.max_priority = self.min_priority # priority for new samples, init as eps # SumTree: data structure to store priorities self.tree = SumTree(buffer_size=buffer_size) - self.real_size = 0 - self.count = 0 - def add( self, obs: np.ndarray, @@ -146,11 +153,7 @@ def add( :param infos: Eventual information given by the environment. """ # store transition index with maximum priority in sum tree - self.tree.add(self.max_priority, self.count) - - # update counters - self.count = (self.count + 1) % self.buffer_size - self.real_size = min(self.buffer_size, self.real_size + 1) + self.tree.add(self.max_priority, self.pos) # store transition in the buffer super().add(obs, next_obs, action, reward, done, infos) @@ -166,51 +169,52 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB """ assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." - sample_idxs, tree_idxs = [], [] + tree_indices = np.zeros(batch_size, dtype=np.uint32) priorities = np.zeros((batch_size, 1)) + sample_indices = np.zeros(batch_size, dtype=np.uint32) - # To sample a minibatch of size k, the range [0, p_total] is divided equally into k ranges. + # To sample a minibatch of size k, the range [0, total_sum] is divided equally into k ranges. # Next, a value is uniformly sampled from each range. Finally the transitions that correspond # to each of these sampled values are retrieved from the tree. - segment = self.tree.p_total / batch_size - for i in range(batch_size): + segment_size = self.tree.total_sum / batch_size + for batch_idx in range(batch_size): # extremes of the current segment - a, b = segment * i, segment * (i + 1) + start, end = segment_size * batch_idx, segment_size * (batch_idx + 1) # uniformely sample a value from the current segment - cumsum = np.random.uniform(a, b) + cumulative_sum = np.random.uniform(start, end) # tree_idx is a index of a sample in the tree, needed further to update priorities # sample_idx is a sample index in buffer, needed further to sample actual transitions - tree_idx, priority, sample_idx = self.tree.get(cumsum) + tree_idx, priority, sample_idx = self.tree.get(cumulative_sum) - priorities[i] = priority - tree_idxs.append(tree_idx) - sample_idxs.append(int(sample_idx.item())) + tree_indices[batch_idx] = tree_idx + priorities[batch_idx] = priority + sample_indices[batch_idx] = sample_idx # probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha # where p_i > 0 is the priority of transition i. - probs = priorities / self.tree.p_total + probs = priorities / self.tree.total_sum # Importance sampling weights. # All weights w_i were scaled so that max_i w_i = 1. - weights = (self.real_size * probs) ** -self.beta + weights = (self.size() * probs) ** -self.beta weights = weights / weights.max() - env_indices = np.random.randint(0, high=self.n_envs, size=(len(sample_idxs),)) + env_indices = np.random.randint(0, high=self.n_envs, size=(batch_size,)) if self.optimize_memory_usage: next_obs = self._normalize_obs( - self.observations[(np.array(sample_idxs) + 1) % self.buffer_size, env_indices, :], env + self.observations[(sample_indices + 1) % self.buffer_size, env_indices, :], env ) else: - next_obs = self._normalize_obs(self.next_observations[sample_idxs, env_indices, :], env) + next_obs = self._normalize_obs(self.next_observations[sample_indices, env_indices, :], env) batch = ( - self._normalize_obs(self.observations[sample_idxs, env_indices, :], env), - self.actions[sample_idxs, env_indices, :], + self._normalize_obs(self.observations[sample_indices, env_indices, :], env), + self.actions[sample_indices, env_indices, :], next_obs, - self.dones[sample_idxs], - self.rewards[sample_idxs], + self.dones[sample_indices], + self.rewards[sample_indices], ) return ReplayBufferSamples(*tuple(map(self.to_torch, batch))) # type: ignore[arg-type] From 007105f8b6b73697aac8740b3d71726c2d90fd9d Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 2 Oct 2023 18:12:14 +0200 Subject: [PATCH 12/17] Rename variables and add priority update --- .../common/prioritized_replay_buffer.py | 82 +++++++++++-------- stable_baselines3/common/type_aliases.py | 4 + stable_baselines3/dqn/dqn.py | 15 +++- 3 files changed, 67 insertions(+), 34 deletions(-) diff --git a/stable_baselines3/common/prioritized_replay_buffer.py b/stable_baselines3/common/prioritized_replay_buffer.py index 93990144b..d209f497d 100644 --- a/stable_baselines3/common/prioritized_replay_buffer.py +++ b/stable_baselines3/common/prioritized_replay_buffer.py @@ -19,19 +19,12 @@ class SumTree: def __init__(self, buffer_size: int) -> None: self.nodes = np.zeros(2 * buffer_size - 1) + # The data array stores transition indices self.data = np.zeros(buffer_size) self.buffer_size = buffer_size self.pos = 0 self.full = False - def size(self) -> int: - """ - :return: The current size of the SumTree - """ - if self.full: - return self.buffer_size - return self.pos - @property def total_sum(self) -> float: """ @@ -41,14 +34,14 @@ def total_sum(self) -> float: """ return self.nodes[0].item() - def update(self, data_idx: int, value: float) -> None: + def update(self, leaf_node_idx: int, value: float) -> None: """ Update the priority of a leaf node. - :param data_idx: Index of the leaf node to update. + :param leaf_node_idx: Index of the leaf node to update. :param value: New priority value. """ - idx = data_idx + self.buffer_size - 1 # child index in tree array + idx = leaf_node_idx + self.buffer_size - 1 # child index in tree array change = value - self.nodes[idx] self.nodes[idx] = value parent = (idx - 1) // 2 @@ -58,24 +51,25 @@ def update(self, data_idx: int, value: float) -> None: def add(self, value: float, data: int) -> None: """ - Add a new transition with priority value. + Add a new transition with priority value, + it adds a new leaf node and update cumulative sum. :param value: Priority value. - :param data: Transition data. + :param data: Data for the new leaf node, storing transition index + in the case of the prioritized replay buffer. """ + # Note: transition_indices should be constant + # as the replay buffer already updates a pointer self.data[self.pos] = data self.update(self.pos, value) - self.pos += 1 - if self.pos == self.buffer_size: - self.full = True - self.pos = 0 + self.pos = (self.pos + 1) % self.buffer_size def get(self, cumulative_sum: float) -> Tuple[int, float, th.Tensor]: """ - Get a leaf node index, its priority value and transition data by cumulative_sum value. + Get a leaf node index, its priority value and transition index by cumulative_sum value. :param cumulative_sum: Cumulative sum value. - :return: Leaf node index, its priority value and transition data. + :return: Leaf node index, its priority value and transition index. """ assert cumulative_sum <= self.total_sum @@ -88,8 +82,8 @@ def get(self, cumulative_sum: float) -> Tuple[int, float, th.Tensor]: idx = right cumulative_sum = cumulative_sum - self.nodes[left] - data_idx = idx - self.buffer_size + 1 - return data_idx, self.nodes[idx].item(), self.data[data_idx] + leaf_node_idx = idx - self.buffer_size + 1 + return leaf_node_idx, self.nodes[idx].item(), self.data[leaf_node_idx] def __repr__(self) -> str: return f"SumTree(nodes={self.nodes!r}, data={self.data!r})" @@ -97,7 +91,7 @@ def __repr__(self) -> str: class PrioritizedReplayBuffer(ReplayBuffer): """ - Prioritized Replay Buffer. + Prioritized Replay Buffer (proportional priorities version). Paper: https://arxiv.org/abs/1511.05952 This code is inspired by: https://github.com/Howuhh/prioritized_experience_replay @@ -108,6 +102,8 @@ class PrioritizedReplayBuffer(ReplayBuffer): :param n_envs: Number of parallel environments :param alpha: How much prioritization is used (0 - no prioritization aka uniform case, 1 - full prioritization) :param beta: To what degree to use importance weights (0 - no corrections, 1 - full correction) + :param min_priority: Minimum priority, prevents zero probabilities, so that all samples + always have a non-zero probability to be sampled. """ def __init__( @@ -120,12 +116,13 @@ def __init__( alpha: float = 0.6, beta: float = 0.4, optimize_memory_usage: bool = False, + min_priority: float = 1e-8, ): super().__init__(buffer_size, observation_space, action_space, device, n_envs) assert optimize_memory_usage is False, "PrioritizedReplayBuffer doesn't support optimize_memory_usage=True" - self.min_priority = 1e-8 # minimal priority, prevents zero probabilities + self.min_priority = 1e-8 self.alpha = alpha self.beta = beta self.max_priority = self.min_priority # priority for new samples, init as eps @@ -169,7 +166,7 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB """ assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." - tree_indices = np.zeros(batch_size, dtype=np.uint32) + leaf_nodes_indices = np.zeros(batch_size, dtype=np.uint32) priorities = np.zeros((batch_size, 1)) sample_indices = np.zeros(batch_size, dtype=np.uint32) @@ -184,11 +181,11 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB # uniformely sample a value from the current segment cumulative_sum = np.random.uniform(start, end) - # tree_idx is a index of a sample in the tree, needed further to update priorities + # leaf_node_idx is a index of a sample in the tree, needed further to update priorities # sample_idx is a sample index in buffer, needed further to sample actual transitions - tree_idx, priority, sample_idx = self.tree.get(cumulative_sum) + leaf_node_idx, priority, sample_idx = self.tree.get(cumulative_sum) - tree_indices[batch_idx] = tree_idx + leaf_nodes_indices[batch_idx] = leaf_node_idx priorities[batch_idx] = priority sample_indices[batch_idx] = sample_idx @@ -201,12 +198,12 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB weights = (self.size() * probs) ** -self.beta weights = weights / weights.max() - env_indices = np.random.randint(0, high=self.n_envs, size=(batch_size,)) + # TODO: add proper support for multi env + # env_indices = np.random.randint(0, high=self.n_envs, size=(batch_size,)) + env_indices = np.zeros(batch_size, dtype=np.uint32) if self.optimize_memory_usage: - next_obs = self._normalize_obs( - self.observations[(sample_indices + 1) % self.buffer_size, env_indices, :], env - ) + next_obs = self._normalize_obs(self.observations[(sample_indices + 1) % self.buffer_size, env_indices, :], env) else: next_obs = self._normalize_obs(self.next_observations[sample_indices, env_indices, :], env) @@ -216,5 +213,26 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB next_obs, self.dones[sample_indices], self.rewards[sample_indices], + weights, ) - return ReplayBufferSamples(*tuple(map(self.to_torch, batch))) # type: ignore[arg-type] + return ReplayBufferSamples(*tuple(map(self.to_torch, batch)), leaf_nodes_indices) # type: ignore[arg-type] + + def update_priorities(self, leaf_nodes_indices: np.ndarray, td_errors: th.Tensor) -> None: + """ + Update transition priorities. + + :param leaf_nodes_indices: Indices for the leaf nodes to update + (correponding to the transitions) + :param td_errors: New priorities, td error in the case of + proportional prioritized replay buffer. + """ + td_errors = td_errors.detach().cpu().numpy().flatten() + + for leaf_node_idx, td_error in zip(leaf_nodes_indices, td_errors): + # Proportional prioritization priority = (abs(td_error) + eps) ^ alpha + # where eps is a small positive constant that prevents the edge-case of transitions not being + # revisited once their error is zero. (Section 3.3) + priority = (abs(td_error) + self.min_priority) ** self.alpha + self.tree.update(leaf_node_idx, priority) + # Update max priority for new samples + self.max_priority = max(self.max_priority, priority) diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 2f98ee198..37ae84926 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -48,6 +48,8 @@ class ReplayBufferSamples(NamedTuple): next_observations: th.Tensor dones: th.Tensor rewards: th.Tensor + weights: Union[th.Tensor, float] = 1.0 + leaf_nodes_indices: Optional[np.ndarray] = None class DictReplayBufferSamples(NamedTuple): @@ -56,6 +58,8 @@ class DictReplayBufferSamples(NamedTuple): next_observations: TensorDict dones: th.Tensor rewards: th.Tensor + weights: Union[th.Tensor, float] = 1.0 + leaf_nodes_indices: Optional[np.ndarray] = None class RolloutReturn(NamedTuple): diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 42e3d0df0..de7a4d7e2 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -9,6 +9,7 @@ from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork @@ -208,8 +209,18 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Retrieve the q-values for the actions from the replay buffer current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long()) - # Compute Huber loss (less sensitive to outliers) - loss = F.smooth_l1_loss(current_q_values, target_q_values) + # Special case when using PrioritizedReplayBuffer (PER) + if isinstance(self.replay_buffer, PrioritizedReplayBuffer): + # TD error in absolute value + td_error = th.abs(current_q_values - target_q_values) + # Weighted Huber loss using importance sampling weights + loss = (replay_data.weights * th.where(td_error < 1.0, 0.5 * td_error**2, td_error - 0.5)).mean() + # Update priorities, they will be proportional to the td error + assert replay_data.leaf_nodes_indices is not None, "Node leaf node indices provided" + self.replay_buffer.update_priorities(replay_data.leaf_nodes_indices, td_error) + else: + # Compute Huber loss (less sensitive to outliers) + loss = F.smooth_l1_loss(current_q_values, target_q_values) losses.append(loss.item()) # Optimize the policy From cc37cba55ae7784a9e0506b0be8a3674df85ab0f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 2 Oct 2023 18:31:21 +0200 Subject: [PATCH 13/17] Ignore mypy --- stable_baselines3/common/buffers.py | 2 +- stable_baselines3/common/prioritized_replay_buffer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 2d51b1da8..2ef6d0a49 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -321,7 +321,7 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1), self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env), ) - return ReplayBufferSamples(*tuple(map(self.to_torch, data))) + return ReplayBufferSamples(*tuple(map(self.to_torch, data))) # type: ignore[arg-type] @staticmethod def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike: diff --git a/stable_baselines3/common/prioritized_replay_buffer.py b/stable_baselines3/common/prioritized_replay_buffer.py index d209f497d..8d707f6f5 100644 --- a/stable_baselines3/common/prioritized_replay_buffer.py +++ b/stable_baselines3/common/prioritized_replay_buffer.py @@ -215,7 +215,7 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB self.rewards[sample_indices], weights, ) - return ReplayBufferSamples(*tuple(map(self.to_torch, batch)), leaf_nodes_indices) # type: ignore[arg-type] + return ReplayBufferSamples(*tuple(map(self.to_torch, batch)), leaf_nodes_indices) # type: ignore[arg-type,call-arg] def update_priorities(self, leaf_nodes_indices: np.ndarray, td_errors: th.Tensor) -> None: """ From b60ef03e211692a1ab539d43ffdee9b15f93e95a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 3 Oct 2023 10:23:34 +0200 Subject: [PATCH 14/17] Add beta schedule --- .../common/prioritized_replay_buffer.py | 29 ++++++++++++++++--- stable_baselines3/dqn/dqn.py | 4 ++- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/stable_baselines3/common/prioritized_replay_buffer.py b/stable_baselines3/common/prioritized_replay_buffer.py index 8d707f6f5..af72db4e6 100644 --- a/stable_baselines3/common/prioritized_replay_buffer.py +++ b/stable_baselines3/common/prioritized_replay_buffer.py @@ -6,6 +6,7 @@ from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.type_aliases import ReplayBufferSamples +from stable_baselines3.common.utils import get_linear_fn from stable_baselines3.common.vec_env.vec_normalize import VecNormalize @@ -102,6 +103,8 @@ class PrioritizedReplayBuffer(ReplayBuffer): :param n_envs: Number of parallel environments :param alpha: How much prioritization is used (0 - no prioritization aka uniform case, 1 - full prioritization) :param beta: To what degree to use importance weights (0 - no corrections, 1 - full correction) + :param final_beta: Value of beta at the end of training. + Linear annealing is used to interpolate between initial value of beta and final beta. :param min_priority: Minimum priority, prevents zero probabilities, so that all samples always have a non-zero probability to be sampled. """ @@ -113,8 +116,9 @@ def __init__( action_space: spaces.Space, device: Union[th.device, str] = "auto", n_envs: int = 1, - alpha: float = 0.6, + alpha: float = 0.5, beta: float = 0.4, + final_beta: float = 1.0, optimize_memory_usage: bool = False, min_priority: float = 1e-8, ): @@ -124,12 +128,25 @@ def __init__( self.min_priority = 1e-8 self.alpha = alpha - self.beta = beta self.max_priority = self.min_priority # priority for new samples, init as eps - + # Track the training progress remaining (from 1 to 0) + # this is used to update beta + self._current_progress_remaining = 1.0 + self.inital_beta = beta + self.final_beta = final_beta + self.beta_schedule = get_linear_fn( + self.inital_beta, + self.final_beta, + end_fraction=1.0, + ) # SumTree: data structure to store priorities self.tree = SumTree(buffer_size=buffer_size) + @property + def beta(self) -> float: + # Linear schedule + return self.beta_schedule(self._current_progress_remaining) + def add( self, obs: np.ndarray, @@ -217,7 +234,7 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB ) return ReplayBufferSamples(*tuple(map(self.to_torch, batch)), leaf_nodes_indices) # type: ignore[arg-type,call-arg] - def update_priorities(self, leaf_nodes_indices: np.ndarray, td_errors: th.Tensor) -> None: + def update_priorities(self, leaf_nodes_indices: np.ndarray, td_errors: th.Tensor, progress_remaining: float) -> None: """ Update transition priorities. @@ -225,7 +242,11 @@ def update_priorities(self, leaf_nodes_indices: np.ndarray, td_errors: th.Tensor (correponding to the transitions) :param td_errors: New priorities, td error in the case of proportional prioritized replay buffer. + :param progress_remaining: Current progress remaining (starts from 1 and ends to 0) + to linearly anneal beta from its start value to 1.0 at the end of training """ + # Update beta schedule + self._current_progress_remaining = progress_remaining td_errors = td_errors.detach().cpu().numpy().flatten() for leaf_node_idx, td_error in zip(leaf_nodes_indices, td_errors): diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index de7a4d7e2..8b4f730ed 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -217,7 +217,9 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: loss = (replay_data.weights * th.where(td_error < 1.0, 0.5 * td_error**2, td_error - 0.5)).mean() # Update priorities, they will be proportional to the td error assert replay_data.leaf_nodes_indices is not None, "Node leaf node indices provided" - self.replay_buffer.update_priorities(replay_data.leaf_nodes_indices, td_error) + self.replay_buffer.update_priorities( + replay_data.leaf_nodes_indices, td_error, self._current_progress_remaining + ) else: # Compute Huber loss (less sensitive to outliers) loss = F.smooth_l1_loss(current_q_values, target_q_values) From be002319667712080007aa50229781cbcf6d2039 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Fri, 24 May 2024 19:19:44 +0200 Subject: [PATCH 15/17] Minor fix in PER --- stable_baselines3/common/prioritized_replay_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/prioritized_replay_buffer.py b/stable_baselines3/common/prioritized_replay_buffer.py index af72db4e6..a97aca24f 100644 --- a/stable_baselines3/common/prioritized_replay_buffer.py +++ b/stable_baselines3/common/prioritized_replay_buffer.py @@ -126,7 +126,7 @@ def __init__( assert optimize_memory_usage is False, "PrioritizedReplayBuffer doesn't support optimize_memory_usage=True" - self.min_priority = 1e-8 + self.min_priority = min_priority self.alpha = alpha self.max_priority = self.min_priority # priority for new samples, init as eps # Track the training progress remaining (from 1 to 0) From fb1a9f7df3926494218535af422959d4756f8571 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 12 Jul 2024 13:41:34 +0200 Subject: [PATCH 16/17] Only convert to numpy if needed --- stable_baselines3/common/prioritized_replay_buffer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/prioritized_replay_buffer.py b/stable_baselines3/common/prioritized_replay_buffer.py index a97aca24f..41cce35ad 100644 --- a/stable_baselines3/common/prioritized_replay_buffer.py +++ b/stable_baselines3/common/prioritized_replay_buffer.py @@ -247,7 +247,8 @@ def update_priorities(self, leaf_nodes_indices: np.ndarray, td_errors: th.Tensor """ # Update beta schedule self._current_progress_remaining = progress_remaining - td_errors = td_errors.detach().cpu().numpy().flatten() + if isinstance(td_errors, th.Tensor): + td_errors = td_errors.detach().cpu().numpy().flatten() for leaf_node_idx, td_error in zip(leaf_nodes_indices, td_errors): # Proportional prioritization priority = (abs(td_error) + eps) ^ alpha From 150b09a858a783d7fdf64a96e41054d81972a50a Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 12 Jul 2024 14:01:07 +0200 Subject: [PATCH 17/17] Increase min priority to avoid division by zero --- stable_baselines3/common/prioritized_replay_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/prioritized_replay_buffer.py b/stable_baselines3/common/prioritized_replay_buffer.py index 41cce35ad..f68388365 100644 --- a/stable_baselines3/common/prioritized_replay_buffer.py +++ b/stable_baselines3/common/prioritized_replay_buffer.py @@ -120,7 +120,7 @@ def __init__( beta: float = 0.4, final_beta: float = 1.0, optimize_memory_usage: bool = False, - min_priority: float = 1e-8, + min_priority: float = 1e-6, ): super().__init__(buffer_size, observation_space, action_space, device, n_envs)