Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Transform._call is not called from reset #2595

Open
codingWhale13 opened this issue Nov 22, 2024 · 0 comments
Open

[BUG] Transform._call is not called from reset #2595

codingWhale13 opened this issue Nov 22, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@codingWhale13
Copy link

Describe the bug

The docstring of Transform._call() says it will be called by TransformedEnv.step() and TransformedEnv.reset(). However, resetting the transformed environment does not trigger _call().

To Reproduce

from tensordict import TensorDictBase
from torchrl.envs.transforms import Transform, TransformedEnv
from torchrl.envs import GymEnv

class PrintHiTransform(Transform):
    def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
        print("Hi")
        return tensordict

env = GymEnv("Pendulum-v1")
env = TransformedEnv(env, PrintHiTransform())

print("Calling env.reset()")
initial_state = env.reset()  # Does NOT print "Hi"

action = env.rand_action(initial_state)
print("Calling env.step()")
env.step(action)  # Prints "Hi" as desired

System info

Describe the characteristic of your environment:

  • Python version: 3.9
  • torchrl==0.6.0
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

0.6.0 2.0.2 3.9.20 | packaged by conda-forge | (main, Sep 30 2024, 17:49:10)
[GCC 13.3.0] linux

Possible fix

I would fix this by changing Transform._reset().

def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
    """Resets a transform if it is stateful."""
    return self._call(tensordict_reset)  # Was before: return tensordict_reset

Maybe there's a good reason why this is not the case, but to me it seems inconsistent: Why should _step call _call but _reset does not?

@codingWhale13 codingWhale13 added the bug Something isn't working label Nov 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants