FPN的代码理解(可视化)
代码 理解 可视化
2023-09-27 14:20:15 时间
代码出处:facebookresearch/maskrcnn-benchmark
可视化的代码我也放上来了,自己有需要就下载:
https://download.csdn.net/download/weixin_42899627/84720371
maskrcnn_benchmark/modeling/backbone/fpn.py
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torch.nn.functional as F
from torch import nn
class FPN(nn.Module):
"""
Module that adds FPN on top of a list of feature maps.
The feature maps are currently supposed to be in increasing depth
order, and must be consecutive
"""
def __init__(
self, in_channels_list, out_channels, conv_block, top_blocks=None
):
"""
Arguments:
in_channels_list (list[int]): number of channels for each feature map that
will be fed
out_channels (int): number of channels of the FPN representation
top_blocks (nn.Module or None): if provided, an extra operation will
be performed on the output of the last (smallest resolution)
FPN output, and the result will extend the result list
"""
super(FPN, self).__init__()
self.inner_blocks = []
self.layer_blocks = []
for idx, in_channels in enumerate(in_channels_list, 1):
#print(idx,in_channels) # 256 512 1024 2048
inner_block = "fpn_inner{}".format(idx)
layer_block = "fpn_layer{}".format(idx)
if in_channels == 0:
continue
inner_block_module = conv_block(in_channels, out_channels, 1)#变换通道数
layer_block_module = conv_block(out_channels, out_channels, 3, 1)#每个层加3X3卷积来避免混叠效应?
#print("inner_block_module,layer_block_module:\n",inner_block_module,layer_block_module,'\n')
self.add_module(inner_block, inner_block_module)
self.add_module(layer_block, layer_block_module)
self.inner_blocks.append(inner_block)
self.layer_blocks.append(layer_block)
self.top_blocks = top_blocks
def forward(self, x):
"""
Arguments:
x (list[Tensor]): feature maps for each feature level.
Returns:
results (tuple[Tensor]): feature maps after FPN layers.
They are ordered from highest resolution first.
"""
#从最后一层开始 last_inner:torch.Size([1, 2048, 25, 34])>>([1, 256, 25, 34])
last_inner = getattr(self, self.inner_blocks[-1])(x[-1])#getattr用于返回一个对象属性值。
print("\nlast_inner:\n",x[-1].shape,last_inner.shape)
results = []
results.append(getattr(self, self.layer_blocks[-1])(last_inner))#3X3卷积不改通道数256
for feature, inner_block, layer_block in zip(
x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
):
print("\nfeature, inner_block, layer_block:\n",feature.shape, inner_block, layer_block,'\n')
if not inner_block:
continue
inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest")#从上到下,上采样(scale尺寸上采倍数)
inner_lateral = getattr(self, inner_block)(feature)#1x1卷积改变通道数(调用inner_block_module)
# TODO use size instead of scale to make it robust to different sizes
# inner_top_down = F.upsample(last_inner, size=inner_lateral.shape[-2:], mode='bilinear', align_corners=False)
last_inner = inner_lateral + inner_top_down #像素点相加
results.insert(0, getattr(self, layer_block)(last_inner)) #3x3卷积消除混叠效应(调用layer_block_module)
if isinstance(self.top_blocks, LastLevelP6P7):
last_results = self.top_blocks(x[-1], results[-1])
results.extend(last_results)
elif isinstance(self.top_blocks, LastLevelMaxPool):
last_results = self.top_blocks(results[-1])
results.extend(last_results)
print("\nresults:",len(results),results[0].shape,results[1].shape,results[2].shape,results[3].shape,results[4].shape)
return tuple(results)
class LastLevelMaxPool(nn.Module):
def forward(self, x):
return [F.max_pool2d(x, 1, 2, 0)]#kernel_size=1,stride=2,pad=0
(body): ResNet50('''只列出backbone中与fpn有关的部分''')
image_size: 800,1066(高,长)
(stem): (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
>> max_pool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))'''池化缩小尺寸'''
featuremap_size: 200,272
(layer1): (downsample): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
featuremap_size: 200,272
(layer2): (downsample): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
featuremap_size: 100,136
(layer3): (downsample): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
featuremap_size: 50,68
(layer4): (downsample): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
featuremap_size: 25,34
(fpn): FPN
(fpn_inner1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(fpn_layer1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fpn_inner2): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
(fpn_layer2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fpn_inner3): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
(fpn_layer3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fpn_inner4): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
(fpn_layer4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(top_blocks): LastLevelMaxPool()
inner_block_module 就是进行1x1的卷积操作(改变通道数为256)
layer_block_module 就是进行3X3的卷积操作(避免出现混叠现象)
inner_block 相当于1x1的卷积操作的名称
layer_block 相当于3X3的卷积操作的名称
getattr()是python的内置函数,可以获取object对象的属性的值,
简单的理解,就是可以通过指定函数的名称来调用函数
举例:last_inner = getattr(self, self.inner_blocks[-1])(x[-1])
说明:x[-1]相当于最后一层的特征图torch.Size([1, 2048, 25, 34]),然后进行了1x1的卷积操作
last_inner = getattr(self, self.inner_blocks[-1])(x[-1])#getattr用于返回一个对象属性值。
print("\nlast_inner:\n",x[-1].shape,last_inner.shape)
--------------------------------------------------------------------
输出结果:
last_inner:
torch.Size([1, 2048, 25, 34]) torch.Size([1, 256, 25, 34])
上采样、1X1卷积、像素值相加 和 3X3卷积 的代码示例
'''从上到下,上采样(scale尺寸上采倍数)'''
inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest")
'''1x1卷积改变通道数(调用inner_block_module)'''
inner_lateral = getattr(self, inner_block)(feature)
'''像素点相加'''
last_inner = inner_lateral + inner_top_down
'''3x3卷积消除混叠效应(调用layer_block_module)'''
results.insert(0, getattr(self, layer_block)(last_inner))
可视化的效果
上面是原图
下面是第5层的特征图(经过1X1卷积变成256通道的特征图),后面的上采样都是基于第5层的特征图
subplot:
第一行是对应的是第4层、第3层、第2层的原始的特征图(自下而上)
第二行是上采样后的特征图(自上而下)
第三行是原始特征图和上采样特征图像素点相加的特征图(特征融合)
第四行是对融合特征图进行3X3卷积后的特征图(避免混叠)
注意:这里只是可视化每一层的第1张特征图(要显示256张也可以,不好演示)
相关文章
- Python代码 Base64 格式图片上传,Base64格式理解
- 67_sentiment_analysis理论代码理解
- c#的托管代码和非托管代码的理解
- C# (初入江湖)-几行代码也可以写个电脑屏保
- 【JAVA】字符串处理大全!(笔记代码,逐步更新中)
- Jenkins与网站代码上线解决方案
- 转载:如何避免代码中的if嵌套
- 编写易于理解代码的六种方式
- 使用 Stream API 高逼格 优化 Java 代码
- 我对图像金字塔的理解及OpenCV下的实现代码
- 卡尔曼滤波算法的代码验证
- MFC:通过代码简单理解进程间的通讯机制——共享内存
- 【用POM设计模式重构项目】POM设计模式:理解页面-对象-模型设计模式,实现代码的松耦合
- 某大佬对代码审计的理解
- Swift - 纯代码实现页面segue跳转,以及参数传递
- Python代码库之正则表达式提出字符串内容(含demo源码)
- 心跳机制tcp keepalive的讨论、应用及“断网”、"断电"检测的C代码实现(Windows环境下)
- 使用vscode中使用git提交代码到gitLab
- 华为OD机试 - 统计差异值大于相似值二元组个数(Python) | 机试题+算法思路+考点+代码解析 【2023】
- 华为OD机试 -内存资源分配(Java) | 机试题+算法思路+考点+代码解析 【2023】
- 浅析我对代码规范的理解
- 浅析我对代码规范的理解
- MySql中把一个表的数据插入到另一个表中的实现代码--转
- 深入理解JavaScript系列(45):代码复用模式(避免篇)
- 多边形标注收缩python代码实现
- GAMES202作业1-万字分析代码框架&帮助更好理解框架内容
- GAMES101作业3-代码过程详细理解
- GAMES101作业2-代码框架逐行理解(c++基础巩固)
- js 使FORM表单的所有元素不可编辑的示例代码 表