在本文中,将探讨如何将一个用于实时驾驶危险检测的模型加载到Android项目中,并准备使用TensorFlow Lite进行图像处理。首先,需要将模型添加到项目中。为此,请在项目的src/main
目录下创建一个名为assets
的新文件夹。然后将TensorFlow Lite模型和包含标签的文本文件复制到src/main/assets
中,使其成为项目的一部分。
在开始编写代码之前,需要了解模型期望其输入数据的结构。数据以多维数组的形式传入和传出,这也被称为数据的“形状”。通常,当找到一个模型时,这些信息会被记录下来。也可以使用工具Netron
来检查数据。当使用这个工具打开模型时,会显示构成网络的节点。点击输入节点(显示在图的顶部)会显示输入数据的格式(在这种情况下,是图像)和网络的输出。在本例中,可以看到输入数据是一个32位浮点数数组。数组的维度是1x416x416x3。这意味着网络将接受一次一个416x416像素的图像,带有红色、绿色和蓝色组件。如果在这个项目中使用不同的模型,需要检查模型的输入和输出,并相应地调整代码。将在解释结果时更详细地检查输出数据。
将添加一个名为Detector
的新类到项目中。所有管理训练网络的代码都将添加到这个类中。当构建这个类时,它将接受一个图像,并以更易于使用的形式提供结果。应该向类中添加一些常量和字段,以便开始使用它。字段包括一个TensorFlow解释器对象,用于包含训练网络,一个模型识别的对象类别列表,以及应用程序上下文。
这个类的构造函数将创建输出缓冲区,加载网络模型,并从assets文件夹中加载对象类别的名称。
执行网络模型只需要几行代码。当一个图像被提供给Detector
类时,它将被调整大小以匹配网络的要求。Bitmap
图像中的数据被编码为字节。这些值必须转换为32位浮点值。TensorFlow Lite库包含了使这种常见转换变得简单的功能。TensorImage
类型还有一个方便的方法,允许它被用作需要输入缓冲区的方法的缓冲区。
现在可以选取一个图像,检测器将处理这个图像,识别其中的物体。但是这些结果意味着什么?如何使用这些结果来警告用户关于危险?在本系列的下一篇文章中,将解释结果,并为用户提供相关信息。
以下是Detector
类的一个简化示例:
public class Detector {
private final String TF_MODEL_NAME = "yolov4.tflite";
private final int IMAGE_WIDTH = 416;
private final int IMAGE_HEIGHT = 416;
private final String TAG = "Detector";
private final boolean useGpuDelegate = false;
private final boolean useNNAPI = true;
private Context context;
private Interpreter tfLiteInterpreter;
private List labelList;
// 这些输出值的结构与使用的训练模型的输出相匹配
private float[][][] buf0 = new float[1][52][52][3];
private float[][][] buf1 = new float[1][26][26][3];
private float[][][] buf2 = new float[1][13][13][3];
private HashMap outputBuffers;
public Detector(Context context) {
this.context = context;
// 初始化TensorFlow解释器
// 加载模型
// 加载类别标签
}
public void processImage(Bitmap sourceImage) {
// 图像处理逻辑
}
}
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout>
<ImageView
android:id="@+id/selected_image_view"
/>
<Button
android:id="@+id/select_image_button"
android:onClick="onSelectImageClicked"
/>
</androidx.constraintlayout.widget.ConstraintLayout>
public override fun onActivityResult(reqCode: Int, resultCode: Int, data: Intent?) {
super.onActivityResult(reqCode, resultCode, data)
if (resultCode == RESULT_OK) {
if (reqCode == SELECT_PICTURE) {
val selectedUri = data!!.data
val fileString = selectedUri!!.path
selected_image_view!!.setImageURI(selectedUri)
var sourceBitmap: Bitmap? = null
try {
sourceBitmap = MediaStore.Images.Media.getBitmap(this.contentResolver, selectedUri)
RunDetector(sourceBitmap)
} catch (e: IOException) {
e.printStackTrace()
}
}
}
}
fun RunDetector(bitmap: Bitmap?) {
if (detector == null) detector = Detector(this)
detector!!.processImage(bitmap!!)
}