Skip to content

Commit

Permalink
[inductor][debug] fix draw_buffers (pytorch#135266)
Browse files Browse the repository at this point in the history
  • Loading branch information
xuanzhang816 authored and pytorchmergebot committed Sep 6, 2024
1 parent 5f57be7 commit c05a7ad
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions torch/_inductor/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def func1(*args: Any) -> int:
FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"])

buf_to_fx_node = {}
node_to_fx_node = {}
graph = torch.fx.Graph()
first_node = None

Expand Down Expand Up @@ -162,10 +163,9 @@ def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool:

fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type)

if isinstance(snode, FusedSchedulerNode):
for x in snode.snodes:
buf_to_fx_node[x.get_name()] = fx_node
buf_to_fx_node[name] = fx_node
node_to_fx_node[name] = fx_node
for buf in snode.get_outputs():
buf_to_fx_node[buf.get_name()] = fx_node

if first_node is None:
first_node = fx_node
Expand All @@ -175,7 +175,7 @@ def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool:
name = snode.get_name()
deps = snode.read_writes.reads

fx_node = buf_to_fx_node[name]
fx_node = node_to_fx_node[name]
new_args = []
for dep in deps:
if dep.name in buf_to_fx_node:
Expand All @@ -184,6 +184,8 @@ def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool:
with graph.inserting_before(first_node):
dep_node = graph.placeholder(dep.name)
buf_to_fx_node[dep.name] = dep_node
if dep_node == fx_node: # to avoid cycles
continue
new_args.append(dep_node)

fx_node.args = tuple(new_args)
Expand Down

0 comments on commit c05a7ad

Please sign in to comment.