初学TensorRT框架的时候,build、engine、context和runtime这几个类经常搞不清楚,不知道他们又什么作用,下面就这几个类进行简单说明。
1.builder:构建器,搜索cuda内核目录以获得最快的可用实现,必须使用和运行时的GPU相同的GPU来构建优化引擎。在构建引擎时,TensorRT会复制权重。
IBuilder* builder = createInferBuilder(gLogger);
INetworkDefinition* network = builder->createNetwork();
创建引擎
builder->setMaxBatchSize(maxBatchSize);
builder->setMaxWorkspaceSize(1 << 20);
ICudaEngine* engine = builder->buildCudaEngine(*network);
2.engine:引擎,不能跨平台和TensorRT版本移植。若要存储,需要将引擎转化为一种格式,即序列化,若要推理,需要反序列化引擎。引擎用于保存网络定义和模型参数。
IHostMemory *serializedModel = engine->serialize();
// store model to disk
std::ofstream p("resnet18.engine");
p.write(reinterpret_cast<const char*>(modelStream->data()), modelStream->size());
// <…>
serializedModel->destroy();
3.context:上下文,创建一些空间来存储中间值。一个engine可以创建多个context,分别执行多个推理任务。
IExecutionContext *context = engine->createExecutionContext();
4.runtime:用于反序列化引擎。
IRuntime* runtime = createInferRuntime(gLogger);
ICudaEngine* engine = runtime->deserializeCudaEngine(modelData, modelSize, nullptr);
TensorRT的C++API手册:
TensorRT:Main Page
tensorrt里的engine、context、buffer、profile等的关系: