PyTorch实现ResNet18
PyTorch 实现 Resnet18
2023-06-13 09:11:51 时间
大家好,又见面了,我是你们的朋友全栈君。
ResNet-18结构
基本结点
代码实现
import torch
import torch.nn as nn
from torch.nn import functional as F
class RestNetBasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride):
super(RestNetBasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
output = self.conv1(x)
output = F.relu(self.bn1(output))
output = self.conv2(output)
output = self.bn2(output)
return F.relu(x + output)
class RestNetDownBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride):
super(RestNetDownBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.extra = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=0),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
extra_x = self.extra(x)
output = self.conv1(x)
out = F.relu(self.bn1(output))
out = self.conv2(out)
out = self.bn2(out)
return F.relu(extra_x + out)
class RestNet18(nn.Module):
def __init__(self):
super(RestNet18, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1),
RestNetBasicBlock(64, 64, 1))
self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2, 1]),
RestNetBasicBlock(128, 128, 1))
self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2, 1]),
RestNetBasicBlock(256, 256, 1))
self.layer4 = nn.Sequential(RestNetDownBlock(256, 512, [2, 1]),
RestNetBasicBlock(512, 512, 1))
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.fc = nn.Linear(512, 10)
def forward(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.avgpool(out)
out = out.reshape(x.shape[0], -1)
out = self.fc(out)
return out
用来预测CIFAR-10数据集
数据集
官网链接:CIFAR-10 DATASET
测试代码
import torch
from torch import nn, optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from restnet18.restnet18 import RestNet18
# 用CIFAR-10 数据集进行实验
def main():
batchsz = 128
cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
x, label = iter(cifar_train).next()
print('x:', x.shape, 'label:', label.shape)
device = torch.device('cuda')
# model = Lenet5().to(device)
model = RestNet18().to(device)
criteon = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)
for epoch in range(1000):
model.train()
for batchidx, (x, label) in enumerate(cifar_train):
# [b, 3, 32, 32]
# [b]
x, label = x.to(device), label.to(device)
logits = model(x)
# logits: [b, 10]
# label: [b]
# loss: tensor scalar
loss = criteon(logits, label)
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch, 'loss:', loss.item())
model.eval()
with torch.no_grad():
# test
total_correct = 0
total_num = 0
for x, label in cifar_test:
# [b, 3, 32, 32]
# [b]
x, label = x.to(device), label.to(device)
# [b, 10]
logits = model(x)
# [b]
pred = logits.argmax(dim=1)
# [b] vs [b] => scalar tensor
correct = torch.eq(pred, label).float().sum().item()
total_correct += correct
total_num += x.size(0)
# print(correct)
acc = total_correct / total_num
print(epoch, 'test acc:', acc)
if __name__ == '__main__':
main()
运行结果
感觉挺low的,迭代50多次能达到80多的准确率
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/141287.html原文链接:https://javaforall.cn
相关文章
- 深度学习四大框架之争(Tensorflow、Pytorch、Keras和Paddle)
- 计算机视觉中自注意力构建块的PyTorch实现
- 【干货书】深度强化学习Python实战:算法的简洁实现,简化数学,以及TensorFlow和PyTorch的使用
- D2L学习笔记00:Pytorch操作
- pytorch – ohem 代码实现
- pytorch 自定义卷积核进行卷积操作[通俗易懂]
- OHEM的pytorch代码实现细节
- python2.7安装pytorch_PyTorch安装「建议收藏」
- Pytorch(一) Pytorch 的安装[通俗易懂]
- pytorch实现textCNN
- Pytorch实现YOLOv3训练自己的数据集
- CBAM:Convolutional Block Attention Module流程详解及Pytorch实现
- BN层pytorch实现[通俗易懂]
- 联邦学习算法FedPer的PyTorch实现
- 深度学习环境搭建之cuda、cudnn以及pytorch和torchvision的whl文件安装方法
- PyTorch实现的ResNet50、ResNet101和ResNet152
- PyTorch: 计算图与动态图机制
- 超分辨率——基于SRGAN的图像超分辨率重建(Pytorch实现)
- 【Pytorch基础】逻辑回归
- 用少于256KB内存实现边缘训练,开销不到PyTorch千分之一
- 超简单的pyTorch训练->onnx模型->C++ OpenCV DNN推理(附源码地址)
- Windows 版的3D目标检测框架 smoke PyTorch 实现
- 一行代码,炼丹2倍速!PyTorch 2.0惊喜问世,LeCun激情转发
- 快手八卦!突破TensorFlow、PyTorch并行瓶颈的开源分布式训练框架来了!
- 分离硬件和代码、稳定 API,PyTorch Lightning 1.0.0 版本正式发布
- 图像风格迁移也有框架了:使用Python编写,与PyTorch完美兼容,外行也能用