在进行图像分析和目标检测时,经常会遇到一个问题:同一个物体在图像中被多次检测到。这种情况在构建高级分析功能,如计数或跟踪检测时尤为不便。幸运的是,这个问题可以通过非极大值抑制(Non-Maximum Suppression,简称NMS)来解决。本文将解释NMS的工作原理,并展示如何在Python中使用NumPy库来实现NMS算法。
如果正在寻找一个快速的解决方案,并且没有时间深入了解数学和代码,可以使用pip包来处理这个问题。"supervision"这个pip包提供了一个NMS算法,可以轻松地过滤掉不需要的检测,无论使用哪种模型。实际上,上面的图片就是使用"supervision"创建的。以下是如何使用这个库的示例代码:
import supervision as sv
results = ...
detections = sv.Detections.from_transformers(transformers_results=results)
detections = detections.with_nms(threshold=0.5)
NMS的核心思想是寻找重叠度很高的一组边界框,并决定保留哪些框,移除哪些框。用来衡量重叠程度的指标叫做交集比(Intersection over Union,简称IoU)。要计算IoU,首先需要计算两个单独的框A和B的面积,以及它们的交集(I)。然后,可以使用这些项来计算并集(U)。最后,可以用I除以U来得到这个指标。
为了计算A和B的面积、它们的交集(I)和并集(U),可以使用向量化的IoU计算方法。这种方法在需要同时计算数十个或数百个框的IoU时非常有用。幸运的是,可以使用矩阵操作来加速这个过程。"box_iou_batch"函数是通用的,允许计算列表A中的每个框与组B中的每个框之间的IoU。在案例中,这些组是相等的,并且是模型提供的所有检测的集合。"boxes_a"和"boxes_b"是二维矩阵,其中每一行描述一个单独的框(x_min, y_min, x_max, y_max)。
def box_iou_batch(
boxes_a: np.ndarray, boxes_b: np.ndarray
) -> np.ndarray:
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
area_a = box_area(boxes_a.T)
area_b = box_area(boxes_b.T)
top_left = np.maximum(boxes_a[:, None, :2], boxes_b[:, :2])
bottom_right = np.minimum(boxes_a[:, None, 2:], boxes_b[:, 2:])
area_inter = np.prod(
np.clip(bottom_right - top_left, a_min=0, a_max=None), 2)
return area_inter / (area_a[:, None] + area_b - area_inter)
有了IoU计算函数,就可以开始处理NMS了。首先,将检测结果打包成一个二维矩阵。前四列由边界框坐标占据——(x_min, y_min, x_max, y_max),后面是分数和分配的类别。按照分数降序排序矩阵。使用"box_iou_batch"计算所有边界框之间的IOUs。遍历矩阵的行,并使用IoU矩阵中的信息,丢弃具有相同类别且IoU超过定义阈值的所有检测。
def non_max_suppression(
predictions: np.ndarray, iou_threshold: float = 0.5
) -> np.ndarray:
rows, columns = predictions.shape
sort_index = np.flip(predictions[:, 4].argsort())
predictions = predictions[sort_index]
boxes = predictions[:, :4]
categories = predictions[:, 5]
ious = box_iou_batch(boxes, boxes)
ious = ious - np.eye(rows)
keep = np.ones(rows, dtype=bool)
for index, (iou, category) in enumerate(zip(ious, categories)):
if not keep[index]:
continue
condition = (iou > iou_threshold) & (categories == category)
keep = keep & ~condition
return keep[sort_index.argsort()]