zl程序教程

您现在的位置是:首页 >  大数据

当前栏目

【pytorch】requires_grad、volatile、no_grad()==>节点不保存梯度,即不进行反向传播

节点PyTorch 进行 No 保存 反向 梯度 传播
2023-09-14 09:06:09 时间

requires_grad

Variable变量的requires_grad的属性默认为False,若一个节点requires_grad被设置为True,那么所有依赖它的节点的requires_grad都为True。

x=Variable(torch.ones(1))
w=Variable(torch.ones(1),requires_grad=True)
y=x*w
x.requires_grad,w.requires_grad,y.requires_grad
Out[23]: (False, True, True)

y依赖于w,w的requires_grad=True,因此y的requires_grad=True (类似or操作)

volatile

volatile=True是Variable的另一个重要的标识,它能够将所有依赖它的节点全部设为volatile=True,其优先级比requires_grad=True高。因而volatile=True的节点不会求导,即使requires_grad=True,也不会进行反向传播对于不需要反向传播的情景(inference,测试推断),该参数可以实现一定速度的提升,并节省一半的显存,因为其不需要保存梯度

前方高能预警:如果你看完了前面volatile,请及时把它从你的脑海中擦除掉,因为

UserWarning: volatile was removed (Variable.volatile is always False)

该属性已经在0.4版本中被移除了,并提示你可以使用with torch.no_grad()代替该功能

>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
...   y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
...     return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False

即使一个tensor(命名为x)的requires_grad = True,由x得到的新tensor(带有with torch.no_grad())(命名为w-标量)requires_grad也为False,且grad_fn也为None,即不会对w求导。

x = torch.randn(10, 5, requires_grad = True)
y = torch.randn(10, 5, requires_grad = True)
z = torch.randn(10, 5, requires_grad = True)
with torch.no_grad():
    w = x + y + z
    print(w.requires_grad)
    print(w.grad_fn)
print(w.requires_grad)


False
None
False


pytorch笔记:06)requires_grad和volatile_Javis486的专栏-CSDN博客

链接:https://www.jianshu.com/p/1cea017f5d11