非极大值抑制(NMS)在目标检测中的应用

在进行图像分析和目标检测时,经常会遇到一个问题:同一个物体在图像中被多次检测到。这种情况在构建高级分析功能,如计数或跟踪检测时尤为不便。幸运的是,这个问题可以通过非极大值抑制(Non-Maximum Suppression,简称NMS)来解决。本文将解释NMS的工作原理,并展示如何在Python中使用NumPy库来实现NMS算法。

如果正在寻找一个快速的解决方案,并且没有时间深入了解数学和代码,可以使用pip包来处理这个问题。"supervision"这个pip包提供了一个NMS算法,可以轻松地过滤掉不需要的检测,无论使用哪种模型。实际上,上面的图片就是使用"supervision"创建的。以下是如何使用这个库的示例代码:

import supervision as sv results = ... detections = sv.Detections.from_transformers(transformers_results=results) detections = detections.with_nms(threshold=0.5)

NMS的核心思想是寻找重叠度很高的一组边界框,并决定保留哪些框,移除哪些框。用来衡量重叠程度的指标叫做交集比(Intersection over Union,简称IoU)。要计算IoU,首先需要计算两个单独的框A和B的面积,以及它们的交集(I)。然后,可以使用这些项来计算并集(U)。最后,可以用I除以U来得到这个指标。

为了计算A和B的面积、它们的交集(I)和并集(U),可以使用向量化的IoU计算方法。这种方法在需要同时计算数十个或数百个框的IoU时非常有用。幸运的是,可以使用矩阵操作来加速这个过程。"box_iou_batch"函数是通用的,允许计算列表A中的每个框与组B中的每个框之间的IoU。在案例中,这些组是相等的,并且是模型提供的所有检测的集合。"boxes_a"和"boxes_b"是二维矩阵,其中每一行描述一个单独的框(x_min, y_min, x_max, y_max)。

def box_iou_batch( boxes_a: np.ndarray, boxes_b: np.ndarray ) -> np.ndarray: def box_area(box): return (box[2] - box[0]) * (box[3] - box[1]) area_a = box_area(boxes_a.T) area_b = box_area(boxes_b.T) top_left = np.maximum(boxes_a[:, None, :2], boxes_b[:, :2]) bottom_right = np.minimum(boxes_a[:, None, 2:], boxes_b[:, 2:]) area_inter = np.prod( np.clip(bottom_right - top_left, a_min=0, a_max=None), 2) return area_inter / (area_a[:, None] + area_b - area_inter)

有了IoU计算函数,就可以开始处理NMS了。首先,将检测结果打包成一个二维矩阵。前四列由边界框坐标占据——(x_min, y_min, x_max, y_max),后面是分数和分配的类别。按照分数降序排序矩阵。使用"box_iou_batch"计算所有边界框之间的IOUs。遍历矩阵的行,并使用IoU矩阵中的信息,丢弃具有相同类别且IoU超过定义阈值的所有检测。

def non_max_suppression( predictions: np.ndarray, iou_threshold: float = 0.5 ) -> np.ndarray: rows, columns = predictions.shape sort_index = np.flip(predictions[:, 4].argsort()) predictions = predictions[sort_index] boxes = predictions[:, :4] categories = predictions[:, 5] ious = box_iou_batch(boxes, boxes) ious = ious - np.eye(rows) keep = np.ones(rows, dtype=bool) for index, (iou, category) in enumerate(zip(ious, categories)): if not keep[index]: continue condition = (iou > iou_threshold) & (categories == category) keep = keep & ~condition return keep[sort_index.argsort()]
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485