-
Notifications
You must be signed in to change notification settings - Fork 312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Fixed shape for MultiStep returns + Distributional loss #2270
Conversation
Hi @roger-creus! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this contribution! I left a couple of comments, would you have time to address them?
torchrl/objectives/dqn.py
Outdated
@@ -560,7 +560,7 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: | |||
support = support.to("cpu") | |||
pns_a = pns_a.to("cpu") | |||
|
|||
Tz = reward + (1 - terminated.to(reward.dtype)) * discount * support | |||
Tz = reward + (1 - terminated.to(reward.dtype)) * discount.unsqueeze(-1) * support.repeat(batch_size, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good
I think we're making a couple of assumptions about discount and support here.
I would feel safer if we had a block before this where we check explicitly the shapes:
if not check(discount.shape):
raise RuntimeError(f"Expected `discount` to have either shape X or Y, got {discount.shape}")
if not other_check(support.shape):
raise RuntimeError(f"Expected `support` to have either shape X or Y, got {support.shape}")
(check
and other_check
are just placeholders)
If possible I would also make an instance of this loss class with the support and discount shape that you are using to test that it works ok (see test_cost.py:TestDQN:test_distributional_dqn
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also make sure that discount is a tensor:
discount = torch.as_tensor(discount)
should do it (it's a no-op if it's already a tensor)
Description
I encountered an error when combining MultiStep returns + Distributional DQN Loss. Because this line assumes discount is an integer but if using MultiStep returns it is a tensor of size
(batch_size,)
.This PR only changes that line to:
Tz = reward + (1 - terminated.to(reward.dtype)) * discount.unsqueeze(-1) * support.repeat(batch_size, 1)
which gives the same results if the discount is an integer but fixes the shape error if it is a batch.
Solves issue #2269
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist