Skip to content

Instantly share code, notes, and snippets.

@zezhishao
Created October 30, 2021 16:44
Show Gist options
  • Select an option

  • Save zezhishao/a337a12b601fc508f1d1a92e75c9d21f to your computer and use it in GitHub Desktop.

Select an option

Save zezhishao/a337a12b601fc508f1d1a92e75c9d21f to your computer and use it in GitHub Desktop.
Pytorch的独热编码
import torch.nn.functional as F
F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
>>> tensor([[1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0]])

其输入向量必须是torch.LongTenso的整型,IntTensor也不行。 可以支持多维度,但是最好要指定好num_class,尤其是minibatch训练的时候。

a = list(range(34272))
a = [_%288 for _ in a]
aa = torch.tensor(a).unsqueeze(1).expand(-1, 207)
aaa = aa[286:298, :].unsqueeze(0).expand(64, -1, -1)
v = F.one_hot(aaa, num_classes=288)

Usage

import torch
import torch.nn.functional as F
F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
# >>> tensor([[1, 0, 0, 0, 0],
#             [0, 1, 0, 0, 0],
#             [0, 0, 1, 0, 0],
#             [1, 0, 0, 0, 0],
#             [0, 1, 0, 0, 0]])

# 其输入向量必须是torch.LongTenso的整型,IntTensor也不行。可以支持多维度,但是最好要指定好num_class,尤其是minibatch训练的时候。
a = list(range(34272))
a = [_%288 for _ in a]
aa = torch.tensor(a).unsqueeze(1).expand(-1, 207)
aaa = aa[286:298, :].unsqueeze(0).expand(64, -1, -1)
v = F.one_hot(aaa, num_classes=288)
# 切换数据类型的方法:
A = torch.Tensor([1.0, 2.0, 3.0])    # 非LongInt
A.type(torch.LongInt).to(device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment