From 559d157860d2f107a41bb9048d29f1762bc451a1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 2 Oct 2023 06:15:26 -0400 Subject: [PATCH] amend --- torchrl/envs/transforms/transforms.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f56701c390d..16f0be8998f 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2080,7 +2080,9 @@ def __init__( ): if in_keys is None: in_keys = IMAGE_KEYS - super(GrayScale, self).__init__(in_keys=in_keys, out_keys=out_keys) + if out_keys is None: + out_keys = copy(in_keys) + super().__init__(in_keys=in_keys, out_keys=out_keys) def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: observation = F.rgb_to_grayscale(observation)