Tensor的操作(合并分割;数学计算;统计;比较)
目录
拼接
cat
torch.cat([a,b], dim=n)
:将tensor按指定的维度拼接。 合并的tensor维度要一样,除了合并以外的其他维度数据量也要一样。
In [1]: import torch
In [2]: a = torch.rand(4,32,8)
In [3]: b = torch.rand(5,32,8)
In [4]: torch.cat([a,b],dim=0).shape # 按dim=0 合并
Out[4]: torch.Size([9, 32, 8])
stack
torch.stack([a,b],dim=n)
:创建一个新的维度。
两个tensor的维度必须一摸一样
In [5]: a1 = torch.rand(4,3,16,32)
In [6]: a2 = torch.rand(4,3,16,32)
In [7]: torch.stack([a1,a2],dim=2).shape
Out[7]: torch.Size([4, 3, 2, 16, 32])
分割
split
a.split([len1,len2],dim=n)
:将tensor在n维上按len1,len2 进行拆分。
In [8]: a = torch.rand(3,32,8)
In [9]: aa,bb = a.split([2,1],dim=0)
In [10]: aa.shape,bb.shape
Out[10]: (torch.Size([2, 32, 8]), torch.Size([1, 32, 8]))
In [11]: aa,bb,cc = a.split(1,dim=0)
In [12]: aa.shape,bb.shape,cc.shape
Out[12]: (torch.Size([1, 32, 8]), torch.Size([1, 32, 8]), torch.Size([1, 32, 8]))
chunk
a.chunl(n,dim=n)
:在n维上将tensor分成n个
In [13]: aa,bb = a.chunk(2,dim=0)
In [14]: aa.shape
Out[14]: torch.Size([2, 32, 8])
In [15]: bb.shape
Out[15]: torch.Size([1, 32, 8])
数学计算
基本加减乘除
+ = (torch.add(a,b))
;
- = (torch.sub(a,b))
;
* = (torch.mul(a,b))
;
/ = (torch.div(a,b))
;
** = (a.pow(2))
自然指数和自然对数
In [21]: a = torch.exp(torch.ones(2,2))
In [22]: a
Out[22]:
tensor([[2.7183, 2.7183],
[2.7183, 2.7183]])
In [23]: torch.log(a)
Out[23]:
tensor([[1., 1.],
[1., 1.]])
In [24]: torch.log2(a)
Out[24]:
tensor([[1.4427, 1.4427],
[1.4427, 1.4427]])
矩阵的乘法
@ = (torch.matmul(a,b))
, 大于2维的tensor做矩阵的乘法,取后面两维计算,前面不变
In [16]: a = torch.rand(4,3,28,64)
In [17]: b = torch.rand(4,3,64,32)
In [18]: torch.matmul(a,b).shape
Out[18]: torch.Size([4, 3, 28, 32])
近似
a.floor()
:向下近似;
a.ceil()
:向上近似;
a.round()
:四舍五入;
a.trunc()
:取整数部分;
a.frac()
:取小数部分.
clamp
a.clamp(n)
:a中不到n的值变成n;
a.clamp(n,m)
:a中超过m的变成m.
In [25]: a = torch.rand(2,3)*10
In [26]: a
Out[26]:
tensor([[7.2193, 0.8880, 8.3444],
[9.9666, 6.6064, 8.0220]])
In [27]: a.max()
Out[27]: tensor(9.9666)
In [28]: a.median()
Out[28]: tensor(7.2193)
In [29]: a.clamp(8)
Out[29]:
tensor([[8.0000, 8.0000, 8.3444],
[9.9666, 8.0000, 8.0220]])
In [30]: a.clamp(1,8)
Out[30]:
tensor([[7.2193, 1.0000, 8.0000],
[8.0000, 6.6064, 8.0000]])
统计
norm(范数)
1-范数
∥
X
∥
1
=
∑
i
=
1
n
∣
a
i
∣
\left \| X \right \|_{1}=\sum_{i=1}^{n}\left | a_{i} \right |
∥X∥1=i=1∑n∣ai∣
a.norm(1,dim=n)
:在第n维求1-范数
2-范数(F-范数)
∥
X
∥
2
=
∑
i
=
1
n
a
i
2
\left \| X \right \|_{2}=\sqrt{\sum_{i=1}^{n} a_{i} ^{2}}
∥X∥2=i=1∑nai2
a.norm(2,dim=n)
:在第n维求2-范数
P-范数
∥
X
∥
2
=
(
∑
i
=
1
n
a
i
p
)
1
/
p
\left \| X \right \|_{2}=(\sum_{i=1}^{n} a_{i} ^{p})^{1/p}
∥X∥2=(i=1∑naip)1/p
a.norm(3,dim=n)
:在第n维求P-范数
mean, sum, min, max, prod
In [2]: a=torch.arange(8).view(2,4).float()
In [3]: a
Out[3]:
tensor([[0., 1., 2., 3.],
[4., 5., 6., 7.]])
In [4]: a.min(),a.max(),a.mean(),a.prod()
Out[4]: (tensor(0.), tensor(7.), tensor(3.5000), tensor(0.))
a.prod
:求累乘
argmin,argmax
如果没有指定维度,那么会将tensor先打平成维度为1的tensor然后返回索引。
In [5]: a.argmin(),a.argmax()
Out[5]: (tensor(0), tensor(7))
In [6]: a=torch.randn(4,10)
In [7]: a
Out[7]:
tensor([[-1.3142, 1.3024, -0.6175, -1.9788, -0.6462, -1.7694, 0.9113, 0.2646,
-1.4072, -0.3260],
[ 1.5882, -0.1465, 0.1068, -0.2997, -1.5346, 0.7147, 1.0094, 0.8033,
0.7738, -0.7427],
[-1.7755, 0.5528, -0.6305, -0.9959, 1.0716, -1.2245, -0.5265, -0.4162,
0.8127, 1.2548],
[ 0.3259, 0.4556, 0.4757, -0.7854, 0.6494, -1.2899, -0.7239, 0.7183,
-0.3027, 0.2722]])
In [8]: a.argmax()
Out[8]: tensor(10)
In [9]: a.argmax(dim=1)
Out[9]: tensor([1, 0, 9, 7])
In [10]: a.argmax(dim=1,keepdim=True)
Out[10]:
tensor([[1],
[0],
[9],
[7]])
Top-k or k-th
In [11]: a.topk(3,dim=1) # 返回最大的三个
Out[11]:
torch.return_types.topk(
values=tensor([[1.3024, 0.9113, 0.2646],
[1.5882, 1.0094, 0.8033],
[1.2548, 1.0716, 0.8127],
[0.7183, 0.6494, 0.4757]]),
indices=tensor([[1, 6, 7],
[0, 6, 7],
[9, 4, 8],
[7, 4, 2]]))
In [12]: a.topk(3,dim=1,largest=False) # 返回最小的三个
Out[12]:
torch.return_types.topk(
values=tensor([[-1.9788, -1.7694, -1.4072],
[-1.5346, -0.7427, -0.2997],
[-1.7755, -1.2245, -0.9959],
[-1.2899, -0.7854, -0.7239]]),
indices=tensor([[3, 5, 8],
[4, 9, 3],
[0, 5, 3],
[5, 3, 6]]))
比较
In [13]: a>0
Out[13]:
tensor([[False, True, False, False, False, False, True, True, False, False],
[ True, False, True, False, False, True, True, True, True, False],
[False, True, False, False, True, False, False, False, True, True],
[ True, True, True, False, True, False, False, True, False, True]])
In [14]: torch.gt(a,0)
Out[14]:
tensor([[False, True, False, False, False, False, True, True, False, False],
[ True, False, True, False, False, True, True, True, True, False],
[False, True, False, False, True, False, False, False, True, True],
[ True, True, True, False, True, False, False, True, False, True]])
In [15]: a!=0
Out[15]:
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
torch.eq(a,b)
:比较两个tensor是否一样。
相关文章
- 【BZOJ3992】序列统计(动态规划,NTT)
- Linux查看网络连接数,统计网络连接数(netstat、Apache连接数)
- BZOJ3992 : [SDOI2015]序列统计
- 统计一行字符中有多少个单词
- Google Earth Engine(GEE)——河流的可视化操作使用ee.Reducer.anyNonZero()统计结果
- Google Earth Engine——空间统计reduceregion()统计和分析以及csv表格数据的导出
- 【PAT乙级】1038 统计同成绩学生 (20 分)
- 数学即音乐,统计即文学
- 自定义百度统计功能使用帮助文档
- MapReduce案例-关于流量统计的求和分区规约排序操作
- Django实现adminx后台统计外键关联内容数据
- shell 网站相关统计命令
- java 统计字符个数
- (JAVA编程练习):输入一行字符,分别统计出其中英文字母、空格、数字和其它字符的个数。
- Linux 压缩指令及名称意义统计
- 浅析Java8新特性-Stream流操作:Stream概念、常见中间/终止操作符、创建stream的3种方式、串行流/并行流的区分、使用示例(遍历/匹配、过滤、聚合、映射、归约、归集、统计、分区分组、接合、排序、组合/提取、分页、并行、集合转Map、使用并行流注意点)
- LeetCode1109之航班预订统计(相关话题:差分数组)
- 字符串统计
- 统计中的t检验
- 详解NGINX如何统计网站的PV、UV、独立IP