如何正确的理解RPN网络的train和test[通俗易懂]
大家好,又见面了,我是你们的朋友全栈君。
刚开始学Faster RCNN时,遇到些困惑不知其他人有没有: 1、RPN网络训练的输出是什么? 2、RPN网络在train中的作用是什么? 3、RPN网络在test中的作用是什么? 其实这些我们如果不看源码都很难真正理解! 以Faster-RCNN_TF的源码为例,以下代码取自./lib/networks/VGGnet_train.py
#========= RPN ============
#以下代码的先后顺序我调整了一下,便于理解
(self.feed('conv5_3')
.conv(3,3,512,1,1,name='rpn_conv/3x3')
.conv(1,1,len(anchor_scales)*3*2 ,1 , 1, padding='VALID', relu = False, name='rpn_cls_score'))
(self.feed('rpn_conv/3x3')
.conv(1,1,len(anchor_scales)*3*4, 1, 1, padding='VALID', relu = False, name='rpn_bbox_pred'))
.anchor_target_layer(_feat_stride, anchor_scales, name = 'rpn-data' ))
重点:
anchor_target_layer的返回值’rpn-data’,这是一个字典 key分别是:rpn_labels, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights
rpn_labels 是 [1,1,A*height,width],如果把它reshape成[1,A,height,width]会更好理解,即feature map上每一点 都是一个anchor,每个anchor对应A个bbox,如果一个bbox与gt_box的重叠度大于0.7(其实还有一个条件),就认为这个bbox包含一个前景,则 rpn_labels 矩阵中相应位置就设置为1。 gt_box的label不能直接用来做训练的目标(target),在训练中使用rpn_labels作为训练的目标 gt_box的唯一作用就在于判断产生的共A*W*H个bbox哪些属于前景,哪些不属于,将那些属于前景的bbox设置为训练目标去训练rpn_cls_score_reshape。 在test中,正好相反,训练好的网络会产生一个rpn_cls_score_reshape,它可以转化成一个[1,A,height,width]的矩阵 #proposal_layer 产生的[1,A,height,width]个bbox哪些属于前景,哪些属于背景。我们会把属于前景的挑出来, 按照得分排序,取前300个输入后面的fc层,fc层会产生两个输出: 一个是cls_score,用于判断bbox中物体的类型 另一个是bbox_pred,用于微调bbox,使其向gt_box进一步靠近(由于bbox都是从anchor产生的,他们不会和gt_box重合,还需要进一步微调)
rpn_bbox_targets 根据 rpn_labels 我们已经可以挑选出300个bbox,这些bbox都是在[1,W,H,A*4]中根据与gt_box的重合程度挑选出来的,与gt_box并不相同,有一些偏差,这些偏差表示为[dx,dy,dw,dh],这就是rpn_bbox_targets。 因为传进后面全卷积网络的是bbox,与gt_boxes不完全重合,为了使最终的结果更加接近gt_box,还需要进一步微调 而全卷积层的输出bbox_pred就是用于微调的,rpn_bbox_targets就是它训练的目标(target) 损失函数的计算:
# RPN
# classification loss
rpn_cls_score = tf.reshape(self.net.get_output('rpn_cls_score_reshape'),[-1,2])
rpn_label = tf.reshape(self.net.get_output('rpn-data')[0],[-1])
rpn_cls_score = tf.reshape(tf.gather(rpn_cls_score,tf.where(tf.not_equal(rpn_label,-1))),[-1,2])
rpn_label = tf.reshape(tf.gather(rpn_label,tf.where(tf.not_equal(rpn_label,-1))),[-1])
rpn_cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=rpn_cls_score, labels=rpn_label))
# bounding box regression L1 loss
rpn_bbox_pred = self.net.get_output('rpn_bbox_pred')
rpn_bbox_targets = tf.transpose(self.net.get_output('rpn-data')[1],[0,2,3,1])
rpn_bbox_inside_weights = tf.transpose(self.net.get_output('rpn-data')[2],[0,2,3,1])
rpn_bbox_outside_weights = tf.transpose(self.net.get_output('rpn-data')[3],[0,2,3,1])
rpn_smooth_l1 = self._modified_smooth_l1(3.0, rpn_bbox_pred, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights)
rpn_loss_box = tf.reduce_mean(tf.reduce_sum(rpn_smooth_l1, reduction_indices=[1, 2, 3]))
其余代码:
# Loss of rpn_cls & rpn_boxes
(self.feed('rpn_conv/3x3')
.conv(1,1,len(anchor_scales)*3*4, 1, 1, padding='VALID', relu = False, name='rpn_bbox_pred'))
#========= RoI Proposal ============
(self.feed('rpn_cls_score')
.reshape_layer(2,name = 'rpn_cls_score_reshape')
.softmax(name='rpn_cls_prob'))
(self.feed('rpn_cls_prob')
.reshape_layer(len(anchor_scales)*3*2,name = 'rpn_cls_prob_reshape'))
(self.feed('rpn_cls_prob_reshape','rpn_bbox_pred','im_info')
.proposal_layer(_feat_stride, anchor_scales, 'TRAIN',name = 'rpn_rois'))
(self.feed('rpn_rois','gt_boxes')
.proposal_target_layer(n_classes,name = 'roi-data'))
#========= RCNN ============
(self.feed('conv5_3', 'roi-data')
.roi_pool(7, 7, 1.0/16, name='pool_5')
.fc(4096, name='fc6')
.dropout(0.5, name='drop6')
.fc(4096, name='fc7')
.dropout(0.5, name='drop7')
.fc(n_classes, relu=False, name='cls_score')
.softmax(name='cls_prob'))
(self.feed('drop7')
.fc(n_classes*4, relu=False, name='bbox_pred'))
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/152315.html原文链接:https://javaforall.cn
相关文章
- 从云南大象保护,看全光网络如何照亮城市数智未来
- 根据IP地址和子网掩码求网络号、主机号
- 算力网络,到底是如何工作的?
- 网络编程java版简述
- 没有网络的工控现场如何传输文件?教你一招——就近共享
- 如何mount到网络为NAT方式的虚拟机
- NeurIPS 2022 | 如何提高存储、传输效率?参数集约型掩码网络效果显著
- linux网络时间同步详解程序员
- 配置Linux实现网络配置开启(linux打开网络)
- 设置如何快速设置Mac OS WiFi网络(macoswifi)
- Linux快速刷流量开启网络创意之路(Linux刷流量)
- 细节决定成效:Linux网络传输实践(linux网络传输)
- SDN如何取代传统网络基础设施?
- 【硬创邦】跟hoowa学做智能路由(十三):网络音箱之Android篇
- Linux下如何测量网络速度(linux测网速)
- 深入探索Linux网络设置方法(linux如何配置上网)
- 如何使用Linux命令查看网络网关(linux如何查看网关)
- Linux网络设置指南:快速简便的联网方法分享(linux如何联网)
- 轻松学习:Linux 如何安装锐捷网络认证系统(linux安装锐捷)
- Linux下高效的网络编程实践(linux高级网络编程)
- DNS检查Oracle确保安全畅通的网络访问(dns检查oracle)
- 交通运输部:网络货运平台不得恶性低价竞争
- 谈谈新手如何学习PHP网络编程
- 关于.NET/C#/WCF/WPF打造IP网络智能视频监控系统的介绍