zl程序教程

您现在的位置是:首页 >  其他

当前栏目

MLP-Mixer [代码实现(基于MNIST)]

代码 实现 基于 mnist
2023-09-27 14:20:42 时间

本文基于lucidrains在gitHub上开源的mlp-mixer-pytorch实现mlp-mixer在MNIST数据上的demo。

lucidrains/mlp-mixer-pytorch

Fig 1给出了MLP-Mixer的宏观建构示意图,它以一系列图像块的线性投影(其形状为patches x channels)作为输入。Mixer采用了两种类型的MLP层(注:这两种类型的层交替执行以促进两个维度见的信息交互):

  • channel-mixingMLP:用于不同通道前通讯,每个token独立处理,即采用每一行作为输入;
  • token-mixingMLP:用于不同空域位置通讯,每个通道图例处理,即采用每一列作为输入。
    在这里插入图片描述

GLOM:如何在神经网络中表示部分-整体层次结构?[代码实现(基于MNIST)]中有介绍:Conv1d的一个特点,就是当kernel_size为1时,它等价于一个nn.Linear。虽然,np.conv1d与np.linear都能实现MLP,但是,它们在实现上有一个没有明说的不同之处。

假设10个8维的数据堆成10个channel,记为 X ∈ R 10 × 8 X\in R^{10\times 8} XR10×8的矩阵。将 X X X输入到Conv1d,是相当取矩阵的每列作为输入。而np.linear则是将矩阵的每行作为输入。lucidrains/mlp-mixer-pytorch中实现的相对需要注意的地方就是这两种MLP实现的选择,即达到MLP-Mixer论文中,行列交替两种类型的MLP。

from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce

class PreNormResidual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.fn(self.norm(x)) + x

def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
    return nn.Sequential(
        dense(dim, dim * expansion_factor),
        nn.GELU(),
        nn.Dropout(dropout),
        dense(dim * expansion_factor, dim),
        nn.Dropout(dropout)
    )

def MLPMixer(*, image_size, patch_size, dim, depth, num_classes, expansion_factor = 4, dropout = 0.):
    assert (image_size % patch_size) == 0, 'image must be divisible by patch size'
    num_patches = (image_size // patch_size) ** 2
    chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear  # 实现行列交替两种MLP的关键

    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
        nn.Linear((patch_size ** 2) * 3, dim),  # 当图片channel为1,例如Mnist图片数据,则3要改为1.
        *[nn.Sequential(
            PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
            PreNormResidual(dim, FeedForward(dim, expansion_factor, dropout, chan_last))
        ) for _ in range(depth)],
        nn.LayerNorm(dim),
        Reduce('b n c -> b c', 'mean'),
        nn.Linear(dim, num_classes)
    )

基于MNIST的demo

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader

n_epochs = 30
batch_size_train = 52
batch_size_test = 100
learning_rate = 0.01
momentum = 0.5
log_interval = 10
random_seed = 1
torch.manual_seed(random_seed)
img_height = 28
img_width = 28


from torchvision.datasets import MNIST

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.1307,), (0.3081,))])

trainset = MNIST(root = './data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size_train, shuffle=True, num_workers=2)

testset = MNIST(root = './data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size_test, shuffle=False, num_workers=2)


# DEVICE = torch.device("cpu")
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MLPMixer(
    image_size = 28,
    patch_size = 7,
    dim = 14,
    depth = 3,
    num_classes = 10
)

model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)

mse = nn.MSELoss()


# 定义训练函数
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        batch_size_train = data.shape[0]
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        pre_out = model(data)
        targ_out = torch.nn.functional.one_hot(target,num_classes=10)
        targ_out = targ_out.view((batch_size_train,10)).float()
        loss = mse(pre_out, targ_out)
        loss.backward()
        optimizer.step()
        if (batch_idx + 1) % 300 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

# 定义测试函数
def test(model, device, test_loader):
    model.eval()
    test_loss =0
    with torch.no_grad():
        for data, target in test_loader:
            batch_size_test = data.shape[0]
            data, target = data.to(device), target.to(device)
            pre_out = model(data)
            targ_out = torch.nn.functional.one_hot(target,num_classes=10)
            targ_out = targ_out.view((batch_size_test,10)).float()
            test_loss += mse(pre_out, targ_out) # 将一批的损失相加
    
    test_loss /= len(test_loader.dataset)
    print("nTest set: Average loss: {:.4f}".format(test_loss))

    
for epoch in range(n_epochs):               
    train(model, DEVICE, train_loader, optimizer, epoch)
    test(model, DEVICE, test_loader)
    torch.save(model.state_dict(), './model.pth')
    torch.save(optimizer.state_dict(), './optimizer.pth')

在这里插入图片描述
注:MLP-Mixer在MNIST数据集上的分类任务是可行的,但是在MNIST数据集上的自编码任务收敛不了。希望有新结论的小伙伴能跟我分享你们的结果。