diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9c461f6ae..cc417a9d3 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -30,6 +30,8 @@ Bug Fixes: - Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122) - Set requirement numpy<2.0 until PyTorch is compatible (https://github.com/pytorch/pytorch/issues/107302) - Updated DQN optimizer input to only include q_network parameters, removing the target_q_network ones (@corentinlger) +- Fixed ``test_buffers.py::test_device`` which was not actually checking the device of tensors (@rhaps0dy) + `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/tests/test_buffers.py b/tests/test_buffers.py index da6b44a34..18171dd21 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -139,18 +139,25 @@ def test_device_buffer(replay_buffer_cls, device): # Get data from the buffer if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: + # get returns an iterator over minibatches data = buffer.get(50) elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]: - data = buffer.sample(50) + data = [buffer.sample(50)] # Check that all data are on the desired device desired_device = get_device(device).type - for value in list(data): - if isinstance(value, dict): - for key in value.keys(): - assert value[key].device.type == desired_device - elif isinstance(value, th.Tensor): - assert value.device.type == desired_device + for minibatch in list(data): + for value in minibatch: + if isinstance(value, dict): + for key in value.keys(): + assert value[key].device.type == desired_device + elif isinstance(value, th.Tensor): + assert value.device.type == desired_device + elif isinstance(value, np.ndarray): + # For prioritized replay weights/indices + pass + else: + raise TypeError(f"Unknown value type: {type(value)}") def test_custom_rollout_buffer():