Skip to content

Instantly share code, notes, and snippets.

@zezhishao
Last active October 30, 2021 16:42
Show Gist options
  • Select an option

  • Save zezhishao/fa0c39461cdea49f6b7dfc7475c08b24 to your computer and use it in GitHub Desktop.

Select an option

Save zezhishao/fa0c39461cdea49f6b7dfc7475c08b24 to your computer and use it in GitHub Desktop.
Pytorch的几种乘法

1. Matrix Multiplication in Pytorch

  • torch.mul()
  • torch.matmul()
  • torch.mm()
  • torch.bmm()

1.0. Broadcast

Broadcast。

1.1. torch.mul(mat1, other, out=None)

矩阵的element-wise相乘,other可以可以是标量,也可以是任意维度的矩阵,只要能满足broadcast就行。

# Normal version
import torch
A = torch.randn(3,4)
B = torch.randn(3, 4)
print (torch.mul(A,B).shape) # 输出 torch.size([3,4])

# Broadcast version
import torch
A = torch.randn(2,3,4)
B = torch.randn(3, 4)
print (torch.mul(A,B).shape) # 输出 torch.size([2,3,4)

1.2. torch.matmul(input, other, out=None)

本操作也支持Broadcast。其使用比较复杂,情况较多:

import torch as th
a = th.randn(5)
b = th.randn(5)
torch.matmul(a, b)    # = sum(a * b)
  • 两个都是二维向量:返回矩阵乘法
import torch as th
a = th.randn(3, 4)
b = th.randn(4, 5)
torch.matmul(a, b)  # shape: [3, 5]
  • 前面是二维,后面是一维,给后面添加一维,返回矩阵乘法。
import torch as th
a = th.randn(2, 2) # 2 x 2
b = th.randn(2)    # 2 ( x 1)
torch.matmul(a, b)  # shape: [2, 1]
  • 前面是一维,后面是二维,给前面添加一维,也是矩阵乘法。
import torch as th
a = th.randn(2) # (1 x ) 2
b = th.randn(2, 2)    # 2 x 2
torch.matmul(a, b)  # shape: [2, 1]
  • 更高维度的$matmal()$

在这种情况下,我们可以认为该乘法使用使用两个参数的后两个维度来计算,其他的维度都可以认为是batch维度,所有的batch维度上都会进行broadcast。

例如:

import torch as th
a = th.randn(15, 1, 64, 32)
b = th.randn(3, 32, 16)
torch.matmul(a, b).shape    # [15, 3, 64, 16]

1.3. torch.mm(mat1, mat2, out=None)

一般只用来计算两个二维矩阵的矩阵乘法,而且不支持broadcast操作。

1.4. torch.bmm(mat1, mat2, out=None)

Batched torch.mm()。mat2和mat2在torch.mm的基础上添加第一维batchsize,对应位置作矩阵相乘,且不支持broad cast操作。

import torch as th
a = th.randn(64, 16, 32)
b = th.randn(64, 32, 16)
torch.matmul(a, b).shape    # [64, 16, 16]

1.5 总结:

Element wise用mul就行,直接用*也是一样的。

Matrix Mul用MatMul就足够。MM、BMM都可以用MatMul来实现。

@zezhishao
Copy link
Author

Matmul不能替代bmmmatmul必须满足broadcast,bmm的才做不能实现。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment