Created
March 10, 2026 22:07
-
-
Save shunting314/2a82a1c2404f647861cecbbe70cda74a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| import torch | |
| import helion.language as hl | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime.triton_compat import libdevice | |
| from helion.runtime import default_launcher as _default_launcher | |
| import __main__ as _source_module | |
| @triton.jit | |
| def symm_mem_sync(signal_pad_ptrs, block_id, rank: tl.constexpr, world_size: tl.constexpr, hasPreviousMemAccess: tl.constexpr=False, hasSubsequentMemAccess: tl.constexpr=False) -> None: | |
| """ | |
| Synchronizes blocks with matching block_id across participating devices. | |
| Note: the function itself is not a system level barrier/fence. It is a | |
| building block for expressing different synchronization patterns. | |
| Pattern 0: Ensures that all writes to symm_mem buffers from previous | |
| kernels across all devices are visible to the current kernel: | |
| symm_mem_sync(..., hasPreviousMemAccess=False, hasSubsequentMemAccess=True) | |
| Pattern 1: Ensures that all writes to symm_mem buffers from the current | |
| block are visible to all remote blocks with matching blockIdx: | |
| symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequentMemAccess=True) | |
| Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe | |
| for writing by subsequent kernels across all devices. | |
| symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequentMemAccess=False) | |
| CUDA graph friendliness: | |
| This barrier operates through atomic operations on a zero-filled signal | |
| pad, which resets to a zero-filled state after each successful | |
| synchronization. This design eliminates the need for incrementing a | |
| flag from host. | |
| """ | |
| # src[allreduce_bias_rmsnorm.py:64]: symm_mem_sync, | |
| # src[allreduce_bias_rmsnorm.py:65]: args=(signal_pad_ptrs, None, RANK, WORLD_SIZE, True, True), | |
| if block_id is None: | |
| # src[allreduce_bias_rmsnorm.py:65]: args=(signal_pad_ptrs, None, RANK, WORLD_SIZE, True, True), | |
| block_id = _get_flat_bid() | |
| # src[allreduce_bias_rmsnorm.py:66]: output_like=None, | |
| flat_tid = _get_flat_tid() | |
| # src[allreduce_bias_rmsnorm.py:68]: [source unavailable] | |
| remote_ranks = tl.arange(0, world_size) | |
| # src[allreduce_bias_rmsnorm.py:69]: # Step 3: All-reduce + bias: acc = bias + sum(buffer from all ranks) | |
| signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64)) | |
| # src[allreduce_bias_rmsnorm.py:70]: # Initialize acc with the right shape by broadcasting bias | |
| # src[allreduce_bias_rmsnorm.py:71]: acc = symm_mem_buffer[tile_n, :].to(torch.float32) * 0.0 + bias[None, :].to( | |
| # src[allreduce_bias_rmsnorm.py:72]: torch.float32 | |
| remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(tl.pointer_type(tl.uint32)) | |
| # src[allreduce_bias_rmsnorm.py:73]: ) | |
| send_addrs = remote_signal_pad_addrs + block_id * world_size + rank | |
| # src[allreduce_bias_rmsnorm.py:75]: acc = acc + remote_buffer[tile_n, :].to(torch.float32) | |
| # src[allreduce_bias_rmsnorm.py:77]: # Step 4: RMS Norm: y = acc * rsqrt(mean(acc^2) + eps) * weight | |
| local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(tl.pointer_type(tl.uint32)) | |
| # src[allreduce_bias_rmsnorm.py:78]: variance = torch.mean(acc * acc, dim=-1, keepdim=True) | |
| wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks | |
| # src[allreduce_bias_rmsnorm.py:80]: normalized = acc * rstd | |
| # src[allreduce_bias_rmsnorm.py:81]: output[tile_n, :] = (normalized * weight[None, :].to(torch.float32)).to(x.dtype) | |
| if hasPreviousMemAccess: | |
| # src[allreduce_bias_rmsnorm.py:81]: output[tile_n, :] = (normalized * weight[None, :].to(torch.float32)).to(x.dtype) | |
| tl.debug_barrier() | |
| # src[allreduce_bias_rmsnorm.py:83]: # Step 5: Final sync (release only) | |
| # src[allreduce_bias_rmsnorm.py:84]: # hl.triton_kernel( | |
| # src[allreduce_bias_rmsnorm.py:85]: # symm_mem_sync, | |
| if flat_tid < world_size: | |
| # src[allreduce_bias_rmsnorm.py:84]: # hl.triton_kernel( | |
| _send_signal(send_addrs, 'release' if hasPreviousMemAccess else 'relaxed') | |
| # src[allreduce_bias_rmsnorm.py:85]: # symm_mem_sync, | |
| _wait_signal(wait_addrs, 'acquire' if hasSubsequentMemAccess else 'relaxed') | |
| # src[allreduce_bias_rmsnorm.py:87]: # output_like=None, | |
| # src[allreduce_bias_rmsnorm.py:88]: # ) | |
| if hasSubsequentMemAccess: | |
| # src[allreduce_bias_rmsnorm.py:88]: # ) | |
| tl.debug_barrier() | |
| @triton.jit | |
| def _get_flat_bid(): | |
| # src[allreduce_bias_rmsnorm.py:93]: def helion_one_shot_allreduce_bias_rmsnorm( | |
| # src[allreduce_bias_rmsnorm.py:94]: symm_mem_buffer: torch.Tensor, | |
| # src[allreduce_bias_rmsnorm.py:92-96]: ... | |
| return tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0) + tl.program_id(1) * tl.num_programs(0) + tl.program_id(0) | |
| @triton.jit | |
| def _get_flat_tid(): | |
| # src[allreduce_bias_rmsnorm.py:100]: """ | |
| tid_x, tid_y, tid_z = _get_tid() | |
| # src[allreduce_bias_rmsnorm.py:101]: Wrapper that sets up symmetric memory and calls the Helion kernel. | |
| ntid_x, ntid_y, _ = _get_ntid() | |
| # src[allreduce_bias_rmsnorm.py:102]: """ | |
| return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x | |
| @triton.jit | |
| def _get_tid(): | |
| # src[allreduce_bias_rmsnorm.py:107]: symm_mem_hdl = symm_mem.rendezvous(symm_mem_buffer, group.group_name) | |
| # src[allreduce_bias_rmsnorm.py:106-117]: ... | |
| return tl.inline_asm_elementwise('\n mov.u32 $0, %tid.x;\n mov.u32 $1, %tid.y;\n mov.u32 $2, %tid.z;\n ', '=r,=r,=r', [], dtype=(tl.uint32, tl.uint32, tl.uint32), is_pure=True, pack=1) | |
| @triton.jit | |
| def _get_ntid(): | |
| # src[allreduce_bias_rmsnorm.py:122]: @helion.jit( | |
| # src[allreduce_bias_rmsnorm.py:123]: config=helion.Config( | |
| # src[allreduce_bias_rmsnorm.py:121-132]: ... | |
| return tl.inline_asm_elementwise('\n mov.u32 $0, %ntid.x;\n mov.u32 $1, %ntid.y;\n mov.u32 $2, %ntid.z;\n ', '=r,=r,=r', [], dtype=(tl.uint32, tl.uint32, tl.uint32), is_pure=True, pack=1) | |
| @triton.jit | |
| def _send_signal(addrs, sem: tl.constexpr) -> None: | |
| # src[allreduce_bias_rmsnorm.py:136]: WORLD_SIZE: hl.constexpr, | |
| # src[allreduce_bias_rmsnorm.py:137]: GROUP_NAME: hl.constexpr, | |
| # src[allreduce_bias_rmsnorm.py:138]: ) -> torch.Tensor: | |
| # src[allreduce_bias_rmsnorm.py:136-153]: ... | |
| tl.inline_asm_elementwise(f'\n {{\n .reg .u32 %tmp32_<1>;\n .reg .pred %p<1>;\n\n send_signal:\n atom.global.{sem}.sys.cas.b32 %tmp32_0, [$1], 0, 1;\n setp.eq.u32 %p0, %tmp32_0, 0;\n @!%p0 bra send_signal;\n }}\n ', '=r, l', [addrs], dtype=addrs.dtype, is_pure=False, pack=1) | |
| @triton.jit | |
| def _wait_signal(addrs, sem: tl.constexpr) -> None: | |
| # src[allreduce_bias_rmsnorm.py:158]: # reduce scatter | |
| # src[allreduce_bias_rmsnorm.py:159]: # TODO(shunting): get rid of the reshape workaround | |
| # src[allreduce_bias_rmsnorm.py:157-174]: ... | |
| tl.inline_asm_elementwise(f'\n {{\n .reg .u32 %tmp32_<1>;\n .reg .pred %p<1>;\n\n wait_signal:\n atom.global.sys.{sem}.cas.b32 %tmp32_0, [$1], 1, 0;\n setp.eq.u32 %p0, %tmp32_0, 1;\n @!%p0 bra wait_signal;\n }}\n ', '=r, l', [addrs], dtype=tl.int32, is_pure=False, pack=1) | |
| _BLOCK_SIZE_0 = tl.constexpr(8) | |
| @triton.jit | |
| def _helion_one_shot_allreduce_bias_rmsnorm_kernel(x, symm_mem_buffer, bias, buffer_tuple_item_0, buffer_tuple_item_1, buffer_tuple_item_2, buffer_tuple_item_3, buffer_tuple_item_4, buffer_tuple_item_5, buffer_tuple_item_6, buffer_tuple_item_7, weight, output, signal_pad_ptrs, _REDUCTION_BLOCK_1: tl.constexpr): | |
| # src[allreduce_bias_rmsnorm.py:56]: for tile_n in hl.tile(N): | |
| pid_0 = tl.program_id(0) | |
| offset_0 = pid_0 * _BLOCK_SIZE_0 | |
| indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) | |
| # src[allreduce_bias_rmsnorm.py:63]: hl.triton_kernel( | |
| # src[allreduce_bias_rmsnorm.py:64]: symm_mem_sync, | |
| # src[allreduce_bias_rmsnorm.py:65]: args=(signal_pad_ptrs, None, RANK, WORLD_SIZE, True, True), | |
| # src[allreduce_bias_rmsnorm.py:63-67]: ... | |
| symm_mem_sync(signal_pad_ptrs, None, 0, 8, True, True) | |
| # src[allreduce_bias_rmsnorm.py:78]: variance = torch.mean(acc * acc, dim=-1, keepdim=True) | |
| variance_extra_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32) | |
| # src[allreduce_bias_rmsnorm.py:58]: symm_mem_buffer[tile_n, :] = x[tile_n, :] | |
| for roffset_1 in tl.range(0, 4096, _REDUCTION_BLOCK_1): | |
| rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32) | |
| load = tl.load(x + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| tl.store(symm_mem_buffer + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), load, None) | |
| # src[allreduce_bias_rmsnorm.py:71]: acc = symm_mem_buffer[tile_n, :].to(torch.float32) * 0.0 + bias[None, :].to( | |
| load_1 = tl.load(symm_mem_buffer + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_0 = tl.full([], 0.0, tl.float32) | |
| v_1 = load_1 * v_0 | |
| load_2 = tl.load(bias + rindex_1[None, :] * 1, None) | |
| # src[allreduce_bias_rmsnorm.py:71]: acc = symm_mem_buffer[tile_n, :].to(torch.float32) * 0.0 + bias[None, :].to( | |
| # src[allreduce_bias_rmsnorm.py:72]: torch.float32 | |
| # src[allreduce_bias_rmsnorm.py:73]: ) | |
| v_2 = v_1 + load_2 | |
| # src[allreduce_bias_rmsnorm.py:75]: acc = acc + remote_buffer[tile_n, :].to(torch.float32) | |
| load_3 = tl.load(buffer_tuple_item_0 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_3 = v_2 + load_3 | |
| load_4 = tl.load(buffer_tuple_item_1 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_4 = v_3 + load_4 | |
| load_5 = tl.load(buffer_tuple_item_2 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_5 = v_4 + load_5 | |
| load_6 = tl.load(buffer_tuple_item_3 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_6 = v_5 + load_6 | |
| load_7 = tl.load(buffer_tuple_item_4 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_7 = v_6 + load_7 | |
| load_8 = tl.load(buffer_tuple_item_5 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_8 = v_7 + load_8 | |
| load_9 = tl.load(buffer_tuple_item_6 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_9 = v_8 + load_9 | |
| load_10 = tl.load(buffer_tuple_item_7 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_10 = v_9 + load_10 | |
| # src[allreduce_bias_rmsnorm.py:78]: variance = torch.mean(acc * acc, dim=-1, keepdim=True) | |
| v_11 = v_10 * v_10 | |
| v_12 = variance_extra_acc + v_11 | |
| variance_extra_acc = v_12 | |
| variance_extra = tl.cast(tl.reshape(tl.sum(variance_extra_acc, 1), [_BLOCK_SIZE_0, 1]), tl.float32) | |
| v_13 = 4096 | |
| v_14 = tl.cast(v_13, tl.float32) | |
| v_15 = variance_extra / v_14 | |
| # src[allreduce_bias_rmsnorm.py:79]: rstd = torch.rsqrt(variance + EPS) # type: ignore[unsupported-operation] | |
| v_16 = tl.full([], 1e-05, tl.float32) | |
| v_17 = v_15 + v_16 | |
| v_18 = libdevice.rsqrt(v_17) | |
| # src[allreduce_bias_rmsnorm.py:58]: symm_mem_buffer[tile_n, :] = x[tile_n, :] | |
| for roffset_1 in tl.range(0, 4096, _REDUCTION_BLOCK_1): | |
| rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32) | |
| v_18_copy = v_18 | |
| # src[allreduce_bias_rmsnorm.py:71]: acc = symm_mem_buffer[tile_n, :].to(torch.float32) * 0.0 + bias[None, :].to( | |
| load_11 = tl.load(symm_mem_buffer + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_19 = tl.full([], 0.0, tl.float32) | |
| v_20 = load_11 * v_19 | |
| load_12 = tl.load(bias + rindex_1[None, :] * 1, None) | |
| # src[allreduce_bias_rmsnorm.py:71]: acc = symm_mem_buffer[tile_n, :].to(torch.float32) * 0.0 + bias[None, :].to( | |
| # src[allreduce_bias_rmsnorm.py:72]: torch.float32 | |
| # src[allreduce_bias_rmsnorm.py:73]: ) | |
| v_21 = v_20 + load_12 | |
| # src[allreduce_bias_rmsnorm.py:75]: acc = acc + remote_buffer[tile_n, :].to(torch.float32) | |
| load_13 = tl.load(buffer_tuple_item_0 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_22 = v_21 + load_13 | |
| load_14 = tl.load(buffer_tuple_item_1 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_23 = v_22 + load_14 | |
| load_15 = tl.load(buffer_tuple_item_2 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_24 = v_23 + load_15 | |
| load_16 = tl.load(buffer_tuple_item_3 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_25 = v_24 + load_16 | |
| load_17 = tl.load(buffer_tuple_item_4 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_26 = v_25 + load_17 | |
| load_18 = tl.load(buffer_tuple_item_5 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_27 = v_26 + load_18 | |
| load_19 = tl.load(buffer_tuple_item_6 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_28 = v_27 + load_19 | |
| load_20 = tl.load(buffer_tuple_item_7 + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), None) | |
| v_29 = v_28 + load_20 | |
| # src[allreduce_bias_rmsnorm.py:80]: normalized = acc * rstd | |
| v_30 = v_29 * v_18_copy | |
| # src[allreduce_bias_rmsnorm.py:81]: output[tile_n, :] = (normalized * weight[None, :].to(torch.float32)).to(x.dtype) | |
| load_21 = tl.load(weight + rindex_1[None, :] * 1, None) | |
| v_31 = v_30 * load_21 | |
| tl.store(output + (indices_0[:, None] * 4096 + rindex_1[None, :] * 1), v_31, None) | |
| def one_shot_allreduce_bias_rmsnorm_kernel(symm_mem_buffer: torch.Tensor, x: torch.Tensor, bias: torch.Tensor, weight: torch.Tensor, signal_pad_ptrs: torch.Tensor, EPS: hl.constexpr, RANK: hl.constexpr, WORLD_SIZE: hl.constexpr, GROUP_NAME: hl.constexpr, *, _launcher=_default_launcher): | |
| """ | |
| Fused one-shot all-reduce + bias addition + RMS normalization. | |
| """ | |
| # src[allreduce_bias_rmsnorm.py:50]: N, D = x.size() | |
| N, D = x.size() | |
| # src[allreduce_bias_rmsnorm.py:51]: output = torch.empty_like(x) | |
| output = torch.empty_like(x) | |
| # src[allreduce_bias_rmsnorm.py:54]: buffer_tuple = torch.ops.symm_mem.get_remote_tensors(symm_mem_buffer, GROUP_NAME) | |
| buffer_tuple = torch.ops.symm_mem.get_remote_tensors(symm_mem_buffer, '0') | |
| # src[allreduce_bias_rmsnorm.py:56]: for tile_n in hl.tile(N): | |
| _BLOCK_SIZE_0 = 8 | |
| # src[allreduce_bias_rmsnorm.py:58]: symm_mem_buffer[tile_n, :] = x[tile_n, :] | |
| _REDUCTION_BLOCK_1 = 512 | |
| # src[allreduce_bias_rmsnorm.py:56]: for tile_n in hl.tile(N): | |
| # src[allreduce_bias_rmsnorm.py:57]: # Step 1: Copy input x to our symmetric memory buffer | |
| # src[allreduce_bias_rmsnorm.py:58]: symm_mem_buffer[tile_n, :] = x[tile_n, :] | |
| # src[allreduce_bias_rmsnorm.py:56-81]: ... | |
| _launcher(_helion_one_shot_allreduce_bias_rmsnorm_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, symm_mem_buffer, bias, buffer_tuple[0], buffer_tuple[1], buffer_tuple[2], buffer_tuple[3], buffer_tuple[4], buffer_tuple[5], buffer_tuple[6], buffer_tuple[7], weight, output, signal_pad_ptrs, _REDUCTION_BLOCK_1, num_warps=8, num_stages=1) | |
| # src[allreduce_bias_rmsnorm.py:90]: return output | |
| return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment