Skip to content

Commit

Permalink
[BugFix] Account for terminating data in SAC losses
Browse files Browse the repository at this point in the history
ghstack-source-id: dc1870292786c262b4ab6a221b3afb551e0efb9b
Pull Request resolved: #2606
  • Loading branch information
vmoens committed Nov 25, 2024
1 parent d90b9e3 commit c8676f4
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 8 deletions.
119 changes: 119 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4459,6 +4459,69 @@ def test_sac_notensordict(
assert loss_actor == loss_val_td["loss_actor"]
assert loss_alpha == loss_val_td["loss_alpha"]

@pytest.mark.parametrize("action_key", ["action", "action2"])
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
def test_sac_terminating(
self, action_key, observation_key, reward_key, done_key, terminated_key, version
):
torch.manual_seed(self.seed)
td = self._create_mock_data_sac(
action_key=action_key,
observation_key=observation_key,
reward_key=reward_key,
done_key=done_key,
terminated_key=terminated_key,
)

actor = self._create_mock_actor(
observation_key=observation_key, action_key=action_key
)
qvalue = self._create_mock_qvalue(
observation_key=observation_key,
action_key=action_key,
out_keys=["state_action_value"],
)
if version == 1:
value = self._create_mock_value(observation_key=observation_key)
else:
value = None

loss = SACLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
)
loss.set_keys(
action=action_key,
reward=reward_key,
done=done_key,
terminated=terminated_key,
)

torch.manual_seed(self.seed)

SoftUpdate(loss, eps=0.5)

done = td.get(("next", done_key))
while not (done.any() and not done.all()):
done.bernoulli_(0.1)
obs_nan = td.get(("next", terminated_key))
obs_nan[done.squeeze(-1)] = float("nan")

kwargs = {
action_key: td.get(action_key),
observation_key: td.get(observation_key),
f"next_{reward_key}": td.get(("next", reward_key)),
f"next_{done_key}": done,
f"next_{terminated_key}": obs_nan,
f"next_{observation_key}": td.get(("next", observation_key)),
}
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
assert loss(td).isfinite().all()

def test_state_dict(self, version):
if version == 1:
pytest.skip("Test not implemented for version 1.")
Expand Down Expand Up @@ -5112,6 +5175,62 @@ def test_discrete_sac_notensordict(
assert loss_actor == loss_val_td["loss_actor"]
assert loss_alpha == loss_val_td["loss_alpha"]

@pytest.mark.parametrize("action_key", ["action", "action2"])
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
def test_discrete_sac_terminating(
self, action_key, observation_key, reward_key, done_key, terminated_key
):
torch.manual_seed(self.seed)
td = self._create_mock_data_sac(
action_key=action_key,
observation_key=observation_key,
reward_key=reward_key,
done_key=done_key,
terminated_key=terminated_key,
)

actor = self._create_mock_actor(
observation_key=observation_key, action_key=action_key
)
qvalue = self._create_mock_qvalue(
observation_key=observation_key,
)

loss = DiscreteSACLoss(
actor_network=actor,
qvalue_network=qvalue,
num_actions=actor.spec[action_key].space.n,
action_space="one-hot",
)
loss.set_keys(
action=action_key,
reward=reward_key,
done=done_key,
terminated=terminated_key,
)

SoftUpdate(loss, eps=0.5)

torch.manual_seed(0)
done = td.get(("next", done_key))
while not (done.any() and not done.all()):
done = done.bernoulli_(0.1)
obs_none = td.get(("next", observation_key))
obs_none[done.squeeze(-1)] = float("nan")
kwargs = {
action_key: td.get(action_key),
observation_key: td.get(observation_key),
f"next_{reward_key}": td.get(("next", reward_key)),
f"next_{done_key}": done,
f"next_{terminated_key}": td.get(("next", terminated_key)),
f"next_{observation_key}": obs_none,
}
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
assert loss(td).isfinite().all()

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_discrete_sac_reduction(self, reduction):
torch.manual_seed(self.seed)
Expand Down
51 changes: 43 additions & 8 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tensordict import TensorDict, TensorDictBase, TensorDictParams

from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey
from tensordict.utils import expand_right, NestedKey
from torch import Tensor
from torchrl.data.tensor_specs import Composite, TensorSpec
from torchrl.data.utils import _find_action_space
Expand Down Expand Up @@ -711,13 +711,37 @@ def _compute_target_v2(self, tensordict) -> Tensor:
with set_exploration_type(
ExplorationType.RANDOM
), self.actor_network_params.to_module(self.actor_network):
next_tensordict = tensordict.get("next").clone(False)
next_dist = self.actor_network.get_dist(next_tensordict)
next_tensordict = tensordict.get("next").copy()
# Check done state and avoid passing these to the actor
done = next_tensordict.get(self.tensor_keys.done)
if done is not None and done.any():
next_tensordict_select = next_tensordict[~done.squeeze(-1)]
else:
next_tensordict_select = next_tensordict
next_dist = self.actor_network.get_dist(next_tensordict_select)
next_action = next_dist.rsample()
next_tensordict.set(self.tensor_keys.action, next_action)
next_sample_log_prob = compute_log_prob(
next_dist, next_action, self.tensor_keys.log_prob
)
if next_tensordict_select is not next_tensordict:
mask = ~done.squeeze(-1)
if mask.ndim < next_action.ndim:
mask = expand_right(
mask, (*mask.shape, *next_action.shape[mask.ndim :])
)
next_action = next_action.new_zeros(mask.shape).masked_scatter_(
mask, next_action
)
mask = ~done.squeeze(-1)
if mask.ndim < next_sample_log_prob.ndim:
mask = expand_right(
mask,
(*mask.shape, *next_sample_log_prob.shape[mask.ndim :]),
)
next_sample_log_prob = next_sample_log_prob.new_zeros(
mask.shape
).masked_scatter_(mask, next_sample_log_prob)
next_tensordict.set(self.tensor_keys.action, next_action)

# get q-values
next_tensordict_expand = self._vmap_qnetworkN0(
Expand Down Expand Up @@ -1194,15 +1218,21 @@ def _compute_target(self, tensordict) -> Tensor:
with torch.no_grad():
next_tensordict = tensordict.get("next").clone(False)

done = next_tensordict.get(self.tensor_keys.done)
if done is not None and done.any():
next_tensordict_select = next_tensordict[~done.squeeze(-1)]
else:
next_tensordict_select = next_tensordict

# get probs and log probs for actions computed from "next"
with self.actor_network_params.to_module(self.actor_network):
next_dist = self.actor_network.get_dist(next_tensordict)
next_prob = next_dist.probs
next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob))
next_dist = self.actor_network.get_dist(next_tensordict_select)
next_log_prob = next_dist.logits
next_prob = next_log_prob.exp()

# get q-values for all actions
next_tensordict_expand = self._vmap_qnetworkN0(
next_tensordict, self.target_qvalue_network_params
next_tensordict_select, self.target_qvalue_network_params
)
next_action_value = next_tensordict_expand.get(
self.tensor_keys.action_value
Expand All @@ -1212,6 +1242,11 @@ def _compute_target(self, tensordict) -> Tensor:
next_state_value = next_action_value.min(0)[0] - self._alpha * next_log_prob
# unlike in continuous SAC, we can compute the exact expectation over all discrete actions
next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1)
if next_tensordict_select is not next_tensordict:
mask = ~done.squeeze(-1)
next_state_value = next_state_value.new_zeros(
mask.shape
).masked_scatter_(mask, next_state_value)

tensordict.set(
("next", self.value_estimator.tensor_keys.value), next_state_value
Expand Down

0 comments on commit c8676f4

Please sign in to comment.