PyTorch模型的多种导出方式提供给其他程序使用

flyfish

1 模型可视化

以下使用模型可视化工具时netron

工具下载到本地
https://github.com/lutzroeder/netron/releases/
或者在使用
https://netron.app/

2 预训练模型

当下载一个预训练模型时,只是一个一个的module
PyTorch模型的多种导出方式提供给其他程序使用-LMLPHP

3 ONNX模型导出有输入有输出

import torch
import torchvision

if __name__ == '__main__':
    input = torch.randn(1, 3, 224, 224)        
    model = torchvision.models.resnet18()                          
    model.load_state_dict(torch.load("resnet18-f37072fd.pth")) 
    model.eval()                               
    torch.onnx.export(model, input, "a.onnx", training=torch.onnx.TrainingMode.TRAINING) 
    torch.onnx.export(model, input, "b.onnx", training=torch.onnx.TrainingMode.EVAL) 

TRAINING导出方式

算子没有融合
PyTorch模型的多种导出方式提供给其他程序使用-LMLPHP

EVAL导出方式

PyTorch模型的多种导出方式提供给其他程序使用-LMLPHP
当采用EVAL方式进行模型导出的时候,Conv和BatchNorm层进行了合并

4 自定义输入输出的名字,并可批量推理

import torch
import torchvision

if __name__ == '__main__':

	model = torchvision.models.resnet18()                          
	model.load_state_dict(torch.load("resnet18-f37072fd.pth")) 
	model.eval()                               


	batch_size = 4 
	input_data = torch.randn(batch_size, 3, 224, 224)

	output_path = "c.onnx"
	torch.onnx.export(model, input_data, output_path,
		          input_names=["input"], output_names=["output"],
		          dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}})

PyTorch模型的多种导出方式提供给其他程序使用-LMLPHP

5 导出JIT模型

JIT(Just-In-Time)
在Yolov5中叫torchscript

import torch
import torchvision

if __name__ == '__main__':

	model = torchvision.models.resnet18()                          
	model.load_state_dict(torch.load("resnet18-f37072fd.pth")) 
	model.eval()                               
	input = torch.rand(1, 3, 224, 224)
	jit_model = torch.jit.trace(model, input)
	torch.jit.save(jit_model, 'resnet18_jit.trace.pth')

	#script_model = torch.jit.script(model, input)
	#torch.jit.save(script_model, 'resnet18_jit.script.pth')

PyTorch模型的多种导出方式提供给其他程序使用-LMLPHP
本文使用的PyTorch版本 1.10.1

10-10 20:32