Created
March 2, 2026 23:07
-
-
Save shunting314/6fe4e931f7e3bd98e1c936b4b1135a5f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # AOT ID: ['0_backward'] | |
| from ctypes import c_void_p, c_long, c_int | |
| import torch | |
| import math | |
| import random | |
| import os | |
| import tempfile | |
| from math import inf, nan | |
| from cmath import nanj | |
| from torch._inductor.hooks import run_intermediate_hooks | |
| from torch._inductor.utils import maybe_profile | |
| from torch._inductor.codegen.memory_planning import _align as align | |
| from torch import device, empty_strided | |
| from torch._inductor.async_compile import AsyncCompile | |
| from torch._inductor.select_algorithm import extern_kernels | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime.triton_heuristics import start_graph, end_graph | |
| from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
| aten = torch.ops.aten | |
| inductor_ops = torch.ops.inductor | |
| _quantized = torch.ops._quantized | |
| assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
| assert_alignment = torch._C._dynamo.guards.assert_alignment | |
| empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
| empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned | |
| empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
| empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu | |
| empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia | |
| reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor | |
| alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
| async_compile = AsyncCompile() | |
| empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p | |
| # kernel path: /tmp/torchinductor/rank0/wb/cwbxlmflrtfpjh5rwf4t5k27entxvmvn33pehlygc5askur3dbxd.py | |
| # Topologically Sorted Source Nodes: [convert_element_type_410, redistribute_4, convert_element_type_412, mul_153, mul_154, sum_1, x_26, mul_155, sum_2, mul_156, sub_29, sub_30, div_3, mul_157, mul_158, sum_3, sum_4, convert_element_type_414], Original ATen: [aten.native_layer_norm_backward, aten._to_copy, aten.native_layer_norm] | |
| # Source node to ATen node mapping: | |
| # convert_element_type_410 => convert_element_type_410 | |
| # convert_element_type_412 => convert_element_type_412 | |
| # convert_element_type_414 => convert_element_type_414 | |
| # div_3 => div_3 | |
| # mul_153 => mul_153 | |
| # mul_154 => mul_154 | |
| # mul_155 => mul_155 | |
| # mul_156 => mul_156 | |
| # mul_157 => mul_157 | |
| # mul_158 => mul_158 | |
| # redistribute_4 => convert_element_type_406 | |
| # sub_29 => sub_29 | |
| # sub_30 => sub_30 | |
| # sum_1 => sum_1 | |
| # sum_2 => sum_2 | |
| # sum_3 => sum_3 | |
| # sum_4 => sum_4 | |
| # x_26 => convert_element_type_408, mul_150, sub_27 | |
| # Graph fragment: | |
| # %tangents_1 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=tangents_1] | |
| # %primals_167 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=primals_167] | |
| # %add_107 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_107] | |
| # %getitem_89 : Tensor "f32[327680, 1][1, 1]cuda:0" = PlaceHolder[target=getitem_89] | |
| # %rsqrt_26 : Tensor "f32[327680, 1][1, 1]cuda:0" = PlaceHolder[target=rsqrt_26] | |
| # %sum_1 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_1] | |
| # %sum_2 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_2] | |
| # %convert_element_type_410 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%tangents_1, torch.float32), kwargs = {}) | |
| # %convert_element_type_406 : Tensor "bf16[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_167, torch.bfloat16), kwargs = {}) | |
| # %convert_element_type_412 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_406, torch.float32), kwargs = {}) | |
| # %mul_153 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_410, %convert_element_type_412), kwargs = {}) | |
| # %mul_154 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_153, 2048), kwargs = {}) | |
| # %sum_1 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_153, [1], True), kwargs = {}) | |
| # %convert_element_type_408 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_107, torch.float32), kwargs = {}) | |
| # %sub_27 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_408, %getitem_89), kwargs = {}) | |
| # %mul_150 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_27, %rsqrt_26), kwargs = {}) | |
| # %mul_155 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_153, %mul_150), kwargs = {}) | |
| # %sum_2 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_155, [1], True), kwargs = {}) | |
| # %mul_156 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_150, %sum_2), kwargs = {}) | |
| # %sub_29 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_154, %sum_1), kwargs = {}) | |
| # %sub_30 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_29, %mul_156), kwargs = {}) | |
| # %div_3 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%rsqrt_26, 2048), kwargs = {}) | |
| # %mul_157 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_3, %sub_30), kwargs = {}) | |
| # %mul_158 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_410, %mul_150), kwargs = {}) | |
| # %sum_3 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_158, [0]), kwargs = {}) | |
| # %sum_4 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_410, [0]), kwargs = {}) | |
| # %convert_element_type_414 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_157, torch.bfloat16), kwargs = {}) | |
| # return %sum_1,%sum_2,%convert_element_type_414 | |
| triton_per_fused__to_copy_native_layer_norm_native_layer_norm_backward_0 = async_compile.triton('triton_per_fused__to_copy_native_layer_norm_native_layer_norm_backward_0', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 524288, 'r0_': 2048}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'out_ptr2': '*bf16', 'ws_ptr': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'RSPLIT_SIZE': 'constexpr', 'NUM_STAGES': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'MixOrderReductionGrid', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_native_layer_norm_native_layer_norm_backward_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 5, 'num_store': -1, 'num_reduction': 2, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'RSPLIT_SIZE': 128, 'has_loadstore_with_contiguous_rdim': True} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_native_layer_norm_native_layer_norm_backward_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, ws_ptr, xnumel, r0_numel, XBLOCK : tl.constexpr, RSPLIT_SIZE : tl.constexpr, NUM_STAGES : tl.constexpr): | |
| xnumel = 327680 | |
| r0_numel = 2048 | |
| R0_BLOCK: tl.constexpr = 2048 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * RSPLIT_SIZE | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_index = tl.arange(0, R0_BLOCK)[None, :] | |
| r0_offset = 0 | |
| r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :] | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_1 = r0_index | |
| accum0 = tl.full([R0_BLOCK], 0, tl.float32)[None, :] | |
| accum1 = tl.full([R0_BLOCK], 0, tl.float32)[None, :] | |
| split_size = min(RSPLIT_SIZE, xnumel - xoffset) | |
| for _ in tl.range(0, split_size, XBLOCK, num_stages=NUM_STAGES): | |
| x0 = xindex | |
| xindex += XBLOCK | |
| tmp0 = tl.load(in_ptr0 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp2 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last') | |
| tmp9 = tl.load(in_ptr2 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp11 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
| tmp13 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last') | |
| tmp1 = tmp0.to(tl.float32) | |
| tmp3 = tmp2.to(tl.float32) | |
| tmp4 = tmp3.to(tl.float32) | |
| tmp5 = tmp1 * tmp4 | |
| tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) | |
| tmp8 = tl.sum(tmp6, 1)[:, None].to(tl.float32) | |
| tmp10 = tmp9.to(tl.float32) | |
| tmp12 = tmp10 - tmp11 | |
| tmp14 = tmp12 * tmp13 | |
| tmp15 = tmp5 * tmp14 | |
| tmp16 = tl.broadcast_to(tmp15, [XBLOCK, R0_BLOCK]) | |
| tmp18 = tl.sum(tmp16, 1)[:, None].to(tl.float32) | |
| tmp19 = tl.full([1, 1], 0.00048828125, tl.float32) | |
| tmp20 = tmp13 * tmp19 | |
| tmp21 = tl.full([1, 1], 2048.0, tl.float32) | |
| tmp22 = tmp5 * tmp21 | |
| tmp23 = tmp22 - tmp8 | |
| tmp24 = tmp14 * tmp18 | |
| tmp25 = tmp23 - tmp24 | |
| tmp26 = tmp20 * tmp25 | |
| tmp27 = tmp26.to(tl.float32) | |
| tmp28 = tmp1 * tmp14 | |
| tl.store(out_ptr2 + (r0_1 + 2048*x0), tmp27, None) | |
| tmp29 = tl.sum(tmp28, 0) | |
| tmp30 = accum0 + tmp29 | |
| accum0 = tmp30 | |
| tmp31 = tl.sum(tmp1, 0) | |
| tmp32 = accum1 + tmp31 | |
| accum1 = tmp32 | |
| tl.store(ws_ptr + (tl.program_id(0) + 0 * tl.num_programs(0)) * r0_numel + r0_index, accum0, r0_mask) | |
| tl.store(ws_ptr + (tl.program_id(0) + 1 * tl.num_programs(0)) * r0_numel + r0_index, accum1, r0_mask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/zm/czmvhutdr45dzcouphyn6ypgzhuegd4eldyd5fc7qmihrwdbp2aa.py | |
| # Topologically Sorted Source Nodes: [convert_element_type_415, all_reduce_1], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| # Source node to ATen node mapping: | |
| # all_reduce_1 => all_reduce_1 | |
| # convert_element_type_415 => convert_element_type_415 | |
| # Graph fragment: | |
| # %sum_3 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=sum_3] | |
| # %convert_element_type_415 : Tensor "bf16[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_3, torch.bfloat16), kwargs = {}) | |
| # %all_reduce_1 : Tensor "bf16[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce_.default](args = (%convert_element_type_415, avg, 0), kwargs = {}) | |
| # return %wait_tensor_1 | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1 = async_compile.triton('triton_poi_fused_all_reduce_native_layer_norm_backward_1', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 2048}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_all_reduce_native_layer_norm_backward_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 16384}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_all_reduce_native_layer_norm_backward_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 2048 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = xindex < xnumel | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), xmask) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, xmask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/oo/coovzos7dfgo6q6d2x5csilspzo54pvavar6efs5zmoocbr3baq6.py | |
| # Topologically Sorted Source Nodes: [convert_element_type_417], Original ATen: [aten._to_copy] | |
| # Source node to ATen node mapping: | |
| # convert_element_type_417 => convert_element_type_417 | |
| # Graph fragment: | |
| # %buf11 : Tensor = PlaceHolder[target=buf11] | |
| # %convert_element_type_417 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor, torch.float32), kwargs = {}) | |
| # return %convert_element_type_417 | |
| triton_poi_fused__to_copy_2 = async_compile.triton('triton_poi_fused__to_copy_2', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 2048}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 16384}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 2048 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = xindex < xnumel | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, xmask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/74/c745s7utu4xbymzze36ogp6jnsxqu2olhyem7egezxhjzg2idgqd.py | |
| # Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy] | |
| # Source node to ATen node mapping: | |
| # redistribute_4 => convert_element_type_401 | |
| # Graph fragment: | |
| # %primals_165 : Tensor "f32[2048, 8192][8192, 1]cuda:0" = PlaceHolder[target=primals_165] | |
| # %convert_element_type_401 : Tensor "bf16[2048, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_165, torch.bfloat16), kwargs = {}) | |
| # return %convert_element_type_401 | |
| triton_poi_fused__to_copy_3 = async_compile.triton('triton_poi_fused__to_copy_3', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 16777216}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 134217728}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_3(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 16777216 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/f2/cf2xntzp4edf6e25wjnwlziltupvbuqpcnzw5jxzmb4pxuc3x7r4.py | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| # Source node to ATen node mapping: | |
| # layer_norm => add_103, add_104, convert_element_type_392, convert_element_type_393, mul_142, mul_143, rsqrt_25, sub_26, var_mean_25 | |
| # redistribute => convert_element_type_390 | |
| # redistribute_1 => convert_element_type_391 | |
| # Graph fragment: | |
| # %add_102 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_102] | |
| # %getitem_87 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=getitem_87] | |
| # %buf22 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf22] | |
| # %primals_161 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=primals_161] | |
| # %primals_162 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=primals_162] | |
| # %convert_element_type_390 : Tensor "bf16[2048][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_161, torch.bfloat16), kwargs = {}) | |
| # %convert_element_type_391 : Tensor "bf16[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_162, torch.bfloat16), kwargs = {}) | |
| # %convert_element_type_392 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_102, torch.float32), kwargs = {}) | |
| # %var_mean_25 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_392, [1]), kwargs = {correction: 0, keepdim: True}) | |
| # %add_103 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_86, 1e-05), kwargs = {}) | |
| # %rsqrt_25 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_103,), kwargs = {}) | |
| # %sub_26 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_392, %getitem_87), kwargs = {}) | |
| # %mul_142 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_26, %rsqrt_25), kwargs = {}) | |
| # %mul_143 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_142, %convert_element_type_390), kwargs = {}) | |
| # %add_104 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_143, %convert_element_type_391), kwargs = {}) | |
| # %convert_element_type_393 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_104, torch.bfloat16), kwargs = {}) | |
| # return %getitem_87,%buf22,%convert_element_type_393 | |
| triton_red_fused__to_copy_native_layer_norm_4 = async_compile.triton('triton_red_fused__to_copy_native_layer_norm_4', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.reduction( | |
| size_hints={'x': 524288, 'r0_': 2048}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_native_layer_norm_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 3, 'num_reduction': 2, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'x': 5242880, 'r0_': 4026548224}} | |
| ) | |
| @triton.jit | |
| def triton_red_fused__to_copy_native_layer_norm_4(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
| xnumel = 327680 | |
| r0_numel = 2048 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
| rbase = r0_base | |
| x0 = xindex | |
| tmp3_mean = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) | |
| tmp3_m2 = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) | |
| tmp3_weight = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) | |
| for r0_offset in tl.range(0, r0_numel, R0_BLOCK): | |
| r0_index = r0_offset + r0_base | |
| r0_mask = r0_index < r0_numel | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_1 = r0_index | |
| tmp0 = tl.load(in_ptr0 + (r0_1 + 2048*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
| tmp3_mean_next, tmp3_m2_next, tmp3_weight_next = triton_helpers.welford_reduce( | |
| tmp2, tmp3_mean, tmp3_m2, tmp3_weight, roffset == 0 | |
| ) | |
| tmp3_mean = tl.where(r0_mask, tmp3_mean_next, tmp3_mean) | |
| tmp3_m2 = tl.where(r0_mask, tmp3_m2_next, tmp3_m2) | |
| tmp3_weight = tl.where(r0_mask, tmp3_weight_next, tmp3_weight) | |
| tmp4, tmp5, tmp6 = triton_helpers.welford(tmp3_mean, tmp3_m2, tmp3_weight, 1) | |
| tmp3 = tmp4[:, None] | |
| tmp7 = tmp5[:, None] | |
| tmp8 = tmp6[:, None] | |
| tl.store(out_ptr0 + (x0), tmp3, None) | |
| tl.store(out_ptr1 + (x0), tmp7, None) | |
| for r0_offset in tl.range(0, r0_numel, R0_BLOCK): | |
| r0_index = r0_offset + r0_base | |
| r0_mask = r0_index < r0_numel | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_1 = r0_index | |
| tmp9 = tl.load(in_ptr0 + (r0_1 + 2048*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
| tmp18 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0) | |
| tmp22 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0) | |
| tmp10 = tmp9.to(tl.float32) | |
| tmp11 = tmp10 - tmp3 | |
| tmp12 = tl.full([1, 1], 2048.0, tl.float32) | |
| tmp13 = (tmp7 / tmp12) | |
| tmp14 = tl.full([1, 1], 1e-05, tl.float32) | |
| tmp15 = tmp13 + tmp14 | |
| tmp16 = libdevice.rsqrt(tmp15) | |
| tmp17 = tmp11 * tmp16 | |
| tmp19 = tmp18.to(tl.float32) | |
| tmp20 = tmp19.to(tl.float32) | |
| tmp21 = tmp17 * tmp20 | |
| tmp23 = tmp22.to(tl.float32) | |
| tmp24 = tmp23.to(tl.float32) | |
| tmp25 = tmp21 + tmp24 | |
| tmp26 = tmp25.to(tl.float32) | |
| tl.store(out_ptr2 + (r0_1 + 2048*x0), tmp26, r0_mask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/k2/ck2gagwbhzuttn76q4burnrkjixhz2qr32zm5ndumumvzjoihsxw.py | |
| # Topologically Sorted Source Nodes: [redistribute_3, x, x_25, convert_element_type_425, mul_164, mul_165, sub_31, mul_166, add_112, mul_167, mul_168, mul_169, add_113, mul_170, convert_element_type_427], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward] | |
| # Source node to ATen node mapping: | |
| # add_112 => add_112 | |
| # add_113 => add_113 | |
| # convert_element_type_425 => convert_element_type_425 | |
| # convert_element_type_427 => convert_element_type_427 | |
| # mul_164 => mul_164 | |
| # mul_165 => mul_165 | |
| # mul_166 => mul_166 | |
| # mul_167 => mul_167 | |
| # mul_168 => mul_168 | |
| # mul_169 => mul_169 | |
| # mul_170 => mul_170 | |
| # redistribute_3 => convert_element_type_395 | |
| # sub_31 => sub_31 | |
| # x => add_tensor_11 | |
| # x_25 => add_105, add_106, convert_element_type_399, convert_element_type_400, mul_144, mul_145, mul_146, mul_147, mul_148, mul_149, tanh_11 | |
| # Graph fragment: | |
| # %primals_164 : Tensor "f32[8192][1]cuda:0" = PlaceHolder[target=primals_164] | |
| # %mm_default_11 : Tensor "bf16[327680, 8192][8192, 1]cuda:0" = PlaceHolder[target=mm_default_11] | |
| # %mm_49 : Tensor "bf16[327680, 8192][8192, 1]cuda:0" = PlaceHolder[target=mm_49] | |
| # %convert_element_type_395 : Tensor "bf16[8192][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_164, torch.bfloat16), kwargs = {}) | |
| # %add_tensor_11 : Tensor "bf16[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_395, %mm_default_11), kwargs = {}) | |
| # %convert_element_type_399 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_tensor_11, torch.float32), kwargs = {}) | |
| # %mul_144 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_399, %convert_element_type_399), kwargs = {}) | |
| # %mul_145 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_144, %convert_element_type_399), kwargs = {}) | |
| # %mul_146 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_145, 0.044715), kwargs = {}) | |
| # %add_105 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_399, %mul_146), kwargs = {}) | |
| # %mul_147 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_105, 0.7978845608028654), kwargs = {}) | |
| # %mul_148 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_399, 0.5), kwargs = {}) | |
| # %tanh_11 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_147,), kwargs = {}) | |
| # %add_106 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh_11, 1), kwargs = {}) | |
| # %mul_149 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_148, %add_106), kwargs = {}) | |
| # %convert_element_type_400 : Tensor "bf16[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_149, torch.bfloat16), kwargs = {}) | |
| # %convert_element_type_425 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_49, torch.float32), kwargs = {}) | |
| # %mul_164 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_106, 0.5), kwargs = {}) | |
| # %mul_165 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tanh_11, %tanh_11), kwargs = {}) | |
| # %sub_31 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (1, %mul_165), kwargs = {}) | |
| # %mul_166 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_144, 0.134145), kwargs = {}) | |
| # %add_112 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_166, 1), kwargs = {}) | |
| # %mul_167 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_112, 0.7978845608028654), kwargs = {}) | |
| # %mul_168 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_148, %sub_31), kwargs = {}) | |
| # %mul_169 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_168, %mul_167), kwargs = {}) | |
| # %add_113 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_164, %mul_169), kwargs = {}) | |
| # %mul_170 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_425, %add_113), kwargs = {}) | |
| # %convert_element_type_427 : Tensor "bf16[327680, 8192][8192, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_170, torch.bfloat16), kwargs = {}) | |
| # return %convert_element_type_400,%convert_element_type_427 | |
| triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5 = async_compile.triton('triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 4294967296}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*fp32', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i64', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 2, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 32212287488}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5(in_out_ptr0, in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 2684354560 | |
| xoffset = tl.program_id(0).to(tl.int64) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:].to(tl.int64) | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x0 = (xindex % 8192) | |
| x2 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last') | |
| tmp2 = tl.load(in_ptr1 + (x2), None).to(tl.float32) | |
| tmp19 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tmp3 = tmp1 + tmp2 | |
| tmp4 = tmp3.to(tl.float32) | |
| tmp5 = tl.full([1], 0.5, tl.float32) | |
| tmp6 = tmp4 * tmp5 | |
| tmp7 = tmp4 * tmp4 | |
| tmp8 = tmp7 * tmp4 | |
| tmp9 = tl.full([1], 0.044715, tl.float32) | |
| tmp10 = tmp8 * tmp9 | |
| tmp11 = tmp4 + tmp10 | |
| tmp12 = tl.full([1], 0.7978845608028654, tl.float32) | |
| tmp13 = tmp11 * tmp12 | |
| tmp14 = libdevice.tanh(tmp13) | |
| tmp15 = tl.full([1], 1.0, tl.float32) | |
| tmp16 = tmp14 + tmp15 | |
| tmp17 = tmp6 * tmp16 | |
| tmp18 = tmp17.to(tl.float32) | |
| tmp20 = tmp19.to(tl.float32) | |
| tmp21 = tmp16 * tmp5 | |
| tmp22 = tmp14 * tmp14 | |
| tmp23 = tmp15 - tmp22 | |
| tmp24 = tmp6 * tmp23 | |
| tmp25 = tl.full([1], 0.134145, tl.float32) | |
| tmp26 = tmp7 * tmp25 | |
| tmp27 = tmp26 + tmp15 | |
| tmp28 = tmp27 * tmp12 | |
| tmp29 = tmp24 * tmp28 | |
| tmp30 = tmp21 + tmp29 | |
| tmp31 = tmp20 * tmp30 | |
| tmp32 = tmp31.to(tl.float32) | |
| tl.store(out_ptr0 + (x2), tmp18, None) | |
| tl.store(in_out_ptr0 + (x2), tmp32, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/t6/ct6itkhhsr67p6nullejwydl2p56bqr5cedys6o4gndtdlcev3mu.py | |
| # Topologically Sorted Source Nodes: [sum_5], Original ATen: [aten.sum] | |
| # Source node to ATen node mapping: | |
| # sum_5 => sum_5 | |
| # Graph fragment: | |
| # %convert_element_type_414 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=convert_element_type_414] | |
| # %sum_5 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_414, [0], True), kwargs = {}) | |
| # return %buf29 | |
| triton_red_fused_sum_6 = async_compile.triton('triton_red_fused_sum_6', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.reduction( | |
| size_hints={'x': 524288, 'r0_': 2048}, | |
| reduction_hint=ReductionHint.OUTER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 1344798720, 'r0_': 0}} | |
| ) | |
| @triton.jit | |
| def triton_red_fused_sum_6(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
| xnumel = 327680 | |
| r0_numel = 2048 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
| rbase = r0_base | |
| x0 = (xindex % 2048) | |
| x1 = xindex // 2048 | |
| _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
| x3 = xindex | |
| for r0_offset in tl.range(0, r0_numel, R0_BLOCK): | |
| r0_index = r0_offset + r0_base | |
| r0_mask = r0_index < r0_numel | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_2 = r0_index | |
| tmp0 = tl.load(in_ptr0 + (x0 + 2048*r0_2 + 4194304*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
| tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) | |
| tmp3 = _tmp2 + tmp1 | |
| _tmp2 = tl.where(r0_mask, tmp3, _tmp2) | |
| tmp2 = tl.sum(_tmp2, 1)[:, None] | |
| tl.store(out_ptr0 + (x3), tmp2, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/p2/cp2yuujssx45mwam4gmghu2bmmldkfzv4gngwd75ivayqdq2igwq.py | |
| # Topologically Sorted Source Nodes: [sum_5], Original ATen: [aten.sum] | |
| # Source node to ATen node mapping: | |
| # sum_5 => sum_5 | |
| # Graph fragment: | |
| # %buf29 : Tensor "f32[1, 2048, 160][327680, 1, 2048]cuda:0" = PlaceHolder[target=buf29] | |
| # %sum_5 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_414, [0], True), kwargs = {}) | |
| # return %sum_5 | |
| triton_red_fused_sum_7 = async_compile.triton('triton_red_fused_sum_7', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.reduction( | |
| size_hints={'x': 2048, 'r0_': 256}, | |
| reduction_hint=ReductionHint.OUTER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_7', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 1318912, 'r0_': 0}} | |
| ) | |
| @triton.jit | |
| def triton_red_fused_sum_7(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
| xnumel = 2048 | |
| r0_numel = 160 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = xindex < xnumel | |
| r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
| rbase = r0_base | |
| x0 = xindex | |
| _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
| for r0_offset in tl.range(0, r0_numel, R0_BLOCK): | |
| r0_index = r0_offset + r0_base | |
| r0_mask = r0_index < r0_numel | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_1 = r0_index | |
| tmp0 = tl.load(in_ptr0 + (x0 + 2048*r0_1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) | |
| tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) | |
| tmp3 = _tmp2 + tmp1 | |
| _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) | |
| tmp2 = tl.sum(_tmp2, 1)[:, None] | |
| tl.store(out_ptr0 + (x0), tmp2, xmask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/hh/chhznpxzkrjd6hw2dultp3lkcxpgo3xbcosobcrvsssw26crs3pc.py | |
| # Topologically Sorted Source Nodes: [convert_element_type_424], Original ATen: [aten._to_copy] | |
| # Source node to ATen node mapping: | |
| # convert_element_type_424 => convert_element_type_424 | |
| # Graph fragment: | |
| # %buf39 : Tensor = PlaceHolder[target=buf39] | |
| # %convert_element_type_424 : Tensor "f32[2048, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_3, torch.float32), kwargs = {}) | |
| # return %convert_element_type_424 | |
| triton_poi_fused__to_copy_8 = async_compile.triton('triton_poi_fused__to_copy_8', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 16777216}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_8', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 134217728}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_8(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 16777216 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/y4/cy4awr42ljhsqyf72s5jzm7g3pbdgrndgk7ry4ij6viltpmia7db.py | |
| # Topologically Sorted Source Nodes: [sum_6], Original ATen: [aten.sum] | |
| # Source node to ATen node mapping: | |
| # sum_6 => sum_6 | |
| # Graph fragment: | |
| # %convert_element_type_427 : Tensor "bf16[327680, 8192][8192, 1]cuda:0" = PlaceHolder[target=convert_element_type_427] | |
| # %sum_6 : Tensor "bf16[1, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_427, [0], True), kwargs = {}) | |
| # return %buf44 | |
| triton_red_fused_sum_9 = async_compile.triton('triton_red_fused_sum_9', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.reduction( | |
| size_hints={'x': 2097152, 'r0_': 2048}, | |
| reduction_hint=ReductionHint.OUTER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_9', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 5379194880, 'r0_': 0}} | |
| ) | |
| @triton.jit | |
| def triton_red_fused_sum_9(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
| xnumel = 1310720 | |
| r0_numel = 2048 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0).to(tl.int64) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) | |
| rbase = r0_base | |
| x0 = (xindex % 8192) | |
| x1 = xindex // 8192 | |
| _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
| x3 = xindex | |
| for r0_offset in tl.range(0, r0_numel, R0_BLOCK): | |
| r0_index = r0_offset + r0_base | |
| r0_mask = r0_index < r0_numel | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_2 = r0_index | |
| tmp0 = tl.load(in_ptr0 + (x0 + 8192*r0_2 + 16777216*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
| tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) | |
| tmp3 = _tmp2 + tmp1 | |
| _tmp2 = tl.where(r0_mask, tmp3, _tmp2) | |
| tmp2 = tl.sum(_tmp2, 1)[:, None] | |
| tl.store(out_ptr0 + (x3), tmp2, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/qk/cqkqenwa2t4jzfp5qb2mrzo3mnfgjl2ns4mywtxjzal2ymd6ew3v.py | |
| # Topologically Sorted Source Nodes: [sum_6], Original ATen: [aten.sum] | |
| # Source node to ATen node mapping: | |
| # sum_6 => sum_6 | |
| # Graph fragment: | |
| # %buf44 : Tensor "f32[1, 8192, 160][1310720, 1, 8192]cuda:0" = PlaceHolder[target=buf44] | |
| # %sum_6 : Tensor "bf16[1, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_427, [0], True), kwargs = {}) | |
| # return %sum_6 | |
| triton_red_fused_sum_10 = async_compile.triton('triton_red_fused_sum_10', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.reduction( | |
| size_hints={'x': 8192, 'r0_': 256}, | |
| reduction_hint=ReductionHint.OUTER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 5275648, 'r0_': 0}} | |
| ) | |
| @triton.jit | |
| def triton_red_fused_sum_10(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
| xnumel = 8192 | |
| r0_numel = 160 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
| rbase = r0_base | |
| x0 = xindex | |
| _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
| for r0_offset in tl.range(0, r0_numel, R0_BLOCK): | |
| r0_index = r0_offset + r0_base | |
| r0_mask = r0_index < r0_numel | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_1 = r0_index | |
| tmp0 = tl.load(in_ptr0 + (x0 + 8192*r0_1), r0_mask, eviction_policy='evict_first', other=0.0) | |
| tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) | |
| tmp3 = _tmp2 + tmp1 | |
| _tmp2 = tl.where(r0_mask, tmp3, _tmp2) | |
| tmp2 = tl.sum(_tmp2, 1)[:, None] | |
| tl.store(out_ptr0 + (x0), tmp2, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/wt/cwt4nuarmo4faznyu2vyoyay5466bt27gffcjcknubmh3gfamc73.py | |
| # Topologically Sorted Source Nodes: [view_280, convert_element_type_432], Original ATen: [aten.view, aten._to_copy] | |
| # Source node to ATen node mapping: | |
| # convert_element_type_432 => convert_element_type_432 | |
| # view_280 => view_280 | |
| # Graph fragment: | |
| # %buf49 : Tensor = PlaceHolder[target=buf49] | |
| # %view_280 : Tensor "bf16[8192][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%sum_6, [8192]), kwargs = {}) | |
| # %convert_element_type_432 : Tensor "f32[8192][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_4, torch.float32), kwargs = {}) | |
| # return %convert_element_type_432 | |
| triton_poi_fused__to_copy_view_11 = async_compile.triton('triton_poi_fused__to_copy_view_11', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 8192}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_view_11', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 65536}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_view_11(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 8192 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/gu/cguot2zrmpx5qc3kxoddiexeuvmukqdb3jgmu5cfce2w2eczd7c7.py | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_434, convert_element_type_436, mul_172, mul_173, sum_7, mul_174, sum_8, mul_175, sub_33, sub_34, div_4, mul_176, mul_177, sum_9, sum_10, convert_element_type_438, add_114], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add] | |
| # Source node to ATen node mapping: | |
| # add_114 => add_114 | |
| # convert_element_type_434 => convert_element_type_434 | |
| # convert_element_type_436 => convert_element_type_436 | |
| # convert_element_type_438 => convert_element_type_438 | |
| # div_4 => div_4 | |
| # layer_norm => add_103, convert_element_type_392, mul_142, rsqrt_25, sub_26, var_mean_25 | |
| # mul_172 => mul_172 | |
| # mul_173 => mul_173 | |
| # mul_174 => mul_174 | |
| # mul_175 => mul_175 | |
| # mul_176 => mul_176 | |
| # mul_177 => mul_177 | |
| # redistribute => convert_element_type_390 | |
| # sub_33 => sub_33 | |
| # sub_34 => sub_34 | |
| # sum_10 => sum_10 | |
| # sum_7 => sum_7 | |
| # sum_8 => sum_8 | |
| # sum_9 => sum_9 | |
| # Graph fragment: | |
| # %mm_51 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_51] | |
| # %primals_161 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=primals_161] | |
| # %add_102 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_102] | |
| # %getitem_87 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=getitem_87] | |
| # %buf22 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf22] | |
| # %convert_element_type_414 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=convert_element_type_414] | |
| # %sum_7 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_7] | |
| # %sum_8 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_8] | |
| # %convert_element_type_390 : Tensor "bf16[2048][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_161, torch.bfloat16), kwargs = {}) | |
| # %convert_element_type_392 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_102, torch.float32), kwargs = {}) | |
| # %var_mean_25 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_392, [1]), kwargs = {correction: 0, keepdim: True}) | |
| # %add_103 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_86, 1e-05), kwargs = {}) | |
| # %rsqrt_25 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_103,), kwargs = {}) | |
| # %sub_26 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_392, %getitem_87), kwargs = {}) | |
| # %mul_142 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_26, %rsqrt_25), kwargs = {}) | |
| # %convert_element_type_434 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_51, torch.float32), kwargs = {}) | |
| # %convert_element_type_436 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_390, torch.float32), kwargs = {}) | |
| # %mul_172 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_434, %convert_element_type_436), kwargs = {}) | |
| # %mul_173 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_172, 2048), kwargs = {}) | |
| # %sum_7 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_172, [1], True), kwargs = {}) | |
| # %mul_174 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_172, %mul_142), kwargs = {}) | |
| # %sum_8 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_174, [1], True), kwargs = {}) | |
| # %mul_175 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_142, %sum_8), kwargs = {}) | |
| # %sub_33 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_173, %sum_7), kwargs = {}) | |
| # %sub_34 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_33, %mul_175), kwargs = {}) | |
| # %div_4 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%rsqrt_25, 2048), kwargs = {}) | |
| # %mul_176 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_4, %sub_34), kwargs = {}) | |
| # %mul_177 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_434, %mul_142), kwargs = {}) | |
| # %sum_9 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_177, [0]), kwargs = {}) | |
| # %sum_10 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_434, [0]), kwargs = {}) | |
| # %convert_element_type_438 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_176, torch.bfloat16), kwargs = {}) | |
| # %add_114 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_414, %convert_element_type_438), kwargs = {}) | |
| # return %sum_7,%sum_8,%add_114 | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12 = async_compile.triton('triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 524288, 'r0_': 2048}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'ws_ptr': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'RSPLIT_SIZE': 'constexpr', 'NUM_STAGES': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'MixOrderReductionGrid', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 6, 'num_store': -1, 'num_reduction': 2, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'RSPLIT_SIZE': 128, 'has_loadstore_with_contiguous_rdim': True} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, ws_ptr, xnumel, r0_numel, XBLOCK : tl.constexpr, RSPLIT_SIZE : tl.constexpr, NUM_STAGES : tl.constexpr): | |
| xnumel = 327680 | |
| r0_numel = 2048 | |
| R0_BLOCK: tl.constexpr = 2048 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * RSPLIT_SIZE | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_index = tl.arange(0, R0_BLOCK)[None, :] | |
| r0_offset = 0 | |
| r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :] | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_1 = r0_index | |
| accum0 = tl.full([R0_BLOCK], 0, tl.float32)[None, :] | |
| accum1 = tl.full([R0_BLOCK], 0, tl.float32)[None, :] | |
| split_size = min(RSPLIT_SIZE, xnumel - xoffset) | |
| for _ in tl.range(0, split_size, XBLOCK, num_stages=NUM_STAGES): | |
| x0 = xindex | |
| xindex += XBLOCK | |
| tmp0 = tl.load(in_ptr0 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp2 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last') | |
| tmp9 = tl.load(in_ptr2 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp11 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
| tmp13 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last') | |
| tmp24 = tl.load(in_out_ptr0 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tmp3 = tmp2.to(tl.float32) | |
| tmp4 = tmp3.to(tl.float32) | |
| tmp5 = tmp1 * tmp4 | |
| tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) | |
| tmp8 = tl.sum(tmp6, 1)[:, None].to(tl.float32) | |
| tmp10 = tmp9.to(tl.float32) | |
| tmp12 = tmp10 - tmp11 | |
| tmp14 = tl.full([1, 1], 2048.0, tl.float32) | |
| tmp15 = (tmp13 / tmp14) | |
| tmp16 = tl.full([1, 1], 1e-05, tl.float32) | |
| tmp17 = tmp15 + tmp16 | |
| tmp18 = libdevice.rsqrt(tmp17) | |
| tmp19 = tmp12 * tmp18 | |
| tmp20 = tmp5 * tmp19 | |
| tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) | |
| tmp23 = tl.sum(tmp21, 1)[:, None].to(tl.float32) | |
| tmp25 = tl.full([1, 1], 0.00048828125, tl.float32) | |
| tmp26 = tmp18 * tmp25 | |
| tmp27 = tmp5 * tmp14 | |
| tmp28 = tmp27 - tmp8 | |
| tmp29 = tmp19 * tmp23 | |
| tmp30 = tmp28 - tmp29 | |
| tmp31 = tmp26 * tmp30 | |
| tmp32 = tmp31.to(tl.float32) | |
| tmp33 = tmp24 + tmp32 | |
| tmp34 = tmp1 * tmp19 | |
| tl.store(in_out_ptr0 + (r0_1 + 2048*x0), tmp33, None) | |
| tmp35 = tl.sum(tmp34, 0) | |
| tmp36 = accum0 + tmp35 | |
| accum0 = tmp36 | |
| tmp37 = tl.sum(tmp1, 0) | |
| tmp38 = accum1 + tmp37 | |
| accum1 = tmp38 | |
| tl.store(ws_ptr + (tl.program_id(0) + 0 * tl.num_programs(0)) * r0_numel + r0_index, accum0, r0_mask) | |
| tl.store(ws_ptr + (tl.program_id(0) + 1 * tl.num_programs(0)) * r0_numel + r0_index, accum1, r0_mask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/wd/cwd4pi47eukzrwowauw3dsgtyi5t2umcku6ki7ijnndn5vrptyak.py | |
| # Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t] | |
| # Source node to ATen node mapping: | |
| # linear => permute_111 | |
| # redistribute_2 => convert_element_type_378 | |
| # Graph fragment: | |
| # %primals_157 : Tensor "f32[2048, 2048][2048, 1]cuda:0" = PlaceHolder[target=primals_157] | |
| # %convert_element_type_378 : Tensor "bf16[2048, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_157, torch.bfloat16), kwargs = {}) | |
| # %permute_111 : Tensor "bf16[2048, 2048][1, 2048]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%convert_element_type_378, [1, 0]), kwargs = {}) | |
| # return %permute_111 | |
| triton_poi_fused__to_copy_t_13 = async_compile.triton('triton_poi_fused__to_copy_t_13', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 4194304}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_t_13', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 33554432}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_t_13(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 4194304 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/we/cwe4ednvdc4atd6stlt2h5bnqtx47uoxynggv7myu77yu54lorzj.py | |
| # Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t] | |
| # Source node to ATen node mapping: | |
| # linear_1 => permute_113 | |
| # redistribute_3 => convert_element_type_381 | |
| # Graph fragment: | |
| # %primals_158 : Tensor "f32[512, 2048][2048, 1]cuda:0" = PlaceHolder[target=primals_158] | |
| # %convert_element_type_381 : Tensor "bf16[512, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_158, torch.bfloat16), kwargs = {}) | |
| # %permute_113 : Tensor "bf16[2048, 512][1, 2048]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%convert_element_type_381, [1, 0]), kwargs = {}) | |
| # return %permute_113 | |
| triton_poi_fused__to_copy_t_14 = async_compile.triton('triton_poi_fused__to_copy_t_14', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 1048576}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_t_14', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 8388608}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_t_14(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 1048576 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/zs/czstcwxp4irb2bhqocdicv4g7b67ruhh3iwzcbk4zgkvjkekysvk.py | |
| # Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention] | |
| # Source node to ATen node mapping: | |
| # flex_attention => flex_attention_11 | |
| # k => permute_114, view_261, view_262 | |
| # q => permute_112, view_258, view_259 | |
| # v => permute_116, view_264, view_265 | |
| # Graph fragment: | |
| # %mm_45 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_45] | |
| # %mm_46 : Tensor "bf16[327680, 512][512, 1]cuda:0" = PlaceHolder[target=mm_46] | |
| # %mm_47 : Tensor "bf16[327680, 512][512, 1]cuda:0" = PlaceHolder[target=mm_47] | |
| # %getitem_84 : Tensor "f32[1, 16, 327680][5242880, 327680, 1]cuda:0" = PlaceHolder[target=getitem_84] | |
| # %buf86 : Tensor "f32[1, 16, 327680][5242880, 327680, 1]cuda:0" = PlaceHolder[target=buf86] | |
| # %primals_18 : Tensor "i32[1, 1, 2560][2560, 2560, 1]cuda:0" = PlaceHolder[target=primals_18] | |
| # %primals_17 : Tensor "i32[1, 1, 2560, s91][2560*s91, 2560*s91, s91, 1]cuda:0" = PlaceHolder[target=primals_17] | |
| # %primals_19 : Tensor "i32[1, 1, 2560][2560, 2560, 1]cuda:0" = PlaceHolder[target=primals_19] | |
| # %primals_21 : Tensor "i32[1, 1, 2560, s6][2560*s6, 2560*s6, s6, 1]cuda:0" = PlaceHolder[target=primals_21] | |
| # %primals_14 : Tensor "i64[327680][1]cuda:0" = PlaceHolder[target=primals_14] | |
| # %primals_15 : Tensor "i64[327680][1]cuda:0" = PlaceHolder[target=primals_15] | |
| # %view_258 : Tensor "bf16[327680, 16, 128][2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_45, [327680, 16, 128]), kwargs = {}) | |
| # %permute_112 : Tensor "bf16[16, 327680, 128][128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_258, [1, 0, 2]), kwargs = {}) | |
| # %view_259 : Tensor "bf16[1, 16, 327680, 128][2048, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_112, [1, 16, 327680, 128]), kwargs = {}) | |
| # %view_261 : Tensor "bf16[327680, 4, 128][512, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_46, [327680, 4, 128]), kwargs = {}) | |
| # %permute_114 : Tensor "bf16[4, 327680, 128][128, 512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_261, [1, 0, 2]), kwargs = {}) | |
| # %view_262 : Tensor "bf16[1, 4, 327680, 128][512, 128, 512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_114, [1, 4, 327680, 128]), kwargs = {}) | |
| # %view_264 : Tensor "bf16[327680, 4, 128][512, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_47, [327680, 4, 128]), kwargs = {}) | |
| # %permute_116 : Tensor "bf16[4, 327680, 128][128, 512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_264, [1, 0, 2]), kwargs = {}) | |
| # %view_265 : Tensor "bf16[1, 4, 327680, 128][512, 128, 512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_116, [1, 4, 327680, 128]), kwargs = {}) | |
| # %flex_attention_11 : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%view_259, %view_262, %view_265, %sdpa_score11, (327680, 327680, %primals_18, %primals_17, %primals_19, %primals_21, %primals_22, %primals_24, %primals_25, %primals_27, 128, 128, %sdpa_mask11), 0.08838834764831843, {BACKEND: AUTO, PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: True}, (), (%primals_14, %primals_15)), kwargs = {}) | |
| # return %getitem_83 | |
| triton_tem_fused_flex_attention_permute_view_15 = async_compile.triton('triton_tem_fused_flex_attention_permute_view_15', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| @triton_heuristics.template( | |
| num_stages=3, | |
| num_warps=8, | |
| triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'in_ptr10': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'kernel_name': 'triton_tem_fused_flex_attention_permute_view_15', 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': True, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, | |
| ) | |
| @triton.jit | |
| def triton_tem_fused_flex_attention_permute_view_15(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, in_ptr10, out_ptr0, ks0): | |
| PRESCALE_QK : tl.constexpr = False | |
| ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
| BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
| WRITE_DQ : tl.constexpr = True | |
| OUTPUT_LOGSUMEXP : tl.constexpr = True | |
| OUTPUT_MAX : tl.constexpr = True | |
| FLOAT32_PRECISION : tl.constexpr = 'tf32' | |
| IS_DIVISIBLE : tl.constexpr = True | |
| SM_SCALE : tl.constexpr = 0.08838834764831843 | |
| GQA_SHARED_HEADS : tl.constexpr = 4 | |
| HAS_FULL_BLOCKS : tl.constexpr = True | |
| QK_HEAD_DIM : tl.constexpr = 128 | |
| QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| V_HEAD_DIM : tl.constexpr = 128 | |
| V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| SAFE_HEAD_DIM : tl.constexpr = True | |
| USE_TMA : tl.constexpr = False | |
| BLOCK_M : tl.constexpr = 128 | |
| BLOCK_N : tl.constexpr = 64 | |
| SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
| SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
| INDEX_DTYPE : tl.constexpr = tl.int32 | |
| Q = arg_Q | |
| K = arg_K | |
| V = arg_V | |
| LSE = arg_LSE | |
| MAX = arg_MAX | |
| KV_NUM_BLKS = arg_KV_NUM_BLKS | |
| KV_IDX = arg_KV_IDX | |
| FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS | |
| FULL_KV_IDX = arg_FULL_KV_IDX | |
| # Sub notation for this kernel: | |
| # | |
| # Q: Query, K: Key, V: Value | |
| # M: Number of queries, N: Number of keys/values, D: Model dimension | |
| # QK_HEAD_DIM: The dimension of the query and key embeddings | |
| # V_HEAD_DIM: The dimension of the value embeddings | |
| # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head | |
| # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. | |
| # | |
| # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. | |
| # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. | |
| # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. | |
| # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. | |
| # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. | |
| # | |
| # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad | |
| # | |
| # (Modifiable) Performance tuning options | |
| # BLOCK_M: The thread block size across the seqlen dim of Q. | |
| # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. | |
| # The below are kernel options that can be applied for certain score_mods, | |
| # or involve a numerics vs. perf tradeoff | |
| # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has | |
| # about 20% more numerical error, but slightly faster. | |
| # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row | |
| # is not masked out? If so, we can skip an extra safety check | |
| # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are | |
| # contiguous? If so, we don't need to do an indirect jump for every block | |
| tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) | |
| tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) | |
| # Define strides of inputs | |
| stride_qz, stride_qh, stride_qm, stride_qk = 2048, 128, 2048, 1 | |
| stride_kz, stride_kh, stride_kn, stride_kk = 512, 128, 512, 1 | |
| stride_vz, stride_vh, stride_vn, stride_vk = 512, 128, 512, 1 | |
| ZQ = 1 | |
| HQ = 16 | |
| Q_LEN = 327680 | |
| ZKV = 1 | |
| KV_LEN = 327680 | |
| MATMUL_PRECISION = Q.dtype.element_ty | |
| q_start = tl.program_id(0).to(INDEX_DTYPE) | |
| off_zq = tl.program_id(1).to(INDEX_DTYPE) | |
| off_hq = tl.program_id(2).to(INDEX_DTYPE) | |
| # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. | |
| # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. | |
| off_zkv = off_zq % ZKV | |
| off_hkv = off_hq // GQA_SHARED_HEADS | |
| off_g = off_hq % GQA_SHARED_HEADS | |
| q_offset = off_zq * stride_qz + off_hq * stride_qh | |
| k_offset = off_zkv * stride_kz + off_hkv * stride_kh | |
| v_offset = off_zkv * stride_vz + off_hkv * stride_vh | |
| Q = Q + q_offset | |
| K = K + k_offset | |
| V = V + v_offset | |
| # Setting up the TMA descriptors for Q, K, V | |
| desc_q = None | |
| desc_k = None | |
| desc_v = None | |
| SPARSE_Z = 1 | |
| SPARSE_HQ = 1 | |
| sparse_idx_z = off_zq % SPARSE_Z | |
| sparse_idx_hq = off_hq % SPARSE_HQ | |
| SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) | |
| SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) | |
| stride_kv_num_blks_h = 2560 | |
| stride_kv_idx_h = 2560*ks0 | |
| stride_kv_idx_m = ks0 | |
| # initialize pointer to m and l | |
| m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
| l_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
| acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) | |
| offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) | |
| # KV_IDX and KV_NUM_BLKS are always contiguous. | |
| sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq | |
| sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE | |
| sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 | |
| offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) | |
| offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) | |
| q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) | |
| # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| # We don't know anything "special" about these blocks, so we need to apply | |
| # both score_mod and mask_mod to it | |
| kv_indices = KV_IDX + sparse_kv_idx_offset | |
| kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading | |
| kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) | |
| block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) | |
| # K and V pointers will be passed directly to forward_inner | |
| offs_n = kv_start + tl.arange(0, BLOCK_N) | |
| acc, l_i, m_i = forward_inner( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, in_ptr10, out_ptr0, ks0, | |
| q, K, V, | |
| desc_k, desc_v, Q_LEN, KV_LEN, | |
| acc, l_i, m_i, | |
| off_zq, off_hq, offs_m[:, None], offs_n[None, :], | |
| kv_start, | |
| kv_indices, kv_num_blocks, | |
| 0, block_n_end, | |
| MATMUL_PRECISION, | |
| stride_kk, stride_kn, stride_vn, stride_vk, | |
| IS_FULL_BLOCKS=False, | |
| ) | |
| # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| # We know these blocks are guaranteed to be "full", so we don't need to | |
| # apply mask_mod to them - only score_mod | |
| if HAS_FULL_BLOCKS: | |
| # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. | |
| kv_indices = FULL_KV_IDX + sparse_kv_idx_offset | |
| kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading | |
| kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) | |
| block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) | |
| # K and V pointers will be passed directly to forward_inner | |
| offs_n = kv_start + tl.arange(0, BLOCK_N) | |
| acc, l_i, m_i = forward_inner( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, in_ptr10, out_ptr0, ks0, | |
| q, K, V, | |
| desc_k, desc_v, Q_LEN, KV_LEN, | |
| acc, l_i, m_i, | |
| off_zq, off_hq, offs_m[:, None], offs_n[None, :], | |
| kv_start, | |
| kv_indices, kv_num_blocks, | |
| 0, block_n_end, | |
| MATMUL_PRECISION, | |
| stride_kk, stride_kn, stride_vn, stride_vk, | |
| IS_FULL_BLOCKS=True, | |
| ) | |
| # [Note] Handle fully masked out rows: | |
| # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. | |
| # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step | |
| l_i = tl.where(l_i == 0.0, 1, l_i) | |
| acc = acc / l_i[:, None] | |
| idx_zq = tl.program_id(1).to(INDEX_DTYPE) | |
| idx_hq = tl.program_id(2).to(INDEX_DTYPE) | |
| idx_m = offs_m[:, None].to(INDEX_DTYPE) | |
| idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) | |
| mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) | |
| tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) | |
| xindex = idx_d + 128*idx_m + 41943040*idx_hq + 671088640*idx_zq | |
| tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 2048*idx_m, [BLOCK_M, V_HEAD_DIM_ROUNDED])), acc, mask) | |
| if OUTPUT_LOGSUMEXP: | |
| off_hz = off_zq * HQ + off_hq | |
| l_ptrs = LSE + off_hz * Q_LEN + offs_m | |
| lse = m_i + tl.math.log2(l_i) | |
| if IS_DIVISIBLE: | |
| tl.store(l_ptrs, lse) | |
| else: | |
| tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) | |
| if OUTPUT_MAX: | |
| off_hz = off_zq * HQ + off_hq | |
| max_ptrs = MAX + off_hz * Q_LEN + offs_m | |
| if IS_DIVISIBLE: | |
| tl.store(max_ptrs, m_i) | |
| else: | |
| tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) | |
| # Utility triton funcs | |
| @triton.jit | |
| def get_offset_for_next_block( | |
| loop_iter, col_indices, total_blocks, | |
| SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, | |
| BLOCKS_ARE_CONTIGUOUS: tl.constexpr | |
| ): | |
| if BLOCKS_ARE_CONTIGUOUS: | |
| return BLOCK | |
| cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE | |
| cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") | |
| next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) | |
| needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 | |
| jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK | |
| offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK | |
| return offset | |
| @triton.jit | |
| def get_bounded_indices(indices, max_len=None): | |
| return indices % max_len if max_len is not None else indices | |
| @triton.jit | |
| def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): | |
| if IS_DIVISIBLE and SAFE_HEAD_DIM: | |
| return tl.load(block_ptr) | |
| elif IS_DIVISIBLE and not SAFE_HEAD_DIM: | |
| return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") | |
| elif not IS_DIVISIBLE and SAFE_HEAD_DIM: | |
| return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") | |
| else: | |
| return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") | |
| @triton.jit | |
| def load_checked_2d( | |
| ptr, | |
| offs_m, | |
| offs_n, | |
| stride_m, | |
| stride_n, | |
| IS_DIVISIBLE_M: tl.constexpr, | |
| IS_DIVISIBLE_N: tl.constexpr, | |
| M_LEN: tl.constexpr, | |
| N_LEN: tl.constexpr, | |
| ): | |
| # Calculate final pointer if strides are provided | |
| if stride_m is not None and stride_n is not None: | |
| ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n | |
| # Handle all masking cases | |
| if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: | |
| return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) | |
| elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: | |
| return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) | |
| elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: | |
| return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) | |
| else: # Both divisible | |
| return tl.load(ptr) | |
| # Common Imports | |
| @triton.jit | |
| def forward_block_mn( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, in_ptr10, out_ptr0, ks0, | |
| q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, | |
| # accumulated values | |
| acc, l_i, m_i, | |
| # Offsets | |
| off_z, off_h, offs_m, offs_n, | |
| # Offsets needed for TMA loads | |
| kv_start, | |
| kv_offset, | |
| MATMUL_PRECISION, RCP_LN2, | |
| # Strides for K and V | |
| stride_kk, stride_kn, stride_vn, stride_vk, | |
| IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, | |
| ): | |
| # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through | |
| PRESCALE_QK : tl.constexpr = False | |
| ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
| BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
| WRITE_DQ : tl.constexpr = True | |
| OUTPUT_LOGSUMEXP : tl.constexpr = True | |
| OUTPUT_MAX : tl.constexpr = True | |
| FLOAT32_PRECISION : tl.constexpr = 'tf32' | |
| IS_DIVISIBLE : tl.constexpr = True | |
| SM_SCALE : tl.constexpr = 0.08838834764831843 | |
| GQA_SHARED_HEADS : tl.constexpr = 4 | |
| HAS_FULL_BLOCKS : tl.constexpr = True | |
| QK_HEAD_DIM : tl.constexpr = 128 | |
| QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| V_HEAD_DIM : tl.constexpr = 128 | |
| V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| SAFE_HEAD_DIM : tl.constexpr = True | |
| USE_TMA : tl.constexpr = False | |
| BLOCK_M : tl.constexpr = 128 | |
| BLOCK_N : tl.constexpr = 64 | |
| SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
| SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
| INDEX_DTYPE : tl.constexpr = tl.int32 | |
| # -- load k -- | |
| # NB reversed order to since K is transposed | |
| kv_base_offset = kv_start + kv_offset | |
| # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] | |
| offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) | |
| offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) | |
| k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) | |
| k = tl.trans(k) | |
| k = k.to(q.dtype) | |
| # -- compute qk --- | |
| qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. | |
| if not PRESCALE_QK: | |
| qk *= SM_SCALE | |
| # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ | |
| # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, | |
| # which is larger than the actual number of elements. To avoid access memory out of bound, | |
| # we need to mask out the elements that are out of Q_LEN & KV_LEN. | |
| m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) | |
| n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) | |
| tmp0 = (qk) | |
| post_mod_scores = tmp0 | |
| if CHECK_BLOCK_BOUNDARY: | |
| # Mask out the elements that are out of the KV_LEN for non divisible seqlen. | |
| post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) | |
| if not IS_FULL_BLOCKS: | |
| tmp1 = (m) | |
| tmp2 = (n) | |
| tmp3 = tmp1 >= tmp2 | |
| tmp4 = tl.load(in_ptr9 + tmp1) | |
| tmp5 = tl.full([1], 0, tl.int64) | |
| tmp6 = tmp4 > tmp5 | |
| tmp7 = tl.load(in_ptr9 + tmp2) | |
| tmp8 = tmp7 > tmp5 | |
| tmp9 = tmp6 & tmp8 | |
| tmp10 = tmp4 == tmp7 | |
| tmp11 = tmp9 & tmp10 | |
| tmp12 = tmp3 | tmp11 | |
| tmp13 = tl.load(in_ptr10 + tmp1) | |
| tmp14 = tl.load(in_ptr10 + tmp2) | |
| tmp15 = tmp13 == tmp14 | |
| tmp16 = tmp12 & tmp15 | |
| tmp17 = tl.full([1], -1, tl.int64) | |
| tmp18 = tmp4 == tmp17 | |
| tmp19 = tmp7 == tmp17 | |
| tmp20 = tmp18 | tmp19 | |
| tmp21 = tmp20 == 0 | |
| tmp22 = tmp16 & tmp21 | |
| mask_mod_output = tmp22 | |
| if CHECK_BLOCK_BOUNDARY: | |
| mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) | |
| # apply mask for partially unmasked blocks | |
| post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) | |
| if not PRESCALE_QK: | |
| post_mod_scores *= RCP_LN2 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| # -- compute scaling constant --- | |
| m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) | |
| if not ROWS_GUARANTEED_SAFE: | |
| masked_out_rows = (m_ij == float("-inf")) | |
| m_ij_masked = tl.where(masked_out_rows, 0, m_ij) | |
| else: | |
| m_ij_masked = m_ij | |
| alpha = tl.math.exp2(m_i - m_ij_masked) | |
| p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) | |
| # NB: l_i update is pulled up here since it's a bit faster | |
| # NB: For headdim=256, it's faster to move it back down to after m_i = | |
| # m_ij | |
| l_i = l_i * alpha + tl.sum(p, 1) | |
| # # -- scale and update acc -- | |
| acc = acc * alpha[:, None] | |
| # Calculate offsets for V loading - reuse kv_base_offset from K loading | |
| offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) | |
| v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) | |
| acc = tl.dot(p.to(MATMUL_PRECISION), v.to(q.dtype), acc, input_precision=FLOAT32_PRECISION) | |
| # -- update m_i | |
| m_i = m_ij | |
| return acc, l_i, m_i | |
| @triton.jit | |
| def forward_inner( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, in_ptr10, out_ptr0, ks0, | |
| q, K, V, | |
| desc_k, desc_v, Q_LEN, KV_LEN, | |
| # accumulated values | |
| acc, l_i, m_i, | |
| # Offsets used as inputs to score_mod & mask_mod | |
| # of size [BLOCK_M, BLOCK_N] or scalar. | |
| off_z, off_h, offs_m, offs_n, | |
| # Offsets needed for TMA loads | |
| kv_start, | |
| # blocksparse data | |
| kv_indices, kv_num_blocks, | |
| # start kv and end kv block | |
| block_n_start, block_n_end, | |
| MATMUL_PRECISION, | |
| # Strides for K and V | |
| stride_kk, stride_kn, stride_vn, stride_vk, | |
| IS_FULL_BLOCKS, | |
| ): | |
| # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through | |
| PRESCALE_QK : tl.constexpr = False | |
| ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
| BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
| WRITE_DQ : tl.constexpr = True | |
| OUTPUT_LOGSUMEXP : tl.constexpr = True | |
| OUTPUT_MAX : tl.constexpr = True | |
| FLOAT32_PRECISION : tl.constexpr = 'tf32' | |
| IS_DIVISIBLE : tl.constexpr = True | |
| SM_SCALE : tl.constexpr = 0.08838834764831843 | |
| GQA_SHARED_HEADS : tl.constexpr = 4 | |
| HAS_FULL_BLOCKS : tl.constexpr = True | |
| QK_HEAD_DIM : tl.constexpr = 128 | |
| QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| V_HEAD_DIM : tl.constexpr = 128 | |
| V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| SAFE_HEAD_DIM : tl.constexpr = True | |
| USE_TMA : tl.constexpr = False | |
| BLOCK_M : tl.constexpr = 128 | |
| BLOCK_N : tl.constexpr = 64 | |
| SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
| SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
| INDEX_DTYPE : tl.constexpr = tl.int32 | |
| SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) | |
| RCP_LN2: tl.constexpr = 1.44269504 | |
| if PRESCALE_QK: | |
| q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) | |
| kv_offset = 0 | |
| # loop over k, v and update accumulator until block_n_end | |
| for start_n in range(block_n_start, block_n_end): | |
| # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. | |
| if IS_DIVISIBLE: | |
| acc, l_i, m_i = forward_block_mn( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, in_ptr10, out_ptr0, ks0, | |
| q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, | |
| # accumulated values | |
| acc, l_i, m_i, | |
| # Offsets | |
| off_z, off_h, offs_m, offs_n, | |
| # Offsets needed for TMA loads | |
| kv_start, | |
| kv_offset, | |
| MATMUL_PRECISION, RCP_LN2, | |
| # Strides for K and V | |
| stride_kk, stride_kn, stride_vn, stride_vk, | |
| IS_FULL_BLOCKS, | |
| ) | |
| else: | |
| # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, | |
| # it's on par or slightly faster than only applying to the last block in fwd. | |
| # However, we choose different strategy for bwd, where we only apply mod & mask | |
| # to the last block because it's faster a lot. | |
| acc, l_i, m_i = forward_block_mn( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, in_ptr10, out_ptr0, ks0, | |
| q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, | |
| # accumulated values | |
| acc, l_i, m_i, | |
| # Offsets | |
| off_z, off_h, offs_m, offs_n, | |
| # Offsets needed for TMA loads | |
| kv_start, | |
| kv_offset, | |
| MATMUL_PRECISION, RCP_LN2, | |
| # Strides for K and V | |
| stride_kk, stride_kn, stride_vn, stride_vk, | |
| IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, | |
| ) | |
| offset = get_offset_for_next_block( | |
| start_n, kv_indices, kv_num_blocks, | |
| SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS | |
| ) | |
| offs_n = offs_n + offset | |
| kv_offset += offset | |
| return acc, l_i, m_i | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/l2/cl2tvo5hcgbtcyeecfjmb2uuyw7bdiqs5eiobc6gbmdzikudnimp.py | |
| # Topologically Sorted Source Nodes: [convert_element_type_447], Original ATen: [aten._to_copy] | |
| # Source node to ATen node mapping: | |
| # convert_element_type_447 => convert_element_type_447 | |
| # Graph fragment: | |
| # %buf96 : Tensor = PlaceHolder[target=buf96] | |
| # %convert_element_type_447 : Tensor "f32[2048, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_8, torch.float32), kwargs = {}) | |
| # return %convert_element_type_447 | |
| triton_poi_fused__to_copy_16 = async_compile.triton('triton_poi_fused__to_copy_16', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 4194304}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_16', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 33554432}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_16(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 4194304 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/6i/c6itkxmaq3yfex4sobzcrmuof6hwih6dimfyghykx2wj6eoen4gq.py | |
| # Topologically Sorted Source Nodes: [q, k, v, view_283, view_284, permute_133, flex_attention_backward], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| # Source node to ATen node mapping: | |
| # flex_attention_backward => flex_attention_backward | |
| # k => permute_114, view_261, view_262 | |
| # permute_133 => permute_133 | |
| # q => permute_112, view_258, view_259 | |
| # v => permute_116, view_264, view_265 | |
| # view_283 => view_283 | |
| # view_284 => view_284 | |
| # Graph fragment: | |
| # %getitem_83 : Tensor "bf16[1, 16, 327680, 128][671088640, 128, 2048, 1]cuda:0" = PlaceHolder[target=getitem_83] | |
| # %mm_54 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_54] | |
| # %buf98 : Tensor "bf16[1, 16, 327680][5242880, 1, 16]cuda:0" = PlaceHolder[target=buf98] | |
| # %view_258 : Tensor "bf16[327680, 16, 128][2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_45, [327680, 16, 128]), kwargs = {}) | |
| # %permute_112 : Tensor "bf16[16, 327680, 128][128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_258, [1, 0, 2]), kwargs = {}) | |
| # %view_259 : Tensor "bf16[1, 16, 327680, 128][2048, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_112, [1, 16, 327680, 128]), kwargs = {}) | |
| # %view_261 : Tensor "bf16[327680, 4, 128][512, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_46, [327680, 4, 128]), kwargs = {}) | |
| # %permute_114 : Tensor "bf16[4, 327680, 128][128, 512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_261, [1, 0, 2]), kwargs = {}) | |
| # %view_262 : Tensor "bf16[1, 4, 327680, 128][512, 128, 512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_114, [1, 4, 327680, 128]), kwargs = {}) | |
| # %view_264 : Tensor "bf16[327680, 4, 128][512, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_47, [327680, 4, 128]), kwargs = {}) | |
| # %permute_116 : Tensor "bf16[4, 327680, 128][128, 512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_264, [1, 0, 2]), kwargs = {}) | |
| # %view_265 : Tensor "bf16[1, 4, 327680, 128][512, 128, 512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_116, [1, 4, 327680, 128]), kwargs = {}) | |
| # %view_283 : Tensor "bf16[1, 327680, 2048][671088640, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_54, [1, 327680, 2048]), kwargs = {}) | |
| # %view_284 : Tensor "bf16[1, 327680, 16, 128][671088640, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_283, [1, 327680, 16, 128]), kwargs = {}) | |
| # %permute_133 : Tensor "bf16[1, 16, 327680, 128][671088640, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_284, [0, 2, 1, 3]), kwargs = {}) | |
| # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%view_259, %view_262, %view_265, %getitem_83, %getitem_84, %permute_133, None, %fw_graph0, %joint_graph0, (327680, 327680, %primals_18, %primals_17, %primals_19, %primals_21, %primals_22, %primals_24, %primals_25, %primals_27, 128, 128, %mask_graph0), 0.08838834764831843, {BACKEND: AUTO, PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: True}, (), (%primals_14, %primals_15)), kwargs = {}) | |
| # return %buf98,%buf99 | |
| triton_per_fused_flex_attention_backward_permute_view_17 = async_compile.triton('triton_per_fused_flex_attention_backward_permute_view_17', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 8388608, 'r0_': 128}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_flex_attention_backward_permute_view_17', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'x': 0, 'r0_': 2684354560}} | |
| ) | |
| @triton.jit | |
| def triton_per_fused_flex_attention_backward_permute_view_17(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr): | |
| xnumel = 5242880 | |
| r0_numel = 128 | |
| R0_BLOCK: tl.constexpr = 128 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_index = tl.arange(0, R0_BLOCK)[None, :] | |
| r0_offset = 0 | |
| r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :] | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_1 = r0_index | |
| x0 = xindex | |
| x2 = (xindex % 16) | |
| x3 = xindex // 16 | |
| tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), None).to(tl.float32) | |
| tmp1 = tl.load(in_ptr1 + (r0_1 + 128*x0), None).to(tl.float32) | |
| tmp2 = tmp0 * tmp1 | |
| tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) | |
| tmp5 = tl.sum(tmp3, 1)[:, None].to(tl.float32) | |
| tmp6 = tmp5.to(tl.float32) | |
| tl.store(out_ptr1 + (x3 + 327680*x2), tmp6, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/xd/cxd2dykhtsl345elda55jiiq6ml5mrbsssan4gy5n6s4njycn65u.py | |
| # Topologically Sorted Source Nodes: [q, k, v, view_283, view_284, permute_133, flex_attention_backward], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| # Source node to ATen node mapping: | |
| # flex_attention_backward => flex_attention_backward | |
| # k => permute_114, view_261, view_262 | |
| # permute_133 => permute_133 | |
| # q => permute_112, view_258, view_259 | |
| # v => permute_116, view_264, view_265 | |
| # view_283 => view_283 | |
| # view_284 => view_284 | |
| # Graph fragment: | |
| # %mm_45 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_45] | |
| # %mm_46 : Tensor "bf16[327680, 512][512, 1]cuda:0" = PlaceHolder[target=mm_46] | |
| # %mm_47 : Tensor "bf16[327680, 512][512, 1]cuda:0" = PlaceHolder[target=mm_47] | |
| # %buf88 : Tensor = PlaceHolder[target=buf88] | |
| # %buf99 : Tensor "f32[1, 16, 327680][5242880, 327680, 1]cuda:0" = PlaceHolder[target=buf99] | |
| # %mm_54 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_54] | |
| # %getitem_90 : Tensor "bf16[1, 16, 327680, 128][671088640, 128, 2048, 1]cuda:0" = PlaceHolder[target=getitem_90] | |
| # %getitem_92 : Tensor "bf16[1, 4, 327680, 128][167772160, 128, 512, 1]cuda:0" = PlaceHolder[target=getitem_92] | |
| # %primals_18 : Tensor "i32[1, 1, 2560][2560, 2560, 1]cuda:0" = PlaceHolder[target=primals_18] | |
| # %primals_17 : Tensor "i32[1, 1, 2560, s91][2560*s91, 2560*s91, s91, 1]cuda:0" = PlaceHolder[target=primals_17] | |
| # %primals_22 : Tensor "i32[1, 1, 2560][2560, 2560, 1]cuda:0" = PlaceHolder[target=primals_22] | |
| # %primals_24 : Tensor "i32[1, 1, 2560, s16][2560*s16, 2560*s16, s16, 1]cuda:0" = PlaceHolder[target=primals_24] | |
| # %primals_19 : Tensor "i32[1, 1, 2560][2560, 2560, 1]cuda:0" = PlaceHolder[target=primals_19] | |
| # %primals_21 : Tensor "i32[1, 1, 2560, s6][2560*s6, 2560*s6, s6, 1]cuda:0" = PlaceHolder[target=primals_21] | |
| # %primals_25 : Tensor "i32[1, 1, 2560][2560, 2560, 1]cuda:0" = PlaceHolder[target=primals_25] | |
| # %primals_27 : Tensor "i32[1, 1, 2560, s18][2560*s18, 2560*s18, s18, 1]cuda:0" = PlaceHolder[target=primals_27] | |
| # %primals_14 : Tensor "i64[327680][1]cuda:0" = PlaceHolder[target=primals_14] | |
| # %primals_15 : Tensor "i64[327680][1]cuda:0" = PlaceHolder[target=primals_15] | |
| # %view_258 : Tensor "bf16[327680, 16, 128][2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_45, [327680, 16, 128]), kwargs = {}) | |
| # %permute_112 : Tensor "bf16[16, 327680, 128][128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_258, [1, 0, 2]), kwargs = {}) | |
| # %view_259 : Tensor "bf16[1, 16, 327680, 128][2048, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_112, [1, 16, 327680, 128]), kwargs = {}) | |
| # %view_261 : Tensor "bf16[327680, 4, 128][512, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_46, [327680, 4, 128]), kwargs = {}) | |
| # %permute_114 : Tensor "bf16[4, 327680, 128][128, 512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_261, [1, 0, 2]), kwargs = {}) | |
| # %view_262 : Tensor "bf16[1, 4, 327680, 128][512, 128, 512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_114, [1, 4, 327680, 128]), kwargs = {}) | |
| # %view_264 : Tensor "bf16[327680, 4, 128][512, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_47, [327680, 4, 128]), kwargs = {}) | |
| # %permute_116 : Tensor "bf16[4, 327680, 128][128, 512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_264, [1, 0, 2]), kwargs = {}) | |
| # %view_265 : Tensor "bf16[1, 4, 327680, 128][512, 128, 512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_116, [1, 4, 327680, 128]), kwargs = {}) | |
| # %view_283 : Tensor "bf16[1, 327680, 2048][671088640, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_54, [1, 327680, 2048]), kwargs = {}) | |
| # %view_284 : Tensor "bf16[1, 327680, 16, 128][671088640, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_283, [1, 327680, 16, 128]), kwargs = {}) | |
| # %permute_133 : Tensor "bf16[1, 16, 327680, 128][671088640, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_284, [0, 2, 1, 3]), kwargs = {}) | |
| # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%view_259, %view_262, %view_265, %getitem_83, %getitem_84, %permute_133, None, %fw_graph0, %joint_graph0, (327680, 327680, %primals_18, %primals_17, %primals_19, %primals_21, %primals_22, %primals_24, %primals_25, %primals_27, 128, 128, %mask_graph0), 0.08838834764831843, {BACKEND: AUTO, PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: True}, (), (%primals_14, %primals_15)), kwargs = {}) | |
| # return %getitem_91 | |
| triton_tem_fused_flex_attention_backward_permute_view_18 = async_compile.triton('triton_tem_fused_flex_attention_backward_permute_view_18', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| @triton_heuristics.template( | |
| num_stages=3, | |
| num_warps=8, | |
| triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'in_ptr17': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]], (18,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'kernel_name': 'triton_tem_fused_flex_attention_backward_permute_view_18', 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': True, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, | |
| ) | |
| @triton.jit | |
| def triton_tem_fused_flex_attention_backward_permute_view_18(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1): | |
| PRESCALE_QK : tl.constexpr = False | |
| ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
| BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
| WRITE_DQ : tl.constexpr = True | |
| OUTPUT_LOGSUMEXP : tl.constexpr = True | |
| OUTPUT_MAX : tl.constexpr = True | |
| FLOAT32_PRECISION : tl.constexpr = 'tf32' | |
| IS_DIVISIBLE : tl.constexpr = True | |
| SM_SCALE : tl.constexpr = 0.08838834764831843 | |
| GQA_SHARED_HEADS : tl.constexpr = 4 | |
| HAS_FULL_BLOCKS : tl.constexpr = True | |
| QK_HEAD_DIM : tl.constexpr = 128 | |
| QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| V_HEAD_DIM : tl.constexpr = 128 | |
| V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| SAFE_HEAD_DIM : tl.constexpr = True | |
| USE_TMA : tl.constexpr = False | |
| BLOCK_M1 : tl.constexpr = 64 | |
| BLOCK_N1 : tl.constexpr = 128 | |
| BLOCK_M2 : tl.constexpr = 128 | |
| BLOCK_N2 : tl.constexpr = 64 | |
| SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
| SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
| INDEX_DTYPE : tl.constexpr = tl.int32 | |
| Q = arg_Q | |
| K = arg_K | |
| V = arg_V | |
| LSE = arg_LSE | |
| DELTA = arg_DELTA | |
| DO = arg_DO | |
| DQ = arg_DQ | |
| DV = arg_DV | |
| KV_NUM_BLKS = arg_KV_NUM_BLKS | |
| KV_IDX = arg_KV_IDX | |
| Q_NUM_BLKS = arg_Q_NUM_BLKS | |
| Q_IDX = arg_Q_IDX | |
| FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS | |
| FULL_KV_IDX = arg_FULL_KV_IDX | |
| FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS | |
| FULL_Q_IDX = arg_FULL_Q_IDX | |
| # Sub notation for this kernel: | |
| # | |
| # Q: Query, K: Key, V: Value | |
| # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) | |
| # DELTA: Precomputed sum(OUT*DO, axis=-1) | |
| # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value | |
| # DK: Derivative of Key, is the written to via the store_output call due to some limitations with | |
| # inductor codegen | |
| # M: Number of queries, N: Number of keys/values | |
| # QK_HEAD_DIM: The dimension of the query and key embeddings | |
| # V_HEAD_DIM: The dimension of the value embeddings | |
| # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim | |
| # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. | |
| # (Modifiable) Performance tuning options | |
| # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. | |
| # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. | |
| # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. | |
| # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. | |
| # | |
| # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. | |
| # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. | |
| # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. | |
| # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. | |
| # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. | |
| # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. | |
| # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. | |
| # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. | |
| # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. | |
| # The below are kernel options that can be applied for certain score_mods, | |
| # or involve a numerics vs. perf tradeoff | |
| # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has | |
| # about 20% more numerical error, but slightly faster. | |
| # Define strides of inputs | |
| stride_qz, stride_qh, stride_qm, stride_qd = 2048, 128, 2048, 1 | |
| stride_kz, stride_kh, stride_kn, stride_kd = 512, 128, 512, 1 | |
| stride_vz, stride_vh, stride_vn, stride_vd = 512, 128, 512, 1 | |
| stride_doz, stride_doh, stride_dom, stride_dod = 671088640, 128, 2048, 1 | |
| stride_dqz, stride_dqh, stride_dqm, stride_dqd = 671088640, 128, 2048, 1 | |
| stride_dvz, stride_dvh, stride_dvm, stride_dvd = 167772160, 128, 512, 1 | |
| ZQ = 1 | |
| HQ = 16 | |
| HKV = 4 | |
| Q_LEN = 327680 | |
| ZKV = 1 | |
| KV_LEN = 327680 | |
| MATMUL_PRECISION = Q.dtype.element_ty | |
| pid = tl.program_id(0).to(INDEX_DTYPE) | |
| NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) | |
| NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) | |
| off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx | |
| off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx | |
| off_zkv = off_zq % ZKV # kv batch idx | |
| SPARSE_Z = 1 | |
| SPARSE_HQ = 1 | |
| sparse_idx_z = off_zq % SPARSE_Z | |
| k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) | |
| v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) | |
| # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] | |
| # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] | |
| dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) | |
| # offset K, V, DV pointers for batch/kv-head | |
| K += k_adj | |
| V += v_adj | |
| DV += dv_adj | |
| RCP_LN2 = 1.44269504 | |
| offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) | |
| offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) | |
| if pid >= NUM_KV_BLOCKS: | |
| off_pid = pid - NUM_KV_BLOCKS | |
| # THIS BLOCK DOES DQ | |
| SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) | |
| SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) | |
| off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS | |
| start_m2_block = off_pid % NUM_Q_BLOCKS | |
| off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE | |
| stride_kv_num_blks_h = 2560 | |
| stride_kv_idx_h = 2560*ks0 | |
| stride_kv_idx_m = ks0 | |
| sparse_idx_hq2 = off_hq2 % SPARSE_HQ | |
| sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 | |
| sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask | |
| sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 | |
| # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. | |
| q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) | |
| do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) | |
| dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) | |
| off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) | |
| Q2 = Q + q_adj2 | |
| DO2 = DO + do_adj2 | |
| # TODO: This does not work if DQ is not the same layout as Q (for example, | |
| # if Q is broadcasted) | |
| DQ2 = DQ + dq_adj2 | |
| LSE2 = LSE + off_chz2 | |
| DELTA2 = DELTA + off_chz2 | |
| dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) | |
| start_m2 = start_m2_block * BLOCK_M2 | |
| offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) | |
| desc_q = None | |
| desc_do = None | |
| desc_k_dq = None | |
| desc_v_dq = None | |
| # load Q and do: they stay in SRAM throughout the inner loop. | |
| q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) | |
| do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) | |
| if PRESCALE_QK: | |
| q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) | |
| if IS_DIVISIBLE: | |
| Di = tl.load(DELTA2 + offs_m2) | |
| lse = tl.load(LSE2 + offs_m2) | |
| else: | |
| Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) | |
| lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) | |
| lse = tl.where(lse == -float("inf"), 0.0, lse) | |
| lse = lse[:, None] | |
| # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| # KV_IDX and KV_NUM_BLKS are always contiguous. | |
| kv_indices = KV_IDX + sparse_kv_idx_offset | |
| kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading | |
| sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) | |
| dq = bwd_dq_inner( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1, | |
| K, V, desc_k_dq, desc_v_dq, kv_start, start_m2, | |
| dq, q, do, Di, lse, | |
| off_zq, off_hq2, | |
| stride_kn, stride_kd, stride_vn, stride_vd, | |
| kv_indices, sparse_kv_num_blocks, | |
| MATMUL_PRECISION, | |
| IS_FULL_BLOCKS=False, | |
| ) | |
| if HAS_FULL_BLOCKS: | |
| # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. | |
| kv_indices = FULL_KV_IDX + sparse_kv_idx_offset | |
| kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading | |
| sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) | |
| offs_n2 = kv_start + tl.arange(0, BLOCK_N2) | |
| dq = bwd_dq_inner( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1, | |
| K, V, desc_k_dq, desc_v_dq, kv_start, start_m2, | |
| dq, q, do, Di, lse, | |
| off_zq, off_hq2, | |
| stride_kn, stride_kd, stride_vn, stride_vd, | |
| kv_indices, sparse_kv_num_blocks, | |
| MATMUL_PRECISION, | |
| IS_FULL_BLOCKS=True, | |
| ) | |
| # Write back dQ. | |
| dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd | |
| dq *= SM_SCALE | |
| if IS_DIVISIBLE and SAFE_HEAD_DIM: | |
| tl.store(dq_ptrs, dq) | |
| else: | |
| tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) | |
| else: | |
| # THIS BLOCK DOES DK & DV | |
| SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) | |
| SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) | |
| pid_mask = pid // SPARSE_KV_MULTIPLE | |
| stride_q_num_blks_h = 2560 | |
| stride_q_idx_h = 2560*ks1 | |
| stride_q_idx_n = ks1 | |
| dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) | |
| dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) | |
| start_n1 = pid * BLOCK_N1 | |
| offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) | |
| # load K and V: they stay in SRAM throughout the inner loop. | |
| k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) | |
| v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) | |
| if PRESCALE_QK: | |
| k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) | |
| for off_g in range(0, GQA_SHARED_HEADS): | |
| off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g | |
| # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. | |
| q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) | |
| do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) | |
| dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) | |
| off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) | |
| Q1 = Q + q_adj1 | |
| DO1 = DO + do_adj1 | |
| # TODO: This does not work if DQ is not the same layout as Q (for example, | |
| # if Q is broadcasted) | |
| LSE1 = LSE + off_chz1 | |
| DELTA1 = DELTA + off_chz1 | |
| sparse_idx_hq1 = off_hq1 % SPARSE_HQ | |
| sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 | |
| sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask | |
| sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 | |
| # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| # Q_IDX and Q_NUM_BLKS are always contiguous. | |
| q_indices = Q_IDX + sparse_q_idx_offset | |
| q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading | |
| sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) | |
| offs_m1 = q_start + tl.arange(0, BLOCK_M1) | |
| desc_q = None | |
| desc_do = None | |
| dk, dv = bwd_dkdv_inner( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1, | |
| Q1, DO1, DELTA1, LSE1, | |
| desc_q, desc_do, q_start, | |
| dk, dv, k, v, | |
| off_zq, off_hq1, offs_n1, offs_m1, | |
| stride_qm, stride_qd, stride_dom, stride_dod, | |
| q_indices, sparse_q_num_blocks, | |
| MATMUL_PRECISION, | |
| IS_FULL_BLOCKS=False, | |
| ) | |
| if HAS_FULL_BLOCKS: | |
| # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. | |
| q_indices = FULL_Q_IDX + sparse_q_idx_offset | |
| q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading | |
| sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) | |
| offs_m1 = q_start + tl.arange(0, BLOCK_M1) | |
| dk, dv = bwd_dkdv_inner( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1, | |
| Q1, DO1, DELTA1, LSE1, | |
| desc_q, desc_do, q_start, | |
| dk, dv, k, v, | |
| off_zq, off_hq1, offs_n1, offs_m1, | |
| stride_qm, stride_qd, stride_dom, stride_dod, | |
| q_indices, sparse_q_num_blocks, | |
| MATMUL_PRECISION, | |
| IS_FULL_BLOCKS=True, | |
| ) | |
| # Write back dV and dK. | |
| dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd | |
| index_n = offs_n1[:, None] | |
| index_k = offs_k[None, :] | |
| index_v = offs_v[None, :] | |
| if IS_DIVISIBLE and SAFE_HEAD_DIM: | |
| tl.store(dv_ptrs, dv) | |
| else: | |
| tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) | |
| dk *= SM_SCALE | |
| if SAFE_HEAD_DIM: | |
| mask = index_n < KV_LEN | |
| else: | |
| mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) | |
| # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] | |
| # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] | |
| tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) | |
| xindex = index_k + 128*index_n + 41943040*off_hkv + 167772160*off_zq | |
| tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 512*index_n, [BLOCK_N1, QK_HEAD_DIM_ROUNDED])), dk, mask) | |
| @triton.jit | |
| def bwd_dq_inner( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1, | |
| K, V, desc_k, desc_v, kv_start, start_m2, | |
| dq, q, do, Di, lse, | |
| off_z, off_hq, | |
| stride_kn, stride_kd, stride_vn, stride_vd, | |
| kv_indices, sparse_kv_num_blocks, | |
| MATMUL_PRECISION, | |
| IS_FULL_BLOCKS, | |
| ): | |
| PRESCALE_QK : tl.constexpr = False | |
| ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
| BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
| WRITE_DQ : tl.constexpr = True | |
| OUTPUT_LOGSUMEXP : tl.constexpr = True | |
| OUTPUT_MAX : tl.constexpr = True | |
| FLOAT32_PRECISION : tl.constexpr = 'tf32' | |
| IS_DIVISIBLE : tl.constexpr = True | |
| SM_SCALE : tl.constexpr = 0.08838834764831843 | |
| GQA_SHARED_HEADS : tl.constexpr = 4 | |
| HAS_FULL_BLOCKS : tl.constexpr = True | |
| QK_HEAD_DIM : tl.constexpr = 128 | |
| QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| V_HEAD_DIM : tl.constexpr = 128 | |
| V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| SAFE_HEAD_DIM : tl.constexpr = True | |
| USE_TMA : tl.constexpr = False | |
| BLOCK_M1 : tl.constexpr = 64 | |
| BLOCK_N1 : tl.constexpr = 128 | |
| BLOCK_M2 : tl.constexpr = 128 | |
| BLOCK_N2 : tl.constexpr = 64 | |
| SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
| SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
| INDEX_DTYPE : tl.constexpr = tl.int32 | |
| SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) | |
| RCP_LN2: tl.constexpr = 1.44269504 | |
| Q_LEN = 327680 | |
| KV_LEN = 327680 | |
| offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) | |
| offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) | |
| # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. | |
| tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) | |
| hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) | |
| kv_offset = 0 | |
| for start_n in range(0, hi): | |
| dq = bwd_dq_block_mn( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1, | |
| dq, q, K, V, desc_k, desc_v, kv_start, kv_offset, start_m2, | |
| do, Di, lse, Q_LEN, KV_LEN, | |
| off_z, off_hq, offs_k, offs_v, | |
| stride_kn, stride_kd, stride_vn, stride_vd, | |
| kv_indices, sparse_kv_num_blocks, | |
| MATMUL_PRECISION, RCP_LN2, | |
| IS_FULL_BLOCKS, | |
| ) | |
| # Increment pointers. | |
| offset = get_offset_for_next_block( | |
| start_n, kv_indices, sparse_kv_num_blocks, | |
| SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS | |
| ) | |
| kv_offset += offset | |
| return dq | |
| @triton.jit | |
| def bwd_dq_block_mn( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1, | |
| dq, q, K, V, desc_k, desc_v, kv_start, kv_offset, start_m2, | |
| do, Di, lse, Q_LEN, KV_LEN, | |
| off_z, off_hq, offs_k, offs_v, | |
| stride_kn, stride_kd, stride_vn, stride_vd, | |
| kv_indices, sparse_kv_num_blocks, | |
| MATMUL_PRECISION, RCP_LN2, | |
| IS_FULL_BLOCKS, | |
| ): | |
| PRESCALE_QK : tl.constexpr = False | |
| ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
| BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
| WRITE_DQ : tl.constexpr = True | |
| OUTPUT_LOGSUMEXP : tl.constexpr = True | |
| OUTPUT_MAX : tl.constexpr = True | |
| FLOAT32_PRECISION : tl.constexpr = 'tf32' | |
| IS_DIVISIBLE : tl.constexpr = True | |
| SM_SCALE : tl.constexpr = 0.08838834764831843 | |
| GQA_SHARED_HEADS : tl.constexpr = 4 | |
| HAS_FULL_BLOCKS : tl.constexpr = True | |
| QK_HEAD_DIM : tl.constexpr = 128 | |
| QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| V_HEAD_DIM : tl.constexpr = 128 | |
| V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| SAFE_HEAD_DIM : tl.constexpr = True | |
| USE_TMA : tl.constexpr = False | |
| BLOCK_M1 : tl.constexpr = 64 | |
| BLOCK_N1 : tl.constexpr = 128 | |
| BLOCK_M2 : tl.constexpr = 128 | |
| BLOCK_N2 : tl.constexpr = 64 | |
| SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
| SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
| INDEX_DTYPE : tl.constexpr = tl.int32 | |
| offs_n2 = kv_start + kv_offset + tl.arange(0, BLOCK_N2) | |
| offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) | |
| kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd | |
| vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd | |
| # NB reversed order to since K is transposed | |
| kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) | |
| qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) | |
| if not PRESCALE_QK: | |
| qk *= SM_SCALE | |
| # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ | |
| pre_mod_scores = qk | |
| n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) | |
| # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim | |
| # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary | |
| m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) | |
| tmp0 = (qk) | |
| post_mod_scores = tmp0 | |
| if not IS_DIVISIBLE: | |
| post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) | |
| if not IS_FULL_BLOCKS: | |
| tmp1 = (m) | |
| tmp2 = (n) | |
| tmp3 = tmp1 >= tmp2 | |
| tmp4 = tl.load(in_ptr16 + tmp1) | |
| tmp5 = tl.full([1], 0, tl.int64) | |
| tmp6 = tmp4 > tmp5 | |
| tmp7 = tl.load(in_ptr16 + tmp2) | |
| tmp8 = tmp7 > tmp5 | |
| tmp9 = tmp6 & tmp8 | |
| tmp10 = tmp4 == tmp7 | |
| tmp11 = tmp9 & tmp10 | |
| tmp12 = tmp3 | tmp11 | |
| tmp13 = tl.load(in_ptr17 + tmp1) | |
| tmp14 = tl.load(in_ptr17 + tmp2) | |
| tmp15 = tmp13 == tmp14 | |
| tmp16 = tmp12 & tmp15 | |
| tmp17 = tl.full([1], -1, tl.int64) | |
| tmp18 = tmp4 == tmp17 | |
| tmp19 = tmp7 == tmp17 | |
| tmp20 = tmp18 | tmp19 | |
| tmp21 = tmp20 == 0 | |
| tmp22 = tmp16 & tmp21 | |
| mask_mod_output = tmp22 | |
| # apply mask for partial masked block | |
| post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| if not PRESCALE_QK: | |
| post_mod_scores *= RCP_LN2 | |
| p = tl.math.exp2(post_mod_scores - lse) | |
| # Compute dP and dS. | |
| # NB reversed order to since V is transposed | |
| vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) | |
| dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) | |
| ds = p * (dp - Di[:, None]) | |
| # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ | |
| tmp23 = (ds) | |
| grad_scores = tmp23 | |
| if not IS_DIVISIBLE: | |
| grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) | |
| # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ | |
| if WRITE_DQ: | |
| scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| ds = grad_scores | |
| if not IS_FULL_BLOCKS: | |
| # (grads) apply mask for partially unmasked block | |
| ds = tl.where(mask_mod_output, ds, 0.0) | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| ds = ds.to(MATMUL_PRECISION) | |
| # Compute dQ. | |
| dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) | |
| return dq | |
| @triton.jit | |
| def bwd_dkdv_inner( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1, | |
| Q, DO, DELTA, LSE, # pointers | |
| desc_q, desc_do, q_start, | |
| dk, dv, k, v, | |
| off_z, off_hq, offs_n1, offs_m1, | |
| stride_qm, stride_qd, stride_dom, stride_dod, | |
| q_indices, sparse_q_num_blocks, | |
| MATMUL_PRECISION, | |
| IS_FULL_BLOCKS, | |
| ): | |
| PRESCALE_QK : tl.constexpr = False | |
| ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
| BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
| WRITE_DQ : tl.constexpr = True | |
| OUTPUT_LOGSUMEXP : tl.constexpr = True | |
| OUTPUT_MAX : tl.constexpr = True | |
| FLOAT32_PRECISION : tl.constexpr = 'tf32' | |
| IS_DIVISIBLE : tl.constexpr = True | |
| SM_SCALE : tl.constexpr = 0.08838834764831843 | |
| GQA_SHARED_HEADS : tl.constexpr = 4 | |
| HAS_FULL_BLOCKS : tl.constexpr = True | |
| QK_HEAD_DIM : tl.constexpr = 128 | |
| QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| V_HEAD_DIM : tl.constexpr = 128 | |
| V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| SAFE_HEAD_DIM : tl.constexpr = True | |
| USE_TMA : tl.constexpr = False | |
| BLOCK_M1 : tl.constexpr = 64 | |
| BLOCK_N1 : tl.constexpr = 128 | |
| BLOCK_M2 : tl.constexpr = 128 | |
| BLOCK_N2 : tl.constexpr = 64 | |
| SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
| SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
| INDEX_DTYPE : tl.constexpr = tl.int32 | |
| SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) | |
| RCP_LN2: tl.constexpr = 1.44269504 | |
| Q_LEN = 327680 | |
| KV_LEN = 327680 | |
| offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) | |
| offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) | |
| # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. | |
| tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) | |
| # The minimum is needed to handle the case where we run with a super large | |
| # SPARSE_BLOCK_SIZE (i.e. no block-mask!) | |
| hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) | |
| offset_block = 0 | |
| for start_m in range(0, hi): | |
| dk, dv = bwd_dkdv_block_mn( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1, | |
| dk, dv, Q, k, v, DO, DELTA, LSE, | |
| desc_q, desc_do, q_start, offset_block, Q_LEN, KV_LEN, | |
| off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, | |
| stride_qm, stride_qd, stride_dom, stride_dod, | |
| q_indices, sparse_q_num_blocks, | |
| MATMUL_PRECISION, RCP_LN2, | |
| IS_FULL_BLOCKS, | |
| ) | |
| # Increment pointers. | |
| offset = get_offset_for_next_block( | |
| start_m, q_indices, sparse_q_num_blocks, | |
| SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS | |
| ) | |
| offs_m1 += offset | |
| offset_block += offset | |
| return dk, dv | |
| @triton.jit | |
| def bwd_dkdv_block_mn( | |
| arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1, | |
| dk, dv, Q, k, v, DO, DELTA, LSE, | |
| desc_q, desc_do, q_start, offset, Q_LEN, KV_LEN, | |
| off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, | |
| stride_qm, stride_qd, stride_dom, stride_dod, | |
| q_indices, sparse_q_num_blocks, | |
| MATMUL_PRECISION, RCP_LN2, | |
| IS_FULL_BLOCKS, | |
| ): | |
| PRESCALE_QK : tl.constexpr = False | |
| ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
| BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
| WRITE_DQ : tl.constexpr = True | |
| OUTPUT_LOGSUMEXP : tl.constexpr = True | |
| OUTPUT_MAX : tl.constexpr = True | |
| FLOAT32_PRECISION : tl.constexpr = 'tf32' | |
| IS_DIVISIBLE : tl.constexpr = True | |
| SM_SCALE : tl.constexpr = 0.08838834764831843 | |
| GQA_SHARED_HEADS : tl.constexpr = 4 | |
| HAS_FULL_BLOCKS : tl.constexpr = True | |
| QK_HEAD_DIM : tl.constexpr = 128 | |
| QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| V_HEAD_DIM : tl.constexpr = 128 | |
| V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
| SAFE_HEAD_DIM : tl.constexpr = True | |
| USE_TMA : tl.constexpr = False | |
| BLOCK_M1 : tl.constexpr = 64 | |
| BLOCK_N1 : tl.constexpr = 128 | |
| BLOCK_M2 : tl.constexpr = 128 | |
| BLOCK_N2 : tl.constexpr = 64 | |
| SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
| SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
| INDEX_DTYPE : tl.constexpr = tl.int32 | |
| qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd | |
| do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod | |
| # NB reversed order since Q is transposed | |
| qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) | |
| # Load LSE before computing qk to reduce pipeline stall. | |
| if IS_DIVISIBLE: | |
| lse = tl.load(LSE + offs_m1) | |
| else: | |
| lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) | |
| lse = tl.where(lse == -float("inf"), 0.0, lse) | |
| qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) | |
| if not PRESCALE_QK: | |
| qkT *= SM_SCALE | |
| # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ | |
| m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) | |
| # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim | |
| # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary | |
| n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) | |
| pre_mod_scores = qkT | |
| tmp24 = (qkT) | |
| post_mod_scores = tmp24 | |
| if not IS_DIVISIBLE: | |
| post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) | |
| if not IS_FULL_BLOCKS: | |
| tmp25 = (m) | |
| tmp26 = (n) | |
| tmp27 = tmp25 >= tmp26 | |
| tmp28 = tl.load(in_ptr16 + tmp25) | |
| tmp29 = tl.full([1], 0, tl.int64) | |
| tmp30 = tmp28 > tmp29 | |
| tmp31 = tl.load(in_ptr16 + tmp26) | |
| tmp32 = tmp31 > tmp29 | |
| tmp33 = tmp30 & tmp32 | |
| tmp34 = tmp28 == tmp31 | |
| tmp35 = tmp33 & tmp34 | |
| tmp36 = tmp27 | tmp35 | |
| tmp37 = tl.load(in_ptr17 + tmp25) | |
| tmp38 = tl.load(in_ptr17 + tmp26) | |
| tmp39 = tmp37 == tmp38 | |
| tmp40 = tmp36 & tmp39 | |
| tmp41 = tl.full([1], -1, tl.int64) | |
| tmp42 = tmp28 == tmp41 | |
| tmp43 = tmp31 == tmp41 | |
| tmp44 = tmp42 | tmp43 | |
| tmp45 = tmp44 == 0 | |
| tmp46 = tmp40 & tmp45 | |
| mask_mod_output = tmp46 | |
| # (grads) apply mask for fully masked block | |
| post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| if not PRESCALE_QK: | |
| post_mod_scores *= RCP_LN2 | |
| pT = tl.math.exp2(post_mod_scores - lse[None, :]) | |
| do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) | |
| # Compute dV. | |
| ppT = pT | |
| dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) | |
| if IS_DIVISIBLE: | |
| Di = tl.load(DELTA + offs_m1) | |
| else: | |
| Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) | |
| # Compute dP and dS. | |
| dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) | |
| dsT = pT * (dpT - Di[None, :]) | |
| # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ | |
| tmp47 = (dsT) | |
| grad_scores = tmp47 | |
| if not IS_DIVISIBLE: | |
| grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) | |
| # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ | |
| if not WRITE_DQ: | |
| idx_b = off_z | |
| idx_h = off_hq | |
| idx_m = m | |
| idx_n = n | |
| scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| dsT = grad_scores | |
| if not IS_FULL_BLOCKS: | |
| # (grads) apply mask for partially unmasked block | |
| dsT = tl.where(mask_mod_output, dsT, 0.0) | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) | |
| return dk, dv | |
| # Utility triton funcs | |
| @triton.jit | |
| def get_offset_for_next_block( | |
| loop_iter, col_indices, total_blocks, | |
| SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, | |
| BLOCKS_ARE_CONTIGUOUS: tl.constexpr | |
| ): | |
| if BLOCKS_ARE_CONTIGUOUS: | |
| return BLOCK | |
| cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE | |
| cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") | |
| next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) | |
| needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 | |
| jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK | |
| offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK | |
| return offset | |
| @triton.jit | |
| def get_bounded_indices(indices, max_len=None): | |
| return indices % max_len if max_len is not None else indices | |
| @triton.jit | |
| def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): | |
| if IS_DIVISIBLE and SAFE_HEAD_DIM: | |
| return tl.load(block_ptr) | |
| elif IS_DIVISIBLE and not SAFE_HEAD_DIM: | |
| return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") | |
| elif not IS_DIVISIBLE and SAFE_HEAD_DIM: | |
| return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") | |
| else: | |
| return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") | |
| @triton.jit | |
| def load_checked_2d( | |
| ptr, | |
| offs_m, | |
| offs_n, | |
| stride_m, | |
| stride_n, | |
| IS_DIVISIBLE_M: tl.constexpr, | |
| IS_DIVISIBLE_N: tl.constexpr, | |
| M_LEN: tl.constexpr, | |
| N_LEN: tl.constexpr, | |
| ): | |
| # Calculate final pointer if strides are provided | |
| if stride_m is not None and stride_n is not None: | |
| ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n | |
| # Handle all masking cases | |
| if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: | |
| return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) | |
| elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: | |
| return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) | |
| elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: | |
| return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) | |
| else: # Both divisible | |
| return tl.load(ptr) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/x5/cx5sqi2rjoxy6l3kpblixq6vlpyrk2tnz5mjwgm76szawvechnfv.py | |
| # Topologically Sorted Source Nodes: [convert_element_type_452], Original ATen: [aten._to_copy] | |
| # Source node to ATen node mapping: | |
| # convert_element_type_452 => convert_element_type_452 | |
| # Graph fragment: | |
| # %buf110 : Tensor = PlaceHolder[target=buf110] | |
| # %convert_element_type_452 : Tensor "f32[512, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_9, torch.float32), kwargs = {}) | |
| # return %convert_element_type_452 | |
| triton_poi_fused__to_copy_19 = async_compile.triton('triton_poi_fused__to_copy_19', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 1048576}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_19', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 8388608}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_19(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 1048576 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/2x/c2xyq7pyq5ybvlnjstymir6orquvs2w2dnuv2qnoitnx4xrvmhzx.py | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, add_115, add_116, convert_element_type_463, convert_element_type_465, mul_179, mul_180, sum_11, mul_181, sum_12, mul_182, sub_36, sub_37, div_5, mul_183, mul_184, sum_13, sum_14, convert_element_type_467, add_117], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward] | |
| # Source node to ATen node mapping: | |
| # add_115 => add_115 | |
| # add_116 => add_116 | |
| # add_117 => add_117 | |
| # convert_element_type_463 => convert_element_type_463 | |
| # convert_element_type_465 => convert_element_type_465 | |
| # convert_element_type_467 => convert_element_type_467 | |
| # div_5 => div_5 | |
| # layer_norm => add_100, convert_element_type_376, mul_139, rsqrt_24, sub_25, var_mean_24 | |
| # mul_179 => mul_179 | |
| # mul_180 => mul_180 | |
| # mul_181 => mul_181 | |
| # mul_182 => mul_182 | |
| # mul_183 => mul_183 | |
| # mul_184 => mul_184 | |
| # redistribute => convert_element_type_374 | |
| # sub_36 => sub_36 | |
| # sub_37 => sub_37 | |
| # sum_11 => sum_11 | |
| # sum_12 => sum_12 | |
| # sum_13 => sum_13 | |
| # sum_14 => sum_14 | |
| # Graph fragment: | |
| # %mm_56 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_56] | |
| # %mm_58 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_58] | |
| # %mm_60 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_60] | |
| # %primals_155 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=primals_155] | |
| # %add_99 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_99] | |
| # %getitem_82 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=getitem_82] | |
| # %buf76 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf76] | |
| # %sum_11 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_11] | |
| # %sum_12 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_12] | |
| # %add_114 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_114] | |
| # %sub_37 : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=sub_37] | |
| # %convert_element_type_374 : Tensor "bf16[2048][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_155, torch.bfloat16), kwargs = {}) | |
| # %convert_element_type_376 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_99, torch.float32), kwargs = {}) | |
| # %var_mean_24 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_376, [1]), kwargs = {correction: 0, keepdim: True}) | |
| # %add_100 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_81, 1e-05), kwargs = {}) | |
| # %rsqrt_24 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_100,), kwargs = {}) | |
| # %sub_25 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_376, %getitem_82), kwargs = {}) | |
| # %mul_139 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_25, %rsqrt_24), kwargs = {}) | |
| # %add_115 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mm_56, %mm_58), kwargs = {}) | |
| # %add_116 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_115, %mm_60), kwargs = {}) | |
| # %convert_element_type_463 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_116, torch.float32), kwargs = {}) | |
| # %convert_element_type_465 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_374, torch.float32), kwargs = {}) | |
| # %mul_179 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_463, %convert_element_type_465), kwargs = {}) | |
| # %mul_180 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_179, 2048), kwargs = {}) | |
| # %sum_11 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_179, [1], True), kwargs = {}) | |
| # %mul_181 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_179, %mul_139), kwargs = {}) | |
| # %sum_12 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_181, [1], True), kwargs = {}) | |
| # %mul_182 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_139, %sum_12), kwargs = {}) | |
| # %sub_36 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_180, %sum_11), kwargs = {}) | |
| # %sub_37 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_36, %mul_182), kwargs = {}) | |
| # %div_5 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%rsqrt_24, 2048), kwargs = {}) | |
| # %mul_183 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_5, %sub_37), kwargs = {}) | |
| # %mul_184 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_463, %mul_139), kwargs = {}) | |
| # %sum_13 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_184, [0]), kwargs = {}) | |
| # %sum_14 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_463, [0]), kwargs = {}) | |
| # %convert_element_type_467 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_183, torch.bfloat16), kwargs = {}) | |
| # %add_117 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_114, %convert_element_type_467), kwargs = {}) | |
| # return %sum_11,%sum_12,%sub_37,%add_117 | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20 = async_compile.triton('triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 524288, 'r0_': 2048}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*fp32', 'ws_ptr': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'RSPLIT_SIZE': 'constexpr', 'NUM_STAGES': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'MixOrderReductionGrid', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 8, 'num_store': -2, 'num_reduction': 2, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'RSPLIT_SIZE': 128, 'has_loadstore_with_contiguous_rdim': True} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, ws_ptr, xnumel, r0_numel, XBLOCK : tl.constexpr, RSPLIT_SIZE : tl.constexpr, NUM_STAGES : tl.constexpr): | |
| xnumel = 327680 | |
| r0_numel = 2048 | |
| R0_BLOCK: tl.constexpr = 2048 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * RSPLIT_SIZE | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_index = tl.arange(0, R0_BLOCK)[None, :] | |
| r0_offset = 0 | |
| r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :] | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_1 = r0_index | |
| accum0 = tl.full([R0_BLOCK], 0, tl.float32)[None, :] | |
| accum1 = tl.full([R0_BLOCK], 0, tl.float32)[None, :] | |
| split_size = min(RSPLIT_SIZE, xnumel - xoffset) | |
| for _ in tl.range(0, split_size, XBLOCK, num_stages=NUM_STAGES): | |
| x0 = xindex | |
| xindex += XBLOCK | |
| tmp0 = tl.load(in_ptr0 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp1 = tl.load(in_ptr1 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp3 = tl.load(in_ptr2 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp6 = tl.load(in_ptr3 + (r0_1), None, eviction_policy='evict_last') | |
| tmp13 = tl.load(in_ptr4 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp15 = tl.load(in_ptr5 + (x0), None, eviction_policy='evict_last') | |
| tmp17 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last') | |
| tmp32 = tl.load(in_out_ptr0 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp4 = tmp2 + tmp3 | |
| tmp5 = tmp4.to(tl.float32) | |
| tmp7 = tmp6.to(tl.float32) | |
| tmp8 = tmp7.to(tl.float32) | |
| tmp9 = tmp5 * tmp8 | |
| tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) | |
| tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) | |
| tmp14 = tmp13.to(tl.float32) | |
| tmp16 = tmp14 - tmp15 | |
| tmp18 = tl.full([1, 1], 2048.0, tl.float32) | |
| tmp19 = (tmp17 / tmp18) | |
| tmp20 = tl.full([1, 1], 1e-05, tl.float32) | |
| tmp21 = tmp19 + tmp20 | |
| tmp22 = libdevice.rsqrt(tmp21) | |
| tmp23 = tmp16 * tmp22 | |
| tmp24 = tmp9 * tmp23 | |
| tmp25 = tl.broadcast_to(tmp24, [XBLOCK, R0_BLOCK]) | |
| tmp27 = tl.sum(tmp25, 1)[:, None].to(tl.float32) | |
| tmp28 = tmp9 * tmp18 | |
| tmp29 = tmp28 - tmp12 | |
| tmp30 = tmp23 * tmp27 | |
| tmp31 = tmp29 - tmp30 | |
| tmp33 = tl.full([1, 1], 0.00048828125, tl.float32) | |
| tmp34 = tmp22 * tmp33 | |
| tmp35 = tmp34 * tmp31 | |
| tmp36 = tmp35.to(tl.float32) | |
| tmp37 = tmp32 + tmp36 | |
| tmp38 = tmp5 * tmp23 | |
| tl.store(in_out_ptr0 + (r0_1 + 2048*x0), tmp37, None) | |
| tmp39 = tl.sum(tmp38, 0) | |
| tmp40 = accum0 + tmp39 | |
| accum0 = tmp40 | |
| tmp41 = tl.sum(tmp5, 0) | |
| tmp42 = accum1 + tmp41 | |
| accum1 = tmp42 | |
| tl.store(ws_ptr + (tl.program_id(0) + 0 * tl.num_programs(0)) * r0_numel + r0_index, accum0, r0_mask) | |
| tl.store(ws_ptr + (tl.program_id(0) + 1 * tl.num_programs(0)) * r0_numel + r0_index, accum1, r0_mask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/gq/cgqx5tvvnyzs57gxny2ebeaaudmxya3lic2feiq6op2k26gzmj56.py | |
| # Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_3, contiguous_1], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone] | |
| # Source node to ATen node mapping: | |
| # contiguous_1 => clone_3 | |
| # data => mul_7 | |
| # getitem => select | |
| # getitem_1 => unsqueeze_6 | |
| # getitem_3 => slice_19 | |
| # mask => eq_1 | |
| # Graph fragment: | |
| # %primals_1 : Tensor "u8[327680, 785][785, 1]cuda:0" = PlaceHolder[target=primals_1] | |
| # %select : Tensor "u8[327680][785]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.select.int](args = (%primals_1, 1, -1), kwargs = {}) | |
| # %eq_1 : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select, 2), kwargs = {}) | |
| # %unsqueeze_6 : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq_1, 1), kwargs = {}) | |
| # %mul_7 : Tensor "u8[327680, 785][785, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_1, %unsqueeze_6), kwargs = {}) | |
| # %slice_19 : Tensor "u8[327680, 16][785, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_7, 1, 768, 784), kwargs = {}) | |
| # %clone_3 : Tensor "u8[327680, 16][16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_19,), kwargs = {memory_format: torch.contiguous_format}) | |
| # return %clone_3 | |
| triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_21 = async_compile.triton('triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_21', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 8388608}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*u8', 'out_ptr0': '*u8', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_21', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 15728640}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_21(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 5242880 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x0 = (xindex % 16) | |
| x1 = xindex // 16 | |
| x2 = xindex | |
| tmp0 = tl.load(in_ptr0 + (768 + x0 + 785*x1), None) | |
| tmp1 = tl.load(in_ptr0 + (784 + 785*x1), None, eviction_policy='evict_last') | |
| tmp2 = tl.full([1], 2, tl.uint8) | |
| tmp3 = tmp1 == tmp2 | |
| tmp4 = tmp3.to(tl.uint8) | |
| tmp5 = tmp0 * tmp4 | |
| tl.store(out_ptr0 + (x2), tmp5, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/io/cioh57ngo45ltcro4ob3mshvmir5ds5aetufnjfpio2asj7ih22f.py | |
| # Topologically Sorted Source Nodes: [i, j, ifreqs, ifreqs_1, neg, freqs, getitem_6, getitem_7, mul_1, sin, cos, getitem_10, mul_3, sin_1, cos_1, posemb, to_1, layer_norm_1], Original ATen: [aten.select, aten.arange, aten.div, aten.neg, aten.pow, aten.unsqueeze, aten.mul, aten.sin, aten.cos, aten.cat, aten._to_copy, aten.native_layer_norm] | |
| # Source node to ATen node mapping: | |
| # cos => cos_1 | |
| # cos_1 => cos_2 | |
| # freqs => pow_2 | |
| # getitem_10 => unsqueeze_11 | |
| # getitem_6 => unsqueeze_7 | |
| # getitem_7 => unsqueeze_8 | |
| # i => select_4 | |
| # ifreqs => add_5, convert_element_type_16, iota_1, mul_10 | |
| # ifreqs_1 => div_2 | |
| # j => select_5 | |
| # layer_norm_1 => convert_element_type_20, var_mean_1 | |
| # mul_1 => mul_11 | |
| # mul_3 => mul_13 | |
| # neg => neg_1 | |
| # posemb => cat | |
| # sin => sin_1 | |
| # sin_1 => sin_2 | |
| # to_1 => convert_element_type_17 | |
| # Graph fragment: | |
| # %view_6 : Tensor "i32[327680, 4][4, 1]cuda:0" = PlaceHolder[target=view_6] | |
| # %cat : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=cat] | |
| # %select_4 : Tensor "i32[327680][4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%view_6, 1, 0), kwargs = {}) | |
| # %select_5 : Tensor "i32[327680][4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%view_6, 1, 1), kwargs = {}) | |
| # %iota_1 : Tensor "i64[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (512,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) | |
| # %mul_10 : Tensor "i64[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%iota_1, 4), kwargs = {}) | |
| # %add_5 : Tensor "i64[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_10, 0), kwargs = {}) | |
| # %convert_element_type_16 : Tensor "f32[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_5, torch.float32), kwargs = {}) | |
| # %div_2 : Tensor "f32[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%convert_element_type_16, 2047), kwargs = {}) | |
| # %neg_1 : Tensor "f32[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%div_2,), kwargs = {}) | |
| # %pow_2 : Tensor "f32[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Scalar](args = (10000.0, %neg_1), kwargs = {}) | |
| # %unsqueeze_7 : Tensor "i32[327680, 1][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%select_4, 1), kwargs = {}) | |
| # %unsqueeze_8 : Tensor "f32[1, 512][512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%pow_2, 0), kwargs = {}) | |
| # %mul_11 : Tensor "f32[327680, 512][512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%unsqueeze_7, %unsqueeze_8), kwargs = {}) | |
| # %sin_1 : Tensor "f32[327680, 512][512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mul_11,), kwargs = {}) | |
| # %cos_1 : Tensor "f32[327680, 512][512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%mul_11,), kwargs = {}) | |
| # %unsqueeze_11 : Tensor "i32[327680, 1][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%select_5, 1), kwargs = {}) | |
| # %mul_13 : Tensor "f32[327680, 512][512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%unsqueeze_11, %unsqueeze_8), kwargs = {}) | |
| # %sin_2 : Tensor "f32[327680, 512][512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mul_13,), kwargs = {}) | |
| # %cos_2 : Tensor "f32[327680, 512][512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%mul_13,), kwargs = {}) | |
| # %cat : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%sin_1, %cos_1, %sin_2, %cos_2], -1), kwargs = {}) | |
| # %convert_element_type_17 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%cat, torch.bfloat16), kwargs = {}) | |
| # %convert_element_type_20 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_17, torch.float32), kwargs = {}) | |
| # %var_mean_1 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_20, [1]), kwargs = {correction: 0, keepdim: True}) | |
| # return %cat,%getitem_3,%buf1547 | |
| triton_red_fused__to_copy_arange_cat_cos_div_mul_native_layer_norm_neg_pow_select_sin_unsqueeze_22 = async_compile.triton('triton_red_fused__to_copy_arange_cat_cos_div_mul_native_layer_norm_neg_pow_select_sin_unsqueeze_22', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.reduction( | |
| size_hints={'x': 524288, 'r0_': 2048}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_cat_cos_div_mul_native_layer_norm_neg_pow_select_sin_unsqueeze_22', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 3, 'num_reduction': 2, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'x': 5242880, 'r0_': 5368709120}} | |
| ) | |
| @triton.jit | |
| def triton_red_fused__to_copy_arange_cat_cos_div_mul_native_layer_norm_neg_pow_select_sin_unsqueeze_22(in_ptr0, out_ptr0, out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
| xnumel = 327680 | |
| r0_numel = 2048 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
| rbase = r0_base | |
| x0 = xindex | |
| tmp74_mean = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) | |
| tmp74_m2 = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) | |
| tmp74_weight = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) | |
| for r0_offset in tl.range(0, r0_numel, R0_BLOCK): | |
| r0_index = r0_offset + r0_base | |
| r0_mask = r0_index < r0_numel | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_1 = r0_index | |
| tmp0 = r0_1 | |
| tmp1 = tl.full([1, 1], 0, tl.int64) | |
| tmp2 = tmp0 >= tmp1 | |
| tmp3 = tl.full([1, 1], 512, tl.int64) | |
| tmp4 = tmp0 < tmp3 | |
| tmp5 = tl.load(in_ptr0 + (tl.broadcast_to(4*x0, [XBLOCK, R0_BLOCK])), r0_mask & tmp4, eviction_policy='evict_last', other=0.0) | |
| tmp6 = tmp5.to(tl.float32) | |
| tmp7 = tl.broadcast_to(4*(r0_1), [XBLOCK, R0_BLOCK]) | |
| tmp8 = tmp7.to(tl.float32) | |
| tmp9 = tl.full([1, 1], 0.0004885197850512946, tl.float32) | |
| tmp10 = tmp8 * tmp9 | |
| tmp11 = -tmp10 | |
| tmp12 = tl.full([1, 1], 10000.0, tl.float32) | |
| tmp13 = libdevice.pow(tmp12, tmp11) | |
| tmp14 = tmp6 * tmp13 | |
| tmp15 = tl_math.sin(tmp14) | |
| tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) | |
| tmp17 = tl.where(tmp4, tmp15, tmp16) | |
| tmp18 = tmp0 >= tmp3 | |
| tmp19 = tl.full([1, 1], 1024, tl.int64) | |
| tmp20 = tmp0 < tmp19 | |
| tmp21 = tmp18 & tmp20 | |
| tmp22 = tl.load(in_ptr0 + (tl.broadcast_to(4*x0, [XBLOCK, R0_BLOCK])), r0_mask & tmp21, eviction_policy='evict_last', other=0.0) | |
| tmp23 = tmp22.to(tl.float32) | |
| tmp24 = tl.broadcast_to(4*((-512) + r0_1), [XBLOCK, R0_BLOCK]) | |
| tmp25 = tmp24.to(tl.float32) | |
| tmp26 = tl.full([1, 1], 0.0004885197850512946, tl.float32) | |
| tmp27 = tmp25 * tmp26 | |
| tmp28 = -tmp27 | |
| tmp29 = tl.full([1, 1], 10000.0, tl.float32) | |
| tmp30 = libdevice.pow(tmp29, tmp28) | |
| tmp31 = tmp23 * tmp30 | |
| tmp32 = tl_math.cos(tmp31) | |
| tmp33 = tl.full(tmp32.shape, 0.0, tmp32.dtype) | |
| tmp34 = tl.where(tmp21, tmp32, tmp33) | |
| tmp35 = tmp0 >= tmp19 | |
| tmp36 = tl.full([1, 1], 1536, tl.int64) | |
| tmp37 = tmp0 < tmp36 | |
| tmp38 = tmp35 & tmp37 | |
| tmp39 = tl.load(in_ptr0 + (tl.broadcast_to(1 + 4*x0, [XBLOCK, R0_BLOCK])), r0_mask & tmp38, eviction_policy='evict_last', other=0.0) | |
| tmp40 = tmp39.to(tl.float32) | |
| tmp41 = tl.broadcast_to(4*((-1024) + r0_1), [XBLOCK, R0_BLOCK]) | |
| tmp42 = tmp41.to(tl.float32) | |
| tmp43 = tl.full([1, 1], 0.0004885197850512946, tl.float32) | |
| tmp44 = tmp42 * tmp43 | |
| tmp45 = -tmp44 | |
| tmp46 = tl.full([1, 1], 10000.0, tl.float32) | |
| tmp47 = libdevice.pow(tmp46, tmp45) | |
| tmp48 = tmp40 * tmp47 | |
| tmp49 = tl_math.sin(tmp48) | |
| tmp50 = tl.full(tmp49.shape, 0.0, tmp49.dtype) | |
| tmp51 = tl.where(tmp38, tmp49, tmp50) | |
| tmp52 = tmp0 >= tmp36 | |
| tmp53 = tl.full([1, 1], 2048, tl.int64) | |
| tmp54 = tmp0 < tmp53 | |
| tmp55 = tl.load(in_ptr0 + (tl.broadcast_to(1 + 4*x0, [XBLOCK, R0_BLOCK])), r0_mask & tmp52, eviction_policy='evict_last', other=0.0) | |
| tmp56 = tmp55.to(tl.float32) | |
| tmp57 = tl.broadcast_to(4*((-1536) + r0_1), [XBLOCK, R0_BLOCK]) | |
| tmp58 = tmp57.to(tl.float32) | |
| tmp59 = tl.full([1, 1], 0.0004885197850512946, tl.float32) | |
| tmp60 = tmp58 * tmp59 | |
| tmp61 = -tmp60 | |
| tmp62 = tl.full([1, 1], 10000.0, tl.float32) | |
| tmp63 = libdevice.pow(tmp62, tmp61) | |
| tmp64 = tmp56 * tmp63 | |
| tmp65 = tl_math.cos(tmp64) | |
| tmp66 = tl.full(tmp65.shape, 0.0, tmp65.dtype) | |
| tmp67 = tl.where(tmp52, tmp65, tmp66) | |
| tmp68 = tl.where(tmp38, tmp51, tmp67) | |
| tmp69 = tl.where(tmp21, tmp34, tmp68) | |
| tmp70 = tl.where(tmp4, tmp17, tmp69) | |
| tmp71 = tmp70.to(tl.float32) | |
| tmp72 = tmp71.to(tl.float32) | |
| tmp73 = tl.broadcast_to(tmp72, [XBLOCK, R0_BLOCK]) | |
| tmp74_mean_next, tmp74_m2_next, tmp74_weight_next = triton_helpers.welford_reduce( | |
| tmp73, tmp74_mean, tmp74_m2, tmp74_weight, roffset == 0 | |
| ) | |
| tmp74_mean = tl.where(r0_mask, tmp74_mean_next, tmp74_mean) | |
| tmp74_m2 = tl.where(r0_mask, tmp74_m2_next, tmp74_m2) | |
| tmp74_weight = tl.where(r0_mask, tmp74_weight_next, tmp74_weight) | |
| tl.store(out_ptr0 + (r0_1 + 2048*x0), tmp70, r0_mask) | |
| tmp75, tmp76, tmp77 = triton_helpers.welford(tmp74_mean, tmp74_m2, tmp74_weight, 1) | |
| tmp74 = tmp75[:, None] | |
| tmp78 = tmp76[:, None] | |
| tmp79 = tmp77[:, None] | |
| tl.store(out_ptr1 + (x0), tmp74, None) | |
| tl.store(out_ptr2 + (x0), tmp78, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/5t/c5tw3okxvozf5cmhbwpwep3xfgakwg676mtpfbahctf6vbhiy2t6.py | |
| # Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_2, patches, to, truediv, patches_1], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone, aten._to_copy, aten.div, aten.sub] | |
| # Source node to ATen node mapping: | |
| # data => mul_7 | |
| # getitem => select | |
| # getitem_1 => unsqueeze_6 | |
| # getitem_2 => slice_18 | |
| # mask => eq_1 | |
| # patches => clone_2 | |
| # patches_1 => sub | |
| # to => convert_element_type_8 | |
| # truediv => div_1 | |
| # Graph fragment: | |
| # %primals_1 : Tensor "u8[327680, 785][785, 1]cuda:0" = PlaceHolder[target=primals_1] | |
| # %select : Tensor "u8[327680][785]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.select.int](args = (%primals_1, 1, -1), kwargs = {}) | |
| # %eq_1 : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select, 2), kwargs = {}) | |
| # %unsqueeze_6 : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq_1, 1), kwargs = {}) | |
| # %mul_7 : Tensor "u8[327680, 785][785, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_1, %unsqueeze_6), kwargs = {}) | |
| # %slice_18 : Tensor "u8[327680, 768][785, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_7, 1, 0, 768), kwargs = {}) | |
| # %clone_2 : Tensor "u8[327680, 768][768, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_18,), kwargs = {memory_format: torch.contiguous_format}) | |
| # %convert_element_type_8 : Tensor "bf16[327680, 768][768, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clone_2, torch.bfloat16), kwargs = {}) | |
| # %div_1 : Tensor "bf16[327680, 768][768, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%convert_element_type_8, 127.5), kwargs = {}) | |
| # %sub : Tensor "bf16[327680, 768][768, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%div_1, 1.0), kwargs = {}) | |
| # return %sub | |
| triton_poi_fused__to_copy_clone_div_eq_mul_select_slice_sub_unsqueeze_23 = async_compile.triton('triton_poi_fused__to_copy_clone_div_eq_mul_select_slice_sub_unsqueeze_23', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 268435456}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*u8', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_div_eq_mul_select_slice_sub_unsqueeze_23', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 1258291200}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_clone_div_eq_mul_select_slice_sub_unsqueeze_23(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 251658240 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x0 = (xindex % 768) | |
| x1 = xindex // 768 | |
| x2 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0 + 785*x1), None) | |
| tmp1 = tl.load(in_ptr0 + (784 + 785*x1), None, eviction_policy='evict_last') | |
| tmp2 = tl.full([1], 2, tl.uint8) | |
| tmp3 = tmp1 == tmp2 | |
| tmp4 = tmp3.to(tl.uint8) | |
| tmp5 = tmp0 * tmp4 | |
| tmp6 = tmp5.to(tl.float32) | |
| tmp7 = tl.full([1], 0.00784313725490196, tl.float32) | |
| tmp8 = tmp6 * tmp7 | |
| tmp9 = tl.full([1], 1.0, tl.float32) | |
| tmp10 = tmp8 - tmp9 | |
| tl.store(out_ptr0 + (x2), tmp10, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/cr/ccrpjojnaxipr3oeop55qzfsbpuwnmdmswoonbypn43u5pixgimi.py | |
| # Topologically Sorted Source Nodes: [redistribute_1], Original ATen: [aten._to_copy] | |
| # Source node to ATen node mapping: | |
| # redistribute_1 => convert_element_type_9 | |
| # Graph fragment: | |
| # %primals_4 : Tensor "f32[2048, 768][768, 1]cuda:0" = PlaceHolder[target=primals_4] | |
| # %convert_element_type_9 : Tensor "bf16[2048, 768][768, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.bfloat16), kwargs = {}) | |
| # return %convert_element_type_9 | |
| triton_poi_fused__to_copy_24 = async_compile.triton('triton_poi_fused__to_copy_24', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 2097152}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_24', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 12582912}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_24(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 1572864 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/hh/chhxb3thm46b3idor7qfb5sr3aucc5huuj3uzxii67e2uiglboqx.py | |
| # Topologically Sorted Source Nodes: [x], Original ATen: [aten.native_layer_norm] | |
| # Source node to ATen node mapping: | |
| # x => convert_element_type_14, var_mean | |
| # Graph fragment: | |
| # %mm : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm] | |
| # %convert_element_type_14 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm, torch.float32), kwargs = {}) | |
| # %var_mean : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_14, [1]), kwargs = {correction: 0, keepdim: True}) | |
| # return %getitem_1,%buf1572 | |
| triton_red_fused_native_layer_norm_25 = async_compile.triton('triton_red_fused_native_layer_norm_25', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.reduction( | |
| size_hints={'x': 524288, 'r0_': 2048}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_native_layer_norm_25', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 2, 'num_reduction': 2, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'x': 5242880, 'r0_': 1342177280}} | |
| ) | |
| @triton.jit | |
| def triton_red_fused_native_layer_norm_25(in_ptr0, out_ptr0, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
| xnumel = 327680 | |
| r0_numel = 2048 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
| rbase = r0_base | |
| x0 = xindex | |
| tmp3_mean = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) | |
| tmp3_m2 = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) | |
| tmp3_weight = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) | |
| for r0_offset in tl.range(0, r0_numel, R0_BLOCK): | |
| r0_index = r0_offset + r0_base | |
| r0_mask = r0_index < r0_numel | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_1 = r0_index | |
| tmp0 = tl.load(in_ptr0 + (r0_1 + 2048*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
| tmp3_mean_next, tmp3_m2_next, tmp3_weight_next = triton_helpers.welford_reduce( | |
| tmp2, tmp3_mean, tmp3_m2, tmp3_weight, roffset == 0 | |
| ) | |
| tmp3_mean = tl.where(r0_mask, tmp3_mean_next, tmp3_mean) | |
| tmp3_m2 = tl.where(r0_mask, tmp3_m2_next, tmp3_m2) | |
| tmp3_weight = tl.where(r0_mask, tmp3_weight_next, tmp3_weight) | |
| tmp4, tmp5, tmp6 = triton_helpers.welford(tmp3_mean, tmp3_m2, tmp3_weight, 1) | |
| tmp3 = tmp4[:, None] | |
| tmp7 = tmp5[:, None] | |
| tmp8 = tmp6[:, None] | |
| tl.store(out_ptr0 + (x0), tmp3, None) | |
| tl.store(out_ptr1 + (x0), tmp7, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/ri/cri52ajrwpssovjadiworgquy76migc7fywttykfs5xlisxg6tb7.py | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, add_203, add_204, convert_element_type_1046, convert_element_type_1048, mul_465, mul_466, sum_121, mul_467, sum_122, mul_468, sub_113, sub_114, div_27, mul_469, mul_470, sum_123, sum_124, convert_element_type_1050, add_205, getitem, mask, getitem_1, mul_471, convert_element_type_1055, to_1, layer_norm_1, mul_477, redistribute_2, convert_element_type_1065, mul_479, mul_480, sum_129, x, mul_481, sum_130, mul_482, sub_119, sub_120, div_28, mul_483, mul_484, convert_element_type_1067, mul_485, slice_21, full_default, copy_3], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward, aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.zeros_like, aten.copy] | |
| # Source node to ATen node mapping: | |
| # add_203 => add_203 | |
| # add_204 => add_204 | |
| # add_205 => add_205 | |
| # convert_element_type_1046 => convert_element_type_1046 | |
| # convert_element_type_1048 => convert_element_type_1048 | |
| # convert_element_type_1050 => convert_element_type_1050 | |
| # convert_element_type_1055 => convert_element_type_1055 | |
| # convert_element_type_1065 => convert_element_type_1065 | |
| # convert_element_type_1067 => convert_element_type_1067 | |
| # copy_3 => copy_3 | |
| # div_27 => div_27 | |
| # div_28 => div_28 | |
| # full_default => full_default | |
| # getitem => select | |
| # getitem_1 => unsqueeze, unsqueeze_6 | |
| # layer_norm => add_12, convert_element_type_24, mul_18, rsqrt_2, sub_3, var_mean_2 | |
| # layer_norm_1 => add_6, convert_element_type_20, mul_15, rsqrt_1, sub_2, var_mean_1 | |
| # mask => eq, eq_1 | |
| # mul_465 => mul_465 | |
| # mul_466 => mul_466 | |
| # mul_467 => mul_467 | |
| # mul_468 => mul_468 | |
| # mul_469 => mul_469 | |
| # mul_470 => mul_470 | |
| # mul_471 => mul_471 | |
| # mul_477 => mul_477 | |
| # mul_479 => mul_479 | |
| # mul_480 => mul_480 | |
| # mul_481 => mul_481 | |
| # mul_482 => mul_482 | |
| # mul_483 => mul_483 | |
| # mul_484 => mul_484 | |
| # mul_485 => mul_485 | |
| # redistribute => convert_element_type_22 | |
| # redistribute_2 => convert_element_type_12 | |
| # slice_21 => slice_21 | |
| # sub_113 => sub_113 | |
| # sub_114 => sub_114 | |
| # sub_119 => sub_119 | |
| # sub_120 => sub_120 | |
| # sum_121 => sum_121 | |
| # sum_122 => sum_122 | |
| # sum_123 => sum_123 | |
| # sum_124 => sum_124 | |
| # sum_129 => sum_129 | |
| # sum_130 => sum_130 | |
| # to_1 => convert_element_type_17 | |
| # x => add_3, convert_element_type_14, mul_8, rsqrt, sub_1, var_mean | |
| # Graph fragment: | |
| # %mm_188 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_188] | |
| # %mm_190 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_190] | |
| # %mm_192 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_192] | |
| # %primals_9 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=primals_9] | |
| # %add_11 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_11] | |
| # %getitem_5 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=getitem_5] | |
| # %buf1473 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf1473] | |
| # %sum_121 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_121] | |
| # %sum_122 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_122] | |
| # %add_202 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_202] | |
| # %sub_114 : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=sub_114] | |
| # %primals_1 : Tensor "u8[327680, 785][785, 1]cuda:0" = PlaceHolder[target=primals_1] | |
| # %cat : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=cat] | |
| # %getitem_3 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=getitem_3] | |
| # %buf1547 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf1547] | |
| # %primals_5 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=primals_5] | |
| # %mm : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm] | |
| # %getitem_1 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=getitem_1] | |
| # %buf1572 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf1572] | |
| # %mul_479 : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mul_479] | |
| # %sum_129 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_129] | |
| # %sum_130 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_130] | |
| # %convert_element_type_22 : Tensor "bf16[2048][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_9, torch.bfloat16), kwargs = {}) | |
| # %convert_element_type_24 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_11, torch.float32), kwargs = {}) | |
| # %var_mean_2 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_24, [1]), kwargs = {correction: 0, keepdim: True}) | |
| # %add_12 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_4, 1e-05), kwargs = {}) | |
| # %rsqrt_2 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_12,), kwargs = {}) | |
| # %sub_3 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_24, %getitem_5), kwargs = {}) | |
| # %mul_18 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_3, %rsqrt_2), kwargs = {}) | |
| # %add_203 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mm_188, %mm_190), kwargs = {}) | |
| # %add_204 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_203, %mm_192), kwargs = {}) | |
| # %convert_element_type_1046 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_204, torch.float32), kwargs = {}) | |
| # %convert_element_type_1048 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_22, torch.float32), kwargs = {}) | |
| # %mul_465 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1046, %convert_element_type_1048), kwargs = {}) | |
| # %mul_466 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_465, 2048), kwargs = {}) | |
| # %sum_121 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_465, [1], True), kwargs = {}) | |
| # %mul_467 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_465, %mul_18), kwargs = {}) | |
| # %sum_122 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_467, [1], True), kwargs = {}) | |
| # %mul_468 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_18, %sum_122), kwargs = {}) | |
| # %sub_113 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_466, %sum_121), kwargs = {}) | |
| # %sub_114 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_113, %mul_468), kwargs = {}) | |
| # %div_27 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%rsqrt_2, 2048), kwargs = {}) | |
| # %mul_469 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_27, %sub_114), kwargs = {}) | |
| # %mul_470 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1046, %mul_18), kwargs = {}) | |
| # %sum_123 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_470, [0]), kwargs = {}) | |
| # %sum_124 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_1046, [0]), kwargs = {}) | |
| # %convert_element_type_1050 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_469, torch.bfloat16), kwargs = {}) | |
| # %add_205 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_202, %convert_element_type_1050), kwargs = {}) | |
| # %select : Tensor "u8[327680][785]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.select.int](args = (%primals_1, 1, -1), kwargs = {}) | |
| # %eq_1 : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select, 2), kwargs = {}) | |
| # %unsqueeze_6 : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq_1, 1), kwargs = {}) | |
| # %mul_471 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %unsqueeze_6), kwargs = {}) | |
| # %convert_element_type_1055 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_471, torch.float32), kwargs = {}) | |
| # %convert_element_type_17 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%cat, torch.bfloat16), kwargs = {}) | |
| # %convert_element_type_20 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_17, torch.float32), kwargs = {}) | |
| # %var_mean_1 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_20, [1]), kwargs = {correction: 0, keepdim: True}) | |
| # %add_6 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_2, 1e-05), kwargs = {}) | |
| # %rsqrt_1 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_6,), kwargs = {}) | |
| # %sub_2 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_20, %getitem_3), kwargs = {}) | |
| # %mul_15 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_2, %rsqrt_1), kwargs = {}) | |
| # %mul_477 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1055, %mul_15), kwargs = {}) | |
| # %convert_element_type_12 : Tensor "bf16[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_5, torch.bfloat16), kwargs = {}) | |
| # %convert_element_type_1065 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_12, torch.float32), kwargs = {}) | |
| # %mul_479 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1055, %convert_element_type_1065), kwargs = {}) | |
| # %mul_480 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_479, 2048), kwargs = {}) | |
| # %sum_129 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_479, [1], True), kwargs = {}) | |
| # %convert_element_type_14 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm, torch.float32), kwargs = {}) | |
| # %var_mean : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_14, [1]), kwargs = {correction: 0, keepdim: True}) | |
| # %add_3 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem, 1e-05), kwargs = {}) | |
| # %rsqrt : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_3,), kwargs = {}) | |
| # %sub_1 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_14, %getitem_1), kwargs = {}) | |
| # %mul_8 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_1, %rsqrt), kwargs = {}) | |
| # %mul_481 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_479, %mul_8), kwargs = {}) | |
| # %sum_130 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_481, [1], True), kwargs = {}) | |
| # %mul_482 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_8, %sum_130), kwargs = {}) | |
| # %sub_119 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_480, %sum_129), kwargs = {}) | |
| # %sub_120 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_119, %mul_482), kwargs = {}) | |
| # %div_28 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%rsqrt, 2048), kwargs = {}) | |
| # %mul_483 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_28, %sub_120), kwargs = {}) | |
| # %mul_484 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1055, %mul_8), kwargs = {}) | |
| # %convert_element_type_1067 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_483, torch.bfloat16), kwargs = {}) | |
| # %eq : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select, 1), kwargs = {}) | |
| # %unsqueeze : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq, 1), kwargs = {}) | |
| # %mul_485 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %unsqueeze), kwargs = {}) | |
| # %slice_21 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_485, 1, 1, 9223372036854775807, 2), kwargs = {}) | |
| # %full_default : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([327680, 1024], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
| # %copy_3 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_21, %full_default), kwargs = {}) | |
| # %slice_scatter_default : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%mul_485, %copy_3, 1, 1, 9223372036854775807, 2), kwargs = {}) | |
| # return %sum_121,%sum_122,%sub_114,%mul_477,%mul_479,%mul_484,%slice_scatter_default,%sum_129,%sum_130,%convert_element_type_1067 | |
| triton_per_fused__to_copy_add_copy_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_unsqueeze_zeros_like_26 = async_compile.triton('triton_per_fused__to_copy_add_copy_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_unsqueeze_zeros_like_26', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 524288, 'r0_': 2048}, | |
| reduction_hint=ReductionHint.DEFAULT, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*fp32', 'in_ptr7': '*bf16', 'in_ptr8': '*u8', 'in_ptr9': '*fp32', 'in_ptr10': '*fp32', 'in_ptr11': '*fp32', 'in_ptr12': '*bf16', 'in_ptr13': '*fp32', 'in_ptr14': '*fp32', 'out_ptr2': '*fp32', 'out_ptr4': '*fp32', 'out_ptr5': '*bf16', 'out_ptr8': '*bf16', 'ws_ptr': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'RSPLIT_SIZE': 'constexpr', 'NUM_STAGES': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]], (18,): [['tt.divisibility', 16]], (19,): [['tt.divisibility', 16]], (20,): [['tt.divisibility', 16]], (21,): [['tt.divisibility', 16]], (22,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'MixOrderReductionGrid', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_copy_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_unsqueeze_zeros_like_26', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 16, 'num_store': 0, 'num_reduction': 4, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'RSPLIT_SIZE': 128, 'has_loadstore_with_contiguous_rdim': True} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_copy_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_unsqueeze_zeros_like_26(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, out_ptr2, out_ptr4, out_ptr5, out_ptr8, ws_ptr, xnumel, r0_numel, XBLOCK : tl.constexpr, RSPLIT_SIZE : tl.constexpr, NUM_STAGES : tl.constexpr): | |
| xnumel = 327680 | |
| r0_numel = 2048 | |
| R0_BLOCK: tl.constexpr = 2048 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * RSPLIT_SIZE | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_index = tl.arange(0, R0_BLOCK)[None, :] | |
| r0_offset = 0 | |
| r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :] | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_1 = r0_index | |
| accum0 = tl.full([R0_BLOCK], 0, tl.float32)[None, :] | |
| accum1 = tl.full([R0_BLOCK], 0, tl.float32)[None, :] | |
| split_size = min(RSPLIT_SIZE, xnumel - xoffset) | |
| for _ in tl.range(0, split_size, XBLOCK, num_stages=NUM_STAGES): | |
| x0 = xindex | |
| xindex += XBLOCK | |
| tmp0 = tl.load(in_ptr0 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp1 = tl.load(in_ptr1 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp3 = tl.load(in_ptr2 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp6 = tl.load(in_ptr3 + (r0_1), None, eviction_policy='evict_last') | |
| tmp13 = tl.load(in_ptr4 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp15 = tl.load(in_ptr5 + (x0), None, eviction_policy='evict_last') | |
| tmp17 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last') | |
| tmp32 = tl.load(in_ptr7 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp38 = tl.load(in_ptr8 + (784 + 785*x0), None, eviction_policy='evict_last') | |
| tmp44 = tl.load(in_out_ptr0 + (r0_1 + 2048*x0), None) | |
| tmp47 = tl.load(in_ptr9 + (x0), None, eviction_policy='evict_last') | |
| tmp49 = tl.load(in_ptr10 + (x0), None, eviction_policy='evict_last') | |
| tmp55 = tl.load(in_ptr11 + (r0_1), None, eviction_policy='evict_last') | |
| tmp59 = tl.load(in_ptr12 + (r0_1 + 2048*x0), None).to(tl.float32) | |
| tmp61 = tl.load(in_ptr13 + (x0), None, eviction_policy='evict_last') | |
| tmp63 = tl.load(in_ptr14 + (x0), None, eviction_policy='evict_last') | |
| tmp2 = tmp0 + tmp1 | |
| tmp4 = tmp2 + tmp3 | |
| tmp5 = tmp4.to(tl.float32) | |
| tmp7 = tmp6.to(tl.float32) | |
| tmp8 = tmp7.to(tl.float32) | |
| tmp9 = tmp5 * tmp8 | |
| tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) | |
| tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) | |
| tmp14 = tmp13.to(tl.float32) | |
| tmp16 = tmp14 - tmp15 | |
| tmp18 = tl.full([1, 1], 2048.0, tl.float32) | |
| tmp19 = (tmp17 / tmp18) | |
| tmp20 = tl.full([1, 1], 1e-05, tl.float32) | |
| tmp21 = tmp19 + tmp20 | |
| tmp22 = libdevice.rsqrt(tmp21) | |
| tmp23 = tmp16 * tmp22 | |
| tmp24 = tmp9 * tmp23 | |
| tmp25 = tl.broadcast_to(tmp24, [XBLOCK, R0_BLOCK]) | |
| tmp27 = tl.sum(tmp25, 1)[:, None].to(tl.float32) | |
| tmp28 = tmp9 * tmp18 | |
| tmp29 = tmp28 - tmp12 | |
| tmp30 = tmp23 * tmp27 | |
| tmp31 = tmp29 - tmp30 | |
| tmp33 = tl.full([1, 1], 0.00048828125, tl.float32) | |
| tmp34 = tmp22 * tmp33 | |
| tmp35 = tmp34 * tmp31 | |
| tmp36 = tmp35.to(tl.float32) | |
| tmp37 = tmp32 + tmp36 | |
| tmp39 = tl.full([1, 1], 2, tl.uint8) | |
| tmp40 = tmp38 == tmp39 | |
| tmp41 = tmp40.to(tl.float32) | |
| tmp42 = tmp37 * tmp41 | |
| tmp43 = tmp42.to(tl.float32) | |
| tmp45 = tmp44.to(tl.float32) | |
| tmp46 = tmp45.to(tl.float32) | |
| tmp48 = tmp46 - tmp47 | |
| tmp50 = (tmp49 / tmp18) | |
| tmp51 = tmp50 + tmp20 | |
| tmp52 = libdevice.rsqrt(tmp51) | |
| tmp53 = tmp48 * tmp52 | |
| tmp54 = tmp43 * tmp53 | |
| tmp56 = tmp55.to(tl.float32) | |
| tmp57 = tmp56.to(tl.float32) | |
| tmp58 = tmp43 * tmp57 | |
| tmp60 = tmp59.to(tl.float32) | |
| tmp62 = tmp60 - tmp61 | |
| tmp64 = (tmp63 / tmp18) | |
| tmp65 = tmp64 + tmp20 | |
| tmp66 = libdevice.rsqrt(tmp65) | |
| tmp67 = tmp62 * tmp66 | |
| tmp68 = tmp43 * tmp67 | |
| tmp69 = r0_1 | |
| tmp70 = tl.full([1, 1], 1, tl.int64) | |
| tmp71 = tmp69 >= tmp70 | |
| tmp72 = (((-1) + r0_1) % 2) | |
| tmp73 = tl.full([1, 1], 0, tl.int64) | |
| tmp74 = tmp72 == tmp73 | |
| tmp75 = tmp71 & tmp74 | |
| tmp76 = tl.full([1, 1], 0.0, tl.float32) | |
| tmp77 = tl.full(tmp76.shape, 0.0, tmp76.dtype) | |
| tmp78 = tl.where(tmp75, tmp76, tmp77) | |
| tmp79 = tl.full([1, 1], 1, tl.uint8) | |
| tmp80 = tmp38 == tmp79 | |
| tmp81 = tmp80.to(tl.float32) | |
| tmp82 = tmp37 * tmp81 | |
| tmp83 = tl.where(tmp75, tmp78, tmp82) | |
| tmp84 = tl.broadcast_to(tmp58, [XBLOCK, R0_BLOCK]) | |
| tmp86 = tl.sum(tmp84, 1)[:, None].to(tl.float32) | |
| tmp87 = tmp58 * tmp67 | |
| tmp88 = tl.broadcast_to(tmp87, [XBLOCK, R0_BLOCK]) | |
| tmp90 = tl.sum(tmp88, 1)[:, None].to(tl.float32) | |
| tmp91 = tmp66 * tmp33 | |
| tmp92 = tmp58 * tmp18 | |
| tmp93 = tmp92 - tmp86 | |
| tmp94 = tmp67 * tmp90 | |
| tmp95 = tmp93 - tmp94 | |
| tmp96 = tmp91 * tmp95 | |
| tmp97 = tmp96.to(tl.float32) | |
| tmp98 = tmp5 * tmp23 | |
| tl.store(out_ptr2 + (r0_1 + 2048*x0), tmp31, None) | |
| tl.store(in_out_ptr0 + (r0_1 + 2048*x0), tmp54, None) | |
| tl.store(out_ptr4 + (r0_1 + 2048*x0), tmp68, None) | |
| tl.store(out_ptr5 + (r0_1 + 2048*x0), tmp83, None) | |
| tl.store(out_ptr8 + (r0_1 + 2048*x0), tmp97, None) | |
| tmp99 = tl.sum(tmp98, 0) | |
| tmp100 = accum0 + tmp99 | |
| accum0 = tmp100 | |
| tmp101 = tl.sum(tmp5, 0) | |
| tmp102 = accum1 + tmp101 | |
| accum1 = tmp102 | |
| tl.store(ws_ptr + (tl.program_id(0) + 0 * tl.num_programs(0)) * r0_numel + r0_index, accum0, r0_mask) | |
| tl.store(ws_ptr + (tl.program_id(0) + 1 * tl.num_programs(0)) * r0_numel + r0_index, accum1, r0_mask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/j7/cj7bnwjnv456srouilgyxwlg3q2bznu3ssquf7zbuoyp7gjfxyik.py | |
| # Topologically Sorted Source Nodes: [sum_127], Original ATen: [aten.native_layer_norm_backward] | |
| # Source node to ATen node mapping: | |
| # sum_127 => sum_127 | |
| # Graph fragment: | |
| # %mul_477 : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mul_477] | |
| # %sum_127 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_477, [0]), kwargs = {}) | |
| # return %buf1550 | |
| triton_red_fused_native_layer_norm_backward_27 = async_compile.triton('triton_red_fused_native_layer_norm_backward_27', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.reduction( | |
| size_hints={'x': 524288, 'r0_': 2048}, | |
| reduction_hint=ReductionHint.OUTER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_native_layer_norm_backward_27', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 2686976000, 'r0_': 0}} | |
| ) | |
| @triton.jit | |
| def triton_red_fused_native_layer_norm_backward_27(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
| xnumel = 327680 | |
| r0_numel = 2048 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
| rbase = r0_base | |
| x0 = (xindex % 2048) | |
| x1 = xindex // 2048 | |
| _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
| x3 = xindex | |
| for r0_offset in tl.range(0, r0_numel, R0_BLOCK): | |
| r0_index = r0_offset + r0_base | |
| r0_mask = r0_index < r0_numel | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_2 = r0_index | |
| tmp0 = tl.load(in_ptr0 + (x0 + 2048*r0_2 + 4194304*x1), r0_mask, eviction_policy='evict_first', other=0.0) | |
| tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) | |
| tmp3 = _tmp2 + tmp1 | |
| _tmp2 = tl.where(r0_mask, tmp3, _tmp2) | |
| tmp2 = tl.sum(_tmp2, 1)[:, None] | |
| tl.store(out_ptr0 + (x3), tmp2, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/h6/ch6htjq4lzhko3lbglu5qfgog4kxolyrgb4ics57nzdmzydgmext.py | |
| # Topologically Sorted Source Nodes: [sum_127, convert_element_type_1059, all_reduce_147], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| # Source node to ATen node mapping: | |
| # all_reduce_147 => all_reduce_147 | |
| # convert_element_type_1059 => convert_element_type_1059 | |
| # sum_127 => sum_127 | |
| # Graph fragment: | |
| # %buf1550 : Tensor "f32[2048, 160][1, 2048]cuda:0" = PlaceHolder[target=buf1550] | |
| # %sum_127 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=sum_127] | |
| # %sum_127 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_477, [0]), kwargs = {}) | |
| # %convert_element_type_1059 : Tensor "bf16[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_127, torch.bfloat16), kwargs = {}) | |
| # %all_reduce_147 : Tensor "bf16[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce_.default](args = (%convert_element_type_1059, avg, 0), kwargs = {}) | |
| # return %sum_127,%wait_tensor_147 | |
| triton_red_fused_all_reduce_native_layer_norm_backward_28 = async_compile.triton('triton_red_fused_all_reduce_native_layer_norm_backward_28', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.reduction( | |
| size_hints={'x': 2048, 'r0_': 256}, | |
| reduction_hint=ReductionHint.OUTER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_all_reduce_native_layer_norm_backward_28', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 1318912, 'r0_': 0}} | |
| ) | |
| @triton.jit | |
| def triton_red_fused_all_reduce_native_layer_norm_backward_28(in_ptr0, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
| xnumel = 2048 | |
| r0_numel = 160 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = xindex < xnumel | |
| r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
| rbase = r0_base | |
| x0 = xindex | |
| _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
| for r0_offset in tl.range(0, r0_numel, R0_BLOCK): | |
| r0_index = r0_offset + r0_base | |
| r0_mask = r0_index < r0_numel | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_1 = r0_index | |
| tmp0 = tl.load(in_ptr0 + (x0 + 2048*r0_1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) | |
| tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) | |
| tmp3 = _tmp2 + tmp1 | |
| _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) | |
| tmp2 = tl.sum(_tmp2, 1)[:, None] | |
| tmp4 = tmp2.to(tl.float32) | |
| tl.store(out_ptr1 + (x0), tmp4, xmask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/pa/cpa4kcho55jmuviwxacygdlktmp5752fifrqu64nhoefgqtwxj23.py | |
| # Topologically Sorted Source Nodes: [layer_norm, div_27, mul_469, convert_element_type_1050, add_205, getitem, mask, getitem_1, mul_471, convert_element_type_1055, sum_128], Original ATen: [aten.native_layer_norm, aten.native_layer_norm_backward, aten.add, aten.select, aten.eq, aten.unsqueeze, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_205 => add_205 | |
| # convert_element_type_1050 => convert_element_type_1050 | |
| # convert_element_type_1055 => convert_element_type_1055 | |
| # div_27 => div_27 | |
| # getitem => select | |
| # getitem_1 => unsqueeze_6 | |
| # layer_norm => add_12, convert_element_type_24, rsqrt_2, var_mean_2 | |
| # mask => eq_1 | |
| # mul_469 => mul_469 | |
| # mul_471 => mul_471 | |
| # sum_128 => sum_128 | |
| # Graph fragment: | |
| # %add_202 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_202] | |
| # %buf1473 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf1473] | |
| # %sub_114 : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=sub_114] | |
| # %primals_1 : Tensor "u8[327680, 785][785, 1]cuda:0" = PlaceHolder[target=primals_1] | |
| # %convert_element_type_24 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_11, torch.float32), kwargs = {}) | |
| # %var_mean_2 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_24, [1]), kwargs = {correction: 0, keepdim: True}) | |
| # %add_12 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_4, 1e-05), kwargs = {}) | |
| # %rsqrt_2 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_12,), kwargs = {}) | |
| # %div_27 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%rsqrt_2, 2048), kwargs = {}) | |
| # %mul_469 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_27, %sub_114), kwargs = {}) | |
| # %convert_element_type_1050 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_469, torch.bfloat16), kwargs = {}) | |
| # %add_205 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_202, %convert_element_type_1050), kwargs = {}) | |
| # %select : Tensor "u8[327680][785]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.select.int](args = (%primals_1, 1, -1), kwargs = {}) | |
| # %eq_1 : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select, 2), kwargs = {}) | |
| # %unsqueeze_6 : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq_1, 1), kwargs = {}) | |
| # %mul_471 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %unsqueeze_6), kwargs = {}) | |
| # %convert_element_type_1055 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_471, torch.float32), kwargs = {}) | |
| # %sum_128 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_1055, [0]), kwargs = {}) | |
| # return %buf1552 | |
| triton_red_fused_add_eq_mul_native_layer_norm_native_layer_norm_backward_select_unsqueeze_29 = async_compile.triton('triton_red_fused_add_eq_mul_native_layer_norm_native_layer_norm_backward_select_unsqueeze_29', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.reduction( | |
| size_hints={'x': 524288, 'r0_': 2048}, | |
| reduction_hint=ReductionHint.OUTER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*u8', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_eq_mul_native_layer_norm_native_layer_norm_backward_select_unsqueeze_29', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'x': 4029153280, 'r0_': 1310720}} | |
| ) | |
| @triton.jit | |
| def triton_red_fused_add_eq_mul_native_layer_norm_native_layer_norm_backward_select_unsqueeze_29(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
| xnumel = 327680 | |
| r0_numel = 2048 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
| rbase = r0_base | |
| x0 = (xindex % 2048) | |
| x1 = xindex // 2048 | |
| _tmp20 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
| x3 = xindex | |
| for r0_offset in tl.range(0, r0_numel, R0_BLOCK): | |
| r0_index = r0_offset + r0_base | |
| r0_mask = r0_index < r0_numel | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_2 = r0_index | |
| tmp0 = tl.load(in_ptr0 + (x0 + 2048*r0_2 + 4194304*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
| tmp1 = tl.load(in_ptr1 + (r0_2 + 2048*x1), r0_mask, eviction_policy='evict_last', other=0.0) | |
| tmp9 = tl.load(in_ptr2 + (x0 + 2048*r0_2 + 4194304*x1), r0_mask, eviction_policy='evict_first', other=0.0) | |
| tmp13 = tl.load(in_ptr3 + (784 + 785*r0_2 + 1607680*x1), r0_mask, eviction_policy='evict_last', other=0.0) | |
| tmp2 = tl.full([1, 1], 2048.0, tl.float32) | |
| tmp3 = (tmp1 / tmp2) | |
| tmp4 = tl.full([1, 1], 1e-05, tl.float32) | |
| tmp5 = tmp3 + tmp4 | |
| tmp6 = libdevice.rsqrt(tmp5) | |
| tmp7 = tl.full([1, 1], 0.00048828125, tl.float32) | |
| tmp8 = tmp6 * tmp7 | |
| tmp10 = tmp8 * tmp9 | |
| tmp11 = tmp10.to(tl.float32) | |
| tmp12 = tmp0 + tmp11 | |
| tmp14 = tl.full([1, 1], 2, tl.uint8) | |
| tmp15 = tmp13 == tmp14 | |
| tmp16 = tmp15.to(tl.float32) | |
| tmp17 = tmp12 * tmp16 | |
| tmp18 = tmp17.to(tl.float32) | |
| tmp19 = tl.broadcast_to(tmp18, [XBLOCK, R0_BLOCK]) | |
| tmp21 = _tmp20 + tmp19 | |
| _tmp20 = tl.where(r0_mask, tmp21, _tmp20) | |
| tmp20 = tl.sum(_tmp20, 1)[:, None] | |
| tl.store(out_ptr0 + (x3), tmp20, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/4f/c4fosyx7zmpcnmjxjiskxj5txwzq4rjvox5xazko4zt622kteyye.py | |
| # Topologically Sorted Source Nodes: [convert_element_type_1061, convert_element_type_1070], Original ATen: [aten._to_copy] | |
| # Source node to ATen node mapping: | |
| # convert_element_type_1061 => convert_element_type_1061 | |
| # convert_element_type_1070 => convert_element_type_1070 | |
| # Graph fragment: | |
| # %buf1558 : Tensor = PlaceHolder[target=buf1558] | |
| # %convert_element_type_1061 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_146, torch.float32), kwargs = {}) | |
| # %convert_element_type_1070 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_146, torch.float32), kwargs = {}) | |
| # return %convert_element_type_1061,%convert_element_type_1070 | |
| triton_poi_fused__to_copy_30 = async_compile.triton('triton_poi_fused__to_copy_30', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 2048}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_30', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 2, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 32768}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_30(in_ptr0, out_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 2048 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = xindex < xnumel | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, xmask) | |
| tl.store(out_ptr1 + (x0), tmp1, xmask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/vv/cvvvlepxuobkrbn55fwahnnr5slz6gc3plw3ii5n7qpgwcdewpmm.py | |
| # Topologically Sorted Source Nodes: [convert_element_type_1074], Original ATen: [aten._to_copy] | |
| # Source node to ATen node mapping: | |
| # convert_element_type_1074 => convert_element_type_1074 | |
| # Graph fragment: | |
| # %buf1590 : Tensor = PlaceHolder[target=buf1590] | |
| # %convert_element_type_1074 : Tensor "f32[2048, 768][768, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_150, torch.float32), kwargs = {}) | |
| # return %convert_element_type_1074 | |
| triton_poi_fused__to_copy_31 = async_compile.triton('triton_poi_fused__to_copy_31', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 2097152}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_31', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 12582912}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_31(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 1572864 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/pe/cpea3exprjpgbfkgsy3hdvwmv3wa542d5oqiyyahnqjznbrvarss.py | |
| # Topologically Sorted Source Nodes: [layer_norm, div_27, mul_469, convert_element_type_1050, add_205, getitem, mask, getitem_1, mul_485, slice_21, clone_112, full_default_1], Original ATen: [aten.native_layer_norm, aten.native_layer_norm_backward, aten.add, aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone, aten.slice_backward] | |
| # Source node to ATen node mapping: | |
| # add_205 => add_205 | |
| # clone_112 => clone_112 | |
| # convert_element_type_1050 => convert_element_type_1050 | |
| # div_27 => div_27 | |
| # full_default_1 => full_default_1 | |
| # getitem => select | |
| # getitem_1 => unsqueeze | |
| # layer_norm => add_12, convert_element_type_24, rsqrt_2, var_mean_2 | |
| # mask => eq | |
| # mul_469 => mul_469 | |
| # mul_485 => mul_485 | |
| # slice_21 => slice_21 | |
| # Graph fragment: | |
| # %add_202 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_202] | |
| # %buf1473 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf1473] | |
| # %sub_114 : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=sub_114] | |
| # %primals_1 : Tensor "u8[327680, 785][785, 1]cuda:0" = PlaceHolder[target=primals_1] | |
| # %convert_element_type_24 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_11, torch.float32), kwargs = {}) | |
| # %var_mean_2 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_24, [1]), kwargs = {correction: 0, keepdim: True}) | |
| # %add_12 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_4, 1e-05), kwargs = {}) | |
| # %rsqrt_2 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_12,), kwargs = {}) | |
| # %div_27 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%rsqrt_2, 2048), kwargs = {}) | |
| # %mul_469 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_27, %sub_114), kwargs = {}) | |
| # %convert_element_type_1050 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_469, torch.bfloat16), kwargs = {}) | |
| # %add_205 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_202, %convert_element_type_1050), kwargs = {}) | |
| # %select : Tensor "u8[327680][785]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.select.int](args = (%primals_1, 1, -1), kwargs = {}) | |
| # %eq : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select, 1), kwargs = {}) | |
| # %unsqueeze : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq, 1), kwargs = {}) | |
| # %mul_485 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %unsqueeze), kwargs = {}) | |
| # %slice_21 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_485, 1, 1, 9223372036854775807, 2), kwargs = {}) | |
| # %clone_112 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_21,), kwargs = {memory_format: torch.contiguous_format}) | |
| # %full_default_1 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([327680, 2048], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
| # %slice_scatter_default_1 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_1, %clone_112, 1, 1, 9223372036854775807, 2), kwargs = {}) | |
| # return %slice_scatter_default_1 | |
| triton_poi_fused_add_clone_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_slice_backward_unsqueeze_32 = async_compile.triton('triton_poi_fused_add_clone_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_slice_backward_unsqueeze_32', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 1073741824}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*u8', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_clone_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_slice_backward_unsqueeze_32', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 2684354560}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_add_clone_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_slice_backward_unsqueeze_32(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 671088640 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x0 = (xindex % 2048) | |
| x1 = xindex // 2048 | |
| x2 = xindex | |
| tmp0 = x0 | |
| tmp1 = tl.full([1], 1, tl.int64) | |
| tmp2 = tmp0 >= tmp1 | |
| tmp3 = (((-1) + x0) % 2) | |
| tmp4 = tl.full([1], 0, tl.int64) | |
| tmp5 = tmp3 == tmp4 | |
| tmp6 = tmp2 & tmp5 | |
| tmp7 = tl.load(in_ptr0 + (1 + 2*(triton_helpers.div_floor_integer((-1) + x0, 2)) + 2048*x1), tmp6, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp8 = tl.load(in_ptr1 + (x1), tmp6, eviction_policy='evict_last', other=0.0) | |
| tmp9 = tl.full([1], 2048.0, tl.float32) | |
| tmp10 = (tmp8 / tmp9) | |
| tmp11 = tl.full([1], 1e-05, tl.float32) | |
| tmp12 = tmp10 + tmp11 | |
| tmp13 = libdevice.rsqrt(tmp12) | |
| tmp14 = tl.full([1], 0.00048828125, tl.float32) | |
| tmp15 = tmp13 * tmp14 | |
| tmp16 = tl.load(in_ptr2 + (1 + 2*(triton_helpers.div_floor_integer((-1) + x0, 2)) + 2048*x1), tmp6, eviction_policy='evict_last', other=0.0) | |
| tmp17 = tmp15 * tmp16 | |
| tmp18 = tmp17.to(tl.float32) | |
| tmp19 = tmp7 + tmp18 | |
| tmp20 = tl.load(in_ptr3 + (784 + 785*x1), tmp6, eviction_policy='evict_last', other=0.0) | |
| tmp21 = tl.full([1], 1, tl.uint8) | |
| tmp22 = tmp20 == tmp21 | |
| tmp23 = tmp22.to(tl.float32) | |
| tmp24 = tmp19 * tmp23 | |
| tmp25 = tl.full(tmp24.shape, 0.0, tmp24.dtype) | |
| tmp26 = tl.where(tmp6, tmp24, tmp25) | |
| tmp27 = tl.full([1], 0.0, tl.float32) | |
| tmp28 = tl.where(tmp6, tmp26, tmp27) | |
| tl.store(out_ptr0 + (x2), tmp28, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/33/c33os57d226grg6ristuoxyi6krgwdsthxorz2uor6zyzc6wuzww.py | |
| # Topologically Sorted Source Nodes: [add_206, slice_24, clone_113, convert_element_type_1075, mul_486, sum_133], Original ATen: [aten.add, aten.slice, aten.clone, aten._to_copy, aten.mul, aten.sum] | |
| # Source node to ATen node mapping: | |
| # add_206 => add_206 | |
| # clone_113 => clone_113 | |
| # convert_element_type_1075 => convert_element_type_1075 | |
| # mul_486 => mul_486 | |
| # slice_24 => slice_24 | |
| # sum_133 => sum_133 | |
| # Graph fragment: | |
| # %slice_scatter_default : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=slice_scatter_default] | |
| # %slice_scatter_default_1 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=slice_scatter_default_1] | |
| # %cos : Tensor "f32[327680, 1024][1024, 1]cuda:0" = PlaceHolder[target=cos] | |
| # %add_206 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {}) | |
| # %slice_24 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add_206, 1, 1, 9223372036854775807, 2), kwargs = {}) | |
| # %clone_113 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.clone.default](args = (%slice_24,), kwargs = {memory_format: torch.contiguous_format}) | |
| # %convert_element_type_1075 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clone_113, torch.float32), kwargs = {}) | |
| # %mul_486 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1075, %cos), kwargs = {}) | |
| # %sum_133 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_486,), kwargs = {}) | |
| # return %buf1595 | |
| triton_red_fused__to_copy_add_clone_mul_slice_sum_33 = async_compile.triton('triton_red_fused__to_copy_add_clone_mul_slice_sum_33', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.reduction( | |
| size_hints={'x': 262144, 'r0_': 2048}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_clone_mul_slice_sum_33', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'x': 1310720, 'r0_': 1342177280}} | |
| ) | |
| @triton.jit | |
| def triton_red_fused__to_copy_add_clone_mul_slice_sum_33(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
| xnumel = 163840 | |
| r0_numel = 2048 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
| rbase = r0_base | |
| x0 = xindex | |
| _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
| for r0_offset in tl.range(0, r0_numel, R0_BLOCK): | |
| r0_index = r0_offset + r0_base | |
| r0_mask = r0_index < r0_numel | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_1 = r0_index | |
| tmp0 = tl.load(in_ptr0 + (1 + 2*r0_1 + 4096*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp1 = tl.load(in_ptr1 + (1 + 2*r0_1 + 4096*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp4 = tl.load(in_ptr2 + (r0_1 + 2048*x0), r0_mask, eviction_policy='evict_first', other=0.0) | |
| tmp2 = tmp0 + tmp1 | |
| tmp3 = tmp2.to(tl.float32) | |
| tmp5 = tmp3 * tmp4 | |
| tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) | |
| tmp8 = _tmp7 + tmp6 | |
| _tmp7 = tl.where(r0_mask, tmp8, _tmp7) | |
| tmp7 = tl.sum(_tmp7, 1)[:, None] | |
| tl.store(out_ptr0 + (x0), tmp7, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/sh/cshwswgpmiqr5iwybcqpwf4nwd2h47twq7yufqtiqf5ytlbgoe4o.py | |
| # Topologically Sorted Source Nodes: [add_206, slice_24, clone_113, convert_element_type_1075, mul_486, sum_133], Original ATen: [aten.add, aten.slice, aten.clone, aten._to_copy, aten.mul, aten.sum] | |
| # Source node to ATen node mapping: | |
| # add_206 => add_206 | |
| # clone_113 => clone_113 | |
| # convert_element_type_1075 => convert_element_type_1075 | |
| # mul_486 => mul_486 | |
| # slice_24 => slice_24 | |
| # sum_133 => sum_133 | |
| # Graph fragment: | |
| # %buf1595 : Tensor "f32[2560, 64][64, 1]cuda:0" = PlaceHolder[target=buf1595] | |
| # %add_206 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {}) | |
| # %slice_24 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add_206, 1, 1, 9223372036854775807, 2), kwargs = {}) | |
| # %clone_113 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.clone.default](args = (%slice_24,), kwargs = {memory_format: torch.contiguous_format}) | |
| # %convert_element_type_1075 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clone_113, torch.float32), kwargs = {}) | |
| # %mul_486 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1075, %cos), kwargs = {}) | |
| # %sum_133 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_486,), kwargs = {}) | |
| # return %buf1596 | |
| triton_per_fused__to_copy_add_clone_mul_slice_sum_34 = async_compile.triton('triton_per_fused__to_copy_add_clone_mul_slice_sum_34', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 4096, 'r0_': 64}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_clone_mul_slice_sum_34', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'x': 20480, 'r0_': 655360}} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_clone_mul_slice_sum_34(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr): | |
| xnumel = 2560 | |
| r0_numel = 64 | |
| R0_BLOCK: tl.constexpr = 64 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = xindex < xnumel | |
| r0_index = tl.arange(0, R0_BLOCK)[None, :] | |
| r0_offset = 0 | |
| r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :] | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_1 = r0_index | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (r0_1 + 64*x0), xmask, other=0.0) | |
| tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) | |
| tmp3 = tl.where(xmask, tmp1, 0) | |
| tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp4, xmask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/bv/cbvk7epkxauekcwgof5v56sdrua37n7ccsnkbqppguepvf2mfyzm.py | |
| # Topologically Sorted Source Nodes: [add_206, slice_24, clone_113, convert_element_type_1075, mul_486, sum_133, convert_element_type_1076], Original ATen: [aten.add, aten.slice, aten.clone, aten._to_copy, aten.mul, aten.sum] | |
| # Source node to ATen node mapping: | |
| # add_206 => add_206 | |
| # clone_113 => clone_113 | |
| # convert_element_type_1075 => convert_element_type_1075 | |
| # convert_element_type_1076 => convert_element_type_1076 | |
| # mul_486 => mul_486 | |
| # slice_24 => slice_24 | |
| # sum_133 => sum_133 | |
| # Graph fragment: | |
| # %buf1596 : Tensor "f32[2560][1]cuda:0" = PlaceHolder[target=buf1596] | |
| # %sum_133 : Tensor "f32[][]cuda:0" = PlaceHolder[target=sum_133] | |
| # %add_206 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {}) | |
| # %slice_24 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add_206, 1, 1, 9223372036854775807, 2), kwargs = {}) | |
| # %clone_113 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.clone.default](args = (%slice_24,), kwargs = {memory_format: torch.contiguous_format}) | |
| # %convert_element_type_1075 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clone_113, torch.float32), kwargs = {}) | |
| # %mul_486 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1075, %cos), kwargs = {}) | |
| # %sum_133 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_486,), kwargs = {}) | |
| # %convert_element_type_1076 : Tensor "bf16[][]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_133, torch.bfloat16), kwargs = {}) | |
| # return %sum_133,%wait_tensor_151 | |
| triton_red_fused__to_copy_add_clone_mul_slice_sum_35 = async_compile.triton('triton_red_fused__to_copy_add_clone_mul_slice_sum_35', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.reduction( | |
| size_hints={'x': 1, 'r0_': 4096}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_clone_mul_slice_sum_35', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'r0_': 10240}} | |
| ) | |
| @triton.jit | |
| def triton_red_fused__to_copy_add_clone_mul_slice_sum_35(in_ptr0, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
| xnumel = 1 | |
| r0_numel = 2560 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
| rbase = r0_base | |
| _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
| for r0_offset in tl.range(0, r0_numel, R0_BLOCK): | |
| r0_index = r0_offset + r0_base | |
| r0_mask = r0_index < r0_numel | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_0 = r0_index | |
| tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) | |
| tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) | |
| tmp3 = _tmp2 + tmp1 | |
| _tmp2 = tl.where(r0_mask, tmp3, _tmp2) | |
| tmp2 = tl.sum(_tmp2, 1)[:, None] | |
| tmp4 = tmp2.to(tl.float32) | |
| tl.store(out_ptr1 + (tl.full([1, 1], 0, tl.int32).broadcast_to(XBLOCK, 1)), tmp4, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/ul/cul436w24uow7x4gca343floynluerpavkg3l7at6smg35fvjsfe.py | |
| # Topologically Sorted Source Nodes: [full_default, full_default_1, add_206, slice_24, clone_113, copy_5, slice_27, clone_114, copy_7, add_207], Original ATen: [aten.zeros_like, aten.slice_backward, aten.add, aten.slice, aten.clone, aten.copy] | |
| # Source node to ATen node mapping: | |
| # add_206 => add_206 | |
| # add_207 => add_207 | |
| # clone_113 => clone_113 | |
| # clone_114 => clone_114 | |
| # copy_5 => copy_5 | |
| # copy_7 => copy_7 | |
| # full_default => full_default | |
| # full_default_1 => full_default_1 | |
| # slice_24 => slice_24 | |
| # slice_27 => slice_27 | |
| # Graph fragment: | |
| # %slice_scatter_default : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=slice_scatter_default] | |
| # %slice_scatter_default_1 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=slice_scatter_default_1] | |
| # %full_default : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([327680, 1024], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
| # %full_default_1 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([327680, 2048], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
| # %add_206 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {}) | |
| # %slice_24 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add_206, 1, 1, 9223372036854775807, 2), kwargs = {}) | |
| # %clone_113 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.clone.default](args = (%slice_24,), kwargs = {memory_format: torch.contiguous_format}) | |
| # %copy_5 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_24, %clone_113), kwargs = {}) | |
| # %slice_scatter_default_2 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%add_206, %copy_5, 1, 1, 9223372036854775807, 2), kwargs = {}) | |
| # %slice_27 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_scatter_default_2, 1, 0, 9223372036854775807, 2), kwargs = {}) | |
| # %clone_114 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_27,), kwargs = {memory_format: torch.contiguous_format}) | |
| # %copy_7 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_27, %full_default), kwargs = {}) | |
| # %slice_scatter_default_3 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_scatter_default_2, %copy_7, 1, 0, 9223372036854775807, 2), kwargs = {}) | |
| # %slice_scatter_default_4 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_1, %clone_114, 1, 0, 9223372036854775807, 2), kwargs = {}) | |
| # %add_207 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_3, %slice_scatter_default_4), kwargs = {}) | |
| # return %add_207 | |
| triton_poi_fused_add_clone_copy_slice_slice_backward_zeros_like_36 = async_compile.triton('triton_poi_fused_add_clone_copy_slice_slice_backward_zeros_like_36', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 1073741824}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_clone_copy_slice_slice_backward_zeros_like_36', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 8, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 10737418240}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_add_clone_copy_slice_slice_backward_zeros_like_36(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 671088640 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x2 = xindex | |
| x0 = (xindex % 2048) | |
| x1 = xindex // 2048 | |
| tmp17 = tl.load(in_ptr0 + (x2), None).to(tl.float32) | |
| tmp18 = tl.load(in_ptr1 + (x2), None).to(tl.float32) | |
| tmp0 = (x2 % 2) | |
| tmp1 = tl.full([1], 0, tl.int64) | |
| tmp2 = tmp0 == tmp1 | |
| tmp3 = tl.full([1], 0.0, tl.float32) | |
| tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype) | |
| tmp5 = tl.where(tmp2, tmp3, tmp4) | |
| tmp6 = x0 | |
| tmp7 = tl.full([1], 1, tl.int64) | |
| tmp8 = tmp6 >= tmp7 | |
| tmp9 = (((-1) + x0) % 2) | |
| tmp10 = tmp9 == tmp1 | |
| tmp11 = tmp8 & tmp10 | |
| tmp12 = tl.load(in_ptr0 + (1 + 2*(triton_helpers.div_floor_integer((-1) + x0, 2)) + 2048*x1), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp13 = tl.load(in_ptr1 + (1 + 2*(triton_helpers.div_floor_integer((-1) + x0, 2)) + 2048*x1), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp14 = tmp12 + tmp13 | |
| tmp15 = tl.full(tmp14.shape, 0.0, tmp14.dtype) | |
| tmp16 = tl.where(tmp11, tmp14, tmp15) | |
| tmp19 = tmp17 + tmp18 | |
| tmp20 = tl.where(tmp11, tmp16, tmp19) | |
| tmp21 = tl.where(tmp2, tmp5, tmp20) | |
| tmp22 = 2*(x0 // 2) | |
| tmp23 = tl.full([1], 1, tl.int64) | |
| tmp24 = tmp22 >= tmp23 | |
| tmp25 = (((-1) + 2*(x0 // 2)) % 2) | |
| tmp26 = tl.full([1], 0, tl.int64) | |
| tmp27 = tmp25 == tmp26 | |
| tmp28 = tmp24 & tmp27 | |
| tmp29 = tmp28 & tmp2 | |
| tmp30 = tl.load(in_ptr0 + ((-1) + 2*(x0 // 2) + 2048*x1), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp31 = tl.load(in_ptr1 + ((-1) + 2*(x0 // 2) + 2048*x1), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp32 = tmp30 + tmp31 | |
| tmp33 = tl.full(tmp32.shape, 0.0, tmp32.dtype) | |
| tmp34 = tl.where(tmp29, tmp32, tmp33) | |
| tmp35 = tl.load(in_ptr0 + (2*(x0 // 2) + 2048*x1), tmp2, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp36 = tl.load(in_ptr1 + (2*(x0 // 2) + 2048*x1), tmp2, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp37 = tmp35 + tmp36 | |
| tmp38 = tl.where(tmp28, tmp34, tmp37) | |
| tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype) | |
| tmp40 = tl.where(tmp2, tmp38, tmp39) | |
| tmp41 = tl.full([1], 0.0, tl.float32) | |
| tmp42 = tl.where(tmp2, tmp40, tmp41) | |
| tmp43 = tmp21 + tmp42 | |
| tl.store(out_ptr0 + (x2), tmp43, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/fo/cfoz6dfisizdj2zmeupbzqtnjgko6ew4rh2fhiswleuzftetokpz.py | |
| # Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_4, contiguous_1], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone] | |
| # Source node to ATen node mapping: | |
| # contiguous_1 => clone_1 | |
| # data => mul | |
| # getitem => select | |
| # getitem_1 => unsqueeze | |
| # getitem_4 => slice_2 | |
| # mask => eq | |
| # Graph fragment: | |
| # %primals_1 : Tensor "u8[327680, 785][785, 1]cuda:0" = PlaceHolder[target=primals_1] | |
| # %select : Tensor "u8[327680][785]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.select.int](args = (%primals_1, 1, -1), kwargs = {}) | |
| # %eq : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select, 1), kwargs = {}) | |
| # %unsqueeze : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq, 1), kwargs = {}) | |
| # %mul : Tensor "u8[327680, 785][785, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_1, %unsqueeze), kwargs = {}) | |
| # %slice_2 : Tensor "u8[327680, 4][785, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul, 1, 8, 12), kwargs = {}) | |
| # %clone_1 : Tensor "u8[327680, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) | |
| # return %clone_1 | |
| triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_37 = async_compile.triton('triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_37', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 2097152}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*u8', 'out_ptr0': '*u8', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_37', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 3932160}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_37(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 1310720 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x0 = (xindex % 4) | |
| x1 = xindex // 4 | |
| x2 = xindex | |
| tmp0 = tl.load(in_ptr0 + (8 + x0 + 785*x1), None) | |
| tmp1 = tl.load(in_ptr0 + (784 + 785*x1), None, eviction_policy='evict_last') | |
| tmp2 = tl.full([1], 1, tl.uint8) | |
| tmp3 = tmp1 == tmp2 | |
| tmp4 = tmp3.to(tl.uint8) | |
| tmp5 = tmp0 * tmp4 | |
| tl.store(out_ptr0 + (x2), tmp5, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/wm/cwmzns2lkf7z7qgwcvjva2ofngj6wsanegij4ylfdjsl2sy4hspo.py | |
| # Topologically Sorted Source Nodes: [slice_30, clone_115, convert_element_type_1078, positions, ifreqs, ifreqs_1, neg, freqs, getitem_7, getitem_8, mul_1, sin, mul_487, sum_134], Original ATen: [aten.slice, aten.clone, aten._to_copy, aten.select, aten.arange, aten.div, aten.neg, aten.pow, aten.unsqueeze, aten.mul, aten.sin, aten.sum] | |
| # Source node to ATen node mapping: | |
| # clone_115 => clone_115 | |
| # convert_element_type_1078 => convert_element_type_1078 | |
| # freqs => pow_1 | |
| # getitem_7 => unsqueeze_1 | |
| # getitem_8 => unsqueeze_2 | |
| # ifreqs => add, convert_element_type_2, iota, mul_1 | |
| # ifreqs_1 => div | |
| # mul_1 => mul_2 | |
| # mul_487 => mul_487 | |
| # neg => neg | |
| # positions => select_2 | |
| # sin => sin | |
| # slice_30 => slice_30 | |
| # sum_134 => sum_134 | |
| # Graph fragment: | |
| # %add_207 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_207] | |
| # %view_1 : Tensor "i32[327680, 1][1, 1]cuda:0" = PlaceHolder[target=view_1] | |
| # %slice_30 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add_207, 1, 0, 9223372036854775807, 2), kwargs = {}) | |
| # %clone_115 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.clone.default](args = (%slice_30,), kwargs = {memory_format: torch.contiguous_format}) | |
| # %convert_element_type_1078 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clone_115, torch.float32), kwargs = {}) | |
| # %select_2 : Tensor "i32[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%view_1, 1, 0), kwargs = {}) | |
| # %iota : Tensor "i64[1024][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1024,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) | |
| # %mul_1 : Tensor "i64[1024][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%iota, 2), kwargs = {}) | |
| # %add : Tensor "i64[1024][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_1, 0), kwargs = {}) | |
| # %convert_element_type_2 : Tensor "f32[1024][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add, torch.float32), kwargs = {}) | |
| # %div : Tensor "f32[1024][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%convert_element_type_2, 2048), kwargs = {}) | |
| # %neg : Tensor "f32[1024][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%div,), kwargs = {}) | |
| # %pow_1 : Tensor "f32[1024][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Scalar](args = (10000.0, %neg), kwargs = {}) | |
| # %unsqueeze_1 : Tensor "i32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%select_2, 1), kwargs = {}) | |
| # %unsqueeze_2 : Tensor "f32[1, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%pow_1, 0), kwargs = {}) | |
| # %mul_2 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%unsqueeze_1, %unsqueeze_2), kwargs = {}) | |
| # %sin : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mul_2,), kwargs = {}) | |
| # %mul_487 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1078, %sin), kwargs = {}) | |
| # %sum_134 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_487,), kwargs = {}) | |
| # return %buf1608 | |
| triton_red_fused__to_copy_arange_clone_div_mul_neg_pow_select_sin_slice_sum_unsqueeze_38 = async_compile.triton('triton_red_fused__to_copy_arange_clone_div_mul_neg_pow_select_sin_slice_sum_unsqueeze_38', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.reduction( | |
| size_hints={'x': 262144, 'r0_': 2048}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_clone_div_mul_neg_pow_select_sin_slice_sum_unsqueeze_38', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 1310720, 'r0_': 1310720}} | |
| ) | |
| @triton.jit | |
| def triton_red_fused__to_copy_arange_clone_div_mul_neg_pow_select_sin_slice_sum_unsqueeze_38(in_ptr0, in_ptr1, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
| xnumel = 163840 | |
| r0_numel = 2048 | |
| rnumel = r0_numel | |
| RBLOCK: tl.constexpr = R0_BLOCK | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:, None] | |
| r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
| rbase = r0_base | |
| x0 = xindex | |
| _tmp15 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
| for r0_offset in tl.range(0, r0_numel, R0_BLOCK): | |
| r0_index = r0_offset + r0_base | |
| r0_mask = r0_index < r0_numel | |
| roffset = r0_offset | |
| rindex = r0_index | |
| r0_1 = r0_index | |
| tmp0 = tl.load(in_ptr0 + (2*r0_1 + 4096*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp2 = tl.load(in_ptr1 + (2*x0 + (r0_1 // 1024)), r0_mask, eviction_policy='evict_last', other=0.0) | |
| tmp1 = tmp0.to(tl.float32) | |
| tmp3 = tmp2.to(tl.float32) | |
| tmp4 = 2*((r0_1 % 1024)) | |
| tmp5 = tmp4.to(tl.float32) | |
| tmp6 = tl.full([1, 1], 0.00048828125, tl.float32) | |
| tmp7 = tmp5 * tmp6 | |
| tmp8 = -tmp7 | |
| tmp9 = tl.full([1, 1], 10000.0, tl.float32) | |
| tmp10 = libdevice.pow(tmp9, tmp8) | |
| tmp11 = tmp3 * tmp10 | |
| tmp12 = tl_math.sin(tmp11) | |
| tmp13 = tmp1 * tmp12 | |
| tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
| tmp16 = _tmp15 + tmp14 | |
| _tmp15 = tl.where(r0_mask, tmp16, _tmp15) | |
| tmp15 = tl.sum(_tmp15, 1)[:, None] | |
| tl.store(out_ptr0 + (x0), tmp15, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/hy/chy7u7jsfy2qsmcd33o6b7i45qzyyoeearxatlqrn2wgkrjafsap.py | |
| # Topologically Sorted Source Nodes: [full_default_5], Original ATen: [aten.embedding_dense_backward] | |
| # Source node to ATen node mapping: | |
| # full_default_5 => full_default_5 | |
| # Graph fragment: | |
| # %full_default_5 : Tensor "f32[259, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([259, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
| # return %index_put | |
| triton_poi_fused_embedding_dense_backward_39 = async_compile.triton('triton_poi_fused_embedding_dense_backward_39', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 1048576}, | |
| filename=__file__, | |
| triton_meta={'signature': {'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_39', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 0, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 4243456}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_embedding_dense_backward_39(out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 530432 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = xindex < xnumel | |
| x0 = xindex | |
| tmp0 = tl.full([1], 0.0, tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp0, xmask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/wr/cwrdkxk3wqg66t3ipy6nagxwmmip3afro46w3ozyntss2wr52t3a.py | |
| # Topologically Sorted Source Nodes: [slice_30, clone_115, copy_9, convert_element_type_1081, eq_2, unsqueeze_16, full_default_4, where], Original ATen: [aten.slice, aten.clone, aten.copy, aten.embedding_dense_backward] | |
| # Source node to ATen node mapping: | |
| # clone_115 => clone_115 | |
| # convert_element_type_1081 => convert_element_type_1081 | |
| # copy_9 => copy_9 | |
| # eq_2 => eq_2 | |
| # full_default_4 => full_default_4 | |
| # slice_30 => slice_30 | |
| # unsqueeze_16 => unsqueeze_16 | |
| # where => where | |
| # Graph fragment: | |
| # %select_1 : Tensor "i64[327680][1]cuda:0" = PlaceHolder[target=select_1] | |
| # %add_207 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_207] | |
| # %slice_30 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add_207, 1, 0, 9223372036854775807, 2), kwargs = {}) | |
| # %clone_115 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.clone.default](args = (%slice_30,), kwargs = {memory_format: torch.contiguous_format}) | |
| # %copy_9 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_30, %clone_115), kwargs = {}) | |
| # %slice_scatter_default_5 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%add_207, %copy_9, 1, 0, 9223372036854775807, 2), kwargs = {}) | |
| # %convert_element_type_1081 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%slice_scatter_default_5, torch.float32), kwargs = {}) | |
| # %eq_2 : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select_1, -1), kwargs = {}) | |
| # %unsqueeze_16 : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq_2, -1), kwargs = {}) | |
| # %full_default_4 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
| # %where : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_16, %full_default_4, %convert_element_type_1081), kwargs = {}) | |
| # return %where | |
| triton_poi_fused_clone_copy_embedding_dense_backward_slice_40 = async_compile.triton('triton_poi_fused_clone_copy_embedding_dense_backward_slice_40', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 1073741824}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_copy_embedding_dense_backward_slice_40', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 8053063680}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_clone_copy_embedding_dense_backward_slice_40(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 671088640 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| x1 = xindex // 2048 | |
| x2 = xindex | |
| x0 = (xindex % 2048) | |
| tmp0 = tl.load(in_ptr0 + (x1), None, eviction_policy='evict_last') | |
| tmp7 = tl.load(in_ptr1 + (x2), None).to(tl.float32) | |
| tmp1 = tl.full([1], -1, tl.int64) | |
| tmp2 = tmp0 == tmp1 | |
| tmp3 = (x2 % 2) | |
| tmp4 = tl.full([1], 0, tl.int64) | |
| tmp5 = tmp3 == tmp4 | |
| tmp6 = tl.load(in_ptr1 + (2*(x0 // 2) + 2048*x1), tmp5, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp8 = tl.where(tmp5, tmp6, tmp7) | |
| tmp9 = tmp8.to(tl.float32) | |
| tmp10 = tl.full([1], 0.0, tl.float32) | |
| tmp11 = tl.where(tmp2, tmp10, tmp9) | |
| tl.store(out_ptr0 + (x2), tmp11, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/rd/crdcbvdrllgakbwzs7hop6q5msvy2l25yyur5n73bke6t4apfi67.py | |
| # Topologically Sorted Source Nodes: [convert_element_type_1082, all_reduce_153], Original ATen: [aten.embedding_dense_backward, _c10d_functional.all_reduce] | |
| # Source node to ATen node mapping: | |
| # all_reduce_153 => all_reduce_153 | |
| # convert_element_type_1082 => convert_element_type_1082 | |
| # Graph fragment: | |
| # %buf1618 : Tensor = PlaceHolder[target=buf1618] | |
| # %convert_element_type_1082 : Tensor "bf16[259, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%index_put, torch.bfloat16), kwargs = {}) | |
| # %all_reduce_153 : Tensor "bf16[259, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce_.default](args = (%convert_element_type_1082, avg, 0), kwargs = {}) | |
| # return %wait_tensor_153 | |
| triton_poi_fused_all_reduce_embedding_dense_backward_41 = async_compile.triton('triton_poi_fused_all_reduce_embedding_dense_backward_41', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 1048576}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_all_reduce_embedding_dense_backward_41', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 2121728}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_all_reduce_embedding_dense_backward_41(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 530432 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = xindex < xnumel | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), xmask) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, xmask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/wq/cwqyv2tz7qk2wcxmq773ashkr4z7appc4qvtyxetah43gcu357k2.py | |
| # Topologically Sorted Source Nodes: [convert_element_type_1083], Original ATen: [aten._to_copy] | |
| # Source node to ATen node mapping: | |
| # convert_element_type_1083 => convert_element_type_1083 | |
| # Graph fragment: | |
| # %buf1623 : Tensor = PlaceHolder[target=buf1623] | |
| # %convert_element_type_1083 : Tensor "f32[259, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_153, torch.float32), kwargs = {}) | |
| # return %convert_element_type_1083 | |
| triton_poi_fused__to_copy_42 = async_compile.triton('triton_poi_fused__to_copy_42', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 1048576}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_42', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 4243456}}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_42(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 530432 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = xindex < xnumel | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, xmask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor/rank0/nj/cnjfakvke3hak7tptij6sluwkwua32gs4pbkbabtcnn6lqzhdh6d.py | |
| # Topologically Sorted Source Nodes: [convert_element_type_1077, convert_element_type_1080, add_209], Original ATen: [aten._to_copy, aten.add] | |
| # Source node to ATen node mapping: | |
| # add_209 => add_209 | |
| # convert_element_type_1077 => convert_element_type_1077 | |
| # convert_element_type_1080 => convert_element_type_1080 | |
| # Graph fragment: | |
| # %buf1602 : Tensor = PlaceHolder[target=buf1602] | |
| # %buf1615 : Tensor = PlaceHolder[target=buf1615] | |
| # %convert_element_type_1077 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_151, torch.float32), kwargs = {}) | |
| # %convert_element_type_1080 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_152, torch.float32), kwargs = {}) | |
| # %add_209 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_1077, %convert_element_type_1080), kwargs = {}) | |
| # return %add_209 | |
| triton_poi_fused__to_copy_add_43 = async_compile.triton('triton_poi_fused__to_copy_add_43', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 1}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_43', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_add_43(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 1 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1)[:] | |
| tmp0 = tl.load(in_ptr0 + (0)).to(tl.float32) | |
| tmp1 = tl.broadcast_to(tmp0, [XBLOCK]) | |
| tmp3 = tl.load(in_ptr1 + (0)).to(tl.float32) | |
| tmp4 = tl.broadcast_to(tmp3, [XBLOCK]) | |
| tmp2 = tmp1.to(tl.float32) | |
| tmp5 = tmp4.to(tl.float32) | |
| tmp6 = tmp2 + tmp5 | |
| tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32).broadcast_to(XBLOCK)), tmp6, None) | |
| ''', device_str='cuda') | |
| async_compile.wait(globals()) | |
| del async_compile | |
| class Runner: | |
| def __init__(self, partitions): | |
| self.partitions = partitions | |
| def recursively_apply_fns(self, fns): | |
| new_callables = [] | |
| for fn, c in zip(fns, self.partitions): | |
| new_callables.append(fn(c)) | |
| self.partitions = new_callables | |
| def call(self, args): | |
| primals_16, primals_20, primals_23, primals_26, primals_1, primals_4, primals_5, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_17, primals_18, primals_19, primals_21, primals_22, primals_24, primals_25, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_167, select_1, cos, add_11, add_14, add_19, add_22, add_27, add_30, add_35, add_38, add_43, add_46, add_51, add_54, add_59, add_62, add_67, add_70, add_75, add_78, add_83, add_86, add_91, add_94, add_99, add_102, add_107, getitem_89, rsqrt_26, tangents_1 = args | |
| args.clear() | |
| s91 = primals_16 | |
| s6 = primals_20 | |
| s16 = primals_23 | |
| s18 = primals_26 | |
| assert_size_stride(primals_1, (327680, 785), (785, 1)) | |
| assert_size_stride(primals_4, (2048, 768), (768, 1)) | |
| assert_size_stride(primals_5, (2048, ), (1, )) | |
| assert_size_stride(primals_9, (2048, ), (1, )) | |
| assert_size_stride(primals_10, (2048, ), (1, )) | |
| assert_size_stride(primals_11, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_12, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_13, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_14, (327680, ), (1, )) | |
| assert_size_stride(primals_15, (327680, ), (1, )) | |
| assert_size_stride(primals_17, (1, 1, 2560, s91), (2560*s91, 2560*s91, s91, 1)) | |
| assert_size_stride(primals_18, (1, 1, 2560), (2560, 2560, 1)) | |
| assert_size_stride(primals_19, (1, 1, 2560), (2560, 2560, 1)) | |
| assert_size_stride(primals_21, (1, 1, 2560, s6), (2560*s6, 2560*s6, s6, 1)) | |
| assert_size_stride(primals_22, (1, 1, 2560), (2560, 2560, 1)) | |
| assert_size_stride(primals_24, (1, 1, 2560, s16), (2560*s16, 2560*s16, s16, 1)) | |
| assert_size_stride(primals_25, (1, 1, 2560), (2560, 2560, 1)) | |
| assert_size_stride(primals_27, (1, 1, 2560, s18), (2560*s18, 2560*s18, s18, 1)) | |
| assert_size_stride(primals_28, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_29, (2048, ), (1, )) | |
| assert_size_stride(primals_30, (2048, ), (1, )) | |
| assert_size_stride(primals_31, (8192, 2048), (2048, 1)) | |
| assert_size_stride(primals_32, (8192, ), (1, )) | |
| assert_size_stride(primals_33, (2048, 8192), (8192, 1)) | |
| assert_size_stride(primals_35, (2048, ), (1, )) | |
| assert_size_stride(primals_36, (2048, ), (1, )) | |
| assert_size_stride(primals_37, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_38, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_39, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_40, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_41, (2048, ), (1, )) | |
| assert_size_stride(primals_42, (2048, ), (1, )) | |
| assert_size_stride(primals_43, (8192, 2048), (2048, 1)) | |
| assert_size_stride(primals_44, (8192, ), (1, )) | |
| assert_size_stride(primals_45, (2048, 8192), (8192, 1)) | |
| assert_size_stride(primals_47, (2048, ), (1, )) | |
| assert_size_stride(primals_48, (2048, ), (1, )) | |
| assert_size_stride(primals_49, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_50, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_51, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_52, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_53, (2048, ), (1, )) | |
| assert_size_stride(primals_54, (2048, ), (1, )) | |
| assert_size_stride(primals_55, (8192, 2048), (2048, 1)) | |
| assert_size_stride(primals_56, (8192, ), (1, )) | |
| assert_size_stride(primals_57, (2048, 8192), (8192, 1)) | |
| assert_size_stride(primals_59, (2048, ), (1, )) | |
| assert_size_stride(primals_60, (2048, ), (1, )) | |
| assert_size_stride(primals_61, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_62, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_63, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_64, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_65, (2048, ), (1, )) | |
| assert_size_stride(primals_66, (2048, ), (1, )) | |
| assert_size_stride(primals_67, (8192, 2048), (2048, 1)) | |
| assert_size_stride(primals_68, (8192, ), (1, )) | |
| assert_size_stride(primals_69, (2048, 8192), (8192, 1)) | |
| assert_size_stride(primals_71, (2048, ), (1, )) | |
| assert_size_stride(primals_72, (2048, ), (1, )) | |
| assert_size_stride(primals_73, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_74, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_75, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_76, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_77, (2048, ), (1, )) | |
| assert_size_stride(primals_78, (2048, ), (1, )) | |
| assert_size_stride(primals_79, (8192, 2048), (2048, 1)) | |
| assert_size_stride(primals_80, (8192, ), (1, )) | |
| assert_size_stride(primals_81, (2048, 8192), (8192, 1)) | |
| assert_size_stride(primals_83, (2048, ), (1, )) | |
| assert_size_stride(primals_84, (2048, ), (1, )) | |
| assert_size_stride(primals_85, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_86, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_87, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_88, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_89, (2048, ), (1, )) | |
| assert_size_stride(primals_90, (2048, ), (1, )) | |
| assert_size_stride(primals_91, (8192, 2048), (2048, 1)) | |
| assert_size_stride(primals_92, (8192, ), (1, )) | |
| assert_size_stride(primals_93, (2048, 8192), (8192, 1)) | |
| assert_size_stride(primals_95, (2048, ), (1, )) | |
| assert_size_stride(primals_96, (2048, ), (1, )) | |
| assert_size_stride(primals_97, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_98, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_99, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_100, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_101, (2048, ), (1, )) | |
| assert_size_stride(primals_102, (2048, ), (1, )) | |
| assert_size_stride(primals_103, (8192, 2048), (2048, 1)) | |
| assert_size_stride(primals_104, (8192, ), (1, )) | |
| assert_size_stride(primals_105, (2048, 8192), (8192, 1)) | |
| assert_size_stride(primals_107, (2048, ), (1, )) | |
| assert_size_stride(primals_108, (2048, ), (1, )) | |
| assert_size_stride(primals_109, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_110, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_111, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_112, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_113, (2048, ), (1, )) | |
| assert_size_stride(primals_114, (2048, ), (1, )) | |
| assert_size_stride(primals_115, (8192, 2048), (2048, 1)) | |
| assert_size_stride(primals_116, (8192, ), (1, )) | |
| assert_size_stride(primals_117, (2048, 8192), (8192, 1)) | |
| assert_size_stride(primals_119, (2048, ), (1, )) | |
| assert_size_stride(primals_120, (2048, ), (1, )) | |
| assert_size_stride(primals_121, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_122, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_123, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_124, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_125, (2048, ), (1, )) | |
| assert_size_stride(primals_126, (2048, ), (1, )) | |
| assert_size_stride(primals_127, (8192, 2048), (2048, 1)) | |
| assert_size_stride(primals_128, (8192, ), (1, )) | |
| assert_size_stride(primals_129, (2048, 8192), (8192, 1)) | |
| assert_size_stride(primals_131, (2048, ), (1, )) | |
| assert_size_stride(primals_132, (2048, ), (1, )) | |
| assert_size_stride(primals_133, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_134, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_135, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_136, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_137, (2048, ), (1, )) | |
| assert_size_stride(primals_138, (2048, ), (1, )) | |
| assert_size_stride(primals_139, (8192, 2048), (2048, 1)) | |
| assert_size_stride(primals_140, (8192, ), (1, )) | |
| assert_size_stride(primals_141, (2048, 8192), (8192, 1)) | |
| assert_size_stride(primals_143, (2048, ), (1, )) | |
| assert_size_stride(primals_144, (2048, ), (1, )) | |
| assert_size_stride(primals_145, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_146, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_147, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_148, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_149, (2048, ), (1, )) | |
| assert_size_stride(primals_150, (2048, ), (1, )) | |
| assert_size_stride(primals_151, (8192, 2048), (2048, 1)) | |
| assert_size_stride(primals_152, (8192, ), (1, )) | |
| assert_size_stride(primals_153, (2048, 8192), (8192, 1)) | |
| assert_size_stride(primals_155, (2048, ), (1, )) | |
| assert_size_stride(primals_156, (2048, ), (1, )) | |
| assert_size_stride(primals_157, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_158, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_159, (512, 2048), (2048, 1)) | |
| assert_size_stride(primals_160, (2048, 2048), (2048, 1)) | |
| assert_size_stride(primals_161, (2048, ), (1, )) | |
| assert_size_stride(primals_162, (2048, ), (1, )) | |
| assert_size_stride(primals_163, (8192, 2048), (2048, 1)) | |
| assert_size_stride(primals_164, (8192, ), (1, )) | |
| assert_size_stride(primals_165, (2048, 8192), (8192, 1)) | |
| assert_size_stride(primals_167, (2048, ), (1, )) | |
| assert_size_stride(select_1, (327680, ), (1, )) | |
| assert_size_stride(cos, (327680, 1024), (1024, 1)) | |
| assert_size_stride(add_11, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_14, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_19, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_22, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_27, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_30, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_35, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_38, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_43, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_46, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_51, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_54, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_59, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_62, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_67, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_70, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_75, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_78, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_83, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_86, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_91, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_94, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_99, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_102, (327680, 2048), (2048, 1)) | |
| assert_size_stride(add_107, (327680, 2048), (2048, 1)) | |
| assert_size_stride(getitem_89, (327680, 1), (1, 1)) | |
| assert_size_stride(rsqrt_26, (327680, 1), (1, 1)) | |
| assert_size_stride(tangents_1, (327680, 2048), (2048, 1)) | |
| with torch.cuda._DeviceGuard(0): | |
| torch.cuda.set_device(0) | |
| buf6 = empty_strided_cuda((327680, 2048), (2048, 1), torch.bfloat16) | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [convert_element_type_410, redistribute_4, convert_element_type_412, mul_153, mul_154, sum_1, x_26, mul_155, sum_2, mul_156, sub_29, sub_30, div_3, mul_157, mul_158, sum_3, sum_4, convert_element_type_414], Original ATen: [aten.native_layer_norm_backward, aten._to_copy, aten.native_layer_norm] | |
| workspace_0 = empty_strided_cuda((10485760, ), (1, ), torch.float32) | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_native_layer_norm_native_layer_norm_backward_0.run(tangents_1, primals_167, add_107, getitem_89, rsqrt_26, buf6, workspace_0, 327680, 2048, stream=stream0) | |
| buf3 = workspace_0[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf5 = workspace_0[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del workspace_0 | |
| del add_107 | |
| del getitem_89 | |
| del primals_167 | |
| del rsqrt_26 | |
| del tangents_1 | |
| buf13 = empty_strided_cuda((2048, ), (1, ), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [convert_element_type_415, all_reduce_1], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf3, buf13, 2048, stream=stream0) | |
| buf7 = empty_strided_cuda((2048, ), (1, ), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [convert_element_type_416, all_reduce], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf5, buf7, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_416, all_reduce], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf7, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf7) | |
| buf12 = buf5; del buf5 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_417], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf7, buf12, 2048, stream=stream0) | |
| del buf7 | |
| # Topologically Sorted Source Nodes: [convert_element_type_415, all_reduce_1], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf13, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_1], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf13) | |
| buf18 = buf3; del buf3 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_418], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf13, buf18, 2048, stream=stream0) | |
| del buf13 | |
| buf19 = empty_strided_cuda((2048, 8192), (8192, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_165, buf19, 16777216, stream=stream0) | |
| del primals_165 | |
| buf20 = empty_strided_cuda((327680, 8192), (8192, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [redistribute_4, x_26, permute_121, mm_49], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf6, buf19, out=buf20) | |
| buf21 = empty_strided_cuda((327680, 1), (1, 327680), torch.float32) | |
| buf22 = empty_strided_cuda((327680, 1), (1, 327680), torch.float32) | |
| buf24 = empty_strided_cuda((327680, 2048), (2048, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_102, primals_161, primals_162, buf21, buf22, buf24, 327680, 2048, stream=stream0) | |
| del primals_162 | |
| buf25 = reinterpret_tensor(buf19, (2048, 8192), (1, 2048), 0); del buf19 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_163, buf25, 16777216, stream=stream0) | |
| del primals_163 | |
| buf26 = empty_strided_cuda((327680, 8192), (8192, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm] | |
| extern_kernels.mm(buf24, buf25, out=buf26) | |
| buf27 = empty_strided_cuda((327680, 8192), (8192, 1), torch.bfloat16) | |
| buf41 = buf20; del buf20 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, x, x_25, convert_element_type_425, mul_164, mul_165, sub_31, mul_166, add_112, mul_167, mul_168, mul_169, add_113, mul_170, convert_element_type_427], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf41, primals_164, buf26, buf27, 2684354560, stream=stream0) | |
| del buf26 | |
| del primals_164 | |
| buf28 = empty_strided_cuda((2048, 8192), (8192, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [permute_122, redistribute_3, x, x_25, mm_50], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf6, (2048, 327680), (1, 2048), 0), buf27, out=buf28) | |
| buf29 = empty_strided_cuda((1, 2048, 160), (327680, 1, 2048), torch.float32) | |
| # Topologically Sorted Source Nodes: [sum_5], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_6.run(buf6, buf29, 327680, 2048, stream=stream0) | |
| buf30 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [sum_5], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_7.run(buf29, buf30, 2048, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_279, all_reduce_2], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf30, (2048, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_279, wait_tensor_2], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf30, (2048, ), (1, ), 0)) | |
| buf35 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_279, convert_element_type_423], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf30, buf35, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_3], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf28, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_3], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf28) | |
| buf40 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_424], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf28, buf40, 16777216, stream=stream0) | |
| buf42 = empty_strided_cuda((327680, 2048), (2048, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [permute_125, mm_51], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(buf41, reinterpret_tensor(buf25, (8192, 2048), (2048, 1), 0), out=buf42) | |
| buf43 = reinterpret_tensor(buf25, (8192, 2048), (2048, 1), 0); del buf25 # reuse | |
| # Topologically Sorted Source Nodes: [permute_126, mm_52], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf41, (8192, 327680), (1, 8192), 0), buf24, out=buf43) | |
| buf44 = empty_strided_cuda((1, 8192, 160), (1310720, 1, 8192), torch.float32) | |
| # Topologically Sorted Source Nodes: [sum_6], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_9.run(buf41, buf44, 1310720, 2048, stream=stream0) | |
| buf45 = empty_strided_cuda((1, 8192), (8192, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [sum_6], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_10.run(buf44, buf45, 8192, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_280, all_reduce_4], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf45, (8192, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_280, wait_tensor_4], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf45, (8192, ), (1, ), 0)) | |
| buf50 = empty_strided_cuda((8192, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_280, convert_element_type_432], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_view_11.run(buf45, buf50, 8192, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_5], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf43, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_5], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf43) | |
| buf55 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_433], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf43, buf55, 16777216, stream=stream0) | |
| buf62 = buf6; del buf6 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_434, convert_element_type_436, mul_172, mul_173, sum_7, mul_174, sum_8, mul_175, sub_33, sub_34, div_4, mul_176, mul_177, sum_9, sum_10, convert_element_type_438, add_114], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add] | |
| workspace_1 = empty_strided_cuda((10485760, ), (1, ), torch.float32) | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf62, buf42, primals_161, add_102, buf21, buf22, workspace_1, 327680, 2048, stream=stream0) | |
| buf59 = workspace_1[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf61 = workspace_1[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_102 | |
| del primals_161 | |
| buf69 = reinterpret_tensor(buf30, (2048, ), (1, ), 0); del buf30 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_439, all_reduce_7], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf59, buf69, 2048, stream=stream0) | |
| buf63 = empty_strided_cuda((2048, ), (1, ), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [convert_element_type_440, all_reduce_6], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf61, buf63, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_440, all_reduce_6], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf63, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_6], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf63) | |
| buf68 = buf61; del buf61 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_441], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf63, buf68, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_439, all_reduce_7], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf69, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_7], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf69) | |
| buf74 = buf59; del buf59 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_442], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf69, buf74, 2048, stream=stream0) | |
| buf75 = buf22; del buf22 # reuse | |
| buf76 = buf21; del buf21 # reuse | |
| buf78 = buf42; del buf42 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_99, primals_155, primals_156, buf75, buf76, buf78, 327680, 2048, stream=stream0) | |
| del primals_156 | |
| buf79 = empty_strided_cuda((2048, 2048), (1, 2048), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_157, buf79, 4194304, stream=stream0) | |
| del primals_157 | |
| buf80 = buf24; del buf24 # reuse | |
| # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf78, buf79, out=buf80) | |
| buf81 = empty_strided_cuda((2048, 512), (1, 2048), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_158, buf81, 1048576, stream=stream0) | |
| del primals_158 | |
| buf82 = empty_strided_cuda((327680, 512), (512, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf78, buf81, out=buf82) | |
| buf83 = empty_strided_cuda((2048, 512), (1, 2048), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_159, buf83, 1048576, stream=stream0) | |
| del primals_159 | |
| buf84 = empty_strided_cuda((327680, 512), (512, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf78, buf83, out=buf84) | |
| buf85 = empty_strided_cuda((1, 16, 327680), (5242880, 327680, 1), torch.float32) | |
| buf86 = empty_strided_cuda((1, 16, 327680), (5242880, 327680, 1), torch.float32) | |
| buf87 = empty_strided_cuda((1, 16, 327680, 128), (671088640, 128, 2048, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_permute_view_15.run(buf80, buf82, buf84, buf85, buf86, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf87, s91, 2560, 1, 16, stream=stream0) | |
| buf90 = empty_strided_cuda((2048, 2048), (2048, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [permute_129, rearrange_3, o, mm_53], Original ATen: [aten.t, aten.permute, aten.view, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf62, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf87, (327680, 2048), (2048, 1), 0), out=buf90) | |
| buf91 = empty_strided_cuda((2048, 2048), (2048, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_160, buf91, 4194304, stream=stream0) | |
| del primals_160 | |
| buf92 = empty_strided_cuda((327680, 2048), (2048, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [redistribute_5, o, permute_131, mm_54], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf62, buf91, out=buf92) | |
| # Topologically Sorted Source Nodes: [all_reduce_8], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf90, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_8], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf90) | |
| buf97 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_447], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf90, buf97, 4194304, stream=stream0) | |
| buf99 = buf86; del buf86 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_283, view_284, permute_133, flex_attention_backward], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused_flex_attention_backward_permute_view_17.run(buf87, buf92, buf99, 5242880, 128, stream=stream0) | |
| buf100 = buf87; del buf87 # reuse | |
| buf101 = empty_strided_cuda((1, 4, 327680, 128), (167772160, 128, 512, 1), torch.bfloat16) | |
| buf102 = empty_strided_cuda((1, 4, 327680, 128), (167772160, 128, 512, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [q, k, v, view_283, view_284, permute_133, flex_attention_backward], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_backward_permute_view_18.run(buf80, buf82, buf84, buf85, buf99, buf92, buf100, buf101, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf102, s91, s16, 12800, 1, 4, stream=stream0) | |
| del buf82 | |
| buf105 = empty_strided_cuda((512, 2048), (2048, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [view_285, permute_134, view_286, permute_135, mm_55], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf101, (512, 327680), (1, 512), 0), buf78, out=buf105) | |
| buf106 = buf92; del buf92 # reuse | |
| # Topologically Sorted Source Nodes: [view_285, permute_134, view_286, permute_137, mm_56], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf101, (327680, 512), (512, 1), 0), reinterpret_tensor(buf83, (512, 2048), (2048, 1), 0), out=buf106) | |
| # Topologically Sorted Source Nodes: [all_reduce_9], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf105, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_9], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf105) | |
| buf111 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_452], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf105, buf111, 1048576, stream=stream0) | |
| buf112 = buf105; del buf105 # reuse | |
| # Topologically Sorted Source Nodes: [view_287, permute_139, view_288, permute_140, mm_57], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf102, (512, 327680), (1, 512), 0), buf78, out=buf112) | |
| buf113 = buf80; del buf80 # reuse | |
| # Topologically Sorted Source Nodes: [view_287, permute_139, view_288, permute_142, mm_58], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf102, (327680, 512), (512, 1), 0), reinterpret_tensor(buf81, (512, 2048), (2048, 1), 0), out=buf113) | |
| # Topologically Sorted Source Nodes: [all_reduce_10], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf112, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_10], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf112) | |
| buf118 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_457], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf112, buf118, 1048576, stream=stream0) | |
| buf119 = buf90; del buf90 # reuse | |
| # Topologically Sorted Source Nodes: [view_289, permute_144, view_290, permute_145, mm_59], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf100, (2048, 327680), (1, 2048), 0), buf78, out=buf119) | |
| buf120 = buf78; del buf78 # reuse | |
| # Topologically Sorted Source Nodes: [view_289, permute_144, view_290, permute_147, mm_60], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf100, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf79, (2048, 2048), (2048, 1), 0), out=buf120) | |
| del buf100 | |
| # Topologically Sorted Source Nodes: [all_reduce_11], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf119, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_11], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf119) | |
| buf125 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_462], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf119, buf125, 4194304, stream=stream0) | |
| buf133 = buf62; del buf62 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, add_115, add_116, convert_element_type_463, convert_element_type_465, mul_179, mul_180, sum_11, mul_181, sum_12, mul_182, sub_36, sub_37, div_5, mul_183, mul_184, sum_13, sum_14, convert_element_type_467, add_117], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward] | |
| workspace_2 = workspace_1; del workspace_1 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf133, buf106, buf113, buf120, primals_155, add_99, buf75, buf76, workspace_2, 327680, 2048, stream=stream0) | |
| buf130 = workspace_2[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf132 = workspace_2[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_99 | |
| del buf106 | |
| del primals_155 | |
| buf140 = buf69; del buf69 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_468, all_reduce_13], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf130, buf140, 2048, stream=stream0) | |
| buf134 = buf63; del buf63 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_469, all_reduce_12], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf132, buf134, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_469, all_reduce_12], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf134, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_12], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf134) | |
| buf139 = buf132; del buf132 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_470], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf134, buf139, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_468, all_reduce_13], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf140, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_13], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf140) | |
| buf145 = buf130; del buf130 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_471], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf140, buf145, 2048, stream=stream0) | |
| buf146 = reinterpret_tensor(buf43, (2048, 8192), (8192, 1), 0); del buf43 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_153, buf146, 16777216, stream=stream0) | |
| del primals_153 | |
| buf147 = buf41; del buf41 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, x_24, permute_149, mm_61], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf133, buf146, out=buf147) | |
| buf148 = buf76; del buf76 # reuse | |
| buf149 = buf75; del buf75 # reuse | |
| buf151 = buf120; del buf120 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_94, primals_149, primals_150, buf148, buf149, buf151, 327680, 2048, stream=stream0) | |
| del primals_150 | |
| buf152 = reinterpret_tensor(buf146, (2048, 8192), (1, 2048), 0); del buf146 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_151, buf152, 16777216, stream=stream0) | |
| del primals_151 | |
| buf153 = buf27; del buf27 # reuse | |
| # Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm] | |
| extern_kernels.mm(buf151, buf152, out=buf153) | |
| buf154 = empty_strided_cuda((327680, 8192), (8192, 1), torch.bfloat16) | |
| buf168 = buf147; del buf147 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, x, x_23, convert_element_type_478, mul_190, mul_191, sub_38, mul_192, add_120, mul_193, mul_194, mul_195, add_121, mul_196, convert_element_type_480], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf168, primals_152, buf153, buf154, 2684354560, stream=stream0) | |
| del primals_152 | |
| buf155 = buf28; del buf28 # reuse | |
| # Topologically Sorted Source Nodes: [permute_150, redistribute_3, x, x_23, mm_62], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf133, (2048, 327680), (1, 2048), 0), buf154, out=buf155) | |
| buf156 = buf29; del buf29 # reuse | |
| # Topologically Sorted Source Nodes: [sum_15], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_6.run(buf133, buf156, 327680, 2048, stream=stream0) | |
| buf157 = reinterpret_tensor(buf140, (1, 2048), (2048, 1), 0); del buf140 # reuse | |
| # Topologically Sorted Source Nodes: [sum_15], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_7.run(buf156, buf157, 2048, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_291, all_reduce_14], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf157, (2048, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_291, wait_tensor_14], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf157, (2048, ), (1, ), 0)) | |
| buf162 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_291, convert_element_type_476], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf157, buf162, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_15], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf155, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_15], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf155) | |
| buf167 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_477], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf155, buf167, 16777216, stream=stream0) | |
| buf169 = buf113; del buf113 # reuse | |
| # Topologically Sorted Source Nodes: [permute_153, mm_63], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(buf168, reinterpret_tensor(buf152, (8192, 2048), (2048, 1), 0), out=buf169) | |
| buf170 = reinterpret_tensor(buf152, (8192, 2048), (2048, 1), 0); del buf152 # reuse | |
| # Topologically Sorted Source Nodes: [permute_154, mm_64], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf168, (8192, 327680), (1, 8192), 0), buf151, out=buf170) | |
| buf171 = buf44; del buf44 # reuse | |
| # Topologically Sorted Source Nodes: [sum_16], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_9.run(buf168, buf171, 1310720, 2048, stream=stream0) | |
| buf172 = buf45; del buf45 # reuse | |
| # Topologically Sorted Source Nodes: [sum_16], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_10.run(buf171, buf172, 8192, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_292, all_reduce_16], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf172, (8192, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_292, wait_tensor_16], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf172, (8192, ), (1, ), 0)) | |
| buf177 = empty_strided_cuda((8192, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_292, convert_element_type_485], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_view_11.run(buf172, buf177, 8192, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_17], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf170, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_17], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf170) | |
| buf182 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_486], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf170, buf182, 16777216, stream=stream0) | |
| buf189 = buf133; del buf133 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_487, convert_element_type_489, mul_198, mul_199, sum_17, mul_200, sum_18, mul_201, sub_40, sub_41, div_6, mul_202, mul_203, sum_19, sum_20, convert_element_type_491, add_122], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add] | |
| workspace_3 = workspace_2; del workspace_2 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf189, buf169, primals_149, add_94, buf148, buf149, workspace_3, 327680, 2048, stream=stream0) | |
| buf186 = workspace_3[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf188 = workspace_3[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_94 | |
| del primals_149 | |
| buf196 = reinterpret_tensor(buf157, (2048, ), (1, ), 0); del buf157 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_492, all_reduce_19], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf186, buf196, 2048, stream=stream0) | |
| buf190 = buf134; del buf134 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_493, all_reduce_18], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf188, buf190, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_493, all_reduce_18], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf190, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_18], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf190) | |
| buf195 = buf188; del buf188 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_494], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf190, buf195, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_492, all_reduce_19], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf196, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_19], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf196) | |
| buf201 = buf186; del buf186 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_495], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf196, buf201, 2048, stream=stream0) | |
| buf202 = buf149; del buf149 # reuse | |
| buf203 = buf148; del buf148 # reuse | |
| buf205 = buf169; del buf169 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_91, primals_143, primals_144, buf202, buf203, buf205, 327680, 2048, stream=stream0) | |
| del primals_144 | |
| buf206 = reinterpret_tensor(buf119, (2048, 2048), (1, 2048), 0); del buf119 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_145, buf206, 4194304, stream=stream0) | |
| del primals_145 | |
| buf207 = buf151; del buf151 # reuse | |
| # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf205, buf206, out=buf207) | |
| buf208 = reinterpret_tensor(buf112, (2048, 512), (1, 2048), 0); del buf112 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_146, buf208, 1048576, stream=stream0) | |
| del primals_146 | |
| buf209 = reinterpret_tensor(buf102, (327680, 512), (512, 1), 0); del buf102 # reuse | |
| # Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf205, buf208, out=buf209) | |
| buf210 = buf81; del buf81 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_147, buf210, 1048576, stream=stream0) | |
| del primals_147 | |
| buf211 = reinterpret_tensor(buf101, (327680, 512), (512, 1), 0); del buf101 # reuse | |
| # Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf205, buf210, out=buf211) | |
| buf212 = buf99; del buf99 # reuse | |
| buf213 = buf85; del buf85 # reuse | |
| buf214 = empty_strided_cuda((1, 16, 327680, 128), (671088640, 128, 2048, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_permute_view_15.run(buf207, buf209, buf211, buf212, buf213, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf214, s91, 2560, 1, 16, stream=stream0) | |
| buf217 = reinterpret_tensor(buf79, (2048, 2048), (2048, 1), 0); del buf79 # reuse | |
| # Topologically Sorted Source Nodes: [permute_157, rearrange_3, o, mm_65], Original ATen: [aten.t, aten.permute, aten.view, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf189, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf214, (327680, 2048), (2048, 1), 0), out=buf217) | |
| buf218 = buf91; del buf91 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_148, buf218, 4194304, stream=stream0) | |
| del primals_148 | |
| buf219 = empty_strided_cuda((327680, 2048), (2048, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [redistribute_5, o, permute_159, mm_66], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf189, buf218, out=buf219) | |
| # Topologically Sorted Source Nodes: [all_reduce_20], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf217, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_20], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf217) | |
| buf224 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_500], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf217, buf224, 4194304, stream=stream0) | |
| buf226 = buf213; del buf213 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_295, view_296, permute_161, flex_attention_backward_1], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused_flex_attention_backward_permute_view_17.run(buf214, buf219, buf226, 5242880, 128, stream=stream0) | |
| buf227 = buf214; del buf214 # reuse | |
| buf228 = reinterpret_tensor(buf84, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf84 # reuse | |
| buf229 = empty_strided_cuda((1, 4, 327680, 128), (167772160, 128, 512, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [q, k, v, view_295, view_296, permute_161, flex_attention_backward_1], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_backward_permute_view_18.run(buf207, buf209, buf211, buf212, buf226, buf219, buf227, buf228, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf229, s91, s16, 12800, 1, 4, stream=stream0) | |
| del buf209 | |
| del buf211 | |
| buf232 = reinterpret_tensor(buf83, (512, 2048), (2048, 1), 0); del buf83 # reuse | |
| # Topologically Sorted Source Nodes: [view_297, permute_162, view_298, permute_163, mm_67], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf228, (512, 327680), (1, 512), 0), buf205, out=buf232) | |
| buf233 = buf219; del buf219 # reuse | |
| # Topologically Sorted Source Nodes: [view_297, permute_162, view_298, permute_165, mm_68], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf228, (327680, 512), (512, 1), 0), reinterpret_tensor(buf210, (512, 2048), (2048, 1), 0), out=buf233) | |
| # Topologically Sorted Source Nodes: [all_reduce_21], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf232, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_21], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf232) | |
| buf238 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_505], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf232, buf238, 1048576, stream=stream0) | |
| buf239 = buf232; del buf232 # reuse | |
| # Topologically Sorted Source Nodes: [view_299, permute_167, view_300, permute_168, mm_69], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf229, (512, 327680), (1, 512), 0), buf205, out=buf239) | |
| buf240 = buf207; del buf207 # reuse | |
| # Topologically Sorted Source Nodes: [view_299, permute_167, view_300, permute_170, mm_70], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf229, (327680, 512), (512, 1), 0), reinterpret_tensor(buf208, (512, 2048), (2048, 1), 0), out=buf240) | |
| # Topologically Sorted Source Nodes: [all_reduce_22], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf239, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_22], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf239) | |
| buf245 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_510], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf239, buf245, 1048576, stream=stream0) | |
| buf246 = buf217; del buf217 # reuse | |
| # Topologically Sorted Source Nodes: [view_301, permute_172, view_302, permute_173, mm_71], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf227, (2048, 327680), (1, 2048), 0), buf205, out=buf246) | |
| buf247 = buf205; del buf205 # reuse | |
| # Topologically Sorted Source Nodes: [view_301, permute_172, view_302, permute_175, mm_72], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf227, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf206, (2048, 2048), (2048, 1), 0), out=buf247) | |
| # Topologically Sorted Source Nodes: [all_reduce_23], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf246, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_23], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf246) | |
| buf252 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_515], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf246, buf252, 4194304, stream=stream0) | |
| buf260 = buf189; del buf189 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, add_123, add_124, convert_element_type_516, convert_element_type_518, mul_205, mul_206, sum_21, mul_207, sum_22, mul_208, sub_43, sub_44, div_7, mul_209, mul_210, sum_23, sum_24, convert_element_type_520, add_125], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward] | |
| workspace_4 = workspace_3; del workspace_3 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf260, buf233, buf240, buf247, primals_143, add_91, buf202, buf203, workspace_4, 327680, 2048, stream=stream0) | |
| buf257 = workspace_4[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf259 = workspace_4[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_91 | |
| del primals_143 | |
| buf267 = buf196; del buf196 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_521, all_reduce_25], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf257, buf267, 2048, stream=stream0) | |
| buf261 = buf190; del buf190 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_522, all_reduce_24], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf259, buf261, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_522, all_reduce_24], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf261, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_24], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf261) | |
| buf266 = buf259; del buf259 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_523], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf261, buf266, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_521, all_reduce_25], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf267, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_25], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf267) | |
| buf272 = buf257; del buf257 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_524], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf267, buf272, 2048, stream=stream0) | |
| buf273 = reinterpret_tensor(buf170, (2048, 8192), (8192, 1), 0); del buf170 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_141, buf273, 16777216, stream=stream0) | |
| del primals_141 | |
| buf274 = buf168; del buf168 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, x_22, permute_177, mm_73], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf260, buf273, out=buf274) | |
| buf275 = buf203; del buf203 # reuse | |
| buf276 = buf202; del buf202 # reuse | |
| buf278 = buf247; del buf247 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_86, primals_137, primals_138, buf275, buf276, buf278, 327680, 2048, stream=stream0) | |
| del primals_138 | |
| buf279 = reinterpret_tensor(buf273, (2048, 8192), (1, 2048), 0); del buf273 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_139, buf279, 16777216, stream=stream0) | |
| del primals_139 | |
| buf280 = buf154; del buf154 # reuse | |
| # Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm] | |
| extern_kernels.mm(buf278, buf279, out=buf280) | |
| buf281 = buf153; del buf153 # reuse | |
| buf295 = buf274; del buf274 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, x, x_21, convert_element_type_531, mul_216, mul_217, sub_45, mul_218, add_128, mul_219, mul_220, mul_221, add_129, mul_222, convert_element_type_533], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf295, primals_140, buf280, buf281, 2684354560, stream=stream0) | |
| del primals_140 | |
| buf282 = buf155; del buf155 # reuse | |
| # Topologically Sorted Source Nodes: [permute_178, redistribute_3, x, x_21, mm_74], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf260, (2048, 327680), (1, 2048), 0), buf281, out=buf282) | |
| buf283 = buf156; del buf156 # reuse | |
| # Topologically Sorted Source Nodes: [sum_25], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_6.run(buf260, buf283, 327680, 2048, stream=stream0) | |
| buf284 = reinterpret_tensor(buf267, (1, 2048), (2048, 1), 0); del buf267 # reuse | |
| # Topologically Sorted Source Nodes: [sum_25], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_7.run(buf283, buf284, 2048, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_303, all_reduce_26], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf284, (2048, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_303, wait_tensor_26], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf284, (2048, ), (1, ), 0)) | |
| buf289 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_303, convert_element_type_529], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf284, buf289, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_27], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf282, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_27], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf282) | |
| buf294 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_530], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf282, buf294, 16777216, stream=stream0) | |
| buf296 = buf240; del buf240 # reuse | |
| # Topologically Sorted Source Nodes: [permute_181, mm_75], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(buf295, reinterpret_tensor(buf279, (8192, 2048), (2048, 1), 0), out=buf296) | |
| buf297 = reinterpret_tensor(buf279, (8192, 2048), (2048, 1), 0); del buf279 # reuse | |
| # Topologically Sorted Source Nodes: [permute_182, mm_76], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf295, (8192, 327680), (1, 8192), 0), buf278, out=buf297) | |
| buf298 = buf171; del buf171 # reuse | |
| # Topologically Sorted Source Nodes: [sum_26], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_9.run(buf295, buf298, 1310720, 2048, stream=stream0) | |
| buf299 = buf172; del buf172 # reuse | |
| # Topologically Sorted Source Nodes: [sum_26], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_10.run(buf298, buf299, 8192, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_304, all_reduce_28], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf299, (8192, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_304, wait_tensor_28], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf299, (8192, ), (1, ), 0)) | |
| buf304 = empty_strided_cuda((8192, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_304, convert_element_type_538], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_view_11.run(buf299, buf304, 8192, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_29], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf297, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_29], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf297) | |
| buf309 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_539], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf297, buf309, 16777216, stream=stream0) | |
| buf316 = buf260; del buf260 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_540, convert_element_type_542, mul_224, mul_225, sum_27, mul_226, sum_28, mul_227, sub_47, sub_48, div_8, mul_228, mul_229, sum_29, sum_30, convert_element_type_544, add_130], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add] | |
| workspace_5 = workspace_4; del workspace_4 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf316, buf296, primals_137, add_86, buf275, buf276, workspace_5, 327680, 2048, stream=stream0) | |
| buf313 = workspace_5[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf315 = workspace_5[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_86 | |
| del primals_137 | |
| buf323 = reinterpret_tensor(buf284, (2048, ), (1, ), 0); del buf284 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_545, all_reduce_31], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf313, buf323, 2048, stream=stream0) | |
| buf317 = buf261; del buf261 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_546, all_reduce_30], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf315, buf317, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_546, all_reduce_30], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf317, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_30], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf317) | |
| buf322 = buf315; del buf315 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_547], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf317, buf322, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_545, all_reduce_31], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf323, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_31], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf323) | |
| buf328 = buf313; del buf313 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_548], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf323, buf328, 2048, stream=stream0) | |
| buf329 = buf276; del buf276 # reuse | |
| buf330 = buf275; del buf275 # reuse | |
| buf332 = buf296; del buf296 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_83, primals_131, primals_132, buf329, buf330, buf332, 327680, 2048, stream=stream0) | |
| del primals_132 | |
| buf333 = reinterpret_tensor(buf246, (2048, 2048), (1, 2048), 0); del buf246 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_133, buf333, 4194304, stream=stream0) | |
| del primals_133 | |
| buf334 = buf278; del buf278 # reuse | |
| # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf332, buf333, out=buf334) | |
| buf335 = reinterpret_tensor(buf239, (2048, 512), (1, 2048), 0); del buf239 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_134, buf335, 1048576, stream=stream0) | |
| del primals_134 | |
| buf336 = reinterpret_tensor(buf229, (327680, 512), (512, 1), 0); del buf229 # reuse | |
| # Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf332, buf335, out=buf336) | |
| buf337 = buf208; del buf208 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_135, buf337, 1048576, stream=stream0) | |
| del primals_135 | |
| buf338 = reinterpret_tensor(buf228, (327680, 512), (512, 1), 0); del buf228 # reuse | |
| # Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf332, buf337, out=buf338) | |
| buf339 = buf226; del buf226 # reuse | |
| buf340 = buf212; del buf212 # reuse | |
| buf341 = reinterpret_tensor(buf233, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf233 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_permute_view_15.run(buf334, buf336, buf338, buf339, buf340, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf341, s91, 2560, 1, 16, stream=stream0) | |
| buf344 = reinterpret_tensor(buf206, (2048, 2048), (2048, 1), 0); del buf206 # reuse | |
| # Topologically Sorted Source Nodes: [permute_185, rearrange_3, o, mm_77], Original ATen: [aten.t, aten.permute, aten.view, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf316, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf341, (327680, 2048), (2048, 1), 0), out=buf344) | |
| buf345 = buf218; del buf218 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_136, buf345, 4194304, stream=stream0) | |
| del primals_136 | |
| buf346 = reinterpret_tensor(buf227, (327680, 2048), (2048, 1), 0); del buf227 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5, o, permute_187, mm_78], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf316, buf345, out=buf346) | |
| # Topologically Sorted Source Nodes: [all_reduce_32], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf344, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_32], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf344) | |
| buf351 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_553], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf344, buf351, 4194304, stream=stream0) | |
| buf353 = buf340; del buf340 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_307, view_308, permute_189, flex_attention_backward_2], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused_flex_attention_backward_permute_view_17.run(buf341, buf346, buf353, 5242880, 128, stream=stream0) | |
| buf354 = buf341; del buf341 # reuse | |
| buf355 = empty_strided_cuda((1, 4, 327680, 128), (167772160, 128, 512, 1), torch.bfloat16) | |
| buf356 = empty_strided_cuda((1, 4, 327680, 128), (167772160, 128, 512, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [q, k, v, view_307, view_308, permute_189, flex_attention_backward_2], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_backward_permute_view_18.run(buf334, buf336, buf338, buf339, buf353, buf346, buf354, buf355, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf356, s91, s16, 12800, 1, 4, stream=stream0) | |
| buf359 = reinterpret_tensor(buf210, (512, 2048), (2048, 1), 0); del buf210 # reuse | |
| # Topologically Sorted Source Nodes: [view_309, permute_190, view_310, permute_191, mm_79], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf355, (512, 327680), (1, 512), 0), buf332, out=buf359) | |
| buf360 = buf346; del buf346 # reuse | |
| # Topologically Sorted Source Nodes: [view_309, permute_190, view_310, permute_193, mm_80], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf355, (327680, 512), (512, 1), 0), reinterpret_tensor(buf337, (512, 2048), (2048, 1), 0), out=buf360) | |
| # Topologically Sorted Source Nodes: [all_reduce_33], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf359, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_33], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf359) | |
| buf365 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_558], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf359, buf365, 1048576, stream=stream0) | |
| buf366 = buf359; del buf359 # reuse | |
| # Topologically Sorted Source Nodes: [view_311, permute_195, view_312, permute_196, mm_81], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf356, (512, 327680), (1, 512), 0), buf332, out=buf366) | |
| buf367 = buf334; del buf334 # reuse | |
| # Topologically Sorted Source Nodes: [view_311, permute_195, view_312, permute_198, mm_82], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf356, (327680, 512), (512, 1), 0), reinterpret_tensor(buf335, (512, 2048), (2048, 1), 0), out=buf367) | |
| # Topologically Sorted Source Nodes: [all_reduce_34], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf366, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_34], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf366) | |
| buf372 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_563], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf366, buf372, 1048576, stream=stream0) | |
| buf373 = buf344; del buf344 # reuse | |
| # Topologically Sorted Source Nodes: [view_313, permute_200, view_314, permute_201, mm_83], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf354, (2048, 327680), (1, 2048), 0), buf332, out=buf373) | |
| buf374 = buf332; del buf332 # reuse | |
| # Topologically Sorted Source Nodes: [view_313, permute_200, view_314, permute_203, mm_84], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf354, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf333, (2048, 2048), (2048, 1), 0), out=buf374) | |
| # Topologically Sorted Source Nodes: [all_reduce_35], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf373, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_35], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf373) | |
| buf379 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_568], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf373, buf379, 4194304, stream=stream0) | |
| buf387 = buf316; del buf316 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, add_131, add_132, convert_element_type_569, convert_element_type_571, mul_231, mul_232, sum_31, mul_233, sum_32, mul_234, sub_50, sub_51, div_9, mul_235, mul_236, sum_33, sum_34, convert_element_type_573, add_133], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward] | |
| workspace_6 = workspace_5; del workspace_5 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf387, buf360, buf367, buf374, primals_131, add_83, buf329, buf330, workspace_6, 327680, 2048, stream=stream0) | |
| buf384 = workspace_6[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf386 = workspace_6[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_83 | |
| del primals_131 | |
| buf394 = buf323; del buf323 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_574, all_reduce_37], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf384, buf394, 2048, stream=stream0) | |
| buf388 = buf317; del buf317 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_575, all_reduce_36], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf386, buf388, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_575, all_reduce_36], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf388, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_36], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf388) | |
| buf393 = buf386; del buf386 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_576], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf388, buf393, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_574, all_reduce_37], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf394, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_37], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf394) | |
| buf399 = buf384; del buf384 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_577], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf394, buf399, 2048, stream=stream0) | |
| buf400 = reinterpret_tensor(buf297, (2048, 8192), (8192, 1), 0); del buf297 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_129, buf400, 16777216, stream=stream0) | |
| del primals_129 | |
| buf401 = buf295; del buf295 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, x_20, permute_205, mm_85], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf387, buf400, out=buf401) | |
| buf402 = buf330; del buf330 # reuse | |
| buf403 = buf329; del buf329 # reuse | |
| buf405 = buf374; del buf374 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_78, primals_125, primals_126, buf402, buf403, buf405, 327680, 2048, stream=stream0) | |
| del primals_126 | |
| buf406 = reinterpret_tensor(buf400, (2048, 8192), (1, 2048), 0); del buf400 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_127, buf406, 16777216, stream=stream0) | |
| del primals_127 | |
| buf407 = buf281; del buf281 # reuse | |
| # Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm] | |
| extern_kernels.mm(buf405, buf406, out=buf407) | |
| buf408 = buf280; del buf280 # reuse | |
| buf422 = buf401; del buf401 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, x, x_19, convert_element_type_584, mul_242, mul_243, sub_52, mul_244, add_136, mul_245, mul_246, mul_247, add_137, mul_248, convert_element_type_586], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf422, primals_128, buf407, buf408, 2684354560, stream=stream0) | |
| del primals_128 | |
| buf409 = buf282; del buf282 # reuse | |
| # Topologically Sorted Source Nodes: [permute_206, redistribute_3, x, x_19, mm_86], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf387, (2048, 327680), (1, 2048), 0), buf408, out=buf409) | |
| buf410 = buf283; del buf283 # reuse | |
| # Topologically Sorted Source Nodes: [sum_35], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_6.run(buf387, buf410, 327680, 2048, stream=stream0) | |
| buf411 = reinterpret_tensor(buf394, (1, 2048), (2048, 1), 0); del buf394 # reuse | |
| # Topologically Sorted Source Nodes: [sum_35], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_7.run(buf410, buf411, 2048, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_315, all_reduce_38], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf411, (2048, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_315, wait_tensor_38], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf411, (2048, ), (1, ), 0)) | |
| buf416 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_315, convert_element_type_582], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf411, buf416, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_39], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf409, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_39], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf409) | |
| buf421 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_583], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf409, buf421, 16777216, stream=stream0) | |
| buf423 = buf367; del buf367 # reuse | |
| # Topologically Sorted Source Nodes: [permute_209, mm_87], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(buf422, reinterpret_tensor(buf406, (8192, 2048), (2048, 1), 0), out=buf423) | |
| buf424 = reinterpret_tensor(buf406, (8192, 2048), (2048, 1), 0); del buf406 # reuse | |
| # Topologically Sorted Source Nodes: [permute_210, mm_88], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf422, (8192, 327680), (1, 8192), 0), buf405, out=buf424) | |
| buf425 = buf298; del buf298 # reuse | |
| # Topologically Sorted Source Nodes: [sum_36], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_9.run(buf422, buf425, 1310720, 2048, stream=stream0) | |
| buf426 = buf299; del buf299 # reuse | |
| # Topologically Sorted Source Nodes: [sum_36], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_10.run(buf425, buf426, 8192, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_316, all_reduce_40], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf426, (8192, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_316, wait_tensor_40], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf426, (8192, ), (1, ), 0)) | |
| buf431 = empty_strided_cuda((8192, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_316, convert_element_type_591], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_view_11.run(buf426, buf431, 8192, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_41], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf424, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_41], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf424) | |
| buf436 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_592], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf424, buf436, 16777216, stream=stream0) | |
| buf443 = buf387; del buf387 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_593, convert_element_type_595, mul_250, mul_251, sum_37, mul_252, sum_38, mul_253, sub_54, sub_55, div_10, mul_254, mul_255, sum_39, sum_40, convert_element_type_597, add_138], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add] | |
| workspace_7 = workspace_6; del workspace_6 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf443, buf423, primals_125, add_78, buf402, buf403, workspace_7, 327680, 2048, stream=stream0) | |
| buf440 = workspace_7[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf442 = workspace_7[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_78 | |
| del primals_125 | |
| buf450 = reinterpret_tensor(buf411, (2048, ), (1, ), 0); del buf411 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_598, all_reduce_43], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf440, buf450, 2048, stream=stream0) | |
| buf444 = buf388; del buf388 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_599, all_reduce_42], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf442, buf444, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_599, all_reduce_42], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf444, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_42], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf444) | |
| buf449 = buf442; del buf442 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_600], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf444, buf449, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_598, all_reduce_43], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf450, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_43], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf450) | |
| buf455 = buf440; del buf440 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_601], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf450, buf455, 2048, stream=stream0) | |
| buf456 = buf403; del buf403 # reuse | |
| buf457 = buf402; del buf402 # reuse | |
| buf459 = buf423; del buf423 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_75, primals_119, primals_120, buf456, buf457, buf459, 327680, 2048, stream=stream0) | |
| del primals_120 | |
| buf460 = reinterpret_tensor(buf373, (2048, 2048), (1, 2048), 0); del buf373 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_121, buf460, 4194304, stream=stream0) | |
| del primals_121 | |
| buf461 = buf405; del buf405 # reuse | |
| # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf459, buf460, out=buf461) | |
| buf462 = reinterpret_tensor(buf366, (2048, 512), (1, 2048), 0); del buf366 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_122, buf462, 1048576, stream=stream0) | |
| del primals_122 | |
| buf463 = reinterpret_tensor(buf356, (327680, 512), (512, 1), 0); del buf356 # reuse | |
| # Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf459, buf462, out=buf463) | |
| buf464 = buf335; del buf335 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_123, buf464, 1048576, stream=stream0) | |
| del primals_123 | |
| buf465 = reinterpret_tensor(buf355, (327680, 512), (512, 1), 0); del buf355 # reuse | |
| # Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf459, buf464, out=buf465) | |
| buf466 = buf353; del buf353 # reuse | |
| buf467 = buf339; del buf339 # reuse | |
| buf468 = reinterpret_tensor(buf360, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf360 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_permute_view_15.run(buf461, buf463, buf465, buf466, buf467, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf468, s91, 2560, 1, 16, stream=stream0) | |
| buf471 = reinterpret_tensor(buf333, (2048, 2048), (2048, 1), 0); del buf333 # reuse | |
| # Topologically Sorted Source Nodes: [permute_213, rearrange_3, o, mm_89], Original ATen: [aten.t, aten.permute, aten.view, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf443, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf468, (327680, 2048), (2048, 1), 0), out=buf471) | |
| buf472 = buf345; del buf345 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_124, buf472, 4194304, stream=stream0) | |
| del primals_124 | |
| buf473 = reinterpret_tensor(buf354, (327680, 2048), (2048, 1), 0); del buf354 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5, o, permute_215, mm_90], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf443, buf472, out=buf473) | |
| # Topologically Sorted Source Nodes: [all_reduce_44], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf471, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_44], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf471) | |
| buf478 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_606], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf471, buf478, 4194304, stream=stream0) | |
| buf480 = buf467; del buf467 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_319, view_320, permute_217, flex_attention_backward_3], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused_flex_attention_backward_permute_view_17.run(buf468, buf473, buf480, 5242880, 128, stream=stream0) | |
| buf481 = buf468; del buf468 # reuse | |
| buf482 = reinterpret_tensor(buf338, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf338 # reuse | |
| buf483 = reinterpret_tensor(buf336, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf336 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_319, view_320, permute_217, flex_attention_backward_3], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_backward_permute_view_18.run(buf461, buf463, buf465, buf466, buf480, buf473, buf481, buf482, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf483, s91, s16, 12800, 1, 4, stream=stream0) | |
| buf486 = reinterpret_tensor(buf337, (512, 2048), (2048, 1), 0); del buf337 # reuse | |
| # Topologically Sorted Source Nodes: [view_321, permute_218, view_322, permute_219, mm_91], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf482, (512, 327680), (1, 512), 0), buf459, out=buf486) | |
| buf487 = buf473; del buf473 # reuse | |
| # Topologically Sorted Source Nodes: [view_321, permute_218, view_322, permute_221, mm_92], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf482, (327680, 512), (512, 1), 0), reinterpret_tensor(buf464, (512, 2048), (2048, 1), 0), out=buf487) | |
| # Topologically Sorted Source Nodes: [all_reduce_45], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf486, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_45], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf486) | |
| buf492 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_611], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf486, buf492, 1048576, stream=stream0) | |
| buf493 = buf486; del buf486 # reuse | |
| # Topologically Sorted Source Nodes: [view_323, permute_223, view_324, permute_224, mm_93], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf483, (512, 327680), (1, 512), 0), buf459, out=buf493) | |
| buf494 = buf461; del buf461 # reuse | |
| # Topologically Sorted Source Nodes: [view_323, permute_223, view_324, permute_226, mm_94], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf483, (327680, 512), (512, 1), 0), reinterpret_tensor(buf462, (512, 2048), (2048, 1), 0), out=buf494) | |
| # Topologically Sorted Source Nodes: [all_reduce_46], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf493, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_46], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf493) | |
| buf499 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_616], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf493, buf499, 1048576, stream=stream0) | |
| buf500 = buf471; del buf471 # reuse | |
| # Topologically Sorted Source Nodes: [view_325, permute_228, view_326, permute_229, mm_95], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf481, (2048, 327680), (1, 2048), 0), buf459, out=buf500) | |
| buf501 = buf459; del buf459 # reuse | |
| # Topologically Sorted Source Nodes: [view_325, permute_228, view_326, permute_231, mm_96], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf481, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf460, (2048, 2048), (2048, 1), 0), out=buf501) | |
| # Topologically Sorted Source Nodes: [all_reduce_47], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf500, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_47], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf500) | |
| buf506 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_621], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf500, buf506, 4194304, stream=stream0) | |
| buf514 = buf443; del buf443 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, add_139, add_140, convert_element_type_622, convert_element_type_624, mul_257, mul_258, sum_41, mul_259, sum_42, mul_260, sub_57, sub_58, div_11, mul_261, mul_262, sum_43, sum_44, convert_element_type_626, add_141], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward] | |
| workspace_8 = workspace_7; del workspace_7 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf514, buf487, buf494, buf501, primals_119, add_75, buf456, buf457, workspace_8, 327680, 2048, stream=stream0) | |
| buf511 = workspace_8[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf513 = workspace_8[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_75 | |
| del primals_119 | |
| buf521 = buf450; del buf450 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_627, all_reduce_49], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf511, buf521, 2048, stream=stream0) | |
| buf515 = buf444; del buf444 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_628, all_reduce_48], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf513, buf515, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_628, all_reduce_48], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf515, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_48], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf515) | |
| buf520 = buf513; del buf513 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_629], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf515, buf520, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_627, all_reduce_49], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf521, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_49], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf521) | |
| buf526 = buf511; del buf511 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_630], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf521, buf526, 2048, stream=stream0) | |
| buf527 = reinterpret_tensor(buf424, (2048, 8192), (8192, 1), 0); del buf424 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_117, buf527, 16777216, stream=stream0) | |
| del primals_117 | |
| buf528 = buf422; del buf422 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, x_18, permute_233, mm_97], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf514, buf527, out=buf528) | |
| buf529 = buf457; del buf457 # reuse | |
| buf530 = buf456; del buf456 # reuse | |
| buf532 = buf501; del buf501 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_70, primals_113, primals_114, buf529, buf530, buf532, 327680, 2048, stream=stream0) | |
| del primals_114 | |
| buf533 = reinterpret_tensor(buf527, (2048, 8192), (1, 2048), 0); del buf527 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_115, buf533, 16777216, stream=stream0) | |
| del primals_115 | |
| buf534 = buf408; del buf408 # reuse | |
| # Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm] | |
| extern_kernels.mm(buf532, buf533, out=buf534) | |
| buf535 = buf407; del buf407 # reuse | |
| buf549 = buf528; del buf528 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, x, x_17, convert_element_type_637, mul_268, mul_269, sub_59, mul_270, add_144, mul_271, mul_272, mul_273, add_145, mul_274, convert_element_type_639], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf549, primals_116, buf534, buf535, 2684354560, stream=stream0) | |
| del primals_116 | |
| buf536 = buf409; del buf409 # reuse | |
| # Topologically Sorted Source Nodes: [permute_234, redistribute_3, x, x_17, mm_98], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf514, (2048, 327680), (1, 2048), 0), buf535, out=buf536) | |
| buf537 = buf410; del buf410 # reuse | |
| # Topologically Sorted Source Nodes: [sum_45], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_6.run(buf514, buf537, 327680, 2048, stream=stream0) | |
| buf538 = reinterpret_tensor(buf521, (1, 2048), (2048, 1), 0); del buf521 # reuse | |
| # Topologically Sorted Source Nodes: [sum_45], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_7.run(buf537, buf538, 2048, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_327, all_reduce_50], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf538, (2048, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_327, wait_tensor_50], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf538, (2048, ), (1, ), 0)) | |
| buf543 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_327, convert_element_type_635], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf538, buf543, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_51], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf536, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_51], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf536) | |
| buf548 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_636], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf536, buf548, 16777216, stream=stream0) | |
| buf550 = buf494; del buf494 # reuse | |
| # Topologically Sorted Source Nodes: [permute_237, mm_99], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(buf549, reinterpret_tensor(buf533, (8192, 2048), (2048, 1), 0), out=buf550) | |
| buf551 = reinterpret_tensor(buf533, (8192, 2048), (2048, 1), 0); del buf533 # reuse | |
| # Topologically Sorted Source Nodes: [permute_238, mm_100], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf549, (8192, 327680), (1, 8192), 0), buf532, out=buf551) | |
| buf552 = buf425; del buf425 # reuse | |
| # Topologically Sorted Source Nodes: [sum_46], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_9.run(buf549, buf552, 1310720, 2048, stream=stream0) | |
| buf553 = buf426; del buf426 # reuse | |
| # Topologically Sorted Source Nodes: [sum_46], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_10.run(buf552, buf553, 8192, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_328, all_reduce_52], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf553, (8192, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_328, wait_tensor_52], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf553, (8192, ), (1, ), 0)) | |
| buf558 = empty_strided_cuda((8192, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_328, convert_element_type_644], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_view_11.run(buf553, buf558, 8192, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_53], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf551, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_53], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf551) | |
| buf563 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_645], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf551, buf563, 16777216, stream=stream0) | |
| buf570 = buf514; del buf514 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_646, convert_element_type_648, mul_276, mul_277, sum_47, mul_278, sum_48, mul_279, sub_61, sub_62, div_12, mul_280, mul_281, sum_49, sum_50, convert_element_type_650, add_146], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add] | |
| workspace_9 = workspace_8; del workspace_8 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf570, buf550, primals_113, add_70, buf529, buf530, workspace_9, 327680, 2048, stream=stream0) | |
| buf567 = workspace_9[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf569 = workspace_9[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_70 | |
| del primals_113 | |
| buf577 = reinterpret_tensor(buf538, (2048, ), (1, ), 0); del buf538 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_651, all_reduce_55], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf567, buf577, 2048, stream=stream0) | |
| buf571 = buf515; del buf515 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_652, all_reduce_54], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf569, buf571, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_652, all_reduce_54], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf571, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_54], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf571) | |
| buf576 = buf569; del buf569 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_653], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf571, buf576, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_651, all_reduce_55], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf577, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_55], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf577) | |
| buf582 = buf567; del buf567 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_654], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf577, buf582, 2048, stream=stream0) | |
| buf583 = buf530; del buf530 # reuse | |
| buf584 = buf529; del buf529 # reuse | |
| buf586 = buf550; del buf550 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_67, primals_107, primals_108, buf583, buf584, buf586, 327680, 2048, stream=stream0) | |
| del primals_108 | |
| buf587 = reinterpret_tensor(buf500, (2048, 2048), (1, 2048), 0); del buf500 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_109, buf587, 4194304, stream=stream0) | |
| del primals_109 | |
| buf588 = buf532; del buf532 # reuse | |
| # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf586, buf587, out=buf588) | |
| buf589 = reinterpret_tensor(buf493, (2048, 512), (1, 2048), 0); del buf493 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_110, buf589, 1048576, stream=stream0) | |
| del primals_110 | |
| buf590 = reinterpret_tensor(buf483, (327680, 512), (512, 1), 0); del buf483 # reuse | |
| # Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf586, buf589, out=buf590) | |
| buf591 = buf462; del buf462 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_111, buf591, 1048576, stream=stream0) | |
| del primals_111 | |
| buf592 = reinterpret_tensor(buf482, (327680, 512), (512, 1), 0); del buf482 # reuse | |
| # Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf586, buf591, out=buf592) | |
| buf593 = buf480; del buf480 # reuse | |
| buf594 = buf466; del buf466 # reuse | |
| buf595 = reinterpret_tensor(buf487, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf487 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_permute_view_15.run(buf588, buf590, buf592, buf593, buf594, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf595, s91, 2560, 1, 16, stream=stream0) | |
| buf598 = reinterpret_tensor(buf460, (2048, 2048), (2048, 1), 0); del buf460 # reuse | |
| # Topologically Sorted Source Nodes: [permute_241, rearrange_3, o, mm_101], Original ATen: [aten.t, aten.permute, aten.view, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf570, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf595, (327680, 2048), (2048, 1), 0), out=buf598) | |
| buf599 = buf472; del buf472 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_112, buf599, 4194304, stream=stream0) | |
| del primals_112 | |
| buf600 = reinterpret_tensor(buf481, (327680, 2048), (2048, 1), 0); del buf481 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5, o, permute_243, mm_102], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf570, buf599, out=buf600) | |
| # Topologically Sorted Source Nodes: [all_reduce_56], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf598, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_56], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf598) | |
| buf605 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_659], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf598, buf605, 4194304, stream=stream0) | |
| buf607 = buf594; del buf594 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_331, view_332, permute_245, flex_attention_backward_4], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused_flex_attention_backward_permute_view_17.run(buf595, buf600, buf607, 5242880, 128, stream=stream0) | |
| buf608 = buf595; del buf595 # reuse | |
| buf609 = reinterpret_tensor(buf465, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf465 # reuse | |
| buf610 = reinterpret_tensor(buf463, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf463 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_331, view_332, permute_245, flex_attention_backward_4], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_backward_permute_view_18.run(buf588, buf590, buf592, buf593, buf607, buf600, buf608, buf609, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf610, s91, s16, 12800, 1, 4, stream=stream0) | |
| buf613 = reinterpret_tensor(buf464, (512, 2048), (2048, 1), 0); del buf464 # reuse | |
| # Topologically Sorted Source Nodes: [view_333, permute_246, view_334, permute_247, mm_103], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf609, (512, 327680), (1, 512), 0), buf586, out=buf613) | |
| buf614 = buf600; del buf600 # reuse | |
| # Topologically Sorted Source Nodes: [view_333, permute_246, view_334, permute_249, mm_104], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf609, (327680, 512), (512, 1), 0), reinterpret_tensor(buf591, (512, 2048), (2048, 1), 0), out=buf614) | |
| # Topologically Sorted Source Nodes: [all_reduce_57], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf613, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_57], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf613) | |
| buf619 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_664], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf613, buf619, 1048576, stream=stream0) | |
| buf620 = buf613; del buf613 # reuse | |
| # Topologically Sorted Source Nodes: [view_335, permute_251, view_336, permute_252, mm_105], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf610, (512, 327680), (1, 512), 0), buf586, out=buf620) | |
| buf621 = buf588; del buf588 # reuse | |
| # Topologically Sorted Source Nodes: [view_335, permute_251, view_336, permute_254, mm_106], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf610, (327680, 512), (512, 1), 0), reinterpret_tensor(buf589, (512, 2048), (2048, 1), 0), out=buf621) | |
| # Topologically Sorted Source Nodes: [all_reduce_58], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf620, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_58], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf620) | |
| buf626 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_669], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf620, buf626, 1048576, stream=stream0) | |
| buf627 = buf598; del buf598 # reuse | |
| # Topologically Sorted Source Nodes: [view_337, permute_256, view_338, permute_257, mm_107], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf608, (2048, 327680), (1, 2048), 0), buf586, out=buf627) | |
| buf628 = buf586; del buf586 # reuse | |
| # Topologically Sorted Source Nodes: [view_337, permute_256, view_338, permute_259, mm_108], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf608, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf587, (2048, 2048), (2048, 1), 0), out=buf628) | |
| # Topologically Sorted Source Nodes: [all_reduce_59], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf627, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_59], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf627) | |
| buf633 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_674], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf627, buf633, 4194304, stream=stream0) | |
| buf641 = buf570; del buf570 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, add_147, add_148, convert_element_type_675, convert_element_type_677, mul_283, mul_284, sum_51, mul_285, sum_52, mul_286, sub_64, sub_65, div_13, mul_287, mul_288, sum_53, sum_54, convert_element_type_679, add_149], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward] | |
| workspace_10 = workspace_9; del workspace_9 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf641, buf614, buf621, buf628, primals_107, add_67, buf583, buf584, workspace_10, 327680, 2048, stream=stream0) | |
| buf638 = workspace_10[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf640 = workspace_10[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_67 | |
| del primals_107 | |
| buf648 = buf577; del buf577 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_680, all_reduce_61], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf638, buf648, 2048, stream=stream0) | |
| buf642 = buf571; del buf571 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_681, all_reduce_60], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf640, buf642, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_681, all_reduce_60], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf642, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_60], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf642) | |
| buf647 = buf640; del buf640 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_682], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf642, buf647, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_680, all_reduce_61], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf648, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_61], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf648) | |
| buf653 = buf638; del buf638 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_683], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf648, buf653, 2048, stream=stream0) | |
| buf654 = reinterpret_tensor(buf551, (2048, 8192), (8192, 1), 0); del buf551 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_105, buf654, 16777216, stream=stream0) | |
| del primals_105 | |
| buf655 = buf549; del buf549 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, x_16, permute_261, mm_109], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf641, buf654, out=buf655) | |
| buf656 = buf584; del buf584 # reuse | |
| buf657 = buf583; del buf583 # reuse | |
| buf659 = buf628; del buf628 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_62, primals_101, primals_102, buf656, buf657, buf659, 327680, 2048, stream=stream0) | |
| del primals_102 | |
| buf660 = reinterpret_tensor(buf654, (2048, 8192), (1, 2048), 0); del buf654 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_103, buf660, 16777216, stream=stream0) | |
| del primals_103 | |
| buf661 = buf535; del buf535 # reuse | |
| # Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm] | |
| extern_kernels.mm(buf659, buf660, out=buf661) | |
| buf662 = buf534; del buf534 # reuse | |
| buf676 = buf655; del buf655 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, x, x_15, convert_element_type_690, mul_294, mul_295, sub_66, mul_296, add_152, mul_297, mul_298, mul_299, add_153, mul_300, convert_element_type_692], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf676, primals_104, buf661, buf662, 2684354560, stream=stream0) | |
| del primals_104 | |
| buf663 = buf536; del buf536 # reuse | |
| # Topologically Sorted Source Nodes: [permute_262, redistribute_3, x, x_15, mm_110], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf641, (2048, 327680), (1, 2048), 0), buf662, out=buf663) | |
| buf664 = buf537; del buf537 # reuse | |
| # Topologically Sorted Source Nodes: [sum_55], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_6.run(buf641, buf664, 327680, 2048, stream=stream0) | |
| buf665 = reinterpret_tensor(buf648, (1, 2048), (2048, 1), 0); del buf648 # reuse | |
| # Topologically Sorted Source Nodes: [sum_55], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_7.run(buf664, buf665, 2048, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_339, all_reduce_62], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf665, (2048, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_339, wait_tensor_62], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf665, (2048, ), (1, ), 0)) | |
| buf670 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_339, convert_element_type_688], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf665, buf670, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_63], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf663, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_63], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf663) | |
| buf675 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_689], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf663, buf675, 16777216, stream=stream0) | |
| buf677 = buf621; del buf621 # reuse | |
| # Topologically Sorted Source Nodes: [permute_265, mm_111], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(buf676, reinterpret_tensor(buf660, (8192, 2048), (2048, 1), 0), out=buf677) | |
| buf678 = reinterpret_tensor(buf660, (8192, 2048), (2048, 1), 0); del buf660 # reuse | |
| # Topologically Sorted Source Nodes: [permute_266, mm_112], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf676, (8192, 327680), (1, 8192), 0), buf659, out=buf678) | |
| buf679 = buf552; del buf552 # reuse | |
| # Topologically Sorted Source Nodes: [sum_56], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_9.run(buf676, buf679, 1310720, 2048, stream=stream0) | |
| buf680 = buf553; del buf553 # reuse | |
| # Topologically Sorted Source Nodes: [sum_56], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_10.run(buf679, buf680, 8192, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_340, all_reduce_64], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf680, (8192, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_340, wait_tensor_64], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf680, (8192, ), (1, ), 0)) | |
| buf685 = empty_strided_cuda((8192, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_340, convert_element_type_697], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_view_11.run(buf680, buf685, 8192, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_65], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf678, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_65], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf678) | |
| buf690 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_698], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf678, buf690, 16777216, stream=stream0) | |
| buf697 = buf641; del buf641 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_699, convert_element_type_701, mul_302, mul_303, sum_57, mul_304, sum_58, mul_305, sub_68, sub_69, div_14, mul_306, mul_307, sum_59, sum_60, convert_element_type_703, add_154], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add] | |
| workspace_11 = workspace_10; del workspace_10 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf697, buf677, primals_101, add_62, buf656, buf657, workspace_11, 327680, 2048, stream=stream0) | |
| buf694 = workspace_11[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf696 = workspace_11[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_62 | |
| del primals_101 | |
| buf704 = reinterpret_tensor(buf665, (2048, ), (1, ), 0); del buf665 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_704, all_reduce_67], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf694, buf704, 2048, stream=stream0) | |
| buf698 = buf642; del buf642 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_705, all_reduce_66], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf696, buf698, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_705, all_reduce_66], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf698, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_66], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf698) | |
| buf703 = buf696; del buf696 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_706], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf698, buf703, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_704, all_reduce_67], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf704, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_67], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf704) | |
| buf709 = buf694; del buf694 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_707], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf704, buf709, 2048, stream=stream0) | |
| buf710 = buf657; del buf657 # reuse | |
| buf711 = buf656; del buf656 # reuse | |
| buf713 = buf677; del buf677 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_59, primals_95, primals_96, buf710, buf711, buf713, 327680, 2048, stream=stream0) | |
| del primals_96 | |
| buf714 = reinterpret_tensor(buf627, (2048, 2048), (1, 2048), 0); del buf627 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_97, buf714, 4194304, stream=stream0) | |
| del primals_97 | |
| buf715 = buf659; del buf659 # reuse | |
| # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf713, buf714, out=buf715) | |
| buf716 = reinterpret_tensor(buf620, (2048, 512), (1, 2048), 0); del buf620 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_98, buf716, 1048576, stream=stream0) | |
| del primals_98 | |
| buf717 = reinterpret_tensor(buf610, (327680, 512), (512, 1), 0); del buf610 # reuse | |
| # Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf713, buf716, out=buf717) | |
| buf718 = buf589; del buf589 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_99, buf718, 1048576, stream=stream0) | |
| del primals_99 | |
| buf719 = reinterpret_tensor(buf609, (327680, 512), (512, 1), 0); del buf609 # reuse | |
| # Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf713, buf718, out=buf719) | |
| buf720 = buf607; del buf607 # reuse | |
| buf721 = buf593; del buf593 # reuse | |
| buf722 = reinterpret_tensor(buf614, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf614 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_permute_view_15.run(buf715, buf717, buf719, buf720, buf721, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf722, s91, 2560, 1, 16, stream=stream0) | |
| buf725 = reinterpret_tensor(buf587, (2048, 2048), (2048, 1), 0); del buf587 # reuse | |
| # Topologically Sorted Source Nodes: [permute_269, rearrange_3, o, mm_113], Original ATen: [aten.t, aten.permute, aten.view, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf697, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf722, (327680, 2048), (2048, 1), 0), out=buf725) | |
| buf726 = buf599; del buf599 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_100, buf726, 4194304, stream=stream0) | |
| del primals_100 | |
| buf727 = reinterpret_tensor(buf608, (327680, 2048), (2048, 1), 0); del buf608 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5, o, permute_271, mm_114], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf697, buf726, out=buf727) | |
| # Topologically Sorted Source Nodes: [all_reduce_68], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf725, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_68], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf725) | |
| buf732 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_712], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf725, buf732, 4194304, stream=stream0) | |
| buf734 = buf721; del buf721 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_343, view_344, permute_273, flex_attention_backward_5], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused_flex_attention_backward_permute_view_17.run(buf722, buf727, buf734, 5242880, 128, stream=stream0) | |
| buf735 = buf722; del buf722 # reuse | |
| buf736 = reinterpret_tensor(buf592, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf592 # reuse | |
| buf737 = reinterpret_tensor(buf590, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf590 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_343, view_344, permute_273, flex_attention_backward_5], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_backward_permute_view_18.run(buf715, buf717, buf719, buf720, buf734, buf727, buf735, buf736, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf737, s91, s16, 12800, 1, 4, stream=stream0) | |
| buf740 = reinterpret_tensor(buf591, (512, 2048), (2048, 1), 0); del buf591 # reuse | |
| # Topologically Sorted Source Nodes: [view_345, permute_274, view_346, permute_275, mm_115], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf736, (512, 327680), (1, 512), 0), buf713, out=buf740) | |
| buf741 = buf727; del buf727 # reuse | |
| # Topologically Sorted Source Nodes: [view_345, permute_274, view_346, permute_277, mm_116], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf736, (327680, 512), (512, 1), 0), reinterpret_tensor(buf718, (512, 2048), (2048, 1), 0), out=buf741) | |
| # Topologically Sorted Source Nodes: [all_reduce_69], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf740, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_69], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf740) | |
| buf746 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_717], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf740, buf746, 1048576, stream=stream0) | |
| buf747 = buf740; del buf740 # reuse | |
| # Topologically Sorted Source Nodes: [view_347, permute_279, view_348, permute_280, mm_117], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf737, (512, 327680), (1, 512), 0), buf713, out=buf747) | |
| buf748 = buf715; del buf715 # reuse | |
| # Topologically Sorted Source Nodes: [view_347, permute_279, view_348, permute_282, mm_118], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf737, (327680, 512), (512, 1), 0), reinterpret_tensor(buf716, (512, 2048), (2048, 1), 0), out=buf748) | |
| # Topologically Sorted Source Nodes: [all_reduce_70], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf747, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_70], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf747) | |
| buf753 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_722], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf747, buf753, 1048576, stream=stream0) | |
| buf754 = buf725; del buf725 # reuse | |
| # Topologically Sorted Source Nodes: [view_349, permute_284, view_350, permute_285, mm_119], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf735, (2048, 327680), (1, 2048), 0), buf713, out=buf754) | |
| buf755 = buf713; del buf713 # reuse | |
| # Topologically Sorted Source Nodes: [view_349, permute_284, view_350, permute_287, mm_120], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf735, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf714, (2048, 2048), (2048, 1), 0), out=buf755) | |
| # Topologically Sorted Source Nodes: [all_reduce_71], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf754, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_71], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf754) | |
| buf760 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_727], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf754, buf760, 4194304, stream=stream0) | |
| buf768 = buf697; del buf697 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, add_155, add_156, convert_element_type_728, convert_element_type_730, mul_309, mul_310, sum_61, mul_311, sum_62, mul_312, sub_71, sub_72, div_15, mul_313, mul_314, sum_63, sum_64, convert_element_type_732, add_157], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward] | |
| workspace_12 = workspace_11; del workspace_11 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf768, buf741, buf748, buf755, primals_95, add_59, buf710, buf711, workspace_12, 327680, 2048, stream=stream0) | |
| buf765 = workspace_12[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf767 = workspace_12[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_59 | |
| del primals_95 | |
| buf775 = buf704; del buf704 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_733, all_reduce_73], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf765, buf775, 2048, stream=stream0) | |
| buf769 = buf698; del buf698 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_734, all_reduce_72], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf767, buf769, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_734, all_reduce_72], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf769, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_72], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf769) | |
| buf774 = buf767; del buf767 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_735], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf769, buf774, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_733, all_reduce_73], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf775, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_73], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf775) | |
| buf780 = buf765; del buf765 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_736], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf775, buf780, 2048, stream=stream0) | |
| buf781 = reinterpret_tensor(buf678, (2048, 8192), (8192, 1), 0); del buf678 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_93, buf781, 16777216, stream=stream0) | |
| del primals_93 | |
| buf782 = buf676; del buf676 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, x_14, permute_289, mm_121], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf768, buf781, out=buf782) | |
| buf783 = buf711; del buf711 # reuse | |
| buf784 = buf710; del buf710 # reuse | |
| buf786 = buf755; del buf755 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_54, primals_89, primals_90, buf783, buf784, buf786, 327680, 2048, stream=stream0) | |
| del primals_90 | |
| buf787 = reinterpret_tensor(buf781, (2048, 8192), (1, 2048), 0); del buf781 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_91, buf787, 16777216, stream=stream0) | |
| del primals_91 | |
| buf788 = buf662; del buf662 # reuse | |
| # Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm] | |
| extern_kernels.mm(buf786, buf787, out=buf788) | |
| buf789 = buf661; del buf661 # reuse | |
| buf803 = buf782; del buf782 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, x, x_13, convert_element_type_743, mul_320, mul_321, sub_73, mul_322, add_160, mul_323, mul_324, mul_325, add_161, mul_326, convert_element_type_745], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf803, primals_92, buf788, buf789, 2684354560, stream=stream0) | |
| del primals_92 | |
| buf790 = buf663; del buf663 # reuse | |
| # Topologically Sorted Source Nodes: [permute_290, redistribute_3, x, x_13, mm_122], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf768, (2048, 327680), (1, 2048), 0), buf789, out=buf790) | |
| buf791 = buf664; del buf664 # reuse | |
| # Topologically Sorted Source Nodes: [sum_65], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_6.run(buf768, buf791, 327680, 2048, stream=stream0) | |
| buf792 = reinterpret_tensor(buf775, (1, 2048), (2048, 1), 0); del buf775 # reuse | |
| # Topologically Sorted Source Nodes: [sum_65], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_7.run(buf791, buf792, 2048, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_351, all_reduce_74], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf792, (2048, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_351, wait_tensor_74], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf792, (2048, ), (1, ), 0)) | |
| buf797 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_351, convert_element_type_741], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf792, buf797, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_75], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf790, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_75], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf790) | |
| buf802 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_742], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf790, buf802, 16777216, stream=stream0) | |
| buf804 = buf748; del buf748 # reuse | |
| # Topologically Sorted Source Nodes: [permute_293, mm_123], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(buf803, reinterpret_tensor(buf787, (8192, 2048), (2048, 1), 0), out=buf804) | |
| buf805 = reinterpret_tensor(buf787, (8192, 2048), (2048, 1), 0); del buf787 # reuse | |
| # Topologically Sorted Source Nodes: [permute_294, mm_124], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf803, (8192, 327680), (1, 8192), 0), buf786, out=buf805) | |
| buf806 = buf679; del buf679 # reuse | |
| # Topologically Sorted Source Nodes: [sum_66], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_9.run(buf803, buf806, 1310720, 2048, stream=stream0) | |
| buf807 = buf680; del buf680 # reuse | |
| # Topologically Sorted Source Nodes: [sum_66], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_10.run(buf806, buf807, 8192, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_352, all_reduce_76], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf807, (8192, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_352, wait_tensor_76], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf807, (8192, ), (1, ), 0)) | |
| buf812 = empty_strided_cuda((8192, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_352, convert_element_type_750], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_view_11.run(buf807, buf812, 8192, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_77], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf805, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_77], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf805) | |
| buf817 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_751], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf805, buf817, 16777216, stream=stream0) | |
| buf824 = buf768; del buf768 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_752, convert_element_type_754, mul_328, mul_329, sum_67, mul_330, sum_68, mul_331, sub_75, sub_76, div_16, mul_332, mul_333, sum_69, sum_70, convert_element_type_756, add_162], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add] | |
| workspace_13 = workspace_12; del workspace_12 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf824, buf804, primals_89, add_54, buf783, buf784, workspace_13, 327680, 2048, stream=stream0) | |
| buf821 = workspace_13[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf823 = workspace_13[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_54 | |
| del primals_89 | |
| buf831 = reinterpret_tensor(buf792, (2048, ), (1, ), 0); del buf792 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_757, all_reduce_79], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf821, buf831, 2048, stream=stream0) | |
| buf825 = buf769; del buf769 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_758, all_reduce_78], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf823, buf825, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_758, all_reduce_78], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf825, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_78], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf825) | |
| buf830 = buf823; del buf823 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_759], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf825, buf830, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_757, all_reduce_79], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf831, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_79], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf831) | |
| buf836 = buf821; del buf821 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_760], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf831, buf836, 2048, stream=stream0) | |
| buf837 = buf784; del buf784 # reuse | |
| buf838 = buf783; del buf783 # reuse | |
| buf840 = buf804; del buf804 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_51, primals_83, primals_84, buf837, buf838, buf840, 327680, 2048, stream=stream0) | |
| del primals_84 | |
| buf841 = reinterpret_tensor(buf754, (2048, 2048), (1, 2048), 0); del buf754 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_85, buf841, 4194304, stream=stream0) | |
| del primals_85 | |
| buf842 = buf786; del buf786 # reuse | |
| # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf840, buf841, out=buf842) | |
| buf843 = reinterpret_tensor(buf747, (2048, 512), (1, 2048), 0); del buf747 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_86, buf843, 1048576, stream=stream0) | |
| del primals_86 | |
| buf844 = reinterpret_tensor(buf737, (327680, 512), (512, 1), 0); del buf737 # reuse | |
| # Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf840, buf843, out=buf844) | |
| buf845 = buf716; del buf716 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_87, buf845, 1048576, stream=stream0) | |
| del primals_87 | |
| buf846 = reinterpret_tensor(buf736, (327680, 512), (512, 1), 0); del buf736 # reuse | |
| # Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf840, buf845, out=buf846) | |
| buf847 = buf734; del buf734 # reuse | |
| buf848 = buf720; del buf720 # reuse | |
| buf849 = reinterpret_tensor(buf741, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf741 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_permute_view_15.run(buf842, buf844, buf846, buf847, buf848, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf849, s91, 2560, 1, 16, stream=stream0) | |
| buf852 = reinterpret_tensor(buf714, (2048, 2048), (2048, 1), 0); del buf714 # reuse | |
| # Topologically Sorted Source Nodes: [permute_297, rearrange_3, o, mm_125], Original ATen: [aten.t, aten.permute, aten.view, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf824, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf849, (327680, 2048), (2048, 1), 0), out=buf852) | |
| buf853 = buf726; del buf726 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_88, buf853, 4194304, stream=stream0) | |
| del primals_88 | |
| buf854 = reinterpret_tensor(buf735, (327680, 2048), (2048, 1), 0); del buf735 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5, o, permute_299, mm_126], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf824, buf853, out=buf854) | |
| # Topologically Sorted Source Nodes: [all_reduce_80], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf852, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_80], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf852) | |
| buf859 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_765], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf852, buf859, 4194304, stream=stream0) | |
| buf861 = buf848; del buf848 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_355, view_356, permute_301, flex_attention_backward_6], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused_flex_attention_backward_permute_view_17.run(buf849, buf854, buf861, 5242880, 128, stream=stream0) | |
| buf862 = buf849; del buf849 # reuse | |
| buf863 = reinterpret_tensor(buf719, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf719 # reuse | |
| buf864 = reinterpret_tensor(buf717, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf717 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_355, view_356, permute_301, flex_attention_backward_6], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_backward_permute_view_18.run(buf842, buf844, buf846, buf847, buf861, buf854, buf862, buf863, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf864, s91, s16, 12800, 1, 4, stream=stream0) | |
| buf867 = reinterpret_tensor(buf718, (512, 2048), (2048, 1), 0); del buf718 # reuse | |
| # Topologically Sorted Source Nodes: [view_357, permute_302, view_358, permute_303, mm_127], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf863, (512, 327680), (1, 512), 0), buf840, out=buf867) | |
| buf868 = buf854; del buf854 # reuse | |
| # Topologically Sorted Source Nodes: [view_357, permute_302, view_358, permute_305, mm_128], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf863, (327680, 512), (512, 1), 0), reinterpret_tensor(buf845, (512, 2048), (2048, 1), 0), out=buf868) | |
| # Topologically Sorted Source Nodes: [all_reduce_81], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf867, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_81], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf867) | |
| buf873 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_770], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf867, buf873, 1048576, stream=stream0) | |
| buf874 = buf867; del buf867 # reuse | |
| # Topologically Sorted Source Nodes: [view_359, permute_307, view_360, permute_308, mm_129], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf864, (512, 327680), (1, 512), 0), buf840, out=buf874) | |
| buf875 = buf842; del buf842 # reuse | |
| # Topologically Sorted Source Nodes: [view_359, permute_307, view_360, permute_310, mm_130], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf864, (327680, 512), (512, 1), 0), reinterpret_tensor(buf843, (512, 2048), (2048, 1), 0), out=buf875) | |
| # Topologically Sorted Source Nodes: [all_reduce_82], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf874, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_82], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf874) | |
| buf880 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_775], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf874, buf880, 1048576, stream=stream0) | |
| buf881 = buf852; del buf852 # reuse | |
| # Topologically Sorted Source Nodes: [view_361, permute_312, view_362, permute_313, mm_131], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf862, (2048, 327680), (1, 2048), 0), buf840, out=buf881) | |
| buf882 = buf840; del buf840 # reuse | |
| # Topologically Sorted Source Nodes: [view_361, permute_312, view_362, permute_315, mm_132], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf862, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf841, (2048, 2048), (2048, 1), 0), out=buf882) | |
| # Topologically Sorted Source Nodes: [all_reduce_83], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf881, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_83], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf881) | |
| buf887 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_780], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf881, buf887, 4194304, stream=stream0) | |
| buf895 = buf824; del buf824 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, add_163, add_164, convert_element_type_781, convert_element_type_783, mul_335, mul_336, sum_71, mul_337, sum_72, mul_338, sub_78, sub_79, div_17, mul_339, mul_340, sum_73, sum_74, convert_element_type_785, add_165], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward] | |
| workspace_14 = workspace_13; del workspace_13 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf895, buf868, buf875, buf882, primals_83, add_51, buf837, buf838, workspace_14, 327680, 2048, stream=stream0) | |
| buf892 = workspace_14[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf894 = workspace_14[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_51 | |
| del primals_83 | |
| buf902 = buf831; del buf831 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_786, all_reduce_85], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf892, buf902, 2048, stream=stream0) | |
| buf896 = buf825; del buf825 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_787, all_reduce_84], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf894, buf896, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_787, all_reduce_84], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf896, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_84], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf896) | |
| buf901 = buf894; del buf894 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_788], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf896, buf901, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_786, all_reduce_85], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf902, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_85], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf902) | |
| buf907 = buf892; del buf892 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_789], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf902, buf907, 2048, stream=stream0) | |
| buf908 = reinterpret_tensor(buf805, (2048, 8192), (8192, 1), 0); del buf805 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_81, buf908, 16777216, stream=stream0) | |
| del primals_81 | |
| buf909 = buf803; del buf803 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, x_12, permute_317, mm_133], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf895, buf908, out=buf909) | |
| buf910 = buf838; del buf838 # reuse | |
| buf911 = buf837; del buf837 # reuse | |
| buf913 = buf882; del buf882 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_46, primals_77, primals_78, buf910, buf911, buf913, 327680, 2048, stream=stream0) | |
| del primals_78 | |
| buf914 = reinterpret_tensor(buf908, (2048, 8192), (1, 2048), 0); del buf908 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_79, buf914, 16777216, stream=stream0) | |
| del primals_79 | |
| buf915 = buf789; del buf789 # reuse | |
| # Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm] | |
| extern_kernels.mm(buf913, buf914, out=buf915) | |
| buf916 = buf788; del buf788 # reuse | |
| buf930 = buf909; del buf909 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, x, x_11, convert_element_type_796, mul_346, mul_347, sub_80, mul_348, add_168, mul_349, mul_350, mul_351, add_169, mul_352, convert_element_type_798], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf930, primals_80, buf915, buf916, 2684354560, stream=stream0) | |
| del primals_80 | |
| buf917 = buf790; del buf790 # reuse | |
| # Topologically Sorted Source Nodes: [permute_318, redistribute_3, x, x_11, mm_134], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf895, (2048, 327680), (1, 2048), 0), buf916, out=buf917) | |
| buf918 = buf791; del buf791 # reuse | |
| # Topologically Sorted Source Nodes: [sum_75], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_6.run(buf895, buf918, 327680, 2048, stream=stream0) | |
| buf919 = reinterpret_tensor(buf902, (1, 2048), (2048, 1), 0); del buf902 # reuse | |
| # Topologically Sorted Source Nodes: [sum_75], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_7.run(buf918, buf919, 2048, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_363, all_reduce_86], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf919, (2048, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_363, wait_tensor_86], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf919, (2048, ), (1, ), 0)) | |
| buf924 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_363, convert_element_type_794], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf919, buf924, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_87], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf917, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_87], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf917) | |
| buf929 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_795], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf917, buf929, 16777216, stream=stream0) | |
| buf931 = buf875; del buf875 # reuse | |
| # Topologically Sorted Source Nodes: [permute_321, mm_135], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(buf930, reinterpret_tensor(buf914, (8192, 2048), (2048, 1), 0), out=buf931) | |
| buf932 = reinterpret_tensor(buf914, (8192, 2048), (2048, 1), 0); del buf914 # reuse | |
| # Topologically Sorted Source Nodes: [permute_322, mm_136], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf930, (8192, 327680), (1, 8192), 0), buf913, out=buf932) | |
| buf933 = buf806; del buf806 # reuse | |
| # Topologically Sorted Source Nodes: [sum_76], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_9.run(buf930, buf933, 1310720, 2048, stream=stream0) | |
| buf934 = buf807; del buf807 # reuse | |
| # Topologically Sorted Source Nodes: [sum_76], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_10.run(buf933, buf934, 8192, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_364, all_reduce_88], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf934, (8192, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_364, wait_tensor_88], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf934, (8192, ), (1, ), 0)) | |
| buf939 = empty_strided_cuda((8192, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_364, convert_element_type_803], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_view_11.run(buf934, buf939, 8192, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_89], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf932, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_89], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf932) | |
| buf944 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_804], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf932, buf944, 16777216, stream=stream0) | |
| buf951 = buf895; del buf895 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_805, convert_element_type_807, mul_354, mul_355, sum_77, mul_356, sum_78, mul_357, sub_82, sub_83, div_18, mul_358, mul_359, sum_79, sum_80, convert_element_type_809, add_170], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add] | |
| workspace_15 = workspace_14; del workspace_14 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf951, buf931, primals_77, add_46, buf910, buf911, workspace_15, 327680, 2048, stream=stream0) | |
| buf948 = workspace_15[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf950 = workspace_15[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_46 | |
| del primals_77 | |
| buf958 = reinterpret_tensor(buf919, (2048, ), (1, ), 0); del buf919 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_810, all_reduce_91], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf948, buf958, 2048, stream=stream0) | |
| buf952 = buf896; del buf896 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_811, all_reduce_90], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf950, buf952, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_811, all_reduce_90], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf952, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_90], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf952) | |
| buf957 = buf950; del buf950 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_812], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf952, buf957, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_810, all_reduce_91], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf958, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_91], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf958) | |
| buf963 = buf948; del buf948 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_813], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf958, buf963, 2048, stream=stream0) | |
| buf964 = buf911; del buf911 # reuse | |
| buf965 = buf910; del buf910 # reuse | |
| buf967 = buf931; del buf931 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_43, primals_71, primals_72, buf964, buf965, buf967, 327680, 2048, stream=stream0) | |
| del primals_72 | |
| buf968 = reinterpret_tensor(buf881, (2048, 2048), (1, 2048), 0); del buf881 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_73, buf968, 4194304, stream=stream0) | |
| del primals_73 | |
| buf969 = buf913; del buf913 # reuse | |
| # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf967, buf968, out=buf969) | |
| buf970 = reinterpret_tensor(buf874, (2048, 512), (1, 2048), 0); del buf874 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_74, buf970, 1048576, stream=stream0) | |
| del primals_74 | |
| buf971 = reinterpret_tensor(buf864, (327680, 512), (512, 1), 0); del buf864 # reuse | |
| # Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf967, buf970, out=buf971) | |
| buf972 = buf843; del buf843 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_75, buf972, 1048576, stream=stream0) | |
| del primals_75 | |
| buf973 = reinterpret_tensor(buf863, (327680, 512), (512, 1), 0); del buf863 # reuse | |
| # Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf967, buf972, out=buf973) | |
| buf974 = buf861; del buf861 # reuse | |
| buf975 = buf847; del buf847 # reuse | |
| buf976 = reinterpret_tensor(buf868, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf868 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_permute_view_15.run(buf969, buf971, buf973, buf974, buf975, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf976, s91, 2560, 1, 16, stream=stream0) | |
| buf979 = reinterpret_tensor(buf841, (2048, 2048), (2048, 1), 0); del buf841 # reuse | |
| # Topologically Sorted Source Nodes: [permute_325, rearrange_3, o, mm_137], Original ATen: [aten.t, aten.permute, aten.view, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf951, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf976, (327680, 2048), (2048, 1), 0), out=buf979) | |
| buf980 = buf853; del buf853 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_76, buf980, 4194304, stream=stream0) | |
| del primals_76 | |
| buf981 = reinterpret_tensor(buf862, (327680, 2048), (2048, 1), 0); del buf862 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5, o, permute_327, mm_138], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf951, buf980, out=buf981) | |
| # Topologically Sorted Source Nodes: [all_reduce_92], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf979, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_92], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf979) | |
| buf986 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_818], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf979, buf986, 4194304, stream=stream0) | |
| buf988 = buf975; del buf975 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_367, view_368, permute_329, flex_attention_backward_7], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused_flex_attention_backward_permute_view_17.run(buf976, buf981, buf988, 5242880, 128, stream=stream0) | |
| buf989 = buf976; del buf976 # reuse | |
| buf990 = reinterpret_tensor(buf846, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf846 # reuse | |
| buf991 = reinterpret_tensor(buf844, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf844 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_367, view_368, permute_329, flex_attention_backward_7], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_backward_permute_view_18.run(buf969, buf971, buf973, buf974, buf988, buf981, buf989, buf990, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf991, s91, s16, 12800, 1, 4, stream=stream0) | |
| buf994 = reinterpret_tensor(buf845, (512, 2048), (2048, 1), 0); del buf845 # reuse | |
| # Topologically Sorted Source Nodes: [view_369, permute_330, view_370, permute_331, mm_139], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf990, (512, 327680), (1, 512), 0), buf967, out=buf994) | |
| buf995 = buf981; del buf981 # reuse | |
| # Topologically Sorted Source Nodes: [view_369, permute_330, view_370, permute_333, mm_140], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf990, (327680, 512), (512, 1), 0), reinterpret_tensor(buf972, (512, 2048), (2048, 1), 0), out=buf995) | |
| # Topologically Sorted Source Nodes: [all_reduce_93], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf994, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_93], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf994) | |
| buf1000 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_823], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf994, buf1000, 1048576, stream=stream0) | |
| buf1001 = buf994; del buf994 # reuse | |
| # Topologically Sorted Source Nodes: [view_371, permute_335, view_372, permute_336, mm_141], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf991, (512, 327680), (1, 512), 0), buf967, out=buf1001) | |
| buf1002 = buf969; del buf969 # reuse | |
| # Topologically Sorted Source Nodes: [view_371, permute_335, view_372, permute_338, mm_142], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf991, (327680, 512), (512, 1), 0), reinterpret_tensor(buf970, (512, 2048), (2048, 1), 0), out=buf1002) | |
| # Topologically Sorted Source Nodes: [all_reduce_94], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1001, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_94], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1001) | |
| buf1007 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_828], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf1001, buf1007, 1048576, stream=stream0) | |
| buf1008 = buf979; del buf979 # reuse | |
| # Topologically Sorted Source Nodes: [view_373, permute_340, view_374, permute_341, mm_143], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf989, (2048, 327680), (1, 2048), 0), buf967, out=buf1008) | |
| buf1009 = buf967; del buf967 # reuse | |
| # Topologically Sorted Source Nodes: [view_373, permute_340, view_374, permute_343, mm_144], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf989, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf968, (2048, 2048), (2048, 1), 0), out=buf1009) | |
| # Topologically Sorted Source Nodes: [all_reduce_95], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1008, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_95], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1008) | |
| buf1014 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_833], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf1008, buf1014, 4194304, stream=stream0) | |
| buf1022 = buf951; del buf951 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, add_171, add_172, convert_element_type_834, convert_element_type_836, mul_361, mul_362, sum_81, mul_363, sum_82, mul_364, sub_85, sub_86, div_19, mul_365, mul_366, sum_83, sum_84, convert_element_type_838, add_173], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward] | |
| workspace_16 = workspace_15; del workspace_15 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf1022, buf995, buf1002, buf1009, primals_71, add_43, buf964, buf965, workspace_16, 327680, 2048, stream=stream0) | |
| buf1019 = workspace_16[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf1021 = workspace_16[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_43 | |
| del primals_71 | |
| buf1029 = buf958; del buf958 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_839, all_reduce_97], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1019, buf1029, 2048, stream=stream0) | |
| buf1023 = buf952; del buf952 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_840, all_reduce_96], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1021, buf1023, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_840, all_reduce_96], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1023, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_96], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1023) | |
| buf1028 = buf1021; del buf1021 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_841], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1023, buf1028, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_839, all_reduce_97], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1029, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_97], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1029) | |
| buf1034 = buf1019; del buf1019 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_842], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1029, buf1034, 2048, stream=stream0) | |
| buf1035 = reinterpret_tensor(buf932, (2048, 8192), (8192, 1), 0); del buf932 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_69, buf1035, 16777216, stream=stream0) | |
| del primals_69 | |
| buf1036 = buf930; del buf930 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, x_10, permute_345, mm_145], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf1022, buf1035, out=buf1036) | |
| buf1037 = buf965; del buf965 # reuse | |
| buf1038 = buf964; del buf964 # reuse | |
| buf1040 = buf995; del buf995 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_38, primals_65, primals_66, buf1037, buf1038, buf1040, 327680, 2048, stream=stream0) | |
| del primals_66 | |
| buf1041 = reinterpret_tensor(buf1035, (2048, 8192), (1, 2048), 0); del buf1035 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_67, buf1041, 16777216, stream=stream0) | |
| del primals_67 | |
| buf1042 = buf916; del buf916 # reuse | |
| # Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm] | |
| extern_kernels.mm(buf1040, buf1041, out=buf1042) | |
| buf1043 = buf915; del buf915 # reuse | |
| buf1057 = buf1036; del buf1036 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, x, x_9, convert_element_type_849, mul_372, mul_373, sub_87, mul_374, add_176, mul_375, mul_376, mul_377, add_177, mul_378, convert_element_type_851], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf1057, primals_68, buf1042, buf1043, 2684354560, stream=stream0) | |
| del primals_68 | |
| buf1044 = buf917; del buf917 # reuse | |
| # Topologically Sorted Source Nodes: [permute_346, redistribute_3, x, x_9, mm_146], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1022, (2048, 327680), (1, 2048), 0), buf1043, out=buf1044) | |
| buf1045 = buf918; del buf918 # reuse | |
| # Topologically Sorted Source Nodes: [sum_85], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_6.run(buf1022, buf1045, 327680, 2048, stream=stream0) | |
| buf1046 = reinterpret_tensor(buf1029, (1, 2048), (2048, 1), 0); del buf1029 # reuse | |
| # Topologically Sorted Source Nodes: [sum_85], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_7.run(buf1045, buf1046, 2048, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_375, all_reduce_98], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1046, (2048, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_375, wait_tensor_98], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1046, (2048, ), (1, ), 0)) | |
| buf1051 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_375, convert_element_type_847], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1046, buf1051, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_99], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1044, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_99], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1044) | |
| buf1056 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_848], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf1044, buf1056, 16777216, stream=stream0) | |
| buf1058 = buf1009; del buf1009 # reuse | |
| # Topologically Sorted Source Nodes: [permute_349, mm_147], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(buf1057, reinterpret_tensor(buf1041, (8192, 2048), (2048, 1), 0), out=buf1058) | |
| buf1059 = reinterpret_tensor(buf1041, (8192, 2048), (2048, 1), 0); del buf1041 # reuse | |
| # Topologically Sorted Source Nodes: [permute_350, mm_148], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1057, (8192, 327680), (1, 8192), 0), buf1040, out=buf1059) | |
| buf1060 = buf933; del buf933 # reuse | |
| # Topologically Sorted Source Nodes: [sum_86], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_9.run(buf1057, buf1060, 1310720, 2048, stream=stream0) | |
| buf1061 = buf934; del buf934 # reuse | |
| # Topologically Sorted Source Nodes: [sum_86], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_10.run(buf1060, buf1061, 8192, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_376, all_reduce_100], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1061, (8192, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_376, wait_tensor_100], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1061, (8192, ), (1, ), 0)) | |
| buf1066 = empty_strided_cuda((8192, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_376, convert_element_type_856], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_view_11.run(buf1061, buf1066, 8192, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_101], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1059, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_101], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1059) | |
| buf1071 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_857], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf1059, buf1071, 16777216, stream=stream0) | |
| buf1078 = buf1022; del buf1022 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_858, convert_element_type_860, mul_380, mul_381, sum_87, mul_382, sum_88, mul_383, sub_89, sub_90, div_20, mul_384, mul_385, sum_89, sum_90, convert_element_type_862, add_178], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add] | |
| workspace_17 = workspace_16; del workspace_16 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf1078, buf1058, primals_65, add_38, buf1037, buf1038, workspace_17, 327680, 2048, stream=stream0) | |
| buf1075 = workspace_17[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf1077 = workspace_17[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_38 | |
| del primals_65 | |
| buf1085 = reinterpret_tensor(buf1046, (2048, ), (1, ), 0); del buf1046 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_863, all_reduce_103], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1075, buf1085, 2048, stream=stream0) | |
| buf1079 = buf1023; del buf1023 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_864, all_reduce_102], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1077, buf1079, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_864, all_reduce_102], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1079, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_102], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1079) | |
| buf1084 = buf1077; del buf1077 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_865], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1079, buf1084, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_863, all_reduce_103], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1085, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_103], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1085) | |
| buf1090 = buf1075; del buf1075 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_866], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1085, buf1090, 2048, stream=stream0) | |
| buf1091 = buf1038; del buf1038 # reuse | |
| buf1092 = buf1037; del buf1037 # reuse | |
| buf1094 = buf1058; del buf1058 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_35, primals_59, primals_60, buf1091, buf1092, buf1094, 327680, 2048, stream=stream0) | |
| del primals_60 | |
| buf1095 = reinterpret_tensor(buf1008, (2048, 2048), (1, 2048), 0); del buf1008 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_61, buf1095, 4194304, stream=stream0) | |
| del primals_61 | |
| buf1096 = buf1040; del buf1040 # reuse | |
| # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf1094, buf1095, out=buf1096) | |
| buf1097 = reinterpret_tensor(buf1001, (2048, 512), (1, 2048), 0); del buf1001 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_62, buf1097, 1048576, stream=stream0) | |
| del primals_62 | |
| buf1098 = reinterpret_tensor(buf991, (327680, 512), (512, 1), 0); del buf991 # reuse | |
| # Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf1094, buf1097, out=buf1098) | |
| buf1099 = buf970; del buf970 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_63, buf1099, 1048576, stream=stream0) | |
| del primals_63 | |
| buf1100 = reinterpret_tensor(buf990, (327680, 512), (512, 1), 0); del buf990 # reuse | |
| # Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf1094, buf1099, out=buf1100) | |
| buf1101 = buf988; del buf988 # reuse | |
| buf1102 = buf974; del buf974 # reuse | |
| buf1103 = reinterpret_tensor(buf1002, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf1002 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_permute_view_15.run(buf1096, buf1098, buf1100, buf1101, buf1102, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf1103, s91, 2560, 1, 16, stream=stream0) | |
| buf1106 = reinterpret_tensor(buf968, (2048, 2048), (2048, 1), 0); del buf968 # reuse | |
| # Topologically Sorted Source Nodes: [permute_353, rearrange_3, o, mm_149], Original ATen: [aten.t, aten.permute, aten.view, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1078, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf1103, (327680, 2048), (2048, 1), 0), out=buf1106) | |
| buf1107 = buf980; del buf980 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_64, buf1107, 4194304, stream=stream0) | |
| del primals_64 | |
| buf1108 = reinterpret_tensor(buf989, (327680, 2048), (2048, 1), 0); del buf989 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5, o, permute_355, mm_150], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf1078, buf1107, out=buf1108) | |
| # Topologically Sorted Source Nodes: [all_reduce_104], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1106, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_104], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1106) | |
| buf1113 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_871], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf1106, buf1113, 4194304, stream=stream0) | |
| buf1115 = buf1102; del buf1102 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_379, view_380, permute_357, flex_attention_backward_8], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused_flex_attention_backward_permute_view_17.run(buf1103, buf1108, buf1115, 5242880, 128, stream=stream0) | |
| buf1116 = buf1103; del buf1103 # reuse | |
| buf1117 = reinterpret_tensor(buf973, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf973 # reuse | |
| buf1118 = reinterpret_tensor(buf971, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf971 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_379, view_380, permute_357, flex_attention_backward_8], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_backward_permute_view_18.run(buf1096, buf1098, buf1100, buf1101, buf1115, buf1108, buf1116, buf1117, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf1118, s91, s16, 12800, 1, 4, stream=stream0) | |
| buf1121 = reinterpret_tensor(buf972, (512, 2048), (2048, 1), 0); del buf972 # reuse | |
| # Topologically Sorted Source Nodes: [view_381, permute_358, view_382, permute_359, mm_151], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1117, (512, 327680), (1, 512), 0), buf1094, out=buf1121) | |
| buf1122 = buf1108; del buf1108 # reuse | |
| # Topologically Sorted Source Nodes: [view_381, permute_358, view_382, permute_361, mm_152], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1117, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1099, (512, 2048), (2048, 1), 0), out=buf1122) | |
| # Topologically Sorted Source Nodes: [all_reduce_105], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1121, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_105], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1121) | |
| buf1127 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_876], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf1121, buf1127, 1048576, stream=stream0) | |
| buf1128 = buf1121; del buf1121 # reuse | |
| # Topologically Sorted Source Nodes: [view_383, permute_363, view_384, permute_364, mm_153], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1118, (512, 327680), (1, 512), 0), buf1094, out=buf1128) | |
| buf1129 = buf1096; del buf1096 # reuse | |
| # Topologically Sorted Source Nodes: [view_383, permute_363, view_384, permute_366, mm_154], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1118, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1097, (512, 2048), (2048, 1), 0), out=buf1129) | |
| # Topologically Sorted Source Nodes: [all_reduce_106], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1128, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_106], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1128) | |
| buf1134 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_881], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf1128, buf1134, 1048576, stream=stream0) | |
| buf1135 = buf1106; del buf1106 # reuse | |
| # Topologically Sorted Source Nodes: [view_385, permute_368, view_386, permute_369, mm_155], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1116, (2048, 327680), (1, 2048), 0), buf1094, out=buf1135) | |
| buf1136 = buf1094; del buf1094 # reuse | |
| # Topologically Sorted Source Nodes: [view_385, permute_368, view_386, permute_371, mm_156], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1116, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf1095, (2048, 2048), (2048, 1), 0), out=buf1136) | |
| # Topologically Sorted Source Nodes: [all_reduce_107], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1135, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_107], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1135) | |
| buf1141 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_886], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf1135, buf1141, 4194304, stream=stream0) | |
| buf1149 = buf1078; del buf1078 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, add_179, add_180, convert_element_type_887, convert_element_type_889, mul_387, mul_388, sum_91, mul_389, sum_92, mul_390, sub_92, sub_93, div_21, mul_391, mul_392, sum_93, sum_94, convert_element_type_891, add_181], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward] | |
| workspace_18 = workspace_17; del workspace_17 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf1149, buf1122, buf1129, buf1136, primals_59, add_35, buf1091, buf1092, workspace_18, 327680, 2048, stream=stream0) | |
| buf1146 = workspace_18[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf1148 = workspace_18[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_35 | |
| del primals_59 | |
| buf1156 = buf1085; del buf1085 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_892, all_reduce_109], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1146, buf1156, 2048, stream=stream0) | |
| buf1150 = buf1079; del buf1079 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_893, all_reduce_108], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1148, buf1150, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_893, all_reduce_108], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1150, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_108], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1150) | |
| buf1155 = buf1148; del buf1148 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_894], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1150, buf1155, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_892, all_reduce_109], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1156, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_109], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1156) | |
| buf1161 = buf1146; del buf1146 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_895], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1156, buf1161, 2048, stream=stream0) | |
| buf1162 = reinterpret_tensor(buf1059, (2048, 8192), (8192, 1), 0); del buf1059 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_57, buf1162, 16777216, stream=stream0) | |
| del primals_57 | |
| buf1163 = buf1057; del buf1057 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, x_8, permute_373, mm_157], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf1149, buf1162, out=buf1163) | |
| buf1164 = buf1092; del buf1092 # reuse | |
| buf1165 = buf1091; del buf1091 # reuse | |
| buf1167 = buf1136; del buf1136 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_30, primals_53, primals_54, buf1164, buf1165, buf1167, 327680, 2048, stream=stream0) | |
| del primals_54 | |
| buf1168 = reinterpret_tensor(buf1162, (2048, 8192), (1, 2048), 0); del buf1162 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_55, buf1168, 16777216, stream=stream0) | |
| del primals_55 | |
| buf1169 = buf1043; del buf1043 # reuse | |
| # Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm] | |
| extern_kernels.mm(buf1167, buf1168, out=buf1169) | |
| buf1170 = buf1042; del buf1042 # reuse | |
| buf1184 = buf1163; del buf1163 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, x, x_7, convert_element_type_902, mul_398, mul_399, sub_94, mul_400, add_184, mul_401, mul_402, mul_403, add_185, mul_404, convert_element_type_904], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf1184, primals_56, buf1169, buf1170, 2684354560, stream=stream0) | |
| del primals_56 | |
| buf1171 = buf1044; del buf1044 # reuse | |
| # Topologically Sorted Source Nodes: [permute_374, redistribute_3, x, x_7, mm_158], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1149, (2048, 327680), (1, 2048), 0), buf1170, out=buf1171) | |
| buf1172 = buf1045; del buf1045 # reuse | |
| # Topologically Sorted Source Nodes: [sum_95], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_6.run(buf1149, buf1172, 327680, 2048, stream=stream0) | |
| buf1173 = reinterpret_tensor(buf1156, (1, 2048), (2048, 1), 0); del buf1156 # reuse | |
| # Topologically Sorted Source Nodes: [sum_95], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_7.run(buf1172, buf1173, 2048, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_387, all_reduce_110], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1173, (2048, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_387, wait_tensor_110], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1173, (2048, ), (1, ), 0)) | |
| buf1178 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_387, convert_element_type_900], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1173, buf1178, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_111], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1171, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_111], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1171) | |
| buf1183 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_901], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf1171, buf1183, 16777216, stream=stream0) | |
| buf1185 = buf1129; del buf1129 # reuse | |
| # Topologically Sorted Source Nodes: [permute_377, mm_159], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(buf1184, reinterpret_tensor(buf1168, (8192, 2048), (2048, 1), 0), out=buf1185) | |
| buf1186 = reinterpret_tensor(buf1168, (8192, 2048), (2048, 1), 0); del buf1168 # reuse | |
| # Topologically Sorted Source Nodes: [permute_378, mm_160], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1184, (8192, 327680), (1, 8192), 0), buf1167, out=buf1186) | |
| buf1187 = buf1060; del buf1060 # reuse | |
| # Topologically Sorted Source Nodes: [sum_96], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_9.run(buf1184, buf1187, 1310720, 2048, stream=stream0) | |
| buf1188 = buf1061; del buf1061 # reuse | |
| # Topologically Sorted Source Nodes: [sum_96], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_10.run(buf1187, buf1188, 8192, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_388, all_reduce_112], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1188, (8192, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_388, wait_tensor_112], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1188, (8192, ), (1, ), 0)) | |
| buf1193 = empty_strided_cuda((8192, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_388, convert_element_type_909], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_view_11.run(buf1188, buf1193, 8192, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_113], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1186, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_113], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1186) | |
| buf1198 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_910], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf1186, buf1198, 16777216, stream=stream0) | |
| buf1205 = buf1149; del buf1149 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_911, convert_element_type_913, mul_406, mul_407, sum_97, mul_408, sum_98, mul_409, sub_96, sub_97, div_22, mul_410, mul_411, sum_99, sum_100, convert_element_type_915, add_186], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add] | |
| workspace_19 = workspace_18; del workspace_18 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf1205, buf1185, primals_53, add_30, buf1164, buf1165, workspace_19, 327680, 2048, stream=stream0) | |
| buf1202 = workspace_19[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf1204 = workspace_19[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_30 | |
| del primals_53 | |
| buf1212 = reinterpret_tensor(buf1173, (2048, ), (1, ), 0); del buf1173 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_916, all_reduce_115], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1202, buf1212, 2048, stream=stream0) | |
| buf1206 = buf1150; del buf1150 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_917, all_reduce_114], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1204, buf1206, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_917, all_reduce_114], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1206, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_114], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1206) | |
| buf1211 = buf1204; del buf1204 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_918], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1206, buf1211, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_916, all_reduce_115], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1212, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_115], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1212) | |
| buf1217 = buf1202; del buf1202 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_919], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1212, buf1217, 2048, stream=stream0) | |
| buf1218 = buf1165; del buf1165 # reuse | |
| buf1219 = buf1164; del buf1164 # reuse | |
| buf1221 = buf1185; del buf1185 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_27, primals_47, primals_48, buf1218, buf1219, buf1221, 327680, 2048, stream=stream0) | |
| del primals_48 | |
| buf1222 = reinterpret_tensor(buf1135, (2048, 2048), (1, 2048), 0); del buf1135 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_49, buf1222, 4194304, stream=stream0) | |
| del primals_49 | |
| buf1223 = buf1167; del buf1167 # reuse | |
| # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf1221, buf1222, out=buf1223) | |
| buf1224 = reinterpret_tensor(buf1128, (2048, 512), (1, 2048), 0); del buf1128 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_50, buf1224, 1048576, stream=stream0) | |
| del primals_50 | |
| buf1225 = reinterpret_tensor(buf1118, (327680, 512), (512, 1), 0); del buf1118 # reuse | |
| # Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf1221, buf1224, out=buf1225) | |
| buf1226 = buf1097; del buf1097 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_51, buf1226, 1048576, stream=stream0) | |
| del primals_51 | |
| buf1227 = reinterpret_tensor(buf1117, (327680, 512), (512, 1), 0); del buf1117 # reuse | |
| # Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf1221, buf1226, out=buf1227) | |
| buf1228 = buf1115; del buf1115 # reuse | |
| buf1229 = buf1101; del buf1101 # reuse | |
| buf1230 = reinterpret_tensor(buf1122, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf1122 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_permute_view_15.run(buf1223, buf1225, buf1227, buf1228, buf1229, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf1230, s91, 2560, 1, 16, stream=stream0) | |
| buf1233 = reinterpret_tensor(buf1095, (2048, 2048), (2048, 1), 0); del buf1095 # reuse | |
| # Topologically Sorted Source Nodes: [permute_381, rearrange_3, o, mm_161], Original ATen: [aten.t, aten.permute, aten.view, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1205, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf1230, (327680, 2048), (2048, 1), 0), out=buf1233) | |
| buf1234 = buf1107; del buf1107 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_52, buf1234, 4194304, stream=stream0) | |
| del primals_52 | |
| buf1235 = reinterpret_tensor(buf1116, (327680, 2048), (2048, 1), 0); del buf1116 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5, o, permute_383, mm_162], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf1205, buf1234, out=buf1235) | |
| # Topologically Sorted Source Nodes: [all_reduce_116], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1233, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_116], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1233) | |
| buf1240 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_924], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf1233, buf1240, 4194304, stream=stream0) | |
| buf1242 = buf1229; del buf1229 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_391, view_392, permute_385, flex_attention_backward_9], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused_flex_attention_backward_permute_view_17.run(buf1230, buf1235, buf1242, 5242880, 128, stream=stream0) | |
| buf1243 = buf1230; del buf1230 # reuse | |
| buf1244 = reinterpret_tensor(buf1100, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf1100 # reuse | |
| buf1245 = reinterpret_tensor(buf1098, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf1098 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_391, view_392, permute_385, flex_attention_backward_9], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_backward_permute_view_18.run(buf1223, buf1225, buf1227, buf1228, buf1242, buf1235, buf1243, buf1244, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf1245, s91, s16, 12800, 1, 4, stream=stream0) | |
| buf1248 = reinterpret_tensor(buf1099, (512, 2048), (2048, 1), 0); del buf1099 # reuse | |
| # Topologically Sorted Source Nodes: [view_393, permute_386, view_394, permute_387, mm_163], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1244, (512, 327680), (1, 512), 0), buf1221, out=buf1248) | |
| buf1249 = buf1235; del buf1235 # reuse | |
| # Topologically Sorted Source Nodes: [view_393, permute_386, view_394, permute_389, mm_164], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1244, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1226, (512, 2048), (2048, 1), 0), out=buf1249) | |
| # Topologically Sorted Source Nodes: [all_reduce_117], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1248, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_117], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1248) | |
| buf1254 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_929], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf1248, buf1254, 1048576, stream=stream0) | |
| buf1255 = buf1248; del buf1248 # reuse | |
| # Topologically Sorted Source Nodes: [view_395, permute_391, view_396, permute_392, mm_165], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1245, (512, 327680), (1, 512), 0), buf1221, out=buf1255) | |
| buf1256 = buf1223; del buf1223 # reuse | |
| # Topologically Sorted Source Nodes: [view_395, permute_391, view_396, permute_394, mm_166], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1245, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1224, (512, 2048), (2048, 1), 0), out=buf1256) | |
| # Topologically Sorted Source Nodes: [all_reduce_118], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1255, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_118], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1255) | |
| buf1261 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_934], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf1255, buf1261, 1048576, stream=stream0) | |
| buf1262 = buf1233; del buf1233 # reuse | |
| # Topologically Sorted Source Nodes: [view_397, permute_396, view_398, permute_397, mm_167], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1243, (2048, 327680), (1, 2048), 0), buf1221, out=buf1262) | |
| buf1263 = buf1221; del buf1221 # reuse | |
| # Topologically Sorted Source Nodes: [view_397, permute_396, view_398, permute_399, mm_168], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1243, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf1222, (2048, 2048), (2048, 1), 0), out=buf1263) | |
| # Topologically Sorted Source Nodes: [all_reduce_119], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1262, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_119], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1262) | |
| buf1268 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_939], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf1262, buf1268, 4194304, stream=stream0) | |
| buf1276 = buf1205; del buf1205 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, add_187, add_188, convert_element_type_940, convert_element_type_942, mul_413, mul_414, sum_101, mul_415, sum_102, mul_416, sub_99, sub_100, div_23, mul_417, mul_418, sum_103, sum_104, convert_element_type_944, add_189], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward] | |
| workspace_20 = workspace_19; del workspace_19 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf1276, buf1249, buf1256, buf1263, primals_47, add_27, buf1218, buf1219, workspace_20, 327680, 2048, stream=stream0) | |
| buf1273 = workspace_20[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf1275 = workspace_20[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_27 | |
| del primals_47 | |
| buf1283 = buf1212; del buf1212 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_945, all_reduce_121], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1273, buf1283, 2048, stream=stream0) | |
| buf1277 = buf1206; del buf1206 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_946, all_reduce_120], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1275, buf1277, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_946, all_reduce_120], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1277, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_120], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1277) | |
| buf1282 = buf1275; del buf1275 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_947], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1277, buf1282, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_945, all_reduce_121], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1283, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_121], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1283) | |
| buf1288 = buf1273; del buf1273 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_948], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1283, buf1288, 2048, stream=stream0) | |
| buf1289 = reinterpret_tensor(buf1186, (2048, 8192), (8192, 1), 0); del buf1186 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_45, buf1289, 16777216, stream=stream0) | |
| del primals_45 | |
| buf1290 = buf1184; del buf1184 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, x_6, permute_401, mm_169], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf1276, buf1289, out=buf1290) | |
| buf1291 = buf1219; del buf1219 # reuse | |
| buf1292 = buf1218; del buf1218 # reuse | |
| buf1294 = buf1263; del buf1263 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_22, primals_41, primals_42, buf1291, buf1292, buf1294, 327680, 2048, stream=stream0) | |
| del primals_42 | |
| buf1295 = reinterpret_tensor(buf1289, (2048, 8192), (1, 2048), 0); del buf1289 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_43, buf1295, 16777216, stream=stream0) | |
| del primals_43 | |
| buf1296 = buf1170; del buf1170 # reuse | |
| # Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm] | |
| extern_kernels.mm(buf1294, buf1295, out=buf1296) | |
| buf1297 = buf1169; del buf1169 # reuse | |
| buf1311 = buf1290; del buf1290 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, x, x_5, convert_element_type_955, mul_424, mul_425, sub_101, mul_426, add_192, mul_427, mul_428, mul_429, add_193, mul_430, convert_element_type_957], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf1311, primals_44, buf1296, buf1297, 2684354560, stream=stream0) | |
| del primals_44 | |
| buf1298 = buf1171; del buf1171 # reuse | |
| # Topologically Sorted Source Nodes: [permute_402, redistribute_3, x, x_5, mm_170], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1276, (2048, 327680), (1, 2048), 0), buf1297, out=buf1298) | |
| buf1299 = buf1172; del buf1172 # reuse | |
| # Topologically Sorted Source Nodes: [sum_105], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_6.run(buf1276, buf1299, 327680, 2048, stream=stream0) | |
| buf1300 = reinterpret_tensor(buf1283, (1, 2048), (2048, 1), 0); del buf1283 # reuse | |
| # Topologically Sorted Source Nodes: [sum_105], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_7.run(buf1299, buf1300, 2048, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_399, all_reduce_122], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1300, (2048, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_399, wait_tensor_122], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1300, (2048, ), (1, ), 0)) | |
| buf1305 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_399, convert_element_type_953], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1300, buf1305, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_123], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1298, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_123], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1298) | |
| buf1310 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_954], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf1298, buf1310, 16777216, stream=stream0) | |
| buf1312 = buf1256; del buf1256 # reuse | |
| # Topologically Sorted Source Nodes: [permute_405, mm_171], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(buf1311, reinterpret_tensor(buf1295, (8192, 2048), (2048, 1), 0), out=buf1312) | |
| buf1313 = reinterpret_tensor(buf1295, (8192, 2048), (2048, 1), 0); del buf1295 # reuse | |
| # Topologically Sorted Source Nodes: [permute_406, mm_172], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1311, (8192, 327680), (1, 8192), 0), buf1294, out=buf1313) | |
| buf1314 = buf1187; del buf1187 # reuse | |
| # Topologically Sorted Source Nodes: [sum_106], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_9.run(buf1311, buf1314, 1310720, 2048, stream=stream0) | |
| buf1315 = buf1188; del buf1188 # reuse | |
| # Topologically Sorted Source Nodes: [sum_106], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_10.run(buf1314, buf1315, 8192, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_400, all_reduce_124], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1315, (8192, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_400, wait_tensor_124], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1315, (8192, ), (1, ), 0)) | |
| buf1320 = empty_strided_cuda((8192, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_400, convert_element_type_962], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_view_11.run(buf1315, buf1320, 8192, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_125], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1313, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_125], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1313) | |
| buf1325 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_963], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf1313, buf1325, 16777216, stream=stream0) | |
| buf1332 = buf1276; del buf1276 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_964, convert_element_type_966, mul_432, mul_433, sum_107, mul_434, sum_108, mul_435, sub_103, sub_104, div_24, mul_436, mul_437, sum_109, sum_110, convert_element_type_968, add_194], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add] | |
| workspace_21 = workspace_20; del workspace_20 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf1332, buf1312, primals_41, add_22, buf1291, buf1292, workspace_21, 327680, 2048, stream=stream0) | |
| buf1329 = workspace_21[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf1331 = workspace_21[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_22 | |
| del primals_41 | |
| buf1339 = reinterpret_tensor(buf1300, (2048, ), (1, ), 0); del buf1300 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_969, all_reduce_127], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1329, buf1339, 2048, stream=stream0) | |
| buf1333 = buf1277; del buf1277 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_970, all_reduce_126], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1331, buf1333, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_970, all_reduce_126], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1333, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_126], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1333) | |
| buf1338 = buf1331; del buf1331 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_971], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1333, buf1338, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_969, all_reduce_127], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1339, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_127], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1339) | |
| buf1344 = buf1329; del buf1329 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_972], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1339, buf1344, 2048, stream=stream0) | |
| buf1345 = buf1292; del buf1292 # reuse | |
| buf1346 = buf1291; del buf1291 # reuse | |
| buf1348 = buf1312; del buf1312 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_19, primals_35, primals_36, buf1345, buf1346, buf1348, 327680, 2048, stream=stream0) | |
| del primals_36 | |
| buf1349 = reinterpret_tensor(buf1262, (2048, 2048), (1, 2048), 0); del buf1262 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_37, buf1349, 4194304, stream=stream0) | |
| del primals_37 | |
| buf1350 = buf1294; del buf1294 # reuse | |
| # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf1348, buf1349, out=buf1350) | |
| buf1351 = reinterpret_tensor(buf1255, (2048, 512), (1, 2048), 0); del buf1255 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_38, buf1351, 1048576, stream=stream0) | |
| del primals_38 | |
| buf1352 = reinterpret_tensor(buf1245, (327680, 512), (512, 1), 0); del buf1245 # reuse | |
| # Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf1348, buf1351, out=buf1352) | |
| buf1353 = buf1224; del buf1224 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_39, buf1353, 1048576, stream=stream0) | |
| del primals_39 | |
| buf1354 = reinterpret_tensor(buf1244, (327680, 512), (512, 1), 0); del buf1244 # reuse | |
| # Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf1348, buf1353, out=buf1354) | |
| buf1355 = buf1242; del buf1242 # reuse | |
| buf1356 = buf1228; del buf1228 # reuse | |
| buf1357 = reinterpret_tensor(buf1249, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf1249 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_permute_view_15.run(buf1350, buf1352, buf1354, buf1355, buf1356, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf1357, s91, 2560, 1, 16, stream=stream0) | |
| buf1360 = reinterpret_tensor(buf1222, (2048, 2048), (2048, 1), 0); del buf1222 # reuse | |
| # Topologically Sorted Source Nodes: [permute_409, rearrange_3, o, mm_173], Original ATen: [aten.t, aten.permute, aten.view, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1332, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf1357, (327680, 2048), (2048, 1), 0), out=buf1360) | |
| buf1361 = buf1234; del buf1234 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_40, buf1361, 4194304, stream=stream0) | |
| del primals_40 | |
| buf1362 = reinterpret_tensor(buf1243, (327680, 2048), (2048, 1), 0); del buf1243 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5, o, permute_411, mm_174], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf1332, buf1361, out=buf1362) | |
| # Topologically Sorted Source Nodes: [all_reduce_128], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1360, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_128], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1360) | |
| buf1367 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_977], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf1360, buf1367, 4194304, stream=stream0) | |
| buf1369 = buf1356; del buf1356 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_403, view_404, permute_413, flex_attention_backward_10], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused_flex_attention_backward_permute_view_17.run(buf1357, buf1362, buf1369, 5242880, 128, stream=stream0) | |
| buf1370 = buf1357; del buf1357 # reuse | |
| buf1371 = reinterpret_tensor(buf1227, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf1227 # reuse | |
| buf1372 = reinterpret_tensor(buf1225, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf1225 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_403, view_404, permute_413, flex_attention_backward_10], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_backward_permute_view_18.run(buf1350, buf1352, buf1354, buf1355, buf1369, buf1362, buf1370, buf1371, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf1372, s91, s16, 12800, 1, 4, stream=stream0) | |
| buf1375 = reinterpret_tensor(buf1226, (512, 2048), (2048, 1), 0); del buf1226 # reuse | |
| # Topologically Sorted Source Nodes: [view_405, permute_414, view_406, permute_415, mm_175], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1371, (512, 327680), (1, 512), 0), buf1348, out=buf1375) | |
| buf1376 = buf1362; del buf1362 # reuse | |
| # Topologically Sorted Source Nodes: [view_405, permute_414, view_406, permute_417, mm_176], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1371, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1353, (512, 2048), (2048, 1), 0), out=buf1376) | |
| # Topologically Sorted Source Nodes: [all_reduce_129], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1375, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_129], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1375) | |
| buf1381 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_982], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf1375, buf1381, 1048576, stream=stream0) | |
| buf1382 = buf1375; del buf1375 # reuse | |
| # Topologically Sorted Source Nodes: [view_407, permute_419, view_408, permute_420, mm_177], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1372, (512, 327680), (1, 512), 0), buf1348, out=buf1382) | |
| buf1383 = buf1350; del buf1350 # reuse | |
| # Topologically Sorted Source Nodes: [view_407, permute_419, view_408, permute_422, mm_178], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1372, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1351, (512, 2048), (2048, 1), 0), out=buf1383) | |
| # Topologically Sorted Source Nodes: [all_reduce_130], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1382, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_130], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1382) | |
| buf1388 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_987], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf1382, buf1388, 1048576, stream=stream0) | |
| buf1389 = buf1360; del buf1360 # reuse | |
| # Topologically Sorted Source Nodes: [view_409, permute_424, view_410, permute_425, mm_179], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1370, (2048, 327680), (1, 2048), 0), buf1348, out=buf1389) | |
| buf1390 = buf1348; del buf1348 # reuse | |
| # Topologically Sorted Source Nodes: [view_409, permute_424, view_410, permute_427, mm_180], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1370, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf1349, (2048, 2048), (2048, 1), 0), out=buf1390) | |
| # Topologically Sorted Source Nodes: [all_reduce_131], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1389, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_131], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1389) | |
| buf1395 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_992], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf1389, buf1395, 4194304, stream=stream0) | |
| buf1403 = buf1332; del buf1332 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, add_195, add_196, convert_element_type_993, convert_element_type_995, mul_439, mul_440, sum_111, mul_441, sum_112, mul_442, sub_106, sub_107, div_25, mul_443, mul_444, sum_113, sum_114, convert_element_type_997, add_197], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward] | |
| workspace_22 = workspace_21; del workspace_21 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf1403, buf1376, buf1383, buf1390, primals_35, add_19, buf1345, buf1346, workspace_22, 327680, 2048, stream=stream0) | |
| buf1400 = workspace_22[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf1402 = workspace_22[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_19 | |
| del primals_35 | |
| buf1410 = buf1339; del buf1339 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_998, all_reduce_133], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1400, buf1410, 2048, stream=stream0) | |
| buf1404 = buf1333; del buf1333 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_999, all_reduce_132], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1402, buf1404, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_999, all_reduce_132], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1404, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_132], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1404) | |
| buf1409 = buf1402; del buf1402 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_1000], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1404, buf1409, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_998, all_reduce_133], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1410, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_133], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1410) | |
| buf1415 = buf1400; del buf1400 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_1001], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1410, buf1415, 2048, stream=stream0) | |
| buf1416 = reinterpret_tensor(buf1313, (2048, 8192), (8192, 1), 0); del buf1313 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_33, buf1416, 16777216, stream=stream0) | |
| del primals_33 | |
| buf1417 = buf1311; del buf1311 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, x_4, permute_429, mm_181], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf1403, buf1416, out=buf1417) | |
| buf1418 = buf1346; del buf1346 # reuse | |
| buf1419 = buf1345; del buf1345 # reuse | |
| buf1421 = buf1390; del buf1390 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_14, primals_29, primals_30, buf1418, buf1419, buf1421, 327680, 2048, stream=stream0) | |
| del primals_30 | |
| buf1422 = reinterpret_tensor(buf1416, (2048, 8192), (1, 2048), 0); del buf1416 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_3.run(primals_31, buf1422, 16777216, stream=stream0) | |
| del primals_31 | |
| buf1423 = buf1297; del buf1297 # reuse | |
| # Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm] | |
| extern_kernels.mm(buf1421, buf1422, out=buf1423) | |
| buf1424 = buf1296; del buf1296 # reuse | |
| buf1438 = buf1417; del buf1417 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, x, x_3, convert_element_type_1008, mul_450, mul_451, sub_108, mul_452, add_200, mul_453, mul_454, mul_455, add_201, mul_456, convert_element_type_1010], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf1438, primals_32, buf1423, buf1424, 2684354560, stream=stream0) | |
| del buf1423 | |
| del primals_32 | |
| buf1425 = buf1298; del buf1298 # reuse | |
| # Topologically Sorted Source Nodes: [permute_430, redistribute_3, x, x_3, mm_182], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1403, (2048, 327680), (1, 2048), 0), buf1424, out=buf1425) | |
| del buf1424 | |
| buf1426 = buf1299; del buf1299 # reuse | |
| # Topologically Sorted Source Nodes: [sum_115], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_6.run(buf1403, buf1426, 327680, 2048, stream=stream0) | |
| buf1427 = reinterpret_tensor(buf1410, (1, 2048), (2048, 1), 0); del buf1410 # reuse | |
| # Topologically Sorted Source Nodes: [sum_115], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_7.run(buf1426, buf1427, 2048, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [view_411, all_reduce_134], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1427, (2048, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_411, wait_tensor_134], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1427, (2048, ), (1, ), 0)) | |
| buf1432 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_411, convert_element_type_1006], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1427, buf1432, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [all_reduce_135], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1425, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_135], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1425) | |
| buf1437 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1007], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf1425, buf1437, 16777216, stream=stream0) | |
| del buf1425 | |
| buf1439 = buf1383; del buf1383 # reuse | |
| # Topologically Sorted Source Nodes: [permute_433, mm_183], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(buf1438, reinterpret_tensor(buf1422, (8192, 2048), (2048, 1), 0), out=buf1439) | |
| buf1440 = reinterpret_tensor(buf1422, (8192, 2048), (2048, 1), 0); del buf1422 # reuse | |
| # Topologically Sorted Source Nodes: [permute_434, mm_184], Original ATen: [aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1438, (8192, 327680), (1, 8192), 0), buf1421, out=buf1440) | |
| buf1441 = buf1314; del buf1314 # reuse | |
| # Topologically Sorted Source Nodes: [sum_116], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_9.run(buf1438, buf1441, 1310720, 2048, stream=stream0) | |
| del buf1438 | |
| buf1442 = buf1315; del buf1315 # reuse | |
| # Topologically Sorted Source Nodes: [sum_116], Original ATen: [aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_sum_10.run(buf1441, buf1442, 8192, 160, stream=stream0) | |
| del buf1441 | |
| # Topologically Sorted Source Nodes: [view_412, all_reduce_136], Original ATen: [aten.view, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1442, (8192, ), (1, ), 0), 'avg', '0') | |
| # Topologically Sorted Source Nodes: [view_412, wait_tensor_136], Original ATen: [aten.view, _c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1442, (8192, ), (1, ), 0)) | |
| buf1447 = empty_strided_cuda((8192, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [view_412, convert_element_type_1015], Original ATen: [aten.view, aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_view_11.run(buf1442, buf1447, 8192, stream=stream0) | |
| del buf1442 | |
| # Topologically Sorted Source Nodes: [all_reduce_137], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1440, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_137], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1440) | |
| buf1452 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1016], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_8.run(buf1440, buf1452, 16777216, stream=stream0) | |
| del buf1440 | |
| buf1459 = buf1403; del buf1403 # reuse | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_1017, convert_element_type_1019, mul_458, mul_459, sum_117, mul_460, sum_118, mul_461, sub_110, sub_111, div_26, mul_462, mul_463, sum_119, sum_120, convert_element_type_1021, add_202], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add] | |
| workspace_23 = workspace_22; del workspace_22 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf1459, buf1439, primals_29, add_14, buf1418, buf1419, workspace_23, 327680, 2048, stream=stream0) | |
| buf1456 = workspace_23[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf1458 = workspace_23[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del add_14 | |
| del primals_29 | |
| buf1466 = reinterpret_tensor(buf1427, (2048, ), (1, ), 0); del buf1427 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_1022, all_reduce_139], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1456, buf1466, 2048, stream=stream0) | |
| buf1460 = buf1404; del buf1404 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_1023, all_reduce_138], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1458, buf1460, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1023, all_reduce_138], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1460, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_138], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1460) | |
| buf1465 = buf1458; del buf1458 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_1024], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1460, buf1465, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1022, all_reduce_139], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1466, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_139], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1466) | |
| buf1471 = buf1456; del buf1456 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_1025], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1466, buf1471, 2048, stream=stream0) | |
| buf1472 = buf1419; del buf1419 # reuse | |
| buf1473 = buf1418; del buf1418 # reuse | |
| buf1475 = buf1439; del buf1439 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_native_layer_norm_4.run(add_11, primals_9, primals_10, buf1472, buf1473, buf1475, 327680, 2048, stream=stream0) | |
| del primals_10 | |
| buf1476 = reinterpret_tensor(buf1389, (2048, 2048), (1, 2048), 0); del buf1389 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_11, buf1476, 4194304, stream=stream0) | |
| del primals_11 | |
| buf1477 = buf1421; del buf1421 # reuse | |
| # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf1475, buf1476, out=buf1477) | |
| buf1478 = reinterpret_tensor(buf1382, (2048, 512), (1, 2048), 0); del buf1382 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_12, buf1478, 1048576, stream=stream0) | |
| del primals_12 | |
| buf1479 = reinterpret_tensor(buf1372, (327680, 512), (512, 1), 0); del buf1372 # reuse | |
| # Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf1475, buf1478, out=buf1479) | |
| buf1480 = buf1351; del buf1351 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_14.run(primals_13, buf1480, 1048576, stream=stream0) | |
| del primals_13 | |
| buf1481 = reinterpret_tensor(buf1371, (327680, 512), (512, 1), 0); del buf1371 # reuse | |
| # Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm] | |
| extern_kernels.mm(buf1475, buf1480, out=buf1481) | |
| buf1482 = buf1369; del buf1369 # reuse | |
| buf1483 = buf1355; del buf1355 # reuse | |
| buf1484 = reinterpret_tensor(buf1376, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf1376 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_permute_view_15.run(buf1477, buf1479, buf1481, buf1482, buf1483, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf1484, s91, 2560, 1, 16, stream=stream0) | |
| buf1487 = reinterpret_tensor(buf1349, (2048, 2048), (2048, 1), 0); del buf1349 # reuse | |
| # Topologically Sorted Source Nodes: [permute_437, rearrange_3, o, mm_185], Original ATen: [aten.t, aten.permute, aten.view, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1459, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf1484, (327680, 2048), (2048, 1), 0), out=buf1487) | |
| buf1488 = buf1361; del buf1361 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_t_13.run(primals_28, buf1488, 4194304, stream=stream0) | |
| del primals_28 | |
| buf1489 = reinterpret_tensor(buf1370, (327680, 2048), (2048, 1), 0); del buf1370 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_5, o, permute_439, mm_186], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf1459, buf1488, out=buf1489) | |
| del buf1488 | |
| # Topologically Sorted Source Nodes: [all_reduce_140], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1487, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_140], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1487) | |
| buf1494 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1030], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf1487, buf1494, 4194304, stream=stream0) | |
| buf1496 = buf1483; del buf1483 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_415, view_416, permute_441, flex_attention_backward_11], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused_flex_attention_backward_permute_view_17.run(buf1484, buf1489, buf1496, 5242880, 128, stream=stream0) | |
| buf1497 = buf1484; del buf1484 # reuse | |
| buf1498 = reinterpret_tensor(buf1354, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf1354 # reuse | |
| buf1499 = reinterpret_tensor(buf1352, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf1352 # reuse | |
| # Topologically Sorted Source Nodes: [q, k, v, view_415, view_416, permute_441, flex_attention_backward_11], Original ATen: [aten.view, aten.permute, flex_attention_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_tem_fused_flex_attention_backward_permute_view_18.run(buf1477, buf1479, buf1481, buf1482, buf1496, buf1489, buf1497, buf1498, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf1499, s91, s16, 12800, 1, 4, stream=stream0) | |
| del buf1479 | |
| del buf1481 | |
| del buf1482 | |
| del buf1496 | |
| del primals_14 | |
| del primals_15 | |
| del primals_17 | |
| del primals_18 | |
| del primals_19 | |
| del primals_21 | |
| del primals_22 | |
| del primals_24 | |
| del primals_25 | |
| del primals_27 | |
| buf1502 = reinterpret_tensor(buf1353, (512, 2048), (2048, 1), 0); del buf1353 # reuse | |
| # Topologically Sorted Source Nodes: [view_417, permute_442, view_418, permute_443, mm_187], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1498, (512, 327680), (1, 512), 0), buf1475, out=buf1502) | |
| buf1503 = buf1489; del buf1489 # reuse | |
| # Topologically Sorted Source Nodes: [view_417, permute_442, view_418, permute_445, mm_188], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1498, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1480, (512, 2048), (2048, 1), 0), out=buf1503) | |
| del buf1480 | |
| del buf1498 | |
| # Topologically Sorted Source Nodes: [all_reduce_141], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1502, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_141], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1502) | |
| buf1508 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1035], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf1502, buf1508, 1048576, stream=stream0) | |
| buf1509 = buf1502; del buf1502 # reuse | |
| # Topologically Sorted Source Nodes: [view_419, permute_447, view_420, permute_448, mm_189], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1499, (512, 327680), (1, 512), 0), buf1475, out=buf1509) | |
| buf1510 = buf1477; del buf1477 # reuse | |
| # Topologically Sorted Source Nodes: [view_419, permute_447, view_420, permute_450, mm_190], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1499, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1478, (512, 2048), (2048, 1), 0), out=buf1510) | |
| del buf1478 | |
| del buf1499 | |
| # Topologically Sorted Source Nodes: [all_reduce_142], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1509, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_142], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1509) | |
| buf1515 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1040], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_19.run(buf1509, buf1515, 1048576, stream=stream0) | |
| del buf1509 | |
| buf1516 = buf1487; del buf1487 # reuse | |
| # Topologically Sorted Source Nodes: [view_421, permute_452, view_422, permute_453, mm_191], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1497, (2048, 327680), (1, 2048), 0), buf1475, out=buf1516) | |
| buf1517 = buf1475; del buf1475 # reuse | |
| # Topologically Sorted Source Nodes: [view_421, permute_452, view_422, permute_455, mm_192], Original ATen: [aten.view, aten.permute, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1497, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf1476, (2048, 2048), (2048, 1), 0), out=buf1517) | |
| del buf1476 | |
| # Topologically Sorted Source Nodes: [all_reduce_143], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1516, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_143], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1516) | |
| buf1522 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1045], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_16.run(buf1516, buf1522, 4194304, stream=stream0) | |
| del buf1516 | |
| buf1542 = empty_strided_cuda((327680, 16), (16, 1), torch.uint8) | |
| # Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_3, contiguous_1], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_21.run(primals_1, buf1542, 5242880, stream=stream0) | |
| # Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_3, contiguous_1, positions], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone, aten.view] | |
| buf1543 = torch.ops.aten.view.dtype(buf1542, torch.int32) | |
| buf1544 = buf1543 | |
| assert_size_stride(buf1544, (327680, 4), (4, 1), 'torch.ops.aten.view.dtype') | |
| assert_alignment(buf1544, 16, 'torch.ops.aten.view.dtype') | |
| buf1545 = empty_strided_cuda((327680, 2048), (2048, 1), torch.float32) | |
| buf1546 = reinterpret_tensor(buf1426, (327680, 1), (1, 327680), 0); del buf1426 # reuse | |
| buf1547 = empty_strided_cuda((327680, 1), (1, 327680), torch.float32) | |
| # Topologically Sorted Source Nodes: [i, j, ifreqs, ifreqs_1, neg, freqs, getitem_6, getitem_7, mul_1, sin, cos, getitem_10, mul_3, sin_1, cos_1, posemb, to_1, layer_norm_1], Original ATen: [aten.select, aten.arange, aten.div, aten.neg, aten.pow, aten.unsqueeze, aten.mul, aten.sin, aten.cos, aten.cat, aten._to_copy, aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_arange_cat_cos_div_mul_native_layer_norm_neg_pow_select_sin_unsqueeze_22.run(buf1544, buf1545, buf1546, buf1547, 327680, 2048, stream=stream0) | |
| del buf1542 | |
| del buf1543 | |
| del buf1544 | |
| buf1568 = empty_strided_cuda((327680, 768), (768, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_2, patches, to, truediv, patches_1], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone, aten._to_copy, aten.div, aten.sub] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_clone_div_eq_mul_select_slice_sub_unsqueeze_23.run(primals_1, buf1568, 251658240, stream=stream0) | |
| buf1569 = empty_strided_cuda((2048, 768), (768, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [redistribute_1], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_24.run(primals_4, buf1569, 1572864, stream=stream0) | |
| del primals_4 | |
| buf1570 = reinterpret_tensor(buf1497, (327680, 2048), (2048, 1), 0); del buf1497 # reuse | |
| # Topologically Sorted Source Nodes: [redistribute_1, linear], Original ATen: [aten._to_copy, aten.t, aten.mm] | |
| extern_kernels.mm(buf1568, reinterpret_tensor(buf1569, (768, 2048), (1, 768), 0), out=buf1570) | |
| buf1571 = empty_strided_cuda((327680, 1), (1, 327680), torch.float32) | |
| buf1572 = empty_strided_cuda((327680, 1), (1, 327680), torch.float32) | |
| # Topologically Sorted Source Nodes: [x], Original ATen: [aten.native_layer_norm] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_native_layer_norm_25.run(buf1570, buf1571, buf1572, 327680, 2048, stream=stream0) | |
| buf1525 = empty_strided_cuda((327680, 2048), (2048, 1), torch.float32) | |
| buf1549 = buf1545; del buf1545 # reuse | |
| buf1575 = empty_strided_cuda((327680, 2048), (2048, 1), torch.float32) | |
| buf1592 = empty_strided_cuda((327680, 2048), (2048, 1), torch.bfloat16) | |
| buf1585 = empty_strided_cuda((327680, 2048), (2048, 1), torch.bfloat16) | |
| # Call mix order reduction kernel | |
| # Topologically Sorted Source Nodes: [redistribute, layer_norm, add_203, add_204, convert_element_type_1046, convert_element_type_1048, mul_465, mul_466, sum_121, mul_467, sum_122, mul_468, sub_113, sub_114, div_27, mul_469, mul_470, sum_123, sum_124, convert_element_type_1050, add_205, getitem, mask, getitem_1, mul_471, convert_element_type_1055, to_1, layer_norm_1, mul_477, redistribute_2, convert_element_type_1065, mul_479, mul_480, sum_129, x, mul_481, sum_130, mul_482, sub_119, sub_120, div_28, mul_483, mul_484, convert_element_type_1067, mul_485, slice_21, full_default, copy_3], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward, aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.zeros_like, aten.copy] | |
| workspace_24 = workspace_23; del workspace_23 # reuse | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_copy_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_unsqueeze_zeros_like_26.run(buf1549, buf1503, buf1510, buf1517, primals_9, add_11, buf1472, buf1473, buf1459, primals_1, buf1546, buf1547, primals_5, buf1570, buf1571, buf1572, buf1525, buf1575, buf1592, buf1585, workspace_24, 327680, 2048, stream=stream0) | |
| buf1527 = workspace_24[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| buf1529 = workspace_24[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0) | |
| del workspace_24 | |
| del add_11 | |
| del buf1472 | |
| del buf1503 | |
| del buf1510 | |
| del buf1517 | |
| del buf1546 | |
| del buf1547 | |
| del buf1570 | |
| del buf1571 | |
| del primals_5 | |
| del primals_9 | |
| buf1536 = buf1466; del buf1466 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_1051, all_reduce_145], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1527, buf1536, 2048, stream=stream0) | |
| buf1530 = buf1460; del buf1460 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_1052, all_reduce_144], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1529, buf1530, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1052, all_reduce_144], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1530, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_144], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1530) | |
| buf1535 = buf1529; del buf1529 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_1053], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1530, buf1535, 2048, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1051, all_reduce_145], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1536, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_145], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1536) | |
| buf1541 = buf1527; del buf1527 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_1054], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1536, buf1541, 2048, stream=stream0) | |
| buf1550 = reinterpret_tensor(buf1572, (2048, 160), (1, 2048), 0); del buf1572 # reuse | |
| # Topologically Sorted Source Nodes: [sum_127], Original ATen: [aten.native_layer_norm_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_native_layer_norm_backward_27.run(buf1549, buf1550, 327680, 2048, stream=stream0) | |
| del buf1549 | |
| buf1560 = buf1536; del buf1536 # reuse | |
| # Topologically Sorted Source Nodes: [sum_127, convert_element_type_1059, all_reduce_147], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_all_reduce_native_layer_norm_backward_28.run(buf1550, buf1560, 2048, 160, stream=stream0) | |
| buf1552 = buf1550; del buf1550 # reuse | |
| # Topologically Sorted Source Nodes: [layer_norm, div_27, mul_469, convert_element_type_1050, add_205, getitem, mask, getitem_1, mul_471, convert_element_type_1055, sum_128], Original ATen: [aten.native_layer_norm, aten.native_layer_norm_backward, aten.add, aten.select, aten.eq, aten.unsqueeze, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_add_eq_mul_native_layer_norm_native_layer_norm_backward_select_unsqueeze_29.run(buf1459, buf1473, buf1525, primals_1, buf1552, 327680, 2048, stream=stream0) | |
| buf1554 = buf1530; del buf1530 # reuse | |
| # Topologically Sorted Source Nodes: [layer_norm, div_27, mul_469, convert_element_type_1050, add_205, getitem, mask, getitem_1, mul_471, convert_element_type_1055, sum_128, convert_element_type_1060, all_reduce_146], Original ATen: [aten.native_layer_norm, aten.native_layer_norm_backward, aten.add, aten.select, aten.eq, aten.unsqueeze, aten.mul, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_all_reduce_native_layer_norm_backward_28.run(buf1552, buf1554, 2048, 160, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1060, all_reduce_146], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1554, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_146], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1554) | |
| buf1559 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| buf1578 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1061, convert_element_type_1070], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_30.run(buf1554, buf1559, buf1578, 2048, stream=stream0) | |
| del buf1554 | |
| # Topologically Sorted Source Nodes: [convert_element_type_1059, all_reduce_147], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1560, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_147], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1560) | |
| buf1565 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1062], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1560, buf1565, 2048, stream=stream0) | |
| buf1576 = buf1552; del buf1552 # reuse | |
| # Topologically Sorted Source Nodes: [sum_131], Original ATen: [aten.native_layer_norm_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_native_layer_norm_backward_27.run(buf1575, buf1576, 327680, 2048, stream=stream0) | |
| del buf1575 | |
| buf1579 = buf1560; del buf1560 # reuse | |
| # Topologically Sorted Source Nodes: [sum_131, convert_element_type_1068, all_reduce_149], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused_all_reduce_native_layer_norm_backward_28.run(buf1576, buf1579, 2048, 160, stream=stream0) | |
| del buf1576 | |
| # Topologically Sorted Source Nodes: [convert_element_type_1068, all_reduce_149], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1579, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_149], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1579) | |
| buf1584 = empty_strided_cuda((2048, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1071], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_2.run(buf1579, buf1584, 2048, stream=stream0) | |
| del buf1579 | |
| buf1586 = buf1569; del buf1569 # reuse | |
| # Topologically Sorted Source Nodes: [mul_480, x, mul_482, sub_119, sub_120, div_28, mul_483, convert_element_type_1067, permute_457, mm_193], Original ATen: [aten.native_layer_norm_backward, aten.native_layer_norm, aten.t, aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf1585, (2048, 327680), (1, 2048), 0), buf1568, out=buf1586) | |
| del buf1568 | |
| # Topologically Sorted Source Nodes: [all_reduce_150], Original ATen: [_c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1586, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_150], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1586) | |
| buf1591 = empty_strided_cuda((2048, 768), (768, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1074], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_31.run(buf1586, buf1591, 1572864, stream=stream0) | |
| del buf1586 | |
| buf1593 = buf1585; del buf1585 # reuse | |
| # Topologically Sorted Source Nodes: [layer_norm, div_27, mul_469, convert_element_type_1050, add_205, getitem, mask, getitem_1, mul_485, slice_21, clone_112, full_default_1], Original ATen: [aten.native_layer_norm, aten.native_layer_norm_backward, aten.add, aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone, aten.slice_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_add_clone_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_slice_backward_unsqueeze_32.run(buf1459, buf1473, buf1525, primals_1, buf1593, 671088640, stream=stream0) | |
| del buf1473 | |
| buf1595 = empty_strided_cuda((2560, 64), (64, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [add_206, slice_24, clone_113, convert_element_type_1075, mul_486, sum_133], Original ATen: [aten.add, aten.slice, aten.clone, aten._to_copy, aten.mul, aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_add_clone_mul_slice_sum_33.run(buf1592, buf1593, cos, buf1595, 163840, 2048, stream=stream0) | |
| del cos | |
| buf1596 = empty_strided_cuda((2560, ), (1, ), torch.float32) | |
| # Topologically Sorted Source Nodes: [add_206, slice_24, clone_113, convert_element_type_1075, mul_486, sum_133], Original ATen: [aten.add, aten.slice, aten.clone, aten._to_copy, aten.mul, aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_clone_mul_slice_sum_34.run(buf1595, buf1596, 2560, 64, stream=stream0) | |
| buf1598 = empty_strided_cuda((), (), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [add_206, slice_24, clone_113, convert_element_type_1075, mul_486, sum_133, convert_element_type_1076], Original ATen: [aten.add, aten.slice, aten.clone, aten._to_copy, aten.mul, aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_add_clone_mul_slice_sum_35.run(buf1596, buf1598, 1, 2560, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1076, all_reduce_151], Original ATen: [aten._to_copy, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1598, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_151], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1598) | |
| buf1603 = buf1459; del buf1459 # reuse | |
| # Topologically Sorted Source Nodes: [full_default, full_default_1, add_206, slice_24, clone_113, copy_5, slice_27, clone_114, copy_7, add_207], Original ATen: [aten.zeros_like, aten.slice_backward, aten.add, aten.slice, aten.clone, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_add_clone_copy_slice_slice_backward_zeros_like_36.run(buf1592, buf1593, buf1603, 671088640, stream=stream0) | |
| del buf1592 | |
| del buf1593 | |
| buf1604 = empty_strided_cuda((327680, 4), (4, 1), torch.uint8) | |
| # Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_4, contiguous_1], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_37.run(primals_1, buf1604, 1310720, stream=stream0) | |
| del primals_1 | |
| # Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_4, contiguous_1, view_1], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone, aten.view] | |
| buf1605 = torch.ops.aten.view.dtype(buf1604, torch.int32) | |
| buf1606 = buf1605 | |
| assert_size_stride(buf1606, (327680, 1), (1, 1), 'torch.ops.aten.view.dtype') | |
| assert_alignment(buf1606, 16, 'torch.ops.aten.view.dtype') | |
| buf1608 = buf1595; del buf1595 # reuse | |
| # Topologically Sorted Source Nodes: [slice_30, clone_115, convert_element_type_1078, positions, ifreqs, ifreqs_1, neg, freqs, getitem_7, getitem_8, mul_1, sin, mul_487, sum_134], Original ATen: [aten.slice, aten.clone, aten._to_copy, aten.select, aten.arange, aten.div, aten.neg, aten.pow, aten.unsqueeze, aten.mul, aten.sin, aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_arange_clone_div_mul_neg_pow_select_sin_slice_sum_unsqueeze_38.run(buf1603, buf1606, buf1608, 163840, 2048, stream=stream0) | |
| del buf1604 | |
| del buf1605 | |
| del buf1606 | |
| buf1609 = buf1596; del buf1596 # reuse | |
| # Topologically Sorted Source Nodes: [slice_30, clone_115, convert_element_type_1078, positions, ifreqs, ifreqs_1, neg, freqs, getitem_7, getitem_8, mul_1, sin, mul_487, sum_134], Original ATen: [aten.slice, aten.clone, aten._to_copy, aten.select, aten.arange, aten.div, aten.neg, aten.pow, aten.unsqueeze, aten.mul, aten.sin, aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_clone_mul_slice_sum_34.run(buf1608, buf1609, 2560, 64, stream=stream0) | |
| del buf1608 | |
| buf1611 = empty_strided_cuda((), (), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [slice_30, clone_115, convert_element_type_1078, positions, ifreqs, ifreqs_1, neg, freqs, getitem_7, getitem_8, mul_1, sin, mul_487, sum_134, convert_element_type_1079], Original ATen: [aten.slice, aten.clone, aten._to_copy, aten.select, aten.arange, aten.div, aten.neg, aten.pow, aten.unsqueeze, aten.mul, aten.sin, aten.sum] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_add_clone_mul_slice_sum_35.run(buf1609, buf1611, 1, 2560, stream=stream0) | |
| del buf1609 | |
| # Topologically Sorted Source Nodes: [convert_element_type_1079, all_reduce_152], Original ATen: [aten._to_copy, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1611, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_152], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1611) | |
| buf1616 = empty_strided_cuda((259, 2048), (2048, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [full_default_5], Original ATen: [aten.embedding_dense_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_embedding_dense_backward_39.run(buf1616, 530432, stream=stream0) | |
| buf1617 = buf1525; del buf1525 # reuse | |
| # Topologically Sorted Source Nodes: [slice_30, clone_115, copy_9, convert_element_type_1081, eq_2, unsqueeze_16, full_default_4, where], Original ATen: [aten.slice, aten.clone, aten.copy, aten.embedding_dense_backward] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_clone_copy_embedding_dense_backward_slice_40.run(select_1, buf1603, buf1617, 671088640, stream=stream0) | |
| del buf1603 | |
| aten.index_put_(buf1616, [select_1], buf1617, True) | |
| del buf1617 | |
| del select_1 | |
| buf1619 = empty_strided_cuda((259, 2048), (2048, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1082, all_reduce_153], Original ATen: [aten.embedding_dense_backward, _c10d_functional.all_reduce] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_all_reduce_embedding_dense_backward_41.run(buf1616, buf1619, 530432, stream=stream0) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1082, all_reduce_153], Original ATen: [aten.embedding_dense_backward, _c10d_functional.all_reduce] | |
| torch.ops._c10d_functional.all_reduce_.default(buf1619, 'avg', '0') | |
| # Topologically Sorted Source Nodes: [wait_tensor_153], Original ATen: [_c10d_functional.wait_tensor] | |
| torch.ops._c10d_functional.wait_tensor.default(buf1619) | |
| buf1624 = buf1616; del buf1616 # reuse | |
| # Topologically Sorted Source Nodes: [convert_element_type_1083], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_42.run(buf1619, buf1624, 530432, stream=stream0) | |
| del buf1619 | |
| buf1625 = empty_strided_cuda((), (), torch.float32) | |
| # Topologically Sorted Source Nodes: [convert_element_type_1077, convert_element_type_1080, add_209], Original ATen: [aten._to_copy, aten.add] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_add_43.run(buf1598, buf1611, buf1625, 1, stream=stream0) | |
| del buf1598 | |
| del buf1611 | |
| return (None, buf1624, buf1625, buf1591, buf1584, buf1578, buf1565, buf1559, buf1541, buf1535, buf1522, buf1515, buf1508, None, None, None, None, None, None, None, None, None, None, None, None, None, None, buf1494, buf1471, buf1465, buf1452, buf1447, buf1437, buf1432, buf1415, buf1409, buf1395, buf1388, buf1381, buf1367, buf1344, buf1338, buf1325, buf1320, buf1310, buf1305, buf1288, buf1282, buf1268, buf1261, buf1254, buf1240, buf1217, buf1211, buf1198, buf1193, buf1183, buf1178, buf1161, buf1155, buf1141, buf1134, buf1127, buf1113, buf1090, buf1084, buf1071, buf1066, buf1056, buf1051, buf1034, buf1028, buf1014, buf1007, buf1000, buf986, buf963, buf957, buf944, buf939, buf929, buf924, buf907, buf901, buf887, buf880, buf873, buf859, buf836, buf830, buf817, buf812, buf802, buf797, buf780, buf774, buf760, buf753, buf746, buf732, buf709, buf703, buf690, buf685, buf675, buf670, buf653, buf647, buf633, buf626, buf619, buf605, buf582, buf576, buf563, buf558, buf548, buf543, buf526, buf520, buf506, buf499, buf492, buf478, buf455, buf449, buf436, buf431, buf421, buf416, buf399, buf393, buf379, buf372, buf365, buf351, buf328, buf322, buf309, buf304, buf294, buf289, buf272, buf266, buf252, buf245, buf238, buf224, buf201, buf195, buf182, buf177, buf167, buf162, buf145, buf139, buf125, buf118, buf111, buf97, buf74, buf68, buf55, buf50, buf40, buf35, buf18, buf12, ) | |
| runner = Runner(partitions=[]) | |
| call = runner.call | |
| recursively_apply_fns = runner.recursively_apply_fns | |
| def get_args(): | |
| from torch._dynamo.testing import rand_strided | |
| primals_16 = 131 | |
| primals_20 = 131 | |
| primals_23 = 131 | |
| primals_26 = 131 | |
| primals_1 = rand_strided((327680, 785), (785, 1), device='cuda:0', dtype=torch.uint8) | |
| primals_4 = rand_strided((2048, 768), (768, 1), device='cuda:0', dtype=torch.float32) | |
| primals_5 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_9 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_10 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_11 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_12 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_13 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_14 = rand_strided((327680, ), (1, ), device='cuda:0', dtype=torch.int64) | |
| primals_15 = rand_strided((327680, ), (1, ), device='cuda:0', dtype=torch.int64) | |
| primals_17 = rand_strided((1, 1, 2560, 131), (335360, 335360, 131, 1), device='cuda:0', dtype=torch.int32) | |
| primals_18 = rand_strided((1, 1, 2560), (2560, 2560, 1), device='cuda:0', dtype=torch.int32) | |
| primals_19 = rand_strided((1, 1, 2560), (2560, 2560, 1), device='cuda:0', dtype=torch.int32) | |
| primals_21 = rand_strided((1, 1, 2560, 131), (335360, 335360, 131, 1), device='cuda:0', dtype=torch.int32) | |
| primals_22 = rand_strided((1, 1, 2560), (2560, 2560, 1), device='cuda:0', dtype=torch.int32) | |
| primals_24 = rand_strided((1, 1, 2560, 131), (335360, 335360, 131, 1), device='cuda:0', dtype=torch.int32) | |
| primals_25 = rand_strided((1, 1, 2560), (2560, 2560, 1), device='cuda:0', dtype=torch.int32) | |
| primals_27 = rand_strided((1, 1, 2560, 131), (335360, 335360, 131, 1), device='cuda:0', dtype=torch.int32) | |
| primals_28 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_29 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_30 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_31 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_32 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_33 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32) | |
| primals_35 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_36 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_37 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_38 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_39 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_40 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_41 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_42 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_43 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_44 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_45 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32) | |
| primals_47 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_48 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_49 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_50 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_51 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_52 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_53 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_54 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_55 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_56 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_57 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32) | |
| primals_59 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_60 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_61 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_62 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_63 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_64 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_65 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_66 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_67 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_68 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_69 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32) | |
| primals_71 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_72 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_73 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_74 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_75 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_76 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_77 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_78 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_79 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_80 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_81 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32) | |
| primals_83 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_84 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_85 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_86 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_87 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_88 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_89 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_90 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_91 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_92 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_93 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32) | |
| primals_95 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_96 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_97 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_98 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_99 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_100 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_101 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_102 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_103 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_104 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_105 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32) | |
| primals_107 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_108 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_109 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_110 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_111 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_112 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_113 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_114 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_115 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_116 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_117 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32) | |
| primals_119 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_120 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_121 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_122 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_123 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_124 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_125 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_126 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_127 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_128 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_129 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32) | |
| primals_131 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_132 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_133 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_134 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_135 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_136 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_137 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_138 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_139 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_140 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_141 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32) | |
| primals_143 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_144 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_145 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_146 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_147 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_148 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_149 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_150 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_151 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_152 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_153 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32) | |
| primals_155 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_156 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_157 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_158 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_159 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_160 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_161 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_162 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_163 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32) | |
| primals_164 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| primals_165 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32) | |
| primals_167 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32) | |
| select_1 = rand_strided((327680, ), (1, ), device='cuda:0', dtype=torch.int64) | |
| cos = rand_strided((327680, 1024), (1024, 1), device='cuda:0', dtype=torch.float32) | |
| add_11 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_14 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_19 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_22 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_27 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_30 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_35 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_38 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_43 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_46 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_51 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_54 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_59 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_62 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_67 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_70 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_75 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_78 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_83 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_86 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_91 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_94 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_99 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_102 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| add_107 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| getitem_89 = rand_strided((327680, 1), (1, 1), device='cuda:0', dtype=torch.float32) | |
| rsqrt_26 = rand_strided((327680, 1), (1, 1), device='cuda:0', dtype=torch.float32) | |
| tangents_1 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16) | |
| return [primals_16, primals_20, primals_23, primals_26, primals_1, primals_4, primals_5, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_17, primals_18, primals_19, primals_21, primals_22, primals_24, primals_25, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_167, select_1, cos, add_11, add_14, add_19, add_22, add_27, add_30, add_35, add_38, add_43, add_46, add_51, add_54, add_59, add_62, add_67, add_70, add_75, add_78, add_83, add_86, add_91, add_94, add_99, add_102, add_107, getitem_89, rsqrt_26, tangents_1] | |
| def benchmark_compiled_module(args, times=10, repeat=10): | |
| from torch._inductor.utils import print_performance | |
| fn = lambda: call(list(args)) | |
| return print_performance(fn, times=times, repeat=repeat) | |
| if __name__ == "__main__": | |
| from torch._inductor.wrapper_benchmark import compiled_module_main | |
| args = get_args() | |
| compiled_module_main('None', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment