zl程序教程

您现在的位置是:首页 >  后端

当前栏目

Python视觉深度学习系列教程 第一卷 第18章 Checkpointing模型

Python教程学习 系列 模型 深度 视觉 18
2023-09-14 09:01:35 时间

         第一卷 第十八章 Checkpointing模型

        在第13章中,我们讨论了如何在训练完成后将模型保存和序列化到磁盘。在上一章中,我们学习了如何在发生时发现欠拟合和过拟合,从而使您能够终止表现不佳的实验,同时保持模型在训练时表现出的希望。

        但是,您可能想知道是否可以将这两种策略结合起来。每当我们的损失/准确性提高时,我们可以序列化模型吗?或者是否可以仅序列化训练过程中的最佳模型(即损失最低或准确率最高的模型)?你打赌。幸运的是,我们也不必构建自定义回调——此功能已直接集成到Keras中。

        1、检查点神经网络模型改进

        检查点的一个很好的应用是在训练期间每次有改进时将您的网络序列化到磁盘。我们将“改进”定义为损失的减少或准确性的提高——我们将在实际的Keras回调中设置此参数。

        在这个例子中,我们将在CIFAR-10数据集上训练MiniVGGNet架构,然后每次模型性能提高时将我们的网络权重序列化到磁盘。

        首先,打开一个新文件,将其命名为cifar10_checkpoint_improvements.py,并插入以下代码:

# import the necessary packages
from sklearn.preprocessing import LabelBinariz