- torch.mul()
- torch.matmul()
- torch.mm()
- torch.bmm()
Broadcast。
矩阵的element-wise相乘,other可以可以是标量,也可以是任意维度的矩阵,只要能满足broadcast就行。
# Normal version
import torch
A = torch.randn(3,4)
B = torch.randn(3, 4)
print (torch.mul(A,B).shape) # 输出 torch.size([3,4])
# Broadcast version
import torch
A = torch.randn(2,3,4)
B = torch.randn(3, 4)
print (torch.mul(A,B).shape) # 输出 torch.size([2,3,4)本操作也支持Broadcast。其使用比较复杂,情况较多:
- 两个都是一维向量:返回dot product(点积、数量积).
import torch as th
a = th.randn(5)
b = th.randn(5)
torch.matmul(a, b) # = sum(a * b)- 两个都是二维向量:返回矩阵乘法
import torch as th
a = th.randn(3, 4)
b = th.randn(4, 5)
torch.matmul(a, b) # shape: [3, 5]- 前面是二维,后面是一维,给后面添加一维,返回矩阵乘法。
import torch as th
a = th.randn(2, 2) # 2 x 2
b = th.randn(2) # 2 ( x 1)
torch.matmul(a, b) # shape: [2, 1]- 前面是一维,后面是二维,给前面添加一维,也是矩阵乘法。
import torch as th
a = th.randn(2) # (1 x ) 2
b = th.randn(2, 2) # 2 x 2
torch.matmul(a, b) # shape: [2, 1]- 更高维度的$matmal()$
在这种情况下,我们可以认为该乘法使用使用两个参数的后两个维度来计算,其他的维度都可以认为是batch维度,所有的batch维度上都会进行broadcast。
例如:
import torch as th
a = th.randn(15, 1, 64, 32)
b = th.randn(3, 32, 16)
torch.matmul(a, b).shape # [15, 3, 64, 16]一般只用来计算两个二维矩阵的矩阵乘法,而且不支持broadcast操作。
Batched torch.mm()。mat2和mat2在torch.mm的基础上添加第一维batchsize,对应位置作矩阵相乘,且不支持broad cast操作。
import torch as th
a = th.randn(64, 16, 32)
b = th.randn(64, 32, 16)
torch.matmul(a, b).shape # [64, 16, 16]Element wise用mul就行,直接用*也是一样的。
Matrix Mul用MatMul就足够。MM、BMM都可以用MatMul来实现。
Matmul不能替代bmm,matmul必须满足broadcast,bmm的才做不能实现。