-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
yololayer.h
154 lines (113 loc) · 4.92 KB
/
yololayer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#ifndef _YOLO_LAYER_H
#define _YOLO_LAYER_H
#include <iostream>
#include <vector>
#include "NvInfer.h"
#include "macros.h"
namespace Yolo
{
static constexpr int CHECK_COUNT = 3;
static constexpr float IGNORE_THRESH = 0.1f;
static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000;
static constexpr int CLASS_NUM = 80;
static constexpr int INPUT_H = 608;
static constexpr int INPUT_W = 608;
struct YoloKernel
{
int width;
int height;
float anchors[CHECK_COUNT*2];
};
static constexpr YoloKernel yolo1 = {
INPUT_W / 32,
INPUT_H / 32,
{116,90, 156,198, 373,326}
};
static constexpr YoloKernel yolo2 = {
INPUT_W / 16,
INPUT_H / 16,
{30,61, 62,45, 59,119}
};
static constexpr YoloKernel yolo3 = {
INPUT_W / 8,
INPUT_H / 8,
{10,13, 16,30, 33,23}
};
static constexpr int LOCATIONS = 4;
struct alignas(float) Detection{
//x y w h
float bbox[LOCATIONS];
float det_confidence;
float class_id;
float class_confidence;
};
}
namespace nvinfer1
{
class YoloLayerPlugin: public IPluginV2IOExt
{
public:
explicit YoloLayerPlugin();
YoloLayerPlugin(const void* data, size_t length);
~YoloLayerPlugin();
int getNbOutputs() const TRT_NOEXCEPT override
{
return 1;
}
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT override;
int initialize() TRT_NOEXCEPT override;
virtual void terminate() TRT_NOEXCEPT override {};
virtual size_t getWorkspaceSize(int maxBatchSize) const TRT_NOEXCEPT override { return 0;}
virtual int enqueue(int batchSize, const void*const * inputs, void*TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
virtual size_t getSerializationSize() const TRT_NOEXCEPT override;
virtual void serialize(void* buffer) const TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const TRT_NOEXCEPT override {
return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
}
const char* getPluginType() const TRT_NOEXCEPT override;
const char* getPluginVersion() const TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override;
IPluginV2IOExt* clone() const TRT_NOEXCEPT override;
void setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT override;
const char* getPluginNamespace() const TRT_NOEXCEPT override;
DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override;
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT override;
bool canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT override;
void attachToContext(
cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override;
void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT override;
void detachFromContext() TRT_NOEXCEPT override;
private:
void forwardGpu(const float *const * inputs,float * output, cudaStream_t stream,int batchSize = 1);
int mClassCount;
int mKernelCount;
std::vector<Yolo::YoloKernel> mYoloKernel;
int mThreadCount = 256;
const char* mPluginNamespace;
};
class YoloPluginCreator : public IPluginCreator
{
public:
YoloPluginCreator();
~YoloPluginCreator() override = default;
const char* getPluginName() const TRT_NOEXCEPT override;
const char* getPluginVersion() const TRT_NOEXCEPT override;
const PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override;
IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT override;
IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override;
void setPluginNamespace(const char* libNamespace) TRT_NOEXCEPT override
{
mNamespace = libNamespace;
}
const char* getPluginNamespace() const TRT_NOEXCEPT override
{
return mNamespace.c_str();
}
private:
std::string mNamespace;
static PluginFieldCollection mFC;
static std::vector<PluginField> mPluginAttributes;
};
REGISTER_TENSORRT_PLUGIN(YoloPluginCreator);
};
#endif