大规模人脸识别技术探索

深度学习领域,人脸识别技术取得了显著进展,尤其是在准确性和可扩展性方面。然而,随着身份数量的指数级增长,GPU内存的有限容量成为了一个重大挑战。以往的研究主要集中在优化用于面部特征提取网络的损失函数上,基于softmax的损失函数推动了人脸识别性能的提升。但是,随着身份数量的增加,GPU内存的限制变得越来越难以克服。本文将探讨如何使用部分全连接(Partial FC)层在大规模人脸识别中实现优化。

学习目标

了解大规模人脸识别中softmax损失函数带来的挑战,例如计算开销和身份数量。探索部分全连接(PFC)层,优化人脸识别任务中的内存和计算,包括其优缺点和应用场景。在人脸识别项目中实现部分FC,提供实用技巧、代码片段和资源。

Softmax瓶颈是什么?

Softmax损失及其变体被广泛采用作为人脸识别任务的目标函数。这些函数在嵌入特征和线性变换矩阵之间的乘法过程中进行全局特征到类别的比较。然而,在处理训练集中的大量身份时,存储和计算最终线性矩阵的成本常常超出当前GPU硬件的能力,这可能导致训练失败。

以前的加速尝试

研究人员探索了各种技术来缓解这个瓶颈,每种技术都有其自身的权衡和局限性。HF-softmax采用动态选择过程,在每个小批量中选择活动类别中心。这种选择是通过在嵌入空间构建随机哈希森林来实现的,使得基于特征检索近似最近的类别中心。然而,重要的是要注意,存储所有类别中心在RAM中,并且不可忽视特征检索的计算开销。

另一方面,Softmax Dissection将softmax损失分解为类内和类间目标,从而减少了类间部分的冗余计算。虽然这种方法值得称赞,但它在适应性和多功能性方面受到限制,因为它只适用于特定的基于softmax的损失函数。

模型并行:正确的一步

ArcFace损失函数引入了模型并行,它将softmax权重矩阵分散在不同的GPU上,并以最小的通信开销计算全类别softmax损失。这种方法成功地使用八块GPU在单台机器上训练了100万个身份。

模型并行方法将softmax权重矩阵W ∈ R (d×C)划分为k个子矩阵w,大小为d × (C/k),其中d是嵌入特征维度,C是类别数量。然后每个子矩阵wi放置在第i个GPU上。

为了计算最终的softmax输出,每个GPU独立计算分子e^((wi)T * X),其中X是输入特征。分母∑ j=1 to C e^((wj)T * X)需要从所有其他GPU收集信息,这是通过首先在每个GPU上计算局部和,然后通信局部和来计算全局和来完成的。

模型并行的内存限制

虽然模型并行减轻了存储权重矩阵W的内存负担,但它引入了一个新的瓶颈——存储预测logits。预测logits是在前向传递期间计算的中间值,它们的存储需求随着所有GPU上的总批量大小而扩展。随着身份和GPU数量的增加,存储logits的内存消耗可以迅速超过GPU内存容量。

引入部分FC

为了克服以前方法的限制,“部分FC”论文的提出了一个开创性的解决方案!部分FC引入了一个softmax近似算法,可以在只使用一小部分(例如,10%)类别中心的情况下保持最先进的准确性。通过在训练期间仔细选择类别中心的子集,可以显著减少内存和计算需求。这将进一步使训练具有前所未有的身份数量的人脸识别模型成为可能。

部分FC的优势

通过随机抽样负类别中心,部分FC受标签噪声或类间冲突的影响较小。在长尾分布中,一些类别的样本比其他类别少得多,部分FC避免了过度更新较少频率的类别,从而带来更好的性能。部分FC可以在仅有8个GPU的情况下训练超过1000万身份,而ArcFace在相同GPU数量下只能处理100万身份。

部分FC的缺点

选择合适的抽样率(r%)对于保持准确性和效率至关重要。太低的速率可能会降低性能,而太高的速率可能会抵消内存和计算优势。随机抽样过程可能会引入噪声,如果处理不当,可能会影响模型的性能。

释放部分FC的力量

def sample(self, labels, index_positive): with torch.no_grad(): positive = torch.unique(labels[index_positive], sorted=True).cuda() if self.num_sample - positive.size(0) >= 0: perm = torch.rand(size=[self.num_local]).cuda() perm[positive] = 2.0 index = torch.topk(perm, k=self.num_sample)[1].cuda() index = index.sort()[0].cuda() else: index = positive self.weight_index = index labels[index_positive] = torch.searchsorted(index, labels[index_positive]) return self.weight[self.weight_index]
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485