在深度学习领域,模型的可移植性是一个重要的考量。ONNX(Open Neural Network Exchange)是一个开放的格式,旨在使不同的深度学习框架之间能够交换模型。本文将介绍如何将TensorFlow模型转换为ONNX格式。首先,将简要介绍TensorFlow 1.0和TensorFlow 2.0,然后展示如何安装和使用转换工具,最后讨论转换过程中可能遇到的问题。
TensorFlow最初由Google Brain团队创建。1.0版本于2017年2月发布。它具有易于在任何计算设备上部署的架构,例如CPU、GPU或TPU服务器集群,以及移动设备和边缘设备。这种灵活的部署能力使其成为生产环境的宠儿。然而,由于PyTorch引入了“定义并运行”方案,TensorFlow 1.0在研究环境中的市场份额有所下降。TensorFlow 1.0基于“定义并运行”方案,其中网络是预先定义并固定的,运行时唯一发生的是数据被送入网络。
为了吸引研究人员,TensorFlow 2.0引入了“定义并运行”方案,并整合了Keras。Keras是一个用于构建神经网络的高级API,易于使用。TensorFlow 2.0文档中的大多数示例都使用Keras,但也可以仅使用TensorFlow 2.0构建神经网络。如果需要对网络进行低级控制,或者正在将现有的TensorFlow 1.0网络迁移到TensorFlow 2.0,那么这可能是想要的。
在将TensorFlow模型转换为ONNX之前,需要安装tf2onnx包,因为它不包含在TensorFlow的任何版本中。以下命令将安装这个工具:
pip install tf2onnx
安装完成后,可以使用以下导入语句将转换器导入到模块中。然而,正如将在接下来的几个部分中看到的,从命令行使用tf2onnx工具更容易。
import tf2onnx
TensorFlow 1.0和TensorFlow 2.0都提供了一个较低级别的API。例如,它们允许设置权重和偏差,以获得对模型训练的更多控制。然而,为了将TensorFlow 1.0和TensorFlow 2.0模型转换为ONNX,只关心指定输入、输出以及模型在TensorFlow格式之一中保存时的位置的代码。(本文有一个完整的示例,用于从MNIST数据集中的手写样本预测数字。)
TensorFlow 1.0使用占位符方法创建用于指示模型输入和输出的特殊变量。下面是一个示例。为了使转换到ONNX更容易,最好在设置占位符时指定一个名称。这里输入名称为"input",输出名称为"output"。
# tf Graph input
X = tf.placeholder("float", [batch_size, num_input], name="input")
Y = tf.placeholder("float", [batch_size, num_classes], name="output")
一旦在TensorFlow 1.0中创建了会话,就可以使用以下代码保存模型。可能会有点混淆的是,文件格式从TensorFlow 1.0变到了TensorFlow 2.0。TensorFlow 1.0使用"checkpoint"文件来持久化模型。下面的示例指定了一个checkpoint文件。约定是使用扩展名.ckpt。还会将其他文件保存到与ckpt文件相同的目录中,因此最好为保存的模型创建一个目录。
saver = tf.train.Saver()
save_path = saver.save(sess, './tensorflow/tensorflow_model.ckpt')
TensorFlow 2.0使用"SavedModel"格式,这使得转换过程稍微容易一些。以下是保存TensorFlow 2.0模型的命令。注意没有指定文件名,而是指定了一个目录。(顺便说一句,整合到TensorFlow 2.0中的Keras使用相同的格式。)
tf.saved_model.save(model, './tensorflow')
将TensorFlow模型转换为ONNX的最简单方法是使用命令行中的tf2onnx工具。当从命令行使用tf2onnx时,它将转换一个保存的TensorFlow模型到另一个表示模型的ONNX格式的文件。也可以从代码中运行转换,但是对于内存中的TensorFlow模型,tf2onnx可能会冻结图。冻结是一个过程,其中图中的所有变量都转换为常量。这对于ONNX是必要的,因为它是一个推理图,没有变量。tf2onnx工具包含一个名为process_tf_graph的函数,如果想要在代码中转换,可以尝试使用它。然而,如果最终得到错误消息KeyError: tf.float32_ref,那么最好从命令行转换文件。
以下是将TensorFlow 1.0 checkpoint文件转换为ONNX的命令。注意需要找到meta文件并将其传递给tf2onnx。还需要指定输入名称和输出名称。
python -m tf2onnx.convert --checkpoint ./tensorflow/tensorflow_model.ckpt.meta --output tfmodel.onnx --inputs input:0 --outputs output:0
以下是将TensorFlow 2.0模型转换的命令。需要指定保存模型到磁盘的目录。(它不是保存在单个文件中。)还需要指定ONNX输出文件。不需要指定输入名称和输出名称。
python -m tf2onnx.convert --saved-model ./tensorflow --output tfmodel.onnx
在本文中,为那些寻找用于构建和训练神经网络的深度学习框架的人提供了TensorFlow 1.0和TensorFlow 2.0的简要概述。然后展示了如何安装tf2onnx转换包并将TensorFlow模型转换为ONNX格式。还展示了在指定输入参数的形状时容易犯的错误。