diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f56701c390d..16f0be8998f 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2080,7 +2080,9 @@ def __init__( ): if in_keys is None: in_keys = IMAGE_KEYS - super(GrayScale, self).__init__(in_keys=in_keys, out_keys=out_keys) + if out_keys is None: + out_keys = copy(in_keys) + super().__init__(in_keys=in_keys, out_keys=out_keys) def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: observation = F.rgb_to_grayscale(observation)