ONNX在Java中的安装与使用

ONNX(Open Neural Network Exchange)是由微软、Facebook和AWS共同开发的一个开放格式,旨在使不同的机器学习框架能够导出它们的模型到ONNX格式,并确保这些模型能在任何硬件配置上运行。ONNX Runtime是一个用于运行转换为ONNX格式的机器学习模型的引擎,无论是传统的机器学习模型还是深度学习模型(神经网络),都可以被导出到ONNX格式。ONNX Runtime可以在Linux、Windows和Mac上运行,并且支持多种芯片架构,包括利用硬件加速器如GPU和TPU。但是,并非所有的操作系统、芯片架构和加速器组合都有现成的安装包,如果使用的是非常见的组合,可能需要从源代码构建运行时。本文将展示如何在x64架构上安装ONNX Runtime,包括默认CPU和带有GPU的情况。

安装和导入ONNX Runtime

在开始使用ONNX Runtime之前,需要为构建工具添加适当的依赖。Maven仓库是为包括Maven和Gradle在内的多种工具设置ONNX Runtime的好来源。要使用默认CPU的x64架构,请参考下面的链接。

https://mvnrepository.com/artifact/org.bytedeco/onnxruntime-platform

要使用带有GPU的x64架构,请使用以下链接。

https://mvnrepository.com/artifact/org.bytedeco/onnxruntime-platform-gpu

一旦安装了运行时,就可以使用以下Java代码中的import语句将其导入到Java代码文件中。这些import语句将帮助为ONNX模型创建输入,并解释ONNX模型的输出(预测)。

import ai.onnxruntime.OnnxMl.TensorProto; import ai.onnxruntime.OnnxMl.TensorProto.DataType; import ai.onnxruntime.OrtSession.Result; import ai.onnxruntime.OrtSession.SessionOptions; import ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode; import ai.onnxruntime.OrtSession.SessionOptions.OptLevel;

加载ONNX模型

下面的代码片段展示了如何在Java中运行ONNX Runtime来加载ONNX模型。这段代码创建了一个会话对象,可以用来进行预测。这里使用的模型是从PyTorch导出的ONNX模型。

这里值得注意几点。首先,需要查询会话以获取其输入。这是通过会话的getInputInfo方法完成的。MNIST模型只有一个输入参数:一个表示MNIST数据集中一张图像的784个浮点数数组。如果模型有多个输入参数,那么InputMetadata将为每个参数有一个条目。

String modelPath = "pytorch_mnist.onnx"; try (OrtSession session = env.createSession(modelPath, options)) { Map inputMetaMap = session.getInputInfo(); Map container = new HashMap<>(); NodeInfo inputMeta = inputMetaMap.values().iterator().next(); float[] inputData = Utilities.ImageData[imageIndex]; string label = Utilities.ImageLabels[imageIndex]; System.out.println("Selected image is the number: " + label); // 这是这个模型唯一的输入张量的数据 Object tensorData = OrtUtil.reshape(inputData, ((TensorInfo) inputMeta.getInfo()).getShape()); OnnxTensor inputTensor = OnnxTensor.createTensor(env, tensorData); container.put(inputMeta.getName(), inputTensor); // 运行代码省略,以简洁起见。 }

上述代码中未显示的实用工具读取原始MNIST图像,并将每张图像转换为784个浮点数的数组。每张图像的标签也从MNIST数据集中读取,以便确定预测的准确性。这段代码是标准的Java代码,但仍然鼓励查看并使用它。如果需要读取类似于MNIST数据集的图像,它将为节省时间。

使用ONNX Runtime进行预测

下面的函数展示了如何使用加载ONNX模型时创建的ONNX会话。

try (OrtSession session = env.createSession(modelPath, options)) { // 加载代码省略,以简洁起见。 // 运行推理 try (OrtSession.Result results = session.run(container)) { for (Map.Entry r : results) { OnnxValue resultValue = r.getValue(); OnnxTensor resultTensor = (OnnxTensor) resultValue; resultTensor.getValue(); System.out.println("Output Name: " + r.getName()); int prediction = MaxProbability(resultTensor); System.out.println("Prediction: " + prediction.ToString()); } } }

大多数神经网络不会直接返回预测。它们返回每个输出类别的概率列表。在MNIST模型中,每张图像的返回值将是10个概率的列表。具有最高概率的条目是预测。一个有趣的测试是将ONNX模型返回的概率与在创建模型的框架内运行原始模型时返回的概率进行比较。理想情况下,模型格式和运行时的变化不应该改变产生的任何概率。这将是一个良好的单元测试,每次模型发生变化时都会运行。

在本文中,简要概述了ONNXRuntime和ONNX格式。然后展示了如何使用Java在ONNX Runtime中加载和运行ONNX模型。本文的代码示例包含一个工作控制台应用程序,演示了这里展示的所有技术。这个代码示例是GitHub仓库的一部分,该仓库探索了使用神经网络预测MNIST数据集中发现的数字。具体来说,有示例展示了如何在Keras、PyTorch、TensorFlow 1.0和TensorFlow 2.0中创建神经网络。

https://microsoft.github.io/onnxruntime/ https://microsoft.github.io/onnxruntime/ https://github.com/microsoft/onnxruntime/blob/master/docs/Java_API.md#getting-started https://github.com/keithpij/onnx-lab
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485