Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created March 9, 2026 21:06
Show Gist options
  • Select an option

  • Save shunting314/1661aed073d5f0811cd3e0bd9020f6ef to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/1661aed073d5f0811cd3e0bd9020f6ef to your computer and use it in GitHub Desktop.
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
@triton_heuristics.persistent_reduction(
size_hints={'x': 524288, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'in_ptr5': '*fp32', 'out_ptr2': '*bf16', 'out_ptr3': '*bf16', 'ws_ptr': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'RSPLIT_SIZE': 'constexpr', 'NUM_STAGES': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'MixOrderReductionGrid', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_as_strided_23', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 6, 'num_store': 0, 'num_reduction': 2, 'backend_hash': '891A923834BDB77D597EFFDCFF1ED74A6C07A5F7B08559989641F4DBBF828EA6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'is_fbcode': True, 'RSPLIT_SIZE': 128, 'kernel_num_gb': 2.949064952, 'kernel_flop': 0}
)
@triton.jit
def triton_per_fused_as_strided_23(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr2, out_ptr3, ws_ptr, xnumel, r0_numel, XBLOCK : tl.constexpr, RSPLIT_SIZE : tl.constexpr, NUM_STAGES : tl.constexpr):
xnumel = 511279
r0_numel = 960
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * RSPLIT_SIZE
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
r0_index = tl.arange(0, R0_BLOCK)[None, :]
r0_offset = 0
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
accum0 = tl.full([R0_BLOCK], 0, tl.float32)[None, :]
accum1 = tl.full([R0_BLOCK], 0, tl.float32)[None, :]
split_size = min(RSPLIT_SIZE, xnumel - xoffset)
for _ in tl.range(0, split_size, XBLOCK, num_stages=NUM_STAGES):
xmask = xindex < xnumel
x0 = xindex
xindex += XBLOCK
tmp0 = tl.load(in_ptr0 + (r0_1 + 960*x0), r0_mask & xmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0)
tmp8 = tl.load(in_ptr2 + (r0_1 + 960*x0), r0_mask & xmask, other=0.0).to(tl.float32)
tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
tmp20 = tl.load(in_ptr5 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp1 * tmp2
tmp4 = tl.broadcast_to(tmp3, [XBLOCK, R0_BLOCK])
tmp6 = tl.where(r0_mask & xmask, tmp4, 0)
tmp7 = tl.sum(tmp6, 1)[:, None].to(tl.float32)
tmp9 = tmp8.to(tl.float32)
tmp11 = tmp9 - tmp10
tmp13 = tmp11 * tmp12
tmp14 = tmp3 * tmp13
tmp15 = tl.broadcast_to(tmp14, [XBLOCK, R0_BLOCK])
tmp17 = tl.where(r0_mask & xmask, tmp15, 0)
tmp18 = tl.sum(tmp17, 1)[:, None].to(tl.float32)
tmp19 = tmp13 * tmp2
tmp21 = tmp19 + tmp20
tmp22 = tmp21.to(tl.float32)
tmp23 = tl.full([1, 1], 0.0010416666666666667, tl.float32)
tmp24 = tmp12 * tmp23
tmp25 = tl.full([1, 1], 960.0, tl.float32)
tmp26 = tmp3 * tmp25
tmp27 = tmp26 - tmp7
tmp28 = tmp13 * tmp18
tmp29 = tmp27 - tmp28
tmp30 = tmp24 * tmp29
tmp31 = tmp30.to(tl.float32)
tmp32 = tmp1 * tmp13
tl.store(out_ptr2 + (r0_1 + 960*x0), tmp22, r0_mask & xmask)
tl.store(out_ptr3 + (r0_1 + 960*x0), tmp31, r0_mask & xmask)
tmp33 = tl.sum(tmp32, 0)
tmp34 = accum0 + tmp33
accum0 = tmp34
tmp35 = tl.sum(tmp1, 0)
tmp36 = accum1 + tmp35
accum1 = tmp36
tl.store(ws_ptr + (tl.program_id(0) + 0 * tl.num_programs(0)) * r0_numel + r0_index, accum0, r0_mask)
tl.store(ws_ptr + (tl.program_id(0) + 1 * tl.num_programs(0)) * r0_numel + r0_index, accum1, r0_mask)
def get_args():
arg_0 = rand_strided((511279, 960), (960, 1), device='cuda:0', dtype=torch.bfloat16)
arg_1 = rand_strided((960,), (1,), device='cuda:0', dtype=torch.float32)
arg_2 = rand_strided((511279, 960), (960, 1), device='cuda:0', dtype=torch.bfloat16)
arg_3 = rand_strided((511279, 1), (1, 1), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((511279, 1), (1, 1), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((960,), (1,), device='cuda:0', dtype=torch.float32)
arg_6 = rand_strided((511279, 960), (960, 1), device='cuda:0', dtype=torch.bfloat16)
arg_7 = rand_strided((511279, 960), (960, 1), device='cuda:0', dtype=torch.bfloat16)
arg_8 = torch.zeros(7670400, device='cuda:0', dtype=torch.float32)
# return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, 511279, 960,
return *torch.load("./saved.pt"), arg_7, arg_8, 511279, 96
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_per_fused_as_strided_23.run(*args, stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_per_fused_as_strided_23.benchmark_all_configs(*args)
if __name__ == '__main__':
from torch._inductor.runtime.benchmarking import benchmarker
args = get_args()
for i, x in enumerate(args):
if isinstance(x, torch.Tensor):
print(f"input {i} is nan/inf? {x.isnan().any()} {x.isinf().any()}")
out = args[-4]
call(args)
for i, x in enumerate(args):
if isinstance(x, torch.Tensor):
print(f"post run input {i} is nan/inf? {x.isnan().any()} {x.isinf().any()}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment