zl程序教程

您现在的位置是:首页 >  其它

当前栏目

Lenet5网络结构

网络结构
2023-09-27 14:25:47 时间

Lenet5网络是深度学习中最基本的网络结构,开始于90年代,最早是应用于手写数字识别。受限于当时的环境,所以一开始不怎么出名。但是,在2012年,出现了Alexnet,在图像分类领域打败了所有机器学习方法。深度学习开始变得火热。

import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet5(nn.Module):
    def __init__(self, num_classes, grayscale=False):
        """
        num_classes: 分类的数量
        grayscale:是否为灰度图
        """
        super(LeNet5, self).__init__()

        self.grayscale = grayscale
        self.num_classes = num_classes
        if self.grayscale: # 可以适用单通道和三通道的图像
            in_channels = 1
        else:
            in_channels = 3

        # 卷积神经网络
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 6, kernel_size=5),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(6, 16, kernel_size=5),
            nn.MaxPool2d(kernel_size=2)   # 原始的模型使用的是 平均池化
        )
        # 分类器
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),  # 这里把第三个卷积当作是全连接层了
            nn.Linear(120, 84),
            nn.Linear(84, num_classes)
        )

    def forward(self, x):
        x = self.features(x) # 输出 16*5*5 特征图
        x = torch.flatten(x, 1) # 展平 (1, 16*5*5)
        logits = self.classifier(x) # 输出 10
        probas = F.softmax(logits, dim=1)
        return logits, probas