Skip to content

Instantly share code, notes, and snippets.

@sayef
Created May 4, 2023 10:11
Show Gist options
  • Select an option

  • Save sayef/da58b3d5a25b000ad9e4e06f8310be0b to your computer and use it in GitHub Desktop.

Select an option

Save sayef/da58b3d5a25b000ad9e4e06f8310be0b to your computer and use it in GitHub Desktop.
PyTorch Positional Encoding
import math
from torch import nn, Tensor
class PositionalEncoding(nn.Module):
def __init__(self, embedding_dim: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, embedding_dim, 2) * (-math.log(10000.0) / embedding_dim)
)
pe = torch.zeros(1, max_len, embedding_dim)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, x: Tensor) -> Tensor:
"""
Arguments:
x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
"""
x = x + self.pe[:, : x.size(1), :]
return self.dropout(x)
positional_encoder = PositionalEncoding(
embedding_dim=512, max_len=1000
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment