PyTorch中模型的保存和加载
保存和加载模型参数(state_dict)(推荐)
state_dice 是一个简单的 python 字典,映射了每一层的参数名称和数值。
# save
torch.save(model.state_dict,PATH)#load
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
保存和加载整个模型 (不推荐)
要清晰的理解这种模型保存和加载的方式,此保存和加载的过程比较直观,可以使用更少的代码。这种保存的方式是使用 python 的 pickle 模块保存整个模块。这种方式的缺点也是往往被大家误解忽视的地方是序列化的数据绑定到特定的类,并且使用确切的目录结构。pickle 不会保存模型累本身,而且将其保存包含类的文件路径。该路径在加载时使用。
# save
torch.save(model, PATH)# load
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
直观地来说,这种保存整个模型并不是表面意义上的,在加载中仍需使用模型的相关定义代码,只是不再代码中人为的显示定义,而是 pickle 反序列时,去加载训练时相同的文件路径下的模型定义。这就导致在其他项目中使用或者代码重构时,容易出现错误。
相关问题 issue:torch.load() requires model module in the same folder · Issue #3678 · pytorch/pytorch · GitHub
注:torch.save 使用的是 python 中的 pickle 来进行序列化,PyTorch 1.6 版本切换了新的 zipfile-base 文件格式。torch.load 仍然可以加载以前的格式,如果在保存中仍想使用之前的格式,通过传入参数 _use_new_zipfile_serialization=False 来控制
使用 REST API 在 python 中部署PyTorch
针对PyTorch的模型部署,最简单的方式就是在 python 中使用 restful 的方式提供服务,这是最简单的方式,但是不适用于具有高性能要求的用例,一些对系统性能要求不高的场景下可以使用。 接下来我们以 flask 为例,你也可以使用 django 等其他 python web framework。在生产环境,比较简单的一种部署方式就是每块卡对应一个容器环境,在每个容器中使用 flask 来启动服务。
flask 是轻量级的 web 应用框架,使用它可以轻松的来部署我们的模型推理服务。 下面给出简单的示例伪代码:
from flask import Flask
from flask import requestapp = Flask(__name__)
model = load_model() # load your model
@app.route(‘/modelinference, methods=[‘POST’])
def model_inference():
input_data = request.form.get(‘input_data’, ”)
output_data = model(input_data)
return outputapp.run(host=’0.0.0.0′, port=7000,debug=True)
上面的代码是一个简单的示例,在 flask 启动的时候,我们就将模型加载进去,这样在每次请求的时候无需在加载模型,减少了每次请求的时间。上述例子中使用的是表单的方式向服务端 post 输入数据,然后进行模型推理的相关逻辑计算,可以将预处理和后处理等都集成在此服务中。
需要注意的是上述的代码是采用的 debug 方式启动的,在生产环境中,我们不能用这种方式来部署。常用的方式是借助 gunicorn 来启动 flask 应用。对于 gunicore 可以通过设置一些参数,来实现异步和多进程、多线程等。
在生产测试的时候,大量请求同时并发发送时,如遇到服务端返回 502、504 等错误代码,可通过抓包工具等进行请求排查,在相应的进行调整 gunicore 中的参数。
在 c++中加载PyTorch模型
在 python 中训练,在 python 中部署,可以说真的十分简单。但是往往由于 python 本身的性能和应用场景,一些对系统稳定性等要求比较高的情况下,我们就要使用 c++语言进行部署开发。接下来,我们就会一步一步的介绍如何将在 python 中训练的模型在 c++中部署。
TorchScript
TorchScript 是 Pytorch 模型的 intermediate representation(IR),可以在更高性能环境下运行,例如 c++。TorchScript 代码可以在其自己的解释器中调用,同时这种格式允许我们将整个模型保存到磁盘上,然后将其加载到另一个环境中。如何获得 TorchScript 形式的模型,官方提供了两种方式,这两种方式有各自的限制,很多时候都是两种方式混合使用。
trace Modules
使用 torch.jit.trace,并传入了 Module 和示例输入,该方法记录了运行 Module 时发生的操作,并创建了 torch.jit.ScriptModule 的实例。
#使用示例
traced_cell = torch.jit.trace(my_cell, (x, h))
另外我们可以通过.graph 属性来检查图,但是这是一种非常低级的表示形式,图中包含的大多数信息对最终的用户没有用,我们也可以使用.code 属性查看 python 的语法解释
# 查看图
print(traced_cell.graph)
#查看 python 语法解释
print(traced_cell.code)
notice——trace 会完全按照我们所说的去做:运行代码,记录发生的操作,并构造一个可以做到这一点的 ScriptModule。这样就导致控制流相关的操作不会被记录(只会记录输入示例的那种情况)
使用 script compiler 去转换模块
如上文所述,trace module 虽然用法简单,但是无法捕捉到控制流相关的操作。因为官方还提供了脚本编译器来直接分析 python 源代码并将其转化为 TorchScript。
然而 script compiler 的方式也不是万能的,例如操作有动态输入等情况,script 就会出问题。因此,往往会将两者混合使用来得到我们期待的 TorchScript 模型。
保存和加载 TorchScript 模型
具体的如果使用 trace 和 script,此处就不将仔细叙述,如果有需要可以详细查看官方文档的使用示例。当我们成功地得到了 TorchScript 模型后,我们将如何保存和加载呢?
# save
traced.save(‘wrapped_rnn.zip’)# load
loaded = torch.jit.load(‘wrapped_rnn.zip’)
c++中加载 TorchScript 模型
前面我们介绍了 TorchScript,并且介绍了两种将 python 中的 Pytorch 转化为 TorchScript 模型的方式。接下来我们将进入正题,在 c++中部署PyTorch模型,显而易见地,我们需要先将模型转化并保存为 TorchScript 模型,然后在 c++中进行模型的加载。 对于推理服务的接口通信,一般我们会使用 grpc 或者 http service 来处理。对于一个简单的 c++中加载 PyTorch 模型示例,一般包括以下几个过程:
1.环境
要想在 c++中加载序列化的 PyTorch 模型,必须依赖 PyTorch 的 c++ api 即 LibTorch。所以请先按照官方所提供的方式进行相关的环境配置。
2.PyTorch Model 转为 Torch Script 并进行保存
上面已经进行了相关阐述。
3.c++中加载
# 一个最小的 CMakeLists.txt 文件
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app “${TORCH_LIBRARIES}”)
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << “usage: example-app <path-to-exported-script-module>\n”;
return -1;
}torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << “error loading the model\n”;
return -1;
}std::cout << “ok\n”;
}
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << ‘\n’;
将模型从 PyTorch 导出到 ONNX
什么是 ONNX
ONNX 是 Open Neural Network Exchange 的缩写,ONNX 为 AI 模型(深度学习模型和传统机器学习模型)提供了开源标准的格式。我们可以将模型转化为 ONNX 进行生产部署,同时也可以将其作为中介格式,使模型在不同的框架间相互转化。
相关扩展:ONNX RUNTIM 是一种跨平台的推理和训练加速器。使用 ONNX RUNTIM 可以提升模型的推理性能、减少训练大型模型的时间和成本、可在 python 中训练并部署到其他语言应用中、可在不同的硬件和操作系统上运行。
将 PyTorch 模型转为 ONNX
我们以图片为例进行示例说明(动态尺寸输入设置):
# 导出模型
def export_onnx_model(model, input_shape, onnx_path, input_names=None, output_names=None, dynamic_axes=None):
inputs = torch.ones(*input_shape)
torch.onnx.export(model, inputs, onnx_path, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)def convert_model():
model = load_model() # PyTorch 模型
batch_size = 1
width = 112
height = 112
#将 batch_size,width,height 都设置成动态的大小
input_shape = (batch_size, 3, width, height)input_names = [‘input’]
output_names = [‘output’]
dynamic_axes = {‘input’: {0: ‘batch_size’, 2: ‘width’, 3: ‘height’},
‘output’: {0: ‘batch_size’, 2: ‘width’, 3: ‘height’}}export_onnx_model(model, input_shape, onnx_path, input_names, output_names, dynamic_axes=dynamic_axes)
总结的来说,torch.onnx.export 方法传入模型和相关的示例输入,通过设置动态输入可以使导出的模型接受不同的输入尺度。
加载部署 ONNX 模型
对于 onnx 模型,很多推理框架都支持,我们也可以简单的使用 onnx 的 API 进行加载或者使用 onnx runtime 的 api 进行部署。
1.使用 onnx 的 API 检查 onnx 模型:
import onnx
onnx_model = onnx.load(“your_model.onnx”)
onnx.checker.check_model(onnx_model)
这里也推荐使用 Netron 可视化模型软件,进行模型前后转换的比对。
2.使用 onnx runtime 运行模型
import onnxruntime
ort_session = onnxruntime.InferenceSession(“your_model.onnx”)def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
NOTICE
onnx 和 onnxruntime 的安装
pip install onnx onnxruntime
注意:安装的 onnxruntime 是 cpu 版本的还是 gpu 版本的,对于 gpu 版本的安装对 cuda 的版本有具体的要求。在安装时仔细看官方文档的版本对应关系,避免踩坑
1.转化 onnx
在转化自己的模型时,经常会遇到一些不支持的操作。比如 avg_pool 的动态输入等,一般常发生在一些输入是动态的时候,如果遇到相应的问题,可以多去 onnx GitHub 查看有没有相应的解决办法和操作支持,在了解自己的网络的情况下,也可以考虑进行模型拆分或者底层 op 的自己实现。
2.其他
今年亚马逊和 facebook 还联合推出了 TorchServe,针对 PyTorch 进行服务部署。但是这一工作似乎评价不一,我并没有实际使用,有兴趣和需求的可以尝试。
相关推荐阅读:
TorchServe github: https://github.com/pytorch/serve
如何评价 PyTorch 在 2020 年 4 月推出的 TorchServe
另外现在有很多优秀的模型推理服务项目,比如 nvidia 家的 triton 等。这些开源服务已经将服务通信、模型管理等都做好了,我们也可以在这些开源项目上面进行应用和二次开发。值得一提的话,如果你真的希望模型服务性能上有很大的改善,要针对性的进行优化工作,寻找瓶颈在哪里。如果需要模型压缩和优化的就需要对模型的结构有一定的了解。
参考文献
LOADING A TORCHSCRIPT MODEL IN C++
EXPORTING A MODEL FROM PYTORCH TO ONNX AND RUNNING IT USING ONNX RUNTIME