Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save shunting314/2a82a1c2404f647861cecbbe70cda74a to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/2a82a1c2404f647861cecbbe70cda74a to your computer and use it in GitHub Desktop.
from __future__ import annotations
import torch
import helion.language as hl
import triton
import triton.language as tl
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher
import __main__ as _source_module
@triton.jit
def symm_mem_sync(signal_pad_ptrs, block_id, rank: tl.constexpr, world_size: tl.constexpr, hasPreviousMemAccess: tl.constexpr=False, hasSubsequentMemAccess: tl.constexpr=False) -> None:
"""
Synchronizes blocks with matching block_id across participating devices.
Note: the function itself is not a system level barrier/fence. It is a
building block for expressing different synchronization patterns.
Pattern 0: Ensures that all writes to symm_mem buffers from previous
kernels across all devices are visible to the current kernel:
symm_mem_sync(..., hasPreviousMemAccess=False, hasSubsequentMemAccess=True)
Pattern 1: Ensures that all writes to symm_mem buffers from the current
block are visible to all remote blocks with matching blockIdx:
symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequentMemAccess=True)
Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
for writing by subsequent kernels across all devices.
symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequentMemAccess=False)
CUDA graph friendliness:
This barrier operates through atomic operations on a zero-filled signal
pad, which resets to a zero-filled state after each successful
synchronization. This design eliminates the need for incrementing a
flag from host.
"""
# src[allreduce_bias_rmsnorm.py:64]: symm_mem_sync,
# src[allreduce_bias_rmsnorm.py:65]: args=(signal_pad_ptrs, None, RANK, WORLD_SIZE, True, True),
if block_id is None:
# src[allreduce_bias_rmsnorm.py:65]: args=(signal_pad_ptrs, None, RANK, WORLD_SIZE, True, True),
block_id = _get_flat_bid()
# src[allreduce_bias_rmsnorm.py:66]: output_like=None,
flat_tid = _get_flat_tid()
# src[allreduce_bias_rmsnorm.py:68]: [source unavailable]
remote_ranks = tl.arange(0, world_size)
# src[allreduce_bias_rmsnorm.py:69]: # Step 3: All-reduce + bias: acc = bias + sum(buffer from all ranks)
signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64))
# src[allreduce_bias_rmsnorm.py:70]: # Initialize acc with the right shape by broadcasting bias
# src[allreduce_bias_rmsnorm.py:71]: acc = symm_mem_buffer[tile_n, :].to(torch.float32) * 0.0 + bias[None, :].to(
# src[allreduce_bias_rmsnorm.py:72]: torch.float32
remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(tl.pointer_type(tl.uint32))
# src[allreduce_bias_rmsnorm.py:73]: )
send_addrs = remote_signal_pad_addrs + block_id * world_size + rank
# src[allreduce_bias_rmsnorm.py:75]: acc = acc + remote_buffer[tile_n, :].to(torch.float32)
# src[allreduce_bias_rmsnorm.py:77]: # Step 4: RMS Norm: y = acc * rsqrt(mean(acc^2) + eps) * weight
local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(tl.pointer_type(tl.uint32))
# src[allreduce_bias_rmsnorm.py:78]: variance = torch.mean(acc * acc, dim=-1, keepdim=True)
wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks
# src[allreduce_bias_rmsnorm.py:80]: normalized = acc * rstd
# src[allreduce_bias_rmsnorm.py:81]: output[tile_n, :] = (normalized * weight[None, :].to(torch.float32)).to(x.dtype)
if hasPreviousMemAccess:
# src[allreduce_bias_rmsnorm.py:81]: output[tile_n, :] = (normalized * weight[None, :].to(torch.float32)).to(x.dtype)
tl.debug_barrier()
# src[allreduce_bias_rmsnorm.py:83]: # Step 5: Final sync (release only)
# src[allreduce_bias_rmsnorm.py:84]: # hl.triton_kernel(
# src[allreduce_bias_rmsnorm.py:85]: # symm_mem_sync,
if flat_tid < world_size:
# src[allreduce_bias_rmsnorm.py:84]: # hl.triton_kernel(
_send_signal(send_addrs, 'release' if hasPreviousMemAccess else 'relaxed')
# src[allreduce_bias_rmsnorm.py:85]: # symm_mem_sync,
_wait_signal(wait_addrs, 'acquire' if hasSubsequentMemAccess else 'relaxed')
# src[allreduce_bias_rmsnorm.py:87]: # output_like=None,
# src[allreduce_bias_rmsnorm.py:88]: # )
if hasSubsequentMemAccess:
# src[allreduce_bias_rmsnorm.py:88]: # )
tl.debug_barrier()
@triton.jit
def _get_flat_bid():
# src[allreduce_bias_rmsnorm.py:93]: def helion_one_shot_allreduce_bias_rmsnorm(
# src[allreduce_bias_rmsnorm.py:94]: symm_mem_buffer: torch.Tensor,
# src[allreduce_bias_rmsnorm.py:92-96]: ...
return tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0) + tl.program_id(1) * tl.num_programs(0) + tl.program_id(0)
@triton.jit
def _get_flat_tid():
# src[allreduce_bias_rmsnorm.py:100]: """
tid_x, tid_y, tid_z = _get_tid()
# src[allreduce_bias_rmsnorm.py:101]: Wrapper that sets up symmetric memory and calls the Helion kernel.
ntid_x, ntid_y, _ = _get_ntid()
# src[allreduce_bias_rmsnorm.py:102]: """
return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x
@triton.jit
def _get_tid():
# src[allreduce_bias_rmsnorm.py:107]: symm_mem_hdl = symm_mem.rendezvous(symm_mem_buffer, group.group_name)
# src[allreduce_bias_rmsnorm.py:106-117]: ...
return tl.inline_asm_elementwise('\n mov.u32 $0, %tid.x;\n mov.u32 $1, %tid.y;\n mov.u32 $2, %tid.z;\n ', '=r,=r,=r', [], dtype=(tl.uint32, tl.uint32, tl.uint32), is_pure=True, pack=1)
@triton.jit
def _get_ntid():
# src[allreduce_bias_rmsnorm.py:122]: @helion.jit(
# src[allreduce_bias_rmsnorm.py:123]: config=helion.Config(
# src[allreduce_bias_rmsnorm.py:121-132]: ...
return tl.inline_asm_elementwise('\n mov.u32 $0, %ntid.x;\n mov.u32 $1, %ntid.y;\n mov.u32 $2, %ntid.z;\n ', '=r,=r,=r', [], dtype=(tl.uint32, tl.uint32, tl.uint32), is_pure=True, pack=1)
@triton.jit
def _send_signal(addrs, sem: tl.constexpr) -> None:
# src[allreduce_bias_rmsnorm.py:136]: WORLD_SIZE: hl.constexpr,
# src[allreduce_bias_rmsnorm.py:137]: GROUP_NAME: hl.constexpr,
# src[allreduce_bias_rmsnorm.py:138]: ) -> torch.Tensor:
# src[allreduce_bias_rmsnorm.py:136-153]: ...
tl.inline_asm_elementwise(f'\n {{\n .reg .u32 %tmp32_<1>;\n .reg .pred %p<1>;\n\n send_signal:\n atom.global.{sem}.sys.cas.b32 %tmp32_0, [$1], 0, 1;\n setp.eq.u32 %p0, %tmp32_0, 0;\n @!%p0 bra send_signal;\n }}\n ', '=r, l', [addrs], dtype=addrs.dtype, is_pure=False, pack=1)
@triton.jit
def _wait_signal(addrs, sem: tl.constexpr) -> None:
# src[allreduce_bias_rmsnorm.py:158]: # reduce scatter
# src[allreduce_bias_rmsnorm.py:159]: # TODO(shunting): get rid of the reshape workaround
# src[allreduce_bias_rmsnorm.py:157-174]: ...
tl.inline_asm_elementwise(f'\n {{\n .reg .u32 %tmp32_<1>;\n .reg .pred %p<1>;\n\n wait_signal:\n atom.global.sys.{sem}.cas.b32 %tmp32_0, [$1], 1, 0;\n setp.eq.u32 %p0, %tmp32_0, 1;\n @!%p0 bra wait_signal;\n }}\n ', '=r, l', [addrs], dtype=tl.int32, is_pure=False, pack=1)
_BLOCK_SIZE_0 = tl.constexpr(8)
@triton.jit
def _helion_one_shot_allreduce_bias_rmsnorm_kernel(x, symm_mem_buffer, bias, buffer_tuple_item_0, buffer_tuple_item_1, buffer_tuple_item_2, buffer_tuple_item_3, buffer_tuple_item_4, buffer_tuple_item_5, buffer_tuple_item_6, buffer_tuple_item_7, weight, output, signal_pad_ptrs, _REDUCTION_BLOCK_1: tl.constexpr):
# src[allreduce_bias_rmsnorm.py:56]: for tile_n in hl.tile(N):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
# src[allreduce_bias_rmsnorm.py:63]: hl.triton_kernel(
# src[allreduce_bias_rmsnorm.py:64]: symm_mem_sync,
# src[allreduce_bias_rmsnorm.py:65]: args=(signal_pad_ptrs, None, RANK, WORLD_SIZE, True, True),
# src[allreduce_bias_rmsnorm.py:63-67]: ...
symm_mem_sync(signal_pad_ptrs, None, 0, 8, True, True)
# src[allreduce_bias_rmsnorm.py:78]: variance = torch.mean(acc * acc, dim=-1, keepdim=True)
variance_extra_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
# src[allreduce_bias_rmsnorm.py:58]: symm_mem_buffer[tile_n, :] = x[tile_n, :]
for roffset_1 in tl.range(0, 4096, _REDUCTION_BLOCK_1):
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
load = tl.load(x + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
tl.store(symm_mem_buffer + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), load, None)
# src[allreduce_bias_rmsnorm.py:71]: acc = symm_mem_buffer[tile_n, :].to(torch.float32) * 0.0 + bias[None, :].to(
load_1 = tl.load(symm_mem_buffer + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_0 = tl.full([], 0.0, tl.float32)
v_1 = load_1 * v_0
load_2 = tl.load(bias + rindex_1[None, :] * 1, None)
# src[allreduce_bias_rmsnorm.py:71]: acc = symm_mem_buffer[tile_n, :].to(torch.float32) * 0.0 + bias[None, :].to(
# src[allreduce_bias_rmsnorm.py:72]: torch.float32
# src[allreduce_bias_rmsnorm.py:73]: )
v_2 = v_1 + load_2
# src[allreduce_bias_rmsnorm.py:75]: acc = acc + remote_buffer[tile_n, :].to(torch.float32)
load_3 = tl.load(buffer_tuple_item_0 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_3 = v_2 + load_3
load_4 = tl.load(buffer_tuple_item_1 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_4 = v_3 + load_4
load_5 = tl.load(buffer_tuple_item_2 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_5 = v_4 + load_5
load_6 = tl.load(buffer_tuple_item_3 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_6 = v_5 + load_6
load_7 = tl.load(buffer_tuple_item_4 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_7 = v_6 + load_7
load_8 = tl.load(buffer_tuple_item_5 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_8 = v_7 + load_8
load_9 = tl.load(buffer_tuple_item_6 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_9 = v_8 + load_9
load_10 = tl.load(buffer_tuple_item_7 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_10 = v_9 + load_10
# src[allreduce_bias_rmsnorm.py:78]: variance = torch.mean(acc * acc, dim=-1, keepdim=True)
v_11 = v_10 * v_10
v_12 = variance_extra_acc + v_11
variance_extra_acc = v_12
variance_extra = tl.cast(tl.reshape(tl.sum(variance_extra_acc, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
v_13 = 4096
v_14 = tl.cast(v_13, tl.float32)
v_15 = variance_extra / v_14
# src[allreduce_bias_rmsnorm.py:79]: rstd = torch.rsqrt(variance + EPS) # type: ignore[unsupported-operation]
v_16 = tl.full([], 1e-05, tl.float32)
v_17 = v_15 + v_16
v_18 = libdevice.rsqrt(v_17)
# src[allreduce_bias_rmsnorm.py:58]: symm_mem_buffer[tile_n, :] = x[tile_n, :]
for roffset_1 in tl.range(0, 4096, _REDUCTION_BLOCK_1):
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
v_18_copy = v_18
# src[allreduce_bias_rmsnorm.py:71]: acc = symm_mem_buffer[tile_n, :].to(torch.float32) * 0.0 + bias[None, :].to(
load_11 = tl.load(symm_mem_buffer + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_19 = tl.full([], 0.0, tl.float32)
v_20 = load_11 * v_19
load_12 = tl.load(bias + rindex_1[None, :] * 1, None)
# src[allreduce_bias_rmsnorm.py:71]: acc = symm_mem_buffer[tile_n, :].to(torch.float32) * 0.0 + bias[None, :].to(
# src[allreduce_bias_rmsnorm.py:72]: torch.float32
# src[allreduce_bias_rmsnorm.py:73]: )
v_21 = v_20 + load_12
# src[allreduce_bias_rmsnorm.py:75]: acc = acc + remote_buffer[tile_n, :].to(torch.float32)
load_13 = tl.load(buffer_tuple_item_0 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_22 = v_21 + load_13
load_14 = tl.load(buffer_tuple_item_1 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_23 = v_22 + load_14
load_15 = tl.load(buffer_tuple_item_2 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_24 = v_23 + load_15
load_16 = tl.load(buffer_tuple_item_3 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_25 = v_24 + load_16
load_17 = tl.load(buffer_tuple_item_4 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_26 = v_25 + load_17
load_18 = tl.load(buffer_tuple_item_5 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_27 = v_26 + load_18
load_19 = tl.load(buffer_tuple_item_6 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_28 = v_27 + load_19
load_20 = tl.load(buffer_tuple_item_7 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None)
v_29 = v_28 + load_20
# src[allreduce_bias_rmsnorm.py:80]: normalized = acc * rstd
v_30 = v_29 * v_18_copy
# src[allreduce_bias_rmsnorm.py:81]: output[tile_n, :] = (normalized * weight[None, :].to(torch.float32)).to(x.dtype)
load_21 = tl.load(weight + rindex_1[None, :] * 1, None)
v_31 = v_30 * load_21
tl.store(output + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), v_31, None)
def one_shot_allreduce_bias_rmsnorm_kernel(symm_mem_buffer: torch.Tensor, x: torch.Tensor, bias: torch.Tensor, weight: torch.Tensor, signal_pad_ptrs: torch.Tensor, EPS: hl.constexpr, RANK: hl.constexpr, WORLD_SIZE: hl.constexpr, GROUP_NAME: hl.constexpr, *, _launcher=_default_launcher):
"""
Fused one-shot all-reduce + bias addition + RMS normalization.
"""
# src[allreduce_bias_rmsnorm.py:50]: N, D = x.size()
N, D = x.size()
# src[allreduce_bias_rmsnorm.py:51]: output = torch.empty_like(x)
output = torch.empty_like(x)
# src[allreduce_bias_rmsnorm.py:54]: buffer_tuple = torch.ops.symm_mem.get_remote_tensors(symm_mem_buffer, GROUP_NAME)
buffer_tuple = torch.ops.symm_mem.get_remote_tensors(symm_mem_buffer, '0')
# src[allreduce_bias_rmsnorm.py:56]: for tile_n in hl.tile(N):
_BLOCK_SIZE_0 = 8
# src[allreduce_bias_rmsnorm.py:58]: symm_mem_buffer[tile_n, :] = x[tile_n, :]
_REDUCTION_BLOCK_1 = 512
# src[allreduce_bias_rmsnorm.py:56]: for tile_n in hl.tile(N):
# src[allreduce_bias_rmsnorm.py:57]: # Step 1: Copy input x to our symmetric memory buffer
# src[allreduce_bias_rmsnorm.py:58]: symm_mem_buffer[tile_n, :] = x[tile_n, :]
# src[allreduce_bias_rmsnorm.py:56-81]: ...
_launcher(_helion_one_shot_allreduce_bias_rmsnorm_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, symm_mem_buffer, bias, buffer_tuple[0], buffer_tuple[1], buffer_tuple[2], buffer_tuple[3], buffer_tuple[4], buffer_tuple[5], buffer_tuple[6], buffer_tuple[7], weight, output, signal_pad_ptrs, _REDUCTION_BLOCK_1, num_warps=8, num_stages=1)
# src[allreduce_bias_rmsnorm.py:90]: return output
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment