diff --git a/test/mocking_classes.py b/test/mocking_classes.py index eb517429c08..8d4c5fe961e 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1038,11 +1038,13 @@ def _step( tensordict: TensorDictBase, ) -> TensorDictBase: action = tensordict.get(self.action_key) + try: + device = self.full_action_spec[self.action_key].device + except KeyError: + device = self.device self.count += action.to( dtype=torch.int, - device=self.full_action_spec[self.action_key].device - if self.device is None - else self.device, + device=device if self.device is None else self.device, ) tensordict = TensorDict( source={ diff --git a/test/test_env.py b/test/test_env.py index ab854a3b4be..81708b0b9a6 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3526,6 +3526,34 @@ def test_single_env_spec(): assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape)) +def test_auto_spec(): + env = CountingEnv() + td = env.reset() + + policy = lambda td, action_spec=env.full_action_spec.clone(): td.update( + action_spec.rand() + ) + + env.full_observation_spec = Composite( + shape=env.full_observation_spec.shape, device=env.full_observation_spec.device + ) + env.full_action_spec = Composite( + shape=env.full_action_spec.shape, device=env.full_action_spec.device + ) + env.full_reward_spec = Composite( + shape=env.full_reward_spec.shape, device=env.full_reward_spec.device + ) + env.full_done_spec = Composite( + shape=env.full_done_spec.shape, device=env.full_done_spec.device + ) + env.full_state_spec = Composite( + shape=env.full_state_spec.shape, device=env.full_state_spec.device + ) + env._action_keys = ["action"] + env.auto_specs_(policy, tensordict=td.copy()) + env.check_env_specs(tensordict=td.copy()) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 0611af20b45..5c74b005541 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -6,9 +6,9 @@ from __future__ import annotations import abc -import functools import warnings from copy import deepcopy +from functools import partial, wraps from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple import numpy as np @@ -33,6 +33,7 @@ _StepMDP, _terminated_or_truncated, _update_during_reset, + check_env_specs as check_env_specs_func, get_available_libraries, ) @@ -390,6 +391,141 @@ def __init__( self.batch_size = torch.Size(batch_size) self._run_type_checks = run_type_checks self._allow_done_after_reset = allow_done_after_reset + self.output_spec.unlock_() + self.input_spec.unlock_() + if "full_observation_spec" not in self.output_spec: + self.output_spec["full_observation_spec"] = Composite() + if "full_done_spec" not in self.output_spec: + self.output_spec["full_done_spec"] = Composite() + if "full_reward_spec" not in self.output_spec: + self.output_spec["full_reward_spec"] = Composite() + + if "full_state_spec" not in self.input_spec: + self.input_spec["full_state_spec"] = Composite() + if "full_action_spec" not in self.input_spec: + self.input_spec["full_action_spec"] = Composite() + + self.output_spec.lock_() + self.input_spec.lock_() + + def auto_specs_( + self, + policy: Callable[[TensorDictBase], TensorDictBase], + *, + tensordict: TensorDictBase | None = None, + action_key: NestedKey | List[NestedKey] = "action", + done_key: NestedKey | List[NestedKey] | None = None, + observation_key: NestedKey | List[NestedKey] = "observation", + reward_key: NestedKey | List[NestedKey] = "reward", + batch_size: torch.Size | None = None, + ): + """Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy. + + This method performs a rollout using the provided policy to infer the input and output specifications of the environment. + It updates the environment's specs for actions, observations, rewards, and done signals based on the data collected + during the rollout. + + Args: + policy (Callable[[TensorDictBase], TensorDictBase]): + A callable policy that takes a `TensorDictBase` as input and returns a `TensorDictBase` as output. + This policy is used to perform the rollout and determine the specs. + + Keyword Args: + tensordict (TensorDictBase, optional): + An optional `TensorDictBase` instance to be used as the initial state for the rollout. + If not provided, the environment's `reset` method will be called to obtain the initial state. + action_key (NestedKey or List[NestedKey], optional): + The key(s) used to identify actions in the `TensorDictBase`. Defaults to "action". + done_key (NestedKey or List[NestedKey], optional): + The key(s) used to identify done signals in the `TensorDictBase`. Defaults to ``None``, which will + attempt to use ["done", "terminated", "truncated"] as potential keys. + observation_key (NestedKey or List[NestedKey], optional): + The key(s) used to identify observations in the `TensorDictBase`. Defaults to "observation". + reward_key (NestedKey or List[NestedKey], optional): + The key(s) used to identify rewards in the `TensorDictBase`. Defaults to "reward". + + Returns: + EnvBase: The environment instance with updated specs. + + Raises: + RuntimeError: If there are keys in the output specs that are not accounted for in the provided keys. + """ + if self.batch_locked or tensordict is None: + batch_size = self.batch_size + else: + batch_size = tensordict.batch_size + if tensordict is None: + tensordict = self.reset() + + # Input specs + tensordict = policy(tensordict) + step_0 = self.step(tensordict.copy()) + tensordict2 = step_0.get("next").copy() + step_1 = self.step(policy(tensordict2).copy()) + nexts_0: TensorDictBase = step_0.pop("next") + nexts_1: TensorDictBase = step_1.pop("next") + + input_spec_stack = {} + tensordict.apply( + partial(_tensor_to_spec, stack=input_spec_stack), + tensordict2, + named=True, + nested_keys=True, + ) + input_spec = Composite(input_spec_stack, batch_size=batch_size) + if not self.batch_locked and batch_size != self.batch_size: + while input_spec.shape: + input_spec = input_spec[0] + if isinstance(action_key, NestedKey): + action_key = [action_key] + full_action_spec = input_spec.separates(*action_key, default=None) + + # Output specs + + output_spec_stack = {} + nexts_0.apply( + partial(_tensor_to_spec, stack=output_spec_stack), + nexts_1, + named=True, + nested_keys=True, + ) + + output_spec = Composite(output_spec_stack, batch_size=batch_size) + if not self.batch_locked and batch_size != self.batch_size: + while output_spec.shape: + output_spec = output_spec[0] + + if done_key is None: + done_key = ["done", "terminated", "truncated"] + full_done_spec = output_spec.separates(*done_key, default=None) + if full_done_spec is not None: + self.full_done_spec = full_done_spec + + if isinstance(reward_key, NestedKey): + reward_key = [reward_key] + full_reward_spec = output_spec.separates(*reward_key, default=None) + + if isinstance(observation_key, NestedKey): + observation_key = [observation_key] + full_observation_spec = output_spec.separates(*observation_key, default=None) + if not output_spec.is_empty(recurse=True): + raise RuntimeError( + f"Keys {list(output_spec.keys(True, True))} are unaccounted for." + ) + + if full_action_spec is not None: + self.full_action_spec = full_action_spec + if full_done_spec is not None: + self.full_done_specs = full_done_spec + if full_observation_spec is not None: + self.full_observation_spec = full_observation_spec + if full_reward_spec is not None: + self.full_reward_spec = full_reward_spec + full_state_spec = input_spec + self.full_state_spec = full_state_spec + + return self + @wraps(check_env_specs_func) def check_env_specs(self, *args, **kwargs): return check_env_specs_func(self, *args, **kwargs) @@ -665,8 +801,6 @@ def action_keys(self) -> List[NestedKey]: if action_keys is not None: return action_keys keys = self.full_action_spec.keys(True, True) - if not len(keys): - raise AttributeError("Could not find action spec") keys = sorted(keys, key=_repr_by_depth) self.__dict__["_action_keys"] = keys return keys @@ -825,15 +959,7 @@ def action_spec(self, value: TensorSpec) -> None: "Please use `env.action_spec_unbatched = value` to set unbatched versions instead." ) - if isinstance(value, Composite): - for _ in value.values(True, True): # noqa: B007 - break - else: - raise RuntimeError( - "An empty Composite was passed for the action spec. " - "This is currently not permitted." - ) - else: + if not isinstance(value, Composite): value = Composite( action=value.to(device), shape=self.batch_size, device=device ) @@ -890,7 +1016,6 @@ def reward_keys(self) -> List[NestedKey]: reward_keys = self.__dict__.get("_reward_keys") if reward_keys is not None: return reward_keys - reward_keys = sorted(self.full_reward_spec.keys(True, True), key=_repr_by_depth) self.__dict__["_reward_keys"] = reward_keys return reward_keys @@ -1028,15 +1153,7 @@ def reward_spec(self, value: TensorSpec) -> None: f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). " "Please use `env.reward_spec_unbatched = value` to set unbatched versions instead." ) - if isinstance(value, Composite): - for _ in value.values(True, True): # noqa: B007 - break - else: - raise RuntimeError( - "An empty Composite was passed for the reward spec. " - "This is currently not permitted." - ) - else: + if not isinstance(value, Composite): value = Composite( reward=value.to(device), shape=self.batch_size, device=device ) @@ -1317,15 +1434,7 @@ def done_spec(self, value: TensorSpec) -> None: raise ValueError( f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - if isinstance(value, Composite): - for _ in value.values(True, True): # noqa: B007 - break - else: - raise RuntimeError( - "An empty Composite was passed for the done spec. " - "This is currently not permitted." - ) - else: + if not isinstance(value, Composite): value = Composite( done=value.to(device), terminated=value.to(device), @@ -2034,7 +2143,7 @@ def _register_gym( if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymWrapper, entry_point=entry_point, info_keys=info_keys, @@ -2083,7 +2192,7 @@ def _register_gym( # noqa: F811 if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymWrapper, entry_point=entry_point, info_keys=info_keys, @@ -2137,7 +2246,7 @@ def _register_gym( # noqa: F811 if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymWrapper, entry_point=entry_point, info_keys=info_keys, @@ -2194,7 +2303,7 @@ def _register_gym( # noqa: F811 if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymWrapper, entry_point=entry_point, info_keys=info_keys, @@ -2253,7 +2362,7 @@ def _register_gym( # noqa: F811 ) if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymWrapper, entry_point=entry_point, info_keys=info_keys, @@ -2292,7 +2401,7 @@ def _register_gym( # noqa: F811 if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymnasiumWrapper, entry_point=entry_point, info_keys=info_keys, @@ -3421,11 +3530,11 @@ def _get_sync_func(policy_device, env_device): if policy_device is not None and policy_device.type == "cuda": if env_device is None or env_device.type == "cuda": return torch.cuda.synchronize - return functools.partial(torch.cuda.synchronize, device=policy_device) + return partial(torch.cuda.synchronize, device=policy_device) if env_device is not None and env_device.type == "cuda": if policy_device is None: return torch.cuda.synchronize - return functools.partial(torch.cuda.synchronize, device=env_device) + return partial(torch.cuda.synchronize, device=env_device) return torch.cuda.synchronize if torch.backends.mps.is_available(): return torch.mps.synchronize @@ -3443,3 +3552,11 @@ def _has_dynamic_specs(spec: Composite): any(s == -1 for s in spec.shape) for spec in spec.values(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS) ) + + +def _tensor_to_spec(name, leaf, leaf_compare=None, *, stack): + shape = leaf.shape + if leaf_compare is not None: + shape_compare = leaf_compare.shape + shape = [s0 if s0 == s1 else -1 for s0, s1 in zip(shape, shape_compare)] + stack[name] = Unbounded(shape, device=leaf.device, dtype=leaf.dtype) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index c83591acb63..da3d4175f9d 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -16,8 +16,6 @@ from enum import Enum from typing import Any, Dict, List, Union -import tensordict.base - import torch from tensordict import ( @@ -29,7 +27,7 @@ TensorDictBase, unravel_key, ) -from tensordict.base import _is_leaf_nontensor +from tensordict.base import _default_is_leaf, _is_leaf_nontensor from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.nn.probabilistic import ( # noqa interaction_type as exploration_type, @@ -691,7 +689,11 @@ def _per_level_env_check(data0, data1, check_dtype): def check_env_specs( - env, return_contiguous=True, check_dtype=True, seed: int | None = None + env, + return_contiguous=True, + check_dtype=True, + seed: int | None = None, + tensordict: TensorDictBase | None = None, ): """Tests an environment specs against the results of short rollout. @@ -715,6 +717,7 @@ def check_env_specs( setting the rng state back to what is was isn't a feature of most environment, we leave it to the user to accomplish that. Defaults to ``None``. + tensordict (TensorDict, optional): an optional tensordict instance to use for reset. Caution: this function resets the env seed. It should be used "offline" to check that an env is adequately constructed, but it may affect the seeding @@ -732,7 +735,16 @@ def check_env_specs( ) fake_tensordict = env.fake_tensordict() - real_tensordict = env.rollout(3, return_contiguous=return_contiguous) + if not env._batch_locked and tensordict is not None: + shape = torch.broadcast_shapes(fake_tensordict.shape, tensordict.shape) + fake_tensordict = fake_tensordict.expand(shape) + tensordict = tensordict.expand(shape) + real_tensordict = env.rollout( + 3, + return_contiguous=return_contiguous, + tensordict=tensordict, + auto_reset=tensordict is not None, + ) if return_contiguous: fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1) @@ -743,17 +755,17 @@ def check_env_specs( ) # eliminate empty containers fake_tensordict_select = fake_tensordict.select( - *fake_tensordict.keys(True, True, is_leaf=tensordict.base._default_is_leaf) + *fake_tensordict.keys(True, True, is_leaf=_default_is_leaf) ) real_tensordict_select = real_tensordict.select( - *real_tensordict.keys(True, True, is_leaf=tensordict.base._default_is_leaf) + *real_tensordict.keys(True, True, is_leaf=_default_is_leaf) ) # check keys fake_tensordict_keys = set( - fake_tensordict.keys(True, True, is_leaf=tensordict.base._is_leaf_nontensor) + fake_tensordict.keys(True, True, is_leaf=_is_leaf_nontensor) ) real_tensordict_keys = set( - real_tensordict.keys(True, True, is_leaf=tensordict.base._is_leaf_nontensor) + real_tensordict.keys(True, True, is_leaf=_is_leaf_nontensor) ) if fake_tensordict_keys != real_tensordict_keys: raise AssertionError(