|
"""Attention benchmark for PyTorch and JAX. |
|
|
|
This module benchmarks variable-length self-attention implementations across |
|
PyTorch and JAX, reporting: |
|
|
|
- Median latency (ms) |
|
- Peak device memory (MB) |
|
- Throughput (TFLOP/s) |
|
|
|
Benchmarked variants: |
|
- PyTorch naive masked attention on padded tensors |
|
- PyTorch SDPA on padded tensors |
|
- PyTorch SDPA on NestedTensor inputs |
|
- PyTorch Flex Attention on packed tokens with block masking |
|
- PyTorch varlen_attn on packed tokens with cumulative sequence lengths |
|
- Jax Naive masked attention on padded tensors |
|
- JAX dot_product_attention (XLA) on padded tensors |
|
- JAX dot_product_attention (cuDNN) using sequence lengths |
|
|
|
The benchmark samples per-sequence lengths in a configured range, runs warmup |
|
and timed iterations, and prints a summary table for all variants. |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import argparse |
|
import enum |
|
import gc |
|
import time |
|
from abc import ABC, abstractmethod |
|
from contextlib import AbstractContextManager |
|
from dataclasses import dataclass, field |
|
from typing import Any, Callable, Self, TypedDict |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from torch.nn import functional as F |
|
from torch.nn.attention.flex_attention import ( |
|
BlockMask, |
|
create_block_mask, |
|
flex_attention, |
|
) |
|
from torch.nn.attention.varlen import varlen_attn |
|
from packaging import version |
|
|
|
assert version.parse(torch.__version__) >= version.parse("2.10") |
|
assert version.parse(jax.__version__) >= version.parse("0.7.2") |
|
|
|
|
|
class DType(enum.StrEnum): |
|
FLOAT16 = "float16" |
|
BFLOAT16 = "bfloat16" |
|
|
|
@property |
|
def jax(self) -> jnp.dtype: |
|
match self: |
|
case DType.FLOAT16: |
|
return jnp.float16 |
|
case DType.BFLOAT16: |
|
return jnp.bfloat16 |
|
case _: |
|
raise ValueError(f"Unsupported dtype: {self}") |
|
|
|
@property |
|
def torch(self) -> torch.dtype: |
|
match self: |
|
case DType.FLOAT16: |
|
return torch.float16 |
|
case DType.BFLOAT16: |
|
return torch.bfloat16 |
|
case _: |
|
raise ValueError(f"Unsupported dtype: {self}") |
|
|
|
|
|
@dataclass(frozen=True, slots=True) |
|
class Config: |
|
"""Benchmark settings.""" |
|
|
|
batch_size: int = 8 |
|
num_heads: int = 32 |
|
head_dim: int = 128 |
|
seq_min: int = 128 |
|
seq_max: int = 2048 |
|
dtype: DType = DType.BFLOAT16 |
|
is_causal: bool = True |
|
warmup: int = 2 |
|
iters: int = 20 |
|
seed: int = 42 |
|
_seq_lens: tuple[int, ...] = field(init=False) |
|
|
|
def __post_init__(self) -> None: |
|
"""Initialize sampled sequence lengths.""" |
|
rng = np.random.default_rng(self.seed) |
|
sampled = rng.integers( |
|
self.seq_min, self.seq_max + 1, size=self.batch_size |
|
).tolist() |
|
object.__setattr__(self, "_seq_lens", tuple(sampled)) |
|
|
|
@property |
|
def seq_lens(self) -> list[int]: |
|
"""Return sampled sequence lengths as a list.""" |
|
return list(self._seq_lens) |
|
|
|
@property |
|
def flops(self) -> int: |
|
"""Theoretical forward FLOPs for ragged self-attention.""" |
|
flops = sum(4 * self.num_heads * s * s * self.head_dim for s in self.seq_lens) |
|
return flops // 2 if self.is_causal else flops |
|
|
|
@property |
|
def max_seq_len(self) -> int: |
|
"""Maximum sequence length in the batch.""" |
|
return max(self._seq_lens) |
|
|
|
@property |
|
def total_tokens(self) -> int: |
|
"""Total number of tokens across all sequences.""" |
|
return sum(self._seq_lens) |
|
|
|
|
|
type AttentionFn = Callable[..., Any] |
|
type InputFactory[T: (jax.Array, torch.Tensor)] = Callable[[Config], T] |
|
|
|
|
|
class Bench(AbstractContextManager, ABC): |
|
"""Base class for framework-specific attention benchmarks.""" |
|
|
|
def __init__( |
|
self, |
|
variant: str, |
|
strategy: str, |
|
attention_fn: AttentionFn, |
|
input_factory: InputFactory[Any], |
|
cfg: Config, |
|
*, |
|
compile_fn: Callable[[AttentionFn], AttentionFn] | None = None, |
|
) -> None: |
|
self.variant = variant |
|
self.strategy = strategy |
|
self.attention_fn = attention_fn |
|
self.input_factory = input_factory |
|
self.cfg = cfg |
|
self._compile_fn = compile_fn |
|
self._inputs: Any = None |
|
self._compiled_fn: AttentionFn | None = None |
|
|
|
def __enter__(self) -> Self: |
|
self._inputs = self.input_factory(self.cfg) |
|
self._compiled_fn = ( |
|
self._compile_fn(self.attention_fn) |
|
if self._compile_fn is not None |
|
else self.attention_fn |
|
) |
|
return self |
|
|
|
def __exit__(self, *_exc: object) -> None: |
|
del self._inputs |
|
self._compiled_fn = None |
|
gc.collect() |
|
|
|
@abstractmethod |
|
def _call_fn(self) -> Any: |
|
"""Invoke the prepared function on inputs and block until complete.""" |
|
assert self._compiled_fn is not None |
|
assert self._inputs is not None |
|
return self._compiled_fn(self._inputs) |
|
|
|
def _warmup(self) -> None: |
|
"""Run warmup iterations to prime caches / JIT.""" |
|
for _ in range(self.cfg.warmup): |
|
self._call_fn() |
|
|
|
@abstractmethod |
|
def _profile_memory(self) -> float: |
|
"""Run once and return peak memory in MB.""" |
|
|
|
def _profile_speed(self) -> list[float]: |
|
"""Run timed iterations and return per-iteration durations in seconds.""" |
|
timings: list[float] = [] |
|
for _ in range(self.cfg.iters): |
|
t0 = time.perf_counter() |
|
self._call_fn() |
|
timings.append(time.perf_counter() - t0) |
|
return timings |
|
|
|
def profile(self) -> Result: |
|
"""Run warmup, memory, and speed measurements and return a Result.""" |
|
self._warmup() |
|
peak_mb = self._profile_memory() |
|
timings = self._profile_speed() |
|
median_s = float(np.median(timings)) |
|
return Result( |
|
variant=self.variant, |
|
strategy=self.strategy, |
|
median_ms=median_s * 1e3, |
|
peak_mb=peak_mb, |
|
tflops=self.cfg.flops / median_s / 1e12, |
|
) |
|
|
|
|
|
class TorchBench(Bench): |
|
"""PyTorch CUDA benchmark.""" |
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None: |
|
kwargs.setdefault("compile_fn", torch.compile) |
|
super().__init__(*args, **kwargs) |
|
|
|
def _call_fn(self) -> Any: |
|
super()._call_fn() |
|
torch.cuda.synchronize() |
|
|
|
def _profile_memory(self) -> float: |
|
torch.cuda.reset_peak_memory_stats() |
|
torch.cuda.synchronize() |
|
before = torch.cuda.max_memory_allocated() |
|
self._call_fn() |
|
after = torch.cuda.max_memory_allocated() |
|
return (after - before) / 1e6 |
|
|
|
def __exit__(self, *_exc: object) -> None: |
|
super().__exit__(*_exc) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
class JaxBench(Bench): |
|
"""JAX benchmark.""" |
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None: |
|
kwargs.setdefault("compile_fn", jax.jit) |
|
super().__init__(*args, **kwargs) |
|
|
|
def _call_fn(self) -> Any: |
|
out = super()._call_fn() |
|
return jax.block_until_ready(out) |
|
|
|
def _profile_memory(self) -> float: |
|
device = jax.local_devices()[0] |
|
bytes_in_use_before = (device.memory_stats()).get("bytes_in_use") |
|
self._call_fn() |
|
bytes_in_use_after = (device.memory_stats()).get("peak_bytes_in_use") |
|
return (bytes_in_use_after - bytes_in_use_before) / 1e6 |
|
|
|
|
|
class Result(TypedDict): |
|
"""One benchmark row.""" |
|
|
|
variant: str |
|
strategy: str |
|
median_ms: float |
|
peak_mb: float |
|
tflops: float |
|
|
|
|
|
@dataclass(slots=True) |
|
class ResultRecorder: |
|
"""Records benchmark results and reports them.""" |
|
|
|
results: list[Result] = field(default_factory=list) |
|
|
|
def add(self, result: Result) -> None: |
|
"""Add a single benchmark result.""" |
|
self.results.append(result) |
|
|
|
def to_dataframe(self) -> pd.DataFrame: |
|
"""Return results as a pandas DataFrame.""" |
|
return pd.DataFrame(self.results) |
|
|
|
def print_header(self, cfg: Config) -> None: |
|
"""Print the benchmark configuration and results table header.""" |
|
# Print hardware and software info |
|
print(f"GPU: {torch.cuda.get_device_name(0)}") |
|
print(f"PyTorch version: {torch.__version__}") |
|
print(f"JAX version: {jax.__version__}") |
|
print(f"Config: {cfg}") |
|
print(f"Seq lens: {cfg.seq_lens}") |
|
print(f"Forward FLOPs: {cfg.flops / 1e9:.2f} GFLOP") |
|
|
|
def print_results(self) -> None: |
|
"""Print all recorded results as a table.""" |
|
if not self.results: |
|
print("No results recorded.") |
|
return |
|
|
|
df = ( |
|
self.to_dataframe() |
|
.rename( |
|
columns={ |
|
"variant": "Variant", |
|
"strategy": "Strategy", |
|
"median_ms": "Median (ms)", |
|
"peak_mb": "Peak (MB)", |
|
"tflops": "TFLOP/s", |
|
} |
|
) |
|
.sort_values("Median (ms)") |
|
) |
|
print(df.to_string(index=False, float_format=lambda value: f"{value:.2f}")) |
|
|
|
|
|
@dataclass(frozen=True, slots=True) |
|
class InputsBase[T: (jax.Array, torch.Tensor)](ABC): |
|
"""Abstract base class for attention inputs""" |
|
|
|
q: T |
|
k: T |
|
v: T |
|
|
|
@classmethod |
|
@abstractmethod |
|
def from_config(cls, cfg: Config) -> Self: |
|
"""Build inputs from benchmark config.""" |
|
|
|
|
|
@dataclass(frozen=True, slots=True) |
|
class TorchPaddedInputs(InputsBase[torch.Tensor]): |
|
"""Prepared padded inputs for torch attention variants. |
|
|
|
q, k, v : Shape (batch, num_heads, max_seq_len, head_dim) with padding for shorter sequences. |
|
""" |
|
|
|
mask: torch.Tensor |
|
"""Shape (batch, 1, max_seq_len, max_seq_len) bool tensor where True indicates valid query-key pairs.""" |
|
|
|
def __post_init__(self) -> None: |
|
"""Validate that padded inputs are regular (non-nested) tensors.""" |
|
tensors = {"q": self.q, "k": self.k, "v": self.v, "mask": self.mask} |
|
for name, tensor in tensors.items(): |
|
if bool(getattr(tensor, "is_nested", False)): |
|
raise TypeError(f"TorchPaddedInputs.{name} must be a non-nested tensor") |
|
|
|
@classmethod |
|
def from_config(cls, cfg: Config) -> Self: |
|
"""Build padded Q/K/V and bool attention mask.""" |
|
torch_nested_inputs = TorchNestedInputs.from_config(cfg) |
|
q = torch.nested.to_padded_tensor(torch_nested_inputs.q, 0.0) |
|
k = torch.nested.to_padded_tensor(torch_nested_inputs.k, 0.0) |
|
v = torch.nested.to_padded_tensor(torch_nested_inputs.v, 0.0) |
|
|
|
valid = ( |
|
torch.arange(cfg.max_seq_len, device="cuda")[None, :] |
|
< torch.tensor(cfg.seq_lens, device="cuda")[:, None] |
|
) |
|
mask_pad = valid[:, None, :, None] & valid[:, None, None, :] |
|
if cfg.is_causal: |
|
mask_causal = torch.ones( |
|
cfg.max_seq_len, cfg.max_seq_len, dtype=torch.bool, device="cuda" |
|
).tril() |
|
mask = mask_pad & mask_causal |
|
else: |
|
mask = mask_pad |
|
return cls(q=q, k=k, v=v, mask=mask) |
|
|
|
|
|
@dataclass(frozen=True, slots=True) |
|
class TorchNestedInputs(InputsBase[torch.Tensor]): |
|
"""Prepared nested inputs for torch SDPA. |
|
|
|
q, k, v : Shape (batch, num_heads, jagged_len, head_dim) |
|
""" |
|
|
|
def __post_init__(self) -> None: |
|
"""Validate that nested inputs are NestedTensor instances.""" |
|
tensors = {"q": self.q, "k": self.k, "v": self.v} |
|
for name, tensor in tensors.items(): |
|
if not bool(getattr(tensor, "is_nested", False)): |
|
raise TypeError(f"TorchNestedInputs.{name} must be a nested tensor") |
|
|
|
@classmethod |
|
def from_config(cls, cfg: Config) -> Self: |
|
"""Build NestedTensor Q/K/V using shared offsets.""" |
|
offsets = torch.zeros(cfg.batch_size + 1, dtype=torch.int32, device="cuda") |
|
offsets[1:] = torch.tensor(cfg.seq_lens, dtype=torch.int32, device="cuda") |
|
offsets = offsets.cumsum(0) |
|
|
|
qkv = torch.nested.nested_tensor_from_jagged( |
|
values=torch.randn( |
|
(cfg.total_tokens, cfg.num_heads, cfg.head_dim, 3), |
|
dtype=cfg.dtype.torch, |
|
device="cuda", |
|
), |
|
# offsets=offsets, |
|
lengths=torch.as_tensor(cfg.seq_lens, dtype=torch.int32, device="cuda"), |
|
min_seqlen=min(cfg.seq_lens), |
|
max_seqlen=cfg.max_seq_len, |
|
) # (B, jagged_len, H, D, 3) |
|
q = qkv.select(dim=-1, index=0).transpose(2, 1).contiguous() |
|
k = qkv.select(dim=-1, index=1).transpose(2, 1).contiguous() |
|
v = qkv.select(dim=-1, index=2).transpose(2, 1).contiguous() |
|
assert q.shape[0] == k.shape[0] == v.shape[0] == cfg.batch_size |
|
assert q.shape[1] == k.shape[1] == v.shape[1] == cfg.num_heads |
|
assert q.shape[-1] == k.shape[-1] == v.shape[-1] == cfg.head_dim |
|
return cls(q=q, k=k, v=v) |
|
|
|
|
|
@dataclass(frozen=True, slots=True) |
|
class TorchFlexInputs(InputsBase[torch.Tensor]): |
|
"""Packed inputs for Flex Attention with document-ID block masking. |
|
|
|
Tensors have shape (1, num_heads, total_tokens, head_dim) — all sequences |
|
are concatenated into a single "batch" and a BlockMask encodes which |
|
query-key pairs belong to the same document. |
|
""" |
|
|
|
block_mask: BlockMask |
|
|
|
@classmethod |
|
def from_config(cls, cfg: Config) -> Self: |
|
"""Build packed Q/K/V and document-ID block mask.""" |
|
q, k, v = torch.randn( |
|
3, |
|
1, |
|
cfg.num_heads, |
|
cfg.total_tokens, |
|
cfg.head_dim, |
|
dtype=cfg.dtype.torch, |
|
device="cuda", |
|
).unbind(0) |
|
|
|
# Build document-ID tensor: each token maps to its sequence index. |
|
document_id = torch.cat( |
|
[ |
|
torch.full((s,), i, dtype=torch.int8, device="cuda") |
|
for i, s in enumerate(cfg.seq_lens) |
|
] |
|
) |
|
|
|
def mask_mod( |
|
_b: torch.Tensor, |
|
_h: torch.Tensor, |
|
q_idx: torch.Tensor, |
|
kv_idx: torch.Tensor, |
|
) -> torch.Tensor: |
|
same_doc = document_id[q_idx] == document_id[kv_idx] |
|
if cfg.is_causal: |
|
return same_doc & (q_idx >= kv_idx) |
|
return same_doc |
|
|
|
block_mask = create_block_mask( |
|
mask_mod, |
|
B=None, |
|
H=None, |
|
Q_LEN=cfg.total_tokens, |
|
KV_LEN=cfg.total_tokens, |
|
device="cuda", |
|
) |
|
return cls(q=q, k=k, v=v, block_mask=block_mask) |
|
|
|
|
|
@dataclass(frozen=True, slots=True) |
|
class TorchVarlenInputs(InputsBase[torch.Tensor]): |
|
"""Packed inputs for varlen_attn with cumulative sequence lengths. |
|
|
|
Tensors have shape (total_tokens, num_heads, head_dim) — all sequences |
|
are concatenated along the token dimension. ``cu_seqlens`` is the |
|
cumulative-sum offset array of shape (N+1,). |
|
""" |
|
|
|
cu_seqlens: torch.Tensor |
|
max_seqlen: int |
|
|
|
@classmethod |
|
def from_config(cls, cfg: Config) -> Self: |
|
"""Build packed Q/K/V and cumulative sequence length tensor.""" |
|
q, k, v = torch.randn( |
|
3, |
|
cfg.total_tokens, |
|
cfg.num_heads, |
|
cfg.head_dim, |
|
dtype=cfg.dtype.torch, |
|
device="cuda", |
|
).unbind(0) |
|
|
|
cu_seqlens = torch.zeros(cfg.batch_size + 1, dtype=torch.int32, device="cuda") |
|
cu_seqlens[1:] = torch.tensor(cfg.seq_lens, dtype=torch.int32, device="cuda") |
|
cu_seqlens = cu_seqlens.cumsum(0, dtype=torch.int32) |
|
return cls(q=q, k=k, v=v, cu_seqlens=cu_seqlens, max_seqlen=cfg.max_seq_len) |
|
|
|
|
|
@jax.tree_util.register_dataclass |
|
@dataclass(frozen=True, slots=True) |
|
class JaxInputs(InputsBase[jax.Array]): |
|
"""Prepared inputs and sequence lengths for JAX Attention variants. |
|
|
|
Shape (batch, max_seq_len, num_heads, head_dim) with padding for shorter sequences. |
|
""" |
|
|
|
mask: jax.Array |
|
"""Shape (batch, 1, max_seq_len, max_seq_len) bool array where True indicates valid query-key pairs.""" |
|
seq_lengths: jax.Array |
|
"""Will be used by JAX cuDNN attention to specify valid token counts per sequence.""" |
|
|
|
@classmethod |
|
def from_config(cls, cfg: Config) -> Self: |
|
"""Build padded Q/K/V + mask + seq lengths.""" |
|
seq_lengths = jnp.array(cfg.seq_lens, dtype=jnp.int32) |
|
max_len = cfg.max_seq_len |
|
qkv = jax.random.normal( |
|
jax.random.key(cfg.seed), |
|
(3, cfg.batch_size, max_len, cfg.num_heads, cfg.head_dim), |
|
dtype=cfg.dtype.jax, |
|
) |
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
valid = jnp.arange(max_len)[None, :] < seq_lengths[:, None] |
|
mask = valid[:, None, :, None] & valid[:, None, None, :] |
|
if cfg.is_causal: |
|
mask = mask & jnp.tril(jnp.ones((max_len, max_len), dtype=jnp.bool)) |
|
return cls(q=q, k=k, v=v, mask=mask, seq_lengths=seq_lengths) |
|
|
|
|
|
def torch_naive_attention(inputs: TorchPaddedInputs) -> torch.Tensor: |
|
"""Manual attention baseline.""" |
|
scale = inputs.q.shape[-1] ** -0.5 |
|
scores = (inputs.q @ inputs.k.transpose(-2, -1)) * scale |
|
scores = scores.masked_fill(~inputs.mask, float("-inf")) |
|
probs = torch.softmax(scores, dim=-1) |
|
return probs @ inputs.v |
|
|
|
|
|
def torch_sdpa_padded(inputs: TorchPaddedInputs) -> torch.Tensor: |
|
"""PyTorch SDPA with explicit bool mask on padded tensors.""" |
|
return F.scaled_dot_product_attention( |
|
inputs.q, inputs.k, inputs.v, attn_mask=inputs.mask |
|
) |
|
|
|
|
|
def torch_sdpa_nested(inputs: TorchNestedInputs) -> torch.Tensor: |
|
"""PyTorch SDPA on NestedTensor inputs.""" |
|
return F.scaled_dot_product_attention(inputs.q, inputs.k, inputs.v, is_causal=True) |
|
|
|
|
|
def torch_flex_packed(inputs: TorchFlexInputs) -> torch.Tensor: |
|
"""Flex Attention on packed sequences with document-ID block masking.""" |
|
return flex_attention(inputs.q, inputs.k, inputs.v, block_mask=inputs.block_mask) |
|
|
|
|
|
def torch_varlen_attention(inputs: TorchVarlenInputs) -> torch.Tensor: |
|
"""Variable-length attention via FlashAttention-backed varlen_attn.""" |
|
return varlen_attn( |
|
inputs.q, |
|
inputs.k, |
|
inputs.v, |
|
cu_seq_q=inputs.cu_seqlens, |
|
cu_seq_k=inputs.cu_seqlens, |
|
max_q=inputs.max_seqlen, |
|
max_k=inputs.max_seqlen, |
|
is_causal=True, |
|
) |
|
|
|
|
|
def jax_sdpa_xla(inputs: JaxInputs) -> jax.Array: |
|
"""JAX attention with XLA implementation.""" |
|
return jax.nn.dot_product_attention( |
|
inputs.q, inputs.k, inputs.v, mask=inputs.mask, implementation="xla" |
|
) |
|
|
|
|
|
def jax_naive(inputs: JaxInputs) -> jax.Array: |
|
"""Manual attention baseline in JAX.""" |
|
q, k, v = inputs.q, inputs.k, inputs.v |
|
q = q.transpose(0, 2, 1, 3) # (B, H, Q_LEN, D) |
|
k = k.transpose(0, 2, 1, 3) # (B, H, K_LEN, D) |
|
v = v.transpose(0, 2, 1, 3) # (B, H, V_LEN, D) |
|
scale = q.shape[-1] ** -0.5 |
|
scores = (scale * q) @ k.transpose(0, 1, 3, 2) |
|
scores = jnp.where(inputs.mask, scores, float("-inf")) |
|
probs = jax.nn.softmax(scores, axis=-1) |
|
return probs @ v |
|
|
|
|
|
def jax_sdpa_cudnn(inputs: JaxInputs) -> jax.Array: |
|
"""JAX attention with cuDNN implementation and ragged seq lengths.""" |
|
return jax.nn.dot_product_attention( |
|
inputs.q, |
|
inputs.k, |
|
inputs.v, |
|
is_causal=True, |
|
query_seq_lengths=inputs.seq_lengths, |
|
key_value_seq_lengths=inputs.seq_lengths, |
|
implementation="cudnn", |
|
) |
|
|
|
|
|
def run(cfg: Config) -> list[Result]: |
|
"""Run all benchmark groups and collect results.""" |
|
recorder = ResultRecorder() |
|
recorder.print_header(cfg) |
|
|
|
bench_specs: list[tuple[type[Bench], str, str, AttentionFn, InputFactory[Any]]] = [ |
|
( |
|
TorchBench, |
|
"torch_naive", |
|
"padded", |
|
torch_naive_attention, |
|
TorchPaddedInputs.from_config, |
|
), |
|
( |
|
TorchBench, |
|
"torch_sdpa", |
|
"padded", |
|
torch_sdpa_padded, |
|
TorchPaddedInputs.from_config, |
|
), |
|
( |
|
TorchBench, |
|
"torch_sdpa", |
|
"nested", |
|
torch_sdpa_nested, |
|
TorchNestedInputs.from_config, |
|
), |
|
( |
|
TorchBench, |
|
"torch_flex", |
|
"packed", |
|
torch_flex_packed, |
|
TorchFlexInputs.from_config, |
|
), |
|
( |
|
TorchBench, |
|
"torch_varlen", |
|
"packed", |
|
torch_varlen_attention, |
|
TorchVarlenInputs.from_config, |
|
), |
|
( |
|
JaxBench, |
|
"jax_cudnn", |
|
"padded using seq_lens", |
|
jax_sdpa_cudnn, |
|
JaxInputs.from_config, |
|
), |
|
(JaxBench, "jax_xla", "padded", jax_sdpa_xla, JaxInputs.from_config), |
|
(JaxBench, "jax_naive", "padded", jax_naive, JaxInputs.from_config), |
|
] |
|
|
|
for bench_cls, variant, strategy, attention_fn, input_factory in bench_specs: |
|
print(f"Running {variant} ({strategy})...") |
|
with bench_cls(variant, strategy, attention_fn, input_factory, cfg) as bench: |
|
recorder.add(bench.profile()) |
|
|
|
recorder.print_results() |
|
return recorder.results |
|
|
|
|
|
def parse_args() -> Config: |
|
"""Parse CLI options into Config.""" |
|
parser = argparse.ArgumentParser(description=__doc__) |
|
parser.add_argument("--batch-size", type=int, default=8) |
|
parser.add_argument("--num-heads", type=int, default=32) |
|
parser.add_argument("--head-dim", type=int, default=128) |
|
parser.add_argument("--seq-min", type=int, default=128) |
|
parser.add_argument("--seq-max", type=int, default=2048) |
|
parser.add_argument( |
|
"--dtype", choices=list(DType), type=DType, default=DType.BFLOAT16 |
|
) |
|
parser.add_argument("--warmup", type=int, default=2) |
|
parser.add_argument("--iters", type=int, default=30) |
|
parser.add_argument("--seed", type=int, default=42) |
|
args = parser.parse_args() |
|
|
|
return Config( |
|
batch_size=args.batch_size, |
|
num_heads=args.num_heads, |
|
head_dim=args.head_dim, |
|
seq_min=args.seq_min, |
|
seq_max=args.seq_max, |
|
dtype=args.dtype, |
|
warmup=args.warmup, |
|
iters=args.iters, |
|
seed=args.seed, |
|
) |
|
|
|
|
|
def main() -> None: |
|
"""CLI entrypoint.""" |
|
if not torch.cuda.is_available(): |
|
raise RuntimeError("CUDA not available") |
|
|
|
cfg = parse_args() |
|
run(cfg) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |