Skip to content

Instantly share code, notes, and snippets.

"""
NOTE: the script right now only benchmark the latency of the attention
kernel itself. The following things are excluded
- add new key/value to the cache
- setup BlockMask for flex-attention
- etc.
"""
import math
import torch
from torch import nn
from torch import distributed
import contextlib
import os
from vllm import LLM, SamplingParams
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
os.environ["VLLM_ATTENTION_BACKEND"] = os.getenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
from torch._dynamo.testing import rand_strided
diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py
index af867f4..5f7f2ac 100644
--- a/src/liger_kernel/ops/rms_norm.py
+++ b/src/liger_kernel/ops/rms_norm.py
@@ -450,6 +450,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
elif X.device.type == "xpu":
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
+ sm_count = sm_count * 32
# fp32 for numerical stability especially.
def triton_per_fused__to_copy_add_div_expand_mul_pow_squeeze_sum_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ws_ptr, xnumel, r0_numel, XBLOCK : tl.constexpr, RSPLIT_SIZE : tl.constexpr, NUM_STAGES : tl.constexpr):
xnumel = 32768
r0_numel = 768
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * RSPLIT_SIZE
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_index = tl.arange(0, R0_BLOCK)[None, :]
def triton_per_fused__to_copy_add_div_expand_mul_pow_squeeze_sum_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ws_ptr, xnumel, r0_numel, XBLOCK : tl.constexpr, RSPLIT_SIZE : tl.constexpr, NUM_STAGES : tl.constexpr):
xnumel = 32768
r0_numel = 768
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * RSPLIT_SIZE
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_index = tl.arange(0, R0_BLOCK)[None, :]
diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py
index b4f175183..b539fd1f8 100644
--- a/vllm/benchmarks/latency.py
+++ b/vllm/benchmarks/latency.py
@@ -101,7 +101,7 @@ def main(args: argparse.Namespace):
sampling_params = SamplingParams(
n=args.n,
- temperature=1.0,
+ temperature=0.0,
import torch
torch._inductor.config.combo_kernels = True
torch._inductor.config.fx_graph_cache = False
@torch.compile
def f(x, y):
return x + 1, y * 2
# x = torch.randn(1024, device="cuda")
@shunting314
shunting314 / k1.py
Last active November 25, 2025 02:23
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
@triton_heuristics.pointwise(
size_hints={'x': 67108864}, tile_hint=TileHint.DEFAULT,
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
@triton_heuristics.pointwise(
size_hints={'x': 16777216}, tile_hint=TileHint.DEFAULT,