Assume we got two matrix [A1, A2] and other two matrix [W1, W2], where A in [N, D] and W in [D, H].
And the goal is to get [A1W1, A2W2] to cat them (e.g. Mixhop GCN), we can achieve it by torch.matmul as following:
import torch
N = 207
D = 64
H = 32
A1 = torch.randn(N, D)
A2 = torch.randn(N, D)
W1 = torch.randn(D, H)
W2 = torch.randn(D, H)
# element-wise product of matrix
A = torch.stack([A1, A2], dim=0)
W = torch.stack([W1, W2], dim=0)
result = torch.matmul(A, W) # shape: [2, N, H]
result = result.permute(1, 0, 2) # shape: [N, 2, H]
result = result.reshape(N, A.shape[0]*H)