PyTorch模型转换为ONNX格式指南

在本文中,将探讨如何将PyTorch模型转换为ONNX(Open Neural Network Exchange)格式。ONNX是一个开放的格式,用于表示深度学习模型,使得模型可以在不同的深度学习框架之间进行转换和使用。这种转换对于希望在生产环境中使用单一框架的组织来说尤其有价值,因为它允许他们将研究阶段开发的模型轻松地迁移到生产环境。

PyTorch简介

PyTorch是由Facebook的人工智能研究实验室(FAIR)在2016年发布的,它已经成为自然语言处理和计算机视觉领域研究人员的首选框架。尽管TensorFlow在生产环境中更为广泛地使用,但PyTorch在研究领域的流行度突显了像ONNX这样的标准的价值,它为模型提供了一个通用的格式和一个可以用于所有流行编程语言的运行时环境。

导入转换器

PyTorch的维护者已经将ONNX转换器集成到了PyTorch本身中。这意味着不需要安装任何额外的包。一旦安装了PyTorch,可以通过在模块中包含以下导入来访问PyTorch到ONNX的转换器:

import torch

一旦导入了torch模块,就可以按照以下方式访问转换函数:

torch.onnx.export()

希望其他框架也会采用这种做法。将转换器与框架本身打包和版本控制,意味着少了一个需要安装的包,同时也防止了框架和转换器之间的版本不匹配问题。

模型快速浏览

在转换PyTorch模型之前,需要查看创建模型的代码,以确定输入的形状。下面的代码创建了一个PyTorch模型,用于预测MNIST数据集中的数字。模型层的详细描述超出了本文的范围,但需要注意到输入的形状。在这里它是784。更具体地说,这段代码创建了一个模型,其中输入将是一个被展平的张量,是一个包含784个浮点数的数组。784有什么意义呢?嗯,MNIST数据集中的每张图片都是一个28×28像素的图片。28×28=784。所以一旦被展平,输入就是784个浮点数,每个浮点数代表一个灰度级别。底线是:这个模型期望从单张图片中得到784个浮点数。它不期望一个多维数组,也不期望一批图片。一次只预测一个。在将模型转换为ONNX时,这是一个重要的事实。

def build_model(): # 神经网络的层细节 input_size = 784 hidden_sizes = [128, 64] output_size = 10 # 构建一个前馈网络 model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]), nn.ReLU(), nn.Linear(hidden_sizes[0], hidden_sizes[1]), nn.ReLU(), nn.Linear(hidden_sizes[1], output_size), nn.LogSoftmax(dim=1)) return model

将PyTorch模型转换为ONNX

下面的函数展示了如何使用torch.onnx.export函数。正确使用这个函数有几个技巧。第一个也是最重要的技巧是正确设置样本输入。sample_input参数用于确定ONNX模型的输入。export_to_onnx函数将接受给它的任何东西——只要它是一个张量——转换就会工作,而不会出现错误。然而,如果样本输入的形状不正确,那么当尝试从ONNX运行时运行ONNX模型时,将得到一个错误。

def export_to_onnx(model): sample_input = torch.randn(1, 784) print(type(sample_input)) torch.onnx.export(model, sample_input, ONNX_MODEL_FILE, input_names=['input'], output_names=['output'])

添加到模型的元数据是一个最佳实践。随着用于训练模型的数据的演变,模型也会随之演变。因此,给模型添加元数据是一个好主意,这样就可以将它与以前的模型区分开来。上面的例子在doc_string属性中添加了模型的简短描述,并设置了版本。creation_dateauthor是添加到metadata_props属性包中的自定义属性。可以自由地使用这个属性包创建尽可能多的自定义属性。不幸的是,model_version属性需要一个整数或长整型,所以不能像服务那样使用major.minor.revision语法来版本控制。此外,导出函数会自动将模型保存到文件中,所以为了添加这个元数据,需要重新打开文件并重新保存它。

在本文中,为那些寻找用于构建和训练神经网络的深度学习框架的人提供了PyTorch的简要概述。然后,展示了如何使用已经是PyTorch一部分的转换工具将PyTorch模型转换ONNX格式。还展示了给导出的模型添加元数据的最佳实践。

沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485