目标检测4--Adaptive Training Sample Selection(ATSS)算法
文章目录
欢迎访问个人网络日志🌹🌹知行空间🌹🌹
1.简介
论文Bridging the Gap Between Anchor-based and Anchor-free Detection via
Adaptive Training Sample Selection
代码https://github.com/sfzhang15/ATSS
ATSS
是中科院自动化研究所的Shifeng Zhang
等最早于2019年12月份提交的论文中提出的方法,发表在CVPR2020会议上。
文中分析了Anchor Based
和Anchor Free
的检测方法,性能差异的主要原因在于正负训练样本的定义方式不同,而和回归目标是基于**点式(point)还是盒式(box)**关系不大。Anchor Free
检测常用的有两种方法,一种是keypoint_based
,另一种是center_based
。keypoint_based
的Anchor Free
目标检测算法同标准的keypoint estimation pipeline
,和anchor based
的目标检测算法差异较大。但center_based
的Anchor Free
目标检测算法与Anchor Based
的方法比较相近,center_based
方法将point
作为预设样本(如FCOS),Anchor Based
方法是将anchor
作为预设样本(如RetinaNet)。Anchor Based
的RetinatNet
与Center Based
的FCOS
的主要区别是:
- 1)
feature map
中每个位置的anchor
数量不同,RetinaNet
每个点生成多个anchor boxes
,FCOS
每个点生成一个anchor point
- 2)正负样本的定义方式不同,
RetinaNet
使用IoU
来判定正负样本,FCOS
使用patial and scale constraints
来判断。 - 3)回归起始状态不同。
RetinaNet
是基于Anchor Box
的 ( t x , t y . t ω , t h ) (t_x,t_y.t_\omega,t_h) (tx,ty.tω,th),FCOS
是基于Anchor Point
的 ( l l , l t , l r , l o ) (l_l,l_t,l_r,l_o) (ll,lt,lr,lo)。
ATSS
分析了Anchor Based
和Anchor Free
检测算法实现上的差异,得出的结论是正负样本定义方式的不同影响了两种方法检测效果的差异。基于此论文提出了Adaptive Training Sample Selection(ATSS)
算法以基于目标特征自动的计算正负样本。本文还基于实验得出了在同个位置没必要使用多个anchor box
做检测的结论。
2.目标检测相关
3.Anchor Based
与Anchor Free
目标检测算法的差异分析
Anchor Based
选择RetinaNet
作为代表,Anchor Free
选择FCOS
作为代表,从以下三方面进行分析:
- 1)正负样本定义
- 2)初始回归状态,是回归 t x , t y , t w , t h t_x,t_y,t_w,t_h tx,ty,tw,th还是 l l , l t , l r , l o l_l,l_t,l_r,l_o ll,lt,lr,lo
- 3)每个位置的
anchor
数量
3.1 RetinaNet
与FCOS
的对比
设置RetinaNet
的Anchor box
数量为1
。对FCOS
的改进:
- 1)将
centernerss
移到regression
分支 - 2)使用
GIoU Loss
- 3)将回归目标使用对应level的stride来归一化
这些提升了FCOS
的检测效果,coco minival
上的map
从37.1
提升到了37.8
,进一步拉开了Anchor=1
的RetinaNet
与FCOS
的差距。
FCOS
中使用的一些trick
在Anchor=1
的RetinaNet
中也能使用,如检测头中使用的Group Normlization
, GIoU
,限制ground truth box
中的正样本,对特征金字塔的每层加上一个中心度分支和可训练参数。将这些trick
逐一加到RetinaNet
上的对比结果为:
从上图可以看出,将所有的通用trick
都应用到RetinaNet
上后,MAP
依然有0.8的差距。除了以上指出的通用性差异后,还有两点不同,一个是正负样本的定义方式,另一个是回归任务本身,RetinaNet
是基于Anchor Box
回归,FCOS
是基于Anchor Point
回归。
3.1 正负样本定义的区别
如上图,RetinaNet
根据ground truth box
与anchor box
之间的IoU
的值来判断是正样本还是负样本,通常设置两个超参数
(
I
o
U
n
e
g
,
I
o
U
p
o
s
)
(IoU_{neg}, IoU_{pos})
(IoUneg,IoUpos),小于
I
o
U
n
e
g
IoU_{neg}
IoUneg的是负样本,大于
I
o
U
p
o
s
IoU_{pos}
IoUpos的是正样本,在两者之间的Anchor Box
被忽略,不参与训练,RPN
生产的Proposal Box
基于FPN
论文中提出的方程式2赋值给某个feature
层。FCOS
则先根据Anchor Point
的空间位置是否落在ground truth box
中找出可能为正的Anchor Point
,再根据Anchor Point
对应feature map
上的回归范围regression scale
来近一步确认是否为正样本,参考见博客FCOSNet。基于Spatial and Scale
的正样本判定方式决定了检测器的优秀性能,如下表,使用Spatial and Scale
后,Anchor=1的RetinaNet
的MAP
也提升到了37.8
,换用IoU
的FCOS
的MAP
降到了36.9
:
3.2 回归起始位置的差异
如下图,Anchor=1的RetinaNet
回归的是AnchorBox
相对于ground truth box
的平移缩放
(
t
x
,
t
y
,
t
w
,
t
h
)
(t_x,t_y,t_w,t_h)
(tx,ty,tw,th)即基于box
的回归,而FCOS
回归的是中心点距离ground truth box
四边的距离
l
l
,
l
t
,
l
r
,
l
b
l_l,l_t,l_r,l_b
ll,lt,lr,lb,即基于点的回归。从上图中按行方向比较可以发现,使用box
或point
的回归方式对最终的结果影响不大,37->36.9
,‵37.8->37.8`。
综合3.1和3.2的分析,可以得出结论:是正负样本的定义方式不同影响了Anchor Based
和Anchor Free
算法的性能。
4.自适应训练样本选择
从前面作者得出的结论,How to define positive and negative samples极大影响了检测器的性能,基于此作者提出了新的samples
分类算法,自适应训练样本选择(Adaptive Training Sample Selection, ATSS)。
Anchor Based
基于IoU
和Anchor Free
基于Scale Range
的正样本定义方法都依赖预先定义好的超参数,ATSS
提出了一种自适应取阈值的方法,减少了sample definition
所需的超参数。
以一张输入图像为例说明上图ATSS
算法的工作流程:
- 1)对于1个
ground truth box
,分别在每个金字塔特征层上取中心 L 2 L_2 L2距离最近的 k k k个anchor boxes
作为候选positive sample
,对于有 L \mathcal{L} L个金字塔特征层的网络,共得到 k L k\mathcal{L} kL个candidate positive anchor boxes
- 2)计算
candidates
与ground truth boxes
g ∈ D g g\in \mathcal{D}_{g} g∈Dg之间的IoU
- 3)计算2)中
IoU
的均值 m g m_{\mathcal{g}} mg和标准差 v g \mathcal{v}_{\mathcal{g}} vg - 4)取
t
g
=
m
g
+
v
g
t_g=m_{\mathcal{g}}+\mathcal{v}_{\mathcal{g}}
tg=mg+vg作为阈值,大于
t
g
t_g
tg的是
positive
,其余的Anchor Boxes
都是negative
作者指出,当一个anchor box
同时落入两个ground truth box
中时,会将其分配给IoU
比较大的ground truth box
。
从上图可以看出ATSS
的作用,对于某个ground truth box
,图a中标准差较大,意味着有某个金字塔特征层比较适合预测该box
,因此阈值
t
g
t_g
tg也比较大。图b中标准差不大,意味者可能有多个特征层适合预测当前box
,因此选取的阈值
t
g
t_g
tg也较小。
作者还指出使用ATSS
,可以使得对于不同大小的目标对象得到相同比例的正负训练样本。对于标准正态分布有16%
的样本落在
[
v
+
σ
,
1
]
[v+\sigma,1]
[v+σ,1]之间,虽然IoU of candidates
不是正态分布,正样本的比例依然保持在了20% of
k
L
k\mathcal{L}
kL 左右,和目标
s
c
a
l
e
/
a
s
p
e
c
t
r
a
t
i
o
/
l
o
c
a
t
i
o
n
scale/aspect ratio/location
scale/aspectratio/location无关。而RetinaNet
和FCOS
都会倾向于对大目标生成更多的正样本。
ATSS
使用的超参数很少,只有k
一个,且算法效果对k
不敏感。实验证明k
取[3, 5, 7, 9, 11, 13, 15, 17, 19]
时map
变化不大:
5.代码实现
mmdetection
中ATSS
算法的实现在ATSSAssigner
类中,assign
的部分代码如下:
# Selecting candidates based on the center distance
candidate_idxs = []
start_idx = 0
for level, bboxes_per_level in enumerate(num_level_bboxes):
# on each pyramid level, for each gt,
# select k bbox whose center are closest to the gt center
end_idx = start_idx + bboxes_per_level
distances_per_level = distances[start_idx:end_idx, :]
selectable_k = min(self.topk, bboxes_per_level)
_, topk_idxs_per_level = distances_per_level.topk(
selectable_k, dim=0, largest=False)
candidate_idxs.append(topk_idxs_per_level + start_idx)
start_idx = end_idx
candidate_idxs = torch.cat(candidate_idxs, dim=0)
# get corresponding iou for the these candidates, and compute the
# mean and std, set mean + std as the iou threshold
candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)]
overlaps_mean_per_gt = candidate_overlaps.mean(0)
overlaps_std_per_gt = candidate_overlaps.std(0)
overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]
# limit the positive sample's center in gt
for gt_idx in range(num_gt):
candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
ep_bboxes_cx = bboxes_cx.view(1, -1).expand(
num_gt, num_bboxes).contiguous().view(-1)
ep_bboxes_cy = bboxes_cy.view(1, -1).expand(
num_gt, num_bboxes).contiguous().view(-1)
candidate_idxs = candidate_idxs.view(-1)
# calculate the left, top, right, bottom distance between positive
# bbox center and gt side
l_ = ep_bboxes_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
t_ = ep_bboxes_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt)
b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt)
is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
is_pos = is_pos & is_in_gts
欢迎访问个人网络日志🌹🌹知行空间🌹🌹
参考资料
相关文章
- C/C++排序算法(四)—— 归并排序和计数排序
- 【算法】给定一个数组,求如果排序之后,相邻两数的最大差值,要求时间复杂度O(N),且要求不能用非基于比较的排序
- 字典序输出全排列算法
- 自动驾驶感知——视觉感知经典算法
- poj2186Popular Cows(Kosaraju算法--有向图的强连通分量的分解)
- 算法学习专栏简介
- 机器学习-有监督学习-集成学习方法(四):Bootstrap->Boosting(提升)方法-->Gradient Boosting(梯度提升)算法--+决策树-->GBDT梯度提升树
- 人工智能-损失函数-优化算法:梯度下降【SGD-->SGDM(梯度动量)-->AdaGrad(动态学习率)-->RMSProp(动态学习率)-->Adam(动态学习率+梯度动量)】、梯度下降优化技巧
- 基于改进量子粒子群算法的电力系统经济调度(Matlab代码实现)
- 第十五章 加密算法实例1--注册登录(消息摘要算法)
- 第四章 消息摘要算法--SHA
- 条件随机场(CRF) - 4 - 学习方法和预测算法(维特比算法)
- 计算年龄算法(周岁虚岁)
- 目标检测5--旷视YOLOX算法介绍
- 3.基于分割的文本检测算法--DBNet++
- ChatGPT 强化学习 Proximal Policy Optimization 近似策略优化算法
- js--算法--定值数组
- 分享天天爱消除算法
- 程序设计与算法--(枚举-完美立方)
- 基于协同过滤算法的推荐
- 蓝桥算法两周训练营--Day2:DP
- 回溯与深度优先算法的关系总结
- opencv-watershed分水岭算法--图像自动分割法
- 【数据结构与算法分析】0基础带你学数据结构与算法分析12--红黑树
- 【数据结构与算法分析】0基础带你学数据结构与算法分析10--树和森林
- 让你一学就会的那些算法知识总结--基础算法 二分
- 【Android 内存优化】Java 内存模型 ( Java 虚拟机内存模型 | 线程私有区 | 共享数据区 | 内存回收算法 | 引用计数 | 可达性分析 )
- 【檀越剑指大厂--算法】链表总结