TensorRT的自定义算子Plugin的实现

   日期:2020-06-05     浏览:428    评论:0    
核心提示:介绍了如何使用TensorRT实现自定义算子,如何实现IPluginV2IOExt,如何实现IPluginCreator,以及如何实现IPluginFactoryV2。

这篇文章主要介绍了如何使用TensorRT实现自定义算子。

Note:

  1. 我使用的是TensorRT7.0,自定义算子使用的IPluginV2IOExt实现的。
  2. 模型框架是caffe,所以以下实现都只适用于caffe模型的解析,但理论上解析tf和onnx的改动不大。
  3. 实现细节不方便全部贴出,但是基本实现过程和结构都在下面了,照着写写没啥问题了。

其实自定义算子写多了发现其实还挺好写的,格式都差不多,主要区别是enqueue的前向计算逻辑可能写起来复杂些。
整个实现过程基本上是:

  1. 继承nvinfer1::IPluginV2IOExt,并实现相应的虚函数。
  2. 继承nvinfer1::IPluginCreator并实现相应的虚函数。
  3. 继承nvcaffeparser1::IPluginFactoryV2并实现相应的虚函数。
  4. 在解析网络之前调用REGISTER_TENSORRT_PLUGIN注册UpsampleCreator和调用parser->setPluginFactoryV2()以使用自定义层类型。

以Upsample为例,TensorRT不支持Caffe的Upsample层,所以这里实现了一个自定义层类型,即plugin。需要实现:

  1. Upsample类,继承自nvinfer1::IPluginV2IOExt。
  2. UpsampleCreator类,继承自nvinfer1::IPluginCreator。
  3. CaffePluginFactory类,继承自nvcaffeparser1::IPluginFactoryV2。

需要实现的函数详见如下代码段。

Upsample类的实现:

class Upsample : public nvinfer1::IPluginV2IOExt {
public:
    // 直接解析网络时候需要用到
    Upsample();
    // 反序列化时候需要用到
    Upsample(const void *data, size_t length);
    ~Upsample();
    
    // 直接return输出节点数,
    int getNbOutputs() override;
    
    // return输出的维度信息,如:return Dims3(inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]);
    Dims getOutputDimensions(int index, const Dims *inputs, int num_input_dims) override;
    
    // pos索引到的input/output的数据格式(format)和数据类型(datatype)如果都支持则返回true
    bool supportsFormatCombination(int pos, const PluginTensorDesc* in_out, int num_inputs, int num_outputs) const override;
    
    // 这个函数可以获取到数据类型和输入的维度信息,如果有需要用到的可以在这里将相关信息取出来
    configurePlugin(const PluginTensorDesc* in, int num_inputs, const PluginTensorDesc* out, int num_outputs) override;

    // 在这里返回正确的序列化数据的长度,如我要序列化数据类型和数据维度:return sizeof(data_type) + sizeof(chw);
    size_t getSerializationSize() const override;
    
    // 序列化函数,在这里把反序列化时需要用到的参数或数据序列化
    void serialize(void *buffer) const override;
    
    // 设置工作空间,不需要直接 return 0;
    size_t getWorkspaceSize(int max_batch_size) const override;
    
    // 前向计算的核心函数,计算逻辑在这里实现,可以使用cublas实现或者自己写cuda核函数实现
    int enqueue(int batch_size, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream) override;
    
    // 调用enqueue的时候需要用到的资源先在这里Initialize,这个函数是在engine创建之后enqueue调用之前调用的,不需要Initialize则直接 return 0;
    int initialize() override;
    
    // 释放Initialize申请的资源,在enqueue调用之后且engine销毁之后调用
    void terminate() override;
    
    // 返回输出的数据类型,如何输入相同,可以直接 return input_types[0];
    nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* input_types, int num_inputs) const override;
    
    // 返回自定义类型,如这里是:return Upsample
    const char* getPluginType() const override;
    
    // 返回plugin version,没啥说的
    const char* getPluginVersion() const override;
      
    // 销毁对象
    void destroy() override {
        delete this;
    }
    
    // 在这里new一个该自定义类型并返回
    nvinfer1::IPluginV2Ext* clone() const override;
    
    // 设置命名空间,用来在网络中查找和创建plugin
    void setPluginNamespace(const char* lib_namespace) override;
    // 返回plugin对象的命名空间
    const char* getPluginNamespace() const override;
    bool isOutputBroadcastAcrossBatch(int output_index, const bool* input_is_broadcasted, int num_inputs) const override;
    bool canBroadcastInputAcrossBatch(int input_index) const override;
}

下面是对应的Creator类的实现

class UpsampleCreator : public nvinfer1::IPluginCreator {
public:
    const char* getPluginName() const override;
    const char* getPluginVersion() const override;
    const PluginFieldCollection* getFieldNames() override;
    // 创建自定义层pluin的对象并返回
    nvinfer1::IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override;
    // 创建自定义层pluin的对象并返回,反序列化用到
    nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serial_data, size_t serial_length) override;
    void setPluginNamespace(const char* lib_namespace) override;
    const char* getPluginNamespace() const override;
}

下面是对应的plugin factory类的实现

class CaffePluginFactory : public nvcaffeparser1::IPluginFactoryV2 {
public:
    // 在这里判断一个层是否为自定义层类型
    bool isPluginV2(const char* name) override;
    // 在这里创建自定义层类型的对象并返回
    nvinfer1::IPluginV2* createPlugin(const char* layer_name, const nvinfer1::Weights* weights, int num_weights, const char* libNamespace="") override;
}

如有问题可加公众号交流:AI算法爱好者

 
打赏
 本文转载自:网络 
所有权利归属于原作者,如文章来源标示错误或侵犯了您的权利请联系微信13520258486
更多>最近资讯中心
更多>最新资讯中心
更多>相关资讯中心
0相关评论

推荐图文
推荐资讯中心
点击排行
最新信息
新手指南
采购商服务
供应商服务
交易安全
关注我们
手机网站:
新浪微博:
微信关注:

13520258486

周一至周五 9:00-18:00
(其他时间联系在线客服)

24小时在线客服