Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created December 8, 2025 23:56
Show Gist options
  • Select an option

  • Save shunting314/2ce3cae23f98b17da67b783e059f7102 to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/2ce3cae23f98b17da67b783e059f7102 to your computer and use it in GitHub Desktop.
def triton_per_fused__to_copy_add_div_expand_mul_pow_squeeze_sum_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ws_ptr, xnumel, r0_numel, XBLOCK : tl.constexpr, RSPLIT_SIZE : tl.constexpr, NUM_STAGES : tl.constexpr):
xnumel = 32768
r0_numel = 768
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * RSPLIT_SIZE
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_index = tl.arange(0, R0_BLOCK)[None, :]
r0_offset = 0
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
accum0 = tl.full([R0_BLOCK], 0, tl.float32)[None, :]
split_size = min(RSPLIT_SIZE, xnumel - xoffset)
for _ in tl.range(0, split_size, XBLOCK, num_stages=NUM_STAGES):
x0 = xindex
xindex += XBLOCK
tmp0 = tl.load(in_ptr0 + (r0_1 + 768*x0), r0_mask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0)
tmp4 = tl.load(in_ptr2 + (r0_1 + 768*x0), r0_mask, other=0.0).to(tl.float32)
tmp11 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp1 * tmp2
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp3 * tmp5
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
tmp9 = tl.where(r0_mask, tmp7, 0)
tmp10 = tl.sum(tmp9, 1)[:, None].to(tl.float32)
tmp12 = tmp3 * tmp11
tmp13 = -0.5
tmp14 = tmp10 * tmp13
tmp15 = tmp11 * tmp11
tmp16 = tmp15 * tmp11
tmp17 = tmp14 * tmp16
tmp18 = 0.0013020833333333333
tmp19 = tmp17 * tmp18
tmp20 = tmp19 * tmp5
tmp21 = tmp12 + tmp20
tmp22 = tmp21 + tmp20
tmp23 = tmp22.to(tl.float32)
tmp24 = tmp5 * tmp11
tmp25 = tmp1 * tmp24
tl.store(out_ptr1 + (r0_1 + 768*x0), tmp23, r0_mask)
tmp26 = tl.sum(tmp25, 0)
tmp27 = accum0 + tmp26
accum0 = tmp27
tl.store(ws_ptr + (tl.program_id(0) + 0 * tl.num_programs(0)) * r0_numel + r0_index, accum0, r0_mask)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment