Skip to content

Instantly share code, notes, and snippets.

@zezhishao
Created October 30, 2021 16:45
Show Gist options
  • Select an option

  • Save zezhishao/64b67961a431ce02148c7596ae5b4d35 to your computer and use it in GitHub Desktop.

Select an option

Save zezhishao/64b67961a431ce02148c7596ae5b4d35 to your computer and use it in GitHub Desktop.
Pytorch Element-Wise Matrix Product

Assume we got two matrix [A1, A2] and other two matrix [W1, W2], where A in [N, D] and W in [D, H]. And the goal is to get [A1W1, A2W2] to cat them (e.g. Mixhop GCN), we can achieve it by torch.matmul as following:

import torch
N = 207
D = 64
H = 32
A1 = torch.randn(N, D)
A2 = torch.randn(N, D)
W1 = torch.randn(D, H)
W2 = torch.randn(D, H)
# element-wise product of matrix
A  = torch.stack([A1, A2], dim=0)
W  = torch.stack([W1, W2], dim=0)
result = torch.matmul(A, W)                 # shape: [2, N, H]
result = result.permute(1, 0, 2)            # shape: [N, 2, H]
result = result.reshape(N, A.shape[0]*H)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment