forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
AdaptiveAveragePooling3d.cpp
347 lines (310 loc) · 11.3 KB
/
AdaptiveAveragePooling3d.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <c10/util/irange.h>
#include <ATen/native/AdaptivePooling.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_adaptive_avg_pool3d.h>
#include <ATen/ops/_adaptive_avg_pool3d_backward_native.h>
#include <ATen/ops/_adaptive_avg_pool3d_native.h>
#include <ATen/ops/adaptive_avg_pool3d_backward_native.h>
#include <ATen/ops/adaptive_avg_pool3d_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/zeros_like.h>
#endif
namespace at::native {
namespace {
template <typename scalar_t>
static void adaptive_avg_pool3d_out_frame(
const scalar_t* input_p,
scalar_t* output_p,
int64_t sizeD,
int64_t isizeT,
int64_t isizeH,
int64_t isizeW,
int64_t osizeT,
int64_t osizeH,
int64_t osizeW,
int64_t istrideD,
int64_t istrideT,
int64_t istrideH,
int64_t istrideW) {
at::parallel_for(0, sizeD, 1, [&](int64_t start, int64_t end) {
for (const auto d : c10::irange(start, end)) {
/* loop over output */
for (const auto ot : c10::irange(osizeT)) {
auto istartT = start_index(ot, osizeT, isizeT);
auto iendT = end_index(ot, osizeT, isizeT);
auto kT = iendT - istartT;
for (const auto oh : c10::irange(osizeH)) {
auto istartH = start_index(oh, osizeH, isizeH);
auto iendH = end_index(oh, osizeH, isizeH);
auto kH = iendH - istartH;
for (const auto ow : c10::irange(osizeW)) {
auto istartW = start_index(ow, osizeW, isizeW);
auto iendW = end_index(ow, osizeW, isizeW);
auto kW = iendW - istartW;
/* local pointers */
const scalar_t* ip = input_p + d * istrideD + istartT * istrideT +
istartH * istrideH + istartW * istrideW;
scalar_t* op = output_p + d * osizeT * osizeH * osizeW +
ot * osizeH * osizeW + oh * osizeW + ow;
/* compute local average: */
scalar_t sum = 0;
for (const auto it : c10::irange(kT)) {
for (const auto ih : c10::irange(kH)) {
for (const auto iw : c10::irange(kW)) {
scalar_t val =
*(ip + it * istrideT + ih * istrideH + iw * istrideW);
sum += val;
}
}
}
/* set output to local average */
*op = sum / kT / kH / kW;
}
}
}
}
});
}
void adaptive_avg_pool3d_out_cpu_template(
Tensor& output,
Tensor const& input,
IntArrayRef output_size) {
TORCH_CHECK(output_size.size() == 3, "adaptive_avg_pool3d: output_size must be 3");
for (const auto i : c10::irange(1, input.ndimension())) {
TORCH_CHECK(
input.size(i) > 0,
"adaptive_avg_pool3d(): Expected input to have non-zero size for non-batch dimensions, "
"but input has sizes ",
input.sizes(),
" with dimension ",
i,
" being "
"empty");
}
TORCH_CHECK(
(input.ndimension() == 4 || input.ndimension() == 5),
"adaptive_avg_pool3d(): Expected 4D or 5D tensor, but got ",
input.sizes());
TORCH_CHECK(input.dtype() == output.dtype(),
"expected dtype ", input.dtype(), " for `output` but got dtype ", output.dtype());
/* sizes */
int64_t sizeD = input.size(-4);
int64_t isizeT = input.size(-3);
int64_t isizeH = input.size(-2);
int64_t isizeW = input.size(-1);
/* strides */
int64_t istrideD = input.stride(-4);
int64_t istrideT = input.stride(-3);
int64_t istrideH = input.stride(-2);
int64_t istrideW = input.stride(-1);
/* output sizes */
auto osizeT = output_size[0];
auto osizeH = output_size[1];
auto osizeW = output_size[2];
if (input.ndimension() == 4) {
output.resize_({sizeD, osizeT, osizeH, osizeW});
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "adaptive_avg_pool3d_cpu", [&] {
auto input_data = input.const_data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
adaptive_avg_pool3d_out_frame<scalar_t>(
input_data,
output_data,
sizeD,
isizeT,
isizeH,
isizeW,
osizeT,
osizeH,
osizeW,
istrideD,
istrideT,
istrideH,
istrideW);
});
} else {
output.resize_({input.size(-5), sizeD, osizeT, osizeH, osizeW});
int64_t n = input.size(0);
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "adaptive_avg_pool3d_cpu", [&] {
auto input_data = input.const_data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
at::parallel_for(0, n, 1, [&](int64_t start, int64_t end) {
for (const auto b : c10::irange(start, end)) {
adaptive_avg_pool3d_out_frame<scalar_t>(
input_data + b * input.stride(0),
output_data + b * sizeD * osizeT * osizeH * osizeW,
sizeD,
isizeT,
isizeH,
isizeW,
osizeT,
osizeH,
osizeW,
istrideD,
istrideT,
istrideH,
istrideW);
}
});
});
}
}
template <typename scalar_t>
static void adaptive_avg_pool3d_backward_out_frame(
scalar_t* gradInput_p,
const scalar_t* gradOutput_p,
int64_t sizeD,
int64_t isizeT,
int64_t isizeH,
int64_t isizeW,
int64_t osizeT,
int64_t osizeH,
int64_t osizeW) {
at::parallel_for(0, sizeD, 1, [&](int64_t start, int64_t end) {
for (const auto d : c10::irange(start, end)) {
scalar_t* gradInput_p_d = gradInput_p + d * isizeT * isizeW * isizeH;
const scalar_t* gradOutput_p_d = gradOutput_p + d * osizeT * osizeW * osizeH;
/* calculate average */
for (const auto ot : c10::irange(osizeT)) {
auto istartT = start_index(ot, osizeT, isizeT);
auto iendT = end_index(ot, osizeT, isizeT);
auto kT = iendT - istartT;
for (const auto oh : c10::irange(osizeH)) {
auto istartH = start_index(oh, osizeH, isizeH);
auto iendH = end_index(oh, osizeH, isizeH);
auto kH = iendH - istartH;
for (const auto ow : c10::irange(osizeW)) {
auto istartW = start_index(ow, osizeW, isizeW);
auto iendW = end_index(ow, osizeW, isizeW);
auto kW = iendW - istartW;
scalar_t grad_delta =
gradOutput_p_d[ot * osizeH * osizeW + oh * osizeW + ow] / kT /
kH / kW;
for (const auto it : c10::irange(istartT, iendT)) {
for (const auto ih : c10::irange(istartH, iendH)) {
for (const auto iw : c10::irange(istartW, iendW)) {
/* update gradient */
gradInput_p_d[it * isizeH * isizeW + ih * isizeW + iw] +=
grad_delta;
}
}
}
}
}
}
}
});
}
Tensor& adaptive_avg_pool3d_backward_out_cpu_template(
Tensor& gradInput,
const Tensor& gradOutput_,
const Tensor& input) {
/* get contiguous gradOutput */
auto gradOutput = gradOutput_.contiguous();
adaptive_pool_empty_output_check(gradOutput_, "adaptive_avg_pool3d_backward");
/* sizes */
int64_t sizeD = input.size(-4);
int64_t isizeT = input.size(-3);
int64_t isizeH = input.size(-2);
int64_t isizeW = input.size(-1);
int64_t osizeT = gradOutput.size(-3);
int64_t osizeH = gradOutput.size(-2);
int64_t osizeW = gradOutput.size(-1);
/* backprop */
if (input.ndimension() == 4) {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "adaptive_avg_pool3d_backward_cpu", [&] {
/* get raw pointers */
scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
const scalar_t* gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
adaptive_avg_pool3d_backward_out_frame<scalar_t>(
gradInput_data,
gradOutput_data,
sizeD,
isizeT,
isizeH,
isizeW,
osizeT,
osizeH,
osizeW);
});
} else {
int64_t n = input.size(0);
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "adaptive_avg_pool3d_backward_cpu", [&] {
/* get raw pointers */
scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
const scalar_t* gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
at::parallel_for(0, n, 1, [&](int64_t start, int64_t end) {
for (const auto b : c10::irange(start, end)) {
adaptive_avg_pool3d_backward_out_frame<scalar_t>(
gradInput_data + b * sizeD * isizeT * isizeH * isizeW,
gradOutput_data + b * sizeD * osizeT * osizeH * osizeW,
sizeD,
isizeT,
isizeH,
isizeW,
osizeT,
osizeH,
osizeW);
}
});
});
}
return gradInput;
}
} // namespace
Tensor& adaptive_avg_pool3d_out_cpu(const Tensor& input,
IntArrayRef output_size,
Tensor& output) {
adaptive_avg_pool3d_out_cpu_template(output, input, output_size);
return output;
}
Tensor adaptive_avg_pool3d_cpu(Tensor const& input, IntArrayRef output_size) {
auto output = at::empty({0}, input.options());
adaptive_avg_pool3d_out_cpu_template(output, input, output_size);
return output;
}
Tensor adaptive_avg_pool3d_symint(Tensor const& input, SymIntArrayRef output_size) {
TORCH_CHECK(output_size.size() == 3, "adaptive_avg_pool3d: output_size must be 3");
TORCH_CHECK(
(output_size[0] >= 0 && output_size[1] >= 0 && output_size[2] >= 0),
"adaptive_avg_pool3d: elements of output_size must be greater than or equal to 0 ",
"but received {", output_size[0], ", ", output_size[1], ",", output_size[2], "}");
if (output_size[0] == 1 && output_size[1] == 1 && output_size[2] == 1) {
// in this case, adaptive pooling is just computing mean over hw
// dimensions, which can be done more efficiently
Tensor out = input.mean({-1, -2, -3}, /* keepdim = */ true);
if (input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d) {
// assert ndim == 5, since ndim = 4 doesn't give channels_last
const auto n = input.sym_size(0);
const auto c = input.sym_size(1);
out.as_strided__symint({n, c, 1, 1, 1}, {c, 1, c, c, c});
}
return out;
} else {
return _adaptive_avg_pool3d_symint(input, output_size);
}
}
Tensor& adaptive_avg_pool3d_backward_out_cpu(const Tensor& gradOutput_,
const Tensor& input,
Tensor& gradInput) {
gradInput.resize_as_(input).zero_();
adaptive_avg_pool3d_backward_out_cpu_template(gradInput, gradOutput_, input);
return gradInput;
}
Tensor adaptive_avg_pool3d_backward_cpu(const Tensor& gradOutput_,
const Tensor& input) {
auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
adaptive_avg_pool3d_backward_out_cpu_template(gradInput, gradOutput_, input);
return gradInput;
}
} // namespace at::native