Skip to content

Instantly share code, notes, and snippets.

@markrogersjr
Last active November 26, 2025 15:39
Show Gist options
  • Select an option

  • Save markrogersjr/ebada9ad3a31381d8d4e0d956c852569 to your computer and use it in GitHub Desktop.

Select an option

Save markrogersjr/ebada9ad3a31381d8d4e0d956c852569 to your computer and use it in GitHub Desktop.
GPT-OSS with Flash Attention and Memory-Efficient Attention via PyTorch-Native SDPA

GPT-OSS with Flash Attention and Memory-Efficient Attention via PyTorch-Native SDPA

from time import time
import warnings

import torch
import transformers


def sdpa_attention_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: torch.Tensor | None,
    scaling: float,
    dropout: float | None = 0.,
    **_
):
    """
    Grouped query attention with sinks using `torch`-native scaled dot product attention. Let
    - N be the batch size,
    - H be total the number of attention heads,
    - G be the number of groups, and
    - E be the per-head embedding dimensionality.

    Parameters
    ----------
    module: torch.nn.Module
        Attention module with `sinks` `torch.Tensor` attribute of shape H
    query: torch.Tensor
        Query tensor of shape N x H x L x E
    key: torch.Tensor
        Key tensor of shape N x G x L x E
    value: torch.Tensor
        Value tensor of shape N x G x L x E
    attention_mask: torch.Tensor | None
        Attention mask of shape N x 1 x L x L
    scaling: float
        Scaling factor applied to query-key dot products, typically 1 / sqrt(E)
    dropout: float | None = 0.
        Dropout probability

    Returns
    -------
    torch.Tensor
        Attention output tensor of shape N x L x H x E
    None
        Unused attention weights
    """
    N, H, L, E = query.shape
    _, G, *_ = key.shape
    if attention_mask is None:
        attention_mask = torch.zeros(N, 1, L, L, device=query.device, dtype=query.dtype)
    attention = torch.nn.functional.scaled_dot_product_attention(
        query,
        *(
            torch.cat(
                [
                    tensor.repeat_interleave(H // G, dim=1),
                    torch.zeros(N, H, 1, E, device=tensor.device, dtype=tensor.dtype)
                ],
                dim=2
            )
            for tensor in (key, value)
        ),
        torch.cat(
            [
                attention_mask.expand(N, H, L, L).clone(),
                module.sinks.reshape(1, H, 1, 1).expand(N, H, L, 1)
            ],
            dim=3
        ),
        dropout_p=dropout,
        is_causal=attention_mask is None,
        scale=scaling
    )
    return attention.transpose(1, 2).contiguous(), None


warnings.filterwarnings('ignore', category=UserWarning, module='torch.cuda')
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
transformers.AttentionInterface.register('sdpa', sdpa_attention_forward)
transformers.AttentionMaskInterface.register('sdpa', transformers.masking_utils.eager_mask)
transformers.models.gpt_oss.modeling_gpt_oss.GptOssPreTrainedModel._supports_sdpa = True
model = transformers.AutoModelForCausalLM.from_pretrained(
    'openai/gpt-oss-20b',
    attn_implementation='sdpa',
    dtype=torch.bfloat16,
    quantization_config=transformers.Mxfp4Config(dequantize=True)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment