Skip to content

Commit

Permalink
Add np.ndarray as a recognized type for TB histograms. (#1635)
Browse files Browse the repository at this point in the history
* Add np.ndarray as a recognized type for TB histograms.

Torch histograms allow th.Tensor, np.ndarray, and caffe2 formatted strings. This commits expands the TensorBoardOutputFormat's capabilities to log the two former types.

* Update changelog to reflect bug fix

* fix: try/catch for if either np or torch aren't at the required versions. See #1635 for more details

* fix: Add comment describing the test for when add_histogram should not have been called

* Cleanup

---------

Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
iwishiwasaneagle and araffin authored Aug 2, 2024
1 parent 6ad6fa5 commit 4a1137b
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 13 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.4.0a7 (WIP)
Release 2.4.0a8 (WIP)
--------------------------

.. note::
Expand All @@ -19,6 +19,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)

Bug Fixes:
^^^^^^^^^^
Expand Down
5 changes: 3 additions & 2 deletions stable_baselines3/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,9 @@ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, .
else:
self.writer.add_scalar(key, value, step)

if isinstance(value, th.Tensor):
self.writer.add_histogram(key, value, step)
if isinstance(value, (th.Tensor, np.ndarray)):
# Convert to Torch so it works with numpy<1.24 and torch<2.0
self.writer.add_histogram(key, th.as_tensor(value), step)

if isinstance(value, Video):
self.writer.add_video(key, value.frames, step, value.fps)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a7
2.4.0a8
70 changes: 61 additions & 9 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"f": np.array(1),
"g": np.array([[[1]]]),
"h": 'this ", ;is a \n tes:,t',
"i": th.ones(3),
}

KEY_EXCLUDED = {}
Expand Down Expand Up @@ -176,6 +177,9 @@ def test_main(tmp_path):
logger.record_mean("b", -22.5)
logger.record_mean("b", -44.4)
logger.record("a", 5.5)
# Converted to string:
logger.record("hist1", th.ones(2))
logger.record("hist2", np.ones(2))
logger.dump()

logger.record("a", "longasslongasslongasslongasslongasslongassvalue")
Expand Down Expand Up @@ -241,7 +245,7 @@ def is_moviepy_installed():


@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
def test_report_video_to_unsupported_format_raises_error(tmp_path, unsupported_format):
def test_unsupported_video_format(tmp_path, unsupported_format):
writer = make_output_format(unsupported_format, tmp_path)

with pytest.raises(FormatUnsupportedError) as exec_info:
Expand All @@ -251,6 +255,54 @@ def test_report_video_to_unsupported_format_raises_error(tmp_path, unsupported_f
writer.close()


@pytest.mark.parametrize(
"histogram",
[
th.rand(100),
np.random.rand(100),
np.ones(1),
np.ones(1, dtype="int"),
],
)
def test_log_histogram(tmp_path, read_log, histogram):
pytest.importorskip("tensorboard")

writer = make_output_format("tensorboard", tmp_path)
writer.write({"data": histogram}, key_excluded={"data": ()})

log = read_log("tensorboard")

assert not log.empty
assert any("data" in line for line in log.lines)
assert any("Histogram" in line for line in log.lines)

writer.close()


@pytest.mark.parametrize(
"histogram",
[
list(np.random.rand(100)),
tuple(np.random.rand(100)),
"1 2 3 4",
np.ones(1).item(),
th.ones(1).item(),
],
)
def test_unsupported_type_histogram(tmp_path, read_log, histogram):
"""
Check that other types aren't accidentally logged as a Histogram
"""
pytest.importorskip("tensorboard")

writer = make_output_format("tensorboard", tmp_path)
writer.write({"data": histogram}, key_excluded={"data": ()})

assert all("Histogram" not in line for line in read_log("tensorboard").lines)

writer.close()


def test_report_image_to_tensorboard(tmp_path, read_log):
pytest.importorskip("tensorboard")

Expand All @@ -263,7 +315,7 @@ def test_report_image_to_tensorboard(tmp_path, read_log):


@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
def test_report_image_to_unsupported_format_raises_error(tmp_path, unsupported_format):
def test_unsupported_image_format(tmp_path, unsupported_format):
writer = make_output_format(unsupported_format, tmp_path)

with pytest.raises(FormatUnsupportedError) as exec_info:
Expand All @@ -287,7 +339,7 @@ def test_report_figure_to_tensorboard(tmp_path, read_log):


@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_format):
def test_unsupported_figure_format(tmp_path, unsupported_format):
writer = make_output_format(unsupported_format, tmp_path)

with pytest.raises(FormatUnsupportedError) as exec_info:
Expand All @@ -300,7 +352,7 @@ def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_


@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
def test_report_hparam_to_unsupported_format_raises_error(tmp_path, unsupported_format):
def test_unsupported_hparam(tmp_path, unsupported_format):
writer = make_output_format(unsupported_format, tmp_path)

with pytest.raises(FormatUnsupportedError) as exec_info:
Expand Down Expand Up @@ -419,9 +471,9 @@ def test_fps_no_div_zero(algo):
model.learn(total_timesteps=100)


def test_human_output_format_no_crash_on_same_keys_different_tags():
o = HumanOutputFormat(sys.stdout, max_length=60)
o.write(
def test_human_output_same_keys_different_tags():
human_out = HumanOutputFormat(sys.stdout, max_length=60)
human_out.write(
{"key1/foo": "value1", "key1/bar": "value2", "key2/bizz": "value3", "key2/foo": "value4"},
{"key1/foo": None, "key2/bizz": None, "key1/bar": None, "key2/foo": None},
)
Expand All @@ -439,7 +491,7 @@ def test_ep_buffers_stats_window_size(algo, stats_window_size):


@pytest.mark.parametrize("base_class", [object, TextIOBase])
def test_human_output_format_custom_test_io(base_class):
def test_human_out_custom_text_io(base_class):
class DummyTextIO(base_class):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -531,7 +583,7 @@ def step(self, action):
return self.observation_space.sample(), 0.0, False, truncated, info


def test_rollout_success_rate_on_policy_algorithm(tmp_path):
def test_rollout_success_rate_onpolicy_algo(tmp_path):
"""
Test if the rollout/success_rate information is correctly logged with on policy algorithms
Expand Down

0 comments on commit 4a1137b

Please sign in to comment.