Skip to content

Commit

Permalink
misc. fixes to unflatten (pytorch#141066)
Browse files Browse the repository at this point in the history
Handling of nested modules in unflatten had several bugs, which were caught by trying to preserve module call signatures for nested modules.
* A module `k` encountered when calling `k.n()` before `k()` used to become an empty nn module. This caused some information to be dropped when `k()` was eventually called. Relatedly, we would also lose call counts for `k.n()` through different paths (say, when `k()` calls `n()`).
* Deleting call-indexed modules and patching up their call sites was broken for nested modules when creating dispatcher modules, because of silliness when handling their fqns.

An interesting aside is that we used random graph generation for testing some of these changes. A future PR will add the infra to create tests using these random graphs.

Differential Revision: D66192799

Pull Request resolved: pytorch#141066
Approved by: https://github.com/angelayi
  • Loading branch information
avikchaudhuri authored and pytorchmergebot committed Nov 23, 2024
1 parent 5268754 commit 8b4ae29
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 48 deletions.
206 changes: 181 additions & 25 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -6793,6 +6793,131 @@ def forward(self, x):

self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp))

def test_unflatten_random_dag_5_modules(self):
# dag: {0: [1, 2, 3], 1: [2, 4], 2: [4], 3: [], 4: []}

class N4(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x + 1

class N3(torch.nn.Module):
def __init__(self):
super().__init__()
self.n4 = N4()

def forward(self, x):
return x + 1

class N2(torch.nn.Module):
def __init__(self):
super().__init__()
self.n3 = N3()

def forward(self, x):
x = self.n3.n4(x + 1)
return x + 1

class N1(torch.nn.Module):
def __init__(self):
super().__init__()
self.n2 = N2()

def forward(self, x):
x = self.n2(x + 1)
x = self.n2.n3.n4(x + 1)
return x + 1

class N0(torch.nn.Module):
def __init__(self):
super().__init__()
self.n1 = N1()

def forward(self, x):
x = self.n1(x + 1)
x = self.n1.n2(x + 1)
x = self.n1.n2.n3(x + 1)
return x + 1

n0 = N0()
inp = (torch.ones(1),)
eager = n0(*inp)
ep = torch.export.export(n0, inp)
epm = ep.module()
ufm = torch.export.unflatten(ep)
self.assertTrue(torch.allclose(epm(*inp), eager))
self.assertTrue(torch.allclose(ufm(*inp), eager))

def test_unflatten_random_dag_6_modules(self):
# dag: {0: [1, 2, 4, 5], 1: [3, 5], 2: [4, 5], 3: [], 4: [5], 5: []}

class N5(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x + 1

class N4(torch.nn.Module):
def __init__(self):
super().__init__()
self.n5 = N5()

def forward(self, x):
x = self.n5(x + 1)
return x + 1

class N3(torch.nn.Module):
def __init__(self):
super().__init__()
self.n4 = N4()

def forward(self, x):
return x + 1

class N2(torch.nn.Module):
def __init__(self):
super().__init__()
self.n3 = N3()

def forward(self, x):
x = self.n3.n4(x + 1)
x = self.n3.n4.n5(x + 1)
return x + 1

class N1(torch.nn.Module):
def __init__(self):
super().__init__()
self.n2 = N2()

def forward(self, x):
x = self.n2.n3(x + 1)
x = self.n2.n3.n4.n5(x + 1)
return x + 1

class N0(torch.nn.Module):
def __init__(self):
super().__init__()
self.n1 = N1()

def forward(self, x):
x = self.n1(x + 1)
x = self.n1.n2(x + 1)
x = self.n1.n2.n3.n4(x + 1)
x = self.n1.n2.n3.n4.n5(x + 1)
return x + 1

n0 = N0()
inp = (torch.ones(1),)
eager = n0(*inp)
ep = torch.export.export(n0, inp)
epm = ep.module()
ufm = torch.export.unflatten(ep)
self.assertTrue(torch.allclose(epm(*inp), eager))
self.assertTrue(torch.allclose(ufm(*inp), eager))

def test_unflatten_no_unroll(self):
inp = (torch.ones(1),)

Expand All @@ -6808,7 +6933,15 @@ def forward(self, x, b):
else:
return x + 2 * (self.buf + 1) - self.const

class M(torch.nn.Module):
class K(torch.nn.Module):
def __init__(self):
super().__init__()
self.n = N()

def forward(self, x0):
return self.n(x0, True)

class P(torch.nn.Module):
def __init__(self):
super().__init__()
self.n = N()
Expand All @@ -6819,29 +6952,27 @@ def forward(self, x):
x2 = self.n(x0, False)
return x1 + x2

m = M()
eager_result = m(*inp)

def test(ep, swap):
epm = ep.module()
ufm = torch.export.unflatten(ep)

exported_result = epm(*inp)
self.assertTrue(torch.allclose(exported_result, eager_result))
class Q(torch.nn.Module):
def __init__(self):
super().__init__()
self.k = K()

unflattened_result = ufm(*inp)
self.assertTrue(torch.allclose(unflattened_result, eager_result))
def forward(self, x):
x0 = x + 3
x1 = self.k.n(x0, True)
x2 = self.k.n(x0, False)
return x1 + x2

for fqn, mod in swap.items():
ufm.set_submodule(fqn, mod)
unflattened_result = ufm(*inp)
self.assertTrue(torch.allclose(unflattened_result, eager_result))
class R(torch.nn.Module):
def __init__(self):
super().__init__()
self.k = K()

if not is_retracebility_test(self._testMethodName):
test(
export(M(), inp, preserve_module_call_signature=("n",)),
swap={"n": N()},
)
def forward(self, x):
x0 = x + 3
x1 = self.k(x0)
x2 = self.k.n(x0, False)
return x1 + x2

class _N(torch.nn.Module):
def forward(self, x):
Expand All @@ -6851,10 +6982,35 @@ class _N_1(torch.nn.Module):
def forward(self, x):
return x + 6

test(
export(M(), inp),
swap={"n": _N(), "n@1": _N_1()},
)
for Mod, path_n in [(P, "n"), (Q, "k.n"), (R, "k.n")]:
m = Mod()
eager_result = m(*inp)

def test(ep, swap):
epm = ep.module()
ufm = torch.export.unflatten(ep)

exported_result = epm(*inp)
self.assertTrue(torch.allclose(exported_result, eager_result))

unflattened_result = ufm(*inp)
self.assertTrue(torch.allclose(unflattened_result, eager_result))

for fqn, mod in swap.items():
ufm.set_submodule(fqn, mod)
unflattened_result = ufm(*inp)
self.assertTrue(torch.allclose(unflattened_result, eager_result))

if not is_retracebility_test(self._testMethodName):
test(
export(Mod(), inp, preserve_module_call_signature=(path_n,)),
swap={path_n: N()},
)

test(
export(Mod(), inp),
swap={path_n: _N(), path_n + "@1": _N_1()},
)

def test_preserve_module_call_signature_unflatten_specialization(self):
class N(torch.nn.Module):
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3260,6 +3260,7 @@ def __init__(
distributed_state=parent.distributed_state,
)
self.parent = parent
self.num_calls = parent.num_calls
self.symbolic_result = None
self.nn_module_stack = parent.nn_module_stack.copy()
self.one_graph = parent.one_graph
Expand Down
2 changes: 2 additions & 0 deletions torch/distributed/pipelining/_unflatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ def _outline_submodules(orig_graph: torch.fx.Graph):
seen_nodes: Dict[str, torch.fx.Node] = {}
seen_modules: Dict[int, List[_SubmoduleEntry]] = defaultdict(list)
seen_attrs: Dict[str, Set[str]] = defaultdict(set)
created_modules: Dict[str, torch.nn.Module] = {}
_ModuleFrame(
orig_graph,
tuple(orig_graph.nodes),
seen_nodes,
seen_modules,
seen_attrs,
created_modules,
None,
[("", 0)],
"",
Expand Down
Loading

0 comments on commit 8b4ae29

Please sign in to comment.