Skip to content

Commit

Permalink
[BugFix] RewardSum transform for multiple reward keys (#1544)
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
Co-authored-by: vmoens <[email protected]>
  • Loading branch information
matteobettini and vmoens authored Oct 2, 2023
1 parent 106368f commit 1697102
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 98 deletions.
41 changes: 39 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@
IncrementingEnv,
MockBatchedLockedEnv,
MockBatchedUnLockedEnv,
MultiKeyCountingEnv,
MultiKeyCountingEnvPolicy,
NestedCountingEnv,
)
from tensordict import unravel_key
from tensordict.nn import TensorDictSequential
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import _unravel_key_to_tuple
from torch import multiprocessing as mp, nn, Tensor
from torchrl._utils import prod
from torchrl.data import (
Expand Down Expand Up @@ -104,7 +107,7 @@
from torchrl.envs.transforms.transforms import _has_tv
from torchrl.envs.transforms.vc1 import _has_vc
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
from torchrl.envs.utils import check_env_specs, step_mdp
from torchrl.envs.utils import _replace_last, check_env_specs, step_mdp
from torchrl.modules import LSTMModule, MLP, ProbabilisticActor, TanhNormal

TIMEOUT = 100.0
Expand Down Expand Up @@ -4527,6 +4530,36 @@ def test_trans_parallel_env_check(self):
r = env.rollout(4)
assert r["next", "episode_reward"].unique().numel() > 1

@pytest.mark.parametrize("has_in_keys,", [True, False])
def test_trans_multi_key(
self, has_in_keys, n_workers=2, batch_size=(3, 2), max_steps=5
):
torch.manual_seed(0)
env_fun = lambda: MultiKeyCountingEnv(batch_size=batch_size)
base_env = SerialEnv(n_workers, env_fun)
if has_in_keys:
t = RewardSum(in_keys=base_env.reward_keys, reset_keys=base_env.reset_keys)
else:
t = RewardSum()
env = TransformedEnv(
base_env,
Compose(t),
)
policy = MultiKeyCountingEnvPolicy(
full_action_spec=env.action_spec, deterministic=True
)

check_env_specs(env)
td = env.rollout(max_steps, policy=policy)
for reward_key in env.reward_keys:
reward_key = _unravel_key_to_tuple(reward_key)
assert (
td.get(
("next", _replace_last(reward_key, f"episode_{reward_key[-1]}"))
)[(0,) * (len(batch_size) + 1)][-1]
== max_steps
).all()

@pytest.mark.parametrize("in_key", ["reward", ("some", "nested")])
def test_transform_no_env(self, in_key):
t = RewardSum(in_keys=[in_key], out_keys=[("some", "nested_sum")])
Expand All @@ -4550,7 +4583,8 @@ def test_transform_no_env(self, in_key):
def test_transform_compose(
self,
):
t = Compose(RewardSum())
# reset keys should not be needed for offline run
t = Compose(RewardSum(in_keys=["reward"], out_keys=["episode_reward"]))
reward = torch.randn(10)
td = TensorDict({("next", "reward"): reward}, [])
with pytest.raises(
Expand Down Expand Up @@ -4649,6 +4683,9 @@ def test_sum_reward(self, keys, device):

# reset environments
td.set("_reset", torch.ones(batch, dtype=torch.bool, device=device))
with pytest.raises(TypeError, match="reset_keys not provided but parent"):
rs.reset(td)
rs._reset_keys = ["_reset"]
rs.reset(td)

# apply a third time, episode_reward should be equal to reward again
Expand Down
230 changes: 134 additions & 96 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from torchrl.envs.common import _EnvPostInit, EnvBase, make_tensordict
from torchrl.envs.transforms import functional as F
from torchrl.envs.transforms.utils import check_finite
from torchrl.envs.utils import _sort_keys, step_mdp
from torchrl.envs.utils import _replace_last, _sort_keys, step_mdp
from torchrl.objectives.value.functional import reward2go

try:
Expand Down Expand Up @@ -242,7 +242,7 @@ def _apply_transform(self, obs: torch.Tensor) -> None:
"""
raise NotImplementedError(
f"{self.__class__.__name__}_apply_transform is not coded. If the transform is coded in "
f"{self.__class__.__name__}._apply_transform is not coded. If the transform is coded in "
"transform._call, make sure that this method is called instead of"
"transform.forward, which is reserved for usage inside nn.Modules"
"or appended to a replay buffer."
Expand Down Expand Up @@ -4342,74 +4342,140 @@ class RewardSum(Transform):
"""Tracks episode cumulative rewards.
This transform accepts a list of tensordict reward keys (i.e. ´in_keys´) and tracks their cumulative
value along each episode. When called, the transform creates a new tensordict key for each in_key named
´episode_{in_key}´ where the cumulative values are written. All ´in_keys´ should be part of the env
reward and be present in the env reward_spec.
value along the time dimension for each episode.
If no in_keys are specified, this transform assumes ´reward´ to be the input key. However, multiple rewards
(e.g. reward1 and reward2) can also be specified. If ´in_keys´ are not present in the provided tensordict,
this transform hos no effect.
When called, the transform writes a new tensordict entry for each ``in_key`` named
``episode_{in_key}`` where the cumulative values are written.
.. note:: :class:`~RewardSum` currently only supports ``"done"`` signal at the root.
Nested ``"done"``, such as those found in MARL settings, are currently not supported.
If this feature is needed, please raise an issue on TorchRL repo.
Args:
in_keys (list of NestedKeys, optional): Input reward keys.
All ´in_keys´ should be part of the environment reward_spec.
If no ``in_keys`` are specified, this transform assumes ``"reward"`` to be the input key.
However, multiple rewards (e.g. ``"reward1"`` and ``"reward2""``) can also be specified.
out_keys (list of NestedKeys, optional): The output sum keys, should be one per each input key.
reset_keys (list of NestedKeys, optional): the list of reset_keys to be
used, if the parent environment cannot be found. If provided, this
value will prevail over the environment ``reset_keys``.
Examples:
>>> from torchrl.envs.transforms import RewardSum, TransformedEnv
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), RewardSum())
>>> td = env.reset()
>>> print(td["episode_reward"])
tensor([0.])
>>> td = env.rollout(3)
>>> print(td["next", "episode_reward"])
tensor([[-0.5926],
[-1.4578],
[-2.7885]])
"""

def __init__(
self,
in_keys: Optional[Sequence[NestedKey]] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
reset_keys: Optional[Sequence[NestedKey]] = None,
):
"""Initialises the transform. Filters out non-reward input keys and defines output keys."""
if in_keys is None:
in_keys = ["reward"]
if out_keys is None and in_keys == ["reward"]:
out_keys = ["episode_reward"]
elif out_keys is None:
raise RuntimeError(
"the out_keys must be specified for non-conventional in-keys in RewardSum."
super().__init__(in_keys=in_keys, out_keys=out_keys)
self._reset_keys = reset_keys

@property
def in_keys(self):
in_keys = self.__dict__.get("_in_keys", None)
if in_keys in (None, []):
# retrieve rewards from parent env
parent = self.parent
if parent is None:
in_keys = ["reward"]
else:
in_keys = copy(parent.reward_keys)
self._in_keys = in_keys
return in_keys

@in_keys.setter
def in_keys(self, value):
if value is not None:
if isinstance(value, (str, tuple)):
value = [value]
value = [unravel_key(val) for val in value]
self._in_keys = value

@property
def out_keys(self):
out_keys = self.__dict__.get("_out_keys", None)
if out_keys in (None, []):
out_keys = [
_replace_last(in_key, f"episode_{_unravel_key_to_tuple(in_key)[-1]}")
for in_key in self.in_keys
]
self._out_keys = out_keys
return out_keys

@out_keys.setter
def out_keys(self, value):
# we must access the private attribute because this check occurs before
# the parent env is defined
if value is not None and len(self._in_keys) != len(value):
raise ValueError(
"RewardSum expects the same number of input and output keys"
)
if value is not None:
if isinstance(value, (str, tuple)):
value = [value]
value = [unravel_key(val) for val in value]
self._out_keys = value

super().__init__(in_keys=in_keys, out_keys=out_keys)
@property
def reset_keys(self):
reset_keys = self.__dict__.get("_reset_keys", None)
if reset_keys is None:
parent = self.parent
if parent is None:
raise TypeError(
"reset_keys not provided but parent env not found. "
"Make sure that the reset_keys are provided during "
"construction if the transform does not have a container env."
)
reset_keys = copy(parent.reset_keys)
self._reset_keys = reset_keys
return reset_keys

@reset_keys.setter
def reset_keys(self, value):
if value is not None:
if isinstance(value, (str, tuple)):
value = [value]
value = [unravel_key(val) for val in value]
self._reset_keys = value

def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Resets episode rewards."""
# Non-batched environments
_reset = tensordict.get("_reset", None)
if _reset is None:
_reset = torch.ones(
self.parent.done_spec.shape if self.parent else tensordict.batch_size,
dtype=torch.bool,
device=tensordict.device,
)
for in_key, reset_key, out_key in zip(
self.in_keys, self.reset_keys, self.out_keys
):
_reset = tensordict.get(reset_key, None)

if _reset.any():
_reset = _reset.sum(
tuple(range(tensordict.batch_dims, _reset.ndim)), dtype=torch.bool
)
reward_key = self.parent.reward_key if self.parent else "reward"
for in_key, out_key in zip(self.in_keys, self.out_keys):
if out_key in tensordict.keys(True, True):
value = tensordict[out_key]
tensordict[out_key] = value.masked_fill(
expand_as_right(_reset, value), 0.0
)
elif unravel_key(in_key) == unravel_key(reward_key):
if _reset is None or _reset.any():
value = tensordict.get(out_key, default=None)
if value is not None:
if _reset is None:
tensordict.set(out_key, torch.zeros_like(value))
else:
tensordict.set(
out_key,
value.masked_fill(
expand_as_right(_reset.squeeze(-1), value), 0.0
),
)
else:
# Since the episode reward is not in the tensordict, we need to allocate it
# with zeros entirely (regardless of the _reset mask)
tensordict[out_key] = self.parent.reward_spec.zero()
else:
try:
tensordict[out_key] = self.parent.observation_spec[
in_key
].zero()
except KeyError as err:
raise KeyError(
f"The key {in_key} was not found in the parent "
f"observation_spec with keys "
f"{list(self.parent.observation_spec.keys(True))}. "
) from err
tensordict.set(
out_key,
self.parent.full_reward_spec[in_key].zero(),
)
return tensordict

def _step(
Expand All @@ -4430,76 +4496,48 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
state_spec = input_spec["full_state_spec"]
if state_spec is None:
state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device)
reward_spec = self.parent.output_spec["full_reward_spec"]
reward_spec_keys = list(reward_spec.keys(True, True))
state_spec.update(self._generate_episode_reward_spec())
input_spec["full_state_spec"] = state_spec
return input_spec

def _generate_episode_reward_spec(self) -> CompositeSpec:
episode_reward_spec = CompositeSpec()
reward_spec = self.parent.full_reward_spec
reward_spec_keys = self.parent.reward_keys
# Define episode specs for all out_keys
for in_key, out_key in zip(self.in_keys, self.out_keys):
if (
in_key in reward_spec_keys
): # if this out_key has a corresponding key in reward_spec
out_key = _unravel_key_to_tuple(out_key)
temp_state_spec = state_spec
temp_episode_reward_spec = episode_reward_spec
temp_rew_spec = reward_spec
for sub_key in out_key[:-1]:
if (
not isinstance(temp_rew_spec, CompositeSpec)
or sub_key not in temp_rew_spec.keys()
):
break
if sub_key not in temp_state_spec.keys():
temp_state_spec[sub_key] = temp_rew_spec[sub_key].empty()
if sub_key not in temp_episode_reward_spec.keys():
temp_episode_reward_spec[sub_key] = temp_rew_spec[
sub_key
].empty()
temp_rew_spec = temp_rew_spec[sub_key]
temp_state_spec = temp_state_spec[sub_key]
state_spec[out_key] = reward_spec[in_key].clone()
temp_episode_reward_spec = temp_episode_reward_spec[sub_key]
episode_reward_spec[out_key] = reward_spec[in_key].clone()
else:
raise ValueError(
f"The in_key: {in_key} is not present in the reward spec {reward_spec}."
)
input_spec["full_state_spec"] = state_spec
return input_spec
return episode_reward_spec

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
"""Transforms the observation spec, adding the new keys generated by RewardSum."""
# Retrieve parent reward spec
reward_spec = self.parent.reward_spec
reward_key = self.parent.reward_key if self.parent else "reward"

episode_specs = {}
if isinstance(reward_spec, CompositeSpec):
# If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec
if not all(k in reward_spec.keys(True, True) for k in self.in_keys):
raise KeyError("Not all in_keys are present in ´reward_spec´")

# Define episode specs for all out_keys
for out_key in self.out_keys:
episode_spec = UnboundedContinuousTensorSpec(
shape=reward_spec.shape,
device=reward_spec.device,
dtype=reward_spec.dtype,
)
episode_specs.update({out_key: episode_spec})

else:
# If reward_spec is not a CompositeSpec, the only in_key should be ´reward´
if set(unravel_key_list(self.in_keys)) != {unravel_key(reward_key)}:
raise KeyError(
"reward_spec is not a CompositeSpec class, in_keys should only include ´reward´"
)

# Define episode spec
episode_spec = UnboundedContinuousTensorSpec(
device=reward_spec.device,
dtype=reward_spec.dtype,
shape=reward_spec.shape,
)
episode_specs.update({self.out_keys[0]: episode_spec})

# Update observation_spec with episode_specs
if not isinstance(observation_spec, CompositeSpec):
observation_spec = CompositeSpec(
observation=observation_spec, shape=self.parent.batch_size
)
observation_spec.update(episode_specs)
observation_spec.update(self._generate_episode_reward_spec())
return observation_spec

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
Expand Down

0 comments on commit 1697102

Please sign in to comment.