torch.bmm() 与 torch.matmul()==>张量的相乘运算
torch.bmm()强制规定维度和大小相同
torch.matmul()没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作
当进行操作的两个tensor都是3D时,两者等同。
torch.bmm()
用法:
torch.bmm(input, mat2, out=None) → Tensor
torch.bmm()是tensor中的一个相乘操作,类似于矩阵中的A*B。
参数:
input,mat2:两个要进行相乘的tensor结构,两者必须是3D维度的,每个维度中的大小是相同的。
output:输出结果
并且相乘的两个矩阵,要满足一定的维度要求:input(p,m,n) * mat2(p,n,a) ->output(p,m,a)。这个要求,可以类比于矩阵相乘。前一个矩阵的列等于后面矩阵的行才可以相乘。
函数作用
计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h1,w),tensor b的size为(b,w,h2),注意两个tensor的维度必须为3.,其中a和b里面的h可以不相等,只要保证w相等就行了。
>>> cc=torch.randn((2,2,5))
>>>print(cc)
tensor([[[ 1.4873, -0.7482, -0.6734, -0.9682, 1.2869],
[ 0.0550, -0.4461, -0.1102, -0.0797, -0.8349]],
[[-0.6872, 1.1920, -0.9732, 0.4580, 0.7901],
[ 0.3035, 0.2022, 0.8815, 0.9982, -1.1892]]])
>>>dd=torch.reshape(cc,(2,5,2))
>>> print(dd)
tensor([[[ 1.4873, -0.7482],
[-0.6734, -0.9682],
[ 1.2869, 0.0550],
[-0.4461, -0.1102],
[-0.0797, -0.8349]],
[[-0.6872, 1.1920],
[-0.9732, 0.4580],
[ 0.7901, 0.3035],
[ 0.2022, 0.8815],
[ 0.9982, -1.1892]]])
>>>e=torch.bmm(cc,dd)
>>> print(e)
tensor([[[ 2.1787, -1.3931],
[ 0.3425, 1.0906]],
[[-0.5754, -1.1045],
[-0.6941, 3.0161]]])
>>> e.size()
torch.Size([2, 2, 2])
torch.matmul()
torch.matmul(input, other, out=None) → Tensor
torch.matmul()也是一种类似于矩阵相乘操作的tensor联乘操作。但是它可以利用python 中的广播机制,处理一些维度不同的tensor结构进行相乘操作。这也是该函数与torch.bmm()区别所在。
参数:
input,other:两个要进行操作的tensor结构
output:结果
一些规则约定:
(1)若两个都是1D(向量)的,则返回两个向量的点积
import torch
x = torch.rand(2)
y = torch.rand(2)
print(torch.matmul(x,y),torch.matmul(x,y).size())
output:
tensor(0.1353) torch.Size([])
(2)若两个都是2D(矩阵)的,则按照(矩阵相乘)规则返回2D
x = torch.rand(2,4)
y = torch.rand(4,3) ###维度也要对应才可以乘
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
output:
tensor([[0.9128, 0.8425, 0.7269],
[1.4441, 1.5334, 1.3273]])
torch.Size([2, 3])
(3)若input维度1D,other维度2D,则先将1D的维度扩充到2D(1D的维数前面+1),然后得到结果后再将此维度去掉,得到的与input的维度相同。即使作扩充(广播)处理,input的维度也要和other维度做对应关系。
import torch
x = torch.rand(4) #1D
y = torch.rand(4,3) #2D
print(x.size())
print(y.size())
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
### 扩充x =>(,4)
### 相乘x(,4) * y(4,3) =>(,3)
### 去掉1D =>(3)
output:
torch.Size([4])
torch.Size([4, 3])
tensor([0.9600, 0.5736, 1.0430])
torch.Size([3])
(4)若input是2D,other是1D,则返回两者的点积结果。(个人觉得这块也可以理解成给other添加了维度,然后再去掉此维度,只不过维度是(3, )而不是规则(3)中的( ,4)了,但是可能就是因为内部机制不同,所以官方说的是点积而不是维度的升高和下降)
import torch
x = torch.rand(3) #1D
y = torch.rand(4,3) #2D
print(torch.matmul(y,x),'\n',torch.matmul(y,x).size()) #2D*1D
output:
torch.Size([3])
torch.Size([4, 3])
tensor([0.8278, 0.5970, 1.0370, 0.2681])
torch.Size([4])
(5)如果一个维度至少是1D,另外一个大于2D,则返回的是一个批矩阵乘法( a batched matrix multiply)。
(a)若input是1D,other是大于2D的,则类似于规则(3)。
import torch
x = torch.randn(2, 3, 4)
y = torch.randn(3)
print(torch.matmul(y, x),'\n',torch.matmul(y, x).size()) #1D*3D
output:
tensor([[-0.9747, -0.6660, -1.1704, -1.0522],
[ 0.0901, -1.5353, 1.5601, -0.0252]])
torch.Size([2, 4])
(b)若other是1D,input是大于2D的,则类似于规则(4)。
import torch
x = torch.randn(2, 3, 4)
y = torch.randn(4)
print(torch.matmul(x, y),'\n',torch.matmul(x, y).size()) # 3D*1D
output:
tensor([[ 0.6217, -0.1259, -0.2377],
[ 0.6874, 0.0733, 0.1793]])
torch.Size([2, 3])
(c)若input和other都是3D的,则与torch.bmm()函数功能一样。
import torch
x = torch.randn(2,2,4)
y = torch.randn(2,4,5)
print(torch.matmul(x, y).size(),'\n',torch.bmm(x, y).size())
print(torch.equal(torch.matmul(x,y),torch.bmm(x,y)))
output:
torch.Size([2, 2, 5])
torch.Size([2, 2, 5])
True
(d)如果input中某一维度满足可以广播(扩充),那么也是可以进行相乘操作的。例如 input(j,1,n,m)* other (k,m,p) = output(j,k,n,p)。
import torch
x = torch.randn(10,1,2,4)
y = torch.randn(2,4,5)
print(torch.matmul(x, y).size())
output:
torch.Size([10, 2, 2, 5])
这个例子中,可以理解为x中dim=1这个维度可以扩充(广播),y中可以添加一个维度,然后在进行批乘操作。
相关文章
- 【数字图像处理】二值化图像腐蚀运算与膨胀运算
- Java实现 LeetCode 1111 有效括号的嵌套深度(阅读理解题,位运算)
- Java实现 LeetCode 1111 有效括号的嵌套深度(阅读理解题,位运算)
- Java实现 LeetCode 241 为运算表达式设计优先级
- Java实现复数运算
- SQL Server调优系列基础篇(并行运算总结篇二)
- LeetCode-1832. 判断句子是否为全字母句【哈希表,位运算】
- 【STM32H7的DSP教程】第24章 DSP变换运算-傅里叶变换
- 【STM32H7的DSP教程】第23章 DSP辅助运算-math_help中函数的使用
- 【STM32F407的DSP教程】第21章 DSP矩阵运算-加法,减法和逆矩阵
- Atitit 嵌入式系统与pc系统的对比 目录 1. 哈佛结构和冯诺依曼结构 普林斯顿结构区1 2. 中断程序 类库调用1 3. 指令集 三大流程语句 与 运算语句 赋值语句1 4. 异
- Py之cv2:cv2(OpenCV,opencv-python)库的简介、安装、使用方法(常见函数、图像基本运算等)最强详细攻略
- Python从0到1丨图像增强及运算:形态学开运算、闭运算和梯度运算
- Python图像处理丨三种实现图像形态学转化运算模式
- 7 分钟全面了解位运算
- 1318. 或运算的最小翻转次数-c语言函数实现
- shell脚本之求和运算
- shell脚本变量详解(自定义变量、环境变量、变量赋值、变量运算、变量内容替换)
- 运算放大电路的基础(秒懂)