From 4a1137ba3ac0ff0ae095aca564edc82ec37b7f1c Mon Sep 17 00:00:00 2001 From: Jan-Hendrik Ewers Date: Fri, 2 Aug 2024 10:55:27 +0100 Subject: [PATCH] Add np.ndarray as a recognized type for TB histograms. (#1635) * Add np.ndarray as a recognized type for TB histograms. Torch histograms allow th.Tensor, np.ndarray, and caffe2 formatted strings. This commits expands the TensorBoardOutputFormat's capabilities to log the two former types. * Update changelog to reflect bug fix * fix: try/catch for if either np or torch aren't at the required versions. See https://github.com/DLR-RM/stable-baselines3/pull/1635 for more details * fix: Add comment describing the test for when add_histogram should not have been called * Cleanup --------- Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 3 +- stable_baselines3/common/logger.py | 5 ++- stable_baselines3/version.txt | 2 +- tests/test_logger.py | 70 ++++++++++++++++++++++++++---- 4 files changed, 67 insertions(+), 13 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 37a035478..9c461f6ae 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a7 (WIP) +Release 2.4.0a8 (WIP) -------------------------- .. note:: @@ -19,6 +19,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ) +- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle) Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 363a9d2e8..8ceda71ed 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -412,8 +412,9 @@ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, . else: self.writer.add_scalar(key, value, step) - if isinstance(value, th.Tensor): - self.writer.add_histogram(key, value, step) + if isinstance(value, (th.Tensor, np.ndarray)): + # Convert to Torch so it works with numpy<1.24 and torch<2.0 + self.writer.add_histogram(key, th.as_tensor(value), step) if isinstance(value, Video): self.writer.add_video(key, value.frames, step, value.fps) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index f5230e413..ee717ba15 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a7 +2.4.0a8 diff --git a/tests/test_logger.py b/tests/test_logger.py index dfa3691ed..bc18bf2ce 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -44,6 +44,7 @@ "f": np.array(1), "g": np.array([[[1]]]), "h": 'this ", ;is a \n tes:,t', + "i": th.ones(3), } KEY_EXCLUDED = {} @@ -176,6 +177,9 @@ def test_main(tmp_path): logger.record_mean("b", -22.5) logger.record_mean("b", -44.4) logger.record("a", 5.5) + # Converted to string: + logger.record("hist1", th.ones(2)) + logger.record("hist2", np.ones(2)) logger.dump() logger.record("a", "longasslongasslongasslongasslongasslongassvalue") @@ -241,7 +245,7 @@ def is_moviepy_installed(): @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) -def test_report_video_to_unsupported_format_raises_error(tmp_path, unsupported_format): +def test_unsupported_video_format(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: @@ -251,6 +255,54 @@ def test_report_video_to_unsupported_format_raises_error(tmp_path, unsupported_f writer.close() +@pytest.mark.parametrize( + "histogram", + [ + th.rand(100), + np.random.rand(100), + np.ones(1), + np.ones(1, dtype="int"), + ], +) +def test_log_histogram(tmp_path, read_log, histogram): + pytest.importorskip("tensorboard") + + writer = make_output_format("tensorboard", tmp_path) + writer.write({"data": histogram}, key_excluded={"data": ()}) + + log = read_log("tensorboard") + + assert not log.empty + assert any("data" in line for line in log.lines) + assert any("Histogram" in line for line in log.lines) + + writer.close() + + +@pytest.mark.parametrize( + "histogram", + [ + list(np.random.rand(100)), + tuple(np.random.rand(100)), + "1 2 3 4", + np.ones(1).item(), + th.ones(1).item(), + ], +) +def test_unsupported_type_histogram(tmp_path, read_log, histogram): + """ + Check that other types aren't accidentally logged as a Histogram + """ + pytest.importorskip("tensorboard") + + writer = make_output_format("tensorboard", tmp_path) + writer.write({"data": histogram}, key_excluded={"data": ()}) + + assert all("Histogram" not in line for line in read_log("tensorboard").lines) + + writer.close() + + def test_report_image_to_tensorboard(tmp_path, read_log): pytest.importorskip("tensorboard") @@ -263,7 +315,7 @@ def test_report_image_to_tensorboard(tmp_path, read_log): @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) -def test_report_image_to_unsupported_format_raises_error(tmp_path, unsupported_format): +def test_unsupported_image_format(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: @@ -287,7 +339,7 @@ def test_report_figure_to_tensorboard(tmp_path, read_log): @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) -def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_format): +def test_unsupported_figure_format(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: @@ -300,7 +352,7 @@ def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_ @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) -def test_report_hparam_to_unsupported_format_raises_error(tmp_path, unsupported_format): +def test_unsupported_hparam(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: @@ -419,9 +471,9 @@ def test_fps_no_div_zero(algo): model.learn(total_timesteps=100) -def test_human_output_format_no_crash_on_same_keys_different_tags(): - o = HumanOutputFormat(sys.stdout, max_length=60) - o.write( +def test_human_output_same_keys_different_tags(): + human_out = HumanOutputFormat(sys.stdout, max_length=60) + human_out.write( {"key1/foo": "value1", "key1/bar": "value2", "key2/bizz": "value3", "key2/foo": "value4"}, {"key1/foo": None, "key2/bizz": None, "key1/bar": None, "key2/foo": None}, ) @@ -439,7 +491,7 @@ def test_ep_buffers_stats_window_size(algo, stats_window_size): @pytest.mark.parametrize("base_class", [object, TextIOBase]) -def test_human_output_format_custom_test_io(base_class): +def test_human_out_custom_text_io(base_class): class DummyTextIO(base_class): def __init__(self) -> None: super().__init__() @@ -531,7 +583,7 @@ def step(self, action): return self.observation_space.sample(), 0.0, False, truncated, info -def test_rollout_success_rate_on_policy_algorithm(tmp_path): +def test_rollout_success_rate_onpolicy_algo(tmp_path): """ Test if the rollout/success_rate information is correctly logged with on policy algorithms