文章

快速了解一个网络:CenterNet (Objects as Points)

快速了解一个网络:CenterNet (Objects as Points)

以下内容偏向于记录个人学习过程及思考,非常规教学内容

背景

一阶段目标检测网络通常是按照一定规则在图像/feature map上滑动选取一系列anchors进行检测

二阶段目标检测网络通常是基于proposals的方式进行检测

以上两种方案都需要增加后处理NMS来去除大量的重复框,但后处理过程很难求导和训练。

核心思想

将目标用其中心点表示,然后利用回归网络预测其其它属性,如尺寸、朝向等。

文章基于CNN预测heatmap,heatmap的峰值即为目标的中心点。使用峰值位置部份的img features来回归预测目标属性。

峰值点的提取过程则等价地取代了NMS计算部份。

Pipeline

backbone仍使用hourglass network / up-convolutional residual networks (ResNet) / deep layer aggregation (DLA)等

只是输出建模为了heatmap形式

亿些细节

目标是基于CNN预测一个H/R x W/R x C的heatmap,其中H,W为图像的原始长宽,R为下采样倍率,C为类别数。预测值为1表示center point,预测值为0表示背景。

真值生成过程使用Gaussian kernel。均值使用目标真值的中心点位置下采样R,方差则根据目标尺寸和下采样倍率计算。

如果同一个类别有两个Gaussian分布重叠,则重叠部份取max

为了减少下采样R的离散过程带来的位置误差,额外预测一个local offset去调整位置,所有的class共享相同的offset预测

同样,为了减小计算量,所有类别也是共用一个size预测

因此,对于每个位置,网络会预测C+4个输出,其中C为对应类别的中心点概率,4分别为size和offset的预测

在预测阶段,从heatmap中提取峰值即可(对应位置的值大于周围8个临近点的值)

进一步了解

  • hourglass network
  • up-convolutional residual networks (ResNet)
  • deep layer aggregation (DLA)

原文和代码

https://arxiv.org/abs/1904.07850

参考资料

本文由作者按照 CC BY 4.0 进行授权