【pytorch】requires_grad、volatile、no_grad()==>节点不保存梯度,即不进行反向传播
2023-09-14 09:06:09 时间
requires_grad
Variable变量的requires_grad的属性默认为False,若一个节点requires_grad被设置为True,那么所有依赖它的节点的requires_grad都为True。
x=Variable(torch.ones(1))
w=Variable(torch.ones(1),requires_grad=True)
y=x*w
x.requires_grad,w.requires_grad,y.requires_grad
Out[23]: (False, True, True)
y依赖于w,w的requires_grad=True,因此y的requires_grad=True (类似or操作)
volatile
volatile=True是Variable的另一个重要的标识,它能够将所有依赖它的节点全部设为volatile=True,其优先级比requires_grad=True高。因而volatile=True的节点不会求导,即使requires_grad=True,也不会进行反向传播,对于不需要反向传播的情景(inference,测试推断),该参数可以实现一定速度的提升,并节省一半的显存,因为其不需要保存梯度
前方高能预警:如果你看完了前面volatile,请及时把它从你的脑海中擦除掉,因为
UserWarning: volatile was removed (Variable.volatile is always False)
该属性已经在0.4版本中被移除了,并提示你可以使用with torch.no_grad()代替该功能
>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
即使一个tensor(命名为x)的requires_grad = True,由x得到的新tensor(带有with torch.no_grad())(命名为w-标量)requires_grad也为False,且grad_fn也为None,即不会对w求导。
x = torch.randn(10, 5, requires_grad = True)
y = torch.randn(10, 5, requires_grad = True)
z = torch.randn(10, 5, requires_grad = True)
with torch.no_grad():
w = x + y + z
print(w.requires_grad)
print(w.grad_fn)
print(w.requires_grad)
False
None
False
pytorch笔记:06)requires_grad和volatile_Javis486的专栏-CSDN博客
链接:https://www.jianshu.com/p/1cea017f5d11
相关文章
- k8s存储节点和POD存储数据
- Java 根据XPATH批量替换XML节点中的值
- minio节点扩展_多节点部署定时任务
- 以太坊私有链搭建_以太坊节点减少
- 多节点、长路径桑基图在线编辑工具上线
- 2023-03-20:给定一个无向图,保证所有节点连成一棵树,没有环, 给定一个正数n为节点数,所以节点编号为0~n-1,那么就一定有n-1条边, 每条边形式为
- 到主节点:Redis从新手走向大师:梦想中的主节点(redis从节点)
- 探究Oracle树形结构中是否存在叶节点(oracle中有叶节点吗)
- 节点Redis集群至少需要3个节点(redis集群至少多少个)
- 数最小化部署Redis集群的最低节点数(redis集群最小节点)
- 节点Redis集群安装3个节点实现高可用(redis集群安装3个)
- Cocos2d-x3.x入门教程(二):Node节点类