Skip to content

Instantly share code, notes, and snippets.

@LiutongZhou
Last active February 27, 2026 16:09
Show Gist options
  • Select an option

  • Save LiutongZhou/ec06bfc6e66fe2ffffaea6569f47e35e to your computer and use it in GitHub Desktop.

Select an option

Save LiutongZhou/ec06bfc6e66fe2ffffaea6569f47e35e to your computer and use it in GitHub Desktop.
Benchmarking Jax 0.7+ vs Pytorch 2.10+ Attention Speed
"""Attention benchmark for PyTorch and JAX.
This module benchmarks variable-length self-attention implementations across
PyTorch and JAX, reporting:
- Median latency (ms)
- Peak device memory (MB)
- Throughput (TFLOP/s)
Benchmarked variants:
- PyTorch naive masked attention on padded tensors
- PyTorch SDPA on padded tensors
- PyTorch SDPA on NestedTensor inputs
- PyTorch Flex Attention on packed tokens with block masking
- PyTorch varlen_attn on packed tokens with cumulative sequence lengths
- Jax Naive masked attention on padded tensors
- JAX dot_product_attention (XLA) on padded tensors
- JAX dot_product_attention (cuDNN) using sequence lengths
The benchmark samples per-sequence lengths in a configured range, runs warmup
and timed iterations, and prints a summary table for all variants.
"""
from __future__ import annotations
import argparse
import enum
import gc
import time
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
from dataclasses import dataclass, field
from typing import Any, Callable, Self, TypedDict
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import torch
from torch.nn import functional as F
from torch.nn.attention.flex_attention import (
BlockMask,
create_block_mask,
flex_attention,
)
from torch.nn.attention.varlen import varlen_attn
from packaging import version
assert version.parse(torch.__version__) >= version.parse("2.10")
assert version.parse(jax.__version__) >= version.parse("0.7.2")
class DType(enum.StrEnum):
FLOAT16 = "float16"
BFLOAT16 = "bfloat16"
@property
def jax(self) -> jnp.dtype:
match self:
case DType.FLOAT16:
return jnp.float16
case DType.BFLOAT16:
return jnp.bfloat16
case _:
raise ValueError(f"Unsupported dtype: {self}")
@property
def torch(self) -> torch.dtype:
match self:
case DType.FLOAT16:
return torch.float16
case DType.BFLOAT16:
return torch.bfloat16
case _:
raise ValueError(f"Unsupported dtype: {self}")
@dataclass(frozen=True, slots=True)
class Config:
"""Benchmark settings."""
batch_size: int = 8
num_heads: int = 32
head_dim: int = 128
seq_min: int = 128
seq_max: int = 2048
dtype: DType = DType.BFLOAT16
is_causal: bool = True
warmup: int = 2
iters: int = 20
seed: int = 42
_seq_lens: tuple[int, ...] = field(init=False)
def __post_init__(self) -> None:
"""Initialize sampled sequence lengths."""
rng = np.random.default_rng(self.seed)
sampled = rng.integers(
self.seq_min, self.seq_max + 1, size=self.batch_size
).tolist()
object.__setattr__(self, "_seq_lens", tuple(sampled))
@property
def seq_lens(self) -> list[int]:
"""Return sampled sequence lengths as a list."""
return list(self._seq_lens)
@property
def flops(self) -> int:
"""Theoretical forward FLOPs for ragged self-attention."""
flops = sum(4 * self.num_heads * s * s * self.head_dim for s in self.seq_lens)
return flops // 2 if self.is_causal else flops
@property
def max_seq_len(self) -> int:
"""Maximum sequence length in the batch."""
return max(self._seq_lens)
@property
def total_tokens(self) -> int:
"""Total number of tokens across all sequences."""
return sum(self._seq_lens)
type AttentionFn = Callable[..., Any]
type InputFactory[T: (jax.Array, torch.Tensor)] = Callable[[Config], T]
class Bench(AbstractContextManager, ABC):
"""Base class for framework-specific attention benchmarks."""
def __init__(
self,
variant: str,
strategy: str,
attention_fn: AttentionFn,
input_factory: InputFactory[Any],
cfg: Config,
*,
compile_fn: Callable[[AttentionFn], AttentionFn] | None = None,
) -> None:
self.variant = variant
self.strategy = strategy
self.attention_fn = attention_fn
self.input_factory = input_factory
self.cfg = cfg
self._compile_fn = compile_fn
self._inputs: Any = None
self._compiled_fn: AttentionFn | None = None
def __enter__(self) -> Self:
self._inputs = self.input_factory(self.cfg)
self._compiled_fn = (
self._compile_fn(self.attention_fn)
if self._compile_fn is not None
else self.attention_fn
)
return self
def __exit__(self, *_exc: object) -> None:
del self._inputs
self._compiled_fn = None
gc.collect()
@abstractmethod
def _call_fn(self) -> Any:
"""Invoke the prepared function on inputs and block until complete."""
assert self._compiled_fn is not None
assert self._inputs is not None
return self._compiled_fn(self._inputs)
def _warmup(self) -> None:
"""Run warmup iterations to prime caches / JIT."""
for _ in range(self.cfg.warmup):
self._call_fn()
@abstractmethod
def _profile_memory(self) -> float:
"""Run once and return peak memory in MB."""
def _profile_speed(self) -> list[float]:
"""Run timed iterations and return per-iteration durations in seconds."""
timings: list[float] = []
for _ in range(self.cfg.iters):
t0 = time.perf_counter()
self._call_fn()
timings.append(time.perf_counter() - t0)
return timings
def profile(self) -> Result:
"""Run warmup, memory, and speed measurements and return a Result."""
self._warmup()
peak_mb = self._profile_memory()
timings = self._profile_speed()
median_s = float(np.median(timings))
return Result(
variant=self.variant,
strategy=self.strategy,
median_ms=median_s * 1e3,
peak_mb=peak_mb,
tflops=self.cfg.flops / median_s / 1e12,
)
class TorchBench(Bench):
"""PyTorch CUDA benchmark."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs.setdefault("compile_fn", torch.compile)
super().__init__(*args, **kwargs)
def _call_fn(self) -> Any:
super()._call_fn()
torch.cuda.synchronize()
def _profile_memory(self) -> float:
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
before = torch.cuda.max_memory_allocated()
self._call_fn()
after = torch.cuda.max_memory_allocated()
return (after - before) / 1e6
def __exit__(self, *_exc: object) -> None:
super().__exit__(*_exc)
torch.cuda.empty_cache()
class JaxBench(Bench):
"""JAX benchmark."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs.setdefault("compile_fn", jax.jit)
super().__init__(*args, **kwargs)
def _call_fn(self) -> Any:
out = super()._call_fn()
return jax.block_until_ready(out)
def _profile_memory(self) -> float:
device = jax.local_devices()[0]
bytes_in_use_before = (device.memory_stats()).get("bytes_in_use")
self._call_fn()
bytes_in_use_after = (device.memory_stats()).get("peak_bytes_in_use")
return (bytes_in_use_after - bytes_in_use_before) / 1e6
class Result(TypedDict):
"""One benchmark row."""
variant: str
strategy: str
median_ms: float
peak_mb: float
tflops: float
@dataclass(slots=True)
class ResultRecorder:
"""Records benchmark results and reports them."""
results: list[Result] = field(default_factory=list)
def add(self, result: Result) -> None:
"""Add a single benchmark result."""
self.results.append(result)
def to_dataframe(self) -> pd.DataFrame:
"""Return results as a pandas DataFrame."""
return pd.DataFrame(self.results)
def print_header(self, cfg: Config) -> None:
"""Print the benchmark configuration and results table header."""
# Print hardware and software info
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"PyTorch version: {torch.__version__}")
print(f"JAX version: {jax.__version__}")
print(f"Config: {cfg}")
print(f"Seq lens: {cfg.seq_lens}")
print(f"Forward FLOPs: {cfg.flops / 1e9:.2f} GFLOP")
def print_results(self) -> None:
"""Print all recorded results as a table."""
if not self.results:
print("No results recorded.")
return
df = (
self.to_dataframe()
.rename(
columns={
"variant": "Variant",
"strategy": "Strategy",
"median_ms": "Median (ms)",
"peak_mb": "Peak (MB)",
"tflops": "TFLOP/s",
}
)
.sort_values("Median (ms)")
)
print(df.to_string(index=False, float_format=lambda value: f"{value:.2f}"))
@dataclass(frozen=True, slots=True)
class InputsBase[T: (jax.Array, torch.Tensor)](ABC):
"""Abstract base class for attention inputs"""
q: T
k: T
v: T
@classmethod
@abstractmethod
def from_config(cls, cfg: Config) -> Self:
"""Build inputs from benchmark config."""
@dataclass(frozen=True, slots=True)
class TorchPaddedInputs(InputsBase[torch.Tensor]):
"""Prepared padded inputs for torch attention variants.
q, k, v : Shape (batch, num_heads, max_seq_len, head_dim) with padding for shorter sequences.
"""
mask: torch.Tensor
"""Shape (batch, 1, max_seq_len, max_seq_len) bool tensor where True indicates valid query-key pairs."""
def __post_init__(self) -> None:
"""Validate that padded inputs are regular (non-nested) tensors."""
tensors = {"q": self.q, "k": self.k, "v": self.v, "mask": self.mask}
for name, tensor in tensors.items():
if bool(getattr(tensor, "is_nested", False)):
raise TypeError(f"TorchPaddedInputs.{name} must be a non-nested tensor")
@classmethod
def from_config(cls, cfg: Config) -> Self:
"""Build padded Q/K/V and bool attention mask."""
torch_nested_inputs = TorchNestedInputs.from_config(cfg)
q = torch.nested.to_padded_tensor(torch_nested_inputs.q, 0.0)
k = torch.nested.to_padded_tensor(torch_nested_inputs.k, 0.0)
v = torch.nested.to_padded_tensor(torch_nested_inputs.v, 0.0)
valid = (
torch.arange(cfg.max_seq_len, device="cuda")[None, :]
< torch.tensor(cfg.seq_lens, device="cuda")[:, None]
)
mask_pad = valid[:, None, :, None] & valid[:, None, None, :]
if cfg.is_causal:
mask_causal = torch.ones(
cfg.max_seq_len, cfg.max_seq_len, dtype=torch.bool, device="cuda"
).tril()
mask = mask_pad & mask_causal
else:
mask = mask_pad
return cls(q=q, k=k, v=v, mask=mask)
@dataclass(frozen=True, slots=True)
class TorchNestedInputs(InputsBase[torch.Tensor]):
"""Prepared nested inputs for torch SDPA.
q, k, v : Shape (batch, num_heads, jagged_len, head_dim)
"""
def __post_init__(self) -> None:
"""Validate that nested inputs are NestedTensor instances."""
tensors = {"q": self.q, "k": self.k, "v": self.v}
for name, tensor in tensors.items():
if not bool(getattr(tensor, "is_nested", False)):
raise TypeError(f"TorchNestedInputs.{name} must be a nested tensor")
@classmethod
def from_config(cls, cfg: Config) -> Self:
"""Build NestedTensor Q/K/V using shared offsets."""
offsets = torch.zeros(cfg.batch_size + 1, dtype=torch.int32, device="cuda")
offsets[1:] = torch.tensor(cfg.seq_lens, dtype=torch.int32, device="cuda")
offsets = offsets.cumsum(0)
qkv = torch.nested.nested_tensor_from_jagged(
values=torch.randn(
(cfg.total_tokens, cfg.num_heads, cfg.head_dim, 3),
dtype=cfg.dtype.torch,
device="cuda",
),
# offsets=offsets,
lengths=torch.as_tensor(cfg.seq_lens, dtype=torch.int32, device="cuda"),
min_seqlen=min(cfg.seq_lens),
max_seqlen=cfg.max_seq_len,
) # (B, jagged_len, H, D, 3)
q = qkv.select(dim=-1, index=0).transpose(2, 1).contiguous()
k = qkv.select(dim=-1, index=1).transpose(2, 1).contiguous()
v = qkv.select(dim=-1, index=2).transpose(2, 1).contiguous()
assert q.shape[0] == k.shape[0] == v.shape[0] == cfg.batch_size
assert q.shape[1] == k.shape[1] == v.shape[1] == cfg.num_heads
assert q.shape[-1] == k.shape[-1] == v.shape[-1] == cfg.head_dim
return cls(q=q, k=k, v=v)
@dataclass(frozen=True, slots=True)
class TorchFlexInputs(InputsBase[torch.Tensor]):
"""Packed inputs for Flex Attention with document-ID block masking.
Tensors have shape (1, num_heads, total_tokens, head_dim) — all sequences
are concatenated into a single "batch" and a BlockMask encodes which
query-key pairs belong to the same document.
"""
block_mask: BlockMask
@classmethod
def from_config(cls, cfg: Config) -> Self:
"""Build packed Q/K/V and document-ID block mask."""
q, k, v = torch.randn(
3,
1,
cfg.num_heads,
cfg.total_tokens,
cfg.head_dim,
dtype=cfg.dtype.torch,
device="cuda",
).unbind(0)
# Build document-ID tensor: each token maps to its sequence index.
document_id = torch.cat(
[
torch.full((s,), i, dtype=torch.int8, device="cuda")
for i, s in enumerate(cfg.seq_lens)
]
)
def mask_mod(
_b: torch.Tensor,
_h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
) -> torch.Tensor:
same_doc = document_id[q_idx] == document_id[kv_idx]
if cfg.is_causal:
return same_doc & (q_idx >= kv_idx)
return same_doc
block_mask = create_block_mask(
mask_mod,
B=None,
H=None,
Q_LEN=cfg.total_tokens,
KV_LEN=cfg.total_tokens,
device="cuda",
)
return cls(q=q, k=k, v=v, block_mask=block_mask)
@dataclass(frozen=True, slots=True)
class TorchVarlenInputs(InputsBase[torch.Tensor]):
"""Packed inputs for varlen_attn with cumulative sequence lengths.
Tensors have shape (total_tokens, num_heads, head_dim) — all sequences
are concatenated along the token dimension. ``cu_seqlens`` is the
cumulative-sum offset array of shape (N+1,).
"""
cu_seqlens: torch.Tensor
max_seqlen: int
@classmethod
def from_config(cls, cfg: Config) -> Self:
"""Build packed Q/K/V and cumulative sequence length tensor."""
q, k, v = torch.randn(
3,
cfg.total_tokens,
cfg.num_heads,
cfg.head_dim,
dtype=cfg.dtype.torch,
device="cuda",
).unbind(0)
cu_seqlens = torch.zeros(cfg.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens[1:] = torch.tensor(cfg.seq_lens, dtype=torch.int32, device="cuda")
cu_seqlens = cu_seqlens.cumsum(0, dtype=torch.int32)
return cls(q=q, k=k, v=v, cu_seqlens=cu_seqlens, max_seqlen=cfg.max_seq_len)
@jax.tree_util.register_dataclass
@dataclass(frozen=True, slots=True)
class JaxInputs(InputsBase[jax.Array]):
"""Prepared inputs and sequence lengths for JAX Attention variants.
Shape (batch, max_seq_len, num_heads, head_dim) with padding for shorter sequences.
"""
mask: jax.Array
"""Shape (batch, 1, max_seq_len, max_seq_len) bool array where True indicates valid query-key pairs."""
seq_lengths: jax.Array
"""Will be used by JAX cuDNN attention to specify valid token counts per sequence."""
@classmethod
def from_config(cls, cfg: Config) -> Self:
"""Build padded Q/K/V + mask + seq lengths."""
seq_lengths = jnp.array(cfg.seq_lens, dtype=jnp.int32)
max_len = cfg.max_seq_len
qkv = jax.random.normal(
jax.random.key(cfg.seed),
(3, cfg.batch_size, max_len, cfg.num_heads, cfg.head_dim),
dtype=cfg.dtype.jax,
)
q, k, v = qkv[0], qkv[1], qkv[2]
valid = jnp.arange(max_len)[None, :] < seq_lengths[:, None]
mask = valid[:, None, :, None] & valid[:, None, None, :]
if cfg.is_causal:
mask = mask & jnp.tril(jnp.ones((max_len, max_len), dtype=jnp.bool))
return cls(q=q, k=k, v=v, mask=mask, seq_lengths=seq_lengths)
def torch_naive_attention(inputs: TorchPaddedInputs) -> torch.Tensor:
"""Manual attention baseline."""
scale = inputs.q.shape[-1] ** -0.5
scores = (inputs.q @ inputs.k.transpose(-2, -1)) * scale
scores = scores.masked_fill(~inputs.mask, float("-inf"))
probs = torch.softmax(scores, dim=-1)
return probs @ inputs.v
def torch_sdpa_padded(inputs: TorchPaddedInputs) -> torch.Tensor:
"""PyTorch SDPA with explicit bool mask on padded tensors."""
return F.scaled_dot_product_attention(
inputs.q, inputs.k, inputs.v, attn_mask=inputs.mask
)
def torch_sdpa_nested(inputs: TorchNestedInputs) -> torch.Tensor:
"""PyTorch SDPA on NestedTensor inputs."""
return F.scaled_dot_product_attention(inputs.q, inputs.k, inputs.v, is_causal=True)
def torch_flex_packed(inputs: TorchFlexInputs) -> torch.Tensor:
"""Flex Attention on packed sequences with document-ID block masking."""
return flex_attention(inputs.q, inputs.k, inputs.v, block_mask=inputs.block_mask)
def torch_varlen_attention(inputs: TorchVarlenInputs) -> torch.Tensor:
"""Variable-length attention via FlashAttention-backed varlen_attn."""
return varlen_attn(
inputs.q,
inputs.k,
inputs.v,
cu_seq_q=inputs.cu_seqlens,
cu_seq_k=inputs.cu_seqlens,
max_q=inputs.max_seqlen,
max_k=inputs.max_seqlen,
is_causal=True,
)
def jax_sdpa_xla(inputs: JaxInputs) -> jax.Array:
"""JAX attention with XLA implementation."""
return jax.nn.dot_product_attention(
inputs.q, inputs.k, inputs.v, mask=inputs.mask, implementation="xla"
)
def jax_naive(inputs: JaxInputs) -> jax.Array:
"""Manual attention baseline in JAX."""
q, k, v = inputs.q, inputs.k, inputs.v
q = q.transpose(0, 2, 1, 3) # (B, H, Q_LEN, D)
k = k.transpose(0, 2, 1, 3) # (B, H, K_LEN, D)
v = v.transpose(0, 2, 1, 3) # (B, H, V_LEN, D)
scale = q.shape[-1] ** -0.5
scores = (scale * q) @ k.transpose(0, 1, 3, 2)
scores = jnp.where(inputs.mask, scores, float("-inf"))
probs = jax.nn.softmax(scores, axis=-1)
return probs @ v
def jax_sdpa_cudnn(inputs: JaxInputs) -> jax.Array:
"""JAX attention with cuDNN implementation and ragged seq lengths."""
return jax.nn.dot_product_attention(
inputs.q,
inputs.k,
inputs.v,
is_causal=True,
query_seq_lengths=inputs.seq_lengths,
key_value_seq_lengths=inputs.seq_lengths,
implementation="cudnn",
)
def run(cfg: Config) -> list[Result]:
"""Run all benchmark groups and collect results."""
recorder = ResultRecorder()
recorder.print_header(cfg)
bench_specs: list[tuple[type[Bench], str, str, AttentionFn, InputFactory[Any]]] = [
(
TorchBench,
"torch_naive",
"padded",
torch_naive_attention,
TorchPaddedInputs.from_config,
),
(
TorchBench,
"torch_sdpa",
"padded",
torch_sdpa_padded,
TorchPaddedInputs.from_config,
),
(
TorchBench,
"torch_sdpa",
"nested",
torch_sdpa_nested,
TorchNestedInputs.from_config,
),
(
TorchBench,
"torch_flex",
"packed",
torch_flex_packed,
TorchFlexInputs.from_config,
),
(
TorchBench,
"torch_varlen",
"packed",
torch_varlen_attention,
TorchVarlenInputs.from_config,
),
(
JaxBench,
"jax_cudnn",
"padded using seq_lens",
jax_sdpa_cudnn,
JaxInputs.from_config,
),
(JaxBench, "jax_xla", "padded", jax_sdpa_xla, JaxInputs.from_config),
(JaxBench, "jax_naive", "padded", jax_naive, JaxInputs.from_config),
]
for bench_cls, variant, strategy, attention_fn, input_factory in bench_specs:
print(f"Running {variant} ({strategy})...")
with bench_cls(variant, strategy, attention_fn, input_factory, cfg) as bench:
recorder.add(bench.profile())
recorder.print_results()
return recorder.results
def parse_args() -> Config:
"""Parse CLI options into Config."""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--num-heads", type=int, default=32)
parser.add_argument("--head-dim", type=int, default=128)
parser.add_argument("--seq-min", type=int, default=128)
parser.add_argument("--seq-max", type=int, default=2048)
parser.add_argument(
"--dtype", choices=list(DType), type=DType, default=DType.BFLOAT16
)
parser.add_argument("--warmup", type=int, default=2)
parser.add_argument("--iters", type=int, default=30)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
return Config(
batch_size=args.batch_size,
num_heads=args.num_heads,
head_dim=args.head_dim,
seq_min=args.seq_min,
seq_max=args.seq_max,
dtype=args.dtype,
warmup=args.warmup,
iters=args.iters,
seed=args.seed,
)
def main() -> None:
"""CLI entrypoint."""
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available")
cfg = parse_args()
run(cfg)
if __name__ == "__main__":
main()

Benchmarking Jax 0.7+ vs Pytorch 2.10+ Attention Variants' Speed and Memory

torch.nn.attention.varlen.varlen_attn is the absolute winner on GPU

GPU: NVIDIA GeForce RTX 4090 Laptop GPU
PyTorch version: 2.10.0
JAX version: 0.7.2
Config: Config(batch_size=8, num_heads=32, head_dim=128, seq_min=128, seq_max=2048, dtype=<DType.BFLOAT16: 'bfloat16'>, is_causal=True, warmup=2, iters=30, seed=42, _seq_lens=(299, 1614, 1385, 971, 959, 1777, 293, 1467))
Seq lens: [299, 1614, 1385, 971, 959, 1777, 293, 1467]
Forward FLOPs: 97.25 GFLOP

     Variant              Strategy  Median (ms)  Peak (MB)  TFLOP/s
torch_varlen                packed         1.50      72.93    64.68
  torch_flex                packed         2.01      74.05    48.42
   jax_cudnn padded using seq_lens         2.41     440.57    40.31
  torch_sdpa                nested         2.57     289.41    37.84
  torch_sdpa                padded        10.41     167.18     9.34
   jax_naive                padded        16.92    6848.02     5.75
 torch_naive                padded        18.55    3271.29     5.24
     jax_xla                padded        22.98    6848.02     4.23
     
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment