Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save shunting314/fa84e5b39a53df096bb18d919342d142 to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/fa84e5b39a53df096bb18d919342d142 to your computer and use it in GitHub Desktop.
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
@triton_heuristics.reduction(
size_hints={'x': 4096, 'r0_': 8192},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*fp32', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*fp32', 'out_ptr2': '*bf16', 'out_ptr3': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__fused_rms_norm_backward__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_view_13', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 9, 'num_store': 2, 'num_reduction': 2, 'backend_hash': '58367EC428ADC15B85CB9CF138B580A95422950F92C44A257A91C53B1E76C9F7', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'x': 469794816, 'r0_': 65536}, 'kernel_num_gb': 0.469843968, 'kernel_flop': 0}
)
@triton.jit
def triton_red_fused__fused_rms_norm_backward__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_view_13(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 4096
r0_numel = 8192
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp17 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
_tmp33 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (x0 + 4096*r0_1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0 + 4096*r0_1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp4 = tl.load(in_ptr2 + (x0 + 4096*r0_1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (x0 + 4096*r0_1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp8 = tl.load(in_ptr4 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0)
tmp19 = tl.load(in_ptr5 + (x0 + 4096*r0_1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp20 = tl.load(in_ptr6 + (x0 + 4096*r0_1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp22 = tl.load(in_ptr7 + (x0 + 4096*r0_1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp26 = tl.load(in_ptr8 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0)
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp6 = tmp4 + tmp5
tmp7 = tmp6.to(tl.float32)
tmp9 = 4096.0
tmp10 = (tmp8 / tmp9)
tmp11 = 9.999999747378752e-06
tmp12 = tmp10 + tmp11
tmp13 = libdevice.rsqrt(tmp12)
tmp14 = tmp7 * tmp13
tmp15 = tmp3 * tmp14
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, R0_BLOCK])
tmp18 = _tmp17 + tmp16
_tmp17 = tl.where(r0_mask, tmp18, _tmp17)
tmp21 = tmp19 + tmp20
tmp23 = tmp21 + tmp22
tmp24 = tmp23.to(tl.float32)
tmp25 = tmp4.to(tl.float32)
tmp27 = (tmp26 / tmp9)
tmp28 = tmp27 + tmp11
tmp29 = libdevice.rsqrt(tmp28)
tmp30 = tmp25 * tmp29
tmp31 = tmp24 * tmp30
tmp32 = tl.broadcast_to(tmp31, [XBLOCK, R0_BLOCK])
tmp34 = _tmp33 + tmp32
_tmp33 = tl.where(r0_mask, tmp34, _tmp33)
tmp17 = tl.sum(_tmp17, 1)[:, None]
tmp33 = tl.sum(_tmp33, 1)[:, None]
tmp35 = tmp17.to(tl.float32)
tmp36 = tmp33.to(tl.float32)
tl.store(out_ptr2 + (x0), tmp35, None)
tl.store(out_ptr3 + (x0), tmp36, None)
def get_args():
arg_0 = rand_strided((8192, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_1 = rand_strided((8192, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_2 = rand_strided((1, 8192, 4096), (33554432, 4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_3 = rand_strided((8192, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_4 = rand_strided((1, 8192, 1), (8192, 1, 8192), device='cuda:0', dtype=torch.float32)
arg_5 = rand_strided((8192, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_6 = rand_strided((8192, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_7 = rand_strided((8192, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_8 = rand_strided((1, 8192, 1), (8192, 1, 8192), device='cuda:0', dtype=torch.float32)
arg_9 = rand_strided((4096,), (1,), device='cuda:0', dtype=torch.bfloat16)
arg_10 = rand_strided((4096,), (1,), device='cuda:0', dtype=torch.bfloat16)
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9, arg_10, 4096, 8192,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__fused_rms_norm_backward__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_view_13.run(*args, stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__fused_rms_norm_backward__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_view_13.benchmark_all_configs(*args)
if __name__ == '__main__':
from torch._inductor.runtime.benchmarking import benchmarker
args = get_args()
ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)
num_gb = 0.469843968
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment