forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CPUFallback.cpp
343 lines (314 loc) · 14.6 KB
/
CPUFallback.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/CPUFallback.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/stack.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <sstream>
#include <vector>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/_copy_from_and_resize.h>
#include <ATen/ops/_to_cpu.h>
#endif
namespace at::native {
// convenience helper for converting tensors to cpu
template<typename T, std::enable_if_t<std::is_same_v<T, at::Tensor> || std::is_same_v<T, std::optional<at::Tensor>>, int> = 1>
static std::vector<T> to_cpu(const std::vector<T>& tensors) {
// We can't just call at::to_cpu() on the entire list of Tensors
// Because it will break on undefined tensors. Separate out undefined tensors first.
const int num = tensors.size();
std::vector<T> cpu_tensors(num);
std::vector<at::Tensor> valid_tensors;
std::vector<bool> to_translate(num);
for (const auto i : c10::irange(num)) {
// Explicitly handling undefined tensors here instead of letting `at::_to_cpu` handle it.
// Otherwise, we'd need to require all backends with their own implementation of _to_cpu
// to properly handle undefined tensors.
if constexpr(std::is_same_v<T, std::optional<at::Tensor>>) {
if (tensors[i].has_value() && tensors[i].value().defined()) {
to_translate[i] = true;
valid_tensors.push_back(tensors[i].value());
} else {
cpu_tensors[i] = tensors[i];
}
} else {
if (tensors[i].defined()) {
to_translate[i] = true;
valid_tensors.push_back(tensors[i]);
} else {
cpu_tensors[i] = tensors[i];
}
}
}
auto cpu_valid_tensors = at::_to_cpu(valid_tensors);
for (int i = 0, defined_pos = 0; i < num; ++i) {
if (to_translate[i]) {
cpu_tensors[i] = std::move(cpu_valid_tensors[defined_pos++]);
}
}
return cpu_tensors;
}
static std::optional<c10::Device> compute_target_device(std::vector<at::Tensor>& t_args, const std::vector<c10::List<at::Tensor>>& tlist_args) {
// Decide what device to move the output tensor(s) to.
// The current convention is that we use the first tensor arg to pick the device
// Barring that, we take the first tensor from a TensorList arg.
if (!t_args.empty()) {
return t_args[0].device();
} else {
// We need to loop through all of the (potentially multiple) TensorList arguments
// In case, e.g. the first one is empty but the second is not.
for (auto& tens_list : tlist_args) {
for (const auto i : c10::irange(tens_list.size())) {
return tens_list.get(i).device();
}
}
}
return std::nullopt;
}
static bool validate_tensor_list(const c10::List<at::Tensor>& tensorlist) {
bool flag = false;
for (const auto& i : c10::irange(tensorlist.size())) {
if (tensorlist[i].defined())
flag = true;
}
return flag;
}
void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views,
c10::DispatchKey cpu_dispatch_key) {
TORCH_CHECK(c10::BackendComponent::CPUBit == c10::toBackendComponent(cpu_dispatch_key),
"Expected CPU backend DispatchKey but got ",
c10::toString(cpu_dispatch_key));
auto& schema_args = op.schema().arguments();
const auto num_arguments = schema_args.size();
auto arguments = torch::jit::last(stack, num_arguments);
const auto arguments_begin = stack->size() - num_arguments;
std::vector<at::Tensor> tensor_args;
std::vector<int> tensor_args_indices;
std::vector<c10::List<at::Tensor>> tensorlist_args;
std::vector<int> tensorlist_args_indices;
std::vector<c10::List<std::optional<at::Tensor>>> optional_tensorlist_args;
std::vector<int> optional_tensorlist_args_indices;
std::optional<c10::Device> tgt_device = std::nullopt;
// save converted cpu tensor for TensorList and optional TensorList
std::vector<c10::IValue> tensorlist_cpu_args;
std::vector<c10::IValue> optional_tensorlist_cpu_args;
// Step 1: Convert all non-CPU tensor inputs into CPU tensors
// and put them on the stack at the correct indices.
for (const auto idx : c10::irange(arguments.size())) {
const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) {
tensor_args.push_back(ivalue.toTensor());
tensor_args_indices.push_back(idx);
} else if (ivalue.isTensorList()) {
// Note: we copy each TensorList argument to CPU individually out of convenience,
// but XLA would benefit from materializing all tensor and TensorList args onto the CPU at the same time.
// We can improve this if we need better perf for XLA's CPU fallbacks.
tensorlist_args.push_back(ivalue.toTensorList());
tensorlist_args_indices.push_back(idx);
auto cpu_ivalue = c10::IValue(c10::List<at::Tensor>(to_cpu(ivalue.toTensorVector())));
tensorlist_cpu_args.push_back(cpu_ivalue);
(*stack)[arguments_begin + idx] = std::move(cpu_ivalue);
} else if (ivalue.isOptionalTensorList()) {
optional_tensorlist_args.push_back(ivalue.toOptionalTensorList());
optional_tensorlist_args_indices.push_back(idx);
auto cpu_ivalue = c10::IValue(c10::List<std::optional<at::Tensor>>(to_cpu(ivalue.toOptionalTensorVector())));
optional_tensorlist_cpu_args.push_back(cpu_ivalue);
(*stack)[arguments_begin + idx] = c10::IValue(cpu_ivalue);
} else if (ivalue.isDevice()) {
tgt_device = ivalue.toDevice();
(*stack)[arguments_begin + idx] = c10::IValue(c10::Device(kCPU));
}
}
// XLA requires all of the tensor arguments to be gathered up and converted to CPU together.
auto cpu_tensors = to_cpu(tensor_args);
for (const auto i : c10::irange(tensor_args_indices.size())) {
auto idx = tensor_args_indices[i];
(*stack)[arguments_begin + idx] = c10::IValue(cpu_tensors[i]);
}
// Step 2: Call the underlying CPU implementation of the operator
op.redispatchBoxed(c10::DispatchKeySet(cpu_dispatch_key), stack);
// Step 3: We need to take special care to handle mutable aliases properly:
// If any input tensors are mutable aliases, we need to
// directly copy the updated data on the CPU tensors back to the original inputs.
for (const auto i : c10::irange(tensor_args_indices.size())) {
auto tensor_idx = tensor_args_indices[i];
const AliasInfo* alias_info = schema_args[tensor_idx].alias_info();
if (alias_info != nullptr && alias_info->isWrite()) {
if (!tensor_args[i].defined()) continue;
at::_copy_from_and_resize(cpu_tensors[i], tensor_args[i]);
}
}
// We also need to explicit reapply input mutations to inputs that are lists
// of tensors
for (const auto i : c10::irange(tensorlist_args_indices.size())) {
auto tensorlist_idx = tensorlist_args_indices[i];
const AliasInfo* alias_info = schema_args[tensorlist_idx].alias_info();
if (alias_info != nullptr && alias_info->isWrite()) {
const auto& cpu_tensors = tensorlist_cpu_args[i].toTensorVector();
for (const auto idx : c10::irange(tensorlist_args[i].size())) {
if (!cpu_tensors[idx].defined()) continue;
at::_copy_from_and_resize(cpu_tensors[idx], tensorlist_args[i][idx]);
}
}
}
// We also need to explicit reapply input mutations to inputs that are lists
// of optional tensors
for (const auto i : c10::irange(optional_tensorlist_args_indices.size())) {
auto tensorlist_idx = optional_tensorlist_args_indices[i];
const AliasInfo* alias_info = schema_args[tensorlist_idx].alias_info();
if (alias_info != nullptr && alias_info->isWrite()) {
const auto& cpu_tensors = optional_tensorlist_cpu_args[i].toOptionalTensorList();
for (const auto idx : c10::irange(optional_tensorlist_args[i].size())) {
if (cpu_tensors[idx].has_value() && cpu_tensors[idx].value().defined()) {
const std::optional<at::Tensor>& optional_tensor = optional_tensorlist_args[i][idx];
at::_copy_from_and_resize(cpu_tensors[idx].value(), optional_tensor.value());
}
}
}
}
// Step 4: Convert any CPU output tensors back to the original input device.
// For mutable alias'd outputs, we also need to take special care
// to move the ORIGINAL input tensor back onto the stack, in place of
// the temporary CPU output tensor that we created.
//
// Note [CPU Fallback Does Not Handle View Operators]
// Also note that we are incapable of handling immutable aliases properly.
// Why?
// Schemas with an immutable alias'd tensor outputs correspond to view operators.
// For example, the `view_as` schema from native_functions.yaml:
// `view_as(Tensor(a) self, Tensor other) -> Tensor(a)`
// We can't handle these ops properly, because view ops are supposed to return
// a NEW tensor that shares the SAME storage as the original tensor.
// However, the new tensor that we created cannot share the same storage,
// since it lives on CPU and the original tensor lives on a different device.
// Because of that, we warn if someone attempts to call the
// CPU fallback on a view operator (this is to maintain BC for view ops for XLA
// that fall back to CPU).
const auto& schema_returns = op.schema().returns();
const auto& num_returns = schema_returns.size();
auto returns = torch::jit::last(stack, num_returns);
const auto returns_begin = stack->size() - num_returns;
if (tgt_device == std::nullopt) {
tgt_device = compute_target_device(tensor_args, tensorlist_args);
}
for (const auto idx : c10::irange(returns.size())) {
const AliasInfo* alias_info = schema_returns[idx].alias_info();
if (alias_info != nullptr && alias_info->isWrite()) {
// Case (1): mutable alias case.
// Move the input ivalue directly onto the stack in place of
// the existing cpu output tensor.
bool found_alias = false;
if (returns[idx].isTensor() && returns[idx].toTensor().defined()) {
// We could store some extra metadata on the function schema to avoid
// the loop here if we need to improve perf.
for (const auto i : c10::irange(tensor_args_indices.size())) {
auto input_tensor_idx = tensor_args_indices[i];
const auto& input_tensor = cpu_tensors[i];
const AliasInfo* input_alias_info =
schema_args[input_tensor_idx].alias_info();
// Checked above; adding assert to guard against breakage of the below
// condition due to changing the above if test.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(alias_info != nullptr);
if (input_tensor.defined() &&
(alias_info == input_alias_info ||
(input_alias_info != nullptr &&
*alias_info == *input_alias_info))) {
// We've found the original input tensor that aliases with the
// current output. Wrap it in an IValue and put it directly on the
// stack.
(*stack)[returns_begin + idx] = c10::IValue(tensor_args[i]);
found_alias = true;
break;
}
}
} else if (
returns[idx].isTensorList() &&
validate_tensor_list(returns[idx].toTensorList())) {
for (const auto i : c10::irange(tensorlist_args_indices.size())) {
auto input_tensor_idx = tensorlist_args_indices[i];
const AliasInfo* input_alias_info =
schema_args[input_tensor_idx].alias_info();
// Checked above; adding assert to guard against breakage of the below
// condition due to changing the above if test.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(alias_info != nullptr);
if (validate_tensor_list(tensorlist_args[i]) &&
(alias_info == input_alias_info ||
(input_alias_info != nullptr &&
*alias_info == *input_alias_info))) {
// We've found the original input tensor that aliases with the
// current output. Wrap it in an IValue and put it directly on the
// stack.
(*stack)[returns_begin + idx] = c10::IValue(tensorlist_args[i]);
found_alias = true;
break;
}
}
}
TORCH_CHECK(
found_alias,
"The operator ",
op.schema().operator_name(),
" appears to have invalid alias information. ",
"Found a return tensor argument with a mismatched mutable alias: ",
schema_returns[idx]);
} else {
if (alias_info != nullptr && !alias_info->isWrite()) {
// Case (3): immutable alias (view) case.
// Warn here, since we're copying and not creating a view.
// If this operator is needed, the backend should provide a kernel for
// it. See Note [CPU Fallback Does Not Handle View Operators]
std::stringstream dev_str;
if (tgt_device) {
dev_str << *tgt_device;
} else {
dev_str << "<none>";
}
if (error_on_views) {
TORCH_CHECK(
false,
"The operator ",
op.schema().operator_name(),
" appears to be a view operator, ",
"but it has no implementation for the backend \"",
dev_str.str(),
"\". View operators don't support ",
"since the tensor's storage cannot be shared across devices.");
} else {
TORCH_WARN(
false,
"The operator ",
op.schema().operator_name(),
" appears to be a view operator, ",
"but it has no implementation for the backend \"",
dev_str.str(),
"\". View operators don't support falling back to run on the CPU, ",
"since the tensor's storage cannot be shared across devices.");
}
}
// Case (2): copy case.
// Copy the cpu output tensor to the original device.
// We technically might not have a target device, e.g. if you call
// torch.cat() with an empty list In that case, we shouldn't have any
// tensors to schlep across devices anyway.
if (tgt_device) {
if (returns[idx].isTensor() && returns[idx].toTensor().defined()) {
(*stack)[returns_begin + idx] =
c10::IValue(returns[idx].toTensor().to(*tgt_device));
} else if (
returns[idx].isTensorList() &&
validate_tensor_list(returns[idx].toTensorList())) {
const auto& cpu_tensors = returns[idx].toTensorList().vec();
std::vector<at::Tensor> tensors;
tensors.reserve(cpu_tensors.size());
for (const auto& tensor : cpu_tensors) {
tensors.push_back(tensor.to(*tgt_device));
}
(*stack)[returns_begin + idx] =
c10::IValue(c10::List<at::Tensor>(tensors));
}
}
}
}
}
} // namespace at::native