在Pytorch中,torch.diag只能用于非batch数据。
下面实现batch版本的diag:
import torch
def matrix_diag(diagonal):
"""
diagonal: [batch_size, N]
"""
N = diagonal.shape[-1]
shape = diagonal.shape[:-1] + (N, N)
device, dtype = diagonal.device, diagonal.dtype
result = torch.zeros(shape, dtype=dtype, device=device)
indices = torch.arange(result.numel(), device=device).reshape(shape)
indices = indices.diagonal(dim1=-2, dim2=-1)
result.view(-1)[indices] = diagonal
return result
Pytorch也出了一个自己内置的版本:
diag_embed。