【PyTorch】教程:学习基础知识-(7) Optimization
OPTIMIZING MODEL PARAMETERS (模型参数优化)
现在我们有了模型和数据,是时候通过优化数据上的参数来训练了,验证和测试我们的模型。训练一个模型是一个迭代的过程,在每次迭代中,模型会对输出进行猜测,计算猜测数据与真实数据的误差(损失),收集误差对其参数的导数(正如前一节我们看到的那样),并使用梯度下降优化这些参数。
Prerequisite Code ( 先决代码 )
We load the code from the previous sections on
我们从前面的章节中直接加载这些代码。Datasets-DataLoader ,构建模型
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
training_data = datasets.FashionMNIST(
root = "../../data/",
train = True,
download = True,
transform = transforms.ToTensor()
)
test_data = datasets.FashionMNIST(
root = "../../data/",
train = False,
download = True,
transform = transforms.ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size = 32, shuffle = True)
test_dataloader = DataLoader(test_data, batch_size = 32, shuffle = True)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
out = self.flatten(x)
out = self.linear_relu_stack(out)
return out
model = NeuralNetwork()
Hyperparameters ( 超参数 )
超参数是可调节的参数,允许控制模型优化过程,不同的超参数会影响模型的训练和收敛速度。read more
我们定义如下的超参数进行训练:
- Number of Epochs: 遍历数据集的次数
- Batch Size: 每一次使用的数据集大小,即每一次用于训练的样本数量
- Learning Rate: 每个 batch/epoch 更新模型参数的速度,较小的值会导致较慢的学习速度,而较大的值可能会导致训练过程中不可预测的行为,例如训练抖动频繁,有可能会发散等。
learning_rate = 1e-3
batch_size = 32
epochs = 5
Optimization Loop ( 优化循环 )
我们设置完超参数后,就可以利用优化循环训练和优化模型;优化循环的每次迭代称为一个 epoch, 每个 epoch 包含两个主要部分:
- The Train Loop: 遍历训练数据集并尝试收敛到最优参数。
- The Validation/Test Loop: 验证/测试循环—遍历测试数据集以检查模型性能是否得到改善。
让我们简单地熟悉一下训练循环中使用的一些概念。跳转到前面以查看优化循环的完整实现。
Loss Function ( 损失函数 )
当给出一些训练数据时,我们未经训练的网络可能不会给出正确的答案。 Loss function 衡量的是得到的结果与目标值的不相似程度,是我们在训练过程中想要最小化的 Loss function。为了计算 loss ,我们使用给定数据样本的输入进行预测,并将其与真实的数据标签值进行比较。
常见的损失函数包括nn.MSELoss (均方误差)用于回归任务,nn.NLLLoss(负对数似然)用于分类神经网络。nn.CrossEntropyLoss 结合 nn.LogSoftmax 和 nn.NLLLoss 。
我们将模型的输出 logits 传递给 nn.CrossEntropyLoss ,它将规范化 logits 并计算预测误差。
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()
Optimizer ( 优化器 )
优化是在每个训练步骤中调整模型参数以减少模型误差的过程。优化算法定义了如何执行这个过程(在这个例子中,我们使用随机梯度下降)。所有优化逻辑都封装在优化器对象中。这里,我们使用 SGD 优化器; 此外,PyTorch 中还有许多不同的优化器,如 ADAM 和 RMSProp ,它们可以更好地用于不同类型的模型和数据。
我们通过注册需要训练的模型参数来初始化优化器,并传入学习率超参数。
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
在训练的循环中,优化分为3个步骤:
- 调用 optimizer.zero_grad() 重置模型参数的梯度,默认情况下,梯度是累加的。为了防止重复计算,我们在每次迭代中显式将他们归零。
- 通过调用 loss.backward() 反向传播预测损失, PyTorch 保存每个参数的损失梯度。
- 一旦我们有了梯度,我们调用 optimizer.step() 在向后传递中收集梯度调整参数。
Full Implementation (完整实现)
我们定义了遍历优化参数代码的 train loop, 以及根据测试数据定义了test loop。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
## 数据集
training_data = datasets.FashionMNIST(
root="../../data/",
train=True,
download=True,
transform=transforms.ToTensor()
)
test_data = datasets.FashionMNIST(
root="../../data/",
train=False,
download=True,
transform=transforms.ToTensor()
)
## dataloader
train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)
## 定义神经网络
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
out = self.flatten(x)
out = self.linear_relu_stack(out)
return out
## 实例化模型
model = NeuralNetwork()
## 损失函数
loss_fn = nn.CrossEntropyLoss()
## 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
## 超参数
learning_rate = 1e-3
batch_size = 32
epochs = 5
## 训练循环
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# 计算预测和损失
pred = model(X)
loss = loss_fn(pred, y)
## 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
## 测试循环
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")
## 训练网络
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")
Output exceeds the size limit. Open the full output data in a text editor
Epoch 1
-------------------------------
loss: 2.286209 [ 0/60000]
loss: 2.276216 [ 3200/60000]
loss: 2.284355 [ 6400/60000]
loss: 2.261775 [ 9600/60000]
loss: 2.263164 [12800/60000]
loss: 2.248842 [16000/60000]
loss: 2.220280 [19200/60000]
loss: 2.189240 [22400/60000]
loss: 2.196523 [25600/60000]
loss: 2.185551 [28800/60000]
loss: 2.168148 [32000/60000]
loss: 2.121402 [35200/60000]
loss: 2.117749 [38400/60000]
loss: 2.077955 [41600/60000]
loss: 2.069381 [44800/60000]
loss: 2.021929 [48000/60000]
loss: 1.964773 [51200/60000]
loss: 1.981333 [54400/60000]
loss: 2.008120 [57600/60000]
Test error:
Accuracy: 50.4%, Avg loss: 1.910086
Epoch 2
...
Test error:
Accuracy: 70.5%, Avg loss: 0.797899
Done!
【参考】
Optimizing Model Parameters — PyTorch Tutorials 1.13.1+cu117 documentation
相关文章
- 《Dreamweaver CS6 完全自学教程》笔记 第一章:网页制作基础知识
- Java基础知识总结之面向对象
- 深度神经网络基础知识
- 代码审计 | Java EE 基础知识
- Java面试|Java基础知识总结一
- C语言基础知识梳理总结[通俗易懂]
- 菜鸟也能懂的 - 音视频基础知识
- PostgreSQL基础知识之SQL操作符实践指南
- nginx应用总结(1)– 基础知识和应用配置梳理详解程序员
- Linux驱动编程入门:基础知识篇(linux驱动基础知识)
- Oracle教程:掌握数据库管理的基础知识(oracle教程)
- 深入理解Oracle用户和实例名的基础知识(oracle用户和实例名)
- 探索Linux中声卡的基础知识(linux查看声卡)
- 「Linux软件开发教程」:学习Linux基础知识与开发技巧,掌握软件开发流程,增强编程实践能力,提高编程效率。(linux软件开发教程)
- 探究MySQL性能:从基础知识到测试实践(mysql的性能测试)
- MySQL中文版下载轻松学习数据库基础知识(mysql下载中文版下载)
- MySQL数据库详细教程,逐步介绍MySQL基础知识SQL语法存储引擎表设计备份与恢复等内容,帮助初学者快速入门MySQL
- [基础知识]Linux新手系列之一
- C#基础知识系列八const和readonly关键字详细介绍
- Go语言基础知识总结(语法、变量、数值类型、表达式、控制结构等)
- AngularJS基础知识