【pytorch】bn
2023-04-18 13:06:28 时间
bn接口定义:
torch.nn.BatchNorm2d:
def init(self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True)
args:
- momentum:
- 默认为 0.1 。
- 要freeze的时候就设置为0.0(和 tf 里面是反着来的,tf是设置为1.0才能freeze)。
- rack_running_stats:
- 计算running_mean和running_var(即moving_mean和moving_var)。
- 默认为True。
- 当设置为True时:
- train 的时候用当前batch的mean和var,并更新running_mean和running_var。
- eval 的时候用存储的running_mean和running_var,不会更新running_mean和running_var。
- 当设置为False时:
- train 的时候用当前batch的mean和var,不会更新running_mean和running_var。
- eval 的时候用当前batch的mean和var,不会更新running_mean和running_var。
- 此时所有的 xx.xx.bn.running_mean、xx.xx.bn.running_var 和 xx.xx.bn.num_batches_tracked 都会被从 model.state_dict() 里面移除。(这三类缺失值可以从其他state_dict导入来补充)
- 具体参见《Pytorch的BatchNorm层使用中容易出现的问题》。
- affine:
- 是否要“乘上缩放矩阵,加上平移向量”(也就是仿射矩阵)的开关。
- 默认为True。
Note:
- 即使对bn设置了 requires_grad = False ,一旦 model.train() ,bn还是会偷偷开启update( model.eval()模式下就又停止update )。
相关文章
- 【技术种草】cdn+轻量服务器+hugo=让博客“云原生”一下
- CLB运维&运营最佳实践 ---访问日志大洞察
- vnc方式登陆服务器
- 轻松学排序算法:眼睛直观感受几种常用排序算法
- 十二个经典的大数据项目
- 为什么使用 CDN 内容分发网络?
- 大数据——大数据默认端口号列表
- Weld 1.1.5.Final,JSR-299 的框架
- JavaFX 2012:彻底开源
- 提升as3程序性能的十大要点
- 通过凸面几何学进行独立于边际的在线多类学习
- 利用行动影响的规律性和部分已知的模型进行离线强化学习
- ModelLight:基于模型的交通信号控制的元强化学习
- 浅谈Visual Source Safe项目分支
- 基于先验知识的递归卡尔曼滤波的代理人联合状态和输入估计
- 结合网络结构和非线性恢复来提高声誉评估的性能
- 最佳实践丨云开发CloudBase多环境管理实践
- TimeVAE:用于生成多变量时间序列的变异自动编码器
- 具有线性阈值激活的神经网络:结构和算法
- 内网渗透之横向移动 -- 从域外向域内进行密码喷洒攻击