Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 23, 2024
1 parent 062be0f commit 655115d
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 51 deletions.
8 changes: 5 additions & 3 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,11 +1038,13 @@ def _step(
tensordict: TensorDictBase,
) -> TensorDictBase:
action = tensordict.get(self.action_key)
try:
device = self.full_action_spec[self.action_key].device
except KeyError:
device = self.device
self.count += action.to(
dtype=torch.int,
device=self.full_action_spec[self.action_key].device
if self.device is None
else self.device,
device=device if self.device is None else self.device,
)
tensordict = TensorDict(
source={
Expand Down
28 changes: 28 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3526,6 +3526,34 @@ def test_single_env_spec():
assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape))


def test_auto_spec():
env = CountingEnv()
td = env.reset()

policy = lambda td, action_spec=env.full_action_spec.clone(): td.update(
action_spec.rand()
)

env.full_observation_spec = Composite(
shape=env.full_observation_spec.shape, device=env.full_observation_spec.device
)
env.full_action_spec = Composite(
shape=env.full_action_spec.shape, device=env.full_action_spec.device
)
env.full_reward_spec = Composite(
shape=env.full_reward_spec.shape, device=env.full_reward_spec.device
)
env.full_done_spec = Composite(
shape=env.full_done_spec.shape, device=env.full_done_spec.device
)
env.full_state_spec = Composite(
shape=env.full_state_spec.shape, device=env.full_state_spec.device
)
env._action_keys = ["action"]
env.auto_specs_(policy, tensordict=td.copy())
env.check_env_specs(tensordict=td.copy())


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
195 changes: 156 additions & 39 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from __future__ import annotations

import abc
import functools
import warnings
from copy import deepcopy
from functools import partial, wraps
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple

import numpy as np
Expand All @@ -33,6 +33,7 @@
_StepMDP,
_terminated_or_truncated,
_update_during_reset,
check_env_specs as check_env_specs_func,
get_available_libraries,
)

Expand Down Expand Up @@ -390,6 +391,141 @@ def __init__(
self.batch_size = torch.Size(batch_size)
self._run_type_checks = run_type_checks
self._allow_done_after_reset = allow_done_after_reset
self.output_spec.unlock_()
self.input_spec.unlock_()
if "full_observation_spec" not in self.output_spec:
self.output_spec["full_observation_spec"] = Composite()
if "full_done_spec" not in self.output_spec:
self.output_spec["full_done_spec"] = Composite()
if "full_reward_spec" not in self.output_spec:
self.output_spec["full_reward_spec"] = Composite()

if "full_state_spec" not in self.input_spec:
self.input_spec["full_state_spec"] = Composite()
if "full_action_spec" not in self.input_spec:
self.input_spec["full_action_spec"] = Composite()

self.output_spec.lock_()
self.input_spec.lock_()

def auto_specs_(
self,
policy: Callable[[TensorDictBase], TensorDictBase],
*,
tensordict: TensorDictBase | None = None,
action_key: NestedKey | List[NestedKey] = "action",
done_key: NestedKey | List[NestedKey] | None = None,
observation_key: NestedKey | List[NestedKey] = "observation",
reward_key: NestedKey | List[NestedKey] = "reward",
batch_size: torch.Size | None = None,
):
"""Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
This method performs a rollout using the provided policy to infer the input and output specifications of the environment.
It updates the environment's specs for actions, observations, rewards, and done signals based on the data collected
during the rollout.
Args:
policy (Callable[[TensorDictBase], TensorDictBase]):
A callable policy that takes a `TensorDictBase` as input and returns a `TensorDictBase` as output.
This policy is used to perform the rollout and determine the specs.
Keyword Args:
tensordict (TensorDictBase, optional):
An optional `TensorDictBase` instance to be used as the initial state for the rollout.
If not provided, the environment's `reset` method will be called to obtain the initial state.
action_key (NestedKey or List[NestedKey], optional):
The key(s) used to identify actions in the `TensorDictBase`. Defaults to "action".
done_key (NestedKey or List[NestedKey], optional):
The key(s) used to identify done signals in the `TensorDictBase`. Defaults to ``None``, which will
attempt to use ["done", "terminated", "truncated"] as potential keys.
observation_key (NestedKey or List[NestedKey], optional):
The key(s) used to identify observations in the `TensorDictBase`. Defaults to "observation".
reward_key (NestedKey or List[NestedKey], optional):
The key(s) used to identify rewards in the `TensorDictBase`. Defaults to "reward".
Returns:
EnvBase: The environment instance with updated specs.
Raises:
RuntimeError: If there are keys in the output specs that are not accounted for in the provided keys.
"""
if self.batch_locked or tensordict is None:
batch_size = self.batch_size
else:
batch_size = tensordict.batch_size
if tensordict is None:
tensordict = self.reset()

# Input specs
tensordict = policy(tensordict)
step_0 = self.step(tensordict.copy())
tensordict2 = step_0.get("next").copy()
step_1 = self.step(policy(tensordict2).copy())
nexts_0: TensorDictBase = step_0.pop("next")
nexts_1: TensorDictBase = step_1.pop("next")

input_spec_stack = {}
tensordict.apply(
partial(_tensor_to_spec, stack=input_spec_stack),
tensordict2,
named=True,
nested_keys=True,
)
input_spec = Composite(input_spec_stack, batch_size=batch_size)
if not self.batch_locked and batch_size != self.batch_size:
while input_spec.shape:
input_spec = input_spec[0]
if isinstance(action_key, NestedKey):
action_key = [action_key]
full_action_spec = input_spec.separates(*action_key, default=None)

# Output specs

output_spec_stack = {}
nexts_0.apply(
partial(_tensor_to_spec, stack=output_spec_stack),
nexts_1,
named=True,
nested_keys=True,
)

output_spec = Composite(output_spec_stack, batch_size=batch_size)
if not self.batch_locked and batch_size != self.batch_size:
while output_spec.shape:
output_spec = output_spec[0]

if done_key is None:
done_key = ["done", "terminated", "truncated"]
full_done_spec = output_spec.separates(*done_key, default=None)
if full_done_spec is not None:
self.full_done_spec = full_done_spec

if isinstance(reward_key, NestedKey):
reward_key = [reward_key]
full_reward_spec = output_spec.separates(*reward_key, default=None)

if isinstance(observation_key, NestedKey):
observation_key = [observation_key]
full_observation_spec = output_spec.separates(*observation_key, default=None)
if not output_spec.is_empty(recurse=True):
raise RuntimeError(
f"Keys {list(output_spec.keys(True, True))} are unaccounted for."
)

if full_action_spec is not None:
self.full_action_spec = full_action_spec
if full_done_spec is not None:
self.full_done_specs = full_done_spec
if full_observation_spec is not None:
self.full_observation_spec = full_observation_spec
if full_reward_spec is not None:
self.full_reward_spec = full_reward_spec
full_state_spec = input_spec
self.full_state_spec = full_state_spec

return self

@wraps(check_env_specs_func)
def check_env_specs(self, *args, **kwargs):
return check_env_specs_func(self, *args, **kwargs)
Expand Down Expand Up @@ -665,8 +801,6 @@ def action_keys(self) -> List[NestedKey]:
if action_keys is not None:
return action_keys
keys = self.full_action_spec.keys(True, True)
if not len(keys):
raise AttributeError("Could not find action spec")
keys = sorted(keys, key=_repr_by_depth)
self.__dict__["_action_keys"] = keys
return keys
Expand Down Expand Up @@ -825,15 +959,7 @@ def action_spec(self, value: TensorSpec) -> None:
"Please use `env.action_spec_unbatched = value` to set unbatched versions instead."
)

if isinstance(value, Composite):
for _ in value.values(True, True): # noqa: B007
break
else:
raise RuntimeError(
"An empty Composite was passed for the action spec. "
"This is currently not permitted."
)
else:
if not isinstance(value, Composite):
value = Composite(
action=value.to(device), shape=self.batch_size, device=device
)
Expand Down Expand Up @@ -890,7 +1016,6 @@ def reward_keys(self) -> List[NestedKey]:
reward_keys = self.__dict__.get("_reward_keys")
if reward_keys is not None:
return reward_keys

reward_keys = sorted(self.full_reward_spec.keys(True, True), key=_repr_by_depth)
self.__dict__["_reward_keys"] = reward_keys
return reward_keys
Expand Down Expand Up @@ -1028,15 +1153,7 @@ def reward_spec(self, value: TensorSpec) -> None:
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). "
"Please use `env.reward_spec_unbatched = value` to set unbatched versions instead."
)
if isinstance(value, Composite):
for _ in value.values(True, True): # noqa: B007
break
else:
raise RuntimeError(
"An empty Composite was passed for the reward spec. "
"This is currently not permitted."
)
else:
if not isinstance(value, Composite):
value = Composite(
reward=value.to(device), shape=self.batch_size, device=device
)
Expand Down Expand Up @@ -1317,15 +1434,7 @@ def done_spec(self, value: TensorSpec) -> None:
raise ValueError(
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)
if isinstance(value, Composite):
for _ in value.values(True, True): # noqa: B007
break
else:
raise RuntimeError(
"An empty Composite was passed for the done spec. "
"This is currently not permitted."
)
else:
if not isinstance(value, Composite):
value = Composite(
done=value.to(device),
terminated=value.to(device),
Expand Down Expand Up @@ -2034,7 +2143,7 @@ def _register_gym(

if entry_point is None:
entry_point = cls
entry_point = functools.partial(
entry_point = partial(
_TorchRLGymWrapper,
entry_point=entry_point,
info_keys=info_keys,
Expand Down Expand Up @@ -2083,7 +2192,7 @@ def _register_gym( # noqa: F811

if entry_point is None:
entry_point = cls
entry_point = functools.partial(
entry_point = partial(
_TorchRLGymWrapper,
entry_point=entry_point,
info_keys=info_keys,
Expand Down Expand Up @@ -2137,7 +2246,7 @@ def _register_gym( # noqa: F811

if entry_point is None:
entry_point = cls
entry_point = functools.partial(
entry_point = partial(
_TorchRLGymWrapper,
entry_point=entry_point,
info_keys=info_keys,
Expand Down Expand Up @@ -2194,7 +2303,7 @@ def _register_gym( # noqa: F811

if entry_point is None:
entry_point = cls
entry_point = functools.partial(
entry_point = partial(
_TorchRLGymWrapper,
entry_point=entry_point,
info_keys=info_keys,
Expand Down Expand Up @@ -2253,7 +2362,7 @@ def _register_gym( # noqa: F811
)
if entry_point is None:
entry_point = cls
entry_point = functools.partial(
entry_point = partial(
_TorchRLGymWrapper,
entry_point=entry_point,
info_keys=info_keys,
Expand Down Expand Up @@ -2292,7 +2401,7 @@ def _register_gym( # noqa: F811
if entry_point is None:
entry_point = cls

entry_point = functools.partial(
entry_point = partial(
_TorchRLGymnasiumWrapper,
entry_point=entry_point,
info_keys=info_keys,
Expand Down Expand Up @@ -3421,11 +3530,11 @@ def _get_sync_func(policy_device, env_device):
if policy_device is not None and policy_device.type == "cuda":
if env_device is None or env_device.type == "cuda":
return torch.cuda.synchronize
return functools.partial(torch.cuda.synchronize, device=policy_device)
return partial(torch.cuda.synchronize, device=policy_device)
if env_device is not None and env_device.type == "cuda":
if policy_device is None:
return torch.cuda.synchronize
return functools.partial(torch.cuda.synchronize, device=env_device)
return partial(torch.cuda.synchronize, device=env_device)
return torch.cuda.synchronize
if torch.backends.mps.is_available():
return torch.mps.synchronize
Expand All @@ -3443,3 +3552,11 @@ def _has_dynamic_specs(spec: Composite):
any(s == -1 for s in spec.shape)
for spec in spec.values(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)
)


def _tensor_to_spec(name, leaf, leaf_compare=None, *, stack):
shape = leaf.shape
if leaf_compare is not None:
shape_compare = leaf_compare.shape
shape = [s0 if s0 == s1 else -1 for s0, s1 in zip(shape, shape_compare)]
stack[name] = Unbounded(shape, device=leaf.device, dtype=leaf.dtype)
Loading

0 comments on commit 655115d

Please sign in to comment.