-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rename rematerialization of saved for backward symbols #1367
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other reviewers, please don't merge this PR without my review.
thunder/core/transforms.py
Outdated
@@ -3148,6 +3148,9 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr | |||
|
|||
producers = find_producer_symbols(fwd_trace, tuple(unvariableify(i) for i in rematerializable), fwd_trace.args) | |||
|
|||
trace_tok = set_tracectx(bwd_trace) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please set and reset traces only with "try: finally:" blocks. If there's any error between the calls, the trace will not be reset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you set the input bwd_trace
as the active trace? There are no Thunder operations calls between set and reset, and the input trace shouldn't be modified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would we not use with tracectx(bwd_trace)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this I've taken inspiration from the code in the torch_autograd executor, in particular these lines explain why the need to set the trace context:
lightning-thunder/thunder/executors/torch_autograd.py
Lines 33 to 40 in 3390c92
# [note: why setting trace ctx?] | |
# [`TensorProxy.replace_name`](https://github.com/Lightning-AI/lightning-thunder/blob/561b699/thunder/core/proxies.py#L1221-L1223) calls | |
# [`tensorproxy`](https://github.com/Lightning-AI/lightning-thunder/blob/561b699/thunder/core/proxies.py#L1506-L1520) | |
# which then calls `TensorProxy.__init__`. `TensorProxy.__init__` of course calls | |
# [` Proxy.__init__`](https://github.com/Lightning-AI/lightning-thunder/blob/561b699/thunder/core/proxies.py#L81-L86). | |
# `Proxy`'s dunder init calls [`make_proxy_name`](https://github.com/Lightning-AI/lightning-thunder/blob/561b699/thunder/core/proxies.py#L81-L86) | |
# which depends on a tracectx. | |
trace_tok = set_tracectx(bwd_trace) |
@IvanYashchuk Would an acceptable workaround be to create a new empty trace and use it as ctx?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, or allow creating Proxies with any name without active tracectx. Maybe all is needed is to return True if trc is None in this function
lightning-thunder/thunder/core/proxies.py
Lines 75 to 82 in 3390c92
def register_proxy_name(name: None | str = None): | |
trc = get_tracectx() | |
if name is not None and not trc.has_name(name): | |
trc.add_name(name) | |
return True | |
return False |
To be noted this does not fix #1232 but it helps to debug it by having an overall clearer backward trace when remat saved for backward is enabled. |
This is part of #1232. PR renames the outputs of recomputed symbols so that they do not overlap with names used in the forward trace. Fusion rematerialization requires names used in producer and consumer fusions to be unique.