-
Notifications
You must be signed in to change notification settings - Fork 0
/
yolov8_det.cpp
301 lines (276 loc) · 12.1 KB
/
yolov8_det.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
#include <fstream>
#include <iostream>
#include <opencv2/opencv.hpp>
#include "cuda_utils.h"
#include "logging.h"
#include "model.h"
#include "postprocess.h"
#include "preprocess.h"
#include "utils.h"
Logger gLogger;
using namespace nvinfer1;
const int kOutputSize = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1;
void serialize_engine(std::string &wts_name, std::string &engine_name, int &is_p, std::string &sub_type, float &gd,
float &gw, int &max_channels) {
IBuilder *builder = createInferBuilder(gLogger);
IBuilderConfig *config = builder->createBuilderConfig();
IHostMemory *serialized_engine = nullptr;
if (is_p == 6) {
serialized_engine = buildEngineYolov8DetP6(builder, config, DataType::kFLOAT, wts_name, gd, gw, max_channels);
} else if (is_p == 2) {
serialized_engine = buildEngineYolov8DetP2(builder, config, DataType::kFLOAT, wts_name, gd, gw, max_channels);
} else {
serialized_engine = buildEngineYolov8Det(builder, config, DataType::kFLOAT, wts_name, gd, gw, max_channels);
}
assert(serialized_engine);
std::ofstream p(engine_name, std::ios::binary);
if (!p) {
std::cout << "could not open plan output file" << std::endl;
assert(false);
}
p.write(reinterpret_cast<const char *>(serialized_engine->data()), serialized_engine->size());
delete serialized_engine;
delete config;
delete builder;
}
void deserialize_engine(std::string &engine_name, IRuntime **runtime, ICudaEngine **engine,
IExecutionContext **context) {
std::ifstream file(engine_name, std::ios::binary);
if (!file.good()) {
std::cerr << "read " << engine_name << " error!" << std::endl;
assert(false);
}
size_t size = 0;
file.seekg(0, file.end);
size = file.tellg();
file.seekg(0, file.beg);
char *serialized_engine = new char[size];
assert(serialized_engine);
file.read(serialized_engine, size);
file.close();
*runtime = createInferRuntime(gLogger);
assert(*runtime);
*engine = (*runtime)->deserializeCudaEngine(serialized_engine, size);
assert(*engine);
*context = (*engine)->createExecutionContext();
assert(*context);
delete[] serialized_engine;
}
void prepare_buffer(ICudaEngine *engine, float **input_buffer_device, float **output_buffer_device,
float **output_buffer_host, float **decode_ptr_host, float **decode_ptr_device,
std::string cuda_post_process) {
assert(engine->getNbIOTensors() == 2);
// In order to bind the buffers, we need to know the names of the input and output tensors.
// Note that indices are guaranteed to be less than IEngine::getNbBindings()
TensorIOMode input_mode = engine->getTensorIOMode(kInputTensorName);
if (input_mode != TensorIOMode::kINPUT) {
std::cerr << kInputTensorName << " should be input tensor" << std::endl;
assert(false);
}
TensorIOMode output_mode = engine->getTensorIOMode(kOutputTensorName);
if (output_mode != TensorIOMode::kOUTPUT) {
std::cerr << kOutputTensorName << " should be output tensor" << std::endl;
assert(false);
}
// Create GPU buffers on device
CUDA_CHECK(cudaMalloc((void **) input_buffer_device, kBatchSize * 3 * kInputH * kInputW * sizeof(float)));
CUDA_CHECK(cudaMalloc((void **) output_buffer_device, kBatchSize * kOutputSize * sizeof(float)));
if (cuda_post_process == "c") {
*output_buffer_host = new float[kBatchSize * kOutputSize];
} else if (cuda_post_process == "g") {
if (kBatchSize > 1) {
std::cerr << "Do not yet support GPU post processing for multiple batches" << std::endl;
exit(0);
}
// Allocate memory for decode_ptr_host and copy to device
*decode_ptr_host = new float[1 + kMaxNumOutputBbox * bbox_element];
CUDA_CHECK(cudaMalloc((void **) decode_ptr_device, sizeof(float) * (1 + kMaxNumOutputBbox * bbox_element)));
}
}
void infer(IExecutionContext &context, cudaStream_t &stream, void **buffers, float *output, int batchsize,
float *decode_ptr_host, float *decode_ptr_device, int model_bboxes, std::string cuda_post_process) {
// infer on the batch asynchronously, and DMA output back to host
auto start = std::chrono::system_clock::now();
context.setInputTensorAddress(kInputTensorName, buffers[0]);
context.setOutputTensorAddress(kOutputTensorName, buffers[1]);
context.enqueueV3(stream);
if (cuda_post_process == "c") {
CUDA_CHECK(cudaMemcpyAsync(output, buffers[1], batchsize * kOutputSize * sizeof(float), cudaMemcpyDeviceToHost,
stream));
auto end = std::chrono::system_clock::now();
std::cout << "inference time: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()
<< "ms" << std::endl;
} else if (cuda_post_process == "g") {
CUDA_CHECK(
cudaMemsetAsync(decode_ptr_device, 0, sizeof(float) * (1 + kMaxNumOutputBbox * bbox_element), stream));
cuda_decode((float *) buffers[1], model_bboxes, kConfThresh, decode_ptr_device, kMaxNumOutputBbox, stream);
cuda_nms(decode_ptr_device, kNmsThresh, kMaxNumOutputBbox, stream); //cuda nms
CUDA_CHECK(cudaMemcpyAsync(decode_ptr_host, decode_ptr_device,
sizeof(float) * (1 + kMaxNumOutputBbox * bbox_element), cudaMemcpyDeviceToHost,
stream));
auto end = std::chrono::system_clock::now();
std::cout << "inference and gpu postprocess time: "
<< std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
}
CUDA_CHECK(cudaStreamSynchronize(stream));
}
bool parse_args(int argc, char **argv, std::string &wts, std::string &engine, int &is_p, std::string &img_dir,
std::string &sub_type, std::string &cuda_post_process, float &gd, float &gw, int &max_channels) {
if (argc < 4)
return false;
if (std::string(argv[1]) == "-s" && (argc == 5 || argc == 7)) {
wts = std::string(argv[2]);
engine = std::string(argv[3]);
auto sub_type = std::string(argv[4]);
if (sub_type[0] == 'n') {
gd = 0.33;
gw = 0.25;
max_channels = 1024;
} else if (sub_type[0] == 's') {
gd = 0.33;
gw = 0.50;
max_channels = 1024;
} else if (sub_type[0] == 'm') {
gd = 0.67;
gw = 0.75;
max_channels = 576;
} else if (sub_type[0] == 'l') {
gd = 1.0;
gw = 1.0;
max_channels = 512;
} else if (sub_type[0] == 'x') {
gd = 1.0;
gw = 1.25;
max_channels = 640;
} else {
return false;
}
if (sub_type.size() == 2 && sub_type[1] == '6') {
is_p = 6;
} else if (sub_type.size() == 2 && sub_type[1] == '2') {
is_p = 2;
}
} else if (std::string(argv[1]) == "-d" && argc == 5) {
engine = std::string(argv[2]);
img_dir = std::string(argv[3]);
cuda_post_process = std::string(argv[4]);
} else {
return false;
}
return true;
}
int main(int argc, char **argv) {
// -s ../models/yolov8n.wts ../models/yolov8n.fp32.trt n
// -d ../models/yolov8n.fp32.trt ../images c
cudaSetDevice(kGpuId);
std::string wts_name = "";
std::string engine_name = "";
std::string img_dir;
std::string sub_type = "";
std::string cuda_post_process = "";
int model_bboxes;
int is_p = 0;
float gd = 0.0f, gw = 0.0f;
int max_channels = 0;
if (!parse_args(argc, argv, wts_name, engine_name, is_p, img_dir, sub_type, cuda_post_process, gd, gw,
max_channels)) {
std::cerr << "Arguments not right!" << std::endl;
std::cerr << "./yolov8 -s [.wts] [.engine] [n/s/m/l/x/n2/s2/m2/l2/x2/n6/s6/m6/l6/x6] // serialize model to "
"plan file"
<< std::endl;
std::cerr << "./yolov8 -d [.engine] ../samples [c/g]// deserialize plan file and run inference" << std::endl;
return -1;
}
// Create a model using the API directly and serialize it to a file
if (!wts_name.empty()) {
serialize_engine(wts_name, engine_name, is_p, sub_type, gd, gw, max_channels);
return 0;
}
// Deserialize the engine from file
IRuntime *runtime = nullptr;
ICudaEngine *engine = nullptr;
IExecutionContext *context = nullptr;
deserialize_engine(engine_name, &runtime, &engine, &context);
cudaStream_t stream;
CUDA_CHECK(cudaStreamCreate(&stream));
cuda_preprocess_init(kMaxInputImageSize);
auto out_dims = engine->getTensorShape(kOutputTensorName);
model_bboxes = out_dims.d[1];
// Prepare cpu and gpu buffers
float *device_buffers[2];
float *output_buffer_host = nullptr;
float *decode_ptr_host = nullptr;
float *decode_ptr_device = nullptr;
// Read images from directory
std::vector<std::string> file_names;
if (read_files_in_dir(img_dir.c_str(), file_names) < 0) {
std::cerr << "read_files_in_dir failed." << std::endl;
return -1;
}
prepare_buffer(engine, &device_buffers[0], &device_buffers[1], &output_buffer_host, &decode_ptr_host,
&decode_ptr_device, cuda_post_process);
// batch predict
for (size_t i = 0; i < file_names.size(); i += kBatchSize) {
// Get a batch of images
std::vector<cv::Mat> img_batch;
std::vector<std::string> img_name_batch;
for (size_t j = i; j < i + kBatchSize && j < file_names.size(); j++) {
cv::Mat img = cv::imread(img_dir + "/" + file_names[j]);
if (img.empty()) {
std::cerr << "Fatal error: image cannot open!" << std::endl;
return -1;
}
img_batch.push_back(img);
img_name_batch.push_back(file_names[j]);
}
// Preprocess
cuda_batch_preprocess(img_batch, device_buffers[0], kInputW, kInputH, stream);
// Run inference
infer(*context, stream, (void **) device_buffers, output_buffer_host, kBatchSize, decode_ptr_host,
decode_ptr_device, model_bboxes, cuda_post_process);
std::vector<std::vector<Detection>> res_batch;
if (cuda_post_process == "c") {
// NMS
batch_nms(res_batch, output_buffer_host, img_batch.size(), kOutputSize, kConfThresh, kNmsThresh);
} else if (cuda_post_process == "g") {
//Process gpu decode and nms results
batch_process(res_batch, decode_ptr_host, img_batch.size(), bbox_element, img_batch);
}
// print results
for (size_t j = 0; j < res_batch.size(); j++) {
for (size_t k = 0; k < res_batch[j].size(); k++) {
std::cout << "image: " << img_name_batch[j] << ", bbox: " << res_batch[j][k].bbox[0] << ", "
<< res_batch[j][k].bbox[1] << ", " << res_batch[j][k].bbox[2] << ", "
<< res_batch[j][k].bbox[3] << ", conf: " << res_batch[j][k].conf << ", class_id: "
<< res_batch[j][k].class_id << std::endl;
}
}
// Draw bounding boxes
draw_bbox(img_batch, res_batch);
// Save images
for (size_t j = 0; j < img_batch.size(); j++) {
cv::imwrite("_" + img_name_batch[j], img_batch[j]);
}
}
// Release stream and buffers
cudaStreamDestroy(stream);
CUDA_CHECK(cudaFree(device_buffers[0]));
CUDA_CHECK(cudaFree(device_buffers[1]));
CUDA_CHECK(cudaFree(decode_ptr_device));
delete[] decode_ptr_host;
delete[] output_buffer_host;
cuda_preprocess_destroy();
// Destroy the engine
delete context;
delete engine;
delete runtime;
// Print histogram of the output distribution
//std::cout << "\nOutput:\n\n";
//for (unsigned int i = 0; i < kOutputSize; i++)
//{
// std::cout << prob[i] << ", ";
// if (i % 10 == 0) std::cout << std::endl;
//}
//std::cout << std::endl;
return 0;
}