-
Notifications
You must be signed in to change notification settings - Fork 2
/
test_fake_pg.py
262 lines (234 loc) · 8.94 KB
/
test_fake_pg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
from datetime import timedelta
import torch
import os
from datetime import timedelta
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.distributed._composable.fsdp
from torch._C._distributed_c10d import ProcessGroup, Work
from torch.futures import Future
from functools import wraps
from contextlib import contextmanager, nullcontext
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
from torch.utils._python_dispatch import TorchDispatchMode
import logging
from torch.distributed.device_mesh import init_device_mesh, DeviceMesh
from typing import Optional, Callable, NamedTuple
from torch._guards import active_fake_mode
aten = torch.ops.aten
from torch.distributed._functional_collectives import all_gather_tensor_inplace
import torch.distributed._functional_collectives_impl as func_col_impl
func_col_impl._use_native_funcol = True
class IgnoreDistMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
logging.info(str(func.__name__))
logging.info(type(func))
logging.info(func)
if func == torch.ops.c10d._allgather_base_.default:
logging.info(str(func.__name__))
logging.info(type(func))
logging.info(func)
logging.info(f"Arg types: {[type(arg) for arg in args]}")
logging.info(f"Arg 0 size: {args[0].size()}")
logging.info(f"Arg 1 size: {args[1].size()}")
logging.info(f"Torch Script Inp Obj: {ProcessGroup.unbox(args[2])}")
# func = torch.ops._c10d_functional.all_gather_into_tensor.default
res = func(*args, **kwargs or {})
# res = all_gather_tensor_inplace(args[0], args[1], ProcessGroup.unbox(args[2]))
else:
res = func(*args, **kwargs or {})
if isinstance(res, tuple):
# logging.info(res)
logging.info(f" Res types: {[type(r) for r in res]}")
work = Work.unbox(res[1])
logging.info(f"Torch Script Op Obj: {work.__dir__()}")
logging.info(f"Future: {work.get_future().__dir__()}")
logging.info(f"Future value: {work.get_future().value()}")
logging.info(f"Future done: {work.get_future().done()}")
logging.info(f"Future value type: {type(work.get_future().value())}")
logging.info(f"Future value size: {work.get_future().value()[0].size()}")
# logging.info(f"Tensor Size: {res[0].size()}")
# if isinstance(res, torch.Tensor):
# print("Function name: ", str(func.__name__))
# print("Result type: ", type(res))
# print("Result size", res.size())
# print("Result element size", res.element_size())
# print("Result device: ", res.device)
return res
@contextmanager
def bypass_collectives(device_mesh: Optional[DeviceMesh] = None):
class _SavedCollectives(NamedTuple):
all_gather_into_tensor: Callable
reduce_scatter_tensor: Callable
all_reduce: Callable
barrier: Callable
gather: Callable
scatter: Callable
broadcast: Callable
saved_collectives = _SavedCollectives(
dist.all_gather_into_tensor,
dist.reduce_scatter_tensor,
dist.all_reduce,
dist.barrier,
dist.gather,
dist.scatter,
dist.broadcast,
)
# in_fake_mode = bool(active_fake_mode())
in_fake_mode = True
print(in_fake_mode)
class FakeWork(Work):
def __init__(self):
super().__init__()
def get_future(self) -> Future:
future: Future = Future()
future.set_result(None)
return future
def wait(self, timeout: Optional[timedelta] = None) -> bool:
return True
@wraps(dist.all_gather_into_tensor)
def all_gather_into_tensor(
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
group=None,
async_op=False,
):
if in_fake_mode:
return FakeWork() if async_op else None
else:
return saved_collectives.all_gather_into_tensor(
output_tensor, input_tensor, group, async_op
)
@wraps(dist.reduce_scatter_tensor)
def reduce_scatter_tensor(
output: torch.Tensor,
input: torch.Tensor,
op=dist.ReduceOp.SUM,
group=None,
async_op=False,
):
if in_fake_mode:
return FakeWork() if async_op else None
else:
return saved_collectives.reduce_scatter_tensor(
output, input, op, group, async_op
)
@wraps(dist.all_reduce)
def all_reduce(
tensor: torch.Tensor,
op=dist.ReduceOp.SUM,
group=None,
async_op=False,
):
if in_fake_mode:
return FakeWork() if async_op else None
else:
return saved_collectives.all_reduce(tensor, op, group, async_op)
@wraps(dist.barrier)
def barrier(group=dist.GroupMember.WORLD, async_op=False, device_ids=None):
if in_fake_mode:
return None
else:
return saved_collectives.barrier(group, async_op, device_ids)
@wraps(dist.gather)
def gather(
tensor: torch.Tensor,
gather_list=None,
dst=0,
group=None,
async_op=False,
):
if in_fake_mode:
return FakeWork() if async_op else None
else:
return saved_collectives.gather(tensor, gather_list, dst, group, async_op)
@wraps(dist.scatter)
def scatter(
tensor: torch.Tensor,
scatter_list=None,
src=0,
group=None,
async_op=False,
):
print("Custom Scatter")
if in_fake_mode:
fake_work = FakeWork()
fake_work.__setattr__("getFuture", fake_work.get_future)
return fake_work if async_op else None
else:
return saved_collectives.scatter(tensor, scatter_list, src, group, async_op)
@wraps(dist.broadcast)
def broadcast(
tensor: torch.Tensor,
src=0,
group=None,
async_op=False,
):
if in_fake_mode:
return FakeWork() if async_op else None
else:
return saved_collectives.broadcast(tensor, src, group, async_op)
try:
dist.all_gather_into_tensor = all_gather_into_tensor
dist.reduce_scatter_tensor = reduce_scatter_tensor
dist.all_reduce = all_reduce
dist.barrier = barrier
dist.gather = gather
dist.scatter = scatter
dist.broadcast = broadcast
if device_mesh is not None:
dm_pgs = device_mesh.get_all_groups()
for pg in dm_pgs:
object.__setattr__(pg, "barrier", barrier)
object.__setattr__(pg, "gather", gather)
object.__setattr__(pg, "scatter", scatter)
object.__setattr__(pg, "broadcast", broadcast)
yield
finally:
dist.all_gather_into_tensor = saved_collectives.all_gather_into_tensor
dist.reduce_scatter_tensor = saved_collectives.reduce_scatter_tensor
dist.all_reduce = saved_collectives.all_reduce
dist.barrier = saved_collectives.barrier
dist.gather = saved_collectives.gather
dist.scatter = saved_collectives.scatter
dist.broadcast = saved_collectives.broadcast
def run_worker(rank, world_size):
logging.getLogger().setLevel(
logging.DEBUG if rank == 0 else logging.CRITICAL
)
# logging.getLogger().setLevel(logging.DEBUG)
store = FakeStore()
# dist.init_process_group(
# "fake", rank=rank, world_size=world_size, store=store
# )
dist.init_process_group(
"nccl", rank=rank, world_size=world_size
)
logging.info(f"Number of visible devices: {torch.cuda.device_count()}")
torch.cuda.set_device(rank)
# with FakeTensorMode() as fake_mode:
with nullcontext():
with IgnoreDistMode():
test_tensor = torch.randn(10000, device="cuda")
output_tensor = torch.empty(
test_tensor.numel() * world_size, device="cuda"
)
# all_gather_tensor_inplace(output_tensor, test_tensor, dist.group.WORLD)
work = dist.all_gather_into_tensor(output_tensor, test_tensor, None, True)
if work is not None:
if rank == 0:
print(type(work))
future = work.get_future()
print(future.done())
print(type(future.value()))
print(future.value()[0].size())
print(future.value()[0].untyped_storage() == output_tensor.untyped_storage())
print(work.wait())
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
world_size = 2
mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)