MLP-Mixer [代码实现(基于MNIST)]
本文基于lucidrains在gitHub上开源的mlp-mixer-pytorch实现mlp-mixer在MNIST数据上的demo。
- MLP-Mixer: An all-MLP Architecture for Vision[PDF]
- MLP-Mixer简介和一些想法 [CSDN博客]
- 新坑来了!谷歌提出MLP-Mixer:一种用于视觉的全MLP架构[CSDN博客]
- lucidrains/mlp-mixer-pytorch [pytorch 实现]
lucidrains/mlp-mixer-pytorch
Fig 1给出了MLP-Mixer的宏观建构示意图,它以一系列图像块的线性投影(其形状为patches x channels)作为输入。Mixer采用了两种类型的MLP层(注:这两种类型的层交替执行以促进两个维度见的信息交互):
- channel-mixingMLP:用于不同通道前通讯,每个token独立处理,即采用每一行作为输入;
- token-mixingMLP:用于不同空域位置通讯,每个通道图例处理,即采用每一列作为输入。
在GLOM:如何在神经网络中表示部分-整体层次结构?[代码实现(基于MNIST)]中有介绍:Conv1d的一个特点,就是当kernel_size为1时,它等价于一个nn.Linear。虽然,np.conv1d与np.linear都能实现MLP,但是,它们在实现上有一个没有明说的不同之处。
假设10个8维的数据堆成10个channel,记为 X ∈ R 10 × 8 X\in R^{10\times 8} X∈R10×8的矩阵。将 X X X输入到Conv1d,是相当取矩阵的每列作为输入。而np.linear则是将矩阵的每行作为输入。lucidrains/mlp-mixer-pytorch中实现的相对需要注意的地方就是这两种MLP实现的选择,即达到MLP-Mixer论文中,行列交替两种类型的MLP。
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
return self.fn(self.norm(x)) + x
def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
return nn.Sequential(
dense(dim, dim * expansion_factor),
nn.GELU(),
nn.Dropout(dropout),
dense(dim * expansion_factor, dim),
nn.Dropout(dropout)
)
def MLPMixer(*, image_size, patch_size, dim, depth, num_classes, expansion_factor = 4, dropout = 0.):
assert (image_size % patch_size) == 0, 'image must be divisible by patch size'
num_patches = (image_size // patch_size) ** 2
chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear # 实现行列交替两种MLP的关键
return nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear((patch_size ** 2) * 3, dim), # 当图片channel为1,例如Mnist图片数据,则3要改为1.
*[nn.Sequential(
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
PreNormResidual(dim, FeedForward(dim, expansion_factor, dropout, chan_last))
) for _ in range(depth)],
nn.LayerNorm(dim),
Reduce('b n c -> b c', 'mean'),
nn.Linear(dim, num_classes)
)
基于MNIST的demo
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
n_epochs = 30
batch_size_train = 52
batch_size_test = 100
learning_rate = 0.01
momentum = 0.5
log_interval = 10
random_seed = 1
torch.manual_seed(random_seed)
img_height = 28
img_width = 28
from torchvision.datasets import MNIST
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
trainset = MNIST(root = './data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size_train, shuffle=True, num_workers=2)
testset = MNIST(root = './data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size_test, shuffle=False, num_workers=2)
# DEVICE = torch.device("cpu")
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MLPMixer(
image_size = 28,
patch_size = 7,
dim = 14,
depth = 3,
num_classes = 10
)
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
mse = nn.MSELoss()
# 定义训练函数
def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
batch_size_train = data.shape[0]
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
pre_out = model(data)
targ_out = torch.nn.functional.one_hot(target,num_classes=10)
targ_out = targ_out.view((batch_size_train,10)).float()
loss = mse(pre_out, targ_out)
loss.backward()
optimizer.step()
if (batch_idx + 1) % 300 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
# 定义测试函数
def test(model, device, test_loader):
model.eval()
test_loss =0
with torch.no_grad():
for data, target in test_loader:
batch_size_test = data.shape[0]
data, target = data.to(device), target.to(device)
pre_out = model(data)
targ_out = torch.nn.functional.one_hot(target,num_classes=10)
targ_out = targ_out.view((batch_size_test,10)).float()
test_loss += mse(pre_out, targ_out) # 将一批的损失相加
test_loss /= len(test_loader.dataset)
print("nTest set: Average loss: {:.4f}".format(test_loss))
for epoch in range(n_epochs):
train(model, DEVICE, train_loader, optimizer, epoch)
test(model, DEVICE, test_loader)
torch.save(model.state_dict(), './model.pth')
torch.save(optimizer.state_dict(), './optimizer.pth')
注:MLP-Mixer在MNIST数据集上的分类任务是可行的,但是在MNIST数据集上的自编码任务收敛不了。希望有新结论的小伙伴能跟我分享你们的结果。
相关文章
- 【场景削减】基于 Kantorovich 距离的 SBR 算法场景削减研究(Matlab代码实现)
- 【轴承故障检测】滚动轴承中进行基于振动的故障诊断研究(Matlab代码实现)
- 基于小生境粒子群优化算法的考虑光伏波动性的主动配电网有功无功协调优化(Matlab代码实现)
- 基于激励的需求响应计划下弹性微电网的短期可靠性和经济性评估(Matlab代码实现)
- 基于风光储能和需求响应的微电网日前经济调度(Python代码实现)【1】
- 基于改进粒子群算法的混合储能系统容量优化(Matlab代码实现)
- 基于蝙蝠算法实现电力系统经济调度(Matlab代码实现)
- 电力系统电价与温度模型(Matlab代码实现)
- 新型智能优化算法——海鸥优化算法(基于Matlab代码实现)
- 【VRP问题】基于企鹅优化算法求解冷链配送物流车辆调度优化研究(Matlab代码实现)
- 基于改进遗传算法的卡车和两架无人机旅行推销员问题(D2TSP)(Matlab代码实现)
- 基于智能算法的无人机路径规划研究(Matlab代码实现)
- 利用Astar算法实现飞行轨迹的三维规划(基于Matlab代码实现)
- 基于遗传算法和非线性规划的函数寻优算法(Matlab代码实现)
- 利用 MLP(多层感知器)和 RBF(径向基函数)神经网络解决的近似和分类示例问题(Matlab代码实现)
- 【车间调度】基于GA/PSO/SA/ACO/TS优化算法的车间调度比较(Matlab代码实现)
- [ci] 基于1 上文实现拉取代码后能自动触发sonar-runner实现代码扫描评测,job1完成
- 「经典题」JavaScript中的预加载—概念、作用、以及代码实现方式
- 《数学与泛型编程:高效编程的奥秘》一3.3 实现该算法并优化其代码
- java代码连接数据库编码实现
- 第7.24节 Python案例详解:使用property函数定义属性简化属性访问代码实现
- 一致性哈希算法——PHP实现代码
- 基于jQuery实现文字倾斜显示代码
- 一遍记住Java常用的八种排序算法与代码实现
- js金额数字格式化实现代码(三位加逗号处理保留两位置小数)
- RabbitMQ延时队列的详细介绍以及Java代码实现