Skip to content

Instantly share code, notes, and snippets.

@zezhishao
Created October 30, 2021 16:49
Show Gist options
  • Select an option

  • Save zezhishao/828f006aa804acaf6f0098b0d5afa1ee to your computer and use it in GitHub Desktop.

Select an option

Save zezhishao/828f006aa804acaf6f0098b0d5afa1ee to your computer and use it in GitHub Desktop.
Torch中的Unsqueeze会增大显存占用量
import torch as th
device = th.device("cuda:0")
data1 = th.randn(207, 621).to(device)
data2 = th.randn(64, 13, 621, 32).to(device)
# situation 1
data3 = th.matmul(data1, data2) # 1175MiB / 11019MiB
# situation 2
# data4 = th.matmul(data1.unsqueeze(0), data2) # 1519MiB / 11019MiB
# equal = th.all(data3 == data4) # True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment