-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the support of YOLO11' s det/cls/seg/pose in TensorRT8. (#1584)
* Add the support of YOLO11' s det/cls/seg/pose in TensorRT8. * add train code link
- Loading branch information
Showing
30 changed files
with
6,580 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
cmake_minimum_required(VERSION 3.10) | ||
|
||
project(yolov11) | ||
|
||
add_definitions(-std=c++11) | ||
add_definitions(-DAPI_EXPORTS) | ||
set(CMAKE_CXX_STANDARD 11) | ||
set(CMAKE_BUILD_TYPE Debug) | ||
|
||
set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc) | ||
enable_language(CUDA) | ||
|
||
include_directories(${PROJECT_SOURCE_DIR}/include) | ||
include_directories(${PROJECT_SOURCE_DIR}/plugin) | ||
|
||
# include and link dirs of cuda and tensorrt, you need adapt them if yours are different | ||
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") | ||
message("embed_platform on") | ||
include_directories(/usr/local/cuda/targets/aarch64-linux/include) | ||
link_directories(/usr/local/cuda/targets/aarch64-linux/lib) | ||
else() | ||
message("embed_platform off") | ||
|
||
# cuda | ||
include_directories(/usr/local/cuda/include) | ||
link_directories(/usr/local/cuda/lib64) | ||
|
||
# tensorrt | ||
include_directories(/workspace/shared/TensorRT-8.6.1.6/include) | ||
link_directories(/workspace/shared/TensorRT-8.6.1.6/lib) | ||
endif() | ||
|
||
add_library(myplugins SHARED ${PROJECT_SOURCE_DIR}/plugin/yololayer.cu) | ||
target_link_libraries(myplugins nvinfer cudart) | ||
|
||
find_package(OpenCV) | ||
include_directories(${OpenCV_INCLUDE_DIRS}) | ||
|
||
file(GLOB_RECURSE SRCS ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src/*.cu) | ||
|
||
add_executable(yolo11_det ${PROJECT_SOURCE_DIR}/yolo11_det.cpp ${SRCS}) | ||
target_link_libraries(yolo11_det nvinfer) | ||
target_link_libraries(yolo11_det cudart) | ||
target_link_libraries(yolo11_det myplugins) | ||
target_link_libraries(yolo11_det ${OpenCV_LIBS}) | ||
|
||
add_executable(yolo11_cls ${PROJECT_SOURCE_DIR}/yolo11_cls.cpp ${SRCS}) | ||
target_link_libraries(yolo11_cls nvinfer) | ||
target_link_libraries(yolo11_cls cudart) | ||
target_link_libraries(yolo11_cls myplugins) | ||
target_link_libraries(yolo11_cls ${OpenCV_LIBS}) | ||
|
||
add_executable(yolo11_seg ${PROJECT_SOURCE_DIR}/yolo11_seg.cpp ${SRCS}) | ||
target_link_libraries(yolo11_seg nvinfer) | ||
target_link_libraries(yolo11_seg cudart) | ||
target_link_libraries(yolo11_seg myplugins) | ||
target_link_libraries(yolo11_seg ${OpenCV_LIBS}) | ||
|
||
add_executable(yolo11_pose ${PROJECT_SOURCE_DIR}/yolo11_pose.cpp ${SRCS}) | ||
target_link_libraries(yolo11_pose nvinfer) | ||
target_link_libraries(yolo11_pose cudart) | ||
target_link_libraries(yolo11_pose myplugins) | ||
target_link_libraries(yolo11_pose ${OpenCV_LIBS}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import sys # noqa: F401 | ||
import argparse | ||
import os | ||
import struct | ||
import torch | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Convert .pt file to .wts') | ||
parser.add_argument('-w', '--weights', required=True, | ||
help='Input weights (.pt) file path (required)') | ||
parser.add_argument( | ||
'-o', '--output', help='Output (.wts) file path (optional)') | ||
parser.add_argument( | ||
'-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg', 'pose'], | ||
help='determines the model is detection/classification') | ||
args = parser.parse_args() | ||
if not os.path.isfile(args.weights): | ||
raise SystemExit('Invalid input file') | ||
if not args.output: | ||
args.output = os.path.splitext(args.weights)[0] + '.wts' | ||
elif os.path.isdir(args.output): | ||
args.output = os.path.join( | ||
args.output, | ||
os.path.splitext(os.path.basename(args.weights))[0] + '.wts') | ||
return args.weights, args.output, args.type | ||
|
||
|
||
pt_file, wts_file, m_type = parse_args() | ||
|
||
print(f'Generating .wts for {m_type} model') | ||
|
||
# Load model | ||
print(f'Loading {pt_file}') | ||
|
||
# Initialize | ||
device = 'cpu' | ||
|
||
# Load model | ||
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32 | ||
|
||
if m_type in ['detect', 'seg', 'pose']: | ||
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None] | ||
|
||
delattr(model.model[-1], 'anchors') | ||
|
||
model.to(device).eval() | ||
|
||
with open(wts_file, 'w') as f: | ||
f.write('{}\n'.format(len(model.state_dict().keys()))) | ||
for k, v in model.state_dict().items(): | ||
vr = v.reshape(-1).cpu().numpy() | ||
f.write('{} {} '.format(k, len(vr))) | ||
for vv in vr: | ||
f.write(' ') | ||
f.write(struct.pack('>f', float(vv)).hex()) | ||
f.write('\n') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#pragma once | ||
|
||
#include <map> | ||
#include <string> | ||
#include <vector> | ||
#include "NvInfer.h" | ||
|
||
std::map<std::string, nvinfer1::Weights> loadWeights(const std::string file); | ||
|
||
nvinfer1::IScaleLayer* addBatchNorm2d(nvinfer1::INetworkDefinition* network, | ||
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input, | ||
std::string lname, float eps); | ||
|
||
nvinfer1::IElementWiseLayer* convBnSiLU(nvinfer1::INetworkDefinition* network, | ||
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input, | ||
int ch, std::vector<int> k, int s, std::string lname); | ||
|
||
nvinfer1::IElementWiseLayer* C2F(nvinfer1::INetworkDefinition* network, | ||
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input, int c1, | ||
int c2, int n, bool shortcut, float e, std::string lname); | ||
|
||
nvinfer1::IElementWiseLayer* C2(nvinfer1::INetworkDefinition* network, | ||
std::map<std::string, nvinfer1::Weights>& weightMap, nvinfer1::ITensor& input, int c1, | ||
int c2, int n, bool shortcut, float e, std::string lname); | ||
|
||
nvinfer1::IElementWiseLayer* SPPF(nvinfer1::INetworkDefinition* network, | ||
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input, int c1, | ||
int c2, int k, std::string lname); | ||
|
||
nvinfer1::IShuffleLayer* DFL(nvinfer1::INetworkDefinition* network, std::map<std::string, nvinfer1::Weights> weightMap, | ||
nvinfer1::ITensor& input, int ch, int grid, int k, int s, int p, std::string lname); | ||
|
||
nvinfer1::IPluginV2Layer* addYoLoLayer(nvinfer1::INetworkDefinition* network, | ||
std::vector<nvinfer1::IConcatenationLayer*> dets, const int* px_arry, | ||
int px_arry_num, bool is_segmentation, bool is_pose); | ||
|
||
nvinfer1::IElementWiseLayer* C3K2(nvinfer1::INetworkDefinition* network, | ||
std::map<std::string, nvinfer1::Weights>& weightMap, nvinfer1::ITensor& input, int c1, | ||
int c2, int n, bool c3k, bool shortcut, float e, std::string lname); | ||
|
||
nvinfer1::ILayer* C2PSA(nvinfer1::INetworkDefinition* network, std::map<std::string, nvinfer1::Weights>& weightMap, | ||
nvinfer1::ITensor& input, int c1, int c2, int n, float e, std::string lname); | ||
|
||
nvinfer1::ILayer* DWConv(nvinfer1::INetworkDefinition* network, std::map<std::string, nvinfer1::Weights> weightMap, | ||
nvinfer1::ITensor& input, int ch, std::vector<int> k, int s, std::string lname); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#ifndef ENTROPY_CALIBRATOR_H | ||
#define ENTROPY_CALIBRATOR_H | ||
|
||
#include <NvInfer.h> | ||
#include <string> | ||
#include <vector> | ||
#include "macros.h" | ||
|
||
//! \class Int8EntropyCalibrator2 | ||
//! | ||
//! \brief Implements Entropy calibrator 2. | ||
//! CalibrationAlgoType is kENTROPY_CALIBRATION_2. | ||
//! | ||
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 { | ||
public: | ||
Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, | ||
const char* input_blob_name, bool read_cache = true); | ||
virtual ~Int8EntropyCalibrator2(); | ||
int getBatchSize() const TRT_NOEXCEPT override; | ||
bool getBatch(void* bindings[], const char* names[], int nbBindings) TRT_NOEXCEPT override; | ||
const void* readCalibrationCache(size_t& length) TRT_NOEXCEPT override; | ||
void writeCalibrationCache(const void* cache, size_t length) TRT_NOEXCEPT override; | ||
|
||
private: | ||
int batchsize_; | ||
int input_w_; | ||
int input_h_; | ||
int img_idx_; | ||
std::string img_dir_; | ||
std::vector<std::string> img_files_; | ||
size_t input_count_; | ||
std::string calib_table_name_; | ||
const char* input_blob_name_; | ||
bool read_cache_; | ||
void* device_input_; | ||
std::vector<char> calib_cache_; | ||
}; | ||
|
||
#endif // ENTROPY_CALIBRATOR_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
//#define USE_FP16 | ||
// #define USE_FP32 | ||
#define USE_INT8 | ||
|
||
const static char* kInputTensorName = "images"; | ||
const static char* kOutputTensorName = "output"; | ||
const static char* kProtoTensorName = "proto"; | ||
const static int kNumClass = 80; | ||
const static int kPoseNumClass = 1; | ||
const static int kNumberOfPoints = 17; // number of keypoints total | ||
const static int kBatchSize = 1; | ||
const static int kGpuId = 0; | ||
const static int kInputH = 640; | ||
const static int kInputW = 640; | ||
const static float kNmsThresh = 0.45f; | ||
const static float kConfThresh = 0.5f; | ||
const static float kConfThreshKeypoints = 0.5f; // keypoints confidence | ||
const static int kMaxInputImageSize = 3000 * 3000; | ||
const static int kMaxNumOutputBbox = 1000; | ||
//Quantization input image folder path | ||
const static char* kInputQuantizationFolder = "./coco_calib"; | ||
|
||
// Classfication model's number of classes | ||
constexpr static int kClsNumClass = 1000; | ||
// Classfication model's input shape | ||
constexpr static int kClsInputH = 224; | ||
constexpr static int kClsInputW = 224; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#ifndef TRTX_CUDA_UTILS_H_ | ||
#define TRTX_CUDA_UTILS_H_ | ||
|
||
#include <cuda_runtime_api.h> | ||
|
||
#ifndef CUDA_CHECK | ||
#define CUDA_CHECK(callstr) \ | ||
{ \ | ||
cudaError_t error_code = callstr; \ | ||
if (error_code != cudaSuccess) { \ | ||
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ | ||
assert(0); \ | ||
} \ | ||
} | ||
#endif // CUDA_CHECK | ||
|
||
#endif // TRTX_CUDA_UTILS_H_ |
Oops, something went wrong.