这篇文章主要介绍了如何使用TensorRT实现自定义算子。
Note:
- 我使用的是TensorRT7.0,自定义算子使用的IPluginV2IOExt实现的。
- 模型框架是caffe,所以以下实现都只适用于caffe模型的解析,但理论上解析tf和onnx的改动不大。
- 实现细节不方便全部贴出,但是基本实现过程和结构都在下面了,照着写写没啥问题了。
其实自定义算子写多了发现其实还挺好写的,格式都差不多,主要区别是enqueue的前向计算逻辑可能写起来复杂些。
整个实现过程基本上是:
- 继承nvinfer1::IPluginV2IOExt,并实现相应的虚函数。
- 继承nvinfer1::IPluginCreator并实现相应的虚函数。
- 继承nvcaffeparser1::IPluginFactoryV2并实现相应的虚函数。
- 在解析网络之前调用REGISTER_TENSORRT_PLUGIN注册UpsampleCreator和调用parser->setPluginFactoryV2()以使用自定义层类型。
以Upsample为例,TensorRT不支持Caffe的Upsample层,所以这里实现了一个自定义层类型,即plugin。需要实现:
- Upsample类,继承自nvinfer1::IPluginV2IOExt。
- UpsampleCreator类,继承自nvinfer1::IPluginCreator。
- CaffePluginFactory类,继承自nvcaffeparser1::IPluginFactoryV2。
需要实现的函数详见如下代码段。
Upsample类的实现:
class Upsample : public nvinfer1::IPluginV2IOExt {
public:
// 直接解析网络时候需要用到
Upsample();
// 反序列化时候需要用到
Upsample(const void *data, size_t length);
~Upsample();
// 直接return输出节点数,
int getNbOutputs() override;
// return输出的维度信息,如:return Dims3(inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]);
Dims getOutputDimensions(int index, const Dims *inputs, int num_input_dims) override;
// pos索引到的input/output的数据格式(format)和数据类型(datatype)如果都支持则返回true
bool supportsFormatCombination(int pos, const PluginTensorDesc* in_out, int num_inputs, int num_outputs) const override;
// 这个函数可以获取到数据类型和输入的维度信息,如果有需要用到的可以在这里将相关信息取出来
configurePlugin(const PluginTensorDesc* in, int num_inputs, const PluginTensorDesc* out, int num_outputs) override;
// 在这里返回正确的序列化数据的长度,如我要序列化数据类型和数据维度:return sizeof(data_type) + sizeof(chw);
size_t getSerializationSize() const override;
// 序列化函数,在这里把反序列化时需要用到的参数或数据序列化
void serialize(void *buffer) const override;
// 设置工作空间,不需要直接 return 0;
size_t getWorkspaceSize(int max_batch_size) const override;
// 前向计算的核心函数,计算逻辑在这里实现,可以使用cublas实现或者自己写cuda核函数实现
int enqueue(int batch_size, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream) override;
// 调用enqueue的时候需要用到的资源先在这里Initialize,这个函数是在engine创建之后enqueue调用之前调用的,不需要Initialize则直接 return 0;
int initialize() override;
// 释放Initialize申请的资源,在enqueue调用之后且engine销毁之后调用
void terminate() override;
// 返回输出的数据类型,如何输入相同,可以直接 return input_types[0];
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* input_types, int num_inputs) const override;
// 返回自定义类型,如这里是:return Upsample
const char* getPluginType() const override;
// 返回plugin version,没啥说的
const char* getPluginVersion() const override;
// 销毁对象
void destroy() override {
delete this;
}
// 在这里new一个该自定义类型并返回
nvinfer1::IPluginV2Ext* clone() const override;
// 设置命名空间,用来在网络中查找和创建plugin
void setPluginNamespace(const char* lib_namespace) override;
// 返回plugin对象的命名空间
const char* getPluginNamespace() const override;
bool isOutputBroadcastAcrossBatch(int output_index, const bool* input_is_broadcasted, int num_inputs) const override;
bool canBroadcastInputAcrossBatch(int input_index) const override;
}
下面是对应的Creator类的实现:
class UpsampleCreator : public nvinfer1::IPluginCreator {
public:
const char* getPluginName() const override;
const char* getPluginVersion() const override;
const PluginFieldCollection* getFieldNames() override;
// 创建自定义层pluin的对象并返回
nvinfer1::IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override;
// 创建自定义层pluin的对象并返回,反序列化用到
nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serial_data, size_t serial_length) override;
void setPluginNamespace(const char* lib_namespace) override;
const char* getPluginNamespace() const override;
}
下面是对应的plugin factory类的实现:
class CaffePluginFactory : public nvcaffeparser1::IPluginFactoryV2 {
public:
// 在这里判断一个层是否为自定义层类型
bool isPluginV2(const char* name) override;
// 在这里创建自定义层类型的对象并返回
nvinfer1::IPluginV2* createPlugin(const char* layer_name, const nvinfer1::Weights* weights, int num_weights, const char* libNamespace="") override;
}
如有问题可加公众号交流:AI算法爱好者