Created
September 3, 2025 21:04
-
-
Save davidberard98/90e770005358341409fc00f4323d930a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # AOT ID: ['0_backward'] | |
| from ctypes import c_void_p, c_long, c_int | |
| import torch | |
| import math | |
| import random | |
| import os | |
| import tempfile | |
| from math import inf, nan | |
| from cmath import nanj | |
| from torch._inductor.hooks import run_intermediate_hooks | |
| from torch._inductor.utils import maybe_profile | |
| from torch._inductor.codegen.memory_planning import _align as align | |
| from torch import device, empty_strided | |
| from torch._inductor.async_compile import AsyncCompile | |
| from torch._inductor.select_algorithm import extern_kernels | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime.triton_heuristics import start_graph, end_graph | |
| from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
| from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
| aten = torch.ops.aten | |
| inductor_ops = torch.ops.inductor | |
| _quantized = torch.ops._quantized | |
| assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
| assert_alignment = torch._C._dynamo.guards.assert_alignment | |
| empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
| empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned | |
| empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
| empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu | |
| empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia | |
| reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor | |
| alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
| empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p | |
| # kernel path: /tmp/tmpnrq5e8tu/ti/ctikyxfidllsntvl7cvzx6cddmbfuafg343fg4xdbg6gycvue3hn.py | |
| # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
| # Source node to ATen node mapping: | |
| # Graph fragment: | |
| # %full_default_4 : Tensor "f32[2, 4, 277][1108, 277, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 4, 277], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
| # %flex_attention_backward : [num_users=4] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem_2, %getitem_3, %tangents_1, %full_default_4, %fw_graph0, %joint_graph0, (1, 1, %full, %full_default, None, None, %convert_element_type, %convert_element_type_1, None, None, 1073741824, 1073741824, %mask_graph0), 0.25, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (%primals_4,), ()), kwargs = {}) | |
| # return %getitem_8 | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton.jit | |
| def triton_poi_fused_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 4 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = xindex < xnumel | |
| x0 = xindex | |
| tmp0 = 0.0 | |
| tl.store(out_ptr0 + (x0), tmp0, xmask) | |
| # kernel path: /tmp/tmpnrq5e8tu/xb/cxbf4sgfrfeefdlfmzcrsustqoaopn6vkn3b2rd23antyudmewm7.py | |
| # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
| # Source node to ATen node mapping: | |
| # Graph fragment: | |
| # %getitem_2 : Tensor "f16[2, 4, 277, 16][17728, 4432, 16, 1]cuda:0" = PlaceHolder[target=getitem_2] | |
| # %tangents_1 : Tensor "f16[2, 4, 277, 16][17728, 4432, 16, 1]cuda:0" = PlaceHolder[target=tangents_1] | |
| # %buf1 : Tensor "f16[2, 4, 277][1152, 277, 1]cuda:0" = PlaceHolder[target=buf1] | |
| # %full_default_4 : Tensor "f32[2, 4, 277][1108, 277, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 4, 277], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
| # %flex_attention_backward : [num_users=4] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem_2, %getitem_3, %tangents_1, %full_default_4, %fw_graph0, %joint_graph0, (1, 1, %full, %full_default, None, None, %convert_element_type, %convert_element_type_1, None, None, 1073741824, 1073741824, %mask_graph0), 0.25, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (%primals_4,), ()), kwargs = {}) | |
| # return %buf1,%buf2 | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton.jit | |
| def triton_per_fused_zeros_1(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr): | |
| xnumel = 2216 | |
| r0_numel = 16 | |
| R0_BLOCK: tl.constexpr = 16 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = xindex < xnumel | |
| r0_index = tl.arange(0, R0_BLOCK)[None, :] | |
| r0_offset = 0 | |
| r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_2 = r0_index | |
| x3 = xindex | |
| x0 = (xindex % 1108) | |
| x1 = xindex // 1108 | |
| tmp0 = tl.load(in_ptr0 + (r0_2 + 16*x3), xmask, other=0.0).to(tl.float32) | |
| tmp1 = tl.load(in_ptr1 + (r0_2 + 16*x3), xmask, other=0.0).to(tl.float32) | |
| tmp2 = tmp0 * tmp1 | |
| tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) | |
| tmp5 = tl.where(xmask, tmp3, 0) | |
| tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) | |
| tmp7 = tmp6.to(tl.float32) | |
| tmp8 = 0.0 | |
| tmp9 = tmp7 - tmp8 | |
| tl.store(out_ptr1 + (x3), tmp9, xmask) | |
| # kernel path: /tmp/tmpnrq5e8tu/oy/coybprchvtfhw7zf24gvxn66tedl6ow54cdbe5cktus4md4gcr5f.py | |
| # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
| # Source node to ATen node mapping: | |
| # Graph fragment: | |
| # %primals_1 : Tensor "f16[2, 4, 277, 16][17728, 4432, 16, 1]cuda:0" = PlaceHolder[target=primals_1] | |
| # %primals_2 : Tensor "f16[2, 4, 277, 16][17728, 4432, 16, 1]cuda:0" = PlaceHolder[target=primals_2] | |
| # %primals_3 : Tensor "f16[2, 4, 277, 16][17728, 4432, 16, 1]cuda:0" = PlaceHolder[target=primals_3] | |
| # %getitem_3 : Tensor "f32[2, 4, 277][1108, 277, 1]cuda:0" = PlaceHolder[target=getitem_3] | |
| # %buf2 : Tensor "f32[2, 4, 277][1108, 277, 1]cuda:0" = PlaceHolder[target=buf2] | |
| # %tangents_1 : Tensor "f16[2, 4, 277, 16][17728, 4432, 16, 1]cuda:0" = PlaceHolder[target=tangents_1] | |
| # %getitem_4 : Tensor "f16[2, 4, 277, 16][17728, 4432, 16, 1]cuda:0" = PlaceHolder[target=getitem_4] | |
| # %getitem_6 : Tensor "f16[2, 4, 277, 16][17728, 4432, 16, 1]cuda:0" = PlaceHolder[target=getitem_6] | |
| # %full : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=full] | |
| # %full_default : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=full_default] | |
| # %convert_element_type : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=convert_element_type] | |
| # %convert_element_type_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=convert_element_type_1] | |
| # %buf6 : Tensor "f32[0][1]cuda:0" = PlaceHolder[target=buf6] | |
| # %buf7 : Tensor "f32[0][1]cuda:0" = PlaceHolder[target=buf7] | |
| # %buf8 : Tensor "f32[0][1]cuda:0" = PlaceHolder[target=buf8] | |
| # %buf9 : Tensor "f32[0][1]cuda:0" = PlaceHolder[target=buf9] | |
| # %primals_4 : Tensor "f16[4][1]cuda:0" = PlaceHolder[target=primals_4] | |
| # %getitem_8 : Tensor "f32[4][1]cuda:0" = PlaceHolder[target=getitem_8] | |
| # %full_default_4 : Tensor "f32[2, 4, 277][1108, 277, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 4, 277], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
| # %flex_attention_backward : [num_users=4] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem_2, %getitem_3, %tangents_1, %full_default_4, %fw_graph0, %joint_graph0, (1, 1, %full, %full_default, None, None, %convert_element_type, %convert_element_type_1, None, None, 1073741824, 1073741824, %mask_graph0), 0.25, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (%primals_4,), ()), kwargs = {}) | |
| # return %getitem_5 | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| @triton.jit | |
| def triton_tem_fused_zeros_2(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0): | |
| PRESCALE_QK : tl.constexpr = False | |
| ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
| BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
| WRITE_DQ : tl.constexpr = True | |
| OUTPUT_LOGSUMEXP : tl.constexpr = True | |
| FLOAT32_PRECISION : tl.constexpr = 'ieee' | |
| IS_DIVISIBLE : tl.constexpr = False | |
| SM_SCALE : tl.constexpr = 0.25 | |
| GQA_SHARED_HEADS : tl.constexpr = 1 | |
| HAS_FULL_BLOCKS : tl.constexpr = False | |
| QK_HEAD_DIM : tl.constexpr = 16 | |
| QK_HEAD_DIM_ROUNDED : tl.constexpr = 16 | |
| V_HEAD_DIM : tl.constexpr = 16 | |
| V_HEAD_DIM_ROUNDED : tl.constexpr = 16 | |
| SAFE_HEAD_DIM : tl.constexpr = True | |
| BLOCK_M1 : tl.constexpr = 16 | |
| BLOCK_N1 : tl.constexpr = 32 | |
| BLOCK_M2 : tl.constexpr = 32 | |
| BLOCK_N2 : tl.constexpr = 16 | |
| SPARSE_Q_BLOCK_SIZE : tl.constexpr = 1073741824 | |
| SPARSE_KV_BLOCK_SIZE : tl.constexpr = 1073741824 | |
| kpack : tl.constexpr = 2 | |
| matrix_instr_nonkdim : tl.constexpr = 16 | |
| waves_per_eu : tl.constexpr = 0 | |
| INDEX_DTYPE : tl.constexpr = tl.int32 | |
| Q = arg_Q | |
| K = arg_K | |
| V = arg_V | |
| LSE = arg_LSE | |
| DELTA = arg_DELTA | |
| DO = arg_DO | |
| DQ = arg_DQ | |
| DV = arg_DV | |
| KV_NUM_BLKS = arg_KV_NUM_BLKS | |
| KV_IDX = arg_KV_IDX | |
| Q_NUM_BLKS = arg_Q_NUM_BLKS | |
| Q_IDX = arg_Q_IDX | |
| FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS | |
| FULL_KV_IDX = arg_FULL_KV_IDX | |
| FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS | |
| FULL_Q_IDX = arg_FULL_Q_IDX | |
| # Sub notation for this kernel: | |
| # | |
| # Q: Query, K: Key, V: Value | |
| # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) | |
| # DELTA: Precomputed sum(OUT*DO, axis=-1) | |
| # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value | |
| # DK: Derivative of Key, is the written to via the store_output call due to some limitations with | |
| # inductor codegen | |
| # M: Number of queries, N: Number of keys/values | |
| # QK_HEAD_DIM: The dimension of the query and key embeddings | |
| # V_HEAD_DIM: The dimension of the value embeddings | |
| # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim | |
| # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. | |
| # (Modifiable) Performance tuning options | |
| # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. | |
| # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. | |
| # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. | |
| # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. | |
| # | |
| # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. | |
| # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. | |
| # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. | |
| # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. | |
| # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. | |
| # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. | |
| # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. | |
| # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. | |
| # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. | |
| # The below are kernel options that can be applied for certain score_mods, | |
| # or involve a numerics vs. perf tradeoff | |
| # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has | |
| # about 20% more numerical error, but slightly faster. | |
| # Define strides of inputs | |
| stride_qz, stride_qh, stride_qm, stride_qd = 17728, 4432, 16, 1 | |
| stride_kz, stride_kh, stride_kn, stride_kd = 17728, 4432, 16, 1 | |
| stride_vz, stride_vh, stride_vn, stride_vd = 17728, 4432, 16, 1 | |
| stride_doz, stride_doh, stride_dom, stride_dod = 17728, 4432, 16, 1 | |
| stride_dqz, stride_dqh, stride_dqm, stride_dqd = 17728, 4432, 16, 1 | |
| stride_dvz, stride_dvh, stride_dvm, stride_dvd = 17728, 4432, 16, 1 | |
| ZQ = 2 | |
| HQ = 4 | |
| HKV = 4 | |
| Q_LEN = 277 | |
| ZKV = 2 | |
| KV_LEN = 277 | |
| MATMUL_PRECISION = Q.dtype.element_ty | |
| pid = tl.program_id(0).to(INDEX_DTYPE) | |
| NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) | |
| NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) | |
| off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx | |
| off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx | |
| off_zkv = off_zq % ZKV # kv batch idx | |
| SPARSE_Z = 1 | |
| SPARSE_HQ = 1 | |
| sparse_idx_z = off_zq % SPARSE_Z | |
| k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) | |
| v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) | |
| # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] | |
| # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] | |
| dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) | |
| # offset K, V, DV pointers for batch/kv-head | |
| K += k_adj | |
| V += v_adj | |
| DV += dv_adj | |
| RCP_LN2 = 1.44269504 | |
| offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) | |
| offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) | |
| if pid >= NUM_KV_BLOCKS: | |
| off_pid = pid - NUM_KV_BLOCKS | |
| # THIS BLOCK DOES DQ | |
| SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) | |
| SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) | |
| off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS | |
| start_m2_block = off_pid % NUM_Q_BLOCKS | |
| off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE | |
| stride_kv_num_blks_h = 1 | |
| stride_kv_idx_h = 1 | |
| stride_kv_idx_m = 1 | |
| sparse_idx_hq2 = off_hq2 % SPARSE_HQ | |
| sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 | |
| sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask | |
| sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 | |
| # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. | |
| q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) | |
| do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) | |
| dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) | |
| off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) | |
| Q2 = Q + q_adj2 | |
| DO2 = DO + do_adj2 | |
| # TODO: This does not work if DQ is not the same layout as Q (for example, | |
| # if Q is broadcasted) | |
| DQ2 = DQ + dq_adj2 | |
| LSE2 = LSE + off_chz2 | |
| DELTA2 = DELTA + off_chz2 | |
| # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) | |
| dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) | |
| start_m2 = start_m2_block * BLOCK_M2 | |
| offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) | |
| # load Q and do: they stay in SRAM throughout the inner loop. | |
| q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) | |
| do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) | |
| if PRESCALE_QK: | |
| q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) | |
| if IS_DIVISIBLE: | |
| Di = tl.load(DELTA2 + offs_m2) | |
| lse = tl.load(LSE2 + offs_m2) | |
| else: | |
| Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) | |
| lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) | |
| lse = tl.where(lse == -float("inf"), 0.0, lse) | |
| lse = lse[:, None] | |
| # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| # KV_IDX and KV_NUM_BLKS are always contiguous. | |
| kv_indices = KV_IDX + sparse_kv_idx_offset | |
| kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading | |
| sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) | |
| offs_n2 = kv_start + tl.arange(0, BLOCK_N2) | |
| dq = bwd_dq_inner( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, | |
| K, V, | |
| dq, q, do, Di, lse, | |
| off_zq, off_hq2, offs_m2, offs_n2, | |
| stride_kn, stride_kd, stride_vn, stride_vd, | |
| kv_indices, sparse_kv_num_blocks, | |
| MATMUL_PRECISION, | |
| IS_FULL_BLOCKS=False, | |
| ) | |
| if HAS_FULL_BLOCKS: | |
| # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. | |
| kv_indices = FULL_KV_IDX + sparse_kv_idx_offset | |
| kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading | |
| sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) | |
| offs_n2 = kv_start + tl.arange(0, BLOCK_N2) | |
| dq = bwd_dq_inner( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, | |
| K, V, | |
| dq, q, do, Di, lse, | |
| off_zq, off_hq2, offs_m2, offs_n2, | |
| stride_kn, stride_kd, stride_vn, stride_vd, | |
| kv_indices, sparse_kv_num_blocks, | |
| MATMUL_PRECISION, | |
| IS_FULL_BLOCKS=True, | |
| ) | |
| # Write back dQ. | |
| dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd | |
| dq *= SM_SCALE | |
| if IS_DIVISIBLE and SAFE_HEAD_DIM: | |
| tl.store(dq_ptrs, dq) | |
| else: | |
| tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) | |
| else: | |
| # THIS BLOCK DOES DK & DV | |
| SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) | |
| SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) | |
| pid_mask = pid // SPARSE_KV_MULTIPLE | |
| stride_q_num_blks_h = 1 | |
| stride_q_idx_h = 1 | |
| stride_q_idx_n = 1 | |
| dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) | |
| dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) | |
| start_n1 = pid * BLOCK_N1 | |
| offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) | |
| # load K and V: they stay in SRAM throughout the inner loop. | |
| k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) | |
| v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) | |
| if PRESCALE_QK: | |
| k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) | |
| for off_g in range(0, GQA_SHARED_HEADS): | |
| off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g | |
| # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. | |
| q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) | |
| do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) | |
| dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) | |
| off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) | |
| Q1 = Q + q_adj1 | |
| DO1 = DO + do_adj1 | |
| # TODO: This does not work if DQ is not the same layout as Q (for example, | |
| # if Q is broadcasted) | |
| LSE1 = LSE + off_chz1 | |
| DELTA1 = DELTA + off_chz1 | |
| sparse_idx_hq1 = off_hq1 % SPARSE_HQ | |
| sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 | |
| sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask | |
| sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 | |
| # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| # Q_IDX and Q_NUM_BLKS are always contiguous. | |
| q_indices = Q_IDX + sparse_q_idx_offset | |
| q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading | |
| sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) | |
| offs_m1 = q_start + tl.arange(0, BLOCK_M1) | |
| dk, dv = bwd_dkdv_inner( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, | |
| Q1, DO1, DELTA1, LSE1, | |
| dk, dv, k, v, | |
| off_zq, off_hq1, offs_n1, offs_m1, | |
| stride_qm, stride_qd, stride_dom, stride_dod, | |
| q_indices, sparse_q_num_blocks, | |
| MATMUL_PRECISION, | |
| IS_FULL_BLOCKS=False, | |
| ) | |
| if HAS_FULL_BLOCKS: | |
| # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. | |
| q_indices = FULL_Q_IDX + sparse_q_idx_offset | |
| q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading | |
| sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) | |
| offs_m1 = q_start + tl.arange(0, BLOCK_M1) | |
| dk, dv = bwd_dkdv_inner( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, | |
| Q1, DO1, DELTA1, LSE1, | |
| dk, dv, k, v, | |
| off_zq, off_hq1, offs_n1, offs_m1, | |
| stride_qm, stride_qd, stride_dom, stride_dod, | |
| q_indices, sparse_q_num_blocks, | |
| MATMUL_PRECISION, | |
| IS_FULL_BLOCKS=True, | |
| ) | |
| # Write back dV and dK. | |
| dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd | |
| index_n = offs_n1[:, None] | |
| index_k = offs_k[None, :] | |
| index_v = offs_v[None, :] | |
| if IS_DIVISIBLE and SAFE_HEAD_DIM: | |
| tl.store(dv_ptrs, dv) | |
| else: | |
| tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) | |
| dk *= SM_SCALE | |
| if SAFE_HEAD_DIM: | |
| mask = index_n < KV_LEN | |
| else: | |
| mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) | |
| # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] | |
| # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] | |
| tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) | |
| xindex = index_k + 16*index_n + 4432*off_hkv + 17728*off_zq | |
| tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) | |
| @triton.jit | |
| def bwd_dq_inner( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, | |
| K, V, # pointers | |
| dq, q, do, Di, lse, | |
| off_z, off_hq, offs_m2, offs_n2, | |
| stride_kn, stride_kd, stride_vn, stride_vd, | |
| kv_indices, sparse_kv_num_blocks, | |
| MATMUL_PRECISION, | |
| IS_FULL_BLOCKS, | |
| ): | |
| PRESCALE_QK : tl.constexpr = False | |
| ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
| BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
| WRITE_DQ : tl.constexpr = True | |
| OUTPUT_LOGSUMEXP : tl.constexpr = True | |
| FLOAT32_PRECISION : tl.constexpr = 'ieee' | |
| IS_DIVISIBLE : tl.constexpr = False | |
| SM_SCALE : tl.constexpr = 0.25 | |
| GQA_SHARED_HEADS : tl.constexpr = 1 | |
| HAS_FULL_BLOCKS : tl.constexpr = False | |
| QK_HEAD_DIM : tl.constexpr = 16 | |
| QK_HEAD_DIM_ROUNDED : tl.constexpr = 16 | |
| V_HEAD_DIM : tl.constexpr = 16 | |
| V_HEAD_DIM_ROUNDED : tl.constexpr = 16 | |
| SAFE_HEAD_DIM : tl.constexpr = True | |
| BLOCK_M1 : tl.constexpr = 16 | |
| BLOCK_N1 : tl.constexpr = 32 | |
| BLOCK_M2 : tl.constexpr = 32 | |
| BLOCK_N2 : tl.constexpr = 16 | |
| SPARSE_Q_BLOCK_SIZE : tl.constexpr = 1073741824 | |
| SPARSE_KV_BLOCK_SIZE : tl.constexpr = 1073741824 | |
| kpack : tl.constexpr = 2 | |
| matrix_instr_nonkdim : tl.constexpr = 16 | |
| waves_per_eu : tl.constexpr = 0 | |
| INDEX_DTYPE : tl.constexpr = tl.int32 | |
| SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) | |
| RCP_LN2: tl.constexpr = 1.44269504 | |
| Q_LEN = 277 | |
| KV_LEN = 277 | |
| offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) | |
| offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) | |
| kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd | |
| vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd | |
| # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. | |
| tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) | |
| hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) | |
| if not IS_DIVISIBLE: | |
| if hi >= 1: | |
| for start_n in range(0, hi - 1): | |
| dq = bwd_dq_block_mn( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, | |
| dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, | |
| off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, | |
| stride_kn, stride_kd, stride_vn, stride_vd, | |
| kv_indices, sparse_kv_num_blocks, | |
| MATMUL_PRECISION, RCP_LN2, | |
| IS_FULL_BLOCKS, | |
| ) | |
| # Increment pointers. | |
| offset = get_offset_for_next_block( | |
| start_n, kv_indices, sparse_kv_num_blocks, | |
| SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS | |
| ) | |
| kT_ptrs += offset * stride_kn | |
| vT_ptrs += offset * stride_vn | |
| offs_n2 += offset | |
| dq = bwd_dq_block_mn( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, | |
| dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, | |
| off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, | |
| stride_kn, stride_kd, stride_vn, stride_vd, | |
| kv_indices, sparse_kv_num_blocks, | |
| MATMUL_PRECISION, RCP_LN2, | |
| IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, | |
| ) | |
| else: | |
| for start_n in range(0, hi): | |
| dq = bwd_dq_block_mn( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, | |
| dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, | |
| off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, | |
| stride_kn, stride_kd, stride_vn, stride_vd, | |
| kv_indices, sparse_kv_num_blocks, | |
| MATMUL_PRECISION, RCP_LN2, | |
| IS_FULL_BLOCKS, | |
| ) | |
| # Increment pointers. | |
| offset = get_offset_for_next_block( | |
| start_n, kv_indices, sparse_kv_num_blocks, | |
| SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS | |
| ) | |
| kT_ptrs += offset * stride_kn | |
| vT_ptrs += offset * stride_vn | |
| offs_n2 += offset | |
| return dq | |
| @triton.jit | |
| def bwd_dq_block_mn( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, | |
| dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, | |
| off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, | |
| stride_kn, stride_kd, stride_vn, stride_vd, | |
| kv_indices, sparse_kv_num_blocks, | |
| MATMUL_PRECISION, RCP_LN2, | |
| IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, | |
| ): | |
| PRESCALE_QK : tl.constexpr = False | |
| ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
| BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
| WRITE_DQ : tl.constexpr = True | |
| OUTPUT_LOGSUMEXP : tl.constexpr = True | |
| FLOAT32_PRECISION : tl.constexpr = 'ieee' | |
| IS_DIVISIBLE : tl.constexpr = False | |
| SM_SCALE : tl.constexpr = 0.25 | |
| GQA_SHARED_HEADS : tl.constexpr = 1 | |
| HAS_FULL_BLOCKS : tl.constexpr = False | |
| QK_HEAD_DIM : tl.constexpr = 16 | |
| QK_HEAD_DIM_ROUNDED : tl.constexpr = 16 | |
| V_HEAD_DIM : tl.constexpr = 16 | |
| V_HEAD_DIM_ROUNDED : tl.constexpr = 16 | |
| SAFE_HEAD_DIM : tl.constexpr = True | |
| BLOCK_M1 : tl.constexpr = 16 | |
| BLOCK_N1 : tl.constexpr = 32 | |
| BLOCK_M2 : tl.constexpr = 32 | |
| BLOCK_N2 : tl.constexpr = 16 | |
| SPARSE_Q_BLOCK_SIZE : tl.constexpr = 1073741824 | |
| SPARSE_KV_BLOCK_SIZE : tl.constexpr = 1073741824 | |
| kpack : tl.constexpr = 2 | |
| matrix_instr_nonkdim : tl.constexpr = 16 | |
| waves_per_eu : tl.constexpr = 0 | |
| INDEX_DTYPE : tl.constexpr = tl.int32 | |
| # NB reversed order to since K is transposed | |
| kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) | |
| qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) | |
| if not PRESCALE_QK: | |
| qk *= SM_SCALE | |
| # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ | |
| pre_mod_scores = qk | |
| n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None) | |
| # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim | |
| # that the M reads out of bounds prior to the last loop | |
| m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) | |
| tmp0 = (qk) | |
| tmp1 = tmp0.to(tl.float32) | |
| tmp2 = (off_hq) | |
| tmp3 = tl.load(in_ptr16 + tmp2).to(tl.float32) | |
| tmp4 = tmp3.to(tl.float32) | |
| tmp5 = tl.sigmoid(tmp4) | |
| tmp6 = tmp1 * tmp5 | |
| post_mod_scores = tmp6 | |
| if CHECK_BLOCK_BOUNDARY: | |
| # Mask out the elements that are out of the KV_LEN for non divisible seqlen. | |
| post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) | |
| if not IS_FULL_BLOCKS: | |
| tmp7 = tl.full([1], True, tl.int1) | |
| mask_mod_output = tmp7 | |
| if CHECK_BLOCK_BOUNDARY: | |
| mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) | |
| # apply mask for partial masked block | |
| post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| if not PRESCALE_QK: | |
| post_mod_scores *= RCP_LN2 | |
| p = tl.math.exp2(post_mod_scores - lse) | |
| # Compute dP and dS. | |
| # NB reversed order to since V is transposed | |
| vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) | |
| dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) | |
| ds = p * (dp - Di[:, None]) | |
| # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ | |
| tmp8 = (ds) | |
| tmp9 = tmp8.to(tl.float32) | |
| tmp10 = (off_hq) | |
| tmp11 = tl.load(in_ptr16 + tmp10).to(tl.float32) | |
| tmp12 = tmp11.to(tl.float32) | |
| tmp13 = tl.sigmoid(tmp12) | |
| tmp14 = tmp9 * tmp13 | |
| tmp15 = tmp14.to(tl.float32) | |
| grad_scores = tmp15 | |
| if CHECK_BLOCK_BOUNDARY: | |
| grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) | |
| # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ | |
| if WRITE_DQ: | |
| scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) | |
| tmp16 = (off_hq) | |
| tmp17 = (ds) | |
| tmp18 = (pre_mod_scores) | |
| tmp19 = tmp17 * tmp18 | |
| tmp20 = tmp19.to(tl.float32) | |
| tmp21 = tl.load(in_ptr16 + tmp16).to(tl.float32) | |
| tmp22 = tmp21.to(tl.float32) | |
| tmp23 = tl.sigmoid(tmp22) | |
| tmp24 = 1.0 | |
| tmp25 = tmp24 - tmp23 | |
| tmp26 = tmp23 * tmp25 | |
| tmp27 = tmp20 * tmp26 | |
| tmp28 = tmp27.to(tl.float32) | |
| tmp29 = tmp28.to(tl.float32) | |
| tl.atomic_add(in_ptr17 + tl.broadcast_to(tmp16, tmp29.shape), tmp29, scatter_mask, sem='relaxed') | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| ds = grad_scores | |
| if not IS_FULL_BLOCKS: | |
| if CHECK_BLOCK_BOUNDARY: | |
| mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) | |
| # (grads) apply mask for partially unmasked block | |
| ds = tl.where(mask_mod_output, ds, 0.0) | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| ds = ds.to(MATMUL_PRECISION) | |
| # Compute dQ. | |
| dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) | |
| return dq | |
| @triton.jit | |
| def bwd_dkdv_inner( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, | |
| Q, DO, DELTA, LSE, # pointers | |
| dk, dv, k, v, | |
| off_z, off_hq, offs_n1, offs_m1, | |
| stride_qm, stride_qd, stride_dom, stride_dod, | |
| q_indices, sparse_q_num_blocks, | |
| MATMUL_PRECISION, | |
| IS_FULL_BLOCKS, | |
| ): | |
| PRESCALE_QK : tl.constexpr = False | |
| ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
| BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
| WRITE_DQ : tl.constexpr = True | |
| OUTPUT_LOGSUMEXP : tl.constexpr = True | |
| FLOAT32_PRECISION : tl.constexpr = 'ieee' | |
| IS_DIVISIBLE : tl.constexpr = False | |
| SM_SCALE : tl.constexpr = 0.25 | |
| GQA_SHARED_HEADS : tl.constexpr = 1 | |
| HAS_FULL_BLOCKS : tl.constexpr = False | |
| QK_HEAD_DIM : tl.constexpr = 16 | |
| QK_HEAD_DIM_ROUNDED : tl.constexpr = 16 | |
| V_HEAD_DIM : tl.constexpr = 16 | |
| V_HEAD_DIM_ROUNDED : tl.constexpr = 16 | |
| SAFE_HEAD_DIM : tl.constexpr = True | |
| BLOCK_M1 : tl.constexpr = 16 | |
| BLOCK_N1 : tl.constexpr = 32 | |
| BLOCK_M2 : tl.constexpr = 32 | |
| BLOCK_N2 : tl.constexpr = 16 | |
| SPARSE_Q_BLOCK_SIZE : tl.constexpr = 1073741824 | |
| SPARSE_KV_BLOCK_SIZE : tl.constexpr = 1073741824 | |
| kpack : tl.constexpr = 2 | |
| matrix_instr_nonkdim : tl.constexpr = 16 | |
| waves_per_eu : tl.constexpr = 0 | |
| INDEX_DTYPE : tl.constexpr = tl.int32 | |
| SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) | |
| RCP_LN2: tl.constexpr = 1.44269504 | |
| Q_LEN = 277 | |
| KV_LEN = 277 | |
| offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) | |
| offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) | |
| qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd | |
| do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod | |
| # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. | |
| tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) | |
| hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) | |
| if not IS_DIVISIBLE: | |
| if hi >= 1: | |
| for start_m in range(0, hi - 1): | |
| dk, dv = bwd_dkdv_block_mn( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, | |
| dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, | |
| off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, | |
| stride_qm, stride_qd, stride_dom, stride_dod, | |
| q_indices, sparse_q_num_blocks, | |
| MATMUL_PRECISION, RCP_LN2, | |
| IS_FULL_BLOCKS, | |
| ) | |
| # Increment pointers. | |
| offset = get_offset_for_next_block( | |
| start_m, q_indices, sparse_q_num_blocks, | |
| SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS | |
| ) | |
| qT_ptrs += offset * stride_qm | |
| do_ptrs += offset * stride_dom | |
| offs_m1 += offset | |
| dk, dv = bwd_dkdv_block_mn( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, | |
| dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, | |
| off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, | |
| stride_qm, stride_qd, stride_dom, stride_dod, | |
| q_indices, sparse_q_num_blocks, | |
| MATMUL_PRECISION, RCP_LN2, | |
| IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, | |
| ) | |
| else: | |
| for start_m in range(0, hi): | |
| dk, dv = bwd_dkdv_block_mn( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, | |
| dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, | |
| off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, | |
| stride_qm, stride_qd, stride_dom, stride_dod, | |
| q_indices, sparse_q_num_blocks, | |
| MATMUL_PRECISION, RCP_LN2, | |
| IS_FULL_BLOCKS, | |
| ) | |
| # Increment pointers. | |
| offset = get_offset_for_next_block( | |
| start_m, q_indices, sparse_q_num_blocks, | |
| SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS | |
| ) | |
| qT_ptrs += offset * stride_qm | |
| do_ptrs += offset * stride_dom | |
| offs_m1 += offset | |
| return dk, dv | |
| @triton.jit | |
| def bwd_dkdv_block_mn( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, | |
| dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, | |
| off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, | |
| stride_qm, stride_qd, stride_dom, stride_dod, | |
| q_indices, sparse_q_num_blocks, | |
| MATMUL_PRECISION, RCP_LN2, | |
| IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, | |
| ): | |
| PRESCALE_QK : tl.constexpr = False | |
| ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
| BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
| WRITE_DQ : tl.constexpr = True | |
| OUTPUT_LOGSUMEXP : tl.constexpr = True | |
| FLOAT32_PRECISION : tl.constexpr = 'ieee' | |
| IS_DIVISIBLE : tl.constexpr = False | |
| SM_SCALE : tl.constexpr = 0.25 | |
| GQA_SHARED_HEADS : tl.constexpr = 1 | |
| HAS_FULL_BLOCKS : tl.constexpr = False | |
| QK_HEAD_DIM : tl.constexpr = 16 | |
| QK_HEAD_DIM_ROUNDED : tl.constexpr = 16 | |
| V_HEAD_DIM : tl.constexpr = 16 | |
| V_HEAD_DIM_ROUNDED : tl.constexpr = 16 | |
| SAFE_HEAD_DIM : tl.constexpr = True | |
| BLOCK_M1 : tl.constexpr = 16 | |
| BLOCK_N1 : tl.constexpr = 32 | |
| BLOCK_M2 : tl.constexpr = 32 | |
| BLOCK_N2 : tl.constexpr = 16 | |
| SPARSE_Q_BLOCK_SIZE : tl.constexpr = 1073741824 | |
| SPARSE_KV_BLOCK_SIZE : tl.constexpr = 1073741824 | |
| kpack : tl.constexpr = 2 | |
| matrix_instr_nonkdim : tl.constexpr = 16 | |
| waves_per_eu : tl.constexpr = 0 | |
| INDEX_DTYPE : tl.constexpr = tl.int32 | |
| # NB reversed order since Q is transposed | |
| qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) | |
| # Load LSE before computing qk to reduce pipeline stall. | |
| if IS_DIVISIBLE: | |
| lse = tl.load(LSE + offs_m1) | |
| else: | |
| lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) | |
| lse = tl.where(lse == -float("inf"), 0.0, lse) | |
| qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) | |
| if not PRESCALE_QK: | |
| qkT *= SM_SCALE | |
| # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ | |
| m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None) | |
| # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim | |
| # that the n reads out of bounds prior to the last loop | |
| n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) | |
| pre_mod_scores = qkT | |
| tmp30 = (qkT) | |
| tmp31 = tmp30.to(tl.float32) | |
| tmp32 = (off_hq) | |
| tmp33 = tl.load(in_ptr16 + tmp32).to(tl.float32) | |
| tmp34 = tmp33.to(tl.float32) | |
| tmp35 = tl.sigmoid(tmp34) | |
| tmp36 = tmp31 * tmp35 | |
| post_mod_scores = tmp36 | |
| if CHECK_BLOCK_BOUNDARY: | |
| # Mask out the elements that are out of the KV_LEN for non divisible seqlen. | |
| post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf")) | |
| if not IS_FULL_BLOCKS: | |
| tmp37 = tl.full([1], True, tl.int1) | |
| mask_mod_output = tmp37 | |
| if CHECK_BLOCK_BOUNDARY: | |
| mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) | |
| # (grads) apply mask for fully masked block | |
| post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| if not PRESCALE_QK: | |
| post_mod_scores *= RCP_LN2 | |
| pT = tl.math.exp2(post_mod_scores - lse[None, :]) | |
| do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) | |
| # Compute dV. | |
| ppT = pT | |
| dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) | |
| if IS_DIVISIBLE: | |
| Di = tl.load(DELTA + offs_m1) | |
| else: | |
| Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) | |
| # Compute dP and dS. | |
| dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) | |
| dsT = pT * (dpT - Di[None, :]) | |
| # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ | |
| tmp38 = (dsT) | |
| tmp39 = tmp38.to(tl.float32) | |
| tmp40 = (off_hq) | |
| tmp41 = tl.load(in_ptr16 + tmp40).to(tl.float32) | |
| tmp42 = tmp41.to(tl.float32) | |
| tmp43 = tl.sigmoid(tmp42) | |
| tmp44 = tmp39 * tmp43 | |
| tmp45 = tmp44.to(tl.float32) | |
| grad_scores = tmp45 | |
| # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ | |
| if not WRITE_DQ: | |
| idx_b = off_z | |
| idx_h = off_hq | |
| idx_m = m | |
| idx_n = n | |
| scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) | |
| tmp46 = (idx_h) | |
| tmp47 = (dsT) | |
| tmp48 = (pre_mod_scores) | |
| tmp49 = tmp47 * tmp48 | |
| tmp50 = tmp49.to(tl.float32) | |
| tmp51 = tl.load(in_ptr16 + tmp46).to(tl.float32) | |
| tmp52 = tmp51.to(tl.float32) | |
| tmp53 = tl.sigmoid(tmp52) | |
| tmp54 = 1.0 | |
| tmp55 = tmp54 - tmp53 | |
| tmp56 = tmp53 * tmp55 | |
| tmp57 = tmp50 * tmp56 | |
| tmp58 = tmp57.to(tl.float32) | |
| tmp59 = tmp58.to(tl.float32) | |
| tl.atomic_add(in_ptr17 + tl.broadcast_to(tmp46, tmp59.shape), tmp59, scatter_mask, sem='relaxed') | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| if CHECK_BLOCK_BOUNDARY: | |
| grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0) | |
| dsT = grad_scores | |
| if not IS_FULL_BLOCKS: | |
| if CHECK_BLOCK_BOUNDARY: | |
| mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) | |
| # (grads) apply mask for partially unmasked block | |
| dsT = tl.where(mask_mod_output, dsT, 0.0) | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) | |
| return dk, dv | |
| # Utility triton funcs | |
| @triton.jit | |
| def get_offset_for_next_block( | |
| loop_iter, col_indices, total_blocks, | |
| SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, | |
| BLOCKS_ARE_CONTIGUOUS: tl.constexpr | |
| ): | |
| if BLOCKS_ARE_CONTIGUOUS: | |
| return BLOCK | |
| cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE | |
| cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") | |
| next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) | |
| needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 | |
| jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK | |
| offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK | |
| return offset | |
| @triton.jit | |
| def get_bounded_indices(indices, max_len=None): | |
| return indices % max_len if max_len is not None else indices | |
| @triton.jit | |
| def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): | |
| if IS_DIVISIBLE and SAFE_HEAD_DIM: | |
| return tl.load(block_ptr) | |
| elif IS_DIVISIBLE and not SAFE_HEAD_DIM: | |
| return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") | |
| elif not IS_DIVISIBLE and SAFE_HEAD_DIM: | |
| return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") | |
| else: | |
| return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") | |
| @triton.jit | |
| def load_checked_2d( | |
| ptr, | |
| offs_m, | |
| offs_n, | |
| stride_m, | |
| stride_n, | |
| IS_DIVISIBLE_M: tl.constexpr, | |
| IS_DIVISIBLE_N: tl.constexpr, | |
| M_LEN: tl.constexpr, | |
| N_DIM: tl.constexpr, | |
| ): | |
| # Calculate final pointer if strides are provided | |
| if stride_m is not None and stride_n is not None: | |
| ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n | |
| # Handle all masking cases | |
| if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: | |
| return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0) | |
| elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: | |
| return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0) | |
| elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: | |
| return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) | |
| else: # Both divisible | |
| return tl.load(ptr) | |
| class Runner: | |
| def __init__(self, partitions): | |
| self.partitions = partitions | |
| def recursively_apply_fns(self, fns): | |
| new_callables = [] | |
| for fn, c in zip(fns, self.partitions): | |
| new_callables.append(fn(c)) | |
| self.partitions = new_callables | |
| def call(self, args): | |
| primals_1, primals_2, primals_3, primals_4, full, full_default, convert_element_type, convert_element_type_1, getitem_2, getitem_3, tangents_1 = args | |
| args.clear() | |
| assert_size_stride(primals_1, (2, 4, 277, 16), (17728, 4432, 16, 1)) | |
| assert_size_stride(primals_2, (2, 4, 277, 16), (17728, 4432, 16, 1)) | |
| assert_size_stride(primals_3, (2, 4, 277, 16), (17728, 4432, 16, 1)) | |
| assert_size_stride(primals_4, (4, ), (1, )) | |
| assert_size_stride(full, (1, 1, 1), (1, 1, 1)) | |
| assert_size_stride(full_default, (1, 1, 1, 1), (1, 1, 1, 1)) | |
| assert_size_stride(convert_element_type, (1, 1, 1), (1, 1, 1)) | |
| assert_size_stride(convert_element_type_1, (1, 1, 1, 1), (1, 1, 1, 1)) | |
| assert_size_stride(getitem_2, (2, 4, 277, 16), (17728, 4432, 16, 1)) | |
| assert_size_stride(getitem_3, (2, 4, 277), (1108, 277, 1)) | |
| assert_size_stride(tangents_1, (2, 4, 277, 16), (17728, 4432, 16, 1)) | |
| with torch.cuda._DeviceGuard(0): | |
| torch.cuda.set_device(0) | |
| buf0 = empty_strided_cuda((4, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_zeros_0[(1, 1, 1)](buf0, 4, XBLOCK=4, num_warps=1, num_stages=1) | |
| buf2 = empty_strided_cuda((2, 4, 277), (1108, 277, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused_zeros_1[(70, 1, 1)](getitem_2, tangents_1, buf2, 2216, 16, XBLOCK=32, num_warps=2, num_stages=1) | |
| del getitem_2 | |
| buf4 = empty_strided_cuda((2, 4, 277, 16), (17728, 4432, 16, 1), torch.float16) | |
| buf5 = empty_strided_cuda((2, 4, 277, 16), (17728, 4432, 16, 1), torch.float16) | |
| buf6 = empty_strided_cuda((0, ), (1, ), torch.float32) | |
| buf7 = empty_strided_cuda((0, ), (1, ), torch.float32) | |
| buf8 = empty_strided_cuda((0, ), (1, ), torch.float32) | |
| buf9 = empty_strided_cuda((0, ), (1, ), torch.float32) | |
| buf10 = empty_strided_cuda((2, 4, 277, 16), (17728, 4432, 16, 1), torch.float16) | |
| # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_zeros_2[(18, 2, 4)](primals_1, primals_2, primals_3, getitem_3, buf2, tangents_1, buf4, buf5, full, full_default, convert_element_type, convert_element_type_1, buf6, buf7, buf8, buf9, primals_4, buf0, buf10, num_warps=4, num_stages=1) | |
| del buf2 | |
| del buf6 | |
| del buf7 | |
| del buf8 | |
| del buf9 | |
| del convert_element_type | |
| del convert_element_type_1 | |
| del full | |
| del full_default | |
| del getitem_3 | |
| del primals_1 | |
| del primals_2 | |
| del primals_3 | |
| del primals_4 | |
| del tangents_1 | |
| return (buf4, buf10, buf5, buf0, ) | |
| runner = Runner(partitions=[]) | |
| call = runner.call | |
| recursively_apply_fns = runner.recursively_apply_fns | |
| def benchmark_compiled_module(times=10, repeat=10): | |
| from torch._dynamo.testing import rand_strided | |
| from torch._inductor.utils import print_performance | |
| primals_1 = rand_strided((2, 4, 277, 16), (17728, 4432, 16, 1), device='cuda:0', dtype=torch.float16) | |
| primals_2 = rand_strided((2, 4, 277, 16), (17728, 4432, 16, 1), device='cuda:0', dtype=torch.float16) | |
| primals_3 = rand_strided((2, 4, 277, 16), (17728, 4432, 16, 1), device='cuda:0', dtype=torch.float16) | |
| primals_4 = rand_strided((4, ), (1, ), device='cuda:0', dtype=torch.float16) | |
| full = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) | |
| full_default = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) | |
| convert_element_type = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) | |
| convert_element_type_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) | |
| getitem_2 = rand_strided((2, 4, 277, 16), (17728, 4432, 16, 1), device='cuda:0', dtype=torch.float16) | |
| getitem_3 = rand_strided((2, 4, 277), (1108, 277, 1), device='cuda:0', dtype=torch.float32) | |
| tangents_1 = rand_strided((2, 4, 277, 16), (17728, 4432, 16, 1), device='cuda:0', dtype=torch.float16) | |
| fn = lambda: call([primals_1, primals_2, primals_3, primals_4, full, full_default, convert_element_type, convert_element_type_1, getitem_2, getitem_3, tangents_1]) | |
| return print_performance(fn, times=times, repeat=repeat) | |
| if __name__ == "__main__": | |
| from torch._inductor.wrapper_benchmark import compiled_module_main | |
| compiled_module_main('None', benchmark_compiled_module) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment