【偷偷卷死小伙伴Pytorch20天】-【day6】-【自动微分机制】
自动 机制 小伙伴 微分 偷偷 Day6
2023-09-14 09:14:57 时间
系统教程20天拿下Pytorch
最近和中哥、会哥进行一个小打卡活动,20天pytorch,这是第6天。欢迎一键三连。
文章目录
神经网络通常依赖反向传播求梯度来更新网络参数,求梯度过程通常是一件非常复杂而容易出错的事情。
而深度学习框架可以帮助我们自动地完成这种求梯度运算。
Pytorch一般通过反向传播 backward 方法 实现这种求梯度计算。该方法求得的梯度将存在对应自变量张量的grad属性下。
除此之外,也能够调用torch.autograd.grad 函数来实现求梯度计算。
这就是Pytorch的自动微分机制。
一、利用backward方法求导数
backward 方法通常在一个标量张量上调用,该方法求得的梯度将存在对应自变量张量的grad属性下。
如果调用的张量非标量,则要传入一个和它同形状 的gradient参数张量。
相当于用该gradient参数张量与调用张量作向量点乘,得到的标量结果再反向传播。
1, 标量的反向传播
import numpy as np
import torch
# f(x) = a*x**2 + b*x + c的导数
x = torch.tensor(0.0,requires_grad = True) # x需要被求导
a = torch.tensor(1.0)
b = torch.tensor(-2.0)
c = torch.tensor(1.0)
y = a*torch.pow(x,2) + b*x + c
y.backward()
dy_dx = x.grad
print(dy_dx)
requires_grad = True可以将其理解为开启自变量
2, 非标量的反向传播
import numpy as np
import torch
# f(x) = a*x**2 + b*x + c
x = torch.tensor([[0.0,0.0],[1.0,2.0]],requires_grad = True) # x需要被求导
a = torch.tensor(1.0)
b = torch.tensor(-2.0)
c = torch.tensor(1.0)
y = a*torch.pow(x,2) + b*x + c
gradient = torch.tensor([[1.0,1.0],[1.0,1.0]])
print("x:\n",x)
print("y:\n",y)
y.backward(gradient = gradient)
x_grad = x.grad
print("x_grad:\n",x_grad)
对于非标量来说,是求点积,到最后出来一个标量才能求导
3, 非标量的反向传播可以用标量的反向传播实现
这里也就是解释了 2, 非标量的反向传播
import numpy as np
import torch
# f(x) = a*x**2 + b*x + c
x = torch.tensor([[0.0,0.0],[1.0,2.0]],requires_grad = True) # x需要被求导
a = torch.tensor(1.0)
b = torch.tensor(-2.0)
c = torch.tensor(1.0)
y = a*torch.pow(x,2) + b*x + c
gradient = torch.tensor([[1.0,1.0],[1.0,1.0]])
z = torch.sum(y*gradient)
# y.backward(gradient=gradient)
print("x:",x)
print("y:",y)
z.backward()
x_grad = x.grad
print("x_grad:\n",x_grad)
二、利用autograd.grad函数求导数
import numpy as np
import torch
# f(x) = a*x**2 + b*x + c的导数
x = torch.tensor(0.0,requires_grad = True) # x需要被求导
a = torch.tensor(1.0,requires_grad = True)
b = torch.tensor(-2.0)
c = torch.tensor(1.0)
y = a*torch.pow(x,2) + b*x + c
# create_graph 设置为 True 将允许创建更高阶的导数
dy_dx = torch.autograd.grad(y,x,create_graph=True) # 返回数组
print(dy_dx)
# 求二阶导数
dy2_dx2 = torch.autograd.grad(dy_dx,x)[0]
print(dy2_dx2.data)
import numpy as np
import torch
x1 = torch.tensor(1.0,requires_grad = True) # x需要被求导
x2 = torch.tensor(2.0,requires_grad = True)
y1 = x1*x2
y2 = x1+x2
# 允许同时对多个自变量求导数
(dy1_dx1,dy1_dx2) = torch.autograd.grad(outputs=y1,inputs = [x1,x2],retain_graph = True)
print(dy1_dx1,dy1_dx2)
# 如果有多个因变量,相当于把多个因变量的梯度结果求和
(dy12_dx1,dy12_dx2) = torch.autograd.grad(outputs=[y1,y2],inputs = [x1,x2])
print(dy12_dx1,dy12_dx2)
三、利用自动微分和优化器求最小值
import numpy as np
import torch
# f(x) = a*x**2 + b*x + c的最小值
x = torch.tensor(0.0,requires_grad = True) # x需要被求导
a = torch.tensor(1.0)
b = torch.tensor(-2.0)
c = torch.tensor(1.0)
optimizer = torch.optim.SGD(params=[x],lr = 0.01)
def f(x):
result = a*torch.pow(x,2) + b*x + c
return(result)
for i in range(500):
optimizer.zero_grad()
y = f(x)
y.backward()
optimizer.step()
print("y=",f(x).data,";","x=",x.data)
总结
标量求导和非标量求导及其本质(requires_grad=True)
autograd.grad函数求导数
torch.autograd.grad(y,x,create_graph=True) # 返回数组
(dy12_dx1,dy12_dx2) = torch.autograd.grad(outputs=[y1,y2],inputs = [x1,x2])如果有多个因变量,相当于对梯度求和
梯度下降算法求最小值(梯度清0,求梯度,梯度更新)
相关文章
- Java 版下载必应每日壁纸并自动设置 Windows 系统桌面(改编自 C# 版)
- 办公技巧 SecureCRTPortable如何设置自动保存日志[通俗易懂]
- Oracle Number 自动转Decimal问题修正「建议收藏」
- 高频面试题:谈谈你对 Spring Boot 自动装配机制的理解
- 介绍下 npm 模块安装机制,为什么输入 npm install 就可以自动安装对应的模块?
- Code For Better 谷歌开发者之声——协议栈收发数据(拼接网络包,自动重发,滑动窗口机制)
- Kubernetes 领进门 | traefik 自动签发证书及可视化面板
- 一个合格的服务器自动备份案例,闭环备份机制出错邮件报警
- 实时自动驾驶车辆定位技术概述
- mysql decimal设置默认值0 无效,设置后自动变为null(navicat设置decimal默认值失效问题)
- jQuery 实现富文本的标题自动生成目录
- MM模块自动过账原理及后台配置详解编程语言
- 微软IOC容器Unity简单代码示例3-基于约定的自动注册机制详解编程语言
- MySQL自动生成主键的好处(mysql生成主键)
- 学习Oracle中字段自动增长机制(oracle字段自动增长)
- Linux防火墙自动关闭机制研究(linux防火墙自动关闭)
- 自动清理Redis中Java实现自动过期清理(redisjava过期)
- 机制使用Redis Java过期机制实现自动清理(redisjava过期)
- 处理Java实现Redis中键值对自动过期机制(redisjava过期)
- Oracle自动排序机制——精确准确有序(oracle自动排序)
- Linux下PCI设备自动枚举机制简介(linuxpci枚举)
- 功能Oracle禁用自动维护功能(Oracle关闭自动维护)
- Oracle表行号自动生成机制(oracle中表的序号)
- Oracle数据库中主键的自动生成机制(oracle主键自己生成)
- Oracle中实现自动类型转换的机制(oracle中的自动转换)
- Redis实现自动重连机制(redis重连机制)
- Redis技术实现自动关闭订单(redis订单自动关闭)