Pytorch加载模型时修改输入通道数
2023-09-11 14:20:14 时间
修改resnet50的输入,将输入改为灰度图
#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import torch
from torch import nn
from torchvision.models import resnet50, ResNet50_Weights
resnet_50 = resnet50(weights=ResNet50_Weights.DEFAULT)
rgb_input = torch.zeros((1, 3, 224, 224))
print(resnet_50(rgb_input).shape)
print(resnet_50.conv1)
weight = resnet_50.conv1.weight.sum(dim=1, keepdim=True)
print(weight.shape)
resnet_50.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
resnet_50.conv1.weight = nn.Parameter(weight)
# resnet_50.conv1.weight.data = temp
print(resnet_50.conv1.weight.requires_grad)
gray_input = torch.zeros((1, 1, 224, 224))
print(resnet_50(gray_input).shape)
修改vgg16的输入,将输入改为灰度图
#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import torch
from torch import nn
from torchvision.models import vgg16, VGG16_Weights
vgg_16 = vgg16(weights=VGG16_Weights.DEFAULT)
rgb_input = torch.zeros((1, 3, 224, 224))
print(vgg_16(rgb_input).shape)
print(vgg_16.features[0])
weight = vgg_16.features[0].weight.sum(dim=1, keepdim=True)
bias = vgg_16.features[0].bias
print(weight.shape)
vgg_16.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
vgg_16.features[0].weight = nn.Parameter(weight)
vgg_16.features[0].bias = nn.Parameter(bias)
print(vgg_16.features[0].weight.requires_grad)
print(vgg_16.features[0].bias.requires_grad)
gray_input = torch.zeros((1, 1, 224, 224))
print(vgg_16(gray_input).shape)
相关文章
- Google Earth Engine ——数据全解析专辑(世界第 4 版网格化人口 (GPWv4) 修订版30 弧秒1公里格网)人口计数和密度网格的输入单元的平均面积数据集
- js 验证 输入值 全是数字
- Vue - 根据输入关键字过滤数组列表(列表搜索功能)
- 判断输入的数是否为数字,不使用isNaN
- formValidator输入验证、异步验证实例 + licenseImage验证码插件实例应用
- 华为商品管理系统批量更新商品时提示:请至少输入一组国家码和价格
- 华为商品管理系统批量更新商品时提示:请至少输入一组国家码和价格
- SwiftUI 问答之 如何使用帮助将用户输入的值与表格进行比较
- 【STM32】输入捕获实验原理
- 【1.1】shell基本实践——密码输入三次错误则结束
- 输入 1,2,4,5,78 返回 (1, 78, 2, 4, 5, 90) 返回形式:最小值 最大值 其余值 及 总和
- [LeetCode] 1320. Minimum Distance to Type a Word Using Two Fingers 二指输入的的最小距离
- [LeetCode] Two Sum IV - Input is a BST 两数之和之四 - 输入是二叉搜索树