-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
BatchedNms.cu
executable file
·163 lines (144 loc) · 6.62 KB
/
BatchedNms.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#include <cuda.h>
#include <thrust/device_ptr.h>
#include <thrust/sequence.h>
#include <thrust/execution_policy.h>
#include <thrust/gather.h>
#include <cmath>
#include <algorithm>
#include <iostream>
#include <stdexcept>
#include <cstdint>
#include <vector>
#include "BatchedNmsPlugin.h"
#include "./cuda_utils.h"
#include "macros.h"
#ifdef CUDA_11
#include <cub/device/device_radix_sort.cuh>
#include <cub/iterator/counting_input_iterator.cuh>
#else
#include <thrust/system/cuda/detail/cub/device/device_radix_sort.cuh>
#include <thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh>
namespace cub = thrust::cuda_cub::cub;
#endif
namespace nvinfer1 {
__global__ void batched_nms_kernel(
const int nms_method, const float threshold, const int num_detections,
const int *indices, float *scores, const float *classes, const float4 *boxes) {
// Go through detections by descending score
for (int m = 0; m < num_detections; m++) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < num_detections && m < i && scores[m] > 0.0f) {
int idx = indices[i];
int max_idx = indices[m];
int icls = classes[idx];
int mcls = classes[max_idx];
if (mcls == icls) {
float4 ibox = boxes[idx];
float4 mbox = boxes[max_idx];
float x1 = max(ibox.x, mbox.x);
float y1 = max(ibox.y, mbox.y);
float x2 = min(ibox.z, mbox.z);
float y2 = min(ibox.w, mbox.w);
float w = max(0.0f, x2 - x1);
float h = max(0.0f, y2 - y1);
float iarea = (ibox.z - ibox.x) * (ibox.w - ibox.y);
float marea = (mbox.z - mbox.x) * (mbox.w - mbox.y);
float inter = w * h;
float overlap = inter / (iarea + marea - inter);
float sigma = 0.5; // this is an empirical value
// printf("nms_method: %d", nms_method);
//nms methods selection in the second stage
// 0: original nms
// 1: soft-nms (linear)
// 2: soft-nms (gaussian)
// printf("nms_method: ", nms_method);
switch (nms_method)
{
case 0:
if (overlap > threshold) {
scores[i] = 0.0f;
}
break;
case 1:
if (overlap > threshold) {
scores[i] = (1 - overlap) * scores[i];
}
break;
case 2:
if (overlap > threshold) {
scores[i] = std::exp(-(overlap * overlap) / sigma) * scores[i];
}
break;
default:
if (overlap > threshold) {
scores[i] = 0.0f;
}
break;
}
}
}
// Sync discarded detections
__syncthreads();
}
}
int batchedNms(int nms_method, int batch_size,
const void *const *inputs, void *TRT_CONST_ENQUEUE*outputs,
size_t count, int detections_per_im, float nms_thresh,
void *workspace, size_t workspace_size, cudaStream_t stream) {
if (!workspace || !workspace_size) {
// Return required scratch space size cub style
workspace_size += get_size_aligned<int>(count); // indices
workspace_size += get_size_aligned<int>(count); // indices_sorted
workspace_size += get_size_aligned<float>(count); // scores_sorted
size_t temp_size_sort = 0;
cub::DeviceRadixSort::SortPairsDescending(
static_cast<void*>(nullptr), temp_size_sort,
static_cast<float*>(nullptr),
static_cast<float*>(nullptr),
static_cast<int*>(nullptr),
static_cast<int*>(nullptr), count);
workspace_size += temp_size_sort;
return workspace_size;
}
auto on_stream = thrust::cuda::par.on(stream);
auto indices = get_next_ptr<int>(count, workspace, workspace_size);
std::vector<int> indices_h(count);
for (int i = 0; i < count; i++)
indices_h[i] = i;
cudaMemcpyAsync(indices, indices_h.data(), count * sizeof * indices, cudaMemcpyHostToDevice, stream);
auto indices_sorted = get_next_ptr<int>(count, workspace, workspace_size);
auto scores_sorted = get_next_ptr<float>(count, workspace, workspace_size);
for (int batch = 0; batch < batch_size; batch++) {
auto in_scores = static_cast<const float *>(inputs[0]) + batch * count;
auto in_boxes = static_cast<const float4 *>(inputs[1]) + batch * count;
auto in_classes = static_cast<const float *>(inputs[2]) + batch * count;
auto out_scores = static_cast<float *>(outputs[0]) + batch * detections_per_im;
auto out_boxes = static_cast<float4 *>(outputs[1]) + batch * detections_per_im;
auto out_classes = static_cast<float *>(outputs[2]) + batch * detections_per_im;
// Sort scores and corresponding indices
int num_detections = count;
cub::DeviceRadixSort::SortPairsDescending(workspace, workspace_size,
in_scores, scores_sorted, indices, indices_sorted, num_detections, 0, sizeof(*scores_sorted) * 8, stream);
// Launch actual NMS kernel - 1 block with each thread handling n detections
// TODO: different device has differnet max threads
const int max_threads = 1024;
int num_per_thread = ceil(static_cast<float>(num_detections) / max_threads);
batched_nms_kernel << <num_per_thread, max_threads, 0, stream >> > (nms_method, nms_thresh, num_detections,
indices_sorted, scores_sorted, in_classes, in_boxes);
// Re-sort with updated scores
cub::DeviceRadixSort::SortPairsDescending(workspace, workspace_size,
scores_sorted, scores_sorted, indices_sorted, indices,
num_detections, 0, sizeof(*scores_sorted) * 8, stream);
// Gather filtered scores, boxes, classes
num_detections = min(detections_per_im, num_detections);
cudaMemcpyAsync(out_scores, scores_sorted, num_detections * sizeof *scores_sorted,
cudaMemcpyDeviceToDevice, stream);
if (num_detections < detections_per_im) {
thrust::fill_n(on_stream, out_scores + num_detections, detections_per_im - num_detections, 0);
}
thrust::gather(on_stream, indices, indices + num_detections, in_boxes, out_boxes);
thrust::gather(on_stream, indices, indices + num_detections, in_classes, out_classes);
}
return 0;
}
} // namespace nvinfer1