zl程序教程

您现在的位置是:首页 >  后端

当前栏目

Pytorch--自己动手实现模型剪枝

PyTorch 实现 -- 模型 自己 动手 剪枝
2023-09-14 09:16:18 时间

pytorch剪枝

尽管Pytorch自带了剪枝的工具,但是其在灵活性上终究敌不过自己手写的剪枝代码,以下就是博主模型剪枝的一次简单尝试。

代码:

import torch.nn as nn
import torch
import torch.nn.functional as F
from torch import optim
from torchvision import datasets,transforms
from torch.utils.data import DataLoader

device = torch.device('cpu')
# 载入训练集
train_dataset = datasets.MNIST(root='./MNIST/',
                               train=True, # 载入训练集
                               transform=transforms.ToTensor(), # 转变为tensor数据
                               download=True)       # 下载数据
#载入测试集
test_dataset = datasets.MNIST(root='./MNIST/',
                               train=False, # 载入测试集
                               transform=transforms.ToTensor(), # 转变为tensor数据
                               download=True)       # 下载数据

# 设置批次大小(每次传入数据量)
batch_size = 64 # 每次训练64张图片的数据

# 装载数据集
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size, #每批数据的大小
                          shuffle=True) # shuffle表示打乱数据
test_loader = DataLoader(dataset=test_dataset,
                          batch_size=batch_size, #每批数据的大小
                          shuffle=True) # shuffle表示打乱数据

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        weight_params1 = torch.nn.init.xavier_uniform_(torch.Tensor(16,1,5,5))
        bias_params1 = torch.zeros((16,),requires_grad=True)
        self.conv1_weight = nn.Parameter(weight_params1)
        self.conv1_bias = nn.Parameter(bias_params1)
        weight_params2 = torch.nn.init.xavier_uniform_(torch.Tensor(32,16,5,5))
        bias_params2 = torch.zeros((32,),requires_grad=True)
        self.conv2_weight = nn.Parameter(weight_params2)
        self.conv2_bias = nn.Parameter(bias_params2)
        self.fc_weight = nn.Parameter(torch.nn.init.xavier_uniform_(torch.Tensor(10,32*7*7)))
        self.fc_bias = nn.Parameter(torch.randn((10,)),requires_grad=True)

        self.conv1_weight=self.conv1_weight.to(device)
        self.conv2_weight=self.conv2_weight.to(device)
        self.conv1_bias=self.conv1_bias.to(device)
        self.conv2_bias=self.conv2_bias.to(device)
        self.fc_weight=self.fc_weight.to(device)
        self.fc_bias=self.fc_bias.to(device)

        self.sparsity=0.5
        #mask矩阵,用于剪枝
        self.register_buffer('conv1_mask', torch.ones((16,1,5,5),dtype=torch.uint8))
        self.register_buffer('conv2_mask', torch.ones((32,16,5,5),dtype=torch.uint8))

    def forward(self,x):
        #更新卷积层1的mask矩阵
        w1=self.conv1_weight.clone().detach()
        w1=torch.where(self.conv1_mask==1,w1,torch.zeros(w1.size()))
        w1=torch.abs(w1)
        sorted,indices=torch.sort(w1.view(-1),descending=False)
        threshold1=sorted[int(sorted.size(0)*self.sparsity)]
        self.conv1_mask=torch.tensor(w1.ge(threshold1),dtype=torch.uint8)
        #print(torch.sum(self.conv1_mask))
        #更新卷积层2的mask矩阵
        w2 = self.conv2_weight.clone().detach()
        w2 = torch.where(self.conv2_mask == 1, w2, torch.zeros(w2.size()))
        w2 = torch.abs(w2)
        sorted, indices = torch.sort(w2.view(-1), descending=False)
        threshold2 = sorted[int(sorted.size(0) * self.sparsity)]
        self.conv2_mask = torch.tensor(w2.ge(threshold2), dtype=torch.uint8)
        #第一个卷积层
        self.conv1_weight.data=self.conv1_weight*self.conv1_mask
        x=F.conv2d(input=x,weight=self.conv1_weight,bias=self.conv1_bias,stride=1,padding=2) #1,28,28 ---> 16,28,28
        x=F.relu(x)
        #池化层
        x=F.max_pool2d(x,kernel_size=2,stride=2)                                             #(16,14,14)
        #第二个卷积层
        self.conv2_weight.data=self.conv2_weight*self.conv2_mask
        x=F.conv2d(input=x,weight=self.conv2_weight,bias=self.conv2_bias,stride=1,padding=2) #16,14,14 ---> 32,14,14
        x=F.relu(x)
        #池化层
        x=F.max_pool2d(x,kernel_size=2,stride=2)                        #32,14,14 --》32,7,7
        #
        x=x.view(x.size(0),-1)                                          #展开成(batch_size,32*7*7)
        #全连接层
        x=F.linear(x,self.fc_weight,bias=self.fc_bias)
        x = F.softmax(x, dim=1)
        return x


model = Net()
model.to(device)
#定义代价函数
mse_loss = nn.MSELoss()
#定义优化器
LR=0.01 #学习率
optimizer = optim.SGD(model.parameters(),lr=LR)

def train_model():
    for i, data in enumerate(train_loader):
        # 循环一次获得一批次的数据与标签
        inputs, labels = data
        inputs, labels = inputs.to(device) , labels.to(device)
        # 获得模型预测结果
        out = model(inputs)
        # to onehot,把数据标签变为独热编码
        labels = labels.reshape(-1, 1)  # 将一维数据变为二维数据(64)->(64,1)
        one_hot = torch.zeros(inputs.shape[0], 10,device=device).scatter(1, labels, 1)
        loss = mse_loss(out, one_hot)
        # 梯度清零
        optimizer.zero_grad()
        # 计算梯度
        loss.backward()
        # 修改权值
        optimizer.step()

def test_model():
    correct = 0
    for i, data in enumerate(test_loader):
        # 获取一批次的数据
        inputs, labels = data
        # 预测结果
        out = model(inputs)
        # 获得最大值即最大值所在的位置
        _, predicted = torch.max(out, 1)
        # 对比预测结果与标签(累积预测正确的数量)
        correct += (predicted == labels).sum()
    print("Test acc:{0}".format(correct.item() / len(test_dataset)))

for epoch in range(20):
    print('epoch:', epoch)
    train_model()
    test_model()

print(model.conv1_weight*model.conv1_mask)
print(model.conv2_weight*model.conv2_mask)
print(torch.sum(model.conv1_mask))
print(torch.sum(model.conv2_mask))

流程简介

其中,主要流程为:
1.设置一个mask张量,值为0则表示对应权重张量相同位置的权值已经被剪枝,为1则表示还未被剪枝。
2.每次前向推理时,首先得到权重张量的一个拷贝,然后,根据mask张量,将已经被剪枝的权值设为0,然后对其求绝对值。
3.将得到的张量展开为一维张量,升序排列,选取第k小的值作为剪枝时的阈值(k=权值数目*稀疏度)。
4.将所有小于该阈值的权值剪枝掉(即更新mask张量的值为0)
5.在进行卷积前,将权值张量首先乘以mask张量,使得被剪去的权值为0,然后再进行卷积。
(注:由于在前向传播过程中,被剪枝的权值乘以了0,因此反向传播时,梯度也为0,以阻止其继续向前传播)

更改

对全连接层进行剪枝,剪枝策略为Bank-Balanced Sparsity策略,这是一种介于非结构化剪枝和块剪枝之间的一种剪枝策略,详见论文:Efficient and Effective Sparse LSTM on FPGA with Bank-Balanced Sparsity
代码如下:

import torch.nn as nn
import torch
import torch.nn.functional as F
from torch import optim
from torchvision import datasets,transforms
from torch.utils.data import DataLoader

device = torch.device('cpu')
# 载入训练集
train_dataset = datasets.MNIST(root='./MNIST/',
                               train=True, # 载入训练集
                               transform=transforms.ToTensor(), # 转变为tensor数据
                               download=True)       # 下载数据
#载入测试集
test_dataset = datasets.MNIST(root='./MNIST/',
                               train=False, # 载入测试集
                               transform=transforms.ToTensor(), # 转变为tensor数据
                               download=True)       # 下载数据

# 设置批次大小(每次传入数据量)
batch_size = 64 # 每次训练64张图片的数据

# 装载数据集
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size, #每批数据的大小
                          shuffle=True) # shuffle表示打乱数据
test_loader = DataLoader(dataset=test_dataset,
                          batch_size=batch_size, #每批数据的大小
                          shuffle=True) # shuffle表示打乱数据

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        weight_params1 = torch.nn.init.xavier_uniform_(torch.Tensor(16,1,5,5))
        bias_params1 = torch.zeros((16,),requires_grad=True)
        self.conv1_weight = nn.Parameter(weight_params1)
        self.conv1_bias = nn.Parameter(bias_params1)
        weight_params2 = torch.nn.init.xavier_uniform_(torch.Tensor(32,16,5,5))
        bias_params2 = torch.zeros((32,),requires_grad=True)
        self.conv2_weight = nn.Parameter(weight_params2)
        self.conv2_bias = nn.Parameter(bias_params2)
        self.fc_weight = nn.Parameter(torch.nn.init.xavier_uniform_(torch.Tensor(10,32*7*7)))
        self.fc_bias = nn.Parameter(torch.randn((10,)),requires_grad=True)

        self.conv1_weight=self.conv1_weight.to(device)
        self.conv2_weight=self.conv2_weight.to(device)
        self.conv1_bias=self.conv1_bias.to(device)
        self.conv2_bias=self.conv2_bias.to(device)
        self.fc_weight=self.fc_weight.to(device)
        self.fc_bias=self.fc_bias.to(device)

        self.sparsity=0.5
        self.bank_size=32
        #mask矩阵,用于剪枝
        self.register_buffer('conv1_mask', torch.ones((16,1,5,5),dtype=torch.uint8))
        self.register_buffer('conv2_mask', torch.ones((32,16,5,5),dtype=torch.uint8))
        self.register_buffer('fc_mask', torch.ones((10,32*7*7),dtype=torch.uint8))

    def forward(self,x):
        #更新卷积层1的mask矩阵
        w1=self.conv1_weight.clone().detach()
        w1=torch.where(self.conv1_mask==1,w1,torch.zeros(w1.size()))
        w1=torch.abs(w1)
        sorted,indices=torch.sort(w1.view(-1),descending=False)
        threshold1=sorted[int(sorted.size(0)*self.sparsity)]
        self.conv1_mask=torch.tensor(w1.ge(threshold1),dtype=torch.uint8).clone().detach()
        #print(torch.sum(self.conv1_mask))
        #更新卷积层2的mask矩阵
        w2 = self.conv2_weight.clone().detach()
        w2 = torch.where(self.conv2_mask == 1, w2, torch.zeros(w2.size()))
        w2 = torch.abs(w2)
        sorted, indices = torch.sort(w2.view(-1), descending=False)
        threshold2 = sorted[int(sorted.size(0) * self.sparsity)]
        self.conv2_mask = torch.tensor(w2.ge(threshold2), dtype=torch.uint8).clone().detach()
        #更新全连接层的mask矩阵
        w3 = self.fc_weight.clone().detach()
        w3 = torch.where(self.fc_mask==1,w3,torch.zeros(w3.size()))       #根据mask,已经被剪枝的权值设为0
        w3 =torch.abs(w3)                                                 #求绝对值作为其重要性的度量
        for i in range(10):
            for j in range(32*7*7//self.bank_size):                       #在每一个bank内进行细粒度剪枝
                bank_weight=w3[i,j*self.bank_size:j*self.bank_size+self.bank_size]                          #获取当前块
                sorted,indices=torch.sort(bank_weight,descending=False)                                     #升序排列
                threshold3=sorted[int(self.bank_size*self.sparsity)]                                        #
                bank_mask=torch.tensor(bank_weight.ge(threshold3),dtype=torch.uint8).clone().detach()       #大于阈值的为1,即保留,小于阈值的剪去
                self.fc_mask[i,j*self.bank_size:j*self.bank_size+self.bank_size]=bank_mask                  #更新mask张量
        #print(torch.sum(self.fc_mask))
        #第一个卷积层
        self.conv1_weight.data=self.conv1_weight*self.conv1_mask
        x=F.conv2d(input=x,weight=self.conv1_weight,bias=self.conv1_bias,stride=1,padding=2) #1,28,28 ---> 16,28,28
        x=F.relu(x)
        #池化层
        x=F.max_pool2d(x,kernel_size=2,stride=2)                                             #(16,14,14)
        #第二个卷积层
        self.conv2_weight.data=self.conv2_weight*self.conv2_mask
        x=F.conv2d(input=x,weight=self.conv2_weight,bias=self.conv2_bias,stride=1,padding=2) #16,14,14 ---> 32,14,14
        x=F.relu(x)
        #池化层
        x=F.max_pool2d(x,kernel_size=2,stride=2)                        #32,14,14 --》32,7,7
        #
        x=x.view(x.size(0),-1)                                          #展开成(batch_size,32*7*7)
        #全连接层
        self.fc_weight.data=self.fc_weight*self.fc_mask
        x=F.linear(x,self.fc_weight,bias=self.fc_bias)
        x = F.softmax(x, dim=1)
        return x


model = Net()
model.to(device)
#定义代价函数
mse_loss = nn.MSELoss()
#定义优化器
LR=0.01 #学习率
optimizer = optim.SGD(model.parameters(),lr=LR)

def train_model():
    for i, data in enumerate(train_loader):
        # 循环一次获得一批次的数据与标签
        inputs, labels = data
        inputs, labels = inputs.to(device) , labels.to(device)
        # 获得模型预测结果
        out = model(inputs)
        # to onehot,把数据标签变为独热编码
        labels = labels.reshape(-1, 1)  # 将一维数据变为二维数据(64)->(64,1)
        one_hot = torch.zeros(inputs.shape[0], 10,device=device).scatter(1, labels, 1)
        loss = mse_loss(out, one_hot)
        # 梯度清零
        optimizer.zero_grad()
        # 计算梯度
        loss.backward()
        # 修改权值
        optimizer.step()

def test_model():
    correct = 0
    for i, data in enumerate(test_loader):
        # 获取一批次的数据
        inputs, labels = data
        # 预测结果
        out = model(inputs)
        # 获得最大值即最大值所在的位置
        _, predicted = torch.max(out, 1)
        # 对比预测结果与标签(累积预测正确的数量)
        correct += (predicted == labels).sum()
    print("Test acc:{0}".format(correct.item() / len(test_dataset)))

trained=True

if not trained:
    for epoch in range(20):
        print('epoch:', epoch)
        train_model()
        test_model()
    torch.save(model.state_dict(), "mynet.pth")
else:
    model = Net()
    # 加载预训练模型的参数
    model.load_state_dict(torch.load("mynet.pth"))
    test_model()
    print(torch.sum(model.conv1_mask))
    print(torch.sum(model.conv2_mask))
    print(torch.sum(model.fc_mask))
    wfc=model.fc_mask
    for i in range(10):
        for j in range(32*7*7//32):
            print(torch.sum(wfc[i,j*32:j*32+32]))
    print(model.fc_weight)

# print(model.conv1_weight*model.conv1_mask)
# print(model.conv2_weight*model.conv2_mask)
# print(torch.sum(model.conv1_mask))
# print(torch.sum(model.conv2_mask))