diff --git a/efficient_ad/CMakeLists.txt b/efficient_ad/CMakeLists.txt new file mode 100644 index 00000000..00b30b74 --- /dev/null +++ b/efficient_ad/CMakeLists.txt @@ -0,0 +1,37 @@ +cmake_minimum_required(VERSION 3.12) +project(EfficientAD-M) + +add_definitions(-w) +add_definitions(-D API_EXPORTS) +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_BUILD_TYPE "Debug") +set(CMAKE_CUDA_ARCHITECTURES 61 75 86 89) +set(THREADS_PREFER_PTHREAD_FLAG ON) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /od") + +### nvcc +set(CMAKE_CUDA_COMPILER "D:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8/bin/nvcc.exe") +enable_language(CUDA) +### cuda +include_directories("D:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8/include") +link_directories("D:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8/lib/x64") +### tensorrt +set(TRT_DIR "D:/Program Files/NVIDIA GPU Computing Toolkit/TensorRT-8.5.3.1/") +include_directories(${TRT_DIR}/include) +link_directories(${TRT_DIR}/lib) +### opencv +set(OpenCV_DIR "E:/OpenCV/OpenCV_4.6.0/opencv/build") +find_package(OpenCV) +include_directories(${OpenCV_INCLUDE_DIRS}) +### dirent +include_directories("E:/SDK/dirent-1.24/include") + +include_directories(${PROJECT_SOURCE_DIR}/src/) +file(GLOB_RECURSE SRCS ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src/*.cu) + +add_executable(efficientAD_det "./efficientAD_det.cpp" ${SRCS}) +target_link_libraries(efficientAD_det nvinfer + cudart + nvinfer_plugin + ${OpenCV_LIBS} + ) \ No newline at end of file diff --git a/efficient_ad/datas/models/gen_wts.py b/efficient_ad/datas/models/gen_wts.py new file mode 100644 index 00000000..dbe7a6b4 --- /dev/null +++ b/efficient_ad/datas/models/gen_wts.py @@ -0,0 +1,19 @@ +import torch +import struct +import sys + +# Initialize +pt_file = sys.argv[1] +# Load model +model = torch.load(pt_file, map_location=torch.device('cpu'))['model'].float() # load to FP32 +model.to(device).eval() + +with open(pt_file.split('.')[0] + '.wts', 'w') as f: + f.write('{}\n'.format(len(model.state_dict().keys()))) + for k, v in model.state_dict().items(): + vr = v.reshape(-1).cpu().numpy() + f.write('{} {} '.format(k, len(vr))) + for vv in vr: + f.write(' ') + f.write(struct.pack('>f',float(vv)).hex()) + f.write('\n') \ No newline at end of file diff --git a/efficient_ad/efficientAD_det.cpp b/efficient_ad/efficientAD_det.cpp new file mode 100644 index 00000000..bde0f012 --- /dev/null +++ b/efficient_ad/efficientAD_det.cpp @@ -0,0 +1,247 @@ +#include + +#include +#include +#include +#include +#include + +#include "config.h" +#include "cuda_utils.h" +#include "logging.h" +#include "model.h" +#include "postprocess.h" +#include "utils.h" + +using namespace nvinfer1; + +static Logger gLogger; +// const static int kOutputSize = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1; +const static int kInputSize = 3 * 256 * 256; +const static int kOutputSize = 1 * 256 * 256; + +bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, float& gd, float& gw, std::string& img_dir) { + if (argc != 4) return false; + if (std::string(argv[1]) == "-s") { + wts = std::string(argv[2]); + engine = std::string(argv[3]); + } else if (std::string(argv[1]) == "-d") { + engine = std::string(argv[2]); + img_dir = std::string(argv[3]); + } else { + return false; + } + return true; +} + +void prepare_infer_buffers(ICudaEngine* engine, float** gpu_input_buffer, float** gpu_output_buffer, float** cpu_output_buffer) { + // assert(engine->getNbIOTensors() == 2); + assert(engine->getNbBindings() == 2); + + // In order to bind the buffers, we need to know the names of the input and output tensors. + // Note that indices are guaranteed to be less than IEngine::getNbBindings() + const int inputIndex = engine->getBindingIndex(kInputTensorName); + const int outputIndex = engine->getBindingIndex(kOutputTensorName); + // nvinfer1::Dims outputDims = engine->getBindingDimensions(outputIndex); + assert(inputIndex == 0); + assert(outputIndex == 1); + + // Create GPU in/output buffers on device + CUDA_CHECK(cudaMalloc((void**)gpu_input_buffer, kBatchSize * 3 * kInputH * kInputW * sizeof(float))); + CUDA_CHECK(cudaMalloc((void**)gpu_output_buffer, kBatchSize * 1 * kOutputSize * sizeof(float))); // 3 or 1 ?? + // Create CPU output buffers on host + *cpu_output_buffer = new float[kBatchSize * kOutputSize]; +} + +void preprocessImg(cv::Mat& img, int newh, int neww) { + cv::cvtColor(img, img, cv::COLOR_BGR2RGB); + cv::resize(img, img, cv::Size(neww, newh)); + img.convertTo(img, CV_32FC3); + // ImageNet normalize + img /= 255.0f; + img -= cv::Scalar(0.485, 0.456, 0.406); + img /= cv::Scalar(0.229, 0.224, 0.225); +} + +void infer(IExecutionContext& context, cudaStream_t& stream, std::vector& gpu_buffers, std::vector& cpu_input_data, + std::vector& cpu_output_data, int batchsize) { + // copy input data from host (CPU) to device (GPU) + CUDA_CHECK(cudaMemcpyAsync(gpu_buffers[0], cpu_input_data.data(), cpu_input_data.size() * sizeof(float), cudaMemcpyHostToDevice, stream)); + // execute inference using context provided by engine + context.enqueue(batchsize, gpu_buffers.data(), stream, nullptr); + // copy output back from device (GPU) to host (CPU) + CUDA_CHECK(cudaMemcpyAsync(cpu_output_data.data(), gpu_buffers[1], batchsize * kOutputSize * sizeof(float), cudaMemcpyDeviceToHost, stream)); + // synchronize the stream to prevent issues (block CUDA and wait for CUDA operations to be completed) + cudaStreamSynchronize(stream); +} + +void serialize_engine(unsigned int max_batchsize, float& gd, float& gw, std::string& wts_name, std::string& engine_name) { + // Create builder + IBuilder* builder = createInferBuilder(gLogger); + IBuilderConfig* config = builder->createBuilderConfig(); + + // Create model to populate the network, then set the outputs and create an engine + ICudaEngine* engine = nullptr; + engine = build_efficientAD_engine(max_batchsize, builder, config, DataType::kFLOAT, gd, gw, wts_name); + assert(engine != nullptr); + + // Serialize the engine + IHostMemory* serialized_engine = engine->serialize(); + assert(serialized_engine != nullptr); + + // Save engine to file + std::ofstream p(engine_name, std::ios::binary); + if (!p) { + std::cerr << "Could not open plan output file" << std::endl; + assert(false); + } + p.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + + // Close everything down + engine->destroy(); + config->destroy(); + serialized_engine->destroy(); + builder->destroy(); +} + +void deserialize_engine(std::string& engine_name, IRuntime** runtime, ICudaEngine** engine, IExecutionContext** context) { + std::ifstream file(engine_name, std::ios::binary); + if (!file.good()) { + std::cerr << "read " << engine_name << " error!" << std::endl; + assert(false); + } + size_t size = 0; + file.seekg(0, file.end); + size = file.tellg(); + file.seekg(0, file.beg); + char* serialized_engine = new char[size]; + assert(serialized_engine); + file.read(serialized_engine, size); + file.close(); + + *runtime = createInferRuntime(gLogger); + assert(*runtime); + *engine = (*runtime)->deserializeCudaEngine(serialized_engine, size); + assert(*engine != nullptr); + *context = (*engine)->createExecutionContext(); + assert(*context); + + delete[] serialized_engine; +} + +int main(int argc, char** argv) { + cudaSetDevice(kGpuId); + + std::string wts_name = ""; + std::string engine_name = ""; + float gd = 1.0f, gw = 1.0f; + std::string img_dir; + + if (!parse_args(argc, argv, wts_name, engine_name, gd, gw, img_dir)) { + std::cerr << "arguments not right!" << std::endl; + std::cerr << "./efficientad_det -s [.wts] [.engine] // serialize model to plan file" << std::endl; + std::cerr << "./efficientad_det -d [.engine] [../../datas/images/...] // deserialize plan file and run inference" << std::endl; + return -1; + } + + // Create a model using the API directly and serialize it to a file + if (!wts_name.empty()) { + serialize_engine(kBatchSize, gd, gw, wts_name, engine_name); + return 0; + } + + // Deserialize the engine from file + IRuntime* runtime = nullptr; + ICudaEngine* engine = nullptr; + IExecutionContext* context = nullptr; + deserialize_engine(engine_name, &runtime, &engine, &context); + + // create CUDA stream for simultaneous CUDA operations + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + + // prepare cpu and gpu buffers + void *gpu_input_buffer, *gpu_output_buffer; + CUDA_CHECK(cudaMalloc(&gpu_input_buffer, kBatchSize * 3 * kInputH * kInputW * sizeof(float))); + CUDA_CHECK(cudaMalloc(&gpu_output_buffer, kBatchSize * 1 * kOutputSize * sizeof(float))); // 3 or 1 ?? + std::vector gpu_buffers = {gpu_input_buffer, gpu_output_buffer}; + std::vector cpu_input_data(kBatchSize * kInputSize, 0); + std::vector cpu_output_data(kBatchSize * kOutputSize, 0); + + // read images from directory + std::vector file_names; + if (read_files_in_dir(img_dir.c_str(), file_names) < 0) { + std::cerr << "read_files_in_dir failed." << std::endl; + return -1; + } + + std::vector originImg_batch; + for (size_t i = 0; i < file_names.size(); i += kBatchSize) { + // get a batch of images + std::vector img_batch; + std::vector img_name_batch; + + for (size_t j = i; j < i + kBatchSize && j < file_names.size(); j++) { + cv::Mat img = cv::imread(img_dir + "/" + file_names[j]); + originImg_batch.push_back(img.clone()); + preprocessImg(img, kInputW, kInputH); + assert(img.cols * img.rows * 3 == 3 * 256 * 256); + for (int c = 0; c < 3; c++) { + for (int h = 0; h < img.rows; h++) { + for (int w = 0; w < img.cols; w++) { + cpu_input_data[c * img.rows * img.cols + + h * img.cols + + w] = img.at(h, w)[c]; + } + } + } + img_batch.push_back(img); + img_name_batch.push_back(file_names[j]); + } + + // Run inference + auto start = std::chrono::system_clock::now(); + // infer(*context, stream, (void**)gpu_buffers, cpu_input_data, cpu_output_buffer, kBatchSize); + infer(*context, stream, gpu_buffers, cpu_input_data, cpu_output_data, kBatchSize); // change to save into vec `cpu_output_data` + auto end = std::chrono::system_clock::now(); + std::cout << "inference time: " << std::chrono::duration_cast(end - start).count() << "ms" << std::endl; + + // postProcess + cv::Mat img_1(256, 256, CV_8UC1); + for (int row = 0; row < 256; row++) { + for (int col = 0; col < 256; col++) { + float value = cpu_output_data[row * 256 + col]; + if (value < 0) // clip(0,1) + value = 0; + else if (value > 1) + value = 1; + img_1.at(row, col) = static_cast(value * 255); + } + } + + cv::Mat HeatMap, colorMap; + // genHeatMap(img_batch[0], img_1, HeatMap); + cv::applyColorMap(img_1, colorMap, cv::COLORMAP_JET); + cv::resize(originImg_batch[i], originImg_batch[i], cv::Size(256, 256)); + cv::cvtColor(originImg_batch[i], originImg_batch[i], cv::COLOR_RGB2BGR); + cv::addWeighted(originImg_batch[i], 0.5, colorMap, 0.5, 0, HeatMap); + + // Save images + for (size_t j = 0; j < img_batch.size(); j++) { + cv::imwrite("_output" + img_name_batch[j], img_1); + cv::imwrite("_heatmap" + img_name_batch[j], HeatMap); + } + } + + // Release stream and buffers + cudaStreamDestroy(stream); + CUDA_CHECK(cudaFree(gpu_buffers[0])); + CUDA_CHECK(cudaFree(gpu_buffers[1])); + + // Destroy the engine + context->destroy(); + engine->destroy(); + runtime->destroy(); + + return 0; +} diff --git a/efficient_ad/src/config.h b/efficient_ad/src/config.h new file mode 100644 index 00000000..5e1204c3 --- /dev/null +++ b/efficient_ad/src/config.h @@ -0,0 +1,32 @@ +#pragma once + +/* -------------------------------------------------------- + * These configs are related to tensorrt model, if these are changed, + * please re-compile and re-serialize the tensorrt model. + * --------------------------------------------------------*/ + +// For INT8, you need prepare the calibration dataset, please refer to +#define USE_FP32 // set USE_INT8 or USE_FP16 or USE_FP32 + +// These are used to define input/output tensor names, +// you can set them to whatever you want. +const static char* kInputTensorName = "data"; +const static char* kOutputTensorName = "prob"; + +constexpr static int kBatchSize = 1; + +// input width and height must by divisible by 32 +constexpr static int kInputH = 256; +constexpr static int kInputW = 256; + +/* -------------------------------------------------------- + * These configs are NOT related to tensorrt model, if these are changed, + * please re-compile, but no need to re-serialize the tensorrt model. + * --------------------------------------------------------*/ + +// default GPU_id +const static int kGpuId = 0; + +// If your image size is larger than 4096 * 3112, please increase this value +const static int kMaxInputImageSize = 4096 * 3112; + diff --git a/efficient_ad/src/cuda_utils.h b/efficient_ad/src/cuda_utils.h new file mode 100644 index 00000000..8fbd3199 --- /dev/null +++ b/efficient_ad/src/cuda_utils.h @@ -0,0 +1,18 @@ +#ifndef TRTX_CUDA_UTILS_H_ +#define TRTX_CUDA_UTILS_H_ + +#include + +#ifndef CUDA_CHECK +#define CUDA_CHECK(callstr)\ + {\ + cudaError_t error_code = callstr;\ + if (error_code != cudaSuccess) {\ + std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__;\ + assert(0);\ + }\ + } +#endif // CUDA_CHECK + +#endif // TRTX_CUDA_UTILS_H_ + diff --git a/efficient_ad/src/logging.h b/efficient_ad/src/logging.h new file mode 100644 index 00000000..6b79a8b9 --- /dev/null +++ b/efficient_ad/src/logging.h @@ -0,0 +1,504 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TENSORRT_LOGGING_H +#define TENSORRT_LOGGING_H + +#include "NvInferRuntimeCommon.h" +#include +#include +#include +#include +#include +#include +#include +#include "macros.h" + +using Severity = nvinfer1::ILogger::Severity; + +class LogStreamConsumerBuffer : public std::stringbuf +{ +public: + LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) + : mOutput(stream) + , mPrefix(prefix) + , mShouldLog(shouldLog) + { + } + + LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) + : mOutput(other.mOutput) + { + } + + ~LogStreamConsumerBuffer() + { + // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence + // std::streambuf::pptr() gives a pointer to the current position of the output sequence + // if the pointer to the beginning is not equal to the pointer to the current position, + // call putOutput() to log the output to the stream + if (pbase() != pptr()) + { + putOutput(); + } + } + + // synchronizes the stream buffer and returns 0 on success + // synchronizing the stream buffer consists of inserting the buffer contents into the stream, + // resetting the buffer and flushing the stream + virtual int sync() + { + putOutput(); + return 0; + } + + void putOutput() + { + if (mShouldLog) + { + // prepend timestamp + std::time_t timestamp = std::time(nullptr); + tm* tm_local = std::localtime(×tamp); + std::cout << "["; + std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; + std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; + // std::stringbuf::str() gets the string contents of the buffer + // insert the buffer contents pre-appended by the appropriate prefix into the stream + mOutput << mPrefix << str(); + // set the buffer to empty + str(""); + // flush the stream + mOutput.flush(); + } + } + + void setShouldLog(bool shouldLog) + { + mShouldLog = shouldLog; + } + +private: + std::ostream& mOutput; + std::string mPrefix; + bool mShouldLog; +}; + +//! +//! \class LogStreamConsumerBase +//! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer +//! +class LogStreamConsumerBase +{ +public: + LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) + : mBuffer(stream, prefix, shouldLog) + { + } + +protected: + LogStreamConsumerBuffer mBuffer; +}; + +//! +//! \class LogStreamConsumer +//! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. +//! Order of base classes is LogStreamConsumerBase and then std::ostream. +//! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field +//! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. +//! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. +//! Please do not change the order of the parent classes. +//! +class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream +{ +public: + //! \brief Creates a LogStreamConsumer which logs messages with level severity. + //! Reportable severity determines if the messages are severe enough to be logged. + LogStreamConsumer(Severity reportableSeverity, Severity severity) + : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) + , std::ostream(&mBuffer) // links the stream buffer with the stream + , mShouldLog(severity <= reportableSeverity) + , mSeverity(severity) + { + } + + LogStreamConsumer(LogStreamConsumer&& other) + : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) + , std::ostream(&mBuffer) // links the stream buffer with the stream + , mShouldLog(other.mShouldLog) + , mSeverity(other.mSeverity) + { + } + + void setReportableSeverity(Severity reportableSeverity) + { + mShouldLog = mSeverity <= reportableSeverity; + mBuffer.setShouldLog(mShouldLog); + } + +private: + static std::ostream& severityOstream(Severity severity) + { + return severity >= Severity::kINFO ? std::cout : std::cerr; + } + + static std::string severityPrefix(Severity severity) + { + switch (severity) + { + case Severity::kINTERNAL_ERROR: return "[F] "; + case Severity::kERROR: return "[E] "; + case Severity::kWARNING: return "[W] "; + case Severity::kINFO: return "[I] "; + case Severity::kVERBOSE: return "[V] "; + default: assert(0); return ""; + } + } + + bool mShouldLog; + Severity mSeverity; +}; + +//! \class Logger +//! +//! \brief Class which manages logging of TensorRT tools and samples +//! +//! \details This class provides a common interface for TensorRT tools and samples to log information to the console, +//! and supports logging two types of messages: +//! +//! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) +//! - Test pass/fail messages +//! +//! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is +//! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. +//! +//! In the future, this class could be extended to support dumping test results to a file in some standard format +//! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). +//! +//! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger +//! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT +//! library and messages coming from the sample. +//! +//! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the +//! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger +//! object. + +class Logger : public nvinfer1::ILogger +{ +public: + Logger(Severity severity = Severity::kWARNING) + : mReportableSeverity(severity) + { + } + + //! + //! \enum TestResult + //! \brief Represents the state of a given test + //! + enum class TestResult + { + kRUNNING, //!< The test is running + kPASSED, //!< The test passed + kFAILED, //!< The test failed + kWAIVED //!< The test was waived + }; + + //! + //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger + //! \return The nvinfer1::ILogger associated with this Logger + //! + //! TODO Once all samples are updated to use this method to register the logger with TensorRT, + //! we can eliminate the inheritance of Logger from ILogger + //! + nvinfer1::ILogger& getTRTLogger() + { + return *this; + } + + //! + //! \brief Implementation of the nvinfer1::ILogger::log() virtual method + //! + //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the + //! inheritance from nvinfer1::ILogger + //! + void log(Severity severity, const char* msg) TRT_NOEXCEPT override + { + LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; + } + + //! + //! \brief Method for controlling the verbosity of logging output + //! + //! \param severity The logger will only emit messages that have severity of this level or higher. + //! + void setReportableSeverity(Severity severity) + { + mReportableSeverity = severity; + } + + //! + //! \brief Opaque handle that holds logging information for a particular test + //! + //! This object is an opaque handle to information used by the Logger to print test results. + //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used + //! with Logger::reportTest{Start,End}(). + //! + class TestAtom + { + public: + TestAtom(TestAtom&&) = default; + + private: + friend class Logger; + + TestAtom(bool started, const std::string& name, const std::string& cmdline) + : mStarted(started) + , mName(name) + , mCmdline(cmdline) + { + } + + bool mStarted; + std::string mName; + std::string mCmdline; + }; + + //! + //! \brief Define a test for logging + //! + //! \param[in] name The name of the test. This should be a string starting with + //! "TensorRT" and containing dot-separated strings containing + //! the characters [A-Za-z0-9_]. + //! For example, "TensorRT.sample_googlenet" + //! \param[in] cmdline The command line used to reproduce the test + // + //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). + //! + static TestAtom defineTest(const std::string& name, const std::string& cmdline) + { + return TestAtom(false, name, cmdline); + } + + //! + //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments + //! as input + //! + //! \param[in] name The name of the test + //! \param[in] argc The number of command-line arguments + //! \param[in] argv The array of command-line arguments (given as C strings) + //! + //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). + static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) + { + auto cmdline = genCmdlineString(argc, argv); + return defineTest(name, cmdline); + } + + //! + //! \brief Report that a test has started. + //! + //! \pre reportTestStart() has not been called yet for the given testAtom + //! + //! \param[in] testAtom The handle to the test that has started + //! + static void reportTestStart(TestAtom& testAtom) + { + reportTestResult(testAtom, TestResult::kRUNNING); + assert(!testAtom.mStarted); + testAtom.mStarted = true; + } + + //! + //! \brief Report that a test has ended. + //! + //! \pre reportTestStart() has been called for the given testAtom + //! + //! \param[in] testAtom The handle to the test that has ended + //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, + //! TestResult::kFAILED, TestResult::kWAIVED + //! + static void reportTestEnd(const TestAtom& testAtom, TestResult result) + { + assert(result != TestResult::kRUNNING); + assert(testAtom.mStarted); + reportTestResult(testAtom, result); + } + + static int reportPass(const TestAtom& testAtom) + { + reportTestEnd(testAtom, TestResult::kPASSED); + return EXIT_SUCCESS; + } + + static int reportFail(const TestAtom& testAtom) + { + reportTestEnd(testAtom, TestResult::kFAILED); + return EXIT_FAILURE; + } + + static int reportWaive(const TestAtom& testAtom) + { + reportTestEnd(testAtom, TestResult::kWAIVED); + return EXIT_SUCCESS; + } + + static int reportTest(const TestAtom& testAtom, bool pass) + { + return pass ? reportPass(testAtom) : reportFail(testAtom); + } + + Severity getReportableSeverity() const + { + return mReportableSeverity; + } + +private: + //! + //! \brief returns an appropriate string for prefixing a log message with the given severity + //! + static const char* severityPrefix(Severity severity) + { + switch (severity) + { + case Severity::kINTERNAL_ERROR: return "[F] "; + case Severity::kERROR: return "[E] "; + case Severity::kWARNING: return "[W] "; + case Severity::kINFO: return "[I] "; + case Severity::kVERBOSE: return "[V] "; + default: assert(0); return ""; + } + } + + //! + //! \brief returns an appropriate string for prefixing a test result message with the given result + //! + static const char* testResultString(TestResult result) + { + switch (result) + { + case TestResult::kRUNNING: return "RUNNING"; + case TestResult::kPASSED: return "PASSED"; + case TestResult::kFAILED: return "FAILED"; + case TestResult::kWAIVED: return "WAIVED"; + default: assert(0); return ""; + } + } + + //! + //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity + //! + static std::ostream& severityOstream(Severity severity) + { + return severity >= Severity::kINFO ? std::cout : std::cerr; + } + + //! + //! \brief method that implements logging test results + //! + static void reportTestResult(const TestAtom& testAtom, TestResult result) + { + severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " + << testAtom.mCmdline << std::endl; + } + + //! + //! \brief generate a command line string from the given (argc, argv) values + //! + static std::string genCmdlineString(int argc, char const* const* argv) + { + std::stringstream ss; + for (int i = 0; i < argc; i++) + { + if (i > 0) + ss << " "; + ss << argv[i]; + } + return ss.str(); + } + + Severity mReportableSeverity; +}; + +namespace +{ + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE +//! +//! Example usage: +//! +//! LOG_VERBOSE(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO +//! +//! Example usage: +//! +//! LOG_INFO(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_INFO(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING +//! +//! Example usage: +//! +//! LOG_WARN(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_WARN(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR +//! +//! Example usage: +//! +//! LOG_ERROR(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_ERROR(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR +// ("fatal" severity) +//! +//! Example usage: +//! +//! LOG_FATAL(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_FATAL(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); +} + +} // anonymous namespace + +#endif // TENSORRT_LOGGING_H diff --git a/efficient_ad/src/macros.h b/efficient_ad/src/macros.h new file mode 100644 index 00000000..17339a24 --- /dev/null +++ b/efficient_ad/src/macros.h @@ -0,0 +1,29 @@ +#ifndef __MACROS_H +#define __MACROS_H + +#include + +#ifdef API_EXPORTS +#if defined(_MSC_VER) +#define API __declspec(dllexport) +#else +#define API __attribute__((visibility("default"))) +#endif +#else + +#if defined(_MSC_VER) +#define API __declspec(dllimport) +#else +#define API +#endif +#endif // API_EXPORTS + +#if NV_TENSORRT_MAJOR >= 8 +#define TRT_NOEXCEPT noexcept +#define TRT_CONST_ENQUEUE const +#else +#define TRT_NOEXCEPT +#define TRT_CONST_ENQUEUE +#endif + +#endif // __MACROS_H diff --git a/efficient_ad/src/model.cpp b/efficient_ad/src/model.cpp new file mode 100644 index 00000000..3ba3f6db --- /dev/null +++ b/efficient_ad/src/model.cpp @@ -0,0 +1,428 @@ +#include "model.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "config.h" + +using namespace nvinfer1; + +// TensorRT weight files have a simple space delimited format: +// [type] [size] +static std::map loadWeights(const std::string file) { + std::cout << "Loading weights: " << file << std::endl; + std::map weightMap; + + // Open weights file + std::ifstream input(file); + assert(input.is_open() && "Unable to load weight file. please check if the .wts file path is right!!!!!!"); + + // Read number of weight blobs + int32_t count; + input >> count; + assert(count > 0 && "Invalid weight map file."); + + while (count--) { + Weights wt{DataType::kFLOAT, nullptr, 0}; + uint32_t size; + + // Read name and type of blob + std::string name; + input >> name >> std::dec >> size; + wt.type = DataType::kFLOAT; + + // Load blob + uint32_t* val = reinterpret_cast(malloc(sizeof(val) * size)); + for (uint32_t x = 0, y = size; x < y; ++x) { + input >> std::hex >> val[x]; + } + wt.values = val; + + wt.count = size; + weightMap[name] = wt; + } + + return weightMap; +} + +void printNetworkLayers(INetworkDefinition* network) { + int numLayers = network->getNbLayers(); + // std::cout << "currently num of layers: " << numLayers << std::endl; + + auto dataTypeToString = [](DataType type) { + switch (type) { + case DataType::kFLOAT: + return "kFLOAT"; + case DataType::kHALF: + return "kHALF"; + case DataType::kINT8: + return "kINT8"; + case DataType::kINT32: + return "kINT32"; + case DataType::kBOOL: + return "kBOOL"; + default: + return "Unknown"; + } + }; + + for (int i = 0; i < numLayers; ++i) { + ILayer* layer = network->getLayer(i); + std::cout << "--- Layer" << i << " = " << layer->getName() << std::endl; + std::cout << "input & output tensor type: " + << dataTypeToString(layer->getInput(0)->getType()) << "\t" + << dataTypeToString(layer->getOutput(0)->getType()) << std::endl; + + // input + int inTensorNum = layer->getNbInputs(); + for (int j = 0; j < inTensorNum; ++j) { + // std::cout << layer->getInput(j)->getDimensions().nbDims; + Dims dims_in = layer->getInput(j)->getDimensions(); + std::cout << "input shape[" << j << "]: ("; + for (int k = 0; k < dims_in.nbDims; ++k) { + std::cout << dims_in.d[k]; + if (k < dims_in.nbDims - 1) { + std::cout << ", "; + } + } + std::cout << ")\t"; + } + std::cout << std::endl; + + // output + int outTensorNum = layer->getNbOutputs(); + for (int j = 0; j < outTensorNum; ++j) { + // std::cout << layer->getOutput(j)->getName(); + Dims dims_out = layer->getOutput(j)->getDimensions(); + std::cout << "output shape: ("; + for (int k = 0; k < dims_out.nbDims; ++k) { + std::cout << dims_out.d[k]; + if (k < dims_out.nbDims - 1) { + std::cout << ", "; + } + } + std::cout << ")"; + } + std::cout << "\n" + << std::endl; + } +} + +static IScaleLayer* NormalizeInput(INetworkDefinition* network, ITensor& input) { + float meanValues[3] = {-0.485f, -0.456f, -0.406f}; + float stdValues[3] = {1.0f / 0.229f, 1.0f / 0.224f, 1.0f / 0.225f}; + Weights meanWeights{DataType::kFLOAT, meanValues, 3}; + Weights stdWeights{DataType::kFLOAT, stdValues, 3}; + + IScaleLayer* NormaLayer = network->addScale(input, ScaleMode::kCHANNEL, meanWeights, stdWeights, Weights{}); + assert(NormaLayer != nullptr); + + return NormaLayer; +} + +static IScaleLayer* NormalizeTeacherMap(INetworkDefinition* network, std::map& weightMap, + ITensor& input) { + float* mean = (float*)weightMap["mean_std.mean"].values; + float* std = (float*)weightMap["mean_std.std"].values; + int len = weightMap["mean_std.mean"].count; + + // 1.scale + float* scaleVal = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + scaleVal[i] = 1.0 / std[i]; + } + Weights scale{DataType::kFLOAT, scaleVal, len}; + + // 2.shift + float* shiftVal = nullptr; + shiftVal = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + shiftVal[i] = -mean[i]; + } + Weights shift{DataType::kFLOAT, shiftVal, len}; + + IScaleLayer* scale_1 = network->addScale(input, ScaleMode::kCHANNEL, shift, Weights{}, Weights{}); + assert(scale_1); + IScaleLayer* scale_2 = network->addScale(*scale_1->getOutput(0), ScaleMode::kCHANNEL, Weights{}, scale, Weights{}); + assert(scale_2); + + return scale_2; +} + +static ILayer* NormalizeFinalMap(INetworkDefinition* network, std::map& weightMap, ITensor& input, + std::string name) { + float* qa = (float*)weightMap["quantiles.qa_" + name].values; + float* qb = (float*)weightMap["quantiles.qb_" + name].values; + int len = weightMap["quantiles.qa_" + name].count; + + Weights qbWeight_2{DataType::kFLOAT, qb, len}; + + // fmap_st - qa_st + float* shiftVal_1 = nullptr; + shiftVal_1 = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + shiftVal_1[i] = -qa[i]; + } + Weights qa_shiftWeight_1{DataType::kFLOAT, shiftVal_1, len}; + IScaleLayer* mapNorm_subLayer_1 = network->addScale(input, ScaleMode::kUNIFORM, qa_shiftWeight_1, Weights{}, Weights{}); + assert(mapNorm_subLayer_1); + + // qb_st - qa_st + float* shiftVal_2 = nullptr; + shiftVal_2 = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + shiftVal_2[i] = qb[i] - qa[i]; + } + + // (fmap_st - qa_st) / (qb_st - qa_st) + float* scaleVal_1 = nullptr; + scaleVal_1 = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + scaleVal_1[i] = 1.0f / shiftVal_2[i]; + } + Weights scaleWeight_1{DataType::kFLOAT, scaleVal_1, len}; + IScaleLayer* mapNorm_divLayer_1 = network->addScale(*mapNorm_subLayer_1->getOutput(0), ScaleMode::kUNIFORM, Weights{}, scaleWeight_1, Weights{}); + assert(mapNorm_divLayer_1); + + // ((fmap_st - qa_st) / (qb_st - qa_st)) * 0.1 + float* scaleVal_2 = nullptr; + scaleVal_2 = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + scaleVal_2[i] = 0.1f; + } + Weights scaleWeight_2{DataType::kFLOAT, scaleVal_2, 1}; + IScaleLayer* mapNorm_Layer = network->addScale(*mapNorm_divLayer_1->getOutput(0), ScaleMode::kUNIFORM, Weights{}, scaleWeight_2, Weights{}); + assert(mapNorm_Layer); + + return mapNorm_Layer; +} + +static ILayer* convRelu(INetworkDefinition* network, std::map& weightMap, ITensor& input, + int outch, int ksize, int s, int p, int g, std::string lname, bool withRelu) { + Weights emptywts{DataType::kFLOAT, nullptr, 0}; + IConvolutionLayer* conv1 = network->addConvolutionNd(input, outch, DimsHW{ksize, ksize}, + weightMap[lname + ".weight"], + weightMap[lname + ".bias"]); // if without bias weights, the results won't match with torch version + assert(conv1); + conv1->setStrideNd(DimsHW{s, s}); + conv1->setPaddingNd(DimsHW{p, p}); + conv1->setNbGroups(g); + conv1->setName((lname).c_str()); + + if (!withRelu) + return conv1; + + auto relu = network->addActivation(*conv1->getOutput(0), ActivationType::kRELU); + assert(relu); + + return relu; +} + +static IResizeLayer* interpolate(INetworkDefinition* network, ITensor& input, Dims upsampleScale, ResizeMode resizeMode) { + IResizeLayer* interpolateLayer = network->addResize(input); + assert(interpolateLayer); + interpolateLayer->setOutputDimensions(upsampleScale); + interpolateLayer->setResizeMode(resizeMode); + + return interpolateLayer; +} + +static ILayer* interpConvRelu(INetworkDefinition* network, std::map& weightMap, ITensor& input, + int outch, int ksize, int s, int p, int g, std::string lname, int dim) { + IResizeLayer* interpolateLayer = network->addResize(input); + assert(interpolateLayer != nullptr); + interpolateLayer->setOutputDimensions(Dims3{input.getDimensions().d[0], dim, dim}); + interpolateLayer->setResizeMode(ResizeMode::kLINEAR); + + IConvolutionLayer* conv1 = network->addConvolutionNd(*interpolateLayer->getOutput(0), outch, DimsHW{ksize, ksize}, + weightMap[lname + ".weight"], + weightMap[lname + ".bias"]); + assert(conv1); + conv1->setStrideNd(DimsHW{s, s}); + conv1->setPaddingNd(DimsHW{p, p}); + conv1->setNbGroups(g); + conv1->setName((lname + ".conv").c_str()); + + auto relu = network->addActivation(*conv1->getOutput(0), ActivationType::kRELU); + assert(relu); + + return relu; +} + +static IPoolingLayer* avgPool2d(INetworkDefinition* network, ITensor& input, int kernelSize, int stride, int padding) { + IPoolingLayer* poolLayer = network->addPooling(input, PoolingType::kAVERAGE, DimsHW{kernelSize, kernelSize}); + assert(poolLayer); + poolLayer->setStride(DimsHW{stride, stride}); + poolLayer->setPadding(DimsHW{padding, padding}); + + return poolLayer; +} + +static void slice(INetworkDefinition* network, ITensor& input, std::vector& layer_vec) { + Dims inputDims = input.getDimensions(); + ISliceLayer* slice1 = network->addSlice(input, + Dims3{0, 0, 0}, + Dims3{inputDims.d[0] / 2, inputDims.d[1], inputDims.d[2]}, + Dims3{1, 1, 1}); + assert(slice1); + + ISliceLayer* slice2 = network->addSlice(input, + Dims3{inputDims.d[0] / 2, 0, 0}, + Dims3{inputDims.d[0] / 2, inputDims.d[1], inputDims.d[2]}, + Dims3{1, 1, 1}); + assert(slice2); + + layer_vec.push_back(slice1->getOutput(0)); + layer_vec.push_back(slice2->getOutput(0)); +} + +static IElementWiseLayer* mergeMap(INetworkDefinition* network, ITensor& input1, ITensor& input2) { + float* scaleVal = nullptr; + scaleVal = reinterpret_cast(malloc(sizeof(float) * 1)); + for (int i = 0; i < 1; i++) { + scaleVal[i] = 0.5f; + } + Weights scaleWeight{DataType::kFLOAT, scaleVal, 1}; + IScaleLayer* mergeMapLayer1 = network->addScale(input1, ScaleMode::kUNIFORM, Weights{}, scaleWeight, Weights{}); + assert(mergeMapLayer1); + + IScaleLayer* mergeMapLayer2 = network->addScale(input2, ScaleMode::kUNIFORM, Weights{}, scaleWeight, Weights{}); + assert(mergeMapLayer2); + + IElementWiseLayer* mergedMapLayer = network->addElementWise(*mergeMapLayer1->getOutput(0), *mergeMapLayer2->getOutput(0), ElementWiseOperation::kSUM); + assert(mergedMapLayer); + + return mergedMapLayer; +} + +ICudaEngine* build_efficientAD_engine(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, + DataType dt, float& gd, float& gw, std::string& wts_name) { + /* create network object */ + INetworkDefinition* network = builder->createNetworkV2(0U); + + /* create input tensor {3, kInputH, kInputW} */ + ITensor* InputData = network->addInput(kInputTensorName, dt, Dims3{3, kInputH, kInputW}); + assert(InputData); + + /* create weight map */ + std::map weightMap = loadWeights(wts_name); + + /* AE */ + // auto BN1 = NormalizeInput(network, *InputData); + // encoder + auto enconv1 = convRelu(network, weightMap, *InputData, 32, 4, 2, 1, 1, "ae.encoder.enconv1", true); + auto enconv2 = convRelu(network, weightMap, *enconv1->getOutput(0), 32, 4, 2, 1, 1, "ae.encoder.enconv2", true); + auto enconv3 = convRelu(network, weightMap, *enconv2->getOutput(0), 64, 4, 2, 1, 1, "ae.encoder.enconv3", true); + auto enconv4 = convRelu(network, weightMap, *enconv3->getOutput(0), 64, 4, 2, 1, 1, "ae.encoder.enconv4", true); + auto enconv5 = convRelu(network, weightMap, *enconv4->getOutput(0), 64, 4, 2, 1, 1, "ae.encoder.enconv5", true); + auto enconv6 = convRelu(network, weightMap, *enconv5->getOutput(0), 64, 8, 1, 0, 1, "ae.encoder.enconv6", false); + // decoder + auto deconv1 = interpConvRelu(network, weightMap, *enconv6->getOutput(0), 64, 4, 1, 2, 1, "ae.decoder.deconv1", 3); + auto deconv2 = interpConvRelu(network, weightMap, *deconv1->getOutput(0), 64, 4, 1, 2, 1, "ae.decoder.deconv2", 8); + auto deconv3 = interpConvRelu(network, weightMap, *deconv2->getOutput(0), 64, 4, 1, 2, 1, "ae.decoder.deconv3", 15); + auto deconv4 = interpConvRelu(network, weightMap, *deconv3->getOutput(0), 64, 4, 1, 2, 1, "ae.decoder.deconv4", 32); + auto deconv5 = interpConvRelu(network, weightMap, *deconv4->getOutput(0), 64, 4, 1, 2, 1, "ae.decoder.deconv5", 63); + auto deconv6 = interpConvRelu(network, weightMap, *deconv5->getOutput(0), 64, 4, 1, 2, 1, "ae.decoder.deconv6", 127); + auto deconv7 = interpConvRelu(network, weightMap, *deconv6->getOutput(0), 64, 3, 1, 1, 1, "ae.decoder.deconv7", 56); + auto deconv8 = convRelu(network, weightMap, *deconv7->getOutput(0), 384, 3, 1, 1, 1, "ae.decoder.deconv8", false); + + /* PDN_medium_teacher */ + // no BN added after the convolutional layer + auto teacher1 = convRelu(network, weightMap, *InputData, 256, 4, 1, 0, 1, "teacher.conv1", true); + auto avgPool1 = avgPool2d(network, *teacher1->getOutput(0), 2, 2, 0); + auto teacher2 = convRelu(network, weightMap, *avgPool1->getOutput(0), 512, 4, 1, 0, 1, "teacher.conv2", true); + auto avgPool2 = avgPool2d(network, *teacher2->getOutput(0), 2, 2, 0); + auto teacher3 = convRelu(network, weightMap, *avgPool2->getOutput(0), 512, 1, 1, 0, 1, "teacher.conv3", true); + auto teacher4 = convRelu(network, weightMap, *teacher3->getOutput(0), 512, 3, 1, 0, 1, "teacher.conv4", true); + auto teacher5 = convRelu(network, weightMap, *teacher4->getOutput(0), 384, 4, 1, 0, 1, "teacher.conv5", true); + auto teacher6 = convRelu(network, weightMap, *teacher5->getOutput(0), 384, 1, 1, 0, 1, "teacher.conv6", false); + + /* PDN_medium_student */ + auto student1 = convRelu(network, weightMap, *InputData, 256, 4, 1, 0, 1, "student.conv1", true); + auto avgPool3 = avgPool2d(network, *student1->getOutput(0), 2, 2, 0); + auto student2 = convRelu(network, weightMap, *avgPool3->getOutput(0), 512, 4, 1, 0, 1, "student.conv2", true); + auto avgPool4 = avgPool2d(network, *student2->getOutput(0), 2, 2, 0); + auto student3 = convRelu(network, weightMap, *avgPool4->getOutput(0), 512, 1, 1, 0, 1, "student.conv3", true); + auto student4 = convRelu(network, weightMap, *student3->getOutput(0), 512, 3, 1, 0, 1, "student.conv4", true); + auto student5 = convRelu(network, weightMap, *student4->getOutput(0), 768, 4, 1, 0, 1, "student.conv5", true); + auto student6 = convRelu(network, weightMap, *student5->getOutput(0), 768, 1, 1, 0, 1, "student.conv6", false); + + /* postCalculate */ + auto normal_teacher_output = NormalizeTeacherMap(network, weightMap, *teacher6->getOutput(0)); + std::vector layer_vec{}; + slice(network, *student6->getOutput(0), layer_vec); + ITensor* y_st = layer_vec[0]; + ITensor* y_stae = layer_vec[1]; + + // distance_st + IElementWiseLayer* sub_st = network->addElementWise(*normal_teacher_output->getOutput(0), *y_st, ElementWiseOperation::kSUB); + assert(sub_st); + IElementWiseLayer* distance_st = network->addElementWise(*sub_st->getOutput(0), *sub_st->getOutput(0), ElementWiseOperation::kPROD); + assert(distance_st); + + // distance_stae + IElementWiseLayer* sub_stae = network->addElementWise(*deconv8->getOutput(0), *y_stae, ElementWiseOperation::kSUB); + assert(sub_stae); + IElementWiseLayer* distance_stae = network->addElementWise(*sub_stae->getOutput(0), *sub_stae->getOutput(0), ElementWiseOperation::kPROD); + assert(distance_stae); + + IReduceLayer* map_st = network->addReduce(*distance_st->getOutput(0), ReduceOperation::kAVG, 1, true); + assert(map_st); + IReduceLayer* map_stae = network->addReduce(*distance_stae->getOutput(0), ReduceOperation::kAVG, 1, true); + assert(map_stae); + + IPaddingLayer* padMap_st = network->addPadding(*map_st->getOutput(0), DimsHW{4, 4}, DimsHW{4, 4}); + assert(padMap_st); + IPaddingLayer* padMap_stae = network->addPadding(*map_stae->getOutput(0), DimsHW{4, 4}, DimsHW{4, 4}); + assert(padMap_stae); + + IResizeLayer* interpMap_st = interpolate(network, *padMap_st->getOutput(0), Dims3{padMap_st->getOutput(0)->getDimensions().d[0], 256, 256}, ResizeMode::kLINEAR); + assert(interpMap_st); + IResizeLayer* interpMap_stae = interpolate(network, *padMap_stae->getOutput(0), Dims3{padMap_stae->getOutput(0)->getDimensions().d[0], 256, 256}, ResizeMode::kLINEAR); + assert(interpMap_stae); + + ILayer* normalizedMap_st = NormalizeFinalMap(network, weightMap, *interpMap_st->getOutput(0), "st"); + assert(normalizedMap_st); + ILayer* normalizedMap_stae = NormalizeFinalMap(network, weightMap, *interpMap_stae->getOutput(0), "ae"); + assert(normalizedMap_stae); + + IElementWiseLayer* mergedMapLayer = mergeMap(network, *normalizedMap_st->getOutput(0), *normalizedMap_st->getOutput(0)); + printNetworkLayers(network); + + /* ouput */ + mergedMapLayer->getOutput(0)->setName(kOutputTensorName); + network->markOutput(*mergedMapLayer->getOutput(0)); + + /* Engine config */ + builder->setMaxBatchSize(maxBatchSize); + config->setMaxWorkspaceSize(16 * (1 << 20)); // 16MB +#if defined(USE_FP16) + config->setFlag(BuilderFlag::kFP16); +#elif defined(USE_INT8) + std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl; + assert(builder->platformHasFastInt8()); + config->setFlag(BuilderFlag::kINT8); + Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, "./coco_calib/", "int8calib.table", kInputTensorName); + config->setInt8Calibrator(calibrator); +#endif + std::cout << "Building engine, please wait for a while..." << std::endl; + ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); + std::cout << "Build engine successfully!" << std::endl; + + // Don't need the network any more + network->destroy(); + + // Release host memory + for (auto& mem : weightMap) { + free((void*)(mem.second.values)); + } + + return engine; +} \ No newline at end of file diff --git a/efficient_ad/src/model.h b/efficient_ad/src/model.h new file mode 100644 index 00000000..7230795f --- /dev/null +++ b/efficient_ad/src/model.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +#include + +nvinfer1::ICudaEngine* build_efficientAD_engine(unsigned int maxBatchSize, nvinfer1::IBuilder* builder, + nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, + float& gd, float& gw, std::string& wts_name); \ No newline at end of file diff --git a/efficient_ad/src/postprocess.h b/efficient_ad/src/postprocess.h new file mode 100644 index 00000000..64bfd2db --- /dev/null +++ b/efficient_ad/src/postprocess.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +void genHeatMap(cv::Mat originImg, cv::Mat& anomalyGrayMap, cv::Mat& HeatMap) { + cv::Mat colorMap; + cv::applyColorMap(colorMap, anomalyGrayMap, cv::COLORMAP_JET); + cv::addWeighted(originImg, 0.5, colorMap, 0.5, 0, HeatMap); +} \ No newline at end of file diff --git a/efficient_ad/src/utils.h b/efficient_ad/src/utils.h new file mode 100644 index 00000000..4b478689 --- /dev/null +++ b/efficient_ad/src/utils.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +static inline int read_files_in_dir(const char* p_dir_name, std::vector& file_names) { + DIR *p_dir = opendir(p_dir_name); + if (p_dir == nullptr) { + return -1; + } + + struct dirent* p_file = nullptr; + while ((p_file = readdir(p_dir)) != nullptr) { + if (strcmp(p_file->d_name, ".") != 0 && + strcmp(p_file->d_name, "..") != 0) { + //std::string cur_file_name(p_dir_name); + //cur_file_name += "/"; + //cur_file_name += p_file->d_name; + std::string cur_file_name(p_file->d_name); + file_names.push_back(cur_file_name); + } + } + + closedir(p_dir); + return 0; +} + +