zl程序教程

您现在的位置是:首页 >  其他

当前栏目

Pytorch加载模型时修改输入通道数

输入PyTorch 模型 修改 加载 通道
2023-09-11 14:20:14 时间

修改resnet50的输入,将输入改为灰度图

#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import torch
from torch import nn
from torchvision.models import resnet50, ResNet50_Weights

resnet_50 = resnet50(weights=ResNet50_Weights.DEFAULT)
rgb_input = torch.zeros((1, 3, 224, 224))
print(resnet_50(rgb_input).shape)
print(resnet_50.conv1)

weight = resnet_50.conv1.weight.sum(dim=1, keepdim=True)
print(weight.shape)
resnet_50.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
resnet_50.conv1.weight = nn.Parameter(weight)
# resnet_50.conv1.weight.data = temp
print(resnet_50.conv1.weight.requires_grad)

gray_input = torch.zeros((1, 1, 224, 224))
print(resnet_50(gray_input).shape)

修改vgg16的输入,将输入改为灰度图

#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import torch
from torch import nn
from torchvision.models import vgg16, VGG16_Weights

vgg_16 = vgg16(weights=VGG16_Weights.DEFAULT)
rgb_input = torch.zeros((1, 3, 224, 224))
print(vgg_16(rgb_input).shape)
print(vgg_16.features[0])

weight = vgg_16.features[0].weight.sum(dim=1, keepdim=True)
bias = vgg_16.features[0].bias
print(weight.shape)
vgg_16.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
vgg_16.features[0].weight = nn.Parameter(weight)
vgg_16.features[0].bias = nn.Parameter(bias)
print(vgg_16.features[0].weight.requires_grad)
print(vgg_16.features[0].bias.requires_grad)

gray_input = torch.zeros((1, 1, 224, 224))
print(vgg_16(gray_input).shape)