from time import time
import warnings
import torch
import transformers
def sdpa_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float | None = 0.,
**_
):
"""
Grouped query attention with sinks using `torch`-native scaled dot product attention. Let
- N be the batch size,
- H be total the number of attention heads,
- G be the number of groups, and
- E be the per-head embedding dimensionality.
Parameters
----------
module: torch.nn.Module
Attention module with `sinks` `torch.Tensor` attribute of shape H
query: torch.Tensor
Query tensor of shape N x H x L x E
key: torch.Tensor
Key tensor of shape N x G x L x E
value: torch.Tensor
Value tensor of shape N x G x L x E
attention_mask: torch.Tensor | None
Attention mask of shape N x 1 x L x L
scaling: float
Scaling factor applied to query-key dot products, typically 1 / sqrt(E)
dropout: float | None = 0.
Dropout probability
Returns
-------
torch.Tensor
Attention output tensor of shape N x L x H x E
None
Unused attention weights
"""
N, H, L, E = query.shape
_, G, *_ = key.shape
if attention_mask is None:
attention_mask = torch.zeros(N, 1, L, L, device=query.device, dtype=query.dtype)
attention = torch.nn.functional.scaled_dot_product_attention(
query,
*(
torch.cat(
[
tensor.repeat_interleave(H // G, dim=1),
torch.zeros(N, H, 1, E, device=tensor.device, dtype=tensor.dtype)
],
dim=2
)
for tensor in (key, value)
),
torch.cat(
[
attention_mask.expand(N, H, L, L).clone(),
module.sinks.reshape(1, H, 1, 1).expand(N, H, L, 1)
],
dim=3
),
dropout_p=dropout,
is_causal=attention_mask is None,
scale=scaling
)
return attention.transpose(1, 2).contiguous(), None
warnings.filterwarnings('ignore', category=UserWarning, module='torch.cuda')
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
transformers.AttentionInterface.register('sdpa', sdpa_attention_forward)
transformers.AttentionMaskInterface.register('sdpa', transformers.masking_utils.eager_mask)
transformers.models.gpt_oss.modeling_gpt_oss.GptOssPreTrainedModel._supports_sdpa = True
model = transformers.AutoModelForCausalLM.from_pretrained(
'openai/gpt-oss-20b',
attn_implementation='sdpa',
dtype=torch.bfloat16,
quantization_config=transformers.Mxfp4Config(dequantize=True)
)
Last active
November 26, 2025 15:39
-
-
Save markrogersjr/ebada9ad3a31381d8d4e0d956c852569 to your computer and use it in GitHub Desktop.
GPT-OSS with Flash Attention and Memory-Efficient Attention via PyTorch-Native SDPA
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment