forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DeviceAccelerator.cpp
82 lines (72 loc) · 2.93 KB
/
DeviceAccelerator.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include <c10/core/DeviceGuard.h>
#include <torch/csrc/DeviceAccelerator.h>
#include <torch/csrc/utils/device_lazy_init.h>
namespace torch::accelerator {
void initModule(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def("_accelerator_getAccelerator", []() {
// If no accelerator is currently available, raise an exception.
return c10::Device(at::getAccelerator(true).value());
});
m.def("_accelerator_deviceCount", []() {
const auto device_type = at::getAccelerator(false);
if (!device_type.has_value()) {
return static_cast<c10::DeviceIndex>(0);
}
torch::utils::maybe_initialize_device(device_type.value());
c10::impl::VirtualGuardImpl impl(device_type.value());
return static_cast<c10::DeviceIndex>(impl.deviceCount());
});
m.def("_accelerator_setDeviceIndex", [](c10::DeviceIndex device_index) {
const auto device_type = at::getAccelerator(true).value();
// If device index is negative, no-op
if (device_index < 0) {
return;
}
torch::utils::maybe_initialize_device(device_type);
c10::impl::VirtualGuardImpl impl(device_type);
impl.setDevice({device_type, device_index});
});
m.def("_accelerator_getDeviceIndex", []() {
const auto device_type = at::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
c10::impl::VirtualGuardImpl impl(device_type);
return static_cast<c10::DeviceIndex>(impl.getDevice().index());
});
m.def("_accelerator_setStream", [](c10::Stream stream) {
const auto device_type = at::getAccelerator(true).value();
TORCH_CHECK(
device_type == stream.device_type(),
"stream's device type ",
c10::DeviceTypeName(stream.device_type()),
" doesn't match the current accelerator ",
c10::DeviceTypeName(device_type));
torch::utils::maybe_initialize_device(device_type);
c10::impl::VirtualGuardImpl impl(device_type);
// Set the current device to the device of stream
if (impl.getDevice().index() != stream.device_index()) {
impl.setDevice(stream.device());
}
impl.exchangeStream(stream);
});
m.def("_accelerator_getStream", [](c10::DeviceIndex device_index) {
const auto device_type = at::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
c10::impl::VirtualGuardImpl impl(device_type);
return impl.getStream({device_type, device_index});
});
m.def("_accelerator_synchronizeDevice", [](c10::DeviceIndex device_index) {
const auto device_type = at::getAccelerator(true).value();
if (!torch::utils::is_device_initialized(device_type)) {
return;
}
torch::utils::maybe_initialize_device(device_type);
c10::impl::VirtualGuardImpl impl(device_type);
// impl.synchronizeDevice should can be safely called from any device
{
py::gil_scoped_release no_gil;
impl.synchronizeDevice(device_index);
}
});
}
} // namespace torch::accelerator