pytorch函数记录
一、torch.cat
torch的拼接函数,将2个tensor拼接起来
按列拼接
A=torch.ones(2,3) B=2*torch.ones(4,3) A B C=torch.cat((A,B),0)#0表示按行拼接,即一行行拼上,就是直接上下堆接 C
A
tensor([[1., 1., 1.], [1., 1., 1.]])
tensor([[2., 2., 2.], [2., 2., 2.], [2., 2., 2.], [2., 2., 2.]])
C
tensor([[1., 1., 1.], [1., 1., 1.], [2., 2., 2.], [2., 2., 2.], [2., 2., 2.], [2., 2., 2.]])
A=torch.ones(2,3) B=2*torch.ones(2,3) A B C=torch.cat((A,B),1) #按列进行拼接,即一列列拼上 C
A
tensor([[1., 1., 1.], [1., 1., 1.]])
tensor([[2., 2., 2.], [2., 2., 2.]])
tensor([[1., 1., 1., 2., 2., 2.], [1., 1., 1., 2., 2., 2.]])
二、torch中的转置
torch.transpose(input, dim0, dim1, out=None) → Tensor
函数返回输入矩阵input
的转置。交换维度dim0
和dim1
参数:
- input (Tensor) – 输入张量,必填
- dim0 (int) – 转置的第一维,默认0,可选
- dim1 (int) – 转置的第二维,默认1,可选
注意:一次只能转置交换2个维度。
permute 将tensor的维度换位,可以同时交换多个维度。
# 创造二维数据x,dim=0时候2,dim=1时候3 x = torch.randn(2,3) 'x.shape → [2,3]' # 创造三维数据y,dim=0时候2,dim=1时候3,dim=2时候4 y = torch.randn(2,3,4) 'y.shape → [2,3,4]'
# 对于transpose x.transpose(0,1) 'shape→[3,2] ' x.transpose(1,0) 'shape→[3,2] ' y.transpose(0,1) 'shape→[3,2,4]' y.transpose(0,2,1) 'error,操作不了多维' # 对于permute() x.permute(0,1) 'shape→[2,3]' x.permute(1,0) 'shape→[3,2], 注意返回的shape不同于x.transpose(1,0) ' y.permute(0,1) "error 没有传入所有维度数" y.permute(1,0,2) 'shape→[3,2,4]'
合法性不同 torch.transpose(x)合法, x.transpose()合法。 tensor.permute(x)不合法,x.permute()合法。 参考第二点的举例 操作dim不同: transpose()只能一次操作两个维度;permute()可以一次操作多维数据,且必须传入所有维度数,因为permute()的参数是int*。
三、矩阵相乘
1)点乘(Element-wise)
a*a
torch.mul(a,a)
2)点积(常规矩阵相乘)
torch.mm() 强制规定维度和大小相同,维度必须是2个维度
torch.bmm() 强制规定维度和大小相同,维度必须是3个维度。
torch.matmul() 没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作,当进行操作的两个tensor都是3D时,与bmm等同。
四、masked_fill(mask, value)
其中的参数mask必须是一个 ByteTensor ,而且shape的最大维度必须和 a一样 并且元素只能是 0或者1,是将 mask中为1的 元素所在的索引,在a中相同的的索引处替换为 value。
import torch
a=torch.tensor([[1,2,3],[4,5,6]])
print(a.size())
mask = torch.ByteTensor([[1],[0]])
print(mask.size())
b = a.masked_fill(mask, value=torch.tensor(-1e3))
print(a)
print(b)
torch.Size([2, 3]) torch.Size([2, 1]) tensor([[1, 2, 3], [4, 5, 6]]) tensor([[-1000, -1000, -1000], [ 4, 5, 6]])
五、expand函数
注意:只能对维度值为1的维度进行扩展
六、view函数
在pytorch中view函数的作用为重构张量的维度,相当于numpy中resize()的功能,但是用法可能不太一样。
七.常用简单函数记录
1.zero_()函数
将Tensor清零,是对已有Tensor的操作,可用于梯度清零;
而zeros()是创建值为0的新的Tensor,两者不要混淆了。
2.Tensor.data
相当于重新复制了一份tensor,与原tensor内存不共享
3.Tensor.detach()
跟tensor.data一样也是开辟了一块新内容,可以代替t.data使用,主要作用是将该tensor参数从网络中隔离开,不再参与参数更新,相当于从这个点截断了,不再向后继续反向传播了
4.torch.randn(1,8)
随机生成标准正态分布的数据,数据的shape=1*8
5.repeat函数
相关文章
- linux用gpu运行pytorch框架
- pytorch的两个函数 .detach() .detach_() 的作用和区别
- PyTorch 学习笔记(六):PyTorch的十八个损失函数
- Pytorch线性回归测试
- 神经网络架构PYTORCH-几个概念
- pytorch学习:准备自己的图片数据
- Pytorch速成教程(二) 常用函数
- 【转载】 浅谈PyTorch的可重复性问题(如何使实验结果可复现)
- 【官网链接】 REPRODUCIBILITY —— pytorch的可复现性
- 【转载】 Pytorch(1) pytorch中的BN层的注意事项
- 【PyTorch教程】02-如何获取张量的形状维度、数据类型和所在设备
- 【PyTorch教程】03-张量运算详细总结
- 【动手学深度学习】基于pytorch的深度学习本地环境搭建
- pytorch下可训练分段函数的写法
- pytorch实现topk剪枝
- pytorch中反向传播和求导的理解
- 使用pytorch-lightning漂亮地进行深度学习研究(转)
- 利用pytorch构建一个完整的自定义神经网络
- Pytorch下二进制语义分割Focal Loss的实现
- PyTorch初学者指南:数据操作
- pytorch将其他类型转为tensor torch.as_tensor()、torch.from_numpy()
- pytorch view函数
- pytorch学习---dataset