Skip to content

Instantly share code, notes, and snippets.

@zezhishao
Created October 30, 2021 16:47
Show Gist options
  • Select an option

  • Save zezhishao/e28ac2ad07a1b8fb9a257261567390d0 to your computer and use it in GitHub Desktop.

Select an option

Save zezhishao/e28ac2ad07a1b8fb9a257261567390d0 to your computer and use it in GitHub Desktop.
Pytorch Batch Diagonal

在Pytorch中,torch.diag只能用于非batch数据。 下面实现batch版本的diag:

import torch

def matrix_diag(diagonal):
    """
    diagonal: [batch_size, N]
    """
    N = diagonal.shape[-1]
    shape = diagonal.shape[:-1] + (N, N)
    device, dtype = diagonal.device, diagonal.dtype
    result = torch.zeros(shape, dtype=dtype, device=device)
    indices = torch.arange(result.numel(), device=device).reshape(shape)
    indices = indices.diagonal(dim1=-2, dim2=-1)
    result.view(-1)[indices] = diagonal
    return result
@zezhishao
Copy link
Author

Pytorch也出了一个自己内置的版本:diag_embed

import torch
degree = torch.rand(64, 207)    # batch_size, num_nodes
diagonal = torch.diag_embed(degree)
diagonal.shape    # [64, 207, 207]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment