Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created March 2, 2026 23:07
Show Gist options
  • Select an option

  • Save shunting314/6fe4e931f7e3bd98e1c936b4b1135a5f to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/6fe4e931f7e3bd98e1c936b4b1135a5f to your computer and use it in GitHub Desktop.
# AOT ID: ['0_backward']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
# kernel path: /tmp/torchinductor/rank0/wb/cwbxlmflrtfpjh5rwf4t5k27entxvmvn33pehlygc5askur3dbxd.py
# Topologically Sorted Source Nodes: [convert_element_type_410, redistribute_4, convert_element_type_412, mul_153, mul_154, sum_1, x_26, mul_155, sum_2, mul_156, sub_29, sub_30, div_3, mul_157, mul_158, sum_3, sum_4, convert_element_type_414], Original ATen: [aten.native_layer_norm_backward, aten._to_copy, aten.native_layer_norm]
# Source node to ATen node mapping:
# convert_element_type_410 => convert_element_type_410
# convert_element_type_412 => convert_element_type_412
# convert_element_type_414 => convert_element_type_414
# div_3 => div_3
# mul_153 => mul_153
# mul_154 => mul_154
# mul_155 => mul_155
# mul_156 => mul_156
# mul_157 => mul_157
# mul_158 => mul_158
# redistribute_4 => convert_element_type_406
# sub_29 => sub_29
# sub_30 => sub_30
# sum_1 => sum_1
# sum_2 => sum_2
# sum_3 => sum_3
# sum_4 => sum_4
# x_26 => convert_element_type_408, mul_150, sub_27
# Graph fragment:
# %tangents_1 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=tangents_1]
# %primals_167 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=primals_167]
# %add_107 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_107]
# %getitem_89 : Tensor "f32[327680, 1][1, 1]cuda:0" = PlaceHolder[target=getitem_89]
# %rsqrt_26 : Tensor "f32[327680, 1][1, 1]cuda:0" = PlaceHolder[target=rsqrt_26]
# %sum_1 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_1]
# %sum_2 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_2]
# %convert_element_type_410 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%tangents_1, torch.float32), kwargs = {})
# %convert_element_type_406 : Tensor "bf16[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_167, torch.bfloat16), kwargs = {})
# %convert_element_type_412 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_406, torch.float32), kwargs = {})
# %mul_153 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_410, %convert_element_type_412), kwargs = {})
# %mul_154 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_153, 2048), kwargs = {})
# %sum_1 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_153, [1], True), kwargs = {})
# %convert_element_type_408 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_107, torch.float32), kwargs = {})
# %sub_27 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_408, %getitem_89), kwargs = {})
# %mul_150 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_27, %rsqrt_26), kwargs = {})
# %mul_155 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_153, %mul_150), kwargs = {})
# %sum_2 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_155, [1], True), kwargs = {})
# %mul_156 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_150, %sum_2), kwargs = {})
# %sub_29 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_154, %sum_1), kwargs = {})
# %sub_30 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_29, %mul_156), kwargs = {})
# %div_3 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%rsqrt_26, 2048), kwargs = {})
# %mul_157 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_3, %sub_30), kwargs = {})
# %mul_158 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_410, %mul_150), kwargs = {})
# %sum_3 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_158, [0]), kwargs = {})
# %sum_4 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_410, [0]), kwargs = {})
# %convert_element_type_414 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_157, torch.bfloat16), kwargs = {})
# return %sum_1,%sum_2,%convert_element_type_414
triton_per_fused__to_copy_native_layer_norm_native_layer_norm_backward_0 = async_compile.triton('triton_per_fused__to_copy_native_layer_norm_native_layer_norm_backward_0', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 524288, 'r0_': 2048},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'out_ptr2': '*bf16', 'ws_ptr': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'RSPLIT_SIZE': 'constexpr', 'NUM_STAGES': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'MixOrderReductionGrid', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_native_layer_norm_native_layer_norm_backward_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 5, 'num_store': -1, 'num_reduction': 2, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'RSPLIT_SIZE': 128, 'has_loadstore_with_contiguous_rdim': True}
)
@triton.jit
def triton_per_fused__to_copy_native_layer_norm_native_layer_norm_backward_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, ws_ptr, xnumel, r0_numel, XBLOCK : tl.constexpr, RSPLIT_SIZE : tl.constexpr, NUM_STAGES : tl.constexpr):
xnumel = 327680
r0_numel = 2048
R0_BLOCK: tl.constexpr = 2048
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * RSPLIT_SIZE
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_index = tl.arange(0, R0_BLOCK)[None, :]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
accum0 = tl.full([R0_BLOCK], 0, tl.float32)[None, :]
accum1 = tl.full([R0_BLOCK], 0, tl.float32)[None, :]
split_size = min(RSPLIT_SIZE, xnumel - xoffset)
for _ in tl.range(0, split_size, XBLOCK, num_stages=NUM_STAGES):
x0 = xindex
xindex += XBLOCK
tmp0 = tl.load(in_ptr0 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last')
tmp9 = tl.load(in_ptr2 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp11 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp13 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp3.to(tl.float32)
tmp5 = tmp1 * tmp4
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
tmp8 = tl.sum(tmp6, 1)[:, None].to(tl.float32)
tmp10 = tmp9.to(tl.float32)
tmp12 = tmp10 - tmp11
tmp14 = tmp12 * tmp13
tmp15 = tmp5 * tmp14
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, R0_BLOCK])
tmp18 = tl.sum(tmp16, 1)[:, None].to(tl.float32)
tmp19 = tl.full([1, 1], 0.00048828125, tl.float32)
tmp20 = tmp13 * tmp19
tmp21 = tl.full([1, 1], 2048.0, tl.float32)
tmp22 = tmp5 * tmp21
tmp23 = tmp22 - tmp8
tmp24 = tmp14 * tmp18
tmp25 = tmp23 - tmp24
tmp26 = tmp20 * tmp25
tmp27 = tmp26.to(tl.float32)
tmp28 = tmp1 * tmp14
tl.store(out_ptr2 + (r0_1 + 2048*x0), tmp27, None)
tmp29 = tl.sum(tmp28, 0)
tmp30 = accum0 + tmp29
accum0 = tmp30
tmp31 = tl.sum(tmp1, 0)
tmp32 = accum1 + tmp31
accum1 = tmp32
tl.store(ws_ptr + (tl.program_id(0) + 0 * tl.num_programs(0)) * r0_numel + r0_index, accum0, r0_mask)
tl.store(ws_ptr + (tl.program_id(0) + 1 * tl.num_programs(0)) * r0_numel + r0_index, accum1, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/zm/czmvhutdr45dzcouphyn6ypgzhuegd4eldyd5fc7qmihrwdbp2aa.py
# Topologically Sorted Source Nodes: [convert_element_type_415, all_reduce_1], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
# Source node to ATen node mapping:
# all_reduce_1 => all_reduce_1
# convert_element_type_415 => convert_element_type_415
# Graph fragment:
# %sum_3 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=sum_3]
# %convert_element_type_415 : Tensor "bf16[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_3, torch.bfloat16), kwargs = {})
# %all_reduce_1 : Tensor "bf16[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce_.default](args = (%convert_element_type_415, avg, 0), kwargs = {})
# return %wait_tensor_1
triton_poi_fused_all_reduce_native_layer_norm_backward_1 = async_compile.triton('triton_poi_fused_all_reduce_native_layer_norm_backward_1', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 2048},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_all_reduce_native_layer_norm_backward_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 16384}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_all_reduce_native_layer_norm_backward_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2048
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/oo/coovzos7dfgo6q6d2x5csilspzo54pvavar6efs5zmoocbr3baq6.py
# Topologically Sorted Source Nodes: [convert_element_type_417], Original ATen: [aten._to_copy]
# Source node to ATen node mapping:
# convert_element_type_417 => convert_element_type_417
# Graph fragment:
# %buf11 : Tensor = PlaceHolder[target=buf11]
# %convert_element_type_417 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor, torch.float32), kwargs = {})
# return %convert_element_type_417
triton_poi_fused__to_copy_2 = async_compile.triton('triton_poi_fused__to_copy_2', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 2048},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 16384}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2048
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/74/c745s7utu4xbymzze36ogp6jnsxqu2olhyem7egezxhjzg2idgqd.py
# Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy]
# Source node to ATen node mapping:
# redistribute_4 => convert_element_type_401
# Graph fragment:
# %primals_165 : Tensor "f32[2048, 8192][8192, 1]cuda:0" = PlaceHolder[target=primals_165]
# %convert_element_type_401 : Tensor "bf16[2048, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_165, torch.bfloat16), kwargs = {})
# return %convert_element_type_401
triton_poi_fused__to_copy_3 = async_compile.triton('triton_poi_fused__to_copy_3', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 16777216},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 134217728}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_3(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 16777216
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/f2/cf2xntzp4edf6e25wjnwlziltupvbuqpcnzw5jxzmb4pxuc3x7r4.py
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
# Source node to ATen node mapping:
# layer_norm => add_103, add_104, convert_element_type_392, convert_element_type_393, mul_142, mul_143, rsqrt_25, sub_26, var_mean_25
# redistribute => convert_element_type_390
# redistribute_1 => convert_element_type_391
# Graph fragment:
# %add_102 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_102]
# %getitem_87 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=getitem_87]
# %buf22 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf22]
# %primals_161 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=primals_161]
# %primals_162 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=primals_162]
# %convert_element_type_390 : Tensor "bf16[2048][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_161, torch.bfloat16), kwargs = {})
# %convert_element_type_391 : Tensor "bf16[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_162, torch.bfloat16), kwargs = {})
# %convert_element_type_392 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_102, torch.float32), kwargs = {})
# %var_mean_25 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_392, [1]), kwargs = {correction: 0, keepdim: True})
# %add_103 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_86, 1e-05), kwargs = {})
# %rsqrt_25 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_103,), kwargs = {})
# %sub_26 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_392, %getitem_87), kwargs = {})
# %mul_142 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_26, %rsqrt_25), kwargs = {})
# %mul_143 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_142, %convert_element_type_390), kwargs = {})
# %add_104 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_143, %convert_element_type_391), kwargs = {})
# %convert_element_type_393 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_104, torch.bfloat16), kwargs = {})
# return %getitem_87,%buf22,%convert_element_type_393
triton_red_fused__to_copy_native_layer_norm_4 = async_compile.triton('triton_red_fused__to_copy_native_layer_norm_4', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 2048},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_native_layer_norm_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 3, 'num_reduction': 2, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'x': 5242880, 'r0_': 4026548224}}
)
@triton.jit
def triton_red_fused__to_copy_native_layer_norm_4(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 327680
r0_numel = 2048
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
tmp3_mean = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
tmp3_m2 = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
tmp3_weight = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_1 + 2048*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
tmp3_mean_next, tmp3_m2_next, tmp3_weight_next = triton_helpers.welford_reduce(
tmp2, tmp3_mean, tmp3_m2, tmp3_weight, roffset == 0
)
tmp3_mean = tl.where(r0_mask, tmp3_mean_next, tmp3_mean)
tmp3_m2 = tl.where(r0_mask, tmp3_m2_next, tmp3_m2)
tmp3_weight = tl.where(r0_mask, tmp3_weight_next, tmp3_weight)
tmp4, tmp5, tmp6 = triton_helpers.welford(tmp3_mean, tmp3_m2, tmp3_weight, 1)
tmp3 = tmp4[:, None]
tmp7 = tmp5[:, None]
tmp8 = tmp6[:, None]
tl.store(out_ptr0 + (x0), tmp3, None)
tl.store(out_ptr1 + (x0), tmp7, None)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp9 = tl.load(in_ptr0 + (r0_1 + 2048*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp18 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0)
tmp22 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0)
tmp10 = tmp9.to(tl.float32)
tmp11 = tmp10 - tmp3
tmp12 = tl.full([1, 1], 2048.0, tl.float32)
tmp13 = (tmp7 / tmp12)
tmp14 = tl.full([1, 1], 1e-05, tl.float32)
tmp15 = tmp13 + tmp14
tmp16 = libdevice.rsqrt(tmp15)
tmp17 = tmp11 * tmp16
tmp19 = tmp18.to(tl.float32)
tmp20 = tmp19.to(tl.float32)
tmp21 = tmp17 * tmp20
tmp23 = tmp22.to(tl.float32)
tmp24 = tmp23.to(tl.float32)
tmp25 = tmp21 + tmp24
tmp26 = tmp25.to(tl.float32)
tl.store(out_ptr2 + (r0_1 + 2048*x0), tmp26, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/k2/ck2gagwbhzuttn76q4burnrkjixhz2qr32zm5ndumumvzjoihsxw.py
# Topologically Sorted Source Nodes: [redistribute_3, x, x_25, convert_element_type_425, mul_164, mul_165, sub_31, mul_166, add_112, mul_167, mul_168, mul_169, add_113, mul_170, convert_element_type_427], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward]
# Source node to ATen node mapping:
# add_112 => add_112
# add_113 => add_113
# convert_element_type_425 => convert_element_type_425
# convert_element_type_427 => convert_element_type_427
# mul_164 => mul_164
# mul_165 => mul_165
# mul_166 => mul_166
# mul_167 => mul_167
# mul_168 => mul_168
# mul_169 => mul_169
# mul_170 => mul_170
# redistribute_3 => convert_element_type_395
# sub_31 => sub_31
# x => add_tensor_11
# x_25 => add_105, add_106, convert_element_type_399, convert_element_type_400, mul_144, mul_145, mul_146, mul_147, mul_148, mul_149, tanh_11
# Graph fragment:
# %primals_164 : Tensor "f32[8192][1]cuda:0" = PlaceHolder[target=primals_164]
# %mm_default_11 : Tensor "bf16[327680, 8192][8192, 1]cuda:0" = PlaceHolder[target=mm_default_11]
# %mm_49 : Tensor "bf16[327680, 8192][8192, 1]cuda:0" = PlaceHolder[target=mm_49]
# %convert_element_type_395 : Tensor "bf16[8192][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_164, torch.bfloat16), kwargs = {})
# %add_tensor_11 : Tensor "bf16[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_395, %mm_default_11), kwargs = {})
# %convert_element_type_399 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_tensor_11, torch.float32), kwargs = {})
# %mul_144 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_399, %convert_element_type_399), kwargs = {})
# %mul_145 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_144, %convert_element_type_399), kwargs = {})
# %mul_146 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_145, 0.044715), kwargs = {})
# %add_105 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_399, %mul_146), kwargs = {})
# %mul_147 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_105, 0.7978845608028654), kwargs = {})
# %mul_148 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_399, 0.5), kwargs = {})
# %tanh_11 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_147,), kwargs = {})
# %add_106 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh_11, 1), kwargs = {})
# %mul_149 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_148, %add_106), kwargs = {})
# %convert_element_type_400 : Tensor "bf16[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_149, torch.bfloat16), kwargs = {})
# %convert_element_type_425 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_49, torch.float32), kwargs = {})
# %mul_164 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_106, 0.5), kwargs = {})
# %mul_165 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tanh_11, %tanh_11), kwargs = {})
# %sub_31 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (1, %mul_165), kwargs = {})
# %mul_166 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_144, 0.134145), kwargs = {})
# %add_112 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_166, 1), kwargs = {})
# %mul_167 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_112, 0.7978845608028654), kwargs = {})
# %mul_168 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_148, %sub_31), kwargs = {})
# %mul_169 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_168, %mul_167), kwargs = {})
# %add_113 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_164, %mul_169), kwargs = {})
# %mul_170 : Tensor "f32[327680, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_425, %add_113), kwargs = {})
# %convert_element_type_427 : Tensor "bf16[327680, 8192][8192, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_170, torch.bfloat16), kwargs = {})
# return %convert_element_type_400,%convert_element_type_427
triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5 = async_compile.triton('triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 4294967296},
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*fp32', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i64', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 2, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 32212287488}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5(in_out_ptr0, in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2684354560
xoffset = tl.program_id(0).to(tl.int64) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:].to(tl.int64)
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x0 = (xindex % 8192)
x2 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp2 = tl.load(in_ptr1 + (x2), None).to(tl.float32)
tmp19 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp1 + tmp2
tmp4 = tmp3.to(tl.float32)
tmp5 = tl.full([1], 0.5, tl.float32)
tmp6 = tmp4 * tmp5
tmp7 = tmp4 * tmp4
tmp8 = tmp7 * tmp4
tmp9 = tl.full([1], 0.044715, tl.float32)
tmp10 = tmp8 * tmp9
tmp11 = tmp4 + tmp10
tmp12 = tl.full([1], 0.7978845608028654, tl.float32)
tmp13 = tmp11 * tmp12
tmp14 = libdevice.tanh(tmp13)
tmp15 = tl.full([1], 1.0, tl.float32)
tmp16 = tmp14 + tmp15
tmp17 = tmp6 * tmp16
tmp18 = tmp17.to(tl.float32)
tmp20 = tmp19.to(tl.float32)
tmp21 = tmp16 * tmp5
tmp22 = tmp14 * tmp14
tmp23 = tmp15 - tmp22
tmp24 = tmp6 * tmp23
tmp25 = tl.full([1], 0.134145, tl.float32)
tmp26 = tmp7 * tmp25
tmp27 = tmp26 + tmp15
tmp28 = tmp27 * tmp12
tmp29 = tmp24 * tmp28
tmp30 = tmp21 + tmp29
tmp31 = tmp20 * tmp30
tmp32 = tmp31.to(tl.float32)
tl.store(out_ptr0 + (x2), tmp18, None)
tl.store(in_out_ptr0 + (x2), tmp32, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/t6/ct6itkhhsr67p6nullejwydl2p56bqr5cedys6o4gndtdlcev3mu.py
# Topologically Sorted Source Nodes: [sum_5], Original ATen: [aten.sum]
# Source node to ATen node mapping:
# sum_5 => sum_5
# Graph fragment:
# %convert_element_type_414 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=convert_element_type_414]
# %sum_5 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_414, [0], True), kwargs = {})
# return %buf29
triton_red_fused_sum_6 = async_compile.triton('triton_red_fused_sum_6', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 2048},
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 1344798720, 'r0_': 0}}
)
@triton.jit
def triton_red_fused_sum_6(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 327680
r0_numel = 2048
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = (xindex % 2048)
x1 = xindex // 2048
_tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
x3 = xindex
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (x0 + 2048*r0_2 + 4194304*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(r0_mask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp2, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/p2/cp2yuujssx45mwam4gmghu2bmmldkfzv4gngwd75ivayqdq2igwq.py
# Topologically Sorted Source Nodes: [sum_5], Original ATen: [aten.sum]
# Source node to ATen node mapping:
# sum_5 => sum_5
# Graph fragment:
# %buf29 : Tensor "f32[1, 2048, 160][327680, 1, 2048]cuda:0" = PlaceHolder[target=buf29]
# %sum_5 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_414, [0], True), kwargs = {})
# return %sum_5
triton_red_fused_sum_7 = async_compile.triton('triton_red_fused_sum_7', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 2048, 'r0_': 256},
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_7', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 1318912, 'r0_': 0}}
)
@triton.jit
def triton_red_fused_sum_7(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 2048
r0_numel = 160
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (x0 + 2048*r0_1), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp2, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/hh/chhznpxzkrjd6hw2dultp3lkcxpgo3xbcosobcrvsssw26crs3pc.py
# Topologically Sorted Source Nodes: [convert_element_type_424], Original ATen: [aten._to_copy]
# Source node to ATen node mapping:
# convert_element_type_424 => convert_element_type_424
# Graph fragment:
# %buf39 : Tensor = PlaceHolder[target=buf39]
# %convert_element_type_424 : Tensor "f32[2048, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_3, torch.float32), kwargs = {})
# return %convert_element_type_424
triton_poi_fused__to_copy_8 = async_compile.triton('triton_poi_fused__to_copy_8', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 16777216},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_8', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 134217728}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_8(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 16777216
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/y4/cy4awr42ljhsqyf72s5jzm7g3pbdgrndgk7ry4ij6viltpmia7db.py
# Topologically Sorted Source Nodes: [sum_6], Original ATen: [aten.sum]
# Source node to ATen node mapping:
# sum_6 => sum_6
# Graph fragment:
# %convert_element_type_427 : Tensor "bf16[327680, 8192][8192, 1]cuda:0" = PlaceHolder[target=convert_element_type_427]
# %sum_6 : Tensor "bf16[1, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_427, [0], True), kwargs = {})
# return %buf44
triton_red_fused_sum_9 = async_compile.triton('triton_red_fused_sum_9', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 2097152, 'r0_': 2048},
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_9', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 5379194880, 'r0_': 0}}
)
@triton.jit
def triton_red_fused_sum_9(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 1310720
r0_numel = 2048
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0).to(tl.int64) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64)
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64)
rbase = r0_base
x0 = (xindex % 8192)
x1 = xindex // 8192
_tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
x3 = xindex
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (x0 + 8192*r0_2 + 16777216*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(r0_mask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp2, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/qk/cqkqenwa2t4jzfp5qb2mrzo3mnfgjl2ns4mywtxjzal2ymd6ew3v.py
# Topologically Sorted Source Nodes: [sum_6], Original ATen: [aten.sum]
# Source node to ATen node mapping:
# sum_6 => sum_6
# Graph fragment:
# %buf44 : Tensor "f32[1, 8192, 160][1310720, 1, 8192]cuda:0" = PlaceHolder[target=buf44]
# %sum_6 : Tensor "bf16[1, 8192][8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_427, [0], True), kwargs = {})
# return %sum_6
triton_red_fused_sum_10 = async_compile.triton('triton_red_fused_sum_10', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 8192, 'r0_': 256},
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 5275648, 'r0_': 0}}
)
@triton.jit
def triton_red_fused_sum_10(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 8192
r0_numel = 160
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (x0 + 8192*r0_1), r0_mask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(r0_mask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp2, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/wt/cwt4nuarmo4faznyu2vyoyay5466bt27gffcjcknubmh3gfamc73.py
# Topologically Sorted Source Nodes: [view_280, convert_element_type_432], Original ATen: [aten.view, aten._to_copy]
# Source node to ATen node mapping:
# convert_element_type_432 => convert_element_type_432
# view_280 => view_280
# Graph fragment:
# %buf49 : Tensor = PlaceHolder[target=buf49]
# %view_280 : Tensor "bf16[8192][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%sum_6, [8192]), kwargs = {})
# %convert_element_type_432 : Tensor "f32[8192][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_4, torch.float32), kwargs = {})
# return %convert_element_type_432
triton_poi_fused__to_copy_view_11 = async_compile.triton('triton_poi_fused__to_copy_view_11', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 8192},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_view_11', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 65536}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_view_11(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 8192
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/gu/cguot2zrmpx5qc3kxoddiexeuvmukqdb3jgmu5cfce2w2eczd7c7.py
# Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_434, convert_element_type_436, mul_172, mul_173, sum_7, mul_174, sum_8, mul_175, sub_33, sub_34, div_4, mul_176, mul_177, sum_9, sum_10, convert_element_type_438, add_114], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add]
# Source node to ATen node mapping:
# add_114 => add_114
# convert_element_type_434 => convert_element_type_434
# convert_element_type_436 => convert_element_type_436
# convert_element_type_438 => convert_element_type_438
# div_4 => div_4
# layer_norm => add_103, convert_element_type_392, mul_142, rsqrt_25, sub_26, var_mean_25
# mul_172 => mul_172
# mul_173 => mul_173
# mul_174 => mul_174
# mul_175 => mul_175
# mul_176 => mul_176
# mul_177 => mul_177
# redistribute => convert_element_type_390
# sub_33 => sub_33
# sub_34 => sub_34
# sum_10 => sum_10
# sum_7 => sum_7
# sum_8 => sum_8
# sum_9 => sum_9
# Graph fragment:
# %mm_51 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_51]
# %primals_161 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=primals_161]
# %add_102 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_102]
# %getitem_87 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=getitem_87]
# %buf22 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf22]
# %convert_element_type_414 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=convert_element_type_414]
# %sum_7 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_7]
# %sum_8 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_8]
# %convert_element_type_390 : Tensor "bf16[2048][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_161, torch.bfloat16), kwargs = {})
# %convert_element_type_392 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_102, torch.float32), kwargs = {})
# %var_mean_25 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_392, [1]), kwargs = {correction: 0, keepdim: True})
# %add_103 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_86, 1e-05), kwargs = {})
# %rsqrt_25 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_103,), kwargs = {})
# %sub_26 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_392, %getitem_87), kwargs = {})
# %mul_142 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_26, %rsqrt_25), kwargs = {})
# %convert_element_type_434 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_51, torch.float32), kwargs = {})
# %convert_element_type_436 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_390, torch.float32), kwargs = {})
# %mul_172 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_434, %convert_element_type_436), kwargs = {})
# %mul_173 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_172, 2048), kwargs = {})
# %sum_7 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_172, [1], True), kwargs = {})
# %mul_174 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_172, %mul_142), kwargs = {})
# %sum_8 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_174, [1], True), kwargs = {})
# %mul_175 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_142, %sum_8), kwargs = {})
# %sub_33 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_173, %sum_7), kwargs = {})
# %sub_34 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_33, %mul_175), kwargs = {})
# %div_4 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%rsqrt_25, 2048), kwargs = {})
# %mul_176 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_4, %sub_34), kwargs = {})
# %mul_177 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_434, %mul_142), kwargs = {})
# %sum_9 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_177, [0]), kwargs = {})
# %sum_10 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_434, [0]), kwargs = {})
# %convert_element_type_438 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_176, torch.bfloat16), kwargs = {})
# %add_114 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_414, %convert_element_type_438), kwargs = {})
# return %sum_7,%sum_8,%add_114
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12 = async_compile.triton('triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 524288, 'r0_': 2048},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'ws_ptr': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'RSPLIT_SIZE': 'constexpr', 'NUM_STAGES': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'MixOrderReductionGrid', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 6, 'num_store': -1, 'num_reduction': 2, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'RSPLIT_SIZE': 128, 'has_loadstore_with_contiguous_rdim': True}
)
@triton.jit
def triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, ws_ptr, xnumel, r0_numel, XBLOCK : tl.constexpr, RSPLIT_SIZE : tl.constexpr, NUM_STAGES : tl.constexpr):
xnumel = 327680
r0_numel = 2048
R0_BLOCK: tl.constexpr = 2048
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * RSPLIT_SIZE
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_index = tl.arange(0, R0_BLOCK)[None, :]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
accum0 = tl.full([R0_BLOCK], 0, tl.float32)[None, :]
accum1 = tl.full([R0_BLOCK], 0, tl.float32)[None, :]
split_size = min(RSPLIT_SIZE, xnumel - xoffset)
for _ in tl.range(0, split_size, XBLOCK, num_stages=NUM_STAGES):
x0 = xindex
xindex += XBLOCK
tmp0 = tl.load(in_ptr0 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last')
tmp9 = tl.load(in_ptr2 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp11 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp13 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
tmp24 = tl.load(in_out_ptr0 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp3.to(tl.float32)
tmp5 = tmp1 * tmp4
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
tmp8 = tl.sum(tmp6, 1)[:, None].to(tl.float32)
tmp10 = tmp9.to(tl.float32)
tmp12 = tmp10 - tmp11
tmp14 = tl.full([1, 1], 2048.0, tl.float32)
tmp15 = (tmp13 / tmp14)
tmp16 = tl.full([1, 1], 1e-05, tl.float32)
tmp17 = tmp15 + tmp16
tmp18 = libdevice.rsqrt(tmp17)
tmp19 = tmp12 * tmp18
tmp20 = tmp5 * tmp19
tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK])
tmp23 = tl.sum(tmp21, 1)[:, None].to(tl.float32)
tmp25 = tl.full([1, 1], 0.00048828125, tl.float32)
tmp26 = tmp18 * tmp25
tmp27 = tmp5 * tmp14
tmp28 = tmp27 - tmp8
tmp29 = tmp19 * tmp23
tmp30 = tmp28 - tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp31.to(tl.float32)
tmp33 = tmp24 + tmp32
tmp34 = tmp1 * tmp19
tl.store(in_out_ptr0 + (r0_1 + 2048*x0), tmp33, None)
tmp35 = tl.sum(tmp34, 0)
tmp36 = accum0 + tmp35
accum0 = tmp36
tmp37 = tl.sum(tmp1, 0)
tmp38 = accum1 + tmp37
accum1 = tmp38
tl.store(ws_ptr + (tl.program_id(0) + 0 * tl.num_programs(0)) * r0_numel + r0_index, accum0, r0_mask)
tl.store(ws_ptr + (tl.program_id(0) + 1 * tl.num_programs(0)) * r0_numel + r0_index, accum1, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/wd/cwd4pi47eukzrwowauw3dsgtyi5t2umcku6ki7ijnndn5vrptyak.py
# Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t]
# Source node to ATen node mapping:
# linear => permute_111
# redistribute_2 => convert_element_type_378
# Graph fragment:
# %primals_157 : Tensor "f32[2048, 2048][2048, 1]cuda:0" = PlaceHolder[target=primals_157]
# %convert_element_type_378 : Tensor "bf16[2048, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_157, torch.bfloat16), kwargs = {})
# %permute_111 : Tensor "bf16[2048, 2048][1, 2048]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%convert_element_type_378, [1, 0]), kwargs = {})
# return %permute_111
triton_poi_fused__to_copy_t_13 = async_compile.triton('triton_poi_fused__to_copy_t_13', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 4194304},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_t_13', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 33554432}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_t_13(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4194304
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/we/cwe4ednvdc4atd6stlt2h5bnqtx47uoxynggv7myu77yu54lorzj.py
# Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t]
# Source node to ATen node mapping:
# linear_1 => permute_113
# redistribute_3 => convert_element_type_381
# Graph fragment:
# %primals_158 : Tensor "f32[512, 2048][2048, 1]cuda:0" = PlaceHolder[target=primals_158]
# %convert_element_type_381 : Tensor "bf16[512, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_158, torch.bfloat16), kwargs = {})
# %permute_113 : Tensor "bf16[2048, 512][1, 2048]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%convert_element_type_381, [1, 0]), kwargs = {})
# return %permute_113
triton_poi_fused__to_copy_t_14 = async_compile.triton('triton_poi_fused__to_copy_t_14', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 1048576},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_t_14', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 8388608}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_t_14(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1048576
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/zs/czstcwxp4irb2bhqocdicv4g7b67ruhh3iwzcbk4zgkvjkekysvk.py
# Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention]
# Source node to ATen node mapping:
# flex_attention => flex_attention_11
# k => permute_114, view_261, view_262
# q => permute_112, view_258, view_259
# v => permute_116, view_264, view_265
# Graph fragment:
# %mm_45 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_45]
# %mm_46 : Tensor "bf16[327680, 512][512, 1]cuda:0" = PlaceHolder[target=mm_46]
# %mm_47 : Tensor "bf16[327680, 512][512, 1]cuda:0" = PlaceHolder[target=mm_47]
# %getitem_84 : Tensor "f32[1, 16, 327680][5242880, 327680, 1]cuda:0" = PlaceHolder[target=getitem_84]
# %buf86 : Tensor "f32[1, 16, 327680][5242880, 327680, 1]cuda:0" = PlaceHolder[target=buf86]
# %primals_18 : Tensor "i32[1, 1, 2560][2560, 2560, 1]cuda:0" = PlaceHolder[target=primals_18]
# %primals_17 : Tensor "i32[1, 1, 2560, s91][2560*s91, 2560*s91, s91, 1]cuda:0" = PlaceHolder[target=primals_17]
# %primals_19 : Tensor "i32[1, 1, 2560][2560, 2560, 1]cuda:0" = PlaceHolder[target=primals_19]
# %primals_21 : Tensor "i32[1, 1, 2560, s6][2560*s6, 2560*s6, s6, 1]cuda:0" = PlaceHolder[target=primals_21]
# %primals_14 : Tensor "i64[327680][1]cuda:0" = PlaceHolder[target=primals_14]
# %primals_15 : Tensor "i64[327680][1]cuda:0" = PlaceHolder[target=primals_15]
# %view_258 : Tensor "bf16[327680, 16, 128][2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_45, [327680, 16, 128]), kwargs = {})
# %permute_112 : Tensor "bf16[16, 327680, 128][128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_258, [1, 0, 2]), kwargs = {})
# %view_259 : Tensor "bf16[1, 16, 327680, 128][2048, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_112, [1, 16, 327680, 128]), kwargs = {})
# %view_261 : Tensor "bf16[327680, 4, 128][512, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_46, [327680, 4, 128]), kwargs = {})
# %permute_114 : Tensor "bf16[4, 327680, 128][128, 512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_261, [1, 0, 2]), kwargs = {})
# %view_262 : Tensor "bf16[1, 4, 327680, 128][512, 128, 512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_114, [1, 4, 327680, 128]), kwargs = {})
# %view_264 : Tensor "bf16[327680, 4, 128][512, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_47, [327680, 4, 128]), kwargs = {})
# %permute_116 : Tensor "bf16[4, 327680, 128][128, 512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_264, [1, 0, 2]), kwargs = {})
# %view_265 : Tensor "bf16[1, 4, 327680, 128][512, 128, 512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_116, [1, 4, 327680, 128]), kwargs = {})
# %flex_attention_11 : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%view_259, %view_262, %view_265, %sdpa_score11, (327680, 327680, %primals_18, %primals_17, %primals_19, %primals_21, %primals_22, %primals_24, %primals_25, %primals_27, 128, 128, %sdpa_mask11), 0.08838834764831843, {BACKEND: AUTO, PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: True}, (), (%primals_14, %primals_15)), kwargs = {})
# return %getitem_83
triton_tem_fused_flex_attention_permute_view_15 = async_compile.triton('triton_tem_fused_flex_attention_permute_view_15', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
@triton_heuristics.template(
num_stages=3,
num_warps=8,
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'in_ptr10': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]]}]},
inductor_meta={'kernel_name': 'triton_tem_fused_flex_attention_permute_view_15', 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': True, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
)
@triton.jit
def triton_tem_fused_flex_attention_permute_view_15(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, in_ptr10, out_ptr0, ks0):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
OUTPUT_MAX : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'tf32'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.08838834764831843
GQA_SHARED_HEADS : tl.constexpr = 4
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
USE_TMA : tl.constexpr = False
BLOCK_M : tl.constexpr = 128
BLOCK_N : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
INDEX_DTYPE : tl.constexpr = tl.int32
Q = arg_Q
K = arg_K
V = arg_V
LSE = arg_LSE
MAX = arg_MAX
KV_NUM_BLKS = arg_KV_NUM_BLKS
KV_IDX = arg_KV_IDX
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
FULL_KV_IDX = arg_FULL_KV_IDX
# Sub notation for this kernel:
#
# Q: Query, K: Key, V: Value
# M: Number of queries, N: Number of keys/values, D: Model dimension
# QK_HEAD_DIM: The dimension of the query and key embeddings
# V_HEAD_DIM: The dimension of the value embeddings
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
#
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
#
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
#
# (Modifiable) Performance tuning options
# BLOCK_M: The thread block size across the seqlen dim of Q.
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
# The below are kernel options that can be applied for certain score_mods,
# or involve a numerics vs. perf tradeoff
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
# about 20% more numerical error, but slightly faster.
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
# is not masked out? If so, we can skip an extra safety check
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
# contiguous? If so, we don't need to do an indirect jump for every block
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
# Define strides of inputs
stride_qz, stride_qh, stride_qm, stride_qk = 2048, 128, 2048, 1
stride_kz, stride_kh, stride_kn, stride_kk = 512, 128, 512, 1
stride_vz, stride_vh, stride_vn, stride_vk = 512, 128, 512, 1
ZQ = 1
HQ = 16
Q_LEN = 327680
ZKV = 1
KV_LEN = 327680
MATMUL_PRECISION = Q.dtype.element_ty
q_start = tl.program_id(0).to(INDEX_DTYPE)
off_zq = tl.program_id(1).to(INDEX_DTYPE)
off_hq = tl.program_id(2).to(INDEX_DTYPE)
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
off_zkv = off_zq % ZKV
off_hkv = off_hq // GQA_SHARED_HEADS
off_g = off_hq % GQA_SHARED_HEADS
q_offset = off_zq * stride_qz + off_hq * stride_qh
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
Q = Q + q_offset
K = K + k_offset
V = V + v_offset
# Setting up the TMA descriptors for Q, K, V
desc_q = None
desc_k = None
desc_v = None
SPARSE_Z = 1
SPARSE_HQ = 1
sparse_idx_z = off_zq % SPARSE_Z
sparse_idx_hq = off_hq % SPARSE_HQ
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
stride_kv_num_blks_h = 2560
stride_kv_idx_h = 2560*ks0
stride_kv_idx_m = ks0
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
# KV_IDX and KV_NUM_BLKS are always contiguous.
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We don't know anything "special" about these blocks, so we need to apply
# both score_mod and mask_mod to it
kv_indices = KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
# K and V pointers will be passed directly to forward_inner
offs_n = kv_start + tl.arange(0, BLOCK_N)
acc, l_i, m_i = forward_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, in_ptr10, out_ptr0, ks0,
q, K, V,
desc_k, desc_v, Q_LEN, KV_LEN,
acc, l_i, m_i,
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
kv_start,
kv_indices, kv_num_blocks,
0, block_n_end,
MATMUL_PRECISION,
stride_kk, stride_kn, stride_vn, stride_vk,
IS_FULL_BLOCKS=False,
)
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We know these blocks are guaranteed to be "full", so we don't need to
# apply mask_mod to them - only score_mod
if HAS_FULL_BLOCKS:
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
# K and V pointers will be passed directly to forward_inner
offs_n = kv_start + tl.arange(0, BLOCK_N)
acc, l_i, m_i = forward_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, in_ptr10, out_ptr0, ks0,
q, K, V,
desc_k, desc_v, Q_LEN, KV_LEN,
acc, l_i, m_i,
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
kv_start,
kv_indices, kv_num_blocks,
0, block_n_end,
MATMUL_PRECISION,
stride_kk, stride_kn, stride_vn, stride_vk,
IS_FULL_BLOCKS=True,
)
# [Note] Handle fully masked out rows:
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
l_i = tl.where(l_i == 0.0, 1, l_i)
acc = acc / l_i[:, None]
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
idx_m = offs_m[:, None].to(INDEX_DTYPE)
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
xindex = idx_d + 128*idx_m + 41943040*idx_hq + 671088640*idx_zq
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 2048*idx_m, [BLOCK_M, V_HEAD_DIM_ROUNDED])), acc, mask)
if OUTPUT_LOGSUMEXP:
off_hz = off_zq * HQ + off_hq
l_ptrs = LSE + off_hz * Q_LEN + offs_m
lse = m_i + tl.math.log2(l_i)
if IS_DIVISIBLE:
tl.store(l_ptrs, lse)
else:
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
if OUTPUT_MAX:
off_hz = off_zq * HQ + off_hq
max_ptrs = MAX + off_hz * Q_LEN + offs_m
if IS_DIVISIBLE:
tl.store(max_ptrs, m_i)
else:
tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
# Utility triton funcs
@triton.jit
def get_offset_for_next_block(
loop_iter, col_indices, total_blocks,
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
):
if BLOCKS_ARE_CONTIGUOUS:
return BLOCK
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
return offset
@triton.jit
def get_bounded_indices(indices, max_len=None):
return indices % max_len if max_len is not None else indices
@triton.jit
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
if IS_DIVISIBLE and SAFE_HEAD_DIM:
return tl.load(block_ptr)
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
else:
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
@triton.jit
def load_checked_2d(
ptr,
offs_m,
offs_n,
stride_m,
stride_n,
IS_DIVISIBLE_M: tl.constexpr,
IS_DIVISIBLE_N: tl.constexpr,
M_LEN: tl.constexpr,
N_LEN: tl.constexpr,
):
# Calculate final pointer if strides are provided
if stride_m is not None and stride_n is not None:
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
# Handle all masking cases
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
else: # Both divisible
return tl.load(ptr)
# Common Imports
@triton.jit
def forward_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, in_ptr10, out_ptr0, ks0,
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
kv_offset,
MATMUL_PRECISION, RCP_LN2,
# Strides for K and V
stride_kk, stride_kn, stride_vn, stride_vk,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
OUTPUT_MAX : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'tf32'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.08838834764831843
GQA_SHARED_HEADS : tl.constexpr = 4
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
USE_TMA : tl.constexpr = False
BLOCK_M : tl.constexpr = 128
BLOCK_N : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
INDEX_DTYPE : tl.constexpr = tl.int32
# -- load k --
# NB reversed order to since K is transposed
kv_base_offset = kv_start + kv_offset
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
k = tl.trans(k)
k = k.to(q.dtype)
# -- compute qk ---
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
# which is larger than the actual number of elements. To avoid access memory out of bound,
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
tmp0 = (qk)
post_mod_scores = tmp0
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
tmp1 = (m)
tmp2 = (n)
tmp3 = tmp1 >= tmp2
tmp4 = tl.load(in_ptr9 + tmp1)
tmp5 = tl.full([1], 0, tl.int64)
tmp6 = tmp4 > tmp5
tmp7 = tl.load(in_ptr9 + tmp2)
tmp8 = tmp7 > tmp5
tmp9 = tmp6 & tmp8
tmp10 = tmp4 == tmp7
tmp11 = tmp9 & tmp10
tmp12 = tmp3 | tmp11
tmp13 = tl.load(in_ptr10 + tmp1)
tmp14 = tl.load(in_ptr10 + tmp2)
tmp15 = tmp13 == tmp14
tmp16 = tmp12 & tmp15
tmp17 = tl.full([1], -1, tl.int64)
tmp18 = tmp4 == tmp17
tmp19 = tmp7 == tmp17
tmp20 = tmp18 | tmp19
tmp21 = tmp20 == 0
tmp22 = tmp16 & tmp21
mask_mod_output = tmp22
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
# apply mask for partially unmasked blocks
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# -- compute scaling constant ---
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
if not ROWS_GUARANTEED_SAFE:
masked_out_rows = (m_ij == float("-inf"))
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
else:
m_ij_masked = m_ij
alpha = tl.math.exp2(m_i - m_ij_masked)
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
# NB: l_i update is pulled up here since it's a bit faster
# NB: For headdim=256, it's faster to move it back down to after m_i =
# m_ij
l_i = l_i * alpha + tl.sum(p, 1)
# # -- scale and update acc --
acc = acc * alpha[:, None]
# Calculate offsets for V loading - reuse kv_base_offset from K loading
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
acc = tl.dot(p.to(MATMUL_PRECISION), v.to(q.dtype), acc, input_precision=FLOAT32_PRECISION)
# -- update m_i
m_i = m_ij
return acc, l_i, m_i
@triton.jit
def forward_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, in_ptr10, out_ptr0, ks0,
q, K, V,
desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets used as inputs to score_mod & mask_mod
# of size [BLOCK_M, BLOCK_N] or scalar.
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
# blocksparse data
kv_indices, kv_num_blocks,
# start kv and end kv block
block_n_start, block_n_end,
MATMUL_PRECISION,
# Strides for K and V
stride_kk, stride_kn, stride_vn, stride_vk,
IS_FULL_BLOCKS,
):
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
OUTPUT_MAX : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'tf32'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.08838834764831843
GQA_SHARED_HEADS : tl.constexpr = 4
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
USE_TMA : tl.constexpr = False
BLOCK_M : tl.constexpr = 128
BLOCK_N : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
INDEX_DTYPE : tl.constexpr = tl.int32
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
RCP_LN2: tl.constexpr = 1.44269504
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
kv_offset = 0
# loop over k, v and update accumulator until block_n_end
for start_n in range(block_n_start, block_n_end):
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
if IS_DIVISIBLE:
acc, l_i, m_i = forward_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, in_ptr10, out_ptr0, ks0,
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
kv_offset,
MATMUL_PRECISION, RCP_LN2,
# Strides for K and V
stride_kk, stride_kn, stride_vn, stride_vk,
IS_FULL_BLOCKS,
)
else:
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
# it's on par or slightly faster than only applying to the last block in fwd.
# However, we choose different strategy for bwd, where we only apply mod & mask
# to the last block because it's faster a lot.
acc, l_i, m_i = forward_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, in_ptr10, out_ptr0, ks0,
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
kv_offset,
MATMUL_PRECISION, RCP_LN2,
# Strides for K and V
stride_kk, stride_kn, stride_vn, stride_vk,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
offset = get_offset_for_next_block(
start_n, kv_indices, kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
)
offs_n = offs_n + offset
kv_offset += offset
return acc, l_i, m_i
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/l2/cl2tvo5hcgbtcyeecfjmb2uuyw7bdiqs5eiobc6gbmdzikudnimp.py
# Topologically Sorted Source Nodes: [convert_element_type_447], Original ATen: [aten._to_copy]
# Source node to ATen node mapping:
# convert_element_type_447 => convert_element_type_447
# Graph fragment:
# %buf96 : Tensor = PlaceHolder[target=buf96]
# %convert_element_type_447 : Tensor "f32[2048, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_8, torch.float32), kwargs = {})
# return %convert_element_type_447
triton_poi_fused__to_copy_16 = async_compile.triton('triton_poi_fused__to_copy_16', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 4194304},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_16', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 33554432}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_16(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4194304
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/6i/c6itkxmaq3yfex4sobzcrmuof6hwih6dimfyghykx2wj6eoen4gq.py
# Topologically Sorted Source Nodes: [q, k, v, view_283, view_284, permute_133, flex_attention_backward], Original ATen: [aten.view, aten.permute, flex_attention_backward]
# Source node to ATen node mapping:
# flex_attention_backward => flex_attention_backward
# k => permute_114, view_261, view_262
# permute_133 => permute_133
# q => permute_112, view_258, view_259
# v => permute_116, view_264, view_265
# view_283 => view_283
# view_284 => view_284
# Graph fragment:
# %getitem_83 : Tensor "bf16[1, 16, 327680, 128][671088640, 128, 2048, 1]cuda:0" = PlaceHolder[target=getitem_83]
# %mm_54 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_54]
# %buf98 : Tensor "bf16[1, 16, 327680][5242880, 1, 16]cuda:0" = PlaceHolder[target=buf98]
# %view_258 : Tensor "bf16[327680, 16, 128][2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_45, [327680, 16, 128]), kwargs = {})
# %permute_112 : Tensor "bf16[16, 327680, 128][128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_258, [1, 0, 2]), kwargs = {})
# %view_259 : Tensor "bf16[1, 16, 327680, 128][2048, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_112, [1, 16, 327680, 128]), kwargs = {})
# %view_261 : Tensor "bf16[327680, 4, 128][512, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_46, [327680, 4, 128]), kwargs = {})
# %permute_114 : Tensor "bf16[4, 327680, 128][128, 512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_261, [1, 0, 2]), kwargs = {})
# %view_262 : Tensor "bf16[1, 4, 327680, 128][512, 128, 512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_114, [1, 4, 327680, 128]), kwargs = {})
# %view_264 : Tensor "bf16[327680, 4, 128][512, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_47, [327680, 4, 128]), kwargs = {})
# %permute_116 : Tensor "bf16[4, 327680, 128][128, 512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_264, [1, 0, 2]), kwargs = {})
# %view_265 : Tensor "bf16[1, 4, 327680, 128][512, 128, 512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_116, [1, 4, 327680, 128]), kwargs = {})
# %view_283 : Tensor "bf16[1, 327680, 2048][671088640, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_54, [1, 327680, 2048]), kwargs = {})
# %view_284 : Tensor "bf16[1, 327680, 16, 128][671088640, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_283, [1, 327680, 16, 128]), kwargs = {})
# %permute_133 : Tensor "bf16[1, 16, 327680, 128][671088640, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_284, [0, 2, 1, 3]), kwargs = {})
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%view_259, %view_262, %view_265, %getitem_83, %getitem_84, %permute_133, None, %fw_graph0, %joint_graph0, (327680, 327680, %primals_18, %primals_17, %primals_19, %primals_21, %primals_22, %primals_24, %primals_25, %primals_27, 128, 128, %mask_graph0), 0.08838834764831843, {BACKEND: AUTO, PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: True}, (), (%primals_14, %primals_15)), kwargs = {})
# return %buf98,%buf99
triton_per_fused_flex_attention_backward_permute_view_17 = async_compile.triton('triton_per_fused_flex_attention_backward_permute_view_17', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 8388608, 'r0_': 128},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_flex_attention_backward_permute_view_17', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'x': 0, 'r0_': 2684354560}}
)
@triton.jit
def triton_per_fused_flex_attention_backward_permute_view_17(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
xnumel = 5242880
r0_numel = 128
R0_BLOCK: tl.constexpr = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_index = tl.arange(0, R0_BLOCK)[None, :]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
x2 = (xindex % 16)
x3 = xindex // 16
tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 128*x0), None).to(tl.float32)
tmp2 = tmp0 * tmp1
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
tmp5 = tl.sum(tmp3, 1)[:, None].to(tl.float32)
tmp6 = tmp5.to(tl.float32)
tl.store(out_ptr1 + (x3 + 327680*x2), tmp6, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/xd/cxd2dykhtsl345elda55jiiq6ml5mrbsssan4gy5n6s4njycn65u.py
# Topologically Sorted Source Nodes: [q, k, v, view_283, view_284, permute_133, flex_attention_backward], Original ATen: [aten.view, aten.permute, flex_attention_backward]
# Source node to ATen node mapping:
# flex_attention_backward => flex_attention_backward
# k => permute_114, view_261, view_262
# permute_133 => permute_133
# q => permute_112, view_258, view_259
# v => permute_116, view_264, view_265
# view_283 => view_283
# view_284 => view_284
# Graph fragment:
# %mm_45 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_45]
# %mm_46 : Tensor "bf16[327680, 512][512, 1]cuda:0" = PlaceHolder[target=mm_46]
# %mm_47 : Tensor "bf16[327680, 512][512, 1]cuda:0" = PlaceHolder[target=mm_47]
# %buf88 : Tensor = PlaceHolder[target=buf88]
# %buf99 : Tensor "f32[1, 16, 327680][5242880, 327680, 1]cuda:0" = PlaceHolder[target=buf99]
# %mm_54 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_54]
# %getitem_90 : Tensor "bf16[1, 16, 327680, 128][671088640, 128, 2048, 1]cuda:0" = PlaceHolder[target=getitem_90]
# %getitem_92 : Tensor "bf16[1, 4, 327680, 128][167772160, 128, 512, 1]cuda:0" = PlaceHolder[target=getitem_92]
# %primals_18 : Tensor "i32[1, 1, 2560][2560, 2560, 1]cuda:0" = PlaceHolder[target=primals_18]
# %primals_17 : Tensor "i32[1, 1, 2560, s91][2560*s91, 2560*s91, s91, 1]cuda:0" = PlaceHolder[target=primals_17]
# %primals_22 : Tensor "i32[1, 1, 2560][2560, 2560, 1]cuda:0" = PlaceHolder[target=primals_22]
# %primals_24 : Tensor "i32[1, 1, 2560, s16][2560*s16, 2560*s16, s16, 1]cuda:0" = PlaceHolder[target=primals_24]
# %primals_19 : Tensor "i32[1, 1, 2560][2560, 2560, 1]cuda:0" = PlaceHolder[target=primals_19]
# %primals_21 : Tensor "i32[1, 1, 2560, s6][2560*s6, 2560*s6, s6, 1]cuda:0" = PlaceHolder[target=primals_21]
# %primals_25 : Tensor "i32[1, 1, 2560][2560, 2560, 1]cuda:0" = PlaceHolder[target=primals_25]
# %primals_27 : Tensor "i32[1, 1, 2560, s18][2560*s18, 2560*s18, s18, 1]cuda:0" = PlaceHolder[target=primals_27]
# %primals_14 : Tensor "i64[327680][1]cuda:0" = PlaceHolder[target=primals_14]
# %primals_15 : Tensor "i64[327680][1]cuda:0" = PlaceHolder[target=primals_15]
# %view_258 : Tensor "bf16[327680, 16, 128][2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_45, [327680, 16, 128]), kwargs = {})
# %permute_112 : Tensor "bf16[16, 327680, 128][128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_258, [1, 0, 2]), kwargs = {})
# %view_259 : Tensor "bf16[1, 16, 327680, 128][2048, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_112, [1, 16, 327680, 128]), kwargs = {})
# %view_261 : Tensor "bf16[327680, 4, 128][512, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_46, [327680, 4, 128]), kwargs = {})
# %permute_114 : Tensor "bf16[4, 327680, 128][128, 512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_261, [1, 0, 2]), kwargs = {})
# %view_262 : Tensor "bf16[1, 4, 327680, 128][512, 128, 512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_114, [1, 4, 327680, 128]), kwargs = {})
# %view_264 : Tensor "bf16[327680, 4, 128][512, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_47, [327680, 4, 128]), kwargs = {})
# %permute_116 : Tensor "bf16[4, 327680, 128][128, 512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_264, [1, 0, 2]), kwargs = {})
# %view_265 : Tensor "bf16[1, 4, 327680, 128][512, 128, 512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_116, [1, 4, 327680, 128]), kwargs = {})
# %view_283 : Tensor "bf16[1, 327680, 2048][671088640, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_54, [1, 327680, 2048]), kwargs = {})
# %view_284 : Tensor "bf16[1, 327680, 16, 128][671088640, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_283, [1, 327680, 16, 128]), kwargs = {})
# %permute_133 : Tensor "bf16[1, 16, 327680, 128][671088640, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_284, [0, 2, 1, 3]), kwargs = {})
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%view_259, %view_262, %view_265, %getitem_83, %getitem_84, %permute_133, None, %fw_graph0, %joint_graph0, (327680, 327680, %primals_18, %primals_17, %primals_19, %primals_21, %primals_22, %primals_24, %primals_25, %primals_27, 128, 128, %mask_graph0), 0.08838834764831843, {BACKEND: AUTO, PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: True}, (), (%primals_14, %primals_15)), kwargs = {})
# return %getitem_91
triton_tem_fused_flex_attention_backward_permute_view_18 = async_compile.triton('triton_tem_fused_flex_attention_backward_permute_view_18', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
@triton_heuristics.template(
num_stages=3,
num_warps=8,
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'in_ptr17': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]], (18,): [['tt.divisibility', 16]]}]},
inductor_meta={'kernel_name': 'triton_tem_fused_flex_attention_backward_permute_view_18', 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': True, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
)
@triton.jit
def triton_tem_fused_flex_attention_backward_permute_view_18(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
OUTPUT_MAX : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'tf32'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.08838834764831843
GQA_SHARED_HEADS : tl.constexpr = 4
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
USE_TMA : tl.constexpr = False
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 128
BLOCK_M2 : tl.constexpr = 128
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
INDEX_DTYPE : tl.constexpr = tl.int32
Q = arg_Q
K = arg_K
V = arg_V
LSE = arg_LSE
DELTA = arg_DELTA
DO = arg_DO
DQ = arg_DQ
DV = arg_DV
KV_NUM_BLKS = arg_KV_NUM_BLKS
KV_IDX = arg_KV_IDX
Q_NUM_BLKS = arg_Q_NUM_BLKS
Q_IDX = arg_Q_IDX
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
FULL_KV_IDX = arg_FULL_KV_IDX
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
FULL_Q_IDX = arg_FULL_Q_IDX
# Sub notation for this kernel:
#
# Q: Query, K: Key, V: Value
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
# DELTA: Precomputed sum(OUT*DO, axis=-1)
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
# inductor codegen
# M: Number of queries, N: Number of keys/values
# QK_HEAD_DIM: The dimension of the query and key embeddings
# V_HEAD_DIM: The dimension of the value embeddings
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
# (Modifiable) Performance tuning options
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
#
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
# The below are kernel options that can be applied for certain score_mods,
# or involve a numerics vs. perf tradeoff
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
# about 20% more numerical error, but slightly faster.
# Define strides of inputs
stride_qz, stride_qh, stride_qm, stride_qd = 2048, 128, 2048, 1
stride_kz, stride_kh, stride_kn, stride_kd = 512, 128, 512, 1
stride_vz, stride_vh, stride_vn, stride_vd = 512, 128, 512, 1
stride_doz, stride_doh, stride_dom, stride_dod = 671088640, 128, 2048, 1
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 671088640, 128, 2048, 1
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 167772160, 128, 512, 1
ZQ = 1
HQ = 16
HKV = 4
Q_LEN = 327680
ZKV = 1
KV_LEN = 327680
MATMUL_PRECISION = Q.dtype.element_ty
pid = tl.program_id(0).to(INDEX_DTYPE)
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
off_zkv = off_zq % ZKV # kv batch idx
SPARSE_Z = 1
SPARSE_HQ = 1
sparse_idx_z = off_zq % SPARSE_Z
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
# offset K, V, DV pointers for batch/kv-head
K += k_adj
V += v_adj
DV += dv_adj
RCP_LN2 = 1.44269504
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
if pid >= NUM_KV_BLOCKS:
off_pid = pid - NUM_KV_BLOCKS
# THIS BLOCK DOES DQ
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
start_m2_block = off_pid % NUM_Q_BLOCKS
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
stride_kv_num_blks_h = 2560
stride_kv_idx_h = 2560*ks0
stride_kv_idx_m = ks0
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
Q2 = Q + q_adj2
DO2 = DO + do_adj2
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
DQ2 = DQ + dq_adj2
LSE2 = LSE + off_chz2
DELTA2 = DELTA + off_chz2
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
start_m2 = start_m2_block * BLOCK_M2
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
desc_q = None
desc_do = None
desc_k_dq = None
desc_v_dq = None
# load Q and do: they stay in SRAM throughout the inner loop.
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
if IS_DIVISIBLE:
Di = tl.load(DELTA2 + offs_m2)
lse = tl.load(LSE2 + offs_m2)
else:
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
lse = tl.where(lse == -float("inf"), 0.0, lse)
lse = lse[:, None]
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# KV_IDX and KV_NUM_BLKS are always contiguous.
kv_indices = KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
dq = bwd_dq_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1,
K, V, desc_k_dq, desc_v_dq, kv_start, start_m2,
dq, q, do, Di, lse,
off_zq, off_hq2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1,
K, V, desc_k_dq, desc_v_dq, kv_start, start_m2,
dq, q, do, Di, lse,
off_zq, off_hq2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# Write back dQ.
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
dq *= SM_SCALE
if IS_DIVISIBLE and SAFE_HEAD_DIM:
tl.store(dq_ptrs, dq)
else:
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
else:
# THIS BLOCK DOES DK & DV
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
pid_mask = pid // SPARSE_KV_MULTIPLE
stride_q_num_blks_h = 2560
stride_q_idx_h = 2560*ks1
stride_q_idx_n = ks1
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
start_n1 = pid * BLOCK_N1
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
# load K and V: they stay in SRAM throughout the inner loop.
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
if PRESCALE_QK:
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
for off_g in range(0, GQA_SHARED_HEADS):
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
Q1 = Q + q_adj1
DO1 = DO + do_adj1
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
LSE1 = LSE + off_chz1
DELTA1 = DELTA + off_chz1
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Q_IDX and Q_NUM_BLKS are always contiguous.
q_indices = Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
desc_q = None
desc_do = None
dk, dv = bwd_dkdv_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1,
Q1, DO1, DELTA1, LSE1,
desc_q, desc_do, q_start,
dk, dv, k, v,
off_zq, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
q_indices = FULL_Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
dk, dv = bwd_dkdv_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1,
Q1, DO1, DELTA1, LSE1,
desc_q, desc_do, q_start,
dk, dv, k, v,
off_zq, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# Write back dV and dK.
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
index_n = offs_n1[:, None]
index_k = offs_k[None, :]
index_v = offs_v[None, :]
if IS_DIVISIBLE and SAFE_HEAD_DIM:
tl.store(dv_ptrs, dv)
else:
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
dk *= SM_SCALE
if SAFE_HEAD_DIM:
mask = index_n < KV_LEN
else:
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
xindex = index_k + 128*index_n + 41943040*off_hkv + 167772160*off_zq
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 512*index_n, [BLOCK_N1, QK_HEAD_DIM_ROUNDED])), dk, mask)
@triton.jit
def bwd_dq_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1,
K, V, desc_k, desc_v, kv_start, start_m2,
dq, q, do, Di, lse,
off_z, off_hq,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
OUTPUT_MAX : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'tf32'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.08838834764831843
GQA_SHARED_HEADS : tl.constexpr = 4
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
USE_TMA : tl.constexpr = False
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 128
BLOCK_M2 : tl.constexpr = 128
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
INDEX_DTYPE : tl.constexpr = tl.int32
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = 327680
KV_LEN = 327680
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
kv_offset = 0
for start_n in range(0, hi):
dq = bwd_dq_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1,
dq, q, K, V, desc_k, desc_v, kv_start, kv_offset, start_m2,
do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_n, kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
)
kv_offset += offset
return dq
@triton.jit
def bwd_dq_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1,
dq, q, K, V, desc_k, desc_v, kv_start, kv_offset, start_m2,
do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
OUTPUT_MAX : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'tf32'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.08838834764831843
GQA_SHARED_HEADS : tl.constexpr = 4
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
USE_TMA : tl.constexpr = False
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 128
BLOCK_M2 : tl.constexpr = 128
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
INDEX_DTYPE : tl.constexpr = tl.int32
offs_n2 = kv_start + kv_offset + tl.arange(0, BLOCK_N2)
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
# NB reversed order to since K is transposed
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
pre_mod_scores = qk
n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
# that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
tmp0 = (qk)
post_mod_scores = tmp0
if not IS_DIVISIBLE:
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
tmp1 = (m)
tmp2 = (n)
tmp3 = tmp1 >= tmp2
tmp4 = tl.load(in_ptr16 + tmp1)
tmp5 = tl.full([1], 0, tl.int64)
tmp6 = tmp4 > tmp5
tmp7 = tl.load(in_ptr16 + tmp2)
tmp8 = tmp7 > tmp5
tmp9 = tmp6 & tmp8
tmp10 = tmp4 == tmp7
tmp11 = tmp9 & tmp10
tmp12 = tmp3 | tmp11
tmp13 = tl.load(in_ptr17 + tmp1)
tmp14 = tl.load(in_ptr17 + tmp2)
tmp15 = tmp13 == tmp14
tmp16 = tmp12 & tmp15
tmp17 = tl.full([1], -1, tl.int64)
tmp18 = tmp4 == tmp17
tmp19 = tmp7 == tmp17
tmp20 = tmp18 | tmp19
tmp21 = tmp20 == 0
tmp22 = tmp16 & tmp21
mask_mod_output = tmp22
# apply mask for partial masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
p = tl.math.exp2(post_mod_scores - lse)
# Compute dP and dS.
# NB reversed order to since V is transposed
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
ds = p * (dp - Di[:, None])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
tmp23 = (ds)
grad_scores = tmp23
if not IS_DIVISIBLE:
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
if WRITE_DQ:
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ds = grad_scores
if not IS_FULL_BLOCKS:
# (grads) apply mask for partially unmasked block
ds = tl.where(mask_mod_output, ds, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ds = ds.to(MATMUL_PRECISION)
# Compute dQ.
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
return dq
@triton.jit
def bwd_dkdv_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1,
Q, DO, DELTA, LSE, # pointers
desc_q, desc_do, q_start,
dk, dv, k, v,
off_z, off_hq, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
OUTPUT_MAX : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'tf32'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.08838834764831843
GQA_SHARED_HEADS : tl.constexpr = 4
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
USE_TMA : tl.constexpr = False
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 128
BLOCK_M2 : tl.constexpr = 128
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
INDEX_DTYPE : tl.constexpr = tl.int32
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = 327680
KV_LEN = 327680
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
# The minimum is needed to handle the case where we run with a super large
# SPARSE_BLOCK_SIZE (i.e. no block-mask!)
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
offset_block = 0
for start_m in range(0, hi):
dk, dv = bwd_dkdv_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1,
dk, dv, Q, k, v, DO, DELTA, LSE,
desc_q, desc_do, q_start, offset_block, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_m, q_indices, sparse_q_num_blocks,
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
)
offs_m1 += offset
offset_block += offset
return dk, dv
@triton.jit
def bwd_dkdv_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, in_ptr17, out_ptr0, ks0, ks1,
dk, dv, Q, k, v, DO, DELTA, LSE,
desc_q, desc_do, q_start, offset, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
OUTPUT_MAX : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'tf32'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.08838834764831843
GQA_SHARED_HEADS : tl.constexpr = 4
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
USE_TMA : tl.constexpr = False
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 128
BLOCK_M2 : tl.constexpr = 128
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
INDEX_DTYPE : tl.constexpr = tl.int32
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
# NB reversed order since Q is transposed
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
# Load LSE before computing qk to reduce pipeline stall.
if IS_DIVISIBLE:
lse = tl.load(LSE + offs_m1)
else:
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
lse = tl.where(lse == -float("inf"), 0.0, lse)
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
if not PRESCALE_QK:
qkT *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
# that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
pre_mod_scores = qkT
tmp24 = (qkT)
post_mod_scores = tmp24
if not IS_DIVISIBLE:
post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
tmp25 = (m)
tmp26 = (n)
tmp27 = tmp25 >= tmp26
tmp28 = tl.load(in_ptr16 + tmp25)
tmp29 = tl.full([1], 0, tl.int64)
tmp30 = tmp28 > tmp29
tmp31 = tl.load(in_ptr16 + tmp26)
tmp32 = tmp31 > tmp29
tmp33 = tmp30 & tmp32
tmp34 = tmp28 == tmp31
tmp35 = tmp33 & tmp34
tmp36 = tmp27 | tmp35
tmp37 = tl.load(in_ptr17 + tmp25)
tmp38 = tl.load(in_ptr17 + tmp26)
tmp39 = tmp37 == tmp38
tmp40 = tmp36 & tmp39
tmp41 = tl.full([1], -1, tl.int64)
tmp42 = tmp28 == tmp41
tmp43 = tmp31 == tmp41
tmp44 = tmp42 | tmp43
tmp45 = tmp44 == 0
tmp46 = tmp40 & tmp45
mask_mod_output = tmp46
# (grads) apply mask for fully masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
pT = tl.math.exp2(post_mod_scores - lse[None, :])
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
# Compute dV.
ppT = pT
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
if IS_DIVISIBLE:
Di = tl.load(DELTA + offs_m1)
else:
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
dsT = pT * (dpT - Di[None, :])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
tmp47 = (dsT)
grad_scores = tmp47
if not IS_DIVISIBLE:
grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
if not WRITE_DQ:
idx_b = off_z
idx_h = off_hq
idx_m = m
idx_n = n
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
dsT = grad_scores
if not IS_FULL_BLOCKS:
# (grads) apply mask for partially unmasked block
dsT = tl.where(mask_mod_output, dsT, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
return dk, dv
# Utility triton funcs
@triton.jit
def get_offset_for_next_block(
loop_iter, col_indices, total_blocks,
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
):
if BLOCKS_ARE_CONTIGUOUS:
return BLOCK
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
return offset
@triton.jit
def get_bounded_indices(indices, max_len=None):
return indices % max_len if max_len is not None else indices
@triton.jit
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
if IS_DIVISIBLE and SAFE_HEAD_DIM:
return tl.load(block_ptr)
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
else:
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
@triton.jit
def load_checked_2d(
ptr,
offs_m,
offs_n,
stride_m,
stride_n,
IS_DIVISIBLE_M: tl.constexpr,
IS_DIVISIBLE_N: tl.constexpr,
M_LEN: tl.constexpr,
N_LEN: tl.constexpr,
):
# Calculate final pointer if strides are provided
if stride_m is not None and stride_n is not None:
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
# Handle all masking cases
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
else: # Both divisible
return tl.load(ptr)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/x5/cx5sqi2rjoxy6l3kpblixq6vlpyrk2tnz5mjwgm76szawvechnfv.py
# Topologically Sorted Source Nodes: [convert_element_type_452], Original ATen: [aten._to_copy]
# Source node to ATen node mapping:
# convert_element_type_452 => convert_element_type_452
# Graph fragment:
# %buf110 : Tensor = PlaceHolder[target=buf110]
# %convert_element_type_452 : Tensor "f32[512, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_9, torch.float32), kwargs = {})
# return %convert_element_type_452
triton_poi_fused__to_copy_19 = async_compile.triton('triton_poi_fused__to_copy_19', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 1048576},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_19', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 8388608}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_19(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1048576
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/2x/c2xyq7pyq5ybvlnjstymir6orquvs2w2dnuv2qnoitnx4xrvmhzx.py
# Topologically Sorted Source Nodes: [redistribute, layer_norm, add_115, add_116, convert_element_type_463, convert_element_type_465, mul_179, mul_180, sum_11, mul_181, sum_12, mul_182, sub_36, sub_37, div_5, mul_183, mul_184, sum_13, sum_14, convert_element_type_467, add_117], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward]
# Source node to ATen node mapping:
# add_115 => add_115
# add_116 => add_116
# add_117 => add_117
# convert_element_type_463 => convert_element_type_463
# convert_element_type_465 => convert_element_type_465
# convert_element_type_467 => convert_element_type_467
# div_5 => div_5
# layer_norm => add_100, convert_element_type_376, mul_139, rsqrt_24, sub_25, var_mean_24
# mul_179 => mul_179
# mul_180 => mul_180
# mul_181 => mul_181
# mul_182 => mul_182
# mul_183 => mul_183
# mul_184 => mul_184
# redistribute => convert_element_type_374
# sub_36 => sub_36
# sub_37 => sub_37
# sum_11 => sum_11
# sum_12 => sum_12
# sum_13 => sum_13
# sum_14 => sum_14
# Graph fragment:
# %mm_56 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_56]
# %mm_58 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_58]
# %mm_60 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_60]
# %primals_155 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=primals_155]
# %add_99 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_99]
# %getitem_82 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=getitem_82]
# %buf76 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf76]
# %sum_11 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_11]
# %sum_12 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_12]
# %add_114 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_114]
# %sub_37 : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=sub_37]
# %convert_element_type_374 : Tensor "bf16[2048][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_155, torch.bfloat16), kwargs = {})
# %convert_element_type_376 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_99, torch.float32), kwargs = {})
# %var_mean_24 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_376, [1]), kwargs = {correction: 0, keepdim: True})
# %add_100 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_81, 1e-05), kwargs = {})
# %rsqrt_24 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_100,), kwargs = {})
# %sub_25 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_376, %getitem_82), kwargs = {})
# %mul_139 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_25, %rsqrt_24), kwargs = {})
# %add_115 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mm_56, %mm_58), kwargs = {})
# %add_116 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_115, %mm_60), kwargs = {})
# %convert_element_type_463 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_116, torch.float32), kwargs = {})
# %convert_element_type_465 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_374, torch.float32), kwargs = {})
# %mul_179 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_463, %convert_element_type_465), kwargs = {})
# %mul_180 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_179, 2048), kwargs = {})
# %sum_11 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_179, [1], True), kwargs = {})
# %mul_181 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_179, %mul_139), kwargs = {})
# %sum_12 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_181, [1], True), kwargs = {})
# %mul_182 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_139, %sum_12), kwargs = {})
# %sub_36 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_180, %sum_11), kwargs = {})
# %sub_37 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_36, %mul_182), kwargs = {})
# %div_5 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%rsqrt_24, 2048), kwargs = {})
# %mul_183 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_5, %sub_37), kwargs = {})
# %mul_184 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_463, %mul_139), kwargs = {})
# %sum_13 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_184, [0]), kwargs = {})
# %sum_14 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_463, [0]), kwargs = {})
# %convert_element_type_467 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_183, torch.bfloat16), kwargs = {})
# %add_117 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_114, %convert_element_type_467), kwargs = {})
# return %sum_11,%sum_12,%sub_37,%add_117
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20 = async_compile.triton('triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 524288, 'r0_': 2048},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*fp32', 'ws_ptr': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'RSPLIT_SIZE': 'constexpr', 'NUM_STAGES': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'MixOrderReductionGrid', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 8, 'num_store': -2, 'num_reduction': 2, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'RSPLIT_SIZE': 128, 'has_loadstore_with_contiguous_rdim': True}
)
@triton.jit
def triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, ws_ptr, xnumel, r0_numel, XBLOCK : tl.constexpr, RSPLIT_SIZE : tl.constexpr, NUM_STAGES : tl.constexpr):
xnumel = 327680
r0_numel = 2048
R0_BLOCK: tl.constexpr = 2048
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * RSPLIT_SIZE
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_index = tl.arange(0, R0_BLOCK)[None, :]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
accum0 = tl.full([R0_BLOCK], 0, tl.float32)[None, :]
accum1 = tl.full([R0_BLOCK], 0, tl.float32)[None, :]
split_size = min(RSPLIT_SIZE, xnumel - xoffset)
for _ in tl.range(0, split_size, XBLOCK, num_stages=NUM_STAGES):
x0 = xindex
xindex += XBLOCK
tmp0 = tl.load(in_ptr0 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp6 = tl.load(in_ptr3 + (r0_1), None, eviction_policy='evict_last')
tmp13 = tl.load(in_ptr4 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp15 = tl.load(in_ptr5 + (x0), None, eviction_policy='evict_last')
tmp17 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last')
tmp32 = tl.load(in_out_ptr0 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp5 * tmp8
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32)
tmp14 = tmp13.to(tl.float32)
tmp16 = tmp14 - tmp15
tmp18 = tl.full([1, 1], 2048.0, tl.float32)
tmp19 = (tmp17 / tmp18)
tmp20 = tl.full([1, 1], 1e-05, tl.float32)
tmp21 = tmp19 + tmp20
tmp22 = libdevice.rsqrt(tmp21)
tmp23 = tmp16 * tmp22
tmp24 = tmp9 * tmp23
tmp25 = tl.broadcast_to(tmp24, [XBLOCK, R0_BLOCK])
tmp27 = tl.sum(tmp25, 1)[:, None].to(tl.float32)
tmp28 = tmp9 * tmp18
tmp29 = tmp28 - tmp12
tmp30 = tmp23 * tmp27
tmp31 = tmp29 - tmp30
tmp33 = tl.full([1, 1], 0.00048828125, tl.float32)
tmp34 = tmp22 * tmp33
tmp35 = tmp34 * tmp31
tmp36 = tmp35.to(tl.float32)
tmp37 = tmp32 + tmp36
tmp38 = tmp5 * tmp23
tl.store(in_out_ptr0 + (r0_1 + 2048*x0), tmp37, None)
tmp39 = tl.sum(tmp38, 0)
tmp40 = accum0 + tmp39
accum0 = tmp40
tmp41 = tl.sum(tmp5, 0)
tmp42 = accum1 + tmp41
accum1 = tmp42
tl.store(ws_ptr + (tl.program_id(0) + 0 * tl.num_programs(0)) * r0_numel + r0_index, accum0, r0_mask)
tl.store(ws_ptr + (tl.program_id(0) + 1 * tl.num_programs(0)) * r0_numel + r0_index, accum1, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/gq/cgqx5tvvnyzs57gxny2ebeaaudmxya3lic2feiq6op2k26gzmj56.py
# Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_3, contiguous_1], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone]
# Source node to ATen node mapping:
# contiguous_1 => clone_3
# data => mul_7
# getitem => select
# getitem_1 => unsqueeze_6
# getitem_3 => slice_19
# mask => eq_1
# Graph fragment:
# %primals_1 : Tensor "u8[327680, 785][785, 1]cuda:0" = PlaceHolder[target=primals_1]
# %select : Tensor "u8[327680][785]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.select.int](args = (%primals_1, 1, -1), kwargs = {})
# %eq_1 : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select, 2), kwargs = {})
# %unsqueeze_6 : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq_1, 1), kwargs = {})
# %mul_7 : Tensor "u8[327680, 785][785, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_1, %unsqueeze_6), kwargs = {})
# %slice_19 : Tensor "u8[327680, 16][785, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_7, 1, 768, 784), kwargs = {})
# %clone_3 : Tensor "u8[327680, 16][16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_19,), kwargs = {memory_format: torch.contiguous_format})
# return %clone_3
triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_21 = async_compile.triton('triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_21', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 8388608},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*u8', 'out_ptr0': '*u8', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_21', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 15728640}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_21(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 5242880
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x0 = (xindex % 16)
x1 = xindex // 16
x2 = xindex
tmp0 = tl.load(in_ptr0 + (768 + x0 + 785*x1), None)
tmp1 = tl.load(in_ptr0 + (784 + 785*x1), None, eviction_policy='evict_last')
tmp2 = tl.full([1], 2, tl.uint8)
tmp3 = tmp1 == tmp2
tmp4 = tmp3.to(tl.uint8)
tmp5 = tmp0 * tmp4
tl.store(out_ptr0 + (x2), tmp5, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/io/cioh57ngo45ltcro4ob3mshvmir5ds5aetufnjfpio2asj7ih22f.py
# Topologically Sorted Source Nodes: [i, j, ifreqs, ifreqs_1, neg, freqs, getitem_6, getitem_7, mul_1, sin, cos, getitem_10, mul_3, sin_1, cos_1, posemb, to_1, layer_norm_1], Original ATen: [aten.select, aten.arange, aten.div, aten.neg, aten.pow, aten.unsqueeze, aten.mul, aten.sin, aten.cos, aten.cat, aten._to_copy, aten.native_layer_norm]
# Source node to ATen node mapping:
# cos => cos_1
# cos_1 => cos_2
# freqs => pow_2
# getitem_10 => unsqueeze_11
# getitem_6 => unsqueeze_7
# getitem_7 => unsqueeze_8
# i => select_4
# ifreqs => add_5, convert_element_type_16, iota_1, mul_10
# ifreqs_1 => div_2
# j => select_5
# layer_norm_1 => convert_element_type_20, var_mean_1
# mul_1 => mul_11
# mul_3 => mul_13
# neg => neg_1
# posemb => cat
# sin => sin_1
# sin_1 => sin_2
# to_1 => convert_element_type_17
# Graph fragment:
# %view_6 : Tensor "i32[327680, 4][4, 1]cuda:0" = PlaceHolder[target=view_6]
# %cat : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=cat]
# %select_4 : Tensor "i32[327680][4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%view_6, 1, 0), kwargs = {})
# %select_5 : Tensor "i32[327680][4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%view_6, 1, 1), kwargs = {})
# %iota_1 : Tensor "i64[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (512,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
# %mul_10 : Tensor "i64[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%iota_1, 4), kwargs = {})
# %add_5 : Tensor "i64[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_10, 0), kwargs = {})
# %convert_element_type_16 : Tensor "f32[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_5, torch.float32), kwargs = {})
# %div_2 : Tensor "f32[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%convert_element_type_16, 2047), kwargs = {})
# %neg_1 : Tensor "f32[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%div_2,), kwargs = {})
# %pow_2 : Tensor "f32[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Scalar](args = (10000.0, %neg_1), kwargs = {})
# %unsqueeze_7 : Tensor "i32[327680, 1][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%select_4, 1), kwargs = {})
# %unsqueeze_8 : Tensor "f32[1, 512][512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%pow_2, 0), kwargs = {})
# %mul_11 : Tensor "f32[327680, 512][512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%unsqueeze_7, %unsqueeze_8), kwargs = {})
# %sin_1 : Tensor "f32[327680, 512][512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mul_11,), kwargs = {})
# %cos_1 : Tensor "f32[327680, 512][512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%mul_11,), kwargs = {})
# %unsqueeze_11 : Tensor "i32[327680, 1][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%select_5, 1), kwargs = {})
# %mul_13 : Tensor "f32[327680, 512][512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%unsqueeze_11, %unsqueeze_8), kwargs = {})
# %sin_2 : Tensor "f32[327680, 512][512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mul_13,), kwargs = {})
# %cos_2 : Tensor "f32[327680, 512][512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%mul_13,), kwargs = {})
# %cat : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%sin_1, %cos_1, %sin_2, %cos_2], -1), kwargs = {})
# %convert_element_type_17 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%cat, torch.bfloat16), kwargs = {})
# %convert_element_type_20 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_17, torch.float32), kwargs = {})
# %var_mean_1 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_20, [1]), kwargs = {correction: 0, keepdim: True})
# return %cat,%getitem_3,%buf1547
triton_red_fused__to_copy_arange_cat_cos_div_mul_native_layer_norm_neg_pow_select_sin_unsqueeze_22 = async_compile.triton('triton_red_fused__to_copy_arange_cat_cos_div_mul_native_layer_norm_neg_pow_select_sin_unsqueeze_22', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 2048},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_cat_cos_div_mul_native_layer_norm_neg_pow_select_sin_unsqueeze_22', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 3, 'num_reduction': 2, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'x': 5242880, 'r0_': 5368709120}}
)
@triton.jit
def triton_red_fused__to_copy_arange_cat_cos_div_mul_native_layer_norm_neg_pow_select_sin_unsqueeze_22(in_ptr0, out_ptr0, out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 327680
r0_numel = 2048
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
tmp74_mean = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
tmp74_m2 = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
tmp74_weight = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = r0_1
tmp1 = tl.full([1, 1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1, 1], 512, tl.int64)
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (tl.broadcast_to(4*x0, [XBLOCK, R0_BLOCK])), r0_mask & tmp4, eviction_policy='evict_last', other=0.0)
tmp6 = tmp5.to(tl.float32)
tmp7 = tl.broadcast_to(4*(r0_1), [XBLOCK, R0_BLOCK])
tmp8 = tmp7.to(tl.float32)
tmp9 = tl.full([1, 1], 0.0004885197850512946, tl.float32)
tmp10 = tmp8 * tmp9
tmp11 = -tmp10
tmp12 = tl.full([1, 1], 10000.0, tl.float32)
tmp13 = libdevice.pow(tmp12, tmp11)
tmp14 = tmp6 * tmp13
tmp15 = tl_math.sin(tmp14)
tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
tmp17 = tl.where(tmp4, tmp15, tmp16)
tmp18 = tmp0 >= tmp3
tmp19 = tl.full([1, 1], 1024, tl.int64)
tmp20 = tmp0 < tmp19
tmp21 = tmp18 & tmp20
tmp22 = tl.load(in_ptr0 + (tl.broadcast_to(4*x0, [XBLOCK, R0_BLOCK])), r0_mask & tmp21, eviction_policy='evict_last', other=0.0)
tmp23 = tmp22.to(tl.float32)
tmp24 = tl.broadcast_to(4*((-512) + r0_1), [XBLOCK, R0_BLOCK])
tmp25 = tmp24.to(tl.float32)
tmp26 = tl.full([1, 1], 0.0004885197850512946, tl.float32)
tmp27 = tmp25 * tmp26
tmp28 = -tmp27
tmp29 = tl.full([1, 1], 10000.0, tl.float32)
tmp30 = libdevice.pow(tmp29, tmp28)
tmp31 = tmp23 * tmp30
tmp32 = tl_math.cos(tmp31)
tmp33 = tl.full(tmp32.shape, 0.0, tmp32.dtype)
tmp34 = tl.where(tmp21, tmp32, tmp33)
tmp35 = tmp0 >= tmp19
tmp36 = tl.full([1, 1], 1536, tl.int64)
tmp37 = tmp0 < tmp36
tmp38 = tmp35 & tmp37
tmp39 = tl.load(in_ptr0 + (tl.broadcast_to(1 + 4*x0, [XBLOCK, R0_BLOCK])), r0_mask & tmp38, eviction_policy='evict_last', other=0.0)
tmp40 = tmp39.to(tl.float32)
tmp41 = tl.broadcast_to(4*((-1024) + r0_1), [XBLOCK, R0_BLOCK])
tmp42 = tmp41.to(tl.float32)
tmp43 = tl.full([1, 1], 0.0004885197850512946, tl.float32)
tmp44 = tmp42 * tmp43
tmp45 = -tmp44
tmp46 = tl.full([1, 1], 10000.0, tl.float32)
tmp47 = libdevice.pow(tmp46, tmp45)
tmp48 = tmp40 * tmp47
tmp49 = tl_math.sin(tmp48)
tmp50 = tl.full(tmp49.shape, 0.0, tmp49.dtype)
tmp51 = tl.where(tmp38, tmp49, tmp50)
tmp52 = tmp0 >= tmp36
tmp53 = tl.full([1, 1], 2048, tl.int64)
tmp54 = tmp0 < tmp53
tmp55 = tl.load(in_ptr0 + (tl.broadcast_to(1 + 4*x0, [XBLOCK, R0_BLOCK])), r0_mask & tmp52, eviction_policy='evict_last', other=0.0)
tmp56 = tmp55.to(tl.float32)
tmp57 = tl.broadcast_to(4*((-1536) + r0_1), [XBLOCK, R0_BLOCK])
tmp58 = tmp57.to(tl.float32)
tmp59 = tl.full([1, 1], 0.0004885197850512946, tl.float32)
tmp60 = tmp58 * tmp59
tmp61 = -tmp60
tmp62 = tl.full([1, 1], 10000.0, tl.float32)
tmp63 = libdevice.pow(tmp62, tmp61)
tmp64 = tmp56 * tmp63
tmp65 = tl_math.cos(tmp64)
tmp66 = tl.full(tmp65.shape, 0.0, tmp65.dtype)
tmp67 = tl.where(tmp52, tmp65, tmp66)
tmp68 = tl.where(tmp38, tmp51, tmp67)
tmp69 = tl.where(tmp21, tmp34, tmp68)
tmp70 = tl.where(tmp4, tmp17, tmp69)
tmp71 = tmp70.to(tl.float32)
tmp72 = tmp71.to(tl.float32)
tmp73 = tl.broadcast_to(tmp72, [XBLOCK, R0_BLOCK])
tmp74_mean_next, tmp74_m2_next, tmp74_weight_next = triton_helpers.welford_reduce(
tmp73, tmp74_mean, tmp74_m2, tmp74_weight, roffset == 0
)
tmp74_mean = tl.where(r0_mask, tmp74_mean_next, tmp74_mean)
tmp74_m2 = tl.where(r0_mask, tmp74_m2_next, tmp74_m2)
tmp74_weight = tl.where(r0_mask, tmp74_weight_next, tmp74_weight)
tl.store(out_ptr0 + (r0_1 + 2048*x0), tmp70, r0_mask)
tmp75, tmp76, tmp77 = triton_helpers.welford(tmp74_mean, tmp74_m2, tmp74_weight, 1)
tmp74 = tmp75[:, None]
tmp78 = tmp76[:, None]
tmp79 = tmp77[:, None]
tl.store(out_ptr1 + (x0), tmp74, None)
tl.store(out_ptr2 + (x0), tmp78, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/5t/c5tw3okxvozf5cmhbwpwep3xfgakwg676mtpfbahctf6vbhiy2t6.py
# Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_2, patches, to, truediv, patches_1], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone, aten._to_copy, aten.div, aten.sub]
# Source node to ATen node mapping:
# data => mul_7
# getitem => select
# getitem_1 => unsqueeze_6
# getitem_2 => slice_18
# mask => eq_1
# patches => clone_2
# patches_1 => sub
# to => convert_element_type_8
# truediv => div_1
# Graph fragment:
# %primals_1 : Tensor "u8[327680, 785][785, 1]cuda:0" = PlaceHolder[target=primals_1]
# %select : Tensor "u8[327680][785]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.select.int](args = (%primals_1, 1, -1), kwargs = {})
# %eq_1 : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select, 2), kwargs = {})
# %unsqueeze_6 : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq_1, 1), kwargs = {})
# %mul_7 : Tensor "u8[327680, 785][785, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_1, %unsqueeze_6), kwargs = {})
# %slice_18 : Tensor "u8[327680, 768][785, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_7, 1, 0, 768), kwargs = {})
# %clone_2 : Tensor "u8[327680, 768][768, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_18,), kwargs = {memory_format: torch.contiguous_format})
# %convert_element_type_8 : Tensor "bf16[327680, 768][768, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clone_2, torch.bfloat16), kwargs = {})
# %div_1 : Tensor "bf16[327680, 768][768, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%convert_element_type_8, 127.5), kwargs = {})
# %sub : Tensor "bf16[327680, 768][768, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%div_1, 1.0), kwargs = {})
# return %sub
triton_poi_fused__to_copy_clone_div_eq_mul_select_slice_sub_unsqueeze_23 = async_compile.triton('triton_poi_fused__to_copy_clone_div_eq_mul_select_slice_sub_unsqueeze_23', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 268435456},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*u8', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_div_eq_mul_select_slice_sub_unsqueeze_23', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 1258291200}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_div_eq_mul_select_slice_sub_unsqueeze_23(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 251658240
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x0 = (xindex % 768)
x1 = xindex // 768
x2 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + 785*x1), None)
tmp1 = tl.load(in_ptr0 + (784 + 785*x1), None, eviction_policy='evict_last')
tmp2 = tl.full([1], 2, tl.uint8)
tmp3 = tmp1 == tmp2
tmp4 = tmp3.to(tl.uint8)
tmp5 = tmp0 * tmp4
tmp6 = tmp5.to(tl.float32)
tmp7 = tl.full([1], 0.00784313725490196, tl.float32)
tmp8 = tmp6 * tmp7
tmp9 = tl.full([1], 1.0, tl.float32)
tmp10 = tmp8 - tmp9
tl.store(out_ptr0 + (x2), tmp10, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/cr/ccrpjojnaxipr3oeop55qzfsbpuwnmdmswoonbypn43u5pixgimi.py
# Topologically Sorted Source Nodes: [redistribute_1], Original ATen: [aten._to_copy]
# Source node to ATen node mapping:
# redistribute_1 => convert_element_type_9
# Graph fragment:
# %primals_4 : Tensor "f32[2048, 768][768, 1]cuda:0" = PlaceHolder[target=primals_4]
# %convert_element_type_9 : Tensor "bf16[2048, 768][768, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.bfloat16), kwargs = {})
# return %convert_element_type_9
triton_poi_fused__to_copy_24 = async_compile.triton('triton_poi_fused__to_copy_24', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 2097152},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_24', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 12582912}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_24(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1572864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/hh/chhxb3thm46b3idor7qfb5sr3aucc5huuj3uzxii67e2uiglboqx.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.native_layer_norm]
# Source node to ATen node mapping:
# x => convert_element_type_14, var_mean
# Graph fragment:
# %mm : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm]
# %convert_element_type_14 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm, torch.float32), kwargs = {})
# %var_mean : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_14, [1]), kwargs = {correction: 0, keepdim: True})
# return %getitem_1,%buf1572
triton_red_fused_native_layer_norm_25 = async_compile.triton('triton_red_fused_native_layer_norm_25', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 2048},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_native_layer_norm_25', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 2, 'num_reduction': 2, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'x': 5242880, 'r0_': 1342177280}}
)
@triton.jit
def triton_red_fused_native_layer_norm_25(in_ptr0, out_ptr0, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 327680
r0_numel = 2048
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
tmp3_mean = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
tmp3_m2 = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
tmp3_weight = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_1 + 2048*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
tmp3_mean_next, tmp3_m2_next, tmp3_weight_next = triton_helpers.welford_reduce(
tmp2, tmp3_mean, tmp3_m2, tmp3_weight, roffset == 0
)
tmp3_mean = tl.where(r0_mask, tmp3_mean_next, tmp3_mean)
tmp3_m2 = tl.where(r0_mask, tmp3_m2_next, tmp3_m2)
tmp3_weight = tl.where(r0_mask, tmp3_weight_next, tmp3_weight)
tmp4, tmp5, tmp6 = triton_helpers.welford(tmp3_mean, tmp3_m2, tmp3_weight, 1)
tmp3 = tmp4[:, None]
tmp7 = tmp5[:, None]
tmp8 = tmp6[:, None]
tl.store(out_ptr0 + (x0), tmp3, None)
tl.store(out_ptr1 + (x0), tmp7, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/ri/cri52ajrwpssovjadiworgquy76migc7fywttykfs5xlisxg6tb7.py
# Topologically Sorted Source Nodes: [redistribute, layer_norm, add_203, add_204, convert_element_type_1046, convert_element_type_1048, mul_465, mul_466, sum_121, mul_467, sum_122, mul_468, sub_113, sub_114, div_27, mul_469, mul_470, sum_123, sum_124, convert_element_type_1050, add_205, getitem, mask, getitem_1, mul_471, convert_element_type_1055, to_1, layer_norm_1, mul_477, redistribute_2, convert_element_type_1065, mul_479, mul_480, sum_129, x, mul_481, sum_130, mul_482, sub_119, sub_120, div_28, mul_483, mul_484, convert_element_type_1067, mul_485, slice_21, full_default, copy_3], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward, aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.zeros_like, aten.copy]
# Source node to ATen node mapping:
# add_203 => add_203
# add_204 => add_204
# add_205 => add_205
# convert_element_type_1046 => convert_element_type_1046
# convert_element_type_1048 => convert_element_type_1048
# convert_element_type_1050 => convert_element_type_1050
# convert_element_type_1055 => convert_element_type_1055
# convert_element_type_1065 => convert_element_type_1065
# convert_element_type_1067 => convert_element_type_1067
# copy_3 => copy_3
# div_27 => div_27
# div_28 => div_28
# full_default => full_default
# getitem => select
# getitem_1 => unsqueeze, unsqueeze_6
# layer_norm => add_12, convert_element_type_24, mul_18, rsqrt_2, sub_3, var_mean_2
# layer_norm_1 => add_6, convert_element_type_20, mul_15, rsqrt_1, sub_2, var_mean_1
# mask => eq, eq_1
# mul_465 => mul_465
# mul_466 => mul_466
# mul_467 => mul_467
# mul_468 => mul_468
# mul_469 => mul_469
# mul_470 => mul_470
# mul_471 => mul_471
# mul_477 => mul_477
# mul_479 => mul_479
# mul_480 => mul_480
# mul_481 => mul_481
# mul_482 => mul_482
# mul_483 => mul_483
# mul_484 => mul_484
# mul_485 => mul_485
# redistribute => convert_element_type_22
# redistribute_2 => convert_element_type_12
# slice_21 => slice_21
# sub_113 => sub_113
# sub_114 => sub_114
# sub_119 => sub_119
# sub_120 => sub_120
# sum_121 => sum_121
# sum_122 => sum_122
# sum_123 => sum_123
# sum_124 => sum_124
# sum_129 => sum_129
# sum_130 => sum_130
# to_1 => convert_element_type_17
# x => add_3, convert_element_type_14, mul_8, rsqrt, sub_1, var_mean
# Graph fragment:
# %mm_188 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_188]
# %mm_190 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_190]
# %mm_192 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_192]
# %primals_9 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=primals_9]
# %add_11 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_11]
# %getitem_5 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=getitem_5]
# %buf1473 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf1473]
# %sum_121 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_121]
# %sum_122 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_122]
# %add_202 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_202]
# %sub_114 : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=sub_114]
# %primals_1 : Tensor "u8[327680, 785][785, 1]cuda:0" = PlaceHolder[target=primals_1]
# %cat : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=cat]
# %getitem_3 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=getitem_3]
# %buf1547 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf1547]
# %primals_5 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=primals_5]
# %mm : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm]
# %getitem_1 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=getitem_1]
# %buf1572 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf1572]
# %mul_479 : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mul_479]
# %sum_129 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_129]
# %sum_130 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=sum_130]
# %convert_element_type_22 : Tensor "bf16[2048][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_9, torch.bfloat16), kwargs = {})
# %convert_element_type_24 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_11, torch.float32), kwargs = {})
# %var_mean_2 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_24, [1]), kwargs = {correction: 0, keepdim: True})
# %add_12 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_4, 1e-05), kwargs = {})
# %rsqrt_2 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_12,), kwargs = {})
# %sub_3 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_24, %getitem_5), kwargs = {})
# %mul_18 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_3, %rsqrt_2), kwargs = {})
# %add_203 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mm_188, %mm_190), kwargs = {})
# %add_204 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_203, %mm_192), kwargs = {})
# %convert_element_type_1046 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_204, torch.float32), kwargs = {})
# %convert_element_type_1048 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_22, torch.float32), kwargs = {})
# %mul_465 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1046, %convert_element_type_1048), kwargs = {})
# %mul_466 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_465, 2048), kwargs = {})
# %sum_121 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_465, [1], True), kwargs = {})
# %mul_467 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_465, %mul_18), kwargs = {})
# %sum_122 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_467, [1], True), kwargs = {})
# %mul_468 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_18, %sum_122), kwargs = {})
# %sub_113 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_466, %sum_121), kwargs = {})
# %sub_114 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_113, %mul_468), kwargs = {})
# %div_27 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%rsqrt_2, 2048), kwargs = {})
# %mul_469 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_27, %sub_114), kwargs = {})
# %mul_470 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1046, %mul_18), kwargs = {})
# %sum_123 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_470, [0]), kwargs = {})
# %sum_124 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_1046, [0]), kwargs = {})
# %convert_element_type_1050 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_469, torch.bfloat16), kwargs = {})
# %add_205 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_202, %convert_element_type_1050), kwargs = {})
# %select : Tensor "u8[327680][785]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.select.int](args = (%primals_1, 1, -1), kwargs = {})
# %eq_1 : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select, 2), kwargs = {})
# %unsqueeze_6 : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq_1, 1), kwargs = {})
# %mul_471 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %unsqueeze_6), kwargs = {})
# %convert_element_type_1055 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_471, torch.float32), kwargs = {})
# %convert_element_type_17 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%cat, torch.bfloat16), kwargs = {})
# %convert_element_type_20 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_17, torch.float32), kwargs = {})
# %var_mean_1 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_20, [1]), kwargs = {correction: 0, keepdim: True})
# %add_6 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_2, 1e-05), kwargs = {})
# %rsqrt_1 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_6,), kwargs = {})
# %sub_2 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_20, %getitem_3), kwargs = {})
# %mul_15 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_2, %rsqrt_1), kwargs = {})
# %mul_477 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1055, %mul_15), kwargs = {})
# %convert_element_type_12 : Tensor "bf16[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_5, torch.bfloat16), kwargs = {})
# %convert_element_type_1065 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_12, torch.float32), kwargs = {})
# %mul_479 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1055, %convert_element_type_1065), kwargs = {})
# %mul_480 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_479, 2048), kwargs = {})
# %sum_129 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_479, [1], True), kwargs = {})
# %convert_element_type_14 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm, torch.float32), kwargs = {})
# %var_mean : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_14, [1]), kwargs = {correction: 0, keepdim: True})
# %add_3 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem, 1e-05), kwargs = {})
# %rsqrt : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_3,), kwargs = {})
# %sub_1 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_14, %getitem_1), kwargs = {})
# %mul_8 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_1, %rsqrt), kwargs = {})
# %mul_481 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_479, %mul_8), kwargs = {})
# %sum_130 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_481, [1], True), kwargs = {})
# %mul_482 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_8, %sum_130), kwargs = {})
# %sub_119 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_480, %sum_129), kwargs = {})
# %sub_120 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_119, %mul_482), kwargs = {})
# %div_28 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%rsqrt, 2048), kwargs = {})
# %mul_483 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_28, %sub_120), kwargs = {})
# %mul_484 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1055, %mul_8), kwargs = {})
# %convert_element_type_1067 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_483, torch.bfloat16), kwargs = {})
# %eq : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select, 1), kwargs = {})
# %unsqueeze : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq, 1), kwargs = {})
# %mul_485 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %unsqueeze), kwargs = {})
# %slice_21 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_485, 1, 1, 9223372036854775807, 2), kwargs = {})
# %full_default : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([327680, 1024], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
# %copy_3 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_21, %full_default), kwargs = {})
# %slice_scatter_default : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%mul_485, %copy_3, 1, 1, 9223372036854775807, 2), kwargs = {})
# return %sum_121,%sum_122,%sub_114,%mul_477,%mul_479,%mul_484,%slice_scatter_default,%sum_129,%sum_130,%convert_element_type_1067
triton_per_fused__to_copy_add_copy_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_unsqueeze_zeros_like_26 = async_compile.triton('triton_per_fused__to_copy_add_copy_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_unsqueeze_zeros_like_26', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 524288, 'r0_': 2048},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*fp32', 'in_ptr7': '*bf16', 'in_ptr8': '*u8', 'in_ptr9': '*fp32', 'in_ptr10': '*fp32', 'in_ptr11': '*fp32', 'in_ptr12': '*bf16', 'in_ptr13': '*fp32', 'in_ptr14': '*fp32', 'out_ptr2': '*fp32', 'out_ptr4': '*fp32', 'out_ptr5': '*bf16', 'out_ptr8': '*bf16', 'ws_ptr': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'RSPLIT_SIZE': 'constexpr', 'NUM_STAGES': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]], (18,): [['tt.divisibility', 16]], (19,): [['tt.divisibility', 16]], (20,): [['tt.divisibility', 16]], (21,): [['tt.divisibility', 16]], (22,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'MixOrderReductionGrid', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_copy_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_unsqueeze_zeros_like_26', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 16, 'num_store': 0, 'num_reduction': 4, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'RSPLIT_SIZE': 128, 'has_loadstore_with_contiguous_rdim': True}
)
@triton.jit
def triton_per_fused__to_copy_add_copy_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_unsqueeze_zeros_like_26(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, out_ptr2, out_ptr4, out_ptr5, out_ptr8, ws_ptr, xnumel, r0_numel, XBLOCK : tl.constexpr, RSPLIT_SIZE : tl.constexpr, NUM_STAGES : tl.constexpr):
xnumel = 327680
r0_numel = 2048
R0_BLOCK: tl.constexpr = 2048
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * RSPLIT_SIZE
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_index = tl.arange(0, R0_BLOCK)[None, :]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
accum0 = tl.full([R0_BLOCK], 0, tl.float32)[None, :]
accum1 = tl.full([R0_BLOCK], 0, tl.float32)[None, :]
split_size = min(RSPLIT_SIZE, xnumel - xoffset)
for _ in tl.range(0, split_size, XBLOCK, num_stages=NUM_STAGES):
x0 = xindex
xindex += XBLOCK
tmp0 = tl.load(in_ptr0 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp6 = tl.load(in_ptr3 + (r0_1), None, eviction_policy='evict_last')
tmp13 = tl.load(in_ptr4 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp15 = tl.load(in_ptr5 + (x0), None, eviction_policy='evict_last')
tmp17 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last')
tmp32 = tl.load(in_ptr7 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp38 = tl.load(in_ptr8 + (784 + 785*x0), None, eviction_policy='evict_last')
tmp44 = tl.load(in_out_ptr0 + (r0_1 + 2048*x0), None)
tmp47 = tl.load(in_ptr9 + (x0), None, eviction_policy='evict_last')
tmp49 = tl.load(in_ptr10 + (x0), None, eviction_policy='evict_last')
tmp55 = tl.load(in_ptr11 + (r0_1), None, eviction_policy='evict_last')
tmp59 = tl.load(in_ptr12 + (r0_1 + 2048*x0), None).to(tl.float32)
tmp61 = tl.load(in_ptr13 + (x0), None, eviction_policy='evict_last')
tmp63 = tl.load(in_ptr14 + (x0), None, eviction_policy='evict_last')
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp5 * tmp8
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32)
tmp14 = tmp13.to(tl.float32)
tmp16 = tmp14 - tmp15
tmp18 = tl.full([1, 1], 2048.0, tl.float32)
tmp19 = (tmp17 / tmp18)
tmp20 = tl.full([1, 1], 1e-05, tl.float32)
tmp21 = tmp19 + tmp20
tmp22 = libdevice.rsqrt(tmp21)
tmp23 = tmp16 * tmp22
tmp24 = tmp9 * tmp23
tmp25 = tl.broadcast_to(tmp24, [XBLOCK, R0_BLOCK])
tmp27 = tl.sum(tmp25, 1)[:, None].to(tl.float32)
tmp28 = tmp9 * tmp18
tmp29 = tmp28 - tmp12
tmp30 = tmp23 * tmp27
tmp31 = tmp29 - tmp30
tmp33 = tl.full([1, 1], 0.00048828125, tl.float32)
tmp34 = tmp22 * tmp33
tmp35 = tmp34 * tmp31
tmp36 = tmp35.to(tl.float32)
tmp37 = tmp32 + tmp36
tmp39 = tl.full([1, 1], 2, tl.uint8)
tmp40 = tmp38 == tmp39
tmp41 = tmp40.to(tl.float32)
tmp42 = tmp37 * tmp41
tmp43 = tmp42.to(tl.float32)
tmp45 = tmp44.to(tl.float32)
tmp46 = tmp45.to(tl.float32)
tmp48 = tmp46 - tmp47
tmp50 = (tmp49 / tmp18)
tmp51 = tmp50 + tmp20
tmp52 = libdevice.rsqrt(tmp51)
tmp53 = tmp48 * tmp52
tmp54 = tmp43 * tmp53
tmp56 = tmp55.to(tl.float32)
tmp57 = tmp56.to(tl.float32)
tmp58 = tmp43 * tmp57
tmp60 = tmp59.to(tl.float32)
tmp62 = tmp60 - tmp61
tmp64 = (tmp63 / tmp18)
tmp65 = tmp64 + tmp20
tmp66 = libdevice.rsqrt(tmp65)
tmp67 = tmp62 * tmp66
tmp68 = tmp43 * tmp67
tmp69 = r0_1
tmp70 = tl.full([1, 1], 1, tl.int64)
tmp71 = tmp69 >= tmp70
tmp72 = (((-1) + r0_1) % 2)
tmp73 = tl.full([1, 1], 0, tl.int64)
tmp74 = tmp72 == tmp73
tmp75 = tmp71 & tmp74
tmp76 = tl.full([1, 1], 0.0, tl.float32)
tmp77 = tl.full(tmp76.shape, 0.0, tmp76.dtype)
tmp78 = tl.where(tmp75, tmp76, tmp77)
tmp79 = tl.full([1, 1], 1, tl.uint8)
tmp80 = tmp38 == tmp79
tmp81 = tmp80.to(tl.float32)
tmp82 = tmp37 * tmp81
tmp83 = tl.where(tmp75, tmp78, tmp82)
tmp84 = tl.broadcast_to(tmp58, [XBLOCK, R0_BLOCK])
tmp86 = tl.sum(tmp84, 1)[:, None].to(tl.float32)
tmp87 = tmp58 * tmp67
tmp88 = tl.broadcast_to(tmp87, [XBLOCK, R0_BLOCK])
tmp90 = tl.sum(tmp88, 1)[:, None].to(tl.float32)
tmp91 = tmp66 * tmp33
tmp92 = tmp58 * tmp18
tmp93 = tmp92 - tmp86
tmp94 = tmp67 * tmp90
tmp95 = tmp93 - tmp94
tmp96 = tmp91 * tmp95
tmp97 = tmp96.to(tl.float32)
tmp98 = tmp5 * tmp23
tl.store(out_ptr2 + (r0_1 + 2048*x0), tmp31, None)
tl.store(in_out_ptr0 + (r0_1 + 2048*x0), tmp54, None)
tl.store(out_ptr4 + (r0_1 + 2048*x0), tmp68, None)
tl.store(out_ptr5 + (r0_1 + 2048*x0), tmp83, None)
tl.store(out_ptr8 + (r0_1 + 2048*x0), tmp97, None)
tmp99 = tl.sum(tmp98, 0)
tmp100 = accum0 + tmp99
accum0 = tmp100
tmp101 = tl.sum(tmp5, 0)
tmp102 = accum1 + tmp101
accum1 = tmp102
tl.store(ws_ptr + (tl.program_id(0) + 0 * tl.num_programs(0)) * r0_numel + r0_index, accum0, r0_mask)
tl.store(ws_ptr + (tl.program_id(0) + 1 * tl.num_programs(0)) * r0_numel + r0_index, accum1, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/j7/cj7bnwjnv456srouilgyxwlg3q2bznu3ssquf7zbuoyp7gjfxyik.py
# Topologically Sorted Source Nodes: [sum_127], Original ATen: [aten.native_layer_norm_backward]
# Source node to ATen node mapping:
# sum_127 => sum_127
# Graph fragment:
# %mul_477 : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=mul_477]
# %sum_127 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_477, [0]), kwargs = {})
# return %buf1550
triton_red_fused_native_layer_norm_backward_27 = async_compile.triton('triton_red_fused_native_layer_norm_backward_27', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 2048},
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_native_layer_norm_backward_27', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 2686976000, 'r0_': 0}}
)
@triton.jit
def triton_red_fused_native_layer_norm_backward_27(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 327680
r0_numel = 2048
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = (xindex % 2048)
x1 = xindex // 2048
_tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
x3 = xindex
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (x0 + 2048*r0_2 + 4194304*x1), r0_mask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(r0_mask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp2, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/h6/ch6htjq4lzhko3lbglu5qfgog4kxolyrgb4ics57nzdmzydgmext.py
# Topologically Sorted Source Nodes: [sum_127, convert_element_type_1059, all_reduce_147], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
# Source node to ATen node mapping:
# all_reduce_147 => all_reduce_147
# convert_element_type_1059 => convert_element_type_1059
# sum_127 => sum_127
# Graph fragment:
# %buf1550 : Tensor "f32[2048, 160][1, 2048]cuda:0" = PlaceHolder[target=buf1550]
# %sum_127 : Tensor "f32[2048][1]cuda:0" = PlaceHolder[target=sum_127]
# %sum_127 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_477, [0]), kwargs = {})
# %convert_element_type_1059 : Tensor "bf16[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_127, torch.bfloat16), kwargs = {})
# %all_reduce_147 : Tensor "bf16[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce_.default](args = (%convert_element_type_1059, avg, 0), kwargs = {})
# return %sum_127,%wait_tensor_147
triton_red_fused_all_reduce_native_layer_norm_backward_28 = async_compile.triton('triton_red_fused_all_reduce_native_layer_norm_backward_28', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 2048, 'r0_': 256},
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_all_reduce_native_layer_norm_backward_28', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 1318912, 'r0_': 0}}
)
@triton.jit
def triton_red_fused_all_reduce_native_layer_norm_backward_28(in_ptr0, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 2048
r0_numel = 160
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (x0 + 2048*r0_1), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tmp4 = tmp2.to(tl.float32)
tl.store(out_ptr1 + (x0), tmp4, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/pa/cpa4kcho55jmuviwxacygdlktmp5752fifrqu64nhoefgqtwxj23.py
# Topologically Sorted Source Nodes: [layer_norm, div_27, mul_469, convert_element_type_1050, add_205, getitem, mask, getitem_1, mul_471, convert_element_type_1055, sum_128], Original ATen: [aten.native_layer_norm, aten.native_layer_norm_backward, aten.add, aten.select, aten.eq, aten.unsqueeze, aten.mul]
# Source node to ATen node mapping:
# add_205 => add_205
# convert_element_type_1050 => convert_element_type_1050
# convert_element_type_1055 => convert_element_type_1055
# div_27 => div_27
# getitem => select
# getitem_1 => unsqueeze_6
# layer_norm => add_12, convert_element_type_24, rsqrt_2, var_mean_2
# mask => eq_1
# mul_469 => mul_469
# mul_471 => mul_471
# sum_128 => sum_128
# Graph fragment:
# %add_202 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_202]
# %buf1473 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf1473]
# %sub_114 : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=sub_114]
# %primals_1 : Tensor "u8[327680, 785][785, 1]cuda:0" = PlaceHolder[target=primals_1]
# %convert_element_type_24 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_11, torch.float32), kwargs = {})
# %var_mean_2 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_24, [1]), kwargs = {correction: 0, keepdim: True})
# %add_12 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_4, 1e-05), kwargs = {})
# %rsqrt_2 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_12,), kwargs = {})
# %div_27 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%rsqrt_2, 2048), kwargs = {})
# %mul_469 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_27, %sub_114), kwargs = {})
# %convert_element_type_1050 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_469, torch.bfloat16), kwargs = {})
# %add_205 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_202, %convert_element_type_1050), kwargs = {})
# %select : Tensor "u8[327680][785]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.select.int](args = (%primals_1, 1, -1), kwargs = {})
# %eq_1 : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select, 2), kwargs = {})
# %unsqueeze_6 : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq_1, 1), kwargs = {})
# %mul_471 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %unsqueeze_6), kwargs = {})
# %convert_element_type_1055 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_471, torch.float32), kwargs = {})
# %sum_128 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_1055, [0]), kwargs = {})
# return %buf1552
triton_red_fused_add_eq_mul_native_layer_norm_native_layer_norm_backward_select_unsqueeze_29 = async_compile.triton('triton_red_fused_add_eq_mul_native_layer_norm_native_layer_norm_backward_select_unsqueeze_29', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 2048},
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*u8', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_eq_mul_native_layer_norm_native_layer_norm_backward_select_unsqueeze_29', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'x': 4029153280, 'r0_': 1310720}}
)
@triton.jit
def triton_red_fused_add_eq_mul_native_layer_norm_native_layer_norm_backward_select_unsqueeze_29(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 327680
r0_numel = 2048
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = (xindex % 2048)
x1 = xindex // 2048
_tmp20 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
x3 = xindex
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (x0 + 2048*r0_2 + 4194304*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_2 + 2048*x1), r0_mask, eviction_policy='evict_last', other=0.0)
tmp9 = tl.load(in_ptr2 + (x0 + 2048*r0_2 + 4194304*x1), r0_mask, eviction_policy='evict_first', other=0.0)
tmp13 = tl.load(in_ptr3 + (784 + 785*r0_2 + 1607680*x1), r0_mask, eviction_policy='evict_last', other=0.0)
tmp2 = tl.full([1, 1], 2048.0, tl.float32)
tmp3 = (tmp1 / tmp2)
tmp4 = tl.full([1, 1], 1e-05, tl.float32)
tmp5 = tmp3 + tmp4
tmp6 = libdevice.rsqrt(tmp5)
tmp7 = tl.full([1, 1], 0.00048828125, tl.float32)
tmp8 = tmp6 * tmp7
tmp10 = tmp8 * tmp9
tmp11 = tmp10.to(tl.float32)
tmp12 = tmp0 + tmp11
tmp14 = tl.full([1, 1], 2, tl.uint8)
tmp15 = tmp13 == tmp14
tmp16 = tmp15.to(tl.float32)
tmp17 = tmp12 * tmp16
tmp18 = tmp17.to(tl.float32)
tmp19 = tl.broadcast_to(tmp18, [XBLOCK, R0_BLOCK])
tmp21 = _tmp20 + tmp19
_tmp20 = tl.where(r0_mask, tmp21, _tmp20)
tmp20 = tl.sum(_tmp20, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp20, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/4f/c4fosyx7zmpcnmjxjiskxj5txwzq4rjvox5xazko4zt622kteyye.py
# Topologically Sorted Source Nodes: [convert_element_type_1061, convert_element_type_1070], Original ATen: [aten._to_copy]
# Source node to ATen node mapping:
# convert_element_type_1061 => convert_element_type_1061
# convert_element_type_1070 => convert_element_type_1070
# Graph fragment:
# %buf1558 : Tensor = PlaceHolder[target=buf1558]
# %convert_element_type_1061 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_146, torch.float32), kwargs = {})
# %convert_element_type_1070 : Tensor "f32[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_146, torch.float32), kwargs = {})
# return %convert_element_type_1061,%convert_element_type_1070
triton_poi_fused__to_copy_30 = async_compile.triton('triton_poi_fused__to_copy_30', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 2048},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_30', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 2, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 32768}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_30(in_ptr0, out_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr):
xnumel = 2048
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, xmask)
tl.store(out_ptr1 + (x0), tmp1, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/vv/cvvvlepxuobkrbn55fwahnnr5slz6gc3plw3ii5n7qpgwcdewpmm.py
# Topologically Sorted Source Nodes: [convert_element_type_1074], Original ATen: [aten._to_copy]
# Source node to ATen node mapping:
# convert_element_type_1074 => convert_element_type_1074
# Graph fragment:
# %buf1590 : Tensor = PlaceHolder[target=buf1590]
# %convert_element_type_1074 : Tensor "f32[2048, 768][768, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_150, torch.float32), kwargs = {})
# return %convert_element_type_1074
triton_poi_fused__to_copy_31 = async_compile.triton('triton_poi_fused__to_copy_31', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 2097152},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_31', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 12582912}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_31(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1572864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/pe/cpea3exprjpgbfkgsy3hdvwmv3wa542d5oqiyyahnqjznbrvarss.py
# Topologically Sorted Source Nodes: [layer_norm, div_27, mul_469, convert_element_type_1050, add_205, getitem, mask, getitem_1, mul_485, slice_21, clone_112, full_default_1], Original ATen: [aten.native_layer_norm, aten.native_layer_norm_backward, aten.add, aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone, aten.slice_backward]
# Source node to ATen node mapping:
# add_205 => add_205
# clone_112 => clone_112
# convert_element_type_1050 => convert_element_type_1050
# div_27 => div_27
# full_default_1 => full_default_1
# getitem => select
# getitem_1 => unsqueeze
# layer_norm => add_12, convert_element_type_24, rsqrt_2, var_mean_2
# mask => eq
# mul_469 => mul_469
# mul_485 => mul_485
# slice_21 => slice_21
# Graph fragment:
# %add_202 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_202]
# %buf1473 : Tensor "f32[327680, 1][1, 327680]cuda:0" = PlaceHolder[target=buf1473]
# %sub_114 : Tensor "f32[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=sub_114]
# %primals_1 : Tensor "u8[327680, 785][785, 1]cuda:0" = PlaceHolder[target=primals_1]
# %convert_element_type_24 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_11, torch.float32), kwargs = {})
# %var_mean_2 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_24, [1]), kwargs = {correction: 0, keepdim: True})
# %add_12 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_4, 1e-05), kwargs = {})
# %rsqrt_2 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_12,), kwargs = {})
# %div_27 : Tensor "f32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%rsqrt_2, 2048), kwargs = {})
# %mul_469 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_27, %sub_114), kwargs = {})
# %convert_element_type_1050 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_469, torch.bfloat16), kwargs = {})
# %add_205 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_202, %convert_element_type_1050), kwargs = {})
# %select : Tensor "u8[327680][785]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.select.int](args = (%primals_1, 1, -1), kwargs = {})
# %eq : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select, 1), kwargs = {})
# %unsqueeze : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq, 1), kwargs = {})
# %mul_485 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %unsqueeze), kwargs = {})
# %slice_21 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_485, 1, 1, 9223372036854775807, 2), kwargs = {})
# %clone_112 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_21,), kwargs = {memory_format: torch.contiguous_format})
# %full_default_1 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([327680, 2048], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
# %slice_scatter_default_1 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_1, %clone_112, 1, 1, 9223372036854775807, 2), kwargs = {})
# return %slice_scatter_default_1
triton_poi_fused_add_clone_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_slice_backward_unsqueeze_32 = async_compile.triton('triton_poi_fused_add_clone_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_slice_backward_unsqueeze_32', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 1073741824},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*u8', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_clone_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_slice_backward_unsqueeze_32', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 2684354560}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_clone_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_slice_backward_unsqueeze_32(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 671088640
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x0 = (xindex % 2048)
x1 = xindex // 2048
x2 = xindex
tmp0 = x0
tmp1 = tl.full([1], 1, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = (((-1) + x0) % 2)
tmp4 = tl.full([1], 0, tl.int64)
tmp5 = tmp3 == tmp4
tmp6 = tmp2 & tmp5
tmp7 = tl.load(in_ptr0 + (1 + 2*(triton_helpers.div_floor_integer((-1) + x0, 2)) + 2048*x1), tmp6, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp8 = tl.load(in_ptr1 + (x1), tmp6, eviction_policy='evict_last', other=0.0)
tmp9 = tl.full([1], 2048.0, tl.float32)
tmp10 = (tmp8 / tmp9)
tmp11 = tl.full([1], 1e-05, tl.float32)
tmp12 = tmp10 + tmp11
tmp13 = libdevice.rsqrt(tmp12)
tmp14 = tl.full([1], 0.00048828125, tl.float32)
tmp15 = tmp13 * tmp14
tmp16 = tl.load(in_ptr2 + (1 + 2*(triton_helpers.div_floor_integer((-1) + x0, 2)) + 2048*x1), tmp6, eviction_policy='evict_last', other=0.0)
tmp17 = tmp15 * tmp16
tmp18 = tmp17.to(tl.float32)
tmp19 = tmp7 + tmp18
tmp20 = tl.load(in_ptr3 + (784 + 785*x1), tmp6, eviction_policy='evict_last', other=0.0)
tmp21 = tl.full([1], 1, tl.uint8)
tmp22 = tmp20 == tmp21
tmp23 = tmp22.to(tl.float32)
tmp24 = tmp19 * tmp23
tmp25 = tl.full(tmp24.shape, 0.0, tmp24.dtype)
tmp26 = tl.where(tmp6, tmp24, tmp25)
tmp27 = tl.full([1], 0.0, tl.float32)
tmp28 = tl.where(tmp6, tmp26, tmp27)
tl.store(out_ptr0 + (x2), tmp28, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/33/c33os57d226grg6ristuoxyi6krgwdsthxorz2uor6zyzc6wuzww.py
# Topologically Sorted Source Nodes: [add_206, slice_24, clone_113, convert_element_type_1075, mul_486, sum_133], Original ATen: [aten.add, aten.slice, aten.clone, aten._to_copy, aten.mul, aten.sum]
# Source node to ATen node mapping:
# add_206 => add_206
# clone_113 => clone_113
# convert_element_type_1075 => convert_element_type_1075
# mul_486 => mul_486
# slice_24 => slice_24
# sum_133 => sum_133
# Graph fragment:
# %slice_scatter_default : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=slice_scatter_default]
# %slice_scatter_default_1 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=slice_scatter_default_1]
# %cos : Tensor "f32[327680, 1024][1024, 1]cuda:0" = PlaceHolder[target=cos]
# %add_206 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {})
# %slice_24 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add_206, 1, 1, 9223372036854775807, 2), kwargs = {})
# %clone_113 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.clone.default](args = (%slice_24,), kwargs = {memory_format: torch.contiguous_format})
# %convert_element_type_1075 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clone_113, torch.float32), kwargs = {})
# %mul_486 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1075, %cos), kwargs = {})
# %sum_133 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_486,), kwargs = {})
# return %buf1595
triton_red_fused__to_copy_add_clone_mul_slice_sum_33 = async_compile.triton('triton_red_fused__to_copy_add_clone_mul_slice_sum_33', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 262144, 'r0_': 2048},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_clone_mul_slice_sum_33', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'x': 1310720, 'r0_': 1342177280}}
)
@triton.jit
def triton_red_fused__to_copy_add_clone_mul_slice_sum_33(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 163840
r0_numel = 2048
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (1 + 2*r0_1 + 4096*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (1 + 2*r0_1 + 4096*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp4 = tl.load(in_ptr2 + (r0_1 + 2048*x0), r0_mask, eviction_policy='evict_first', other=0.0)
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp5 = tmp3 * tmp4
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
tmp8 = _tmp7 + tmp6
_tmp7 = tl.where(r0_mask, tmp8, _tmp7)
tmp7 = tl.sum(_tmp7, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp7, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/sh/cshwswgpmiqr5iwybcqpwf4nwd2h47twq7yufqtiqf5ytlbgoe4o.py
# Topologically Sorted Source Nodes: [add_206, slice_24, clone_113, convert_element_type_1075, mul_486, sum_133], Original ATen: [aten.add, aten.slice, aten.clone, aten._to_copy, aten.mul, aten.sum]
# Source node to ATen node mapping:
# add_206 => add_206
# clone_113 => clone_113
# convert_element_type_1075 => convert_element_type_1075
# mul_486 => mul_486
# slice_24 => slice_24
# sum_133 => sum_133
# Graph fragment:
# %buf1595 : Tensor "f32[2560, 64][64, 1]cuda:0" = PlaceHolder[target=buf1595]
# %add_206 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {})
# %slice_24 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add_206, 1, 1, 9223372036854775807, 2), kwargs = {})
# %clone_113 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.clone.default](args = (%slice_24,), kwargs = {memory_format: torch.contiguous_format})
# %convert_element_type_1075 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clone_113, torch.float32), kwargs = {})
# %mul_486 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1075, %cos), kwargs = {})
# %sum_133 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_486,), kwargs = {})
# return %buf1596
triton_per_fused__to_copy_add_clone_mul_slice_sum_34 = async_compile.triton('triton_per_fused__to_copy_add_clone_mul_slice_sum_34', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 4096, 'r0_': 64},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_clone_mul_slice_sum_34', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'x': 20480, 'r0_': 655360}}
)
@triton.jit
def triton_per_fused__to_copy_add_clone_mul_slice_sum_34(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr):
xnumel = 2560
r0_numel = 64
R0_BLOCK: tl.constexpr = 64
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_index = tl.arange(0, R0_BLOCK)[None, :]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 64*x0), xmask, other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = tl.where(xmask, tmp1, 0)
tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32)
tl.store(out_ptr0 + (x0), tmp4, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/bv/cbvk7epkxauekcwgof5v56sdrua37n7ccsnkbqppguepvf2mfyzm.py
# Topologically Sorted Source Nodes: [add_206, slice_24, clone_113, convert_element_type_1075, mul_486, sum_133, convert_element_type_1076], Original ATen: [aten.add, aten.slice, aten.clone, aten._to_copy, aten.mul, aten.sum]
# Source node to ATen node mapping:
# add_206 => add_206
# clone_113 => clone_113
# convert_element_type_1075 => convert_element_type_1075
# convert_element_type_1076 => convert_element_type_1076
# mul_486 => mul_486
# slice_24 => slice_24
# sum_133 => sum_133
# Graph fragment:
# %buf1596 : Tensor "f32[2560][1]cuda:0" = PlaceHolder[target=buf1596]
# %sum_133 : Tensor "f32[][]cuda:0" = PlaceHolder[target=sum_133]
# %add_206 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {})
# %slice_24 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add_206, 1, 1, 9223372036854775807, 2), kwargs = {})
# %clone_113 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.clone.default](args = (%slice_24,), kwargs = {memory_format: torch.contiguous_format})
# %convert_element_type_1075 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clone_113, torch.float32), kwargs = {})
# %mul_486 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1075, %cos), kwargs = {})
# %sum_133 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_486,), kwargs = {})
# %convert_element_type_1076 : Tensor "bf16[][]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_133, torch.bfloat16), kwargs = {})
# return %sum_133,%wait_tensor_151
triton_red_fused__to_copy_add_clone_mul_slice_sum_35 = async_compile.triton('triton_red_fused__to_copy_add_clone_mul_slice_sum_35', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 1, 'r0_': 4096},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_clone_mul_slice_sum_35', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': True, 'tiling_scores': {'r0_': 10240}}
)
@triton.jit
def triton_red_fused__to_copy_add_clone_mul_slice_sum_35(in_ptr0, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 1
r0_numel = 2560
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
_tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_0 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(r0_mask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tmp4 = tmp2.to(tl.float32)
tl.store(out_ptr1 + (tl.full([1, 1], 0, tl.int32).broadcast_to(XBLOCK, 1)), tmp4, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/ul/cul436w24uow7x4gca343floynluerpavkg3l7at6smg35fvjsfe.py
# Topologically Sorted Source Nodes: [full_default, full_default_1, add_206, slice_24, clone_113, copy_5, slice_27, clone_114, copy_7, add_207], Original ATen: [aten.zeros_like, aten.slice_backward, aten.add, aten.slice, aten.clone, aten.copy]
# Source node to ATen node mapping:
# add_206 => add_206
# add_207 => add_207
# clone_113 => clone_113
# clone_114 => clone_114
# copy_5 => copy_5
# copy_7 => copy_7
# full_default => full_default
# full_default_1 => full_default_1
# slice_24 => slice_24
# slice_27 => slice_27
# Graph fragment:
# %slice_scatter_default : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=slice_scatter_default]
# %slice_scatter_default_1 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=slice_scatter_default_1]
# %full_default : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([327680, 1024], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
# %full_default_1 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([327680, 2048], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
# %add_206 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {})
# %slice_24 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add_206, 1, 1, 9223372036854775807, 2), kwargs = {})
# %clone_113 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.clone.default](args = (%slice_24,), kwargs = {memory_format: torch.contiguous_format})
# %copy_5 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_24, %clone_113), kwargs = {})
# %slice_scatter_default_2 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%add_206, %copy_5, 1, 1, 9223372036854775807, 2), kwargs = {})
# %slice_27 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_scatter_default_2, 1, 0, 9223372036854775807, 2), kwargs = {})
# %clone_114 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_27,), kwargs = {memory_format: torch.contiguous_format})
# %copy_7 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_27, %full_default), kwargs = {})
# %slice_scatter_default_3 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_scatter_default_2, %copy_7, 1, 0, 9223372036854775807, 2), kwargs = {})
# %slice_scatter_default_4 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_1, %clone_114, 1, 0, 9223372036854775807, 2), kwargs = {})
# %add_207 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_3, %slice_scatter_default_4), kwargs = {})
# return %add_207
triton_poi_fused_add_clone_copy_slice_slice_backward_zeros_like_36 = async_compile.triton('triton_poi_fused_add_clone_copy_slice_slice_backward_zeros_like_36', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 1073741824},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_clone_copy_slice_slice_backward_zeros_like_36', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 8, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 10737418240}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_clone_copy_slice_slice_backward_zeros_like_36(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 671088640
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x2 = xindex
x0 = (xindex % 2048)
x1 = xindex // 2048
tmp17 = tl.load(in_ptr0 + (x2), None).to(tl.float32)
tmp18 = tl.load(in_ptr1 + (x2), None).to(tl.float32)
tmp0 = (x2 % 2)
tmp1 = tl.full([1], 0, tl.int64)
tmp2 = tmp0 == tmp1
tmp3 = tl.full([1], 0.0, tl.float32)
tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
tmp5 = tl.where(tmp2, tmp3, tmp4)
tmp6 = x0
tmp7 = tl.full([1], 1, tl.int64)
tmp8 = tmp6 >= tmp7
tmp9 = (((-1) + x0) % 2)
tmp10 = tmp9 == tmp1
tmp11 = tmp8 & tmp10
tmp12 = tl.load(in_ptr0 + (1 + 2*(triton_helpers.div_floor_integer((-1) + x0, 2)) + 2048*x1), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp13 = tl.load(in_ptr1 + (1 + 2*(triton_helpers.div_floor_integer((-1) + x0, 2)) + 2048*x1), tmp11, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp14 = tmp12 + tmp13
tmp15 = tl.full(tmp14.shape, 0.0, tmp14.dtype)
tmp16 = tl.where(tmp11, tmp14, tmp15)
tmp19 = tmp17 + tmp18
tmp20 = tl.where(tmp11, tmp16, tmp19)
tmp21 = tl.where(tmp2, tmp5, tmp20)
tmp22 = 2*(x0 // 2)
tmp23 = tl.full([1], 1, tl.int64)
tmp24 = tmp22 >= tmp23
tmp25 = (((-1) + 2*(x0 // 2)) % 2)
tmp26 = tl.full([1], 0, tl.int64)
tmp27 = tmp25 == tmp26
tmp28 = tmp24 & tmp27
tmp29 = tmp28 & tmp2
tmp30 = tl.load(in_ptr0 + ((-1) + 2*(x0 // 2) + 2048*x1), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp31 = tl.load(in_ptr1 + ((-1) + 2*(x0 // 2) + 2048*x1), tmp29, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp32 = tmp30 + tmp31
tmp33 = tl.full(tmp32.shape, 0.0, tmp32.dtype)
tmp34 = tl.where(tmp29, tmp32, tmp33)
tmp35 = tl.load(in_ptr0 + (2*(x0 // 2) + 2048*x1), tmp2, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp36 = tl.load(in_ptr1 + (2*(x0 // 2) + 2048*x1), tmp2, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp37 = tmp35 + tmp36
tmp38 = tl.where(tmp28, tmp34, tmp37)
tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype)
tmp40 = tl.where(tmp2, tmp38, tmp39)
tmp41 = tl.full([1], 0.0, tl.float32)
tmp42 = tl.where(tmp2, tmp40, tmp41)
tmp43 = tmp21 + tmp42
tl.store(out_ptr0 + (x2), tmp43, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/fo/cfoz6dfisizdj2zmeupbzqtnjgko6ew4rh2fhiswleuzftetokpz.py
# Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_4, contiguous_1], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone]
# Source node to ATen node mapping:
# contiguous_1 => clone_1
# data => mul
# getitem => select
# getitem_1 => unsqueeze
# getitem_4 => slice_2
# mask => eq
# Graph fragment:
# %primals_1 : Tensor "u8[327680, 785][785, 1]cuda:0" = PlaceHolder[target=primals_1]
# %select : Tensor "u8[327680][785]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.select.int](args = (%primals_1, 1, -1), kwargs = {})
# %eq : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select, 1), kwargs = {})
# %unsqueeze : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq, 1), kwargs = {})
# %mul : Tensor "u8[327680, 785][785, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_1, %unsqueeze), kwargs = {})
# %slice_2 : Tensor "u8[327680, 4][785, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul, 1, 8, 12), kwargs = {})
# %clone_1 : Tensor "u8[327680, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format})
# return %clone_1
triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_37 = async_compile.triton('triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_37', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 2097152},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*u8', 'out_ptr0': '*u8', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_37', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 3932160}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_37(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1310720
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x0 = (xindex % 4)
x1 = xindex // 4
x2 = xindex
tmp0 = tl.load(in_ptr0 + (8 + x0 + 785*x1), None)
tmp1 = tl.load(in_ptr0 + (784 + 785*x1), None, eviction_policy='evict_last')
tmp2 = tl.full([1], 1, tl.uint8)
tmp3 = tmp1 == tmp2
tmp4 = tmp3.to(tl.uint8)
tmp5 = tmp0 * tmp4
tl.store(out_ptr0 + (x2), tmp5, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/wm/cwmzns2lkf7z7qgwcvjva2ofngj6wsanegij4ylfdjsl2sy4hspo.py
# Topologically Sorted Source Nodes: [slice_30, clone_115, convert_element_type_1078, positions, ifreqs, ifreqs_1, neg, freqs, getitem_7, getitem_8, mul_1, sin, mul_487, sum_134], Original ATen: [aten.slice, aten.clone, aten._to_copy, aten.select, aten.arange, aten.div, aten.neg, aten.pow, aten.unsqueeze, aten.mul, aten.sin, aten.sum]
# Source node to ATen node mapping:
# clone_115 => clone_115
# convert_element_type_1078 => convert_element_type_1078
# freqs => pow_1
# getitem_7 => unsqueeze_1
# getitem_8 => unsqueeze_2
# ifreqs => add, convert_element_type_2, iota, mul_1
# ifreqs_1 => div
# mul_1 => mul_2
# mul_487 => mul_487
# neg => neg
# positions => select_2
# sin => sin
# slice_30 => slice_30
# sum_134 => sum_134
# Graph fragment:
# %add_207 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_207]
# %view_1 : Tensor "i32[327680, 1][1, 1]cuda:0" = PlaceHolder[target=view_1]
# %slice_30 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add_207, 1, 0, 9223372036854775807, 2), kwargs = {})
# %clone_115 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.clone.default](args = (%slice_30,), kwargs = {memory_format: torch.contiguous_format})
# %convert_element_type_1078 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clone_115, torch.float32), kwargs = {})
# %select_2 : Tensor "i32[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%view_1, 1, 0), kwargs = {})
# %iota : Tensor "i64[1024][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1024,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
# %mul_1 : Tensor "i64[1024][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%iota, 2), kwargs = {})
# %add : Tensor "i64[1024][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_1, 0), kwargs = {})
# %convert_element_type_2 : Tensor "f32[1024][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add, torch.float32), kwargs = {})
# %div : Tensor "f32[1024][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%convert_element_type_2, 2048), kwargs = {})
# %neg : Tensor "f32[1024][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%div,), kwargs = {})
# %pow_1 : Tensor "f32[1024][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Scalar](args = (10000.0, %neg), kwargs = {})
# %unsqueeze_1 : Tensor "i32[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%select_2, 1), kwargs = {})
# %unsqueeze_2 : Tensor "f32[1, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%pow_1, 0), kwargs = {})
# %mul_2 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%unsqueeze_1, %unsqueeze_2), kwargs = {})
# %sin : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mul_2,), kwargs = {})
# %mul_487 : Tensor "f32[327680, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1078, %sin), kwargs = {})
# %sum_134 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_487,), kwargs = {})
# return %buf1608
triton_red_fused__to_copy_arange_clone_div_mul_neg_pow_select_sin_slice_sum_unsqueeze_38 = async_compile.triton('triton_red_fused__to_copy_arange_clone_div_mul_neg_pow_select_sin_slice_sum_unsqueeze_38', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 262144, 'r0_': 2048},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_clone_div_mul_neg_pow_select_sin_slice_sum_unsqueeze_38', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 1310720, 'r0_': 1310720}}
)
@triton.jit
def triton_red_fused__to_copy_arange_clone_div_mul_neg_pow_select_sin_slice_sum_unsqueeze_38(in_ptr0, in_ptr1, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 163840
r0_numel = 2048
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp15 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (2*r0_1 + 4096*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (2*x0 + (r0_1 // 1024)), r0_mask, eviction_policy='evict_last', other=0.0)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = 2*((r0_1 % 1024))
tmp5 = tmp4.to(tl.float32)
tmp6 = tl.full([1, 1], 0.00048828125, tl.float32)
tmp7 = tmp5 * tmp6
tmp8 = -tmp7
tmp9 = tl.full([1, 1], 10000.0, tl.float32)
tmp10 = libdevice.pow(tmp9, tmp8)
tmp11 = tmp3 * tmp10
tmp12 = tl_math.sin(tmp11)
tmp13 = tmp1 * tmp12
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp16 = _tmp15 + tmp14
_tmp15 = tl.where(r0_mask, tmp16, _tmp15)
tmp15 = tl.sum(_tmp15, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp15, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/hy/chy7u7jsfy2qsmcd33o6b7i45qzyyoeearxatlqrn2wgkrjafsap.py
# Topologically Sorted Source Nodes: [full_default_5], Original ATen: [aten.embedding_dense_backward]
# Source node to ATen node mapping:
# full_default_5 => full_default_5
# Graph fragment:
# %full_default_5 : Tensor "f32[259, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([259, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# return %index_put
triton_poi_fused_embedding_dense_backward_39 = async_compile.triton('triton_poi_fused_embedding_dense_backward_39', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 1048576},
filename=__file__,
triton_meta={'signature': {'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_39', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 0, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 4243456}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_embedding_dense_backward_39(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 530432
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.full([1], 0.0, tl.float32)
tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/wr/cwrdkxk3wqg66t3ipy6nagxwmmip3afro46w3ozyntss2wr52t3a.py
# Topologically Sorted Source Nodes: [slice_30, clone_115, copy_9, convert_element_type_1081, eq_2, unsqueeze_16, full_default_4, where], Original ATen: [aten.slice, aten.clone, aten.copy, aten.embedding_dense_backward]
# Source node to ATen node mapping:
# clone_115 => clone_115
# convert_element_type_1081 => convert_element_type_1081
# copy_9 => copy_9
# eq_2 => eq_2
# full_default_4 => full_default_4
# slice_30 => slice_30
# unsqueeze_16 => unsqueeze_16
# where => where
# Graph fragment:
# %select_1 : Tensor "i64[327680][1]cuda:0" = PlaceHolder[target=select_1]
# %add_207 : Tensor "bf16[327680, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_207]
# %slice_30 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%add_207, 1, 0, 9223372036854775807, 2), kwargs = {})
# %clone_115 : Tensor "bf16[327680, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.clone.default](args = (%slice_30,), kwargs = {memory_format: torch.contiguous_format})
# %copy_9 : Tensor "bf16[327680, 1024][2048, 2]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_30, %clone_115), kwargs = {})
# %slice_scatter_default_5 : Tensor "bf16[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%add_207, %copy_9, 1, 0, 9223372036854775807, 2), kwargs = {})
# %convert_element_type_1081 : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%slice_scatter_default_5, torch.float32), kwargs = {})
# %eq_2 : Tensor "b8[327680][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%select_1, -1), kwargs = {})
# %unsqueeze_16 : Tensor "b8[327680, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%eq_2, -1), kwargs = {})
# %full_default_4 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where : Tensor "f32[327680, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_16, %full_default_4, %convert_element_type_1081), kwargs = {})
# return %where
triton_poi_fused_clone_copy_embedding_dense_backward_slice_40 = async_compile.triton('triton_poi_fused_clone_copy_embedding_dense_backward_slice_40', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 1073741824},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_copy_embedding_dense_backward_slice_40', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 8053063680}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_clone_copy_embedding_dense_backward_slice_40(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 671088640
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
x1 = xindex // 2048
x2 = xindex
x0 = (xindex % 2048)
tmp0 = tl.load(in_ptr0 + (x1), None, eviction_policy='evict_last')
tmp7 = tl.load(in_ptr1 + (x2), None).to(tl.float32)
tmp1 = tl.full([1], -1, tl.int64)
tmp2 = tmp0 == tmp1
tmp3 = (x2 % 2)
tmp4 = tl.full([1], 0, tl.int64)
tmp5 = tmp3 == tmp4
tmp6 = tl.load(in_ptr1 + (2*(x0 // 2) + 2048*x1), tmp5, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp8 = tl.where(tmp5, tmp6, tmp7)
tmp9 = tmp8.to(tl.float32)
tmp10 = tl.full([1], 0.0, tl.float32)
tmp11 = tl.where(tmp2, tmp10, tmp9)
tl.store(out_ptr0 + (x2), tmp11, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/rd/crdcbvdrllgakbwzs7hop6q5msvy2l25yyur5n73bke6t4apfi67.py
# Topologically Sorted Source Nodes: [convert_element_type_1082, all_reduce_153], Original ATen: [aten.embedding_dense_backward, _c10d_functional.all_reduce]
# Source node to ATen node mapping:
# all_reduce_153 => all_reduce_153
# convert_element_type_1082 => convert_element_type_1082
# Graph fragment:
# %buf1618 : Tensor = PlaceHolder[target=buf1618]
# %convert_element_type_1082 : Tensor "bf16[259, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%index_put, torch.bfloat16), kwargs = {})
# %all_reduce_153 : Tensor "bf16[259, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce_.default](args = (%convert_element_type_1082, avg, 0), kwargs = {})
# return %wait_tensor_153
triton_poi_fused_all_reduce_embedding_dense_backward_41 = async_compile.triton('triton_poi_fused_all_reduce_embedding_dense_backward_41', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 1048576},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_all_reduce_embedding_dense_backward_41', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 2121728}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_all_reduce_embedding_dense_backward_41(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 530432
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/wq/cwqyv2tz7qk2wcxmq773ashkr4z7appc4qvtyxetah43gcu357k2.py
# Topologically Sorted Source Nodes: [convert_element_type_1083], Original ATen: [aten._to_copy]
# Source node to ATen node mapping:
# convert_element_type_1083 => convert_element_type_1083
# Graph fragment:
# %buf1623 : Tensor = PlaceHolder[target=buf1623]
# %convert_element_type_1083 : Tensor "f32[259, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_153, torch.float32), kwargs = {})
# return %convert_element_type_1083
triton_poi_fused__to_copy_42 = async_compile.triton('triton_poi_fused__to_copy_42', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 1048576},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_42', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False, 'tiling_scores': {'x': 4243456}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_42(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 530432
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor/rank0/nj/cnjfakvke3hak7tptij6sluwkwua32gs4pbkbabtcnn6lqzhdh6d.py
# Topologically Sorted Source Nodes: [convert_element_type_1077, convert_element_type_1080, add_209], Original ATen: [aten._to_copy, aten.add]
# Source node to ATen node mapping:
# add_209 => add_209
# convert_element_type_1077 => convert_element_type_1077
# convert_element_type_1080 => convert_element_type_1080
# Graph fragment:
# %buf1602 : Tensor = PlaceHolder[target=buf1602]
# %buf1615 : Tensor = PlaceHolder[target=buf1615]
# %convert_element_type_1077 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_151, torch.float32), kwargs = {})
# %convert_element_type_1080 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%wait_tensor_152, torch.float32), kwargs = {})
# %add_209 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_1077, %convert_element_type_1080), kwargs = {})
# return %add_209
triton_poi_fused__to_copy_add_43 = async_compile.triton('triton_poi_fused__to_copy_add_43', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 1},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_43', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'D2386747DC7DD0AECB9BA32040289DCFF8D245CBE872A9EAD656A275916E43AA', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': True, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': True, 'has_loadstore_with_contiguous_rdim': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_43(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)[:]
tmp0 = tl.load(in_ptr0 + (0)).to(tl.float32)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
tmp3 = tl.load(in_ptr1 + (0)).to(tl.float32)
tmp4 = tl.broadcast_to(tmp3, [XBLOCK])
tmp2 = tmp1.to(tl.float32)
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp2 + tmp5
tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32).broadcast_to(XBLOCK)), tmp6, None)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
class Runner:
def __init__(self, partitions):
self.partitions = partitions
def recursively_apply_fns(self, fns):
new_callables = []
for fn, c in zip(fns, self.partitions):
new_callables.append(fn(c))
self.partitions = new_callables
def call(self, args):
primals_16, primals_20, primals_23, primals_26, primals_1, primals_4, primals_5, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_17, primals_18, primals_19, primals_21, primals_22, primals_24, primals_25, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_167, select_1, cos, add_11, add_14, add_19, add_22, add_27, add_30, add_35, add_38, add_43, add_46, add_51, add_54, add_59, add_62, add_67, add_70, add_75, add_78, add_83, add_86, add_91, add_94, add_99, add_102, add_107, getitem_89, rsqrt_26, tangents_1 = args
args.clear()
s91 = primals_16
s6 = primals_20
s16 = primals_23
s18 = primals_26
assert_size_stride(primals_1, (327680, 785), (785, 1))
assert_size_stride(primals_4, (2048, 768), (768, 1))
assert_size_stride(primals_5, (2048, ), (1, ))
assert_size_stride(primals_9, (2048, ), (1, ))
assert_size_stride(primals_10, (2048, ), (1, ))
assert_size_stride(primals_11, (2048, 2048), (2048, 1))
assert_size_stride(primals_12, (512, 2048), (2048, 1))
assert_size_stride(primals_13, (512, 2048), (2048, 1))
assert_size_stride(primals_14, (327680, ), (1, ))
assert_size_stride(primals_15, (327680, ), (1, ))
assert_size_stride(primals_17, (1, 1, 2560, s91), (2560*s91, 2560*s91, s91, 1))
assert_size_stride(primals_18, (1, 1, 2560), (2560, 2560, 1))
assert_size_stride(primals_19, (1, 1, 2560), (2560, 2560, 1))
assert_size_stride(primals_21, (1, 1, 2560, s6), (2560*s6, 2560*s6, s6, 1))
assert_size_stride(primals_22, (1, 1, 2560), (2560, 2560, 1))
assert_size_stride(primals_24, (1, 1, 2560, s16), (2560*s16, 2560*s16, s16, 1))
assert_size_stride(primals_25, (1, 1, 2560), (2560, 2560, 1))
assert_size_stride(primals_27, (1, 1, 2560, s18), (2560*s18, 2560*s18, s18, 1))
assert_size_stride(primals_28, (2048, 2048), (2048, 1))
assert_size_stride(primals_29, (2048, ), (1, ))
assert_size_stride(primals_30, (2048, ), (1, ))
assert_size_stride(primals_31, (8192, 2048), (2048, 1))
assert_size_stride(primals_32, (8192, ), (1, ))
assert_size_stride(primals_33, (2048, 8192), (8192, 1))
assert_size_stride(primals_35, (2048, ), (1, ))
assert_size_stride(primals_36, (2048, ), (1, ))
assert_size_stride(primals_37, (2048, 2048), (2048, 1))
assert_size_stride(primals_38, (512, 2048), (2048, 1))
assert_size_stride(primals_39, (512, 2048), (2048, 1))
assert_size_stride(primals_40, (2048, 2048), (2048, 1))
assert_size_stride(primals_41, (2048, ), (1, ))
assert_size_stride(primals_42, (2048, ), (1, ))
assert_size_stride(primals_43, (8192, 2048), (2048, 1))
assert_size_stride(primals_44, (8192, ), (1, ))
assert_size_stride(primals_45, (2048, 8192), (8192, 1))
assert_size_stride(primals_47, (2048, ), (1, ))
assert_size_stride(primals_48, (2048, ), (1, ))
assert_size_stride(primals_49, (2048, 2048), (2048, 1))
assert_size_stride(primals_50, (512, 2048), (2048, 1))
assert_size_stride(primals_51, (512, 2048), (2048, 1))
assert_size_stride(primals_52, (2048, 2048), (2048, 1))
assert_size_stride(primals_53, (2048, ), (1, ))
assert_size_stride(primals_54, (2048, ), (1, ))
assert_size_stride(primals_55, (8192, 2048), (2048, 1))
assert_size_stride(primals_56, (8192, ), (1, ))
assert_size_stride(primals_57, (2048, 8192), (8192, 1))
assert_size_stride(primals_59, (2048, ), (1, ))
assert_size_stride(primals_60, (2048, ), (1, ))
assert_size_stride(primals_61, (2048, 2048), (2048, 1))
assert_size_stride(primals_62, (512, 2048), (2048, 1))
assert_size_stride(primals_63, (512, 2048), (2048, 1))
assert_size_stride(primals_64, (2048, 2048), (2048, 1))
assert_size_stride(primals_65, (2048, ), (1, ))
assert_size_stride(primals_66, (2048, ), (1, ))
assert_size_stride(primals_67, (8192, 2048), (2048, 1))
assert_size_stride(primals_68, (8192, ), (1, ))
assert_size_stride(primals_69, (2048, 8192), (8192, 1))
assert_size_stride(primals_71, (2048, ), (1, ))
assert_size_stride(primals_72, (2048, ), (1, ))
assert_size_stride(primals_73, (2048, 2048), (2048, 1))
assert_size_stride(primals_74, (512, 2048), (2048, 1))
assert_size_stride(primals_75, (512, 2048), (2048, 1))
assert_size_stride(primals_76, (2048, 2048), (2048, 1))
assert_size_stride(primals_77, (2048, ), (1, ))
assert_size_stride(primals_78, (2048, ), (1, ))
assert_size_stride(primals_79, (8192, 2048), (2048, 1))
assert_size_stride(primals_80, (8192, ), (1, ))
assert_size_stride(primals_81, (2048, 8192), (8192, 1))
assert_size_stride(primals_83, (2048, ), (1, ))
assert_size_stride(primals_84, (2048, ), (1, ))
assert_size_stride(primals_85, (2048, 2048), (2048, 1))
assert_size_stride(primals_86, (512, 2048), (2048, 1))
assert_size_stride(primals_87, (512, 2048), (2048, 1))
assert_size_stride(primals_88, (2048, 2048), (2048, 1))
assert_size_stride(primals_89, (2048, ), (1, ))
assert_size_stride(primals_90, (2048, ), (1, ))
assert_size_stride(primals_91, (8192, 2048), (2048, 1))
assert_size_stride(primals_92, (8192, ), (1, ))
assert_size_stride(primals_93, (2048, 8192), (8192, 1))
assert_size_stride(primals_95, (2048, ), (1, ))
assert_size_stride(primals_96, (2048, ), (1, ))
assert_size_stride(primals_97, (2048, 2048), (2048, 1))
assert_size_stride(primals_98, (512, 2048), (2048, 1))
assert_size_stride(primals_99, (512, 2048), (2048, 1))
assert_size_stride(primals_100, (2048, 2048), (2048, 1))
assert_size_stride(primals_101, (2048, ), (1, ))
assert_size_stride(primals_102, (2048, ), (1, ))
assert_size_stride(primals_103, (8192, 2048), (2048, 1))
assert_size_stride(primals_104, (8192, ), (1, ))
assert_size_stride(primals_105, (2048, 8192), (8192, 1))
assert_size_stride(primals_107, (2048, ), (1, ))
assert_size_stride(primals_108, (2048, ), (1, ))
assert_size_stride(primals_109, (2048, 2048), (2048, 1))
assert_size_stride(primals_110, (512, 2048), (2048, 1))
assert_size_stride(primals_111, (512, 2048), (2048, 1))
assert_size_stride(primals_112, (2048, 2048), (2048, 1))
assert_size_stride(primals_113, (2048, ), (1, ))
assert_size_stride(primals_114, (2048, ), (1, ))
assert_size_stride(primals_115, (8192, 2048), (2048, 1))
assert_size_stride(primals_116, (8192, ), (1, ))
assert_size_stride(primals_117, (2048, 8192), (8192, 1))
assert_size_stride(primals_119, (2048, ), (1, ))
assert_size_stride(primals_120, (2048, ), (1, ))
assert_size_stride(primals_121, (2048, 2048), (2048, 1))
assert_size_stride(primals_122, (512, 2048), (2048, 1))
assert_size_stride(primals_123, (512, 2048), (2048, 1))
assert_size_stride(primals_124, (2048, 2048), (2048, 1))
assert_size_stride(primals_125, (2048, ), (1, ))
assert_size_stride(primals_126, (2048, ), (1, ))
assert_size_stride(primals_127, (8192, 2048), (2048, 1))
assert_size_stride(primals_128, (8192, ), (1, ))
assert_size_stride(primals_129, (2048, 8192), (8192, 1))
assert_size_stride(primals_131, (2048, ), (1, ))
assert_size_stride(primals_132, (2048, ), (1, ))
assert_size_stride(primals_133, (2048, 2048), (2048, 1))
assert_size_stride(primals_134, (512, 2048), (2048, 1))
assert_size_stride(primals_135, (512, 2048), (2048, 1))
assert_size_stride(primals_136, (2048, 2048), (2048, 1))
assert_size_stride(primals_137, (2048, ), (1, ))
assert_size_stride(primals_138, (2048, ), (1, ))
assert_size_stride(primals_139, (8192, 2048), (2048, 1))
assert_size_stride(primals_140, (8192, ), (1, ))
assert_size_stride(primals_141, (2048, 8192), (8192, 1))
assert_size_stride(primals_143, (2048, ), (1, ))
assert_size_stride(primals_144, (2048, ), (1, ))
assert_size_stride(primals_145, (2048, 2048), (2048, 1))
assert_size_stride(primals_146, (512, 2048), (2048, 1))
assert_size_stride(primals_147, (512, 2048), (2048, 1))
assert_size_stride(primals_148, (2048, 2048), (2048, 1))
assert_size_stride(primals_149, (2048, ), (1, ))
assert_size_stride(primals_150, (2048, ), (1, ))
assert_size_stride(primals_151, (8192, 2048), (2048, 1))
assert_size_stride(primals_152, (8192, ), (1, ))
assert_size_stride(primals_153, (2048, 8192), (8192, 1))
assert_size_stride(primals_155, (2048, ), (1, ))
assert_size_stride(primals_156, (2048, ), (1, ))
assert_size_stride(primals_157, (2048, 2048), (2048, 1))
assert_size_stride(primals_158, (512, 2048), (2048, 1))
assert_size_stride(primals_159, (512, 2048), (2048, 1))
assert_size_stride(primals_160, (2048, 2048), (2048, 1))
assert_size_stride(primals_161, (2048, ), (1, ))
assert_size_stride(primals_162, (2048, ), (1, ))
assert_size_stride(primals_163, (8192, 2048), (2048, 1))
assert_size_stride(primals_164, (8192, ), (1, ))
assert_size_stride(primals_165, (2048, 8192), (8192, 1))
assert_size_stride(primals_167, (2048, ), (1, ))
assert_size_stride(select_1, (327680, ), (1, ))
assert_size_stride(cos, (327680, 1024), (1024, 1))
assert_size_stride(add_11, (327680, 2048), (2048, 1))
assert_size_stride(add_14, (327680, 2048), (2048, 1))
assert_size_stride(add_19, (327680, 2048), (2048, 1))
assert_size_stride(add_22, (327680, 2048), (2048, 1))
assert_size_stride(add_27, (327680, 2048), (2048, 1))
assert_size_stride(add_30, (327680, 2048), (2048, 1))
assert_size_stride(add_35, (327680, 2048), (2048, 1))
assert_size_stride(add_38, (327680, 2048), (2048, 1))
assert_size_stride(add_43, (327680, 2048), (2048, 1))
assert_size_stride(add_46, (327680, 2048), (2048, 1))
assert_size_stride(add_51, (327680, 2048), (2048, 1))
assert_size_stride(add_54, (327680, 2048), (2048, 1))
assert_size_stride(add_59, (327680, 2048), (2048, 1))
assert_size_stride(add_62, (327680, 2048), (2048, 1))
assert_size_stride(add_67, (327680, 2048), (2048, 1))
assert_size_stride(add_70, (327680, 2048), (2048, 1))
assert_size_stride(add_75, (327680, 2048), (2048, 1))
assert_size_stride(add_78, (327680, 2048), (2048, 1))
assert_size_stride(add_83, (327680, 2048), (2048, 1))
assert_size_stride(add_86, (327680, 2048), (2048, 1))
assert_size_stride(add_91, (327680, 2048), (2048, 1))
assert_size_stride(add_94, (327680, 2048), (2048, 1))
assert_size_stride(add_99, (327680, 2048), (2048, 1))
assert_size_stride(add_102, (327680, 2048), (2048, 1))
assert_size_stride(add_107, (327680, 2048), (2048, 1))
assert_size_stride(getitem_89, (327680, 1), (1, 1))
assert_size_stride(rsqrt_26, (327680, 1), (1, 1))
assert_size_stride(tangents_1, (327680, 2048), (2048, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf6 = empty_strided_cuda((327680, 2048), (2048, 1), torch.bfloat16)
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [convert_element_type_410, redistribute_4, convert_element_type_412, mul_153, mul_154, sum_1, x_26, mul_155, sum_2, mul_156, sub_29, sub_30, div_3, mul_157, mul_158, sum_3, sum_4, convert_element_type_414], Original ATen: [aten.native_layer_norm_backward, aten._to_copy, aten.native_layer_norm]
workspace_0 = empty_strided_cuda((10485760, ), (1, ), torch.float32)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_native_layer_norm_native_layer_norm_backward_0.run(tangents_1, primals_167, add_107, getitem_89, rsqrt_26, buf6, workspace_0, 327680, 2048, stream=stream0)
buf3 = workspace_0[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf5 = workspace_0[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del workspace_0
del add_107
del getitem_89
del primals_167
del rsqrt_26
del tangents_1
buf13 = empty_strided_cuda((2048, ), (1, ), torch.bfloat16)
# Topologically Sorted Source Nodes: [convert_element_type_415, all_reduce_1], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf3, buf13, 2048, stream=stream0)
buf7 = empty_strided_cuda((2048, ), (1, ), torch.bfloat16)
# Topologically Sorted Source Nodes: [convert_element_type_416, all_reduce], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf5, buf7, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_416, all_reduce], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf7, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf7)
buf12 = buf5; del buf5 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_417], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf7, buf12, 2048, stream=stream0)
del buf7
# Topologically Sorted Source Nodes: [convert_element_type_415, all_reduce_1], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf13, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_1], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf13)
buf18 = buf3; del buf3 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_418], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf13, buf18, 2048, stream=stream0)
del buf13
buf19 = empty_strided_cuda((2048, 8192), (8192, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_165, buf19, 16777216, stream=stream0)
del primals_165
buf20 = empty_strided_cuda((327680, 8192), (8192, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [redistribute_4, x_26, permute_121, mm_49], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf6, buf19, out=buf20)
buf21 = empty_strided_cuda((327680, 1), (1, 327680), torch.float32)
buf22 = empty_strided_cuda((327680, 1), (1, 327680), torch.float32)
buf24 = empty_strided_cuda((327680, 2048), (2048, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_102, primals_161, primals_162, buf21, buf22, buf24, 327680, 2048, stream=stream0)
del primals_162
buf25 = reinterpret_tensor(buf19, (2048, 8192), (1, 2048), 0); del buf19 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_163, buf25, 16777216, stream=stream0)
del primals_163
buf26 = empty_strided_cuda((327680, 8192), (8192, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm]
extern_kernels.mm(buf24, buf25, out=buf26)
buf27 = empty_strided_cuda((327680, 8192), (8192, 1), torch.bfloat16)
buf41 = buf20; del buf20 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, x, x_25, convert_element_type_425, mul_164, mul_165, sub_31, mul_166, add_112, mul_167, mul_168, mul_169, add_113, mul_170, convert_element_type_427], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf41, primals_164, buf26, buf27, 2684354560, stream=stream0)
del buf26
del primals_164
buf28 = empty_strided_cuda((2048, 8192), (8192, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [permute_122, redistribute_3, x, x_25, mm_50], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf6, (2048, 327680), (1, 2048), 0), buf27, out=buf28)
buf29 = empty_strided_cuda((1, 2048, 160), (327680, 1, 2048), torch.float32)
# Topologically Sorted Source Nodes: [sum_5], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_6.run(buf6, buf29, 327680, 2048, stream=stream0)
buf30 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [sum_5], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_7.run(buf29, buf30, 2048, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_279, all_reduce_2], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf30, (2048, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_279, wait_tensor_2], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf30, (2048, ), (1, ), 0))
buf35 = empty_strided_cuda((2048, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_279, convert_element_type_423], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf30, buf35, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_3], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf28, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_3], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf28)
buf40 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_424], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf28, buf40, 16777216, stream=stream0)
buf42 = empty_strided_cuda((327680, 2048), (2048, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [permute_125, mm_51], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(buf41, reinterpret_tensor(buf25, (8192, 2048), (2048, 1), 0), out=buf42)
buf43 = reinterpret_tensor(buf25, (8192, 2048), (2048, 1), 0); del buf25 # reuse
# Topologically Sorted Source Nodes: [permute_126, mm_52], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf41, (8192, 327680), (1, 8192), 0), buf24, out=buf43)
buf44 = empty_strided_cuda((1, 8192, 160), (1310720, 1, 8192), torch.float32)
# Topologically Sorted Source Nodes: [sum_6], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_9.run(buf41, buf44, 1310720, 2048, stream=stream0)
buf45 = empty_strided_cuda((1, 8192), (8192, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [sum_6], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_10.run(buf44, buf45, 8192, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_280, all_reduce_4], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf45, (8192, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_280, wait_tensor_4], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf45, (8192, ), (1, ), 0))
buf50 = empty_strided_cuda((8192, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_280, convert_element_type_432], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_view_11.run(buf45, buf50, 8192, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_5], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf43, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_5], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf43)
buf55 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_433], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf43, buf55, 16777216, stream=stream0)
buf62 = buf6; del buf6 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_434, convert_element_type_436, mul_172, mul_173, sum_7, mul_174, sum_8, mul_175, sub_33, sub_34, div_4, mul_176, mul_177, sum_9, sum_10, convert_element_type_438, add_114], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add]
workspace_1 = empty_strided_cuda((10485760, ), (1, ), torch.float32)
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf62, buf42, primals_161, add_102, buf21, buf22, workspace_1, 327680, 2048, stream=stream0)
buf59 = workspace_1[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf61 = workspace_1[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_102
del primals_161
buf69 = reinterpret_tensor(buf30, (2048, ), (1, ), 0); del buf30 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_439, all_reduce_7], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf59, buf69, 2048, stream=stream0)
buf63 = empty_strided_cuda((2048, ), (1, ), torch.bfloat16)
# Topologically Sorted Source Nodes: [convert_element_type_440, all_reduce_6], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf61, buf63, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_440, all_reduce_6], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf63, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_6], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf63)
buf68 = buf61; del buf61 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_441], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf63, buf68, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_439, all_reduce_7], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf69, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_7], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf69)
buf74 = buf59; del buf59 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_442], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf69, buf74, 2048, stream=stream0)
buf75 = buf22; del buf22 # reuse
buf76 = buf21; del buf21 # reuse
buf78 = buf42; del buf42 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_99, primals_155, primals_156, buf75, buf76, buf78, 327680, 2048, stream=stream0)
del primals_156
buf79 = empty_strided_cuda((2048, 2048), (1, 2048), torch.bfloat16)
# Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_157, buf79, 4194304, stream=stream0)
del primals_157
buf80 = buf24; del buf24 # reuse
# Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm]
extern_kernels.mm(buf78, buf79, out=buf80)
buf81 = empty_strided_cuda((2048, 512), (1, 2048), torch.bfloat16)
# Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_158, buf81, 1048576, stream=stream0)
del primals_158
buf82 = empty_strided_cuda((327680, 512), (512, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm]
extern_kernels.mm(buf78, buf81, out=buf82)
buf83 = empty_strided_cuda((2048, 512), (1, 2048), torch.bfloat16)
# Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_159, buf83, 1048576, stream=stream0)
del primals_159
buf84 = empty_strided_cuda((327680, 512), (512, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm]
extern_kernels.mm(buf78, buf83, out=buf84)
buf85 = empty_strided_cuda((1, 16, 327680), (5242880, 327680, 1), torch.float32)
buf86 = empty_strided_cuda((1, 16, 327680), (5242880, 327680, 1), torch.float32)
buf87 = empty_strided_cuda((1, 16, 327680, 128), (671088640, 128, 2048, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_permute_view_15.run(buf80, buf82, buf84, buf85, buf86, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf87, s91, 2560, 1, 16, stream=stream0)
buf90 = empty_strided_cuda((2048, 2048), (2048, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [permute_129, rearrange_3, o, mm_53], Original ATen: [aten.t, aten.permute, aten.view, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf62, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf87, (327680, 2048), (2048, 1), 0), out=buf90)
buf91 = empty_strided_cuda((2048, 2048), (2048, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_160, buf91, 4194304, stream=stream0)
del primals_160
buf92 = empty_strided_cuda((327680, 2048), (2048, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [redistribute_5, o, permute_131, mm_54], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf62, buf91, out=buf92)
# Topologically Sorted Source Nodes: [all_reduce_8], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf90, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_8], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf90)
buf97 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_447], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf90, buf97, 4194304, stream=stream0)
buf99 = buf86; del buf86 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_283, view_284, permute_133, flex_attention_backward], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_per_fused_flex_attention_backward_permute_view_17.run(buf87, buf92, buf99, 5242880, 128, stream=stream0)
buf100 = buf87; del buf87 # reuse
buf101 = empty_strided_cuda((1, 4, 327680, 128), (167772160, 128, 512, 1), torch.bfloat16)
buf102 = empty_strided_cuda((1, 4, 327680, 128), (167772160, 128, 512, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [q, k, v, view_283, view_284, permute_133, flex_attention_backward], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_backward_permute_view_18.run(buf80, buf82, buf84, buf85, buf99, buf92, buf100, buf101, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf102, s91, s16, 12800, 1, 4, stream=stream0)
del buf82
buf105 = empty_strided_cuda((512, 2048), (2048, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [view_285, permute_134, view_286, permute_135, mm_55], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf101, (512, 327680), (1, 512), 0), buf78, out=buf105)
buf106 = buf92; del buf92 # reuse
# Topologically Sorted Source Nodes: [view_285, permute_134, view_286, permute_137, mm_56], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf101, (327680, 512), (512, 1), 0), reinterpret_tensor(buf83, (512, 2048), (2048, 1), 0), out=buf106)
# Topologically Sorted Source Nodes: [all_reduce_9], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf105, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_9], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf105)
buf111 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_452], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf105, buf111, 1048576, stream=stream0)
buf112 = buf105; del buf105 # reuse
# Topologically Sorted Source Nodes: [view_287, permute_139, view_288, permute_140, mm_57], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf102, (512, 327680), (1, 512), 0), buf78, out=buf112)
buf113 = buf80; del buf80 # reuse
# Topologically Sorted Source Nodes: [view_287, permute_139, view_288, permute_142, mm_58], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf102, (327680, 512), (512, 1), 0), reinterpret_tensor(buf81, (512, 2048), (2048, 1), 0), out=buf113)
# Topologically Sorted Source Nodes: [all_reduce_10], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf112, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_10], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf112)
buf118 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_457], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf112, buf118, 1048576, stream=stream0)
buf119 = buf90; del buf90 # reuse
# Topologically Sorted Source Nodes: [view_289, permute_144, view_290, permute_145, mm_59], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf100, (2048, 327680), (1, 2048), 0), buf78, out=buf119)
buf120 = buf78; del buf78 # reuse
# Topologically Sorted Source Nodes: [view_289, permute_144, view_290, permute_147, mm_60], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf100, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf79, (2048, 2048), (2048, 1), 0), out=buf120)
del buf100
# Topologically Sorted Source Nodes: [all_reduce_11], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf119, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_11], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf119)
buf125 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_462], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf119, buf125, 4194304, stream=stream0)
buf133 = buf62; del buf62 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, add_115, add_116, convert_element_type_463, convert_element_type_465, mul_179, mul_180, sum_11, mul_181, sum_12, mul_182, sub_36, sub_37, div_5, mul_183, mul_184, sum_13, sum_14, convert_element_type_467, add_117], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward]
workspace_2 = workspace_1; del workspace_1 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf133, buf106, buf113, buf120, primals_155, add_99, buf75, buf76, workspace_2, 327680, 2048, stream=stream0)
buf130 = workspace_2[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf132 = workspace_2[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_99
del buf106
del primals_155
buf140 = buf69; del buf69 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_468, all_reduce_13], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf130, buf140, 2048, stream=stream0)
buf134 = buf63; del buf63 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_469, all_reduce_12], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf132, buf134, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_469, all_reduce_12], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf134, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_12], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf134)
buf139 = buf132; del buf132 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_470], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf134, buf139, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_468, all_reduce_13], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf140, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_13], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf140)
buf145 = buf130; del buf130 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_471], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf140, buf145, 2048, stream=stream0)
buf146 = reinterpret_tensor(buf43, (2048, 8192), (8192, 1), 0); del buf43 # reuse
# Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_153, buf146, 16777216, stream=stream0)
del primals_153
buf147 = buf41; del buf41 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, x_24, permute_149, mm_61], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf133, buf146, out=buf147)
buf148 = buf76; del buf76 # reuse
buf149 = buf75; del buf75 # reuse
buf151 = buf120; del buf120 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_94, primals_149, primals_150, buf148, buf149, buf151, 327680, 2048, stream=stream0)
del primals_150
buf152 = reinterpret_tensor(buf146, (2048, 8192), (1, 2048), 0); del buf146 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_151, buf152, 16777216, stream=stream0)
del primals_151
buf153 = buf27; del buf27 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm]
extern_kernels.mm(buf151, buf152, out=buf153)
buf154 = empty_strided_cuda((327680, 8192), (8192, 1), torch.bfloat16)
buf168 = buf147; del buf147 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, x, x_23, convert_element_type_478, mul_190, mul_191, sub_38, mul_192, add_120, mul_193, mul_194, mul_195, add_121, mul_196, convert_element_type_480], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf168, primals_152, buf153, buf154, 2684354560, stream=stream0)
del primals_152
buf155 = buf28; del buf28 # reuse
# Topologically Sorted Source Nodes: [permute_150, redistribute_3, x, x_23, mm_62], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf133, (2048, 327680), (1, 2048), 0), buf154, out=buf155)
buf156 = buf29; del buf29 # reuse
# Topologically Sorted Source Nodes: [sum_15], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_6.run(buf133, buf156, 327680, 2048, stream=stream0)
buf157 = reinterpret_tensor(buf140, (1, 2048), (2048, 1), 0); del buf140 # reuse
# Topologically Sorted Source Nodes: [sum_15], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_7.run(buf156, buf157, 2048, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_291, all_reduce_14], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf157, (2048, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_291, wait_tensor_14], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf157, (2048, ), (1, ), 0))
buf162 = empty_strided_cuda((2048, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_291, convert_element_type_476], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf157, buf162, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_15], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf155, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_15], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf155)
buf167 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_477], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf155, buf167, 16777216, stream=stream0)
buf169 = buf113; del buf113 # reuse
# Topologically Sorted Source Nodes: [permute_153, mm_63], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(buf168, reinterpret_tensor(buf152, (8192, 2048), (2048, 1), 0), out=buf169)
buf170 = reinterpret_tensor(buf152, (8192, 2048), (2048, 1), 0); del buf152 # reuse
# Topologically Sorted Source Nodes: [permute_154, mm_64], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf168, (8192, 327680), (1, 8192), 0), buf151, out=buf170)
buf171 = buf44; del buf44 # reuse
# Topologically Sorted Source Nodes: [sum_16], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_9.run(buf168, buf171, 1310720, 2048, stream=stream0)
buf172 = buf45; del buf45 # reuse
# Topologically Sorted Source Nodes: [sum_16], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_10.run(buf171, buf172, 8192, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_292, all_reduce_16], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf172, (8192, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_292, wait_tensor_16], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf172, (8192, ), (1, ), 0))
buf177 = empty_strided_cuda((8192, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_292, convert_element_type_485], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_view_11.run(buf172, buf177, 8192, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_17], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf170, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_17], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf170)
buf182 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_486], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf170, buf182, 16777216, stream=stream0)
buf189 = buf133; del buf133 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_487, convert_element_type_489, mul_198, mul_199, sum_17, mul_200, sum_18, mul_201, sub_40, sub_41, div_6, mul_202, mul_203, sum_19, sum_20, convert_element_type_491, add_122], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add]
workspace_3 = workspace_2; del workspace_2 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf189, buf169, primals_149, add_94, buf148, buf149, workspace_3, 327680, 2048, stream=stream0)
buf186 = workspace_3[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf188 = workspace_3[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_94
del primals_149
buf196 = reinterpret_tensor(buf157, (2048, ), (1, ), 0); del buf157 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_492, all_reduce_19], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf186, buf196, 2048, stream=stream0)
buf190 = buf134; del buf134 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_493, all_reduce_18], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf188, buf190, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_493, all_reduce_18], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf190, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_18], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf190)
buf195 = buf188; del buf188 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_494], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf190, buf195, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_492, all_reduce_19], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf196, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_19], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf196)
buf201 = buf186; del buf186 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_495], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf196, buf201, 2048, stream=stream0)
buf202 = buf149; del buf149 # reuse
buf203 = buf148; del buf148 # reuse
buf205 = buf169; del buf169 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_91, primals_143, primals_144, buf202, buf203, buf205, 327680, 2048, stream=stream0)
del primals_144
buf206 = reinterpret_tensor(buf119, (2048, 2048), (1, 2048), 0); del buf119 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_145, buf206, 4194304, stream=stream0)
del primals_145
buf207 = buf151; del buf151 # reuse
# Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm]
extern_kernels.mm(buf205, buf206, out=buf207)
buf208 = reinterpret_tensor(buf112, (2048, 512), (1, 2048), 0); del buf112 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_146, buf208, 1048576, stream=stream0)
del primals_146
buf209 = reinterpret_tensor(buf102, (327680, 512), (512, 1), 0); del buf102 # reuse
# Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm]
extern_kernels.mm(buf205, buf208, out=buf209)
buf210 = buf81; del buf81 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_147, buf210, 1048576, stream=stream0)
del primals_147
buf211 = reinterpret_tensor(buf101, (327680, 512), (512, 1), 0); del buf101 # reuse
# Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm]
extern_kernels.mm(buf205, buf210, out=buf211)
buf212 = buf99; del buf99 # reuse
buf213 = buf85; del buf85 # reuse
buf214 = empty_strided_cuda((1, 16, 327680, 128), (671088640, 128, 2048, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_permute_view_15.run(buf207, buf209, buf211, buf212, buf213, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf214, s91, 2560, 1, 16, stream=stream0)
buf217 = reinterpret_tensor(buf79, (2048, 2048), (2048, 1), 0); del buf79 # reuse
# Topologically Sorted Source Nodes: [permute_157, rearrange_3, o, mm_65], Original ATen: [aten.t, aten.permute, aten.view, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf189, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf214, (327680, 2048), (2048, 1), 0), out=buf217)
buf218 = buf91; del buf91 # reuse
# Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_148, buf218, 4194304, stream=stream0)
del primals_148
buf219 = empty_strided_cuda((327680, 2048), (2048, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [redistribute_5, o, permute_159, mm_66], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf189, buf218, out=buf219)
# Topologically Sorted Source Nodes: [all_reduce_20], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf217, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_20], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf217)
buf224 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_500], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf217, buf224, 4194304, stream=stream0)
buf226 = buf213; del buf213 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_295, view_296, permute_161, flex_attention_backward_1], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_per_fused_flex_attention_backward_permute_view_17.run(buf214, buf219, buf226, 5242880, 128, stream=stream0)
buf227 = buf214; del buf214 # reuse
buf228 = reinterpret_tensor(buf84, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf84 # reuse
buf229 = empty_strided_cuda((1, 4, 327680, 128), (167772160, 128, 512, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [q, k, v, view_295, view_296, permute_161, flex_attention_backward_1], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_backward_permute_view_18.run(buf207, buf209, buf211, buf212, buf226, buf219, buf227, buf228, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf229, s91, s16, 12800, 1, 4, stream=stream0)
del buf209
del buf211
buf232 = reinterpret_tensor(buf83, (512, 2048), (2048, 1), 0); del buf83 # reuse
# Topologically Sorted Source Nodes: [view_297, permute_162, view_298, permute_163, mm_67], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf228, (512, 327680), (1, 512), 0), buf205, out=buf232)
buf233 = buf219; del buf219 # reuse
# Topologically Sorted Source Nodes: [view_297, permute_162, view_298, permute_165, mm_68], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf228, (327680, 512), (512, 1), 0), reinterpret_tensor(buf210, (512, 2048), (2048, 1), 0), out=buf233)
# Topologically Sorted Source Nodes: [all_reduce_21], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf232, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_21], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf232)
buf238 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_505], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf232, buf238, 1048576, stream=stream0)
buf239 = buf232; del buf232 # reuse
# Topologically Sorted Source Nodes: [view_299, permute_167, view_300, permute_168, mm_69], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf229, (512, 327680), (1, 512), 0), buf205, out=buf239)
buf240 = buf207; del buf207 # reuse
# Topologically Sorted Source Nodes: [view_299, permute_167, view_300, permute_170, mm_70], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf229, (327680, 512), (512, 1), 0), reinterpret_tensor(buf208, (512, 2048), (2048, 1), 0), out=buf240)
# Topologically Sorted Source Nodes: [all_reduce_22], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf239, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_22], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf239)
buf245 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_510], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf239, buf245, 1048576, stream=stream0)
buf246 = buf217; del buf217 # reuse
# Topologically Sorted Source Nodes: [view_301, permute_172, view_302, permute_173, mm_71], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf227, (2048, 327680), (1, 2048), 0), buf205, out=buf246)
buf247 = buf205; del buf205 # reuse
# Topologically Sorted Source Nodes: [view_301, permute_172, view_302, permute_175, mm_72], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf227, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf206, (2048, 2048), (2048, 1), 0), out=buf247)
# Topologically Sorted Source Nodes: [all_reduce_23], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf246, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_23], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf246)
buf252 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_515], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf246, buf252, 4194304, stream=stream0)
buf260 = buf189; del buf189 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, add_123, add_124, convert_element_type_516, convert_element_type_518, mul_205, mul_206, sum_21, mul_207, sum_22, mul_208, sub_43, sub_44, div_7, mul_209, mul_210, sum_23, sum_24, convert_element_type_520, add_125], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward]
workspace_4 = workspace_3; del workspace_3 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf260, buf233, buf240, buf247, primals_143, add_91, buf202, buf203, workspace_4, 327680, 2048, stream=stream0)
buf257 = workspace_4[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf259 = workspace_4[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_91
del primals_143
buf267 = buf196; del buf196 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_521, all_reduce_25], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf257, buf267, 2048, stream=stream0)
buf261 = buf190; del buf190 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_522, all_reduce_24], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf259, buf261, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_522, all_reduce_24], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf261, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_24], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf261)
buf266 = buf259; del buf259 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_523], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf261, buf266, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_521, all_reduce_25], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf267, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_25], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf267)
buf272 = buf257; del buf257 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_524], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf267, buf272, 2048, stream=stream0)
buf273 = reinterpret_tensor(buf170, (2048, 8192), (8192, 1), 0); del buf170 # reuse
# Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_141, buf273, 16777216, stream=stream0)
del primals_141
buf274 = buf168; del buf168 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, x_22, permute_177, mm_73], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf260, buf273, out=buf274)
buf275 = buf203; del buf203 # reuse
buf276 = buf202; del buf202 # reuse
buf278 = buf247; del buf247 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_86, primals_137, primals_138, buf275, buf276, buf278, 327680, 2048, stream=stream0)
del primals_138
buf279 = reinterpret_tensor(buf273, (2048, 8192), (1, 2048), 0); del buf273 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_139, buf279, 16777216, stream=stream0)
del primals_139
buf280 = buf154; del buf154 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm]
extern_kernels.mm(buf278, buf279, out=buf280)
buf281 = buf153; del buf153 # reuse
buf295 = buf274; del buf274 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, x, x_21, convert_element_type_531, mul_216, mul_217, sub_45, mul_218, add_128, mul_219, mul_220, mul_221, add_129, mul_222, convert_element_type_533], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf295, primals_140, buf280, buf281, 2684354560, stream=stream0)
del primals_140
buf282 = buf155; del buf155 # reuse
# Topologically Sorted Source Nodes: [permute_178, redistribute_3, x, x_21, mm_74], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf260, (2048, 327680), (1, 2048), 0), buf281, out=buf282)
buf283 = buf156; del buf156 # reuse
# Topologically Sorted Source Nodes: [sum_25], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_6.run(buf260, buf283, 327680, 2048, stream=stream0)
buf284 = reinterpret_tensor(buf267, (1, 2048), (2048, 1), 0); del buf267 # reuse
# Topologically Sorted Source Nodes: [sum_25], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_7.run(buf283, buf284, 2048, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_303, all_reduce_26], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf284, (2048, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_303, wait_tensor_26], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf284, (2048, ), (1, ), 0))
buf289 = empty_strided_cuda((2048, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_303, convert_element_type_529], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf284, buf289, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_27], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf282, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_27], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf282)
buf294 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_530], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf282, buf294, 16777216, stream=stream0)
buf296 = buf240; del buf240 # reuse
# Topologically Sorted Source Nodes: [permute_181, mm_75], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(buf295, reinterpret_tensor(buf279, (8192, 2048), (2048, 1), 0), out=buf296)
buf297 = reinterpret_tensor(buf279, (8192, 2048), (2048, 1), 0); del buf279 # reuse
# Topologically Sorted Source Nodes: [permute_182, mm_76], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf295, (8192, 327680), (1, 8192), 0), buf278, out=buf297)
buf298 = buf171; del buf171 # reuse
# Topologically Sorted Source Nodes: [sum_26], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_9.run(buf295, buf298, 1310720, 2048, stream=stream0)
buf299 = buf172; del buf172 # reuse
# Topologically Sorted Source Nodes: [sum_26], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_10.run(buf298, buf299, 8192, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_304, all_reduce_28], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf299, (8192, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_304, wait_tensor_28], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf299, (8192, ), (1, ), 0))
buf304 = empty_strided_cuda((8192, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_304, convert_element_type_538], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_view_11.run(buf299, buf304, 8192, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_29], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf297, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_29], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf297)
buf309 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_539], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf297, buf309, 16777216, stream=stream0)
buf316 = buf260; del buf260 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_540, convert_element_type_542, mul_224, mul_225, sum_27, mul_226, sum_28, mul_227, sub_47, sub_48, div_8, mul_228, mul_229, sum_29, sum_30, convert_element_type_544, add_130], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add]
workspace_5 = workspace_4; del workspace_4 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf316, buf296, primals_137, add_86, buf275, buf276, workspace_5, 327680, 2048, stream=stream0)
buf313 = workspace_5[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf315 = workspace_5[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_86
del primals_137
buf323 = reinterpret_tensor(buf284, (2048, ), (1, ), 0); del buf284 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_545, all_reduce_31], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf313, buf323, 2048, stream=stream0)
buf317 = buf261; del buf261 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_546, all_reduce_30], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf315, buf317, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_546, all_reduce_30], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf317, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_30], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf317)
buf322 = buf315; del buf315 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_547], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf317, buf322, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_545, all_reduce_31], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf323, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_31], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf323)
buf328 = buf313; del buf313 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_548], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf323, buf328, 2048, stream=stream0)
buf329 = buf276; del buf276 # reuse
buf330 = buf275; del buf275 # reuse
buf332 = buf296; del buf296 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_83, primals_131, primals_132, buf329, buf330, buf332, 327680, 2048, stream=stream0)
del primals_132
buf333 = reinterpret_tensor(buf246, (2048, 2048), (1, 2048), 0); del buf246 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_133, buf333, 4194304, stream=stream0)
del primals_133
buf334 = buf278; del buf278 # reuse
# Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm]
extern_kernels.mm(buf332, buf333, out=buf334)
buf335 = reinterpret_tensor(buf239, (2048, 512), (1, 2048), 0); del buf239 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_134, buf335, 1048576, stream=stream0)
del primals_134
buf336 = reinterpret_tensor(buf229, (327680, 512), (512, 1), 0); del buf229 # reuse
# Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm]
extern_kernels.mm(buf332, buf335, out=buf336)
buf337 = buf208; del buf208 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_135, buf337, 1048576, stream=stream0)
del primals_135
buf338 = reinterpret_tensor(buf228, (327680, 512), (512, 1), 0); del buf228 # reuse
# Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm]
extern_kernels.mm(buf332, buf337, out=buf338)
buf339 = buf226; del buf226 # reuse
buf340 = buf212; del buf212 # reuse
buf341 = reinterpret_tensor(buf233, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf233 # reuse
# Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_permute_view_15.run(buf334, buf336, buf338, buf339, buf340, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf341, s91, 2560, 1, 16, stream=stream0)
buf344 = reinterpret_tensor(buf206, (2048, 2048), (2048, 1), 0); del buf206 # reuse
# Topologically Sorted Source Nodes: [permute_185, rearrange_3, o, mm_77], Original ATen: [aten.t, aten.permute, aten.view, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf316, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf341, (327680, 2048), (2048, 1), 0), out=buf344)
buf345 = buf218; del buf218 # reuse
# Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_136, buf345, 4194304, stream=stream0)
del primals_136
buf346 = reinterpret_tensor(buf227, (327680, 2048), (2048, 1), 0); del buf227 # reuse
# Topologically Sorted Source Nodes: [redistribute_5, o, permute_187, mm_78], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf316, buf345, out=buf346)
# Topologically Sorted Source Nodes: [all_reduce_32], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf344, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_32], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf344)
buf351 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_553], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf344, buf351, 4194304, stream=stream0)
buf353 = buf340; del buf340 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_307, view_308, permute_189, flex_attention_backward_2], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_per_fused_flex_attention_backward_permute_view_17.run(buf341, buf346, buf353, 5242880, 128, stream=stream0)
buf354 = buf341; del buf341 # reuse
buf355 = empty_strided_cuda((1, 4, 327680, 128), (167772160, 128, 512, 1), torch.bfloat16)
buf356 = empty_strided_cuda((1, 4, 327680, 128), (167772160, 128, 512, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [q, k, v, view_307, view_308, permute_189, flex_attention_backward_2], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_backward_permute_view_18.run(buf334, buf336, buf338, buf339, buf353, buf346, buf354, buf355, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf356, s91, s16, 12800, 1, 4, stream=stream0)
buf359 = reinterpret_tensor(buf210, (512, 2048), (2048, 1), 0); del buf210 # reuse
# Topologically Sorted Source Nodes: [view_309, permute_190, view_310, permute_191, mm_79], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf355, (512, 327680), (1, 512), 0), buf332, out=buf359)
buf360 = buf346; del buf346 # reuse
# Topologically Sorted Source Nodes: [view_309, permute_190, view_310, permute_193, mm_80], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf355, (327680, 512), (512, 1), 0), reinterpret_tensor(buf337, (512, 2048), (2048, 1), 0), out=buf360)
# Topologically Sorted Source Nodes: [all_reduce_33], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf359, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_33], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf359)
buf365 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_558], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf359, buf365, 1048576, stream=stream0)
buf366 = buf359; del buf359 # reuse
# Topologically Sorted Source Nodes: [view_311, permute_195, view_312, permute_196, mm_81], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf356, (512, 327680), (1, 512), 0), buf332, out=buf366)
buf367 = buf334; del buf334 # reuse
# Topologically Sorted Source Nodes: [view_311, permute_195, view_312, permute_198, mm_82], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf356, (327680, 512), (512, 1), 0), reinterpret_tensor(buf335, (512, 2048), (2048, 1), 0), out=buf367)
# Topologically Sorted Source Nodes: [all_reduce_34], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf366, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_34], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf366)
buf372 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_563], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf366, buf372, 1048576, stream=stream0)
buf373 = buf344; del buf344 # reuse
# Topologically Sorted Source Nodes: [view_313, permute_200, view_314, permute_201, mm_83], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf354, (2048, 327680), (1, 2048), 0), buf332, out=buf373)
buf374 = buf332; del buf332 # reuse
# Topologically Sorted Source Nodes: [view_313, permute_200, view_314, permute_203, mm_84], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf354, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf333, (2048, 2048), (2048, 1), 0), out=buf374)
# Topologically Sorted Source Nodes: [all_reduce_35], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf373, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_35], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf373)
buf379 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_568], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf373, buf379, 4194304, stream=stream0)
buf387 = buf316; del buf316 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, add_131, add_132, convert_element_type_569, convert_element_type_571, mul_231, mul_232, sum_31, mul_233, sum_32, mul_234, sub_50, sub_51, div_9, mul_235, mul_236, sum_33, sum_34, convert_element_type_573, add_133], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward]
workspace_6 = workspace_5; del workspace_5 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf387, buf360, buf367, buf374, primals_131, add_83, buf329, buf330, workspace_6, 327680, 2048, stream=stream0)
buf384 = workspace_6[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf386 = workspace_6[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_83
del primals_131
buf394 = buf323; del buf323 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_574, all_reduce_37], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf384, buf394, 2048, stream=stream0)
buf388 = buf317; del buf317 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_575, all_reduce_36], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf386, buf388, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_575, all_reduce_36], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf388, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_36], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf388)
buf393 = buf386; del buf386 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_576], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf388, buf393, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_574, all_reduce_37], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf394, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_37], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf394)
buf399 = buf384; del buf384 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_577], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf394, buf399, 2048, stream=stream0)
buf400 = reinterpret_tensor(buf297, (2048, 8192), (8192, 1), 0); del buf297 # reuse
# Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_129, buf400, 16777216, stream=stream0)
del primals_129
buf401 = buf295; del buf295 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, x_20, permute_205, mm_85], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf387, buf400, out=buf401)
buf402 = buf330; del buf330 # reuse
buf403 = buf329; del buf329 # reuse
buf405 = buf374; del buf374 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_78, primals_125, primals_126, buf402, buf403, buf405, 327680, 2048, stream=stream0)
del primals_126
buf406 = reinterpret_tensor(buf400, (2048, 8192), (1, 2048), 0); del buf400 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_127, buf406, 16777216, stream=stream0)
del primals_127
buf407 = buf281; del buf281 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm]
extern_kernels.mm(buf405, buf406, out=buf407)
buf408 = buf280; del buf280 # reuse
buf422 = buf401; del buf401 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, x, x_19, convert_element_type_584, mul_242, mul_243, sub_52, mul_244, add_136, mul_245, mul_246, mul_247, add_137, mul_248, convert_element_type_586], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf422, primals_128, buf407, buf408, 2684354560, stream=stream0)
del primals_128
buf409 = buf282; del buf282 # reuse
# Topologically Sorted Source Nodes: [permute_206, redistribute_3, x, x_19, mm_86], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf387, (2048, 327680), (1, 2048), 0), buf408, out=buf409)
buf410 = buf283; del buf283 # reuse
# Topologically Sorted Source Nodes: [sum_35], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_6.run(buf387, buf410, 327680, 2048, stream=stream0)
buf411 = reinterpret_tensor(buf394, (1, 2048), (2048, 1), 0); del buf394 # reuse
# Topologically Sorted Source Nodes: [sum_35], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_7.run(buf410, buf411, 2048, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_315, all_reduce_38], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf411, (2048, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_315, wait_tensor_38], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf411, (2048, ), (1, ), 0))
buf416 = empty_strided_cuda((2048, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_315, convert_element_type_582], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf411, buf416, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_39], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf409, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_39], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf409)
buf421 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_583], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf409, buf421, 16777216, stream=stream0)
buf423 = buf367; del buf367 # reuse
# Topologically Sorted Source Nodes: [permute_209, mm_87], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(buf422, reinterpret_tensor(buf406, (8192, 2048), (2048, 1), 0), out=buf423)
buf424 = reinterpret_tensor(buf406, (8192, 2048), (2048, 1), 0); del buf406 # reuse
# Topologically Sorted Source Nodes: [permute_210, mm_88], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf422, (8192, 327680), (1, 8192), 0), buf405, out=buf424)
buf425 = buf298; del buf298 # reuse
# Topologically Sorted Source Nodes: [sum_36], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_9.run(buf422, buf425, 1310720, 2048, stream=stream0)
buf426 = buf299; del buf299 # reuse
# Topologically Sorted Source Nodes: [sum_36], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_10.run(buf425, buf426, 8192, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_316, all_reduce_40], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf426, (8192, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_316, wait_tensor_40], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf426, (8192, ), (1, ), 0))
buf431 = empty_strided_cuda((8192, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_316, convert_element_type_591], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_view_11.run(buf426, buf431, 8192, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_41], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf424, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_41], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf424)
buf436 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_592], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf424, buf436, 16777216, stream=stream0)
buf443 = buf387; del buf387 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_593, convert_element_type_595, mul_250, mul_251, sum_37, mul_252, sum_38, mul_253, sub_54, sub_55, div_10, mul_254, mul_255, sum_39, sum_40, convert_element_type_597, add_138], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add]
workspace_7 = workspace_6; del workspace_6 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf443, buf423, primals_125, add_78, buf402, buf403, workspace_7, 327680, 2048, stream=stream0)
buf440 = workspace_7[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf442 = workspace_7[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_78
del primals_125
buf450 = reinterpret_tensor(buf411, (2048, ), (1, ), 0); del buf411 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_598, all_reduce_43], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf440, buf450, 2048, stream=stream0)
buf444 = buf388; del buf388 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_599, all_reduce_42], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf442, buf444, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_599, all_reduce_42], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf444, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_42], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf444)
buf449 = buf442; del buf442 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_600], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf444, buf449, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_598, all_reduce_43], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf450, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_43], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf450)
buf455 = buf440; del buf440 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_601], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf450, buf455, 2048, stream=stream0)
buf456 = buf403; del buf403 # reuse
buf457 = buf402; del buf402 # reuse
buf459 = buf423; del buf423 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_75, primals_119, primals_120, buf456, buf457, buf459, 327680, 2048, stream=stream0)
del primals_120
buf460 = reinterpret_tensor(buf373, (2048, 2048), (1, 2048), 0); del buf373 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_121, buf460, 4194304, stream=stream0)
del primals_121
buf461 = buf405; del buf405 # reuse
# Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm]
extern_kernels.mm(buf459, buf460, out=buf461)
buf462 = reinterpret_tensor(buf366, (2048, 512), (1, 2048), 0); del buf366 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_122, buf462, 1048576, stream=stream0)
del primals_122
buf463 = reinterpret_tensor(buf356, (327680, 512), (512, 1), 0); del buf356 # reuse
# Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm]
extern_kernels.mm(buf459, buf462, out=buf463)
buf464 = buf335; del buf335 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_123, buf464, 1048576, stream=stream0)
del primals_123
buf465 = reinterpret_tensor(buf355, (327680, 512), (512, 1), 0); del buf355 # reuse
# Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm]
extern_kernels.mm(buf459, buf464, out=buf465)
buf466 = buf353; del buf353 # reuse
buf467 = buf339; del buf339 # reuse
buf468 = reinterpret_tensor(buf360, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf360 # reuse
# Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_permute_view_15.run(buf461, buf463, buf465, buf466, buf467, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf468, s91, 2560, 1, 16, stream=stream0)
buf471 = reinterpret_tensor(buf333, (2048, 2048), (2048, 1), 0); del buf333 # reuse
# Topologically Sorted Source Nodes: [permute_213, rearrange_3, o, mm_89], Original ATen: [aten.t, aten.permute, aten.view, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf443, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf468, (327680, 2048), (2048, 1), 0), out=buf471)
buf472 = buf345; del buf345 # reuse
# Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_124, buf472, 4194304, stream=stream0)
del primals_124
buf473 = reinterpret_tensor(buf354, (327680, 2048), (2048, 1), 0); del buf354 # reuse
# Topologically Sorted Source Nodes: [redistribute_5, o, permute_215, mm_90], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf443, buf472, out=buf473)
# Topologically Sorted Source Nodes: [all_reduce_44], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf471, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_44], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf471)
buf478 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_606], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf471, buf478, 4194304, stream=stream0)
buf480 = buf467; del buf467 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_319, view_320, permute_217, flex_attention_backward_3], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_per_fused_flex_attention_backward_permute_view_17.run(buf468, buf473, buf480, 5242880, 128, stream=stream0)
buf481 = buf468; del buf468 # reuse
buf482 = reinterpret_tensor(buf338, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf338 # reuse
buf483 = reinterpret_tensor(buf336, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf336 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_319, view_320, permute_217, flex_attention_backward_3], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_backward_permute_view_18.run(buf461, buf463, buf465, buf466, buf480, buf473, buf481, buf482, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf483, s91, s16, 12800, 1, 4, stream=stream0)
buf486 = reinterpret_tensor(buf337, (512, 2048), (2048, 1), 0); del buf337 # reuse
# Topologically Sorted Source Nodes: [view_321, permute_218, view_322, permute_219, mm_91], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf482, (512, 327680), (1, 512), 0), buf459, out=buf486)
buf487 = buf473; del buf473 # reuse
# Topologically Sorted Source Nodes: [view_321, permute_218, view_322, permute_221, mm_92], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf482, (327680, 512), (512, 1), 0), reinterpret_tensor(buf464, (512, 2048), (2048, 1), 0), out=buf487)
# Topologically Sorted Source Nodes: [all_reduce_45], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf486, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_45], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf486)
buf492 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_611], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf486, buf492, 1048576, stream=stream0)
buf493 = buf486; del buf486 # reuse
# Topologically Sorted Source Nodes: [view_323, permute_223, view_324, permute_224, mm_93], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf483, (512, 327680), (1, 512), 0), buf459, out=buf493)
buf494 = buf461; del buf461 # reuse
# Topologically Sorted Source Nodes: [view_323, permute_223, view_324, permute_226, mm_94], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf483, (327680, 512), (512, 1), 0), reinterpret_tensor(buf462, (512, 2048), (2048, 1), 0), out=buf494)
# Topologically Sorted Source Nodes: [all_reduce_46], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf493, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_46], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf493)
buf499 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_616], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf493, buf499, 1048576, stream=stream0)
buf500 = buf471; del buf471 # reuse
# Topologically Sorted Source Nodes: [view_325, permute_228, view_326, permute_229, mm_95], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf481, (2048, 327680), (1, 2048), 0), buf459, out=buf500)
buf501 = buf459; del buf459 # reuse
# Topologically Sorted Source Nodes: [view_325, permute_228, view_326, permute_231, mm_96], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf481, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf460, (2048, 2048), (2048, 1), 0), out=buf501)
# Topologically Sorted Source Nodes: [all_reduce_47], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf500, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_47], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf500)
buf506 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_621], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf500, buf506, 4194304, stream=stream0)
buf514 = buf443; del buf443 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, add_139, add_140, convert_element_type_622, convert_element_type_624, mul_257, mul_258, sum_41, mul_259, sum_42, mul_260, sub_57, sub_58, div_11, mul_261, mul_262, sum_43, sum_44, convert_element_type_626, add_141], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward]
workspace_8 = workspace_7; del workspace_7 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf514, buf487, buf494, buf501, primals_119, add_75, buf456, buf457, workspace_8, 327680, 2048, stream=stream0)
buf511 = workspace_8[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf513 = workspace_8[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_75
del primals_119
buf521 = buf450; del buf450 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_627, all_reduce_49], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf511, buf521, 2048, stream=stream0)
buf515 = buf444; del buf444 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_628, all_reduce_48], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf513, buf515, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_628, all_reduce_48], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf515, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_48], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf515)
buf520 = buf513; del buf513 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_629], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf515, buf520, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_627, all_reduce_49], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf521, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_49], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf521)
buf526 = buf511; del buf511 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_630], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf521, buf526, 2048, stream=stream0)
buf527 = reinterpret_tensor(buf424, (2048, 8192), (8192, 1), 0); del buf424 # reuse
# Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_117, buf527, 16777216, stream=stream0)
del primals_117
buf528 = buf422; del buf422 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, x_18, permute_233, mm_97], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf514, buf527, out=buf528)
buf529 = buf457; del buf457 # reuse
buf530 = buf456; del buf456 # reuse
buf532 = buf501; del buf501 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_70, primals_113, primals_114, buf529, buf530, buf532, 327680, 2048, stream=stream0)
del primals_114
buf533 = reinterpret_tensor(buf527, (2048, 8192), (1, 2048), 0); del buf527 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_115, buf533, 16777216, stream=stream0)
del primals_115
buf534 = buf408; del buf408 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm]
extern_kernels.mm(buf532, buf533, out=buf534)
buf535 = buf407; del buf407 # reuse
buf549 = buf528; del buf528 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, x, x_17, convert_element_type_637, mul_268, mul_269, sub_59, mul_270, add_144, mul_271, mul_272, mul_273, add_145, mul_274, convert_element_type_639], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf549, primals_116, buf534, buf535, 2684354560, stream=stream0)
del primals_116
buf536 = buf409; del buf409 # reuse
# Topologically Sorted Source Nodes: [permute_234, redistribute_3, x, x_17, mm_98], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf514, (2048, 327680), (1, 2048), 0), buf535, out=buf536)
buf537 = buf410; del buf410 # reuse
# Topologically Sorted Source Nodes: [sum_45], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_6.run(buf514, buf537, 327680, 2048, stream=stream0)
buf538 = reinterpret_tensor(buf521, (1, 2048), (2048, 1), 0); del buf521 # reuse
# Topologically Sorted Source Nodes: [sum_45], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_7.run(buf537, buf538, 2048, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_327, all_reduce_50], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf538, (2048, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_327, wait_tensor_50], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf538, (2048, ), (1, ), 0))
buf543 = empty_strided_cuda((2048, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_327, convert_element_type_635], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf538, buf543, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_51], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf536, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_51], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf536)
buf548 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_636], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf536, buf548, 16777216, stream=stream0)
buf550 = buf494; del buf494 # reuse
# Topologically Sorted Source Nodes: [permute_237, mm_99], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(buf549, reinterpret_tensor(buf533, (8192, 2048), (2048, 1), 0), out=buf550)
buf551 = reinterpret_tensor(buf533, (8192, 2048), (2048, 1), 0); del buf533 # reuse
# Topologically Sorted Source Nodes: [permute_238, mm_100], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf549, (8192, 327680), (1, 8192), 0), buf532, out=buf551)
buf552 = buf425; del buf425 # reuse
# Topologically Sorted Source Nodes: [sum_46], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_9.run(buf549, buf552, 1310720, 2048, stream=stream0)
buf553 = buf426; del buf426 # reuse
# Topologically Sorted Source Nodes: [sum_46], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_10.run(buf552, buf553, 8192, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_328, all_reduce_52], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf553, (8192, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_328, wait_tensor_52], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf553, (8192, ), (1, ), 0))
buf558 = empty_strided_cuda((8192, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_328, convert_element_type_644], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_view_11.run(buf553, buf558, 8192, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_53], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf551, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_53], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf551)
buf563 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_645], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf551, buf563, 16777216, stream=stream0)
buf570 = buf514; del buf514 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_646, convert_element_type_648, mul_276, mul_277, sum_47, mul_278, sum_48, mul_279, sub_61, sub_62, div_12, mul_280, mul_281, sum_49, sum_50, convert_element_type_650, add_146], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add]
workspace_9 = workspace_8; del workspace_8 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf570, buf550, primals_113, add_70, buf529, buf530, workspace_9, 327680, 2048, stream=stream0)
buf567 = workspace_9[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf569 = workspace_9[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_70
del primals_113
buf577 = reinterpret_tensor(buf538, (2048, ), (1, ), 0); del buf538 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_651, all_reduce_55], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf567, buf577, 2048, stream=stream0)
buf571 = buf515; del buf515 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_652, all_reduce_54], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf569, buf571, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_652, all_reduce_54], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf571, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_54], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf571)
buf576 = buf569; del buf569 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_653], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf571, buf576, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_651, all_reduce_55], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf577, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_55], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf577)
buf582 = buf567; del buf567 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_654], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf577, buf582, 2048, stream=stream0)
buf583 = buf530; del buf530 # reuse
buf584 = buf529; del buf529 # reuse
buf586 = buf550; del buf550 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_67, primals_107, primals_108, buf583, buf584, buf586, 327680, 2048, stream=stream0)
del primals_108
buf587 = reinterpret_tensor(buf500, (2048, 2048), (1, 2048), 0); del buf500 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_109, buf587, 4194304, stream=stream0)
del primals_109
buf588 = buf532; del buf532 # reuse
# Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm]
extern_kernels.mm(buf586, buf587, out=buf588)
buf589 = reinterpret_tensor(buf493, (2048, 512), (1, 2048), 0); del buf493 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_110, buf589, 1048576, stream=stream0)
del primals_110
buf590 = reinterpret_tensor(buf483, (327680, 512), (512, 1), 0); del buf483 # reuse
# Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm]
extern_kernels.mm(buf586, buf589, out=buf590)
buf591 = buf462; del buf462 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_111, buf591, 1048576, stream=stream0)
del primals_111
buf592 = reinterpret_tensor(buf482, (327680, 512), (512, 1), 0); del buf482 # reuse
# Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm]
extern_kernels.mm(buf586, buf591, out=buf592)
buf593 = buf480; del buf480 # reuse
buf594 = buf466; del buf466 # reuse
buf595 = reinterpret_tensor(buf487, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf487 # reuse
# Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_permute_view_15.run(buf588, buf590, buf592, buf593, buf594, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf595, s91, 2560, 1, 16, stream=stream0)
buf598 = reinterpret_tensor(buf460, (2048, 2048), (2048, 1), 0); del buf460 # reuse
# Topologically Sorted Source Nodes: [permute_241, rearrange_3, o, mm_101], Original ATen: [aten.t, aten.permute, aten.view, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf570, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf595, (327680, 2048), (2048, 1), 0), out=buf598)
buf599 = buf472; del buf472 # reuse
# Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_112, buf599, 4194304, stream=stream0)
del primals_112
buf600 = reinterpret_tensor(buf481, (327680, 2048), (2048, 1), 0); del buf481 # reuse
# Topologically Sorted Source Nodes: [redistribute_5, o, permute_243, mm_102], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf570, buf599, out=buf600)
# Topologically Sorted Source Nodes: [all_reduce_56], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf598, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_56], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf598)
buf605 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_659], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf598, buf605, 4194304, stream=stream0)
buf607 = buf594; del buf594 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_331, view_332, permute_245, flex_attention_backward_4], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_per_fused_flex_attention_backward_permute_view_17.run(buf595, buf600, buf607, 5242880, 128, stream=stream0)
buf608 = buf595; del buf595 # reuse
buf609 = reinterpret_tensor(buf465, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf465 # reuse
buf610 = reinterpret_tensor(buf463, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf463 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_331, view_332, permute_245, flex_attention_backward_4], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_backward_permute_view_18.run(buf588, buf590, buf592, buf593, buf607, buf600, buf608, buf609, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf610, s91, s16, 12800, 1, 4, stream=stream0)
buf613 = reinterpret_tensor(buf464, (512, 2048), (2048, 1), 0); del buf464 # reuse
# Topologically Sorted Source Nodes: [view_333, permute_246, view_334, permute_247, mm_103], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf609, (512, 327680), (1, 512), 0), buf586, out=buf613)
buf614 = buf600; del buf600 # reuse
# Topologically Sorted Source Nodes: [view_333, permute_246, view_334, permute_249, mm_104], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf609, (327680, 512), (512, 1), 0), reinterpret_tensor(buf591, (512, 2048), (2048, 1), 0), out=buf614)
# Topologically Sorted Source Nodes: [all_reduce_57], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf613, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_57], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf613)
buf619 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_664], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf613, buf619, 1048576, stream=stream0)
buf620 = buf613; del buf613 # reuse
# Topologically Sorted Source Nodes: [view_335, permute_251, view_336, permute_252, mm_105], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf610, (512, 327680), (1, 512), 0), buf586, out=buf620)
buf621 = buf588; del buf588 # reuse
# Topologically Sorted Source Nodes: [view_335, permute_251, view_336, permute_254, mm_106], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf610, (327680, 512), (512, 1), 0), reinterpret_tensor(buf589, (512, 2048), (2048, 1), 0), out=buf621)
# Topologically Sorted Source Nodes: [all_reduce_58], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf620, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_58], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf620)
buf626 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_669], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf620, buf626, 1048576, stream=stream0)
buf627 = buf598; del buf598 # reuse
# Topologically Sorted Source Nodes: [view_337, permute_256, view_338, permute_257, mm_107], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf608, (2048, 327680), (1, 2048), 0), buf586, out=buf627)
buf628 = buf586; del buf586 # reuse
# Topologically Sorted Source Nodes: [view_337, permute_256, view_338, permute_259, mm_108], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf608, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf587, (2048, 2048), (2048, 1), 0), out=buf628)
# Topologically Sorted Source Nodes: [all_reduce_59], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf627, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_59], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf627)
buf633 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_674], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf627, buf633, 4194304, stream=stream0)
buf641 = buf570; del buf570 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, add_147, add_148, convert_element_type_675, convert_element_type_677, mul_283, mul_284, sum_51, mul_285, sum_52, mul_286, sub_64, sub_65, div_13, mul_287, mul_288, sum_53, sum_54, convert_element_type_679, add_149], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward]
workspace_10 = workspace_9; del workspace_9 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf641, buf614, buf621, buf628, primals_107, add_67, buf583, buf584, workspace_10, 327680, 2048, stream=stream0)
buf638 = workspace_10[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf640 = workspace_10[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_67
del primals_107
buf648 = buf577; del buf577 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_680, all_reduce_61], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf638, buf648, 2048, stream=stream0)
buf642 = buf571; del buf571 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_681, all_reduce_60], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf640, buf642, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_681, all_reduce_60], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf642, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_60], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf642)
buf647 = buf640; del buf640 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_682], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf642, buf647, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_680, all_reduce_61], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf648, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_61], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf648)
buf653 = buf638; del buf638 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_683], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf648, buf653, 2048, stream=stream0)
buf654 = reinterpret_tensor(buf551, (2048, 8192), (8192, 1), 0); del buf551 # reuse
# Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_105, buf654, 16777216, stream=stream0)
del primals_105
buf655 = buf549; del buf549 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, x_16, permute_261, mm_109], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf641, buf654, out=buf655)
buf656 = buf584; del buf584 # reuse
buf657 = buf583; del buf583 # reuse
buf659 = buf628; del buf628 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_62, primals_101, primals_102, buf656, buf657, buf659, 327680, 2048, stream=stream0)
del primals_102
buf660 = reinterpret_tensor(buf654, (2048, 8192), (1, 2048), 0); del buf654 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_103, buf660, 16777216, stream=stream0)
del primals_103
buf661 = buf535; del buf535 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm]
extern_kernels.mm(buf659, buf660, out=buf661)
buf662 = buf534; del buf534 # reuse
buf676 = buf655; del buf655 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, x, x_15, convert_element_type_690, mul_294, mul_295, sub_66, mul_296, add_152, mul_297, mul_298, mul_299, add_153, mul_300, convert_element_type_692], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf676, primals_104, buf661, buf662, 2684354560, stream=stream0)
del primals_104
buf663 = buf536; del buf536 # reuse
# Topologically Sorted Source Nodes: [permute_262, redistribute_3, x, x_15, mm_110], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf641, (2048, 327680), (1, 2048), 0), buf662, out=buf663)
buf664 = buf537; del buf537 # reuse
# Topologically Sorted Source Nodes: [sum_55], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_6.run(buf641, buf664, 327680, 2048, stream=stream0)
buf665 = reinterpret_tensor(buf648, (1, 2048), (2048, 1), 0); del buf648 # reuse
# Topologically Sorted Source Nodes: [sum_55], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_7.run(buf664, buf665, 2048, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_339, all_reduce_62], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf665, (2048, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_339, wait_tensor_62], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf665, (2048, ), (1, ), 0))
buf670 = empty_strided_cuda((2048, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_339, convert_element_type_688], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf665, buf670, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_63], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf663, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_63], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf663)
buf675 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_689], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf663, buf675, 16777216, stream=stream0)
buf677 = buf621; del buf621 # reuse
# Topologically Sorted Source Nodes: [permute_265, mm_111], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(buf676, reinterpret_tensor(buf660, (8192, 2048), (2048, 1), 0), out=buf677)
buf678 = reinterpret_tensor(buf660, (8192, 2048), (2048, 1), 0); del buf660 # reuse
# Topologically Sorted Source Nodes: [permute_266, mm_112], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf676, (8192, 327680), (1, 8192), 0), buf659, out=buf678)
buf679 = buf552; del buf552 # reuse
# Topologically Sorted Source Nodes: [sum_56], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_9.run(buf676, buf679, 1310720, 2048, stream=stream0)
buf680 = buf553; del buf553 # reuse
# Topologically Sorted Source Nodes: [sum_56], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_10.run(buf679, buf680, 8192, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_340, all_reduce_64], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf680, (8192, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_340, wait_tensor_64], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf680, (8192, ), (1, ), 0))
buf685 = empty_strided_cuda((8192, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_340, convert_element_type_697], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_view_11.run(buf680, buf685, 8192, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_65], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf678, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_65], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf678)
buf690 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_698], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf678, buf690, 16777216, stream=stream0)
buf697 = buf641; del buf641 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_699, convert_element_type_701, mul_302, mul_303, sum_57, mul_304, sum_58, mul_305, sub_68, sub_69, div_14, mul_306, mul_307, sum_59, sum_60, convert_element_type_703, add_154], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add]
workspace_11 = workspace_10; del workspace_10 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf697, buf677, primals_101, add_62, buf656, buf657, workspace_11, 327680, 2048, stream=stream0)
buf694 = workspace_11[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf696 = workspace_11[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_62
del primals_101
buf704 = reinterpret_tensor(buf665, (2048, ), (1, ), 0); del buf665 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_704, all_reduce_67], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf694, buf704, 2048, stream=stream0)
buf698 = buf642; del buf642 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_705, all_reduce_66], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf696, buf698, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_705, all_reduce_66], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf698, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_66], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf698)
buf703 = buf696; del buf696 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_706], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf698, buf703, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_704, all_reduce_67], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf704, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_67], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf704)
buf709 = buf694; del buf694 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_707], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf704, buf709, 2048, stream=stream0)
buf710 = buf657; del buf657 # reuse
buf711 = buf656; del buf656 # reuse
buf713 = buf677; del buf677 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_59, primals_95, primals_96, buf710, buf711, buf713, 327680, 2048, stream=stream0)
del primals_96
buf714 = reinterpret_tensor(buf627, (2048, 2048), (1, 2048), 0); del buf627 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_97, buf714, 4194304, stream=stream0)
del primals_97
buf715 = buf659; del buf659 # reuse
# Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm]
extern_kernels.mm(buf713, buf714, out=buf715)
buf716 = reinterpret_tensor(buf620, (2048, 512), (1, 2048), 0); del buf620 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_98, buf716, 1048576, stream=stream0)
del primals_98
buf717 = reinterpret_tensor(buf610, (327680, 512), (512, 1), 0); del buf610 # reuse
# Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm]
extern_kernels.mm(buf713, buf716, out=buf717)
buf718 = buf589; del buf589 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_99, buf718, 1048576, stream=stream0)
del primals_99
buf719 = reinterpret_tensor(buf609, (327680, 512), (512, 1), 0); del buf609 # reuse
# Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm]
extern_kernels.mm(buf713, buf718, out=buf719)
buf720 = buf607; del buf607 # reuse
buf721 = buf593; del buf593 # reuse
buf722 = reinterpret_tensor(buf614, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf614 # reuse
# Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_permute_view_15.run(buf715, buf717, buf719, buf720, buf721, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf722, s91, 2560, 1, 16, stream=stream0)
buf725 = reinterpret_tensor(buf587, (2048, 2048), (2048, 1), 0); del buf587 # reuse
# Topologically Sorted Source Nodes: [permute_269, rearrange_3, o, mm_113], Original ATen: [aten.t, aten.permute, aten.view, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf697, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf722, (327680, 2048), (2048, 1), 0), out=buf725)
buf726 = buf599; del buf599 # reuse
# Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_100, buf726, 4194304, stream=stream0)
del primals_100
buf727 = reinterpret_tensor(buf608, (327680, 2048), (2048, 1), 0); del buf608 # reuse
# Topologically Sorted Source Nodes: [redistribute_5, o, permute_271, mm_114], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf697, buf726, out=buf727)
# Topologically Sorted Source Nodes: [all_reduce_68], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf725, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_68], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf725)
buf732 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_712], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf725, buf732, 4194304, stream=stream0)
buf734 = buf721; del buf721 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_343, view_344, permute_273, flex_attention_backward_5], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_per_fused_flex_attention_backward_permute_view_17.run(buf722, buf727, buf734, 5242880, 128, stream=stream0)
buf735 = buf722; del buf722 # reuse
buf736 = reinterpret_tensor(buf592, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf592 # reuse
buf737 = reinterpret_tensor(buf590, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf590 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_343, view_344, permute_273, flex_attention_backward_5], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_backward_permute_view_18.run(buf715, buf717, buf719, buf720, buf734, buf727, buf735, buf736, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf737, s91, s16, 12800, 1, 4, stream=stream0)
buf740 = reinterpret_tensor(buf591, (512, 2048), (2048, 1), 0); del buf591 # reuse
# Topologically Sorted Source Nodes: [view_345, permute_274, view_346, permute_275, mm_115], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf736, (512, 327680), (1, 512), 0), buf713, out=buf740)
buf741 = buf727; del buf727 # reuse
# Topologically Sorted Source Nodes: [view_345, permute_274, view_346, permute_277, mm_116], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf736, (327680, 512), (512, 1), 0), reinterpret_tensor(buf718, (512, 2048), (2048, 1), 0), out=buf741)
# Topologically Sorted Source Nodes: [all_reduce_69], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf740, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_69], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf740)
buf746 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_717], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf740, buf746, 1048576, stream=stream0)
buf747 = buf740; del buf740 # reuse
# Topologically Sorted Source Nodes: [view_347, permute_279, view_348, permute_280, mm_117], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf737, (512, 327680), (1, 512), 0), buf713, out=buf747)
buf748 = buf715; del buf715 # reuse
# Topologically Sorted Source Nodes: [view_347, permute_279, view_348, permute_282, mm_118], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf737, (327680, 512), (512, 1), 0), reinterpret_tensor(buf716, (512, 2048), (2048, 1), 0), out=buf748)
# Topologically Sorted Source Nodes: [all_reduce_70], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf747, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_70], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf747)
buf753 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_722], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf747, buf753, 1048576, stream=stream0)
buf754 = buf725; del buf725 # reuse
# Topologically Sorted Source Nodes: [view_349, permute_284, view_350, permute_285, mm_119], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf735, (2048, 327680), (1, 2048), 0), buf713, out=buf754)
buf755 = buf713; del buf713 # reuse
# Topologically Sorted Source Nodes: [view_349, permute_284, view_350, permute_287, mm_120], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf735, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf714, (2048, 2048), (2048, 1), 0), out=buf755)
# Topologically Sorted Source Nodes: [all_reduce_71], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf754, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_71], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf754)
buf760 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_727], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf754, buf760, 4194304, stream=stream0)
buf768 = buf697; del buf697 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, add_155, add_156, convert_element_type_728, convert_element_type_730, mul_309, mul_310, sum_61, mul_311, sum_62, mul_312, sub_71, sub_72, div_15, mul_313, mul_314, sum_63, sum_64, convert_element_type_732, add_157], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward]
workspace_12 = workspace_11; del workspace_11 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf768, buf741, buf748, buf755, primals_95, add_59, buf710, buf711, workspace_12, 327680, 2048, stream=stream0)
buf765 = workspace_12[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf767 = workspace_12[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_59
del primals_95
buf775 = buf704; del buf704 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_733, all_reduce_73], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf765, buf775, 2048, stream=stream0)
buf769 = buf698; del buf698 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_734, all_reduce_72], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf767, buf769, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_734, all_reduce_72], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf769, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_72], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf769)
buf774 = buf767; del buf767 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_735], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf769, buf774, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_733, all_reduce_73], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf775, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_73], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf775)
buf780 = buf765; del buf765 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_736], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf775, buf780, 2048, stream=stream0)
buf781 = reinterpret_tensor(buf678, (2048, 8192), (8192, 1), 0); del buf678 # reuse
# Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_93, buf781, 16777216, stream=stream0)
del primals_93
buf782 = buf676; del buf676 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, x_14, permute_289, mm_121], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf768, buf781, out=buf782)
buf783 = buf711; del buf711 # reuse
buf784 = buf710; del buf710 # reuse
buf786 = buf755; del buf755 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_54, primals_89, primals_90, buf783, buf784, buf786, 327680, 2048, stream=stream0)
del primals_90
buf787 = reinterpret_tensor(buf781, (2048, 8192), (1, 2048), 0); del buf781 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_91, buf787, 16777216, stream=stream0)
del primals_91
buf788 = buf662; del buf662 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm]
extern_kernels.mm(buf786, buf787, out=buf788)
buf789 = buf661; del buf661 # reuse
buf803 = buf782; del buf782 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, x, x_13, convert_element_type_743, mul_320, mul_321, sub_73, mul_322, add_160, mul_323, mul_324, mul_325, add_161, mul_326, convert_element_type_745], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf803, primals_92, buf788, buf789, 2684354560, stream=stream0)
del primals_92
buf790 = buf663; del buf663 # reuse
# Topologically Sorted Source Nodes: [permute_290, redistribute_3, x, x_13, mm_122], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf768, (2048, 327680), (1, 2048), 0), buf789, out=buf790)
buf791 = buf664; del buf664 # reuse
# Topologically Sorted Source Nodes: [sum_65], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_6.run(buf768, buf791, 327680, 2048, stream=stream0)
buf792 = reinterpret_tensor(buf775, (1, 2048), (2048, 1), 0); del buf775 # reuse
# Topologically Sorted Source Nodes: [sum_65], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_7.run(buf791, buf792, 2048, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_351, all_reduce_74], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf792, (2048, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_351, wait_tensor_74], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf792, (2048, ), (1, ), 0))
buf797 = empty_strided_cuda((2048, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_351, convert_element_type_741], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf792, buf797, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_75], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf790, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_75], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf790)
buf802 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_742], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf790, buf802, 16777216, stream=stream0)
buf804 = buf748; del buf748 # reuse
# Topologically Sorted Source Nodes: [permute_293, mm_123], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(buf803, reinterpret_tensor(buf787, (8192, 2048), (2048, 1), 0), out=buf804)
buf805 = reinterpret_tensor(buf787, (8192, 2048), (2048, 1), 0); del buf787 # reuse
# Topologically Sorted Source Nodes: [permute_294, mm_124], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf803, (8192, 327680), (1, 8192), 0), buf786, out=buf805)
buf806 = buf679; del buf679 # reuse
# Topologically Sorted Source Nodes: [sum_66], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_9.run(buf803, buf806, 1310720, 2048, stream=stream0)
buf807 = buf680; del buf680 # reuse
# Topologically Sorted Source Nodes: [sum_66], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_10.run(buf806, buf807, 8192, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_352, all_reduce_76], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf807, (8192, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_352, wait_tensor_76], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf807, (8192, ), (1, ), 0))
buf812 = empty_strided_cuda((8192, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_352, convert_element_type_750], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_view_11.run(buf807, buf812, 8192, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_77], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf805, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_77], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf805)
buf817 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_751], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf805, buf817, 16777216, stream=stream0)
buf824 = buf768; del buf768 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_752, convert_element_type_754, mul_328, mul_329, sum_67, mul_330, sum_68, mul_331, sub_75, sub_76, div_16, mul_332, mul_333, sum_69, sum_70, convert_element_type_756, add_162], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add]
workspace_13 = workspace_12; del workspace_12 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf824, buf804, primals_89, add_54, buf783, buf784, workspace_13, 327680, 2048, stream=stream0)
buf821 = workspace_13[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf823 = workspace_13[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_54
del primals_89
buf831 = reinterpret_tensor(buf792, (2048, ), (1, ), 0); del buf792 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_757, all_reduce_79], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf821, buf831, 2048, stream=stream0)
buf825 = buf769; del buf769 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_758, all_reduce_78], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf823, buf825, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_758, all_reduce_78], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf825, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_78], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf825)
buf830 = buf823; del buf823 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_759], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf825, buf830, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_757, all_reduce_79], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf831, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_79], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf831)
buf836 = buf821; del buf821 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_760], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf831, buf836, 2048, stream=stream0)
buf837 = buf784; del buf784 # reuse
buf838 = buf783; del buf783 # reuse
buf840 = buf804; del buf804 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_51, primals_83, primals_84, buf837, buf838, buf840, 327680, 2048, stream=stream0)
del primals_84
buf841 = reinterpret_tensor(buf754, (2048, 2048), (1, 2048), 0); del buf754 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_85, buf841, 4194304, stream=stream0)
del primals_85
buf842 = buf786; del buf786 # reuse
# Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm]
extern_kernels.mm(buf840, buf841, out=buf842)
buf843 = reinterpret_tensor(buf747, (2048, 512), (1, 2048), 0); del buf747 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_86, buf843, 1048576, stream=stream0)
del primals_86
buf844 = reinterpret_tensor(buf737, (327680, 512), (512, 1), 0); del buf737 # reuse
# Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm]
extern_kernels.mm(buf840, buf843, out=buf844)
buf845 = buf716; del buf716 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_87, buf845, 1048576, stream=stream0)
del primals_87
buf846 = reinterpret_tensor(buf736, (327680, 512), (512, 1), 0); del buf736 # reuse
# Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm]
extern_kernels.mm(buf840, buf845, out=buf846)
buf847 = buf734; del buf734 # reuse
buf848 = buf720; del buf720 # reuse
buf849 = reinterpret_tensor(buf741, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf741 # reuse
# Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_permute_view_15.run(buf842, buf844, buf846, buf847, buf848, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf849, s91, 2560, 1, 16, stream=stream0)
buf852 = reinterpret_tensor(buf714, (2048, 2048), (2048, 1), 0); del buf714 # reuse
# Topologically Sorted Source Nodes: [permute_297, rearrange_3, o, mm_125], Original ATen: [aten.t, aten.permute, aten.view, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf824, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf849, (327680, 2048), (2048, 1), 0), out=buf852)
buf853 = buf726; del buf726 # reuse
# Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_88, buf853, 4194304, stream=stream0)
del primals_88
buf854 = reinterpret_tensor(buf735, (327680, 2048), (2048, 1), 0); del buf735 # reuse
# Topologically Sorted Source Nodes: [redistribute_5, o, permute_299, mm_126], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf824, buf853, out=buf854)
# Topologically Sorted Source Nodes: [all_reduce_80], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf852, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_80], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf852)
buf859 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_765], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf852, buf859, 4194304, stream=stream0)
buf861 = buf848; del buf848 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_355, view_356, permute_301, flex_attention_backward_6], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_per_fused_flex_attention_backward_permute_view_17.run(buf849, buf854, buf861, 5242880, 128, stream=stream0)
buf862 = buf849; del buf849 # reuse
buf863 = reinterpret_tensor(buf719, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf719 # reuse
buf864 = reinterpret_tensor(buf717, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf717 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_355, view_356, permute_301, flex_attention_backward_6], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_backward_permute_view_18.run(buf842, buf844, buf846, buf847, buf861, buf854, buf862, buf863, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf864, s91, s16, 12800, 1, 4, stream=stream0)
buf867 = reinterpret_tensor(buf718, (512, 2048), (2048, 1), 0); del buf718 # reuse
# Topologically Sorted Source Nodes: [view_357, permute_302, view_358, permute_303, mm_127], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf863, (512, 327680), (1, 512), 0), buf840, out=buf867)
buf868 = buf854; del buf854 # reuse
# Topologically Sorted Source Nodes: [view_357, permute_302, view_358, permute_305, mm_128], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf863, (327680, 512), (512, 1), 0), reinterpret_tensor(buf845, (512, 2048), (2048, 1), 0), out=buf868)
# Topologically Sorted Source Nodes: [all_reduce_81], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf867, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_81], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf867)
buf873 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_770], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf867, buf873, 1048576, stream=stream0)
buf874 = buf867; del buf867 # reuse
# Topologically Sorted Source Nodes: [view_359, permute_307, view_360, permute_308, mm_129], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf864, (512, 327680), (1, 512), 0), buf840, out=buf874)
buf875 = buf842; del buf842 # reuse
# Topologically Sorted Source Nodes: [view_359, permute_307, view_360, permute_310, mm_130], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf864, (327680, 512), (512, 1), 0), reinterpret_tensor(buf843, (512, 2048), (2048, 1), 0), out=buf875)
# Topologically Sorted Source Nodes: [all_reduce_82], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf874, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_82], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf874)
buf880 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_775], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf874, buf880, 1048576, stream=stream0)
buf881 = buf852; del buf852 # reuse
# Topologically Sorted Source Nodes: [view_361, permute_312, view_362, permute_313, mm_131], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf862, (2048, 327680), (1, 2048), 0), buf840, out=buf881)
buf882 = buf840; del buf840 # reuse
# Topologically Sorted Source Nodes: [view_361, permute_312, view_362, permute_315, mm_132], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf862, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf841, (2048, 2048), (2048, 1), 0), out=buf882)
# Topologically Sorted Source Nodes: [all_reduce_83], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf881, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_83], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf881)
buf887 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_780], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf881, buf887, 4194304, stream=stream0)
buf895 = buf824; del buf824 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, add_163, add_164, convert_element_type_781, convert_element_type_783, mul_335, mul_336, sum_71, mul_337, sum_72, mul_338, sub_78, sub_79, div_17, mul_339, mul_340, sum_73, sum_74, convert_element_type_785, add_165], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward]
workspace_14 = workspace_13; del workspace_13 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf895, buf868, buf875, buf882, primals_83, add_51, buf837, buf838, workspace_14, 327680, 2048, stream=stream0)
buf892 = workspace_14[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf894 = workspace_14[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_51
del primals_83
buf902 = buf831; del buf831 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_786, all_reduce_85], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf892, buf902, 2048, stream=stream0)
buf896 = buf825; del buf825 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_787, all_reduce_84], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf894, buf896, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_787, all_reduce_84], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf896, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_84], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf896)
buf901 = buf894; del buf894 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_788], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf896, buf901, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_786, all_reduce_85], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf902, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_85], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf902)
buf907 = buf892; del buf892 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_789], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf902, buf907, 2048, stream=stream0)
buf908 = reinterpret_tensor(buf805, (2048, 8192), (8192, 1), 0); del buf805 # reuse
# Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_81, buf908, 16777216, stream=stream0)
del primals_81
buf909 = buf803; del buf803 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, x_12, permute_317, mm_133], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf895, buf908, out=buf909)
buf910 = buf838; del buf838 # reuse
buf911 = buf837; del buf837 # reuse
buf913 = buf882; del buf882 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_46, primals_77, primals_78, buf910, buf911, buf913, 327680, 2048, stream=stream0)
del primals_78
buf914 = reinterpret_tensor(buf908, (2048, 8192), (1, 2048), 0); del buf908 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_79, buf914, 16777216, stream=stream0)
del primals_79
buf915 = buf789; del buf789 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm]
extern_kernels.mm(buf913, buf914, out=buf915)
buf916 = buf788; del buf788 # reuse
buf930 = buf909; del buf909 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, x, x_11, convert_element_type_796, mul_346, mul_347, sub_80, mul_348, add_168, mul_349, mul_350, mul_351, add_169, mul_352, convert_element_type_798], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf930, primals_80, buf915, buf916, 2684354560, stream=stream0)
del primals_80
buf917 = buf790; del buf790 # reuse
# Topologically Sorted Source Nodes: [permute_318, redistribute_3, x, x_11, mm_134], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf895, (2048, 327680), (1, 2048), 0), buf916, out=buf917)
buf918 = buf791; del buf791 # reuse
# Topologically Sorted Source Nodes: [sum_75], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_6.run(buf895, buf918, 327680, 2048, stream=stream0)
buf919 = reinterpret_tensor(buf902, (1, 2048), (2048, 1), 0); del buf902 # reuse
# Topologically Sorted Source Nodes: [sum_75], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_7.run(buf918, buf919, 2048, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_363, all_reduce_86], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf919, (2048, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_363, wait_tensor_86], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf919, (2048, ), (1, ), 0))
buf924 = empty_strided_cuda((2048, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_363, convert_element_type_794], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf919, buf924, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_87], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf917, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_87], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf917)
buf929 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_795], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf917, buf929, 16777216, stream=stream0)
buf931 = buf875; del buf875 # reuse
# Topologically Sorted Source Nodes: [permute_321, mm_135], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(buf930, reinterpret_tensor(buf914, (8192, 2048), (2048, 1), 0), out=buf931)
buf932 = reinterpret_tensor(buf914, (8192, 2048), (2048, 1), 0); del buf914 # reuse
# Topologically Sorted Source Nodes: [permute_322, mm_136], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf930, (8192, 327680), (1, 8192), 0), buf913, out=buf932)
buf933 = buf806; del buf806 # reuse
# Topologically Sorted Source Nodes: [sum_76], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_9.run(buf930, buf933, 1310720, 2048, stream=stream0)
buf934 = buf807; del buf807 # reuse
# Topologically Sorted Source Nodes: [sum_76], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_10.run(buf933, buf934, 8192, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_364, all_reduce_88], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf934, (8192, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_364, wait_tensor_88], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf934, (8192, ), (1, ), 0))
buf939 = empty_strided_cuda((8192, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_364, convert_element_type_803], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_view_11.run(buf934, buf939, 8192, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_89], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf932, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_89], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf932)
buf944 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_804], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf932, buf944, 16777216, stream=stream0)
buf951 = buf895; del buf895 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_805, convert_element_type_807, mul_354, mul_355, sum_77, mul_356, sum_78, mul_357, sub_82, sub_83, div_18, mul_358, mul_359, sum_79, sum_80, convert_element_type_809, add_170], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add]
workspace_15 = workspace_14; del workspace_14 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf951, buf931, primals_77, add_46, buf910, buf911, workspace_15, 327680, 2048, stream=stream0)
buf948 = workspace_15[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf950 = workspace_15[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_46
del primals_77
buf958 = reinterpret_tensor(buf919, (2048, ), (1, ), 0); del buf919 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_810, all_reduce_91], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf948, buf958, 2048, stream=stream0)
buf952 = buf896; del buf896 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_811, all_reduce_90], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf950, buf952, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_811, all_reduce_90], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf952, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_90], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf952)
buf957 = buf950; del buf950 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_812], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf952, buf957, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_810, all_reduce_91], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf958, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_91], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf958)
buf963 = buf948; del buf948 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_813], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf958, buf963, 2048, stream=stream0)
buf964 = buf911; del buf911 # reuse
buf965 = buf910; del buf910 # reuse
buf967 = buf931; del buf931 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_43, primals_71, primals_72, buf964, buf965, buf967, 327680, 2048, stream=stream0)
del primals_72
buf968 = reinterpret_tensor(buf881, (2048, 2048), (1, 2048), 0); del buf881 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_73, buf968, 4194304, stream=stream0)
del primals_73
buf969 = buf913; del buf913 # reuse
# Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm]
extern_kernels.mm(buf967, buf968, out=buf969)
buf970 = reinterpret_tensor(buf874, (2048, 512), (1, 2048), 0); del buf874 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_74, buf970, 1048576, stream=stream0)
del primals_74
buf971 = reinterpret_tensor(buf864, (327680, 512), (512, 1), 0); del buf864 # reuse
# Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm]
extern_kernels.mm(buf967, buf970, out=buf971)
buf972 = buf843; del buf843 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_75, buf972, 1048576, stream=stream0)
del primals_75
buf973 = reinterpret_tensor(buf863, (327680, 512), (512, 1), 0); del buf863 # reuse
# Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm]
extern_kernels.mm(buf967, buf972, out=buf973)
buf974 = buf861; del buf861 # reuse
buf975 = buf847; del buf847 # reuse
buf976 = reinterpret_tensor(buf868, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf868 # reuse
# Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_permute_view_15.run(buf969, buf971, buf973, buf974, buf975, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf976, s91, 2560, 1, 16, stream=stream0)
buf979 = reinterpret_tensor(buf841, (2048, 2048), (2048, 1), 0); del buf841 # reuse
# Topologically Sorted Source Nodes: [permute_325, rearrange_3, o, mm_137], Original ATen: [aten.t, aten.permute, aten.view, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf951, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf976, (327680, 2048), (2048, 1), 0), out=buf979)
buf980 = buf853; del buf853 # reuse
# Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_76, buf980, 4194304, stream=stream0)
del primals_76
buf981 = reinterpret_tensor(buf862, (327680, 2048), (2048, 1), 0); del buf862 # reuse
# Topologically Sorted Source Nodes: [redistribute_5, o, permute_327, mm_138], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf951, buf980, out=buf981)
# Topologically Sorted Source Nodes: [all_reduce_92], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf979, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_92], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf979)
buf986 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_818], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf979, buf986, 4194304, stream=stream0)
buf988 = buf975; del buf975 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_367, view_368, permute_329, flex_attention_backward_7], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_per_fused_flex_attention_backward_permute_view_17.run(buf976, buf981, buf988, 5242880, 128, stream=stream0)
buf989 = buf976; del buf976 # reuse
buf990 = reinterpret_tensor(buf846, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf846 # reuse
buf991 = reinterpret_tensor(buf844, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf844 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_367, view_368, permute_329, flex_attention_backward_7], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_backward_permute_view_18.run(buf969, buf971, buf973, buf974, buf988, buf981, buf989, buf990, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf991, s91, s16, 12800, 1, 4, stream=stream0)
buf994 = reinterpret_tensor(buf845, (512, 2048), (2048, 1), 0); del buf845 # reuse
# Topologically Sorted Source Nodes: [view_369, permute_330, view_370, permute_331, mm_139], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf990, (512, 327680), (1, 512), 0), buf967, out=buf994)
buf995 = buf981; del buf981 # reuse
# Topologically Sorted Source Nodes: [view_369, permute_330, view_370, permute_333, mm_140], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf990, (327680, 512), (512, 1), 0), reinterpret_tensor(buf972, (512, 2048), (2048, 1), 0), out=buf995)
# Topologically Sorted Source Nodes: [all_reduce_93], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf994, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_93], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf994)
buf1000 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_823], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf994, buf1000, 1048576, stream=stream0)
buf1001 = buf994; del buf994 # reuse
# Topologically Sorted Source Nodes: [view_371, permute_335, view_372, permute_336, mm_141], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf991, (512, 327680), (1, 512), 0), buf967, out=buf1001)
buf1002 = buf969; del buf969 # reuse
# Topologically Sorted Source Nodes: [view_371, permute_335, view_372, permute_338, mm_142], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf991, (327680, 512), (512, 1), 0), reinterpret_tensor(buf970, (512, 2048), (2048, 1), 0), out=buf1002)
# Topologically Sorted Source Nodes: [all_reduce_94], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1001, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_94], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1001)
buf1007 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_828], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf1001, buf1007, 1048576, stream=stream0)
buf1008 = buf979; del buf979 # reuse
# Topologically Sorted Source Nodes: [view_373, permute_340, view_374, permute_341, mm_143], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf989, (2048, 327680), (1, 2048), 0), buf967, out=buf1008)
buf1009 = buf967; del buf967 # reuse
# Topologically Sorted Source Nodes: [view_373, permute_340, view_374, permute_343, mm_144], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf989, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf968, (2048, 2048), (2048, 1), 0), out=buf1009)
# Topologically Sorted Source Nodes: [all_reduce_95], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1008, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_95], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1008)
buf1014 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_833], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf1008, buf1014, 4194304, stream=stream0)
buf1022 = buf951; del buf951 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, add_171, add_172, convert_element_type_834, convert_element_type_836, mul_361, mul_362, sum_81, mul_363, sum_82, mul_364, sub_85, sub_86, div_19, mul_365, mul_366, sum_83, sum_84, convert_element_type_838, add_173], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward]
workspace_16 = workspace_15; del workspace_15 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf1022, buf995, buf1002, buf1009, primals_71, add_43, buf964, buf965, workspace_16, 327680, 2048, stream=stream0)
buf1019 = workspace_16[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf1021 = workspace_16[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_43
del primals_71
buf1029 = buf958; del buf958 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_839, all_reduce_97], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1019, buf1029, 2048, stream=stream0)
buf1023 = buf952; del buf952 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_840, all_reduce_96], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1021, buf1023, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_840, all_reduce_96], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1023, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_96], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1023)
buf1028 = buf1021; del buf1021 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_841], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1023, buf1028, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_839, all_reduce_97], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1029, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_97], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1029)
buf1034 = buf1019; del buf1019 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_842], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1029, buf1034, 2048, stream=stream0)
buf1035 = reinterpret_tensor(buf932, (2048, 8192), (8192, 1), 0); del buf932 # reuse
# Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_69, buf1035, 16777216, stream=stream0)
del primals_69
buf1036 = buf930; del buf930 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, x_10, permute_345, mm_145], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf1022, buf1035, out=buf1036)
buf1037 = buf965; del buf965 # reuse
buf1038 = buf964; del buf964 # reuse
buf1040 = buf995; del buf995 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_38, primals_65, primals_66, buf1037, buf1038, buf1040, 327680, 2048, stream=stream0)
del primals_66
buf1041 = reinterpret_tensor(buf1035, (2048, 8192), (1, 2048), 0); del buf1035 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_67, buf1041, 16777216, stream=stream0)
del primals_67
buf1042 = buf916; del buf916 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm]
extern_kernels.mm(buf1040, buf1041, out=buf1042)
buf1043 = buf915; del buf915 # reuse
buf1057 = buf1036; del buf1036 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, x, x_9, convert_element_type_849, mul_372, mul_373, sub_87, mul_374, add_176, mul_375, mul_376, mul_377, add_177, mul_378, convert_element_type_851], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf1057, primals_68, buf1042, buf1043, 2684354560, stream=stream0)
del primals_68
buf1044 = buf917; del buf917 # reuse
# Topologically Sorted Source Nodes: [permute_346, redistribute_3, x, x_9, mm_146], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1022, (2048, 327680), (1, 2048), 0), buf1043, out=buf1044)
buf1045 = buf918; del buf918 # reuse
# Topologically Sorted Source Nodes: [sum_85], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_6.run(buf1022, buf1045, 327680, 2048, stream=stream0)
buf1046 = reinterpret_tensor(buf1029, (1, 2048), (2048, 1), 0); del buf1029 # reuse
# Topologically Sorted Source Nodes: [sum_85], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_7.run(buf1045, buf1046, 2048, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_375, all_reduce_98], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1046, (2048, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_375, wait_tensor_98], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1046, (2048, ), (1, ), 0))
buf1051 = empty_strided_cuda((2048, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_375, convert_element_type_847], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1046, buf1051, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_99], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1044, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_99], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1044)
buf1056 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_848], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf1044, buf1056, 16777216, stream=stream0)
buf1058 = buf1009; del buf1009 # reuse
# Topologically Sorted Source Nodes: [permute_349, mm_147], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(buf1057, reinterpret_tensor(buf1041, (8192, 2048), (2048, 1), 0), out=buf1058)
buf1059 = reinterpret_tensor(buf1041, (8192, 2048), (2048, 1), 0); del buf1041 # reuse
# Topologically Sorted Source Nodes: [permute_350, mm_148], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1057, (8192, 327680), (1, 8192), 0), buf1040, out=buf1059)
buf1060 = buf933; del buf933 # reuse
# Topologically Sorted Source Nodes: [sum_86], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_9.run(buf1057, buf1060, 1310720, 2048, stream=stream0)
buf1061 = buf934; del buf934 # reuse
# Topologically Sorted Source Nodes: [sum_86], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_10.run(buf1060, buf1061, 8192, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_376, all_reduce_100], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1061, (8192, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_376, wait_tensor_100], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1061, (8192, ), (1, ), 0))
buf1066 = empty_strided_cuda((8192, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_376, convert_element_type_856], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_view_11.run(buf1061, buf1066, 8192, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_101], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1059, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_101], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1059)
buf1071 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_857], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf1059, buf1071, 16777216, stream=stream0)
buf1078 = buf1022; del buf1022 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_858, convert_element_type_860, mul_380, mul_381, sum_87, mul_382, sum_88, mul_383, sub_89, sub_90, div_20, mul_384, mul_385, sum_89, sum_90, convert_element_type_862, add_178], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add]
workspace_17 = workspace_16; del workspace_16 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf1078, buf1058, primals_65, add_38, buf1037, buf1038, workspace_17, 327680, 2048, stream=stream0)
buf1075 = workspace_17[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf1077 = workspace_17[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_38
del primals_65
buf1085 = reinterpret_tensor(buf1046, (2048, ), (1, ), 0); del buf1046 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_863, all_reduce_103], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1075, buf1085, 2048, stream=stream0)
buf1079 = buf1023; del buf1023 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_864, all_reduce_102], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1077, buf1079, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_864, all_reduce_102], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1079, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_102], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1079)
buf1084 = buf1077; del buf1077 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_865], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1079, buf1084, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_863, all_reduce_103], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1085, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_103], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1085)
buf1090 = buf1075; del buf1075 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_866], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1085, buf1090, 2048, stream=stream0)
buf1091 = buf1038; del buf1038 # reuse
buf1092 = buf1037; del buf1037 # reuse
buf1094 = buf1058; del buf1058 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_35, primals_59, primals_60, buf1091, buf1092, buf1094, 327680, 2048, stream=stream0)
del primals_60
buf1095 = reinterpret_tensor(buf1008, (2048, 2048), (1, 2048), 0); del buf1008 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_61, buf1095, 4194304, stream=stream0)
del primals_61
buf1096 = buf1040; del buf1040 # reuse
# Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm]
extern_kernels.mm(buf1094, buf1095, out=buf1096)
buf1097 = reinterpret_tensor(buf1001, (2048, 512), (1, 2048), 0); del buf1001 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_62, buf1097, 1048576, stream=stream0)
del primals_62
buf1098 = reinterpret_tensor(buf991, (327680, 512), (512, 1), 0); del buf991 # reuse
# Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm]
extern_kernels.mm(buf1094, buf1097, out=buf1098)
buf1099 = buf970; del buf970 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_63, buf1099, 1048576, stream=stream0)
del primals_63
buf1100 = reinterpret_tensor(buf990, (327680, 512), (512, 1), 0); del buf990 # reuse
# Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm]
extern_kernels.mm(buf1094, buf1099, out=buf1100)
buf1101 = buf988; del buf988 # reuse
buf1102 = buf974; del buf974 # reuse
buf1103 = reinterpret_tensor(buf1002, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf1002 # reuse
# Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_permute_view_15.run(buf1096, buf1098, buf1100, buf1101, buf1102, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf1103, s91, 2560, 1, 16, stream=stream0)
buf1106 = reinterpret_tensor(buf968, (2048, 2048), (2048, 1), 0); del buf968 # reuse
# Topologically Sorted Source Nodes: [permute_353, rearrange_3, o, mm_149], Original ATen: [aten.t, aten.permute, aten.view, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1078, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf1103, (327680, 2048), (2048, 1), 0), out=buf1106)
buf1107 = buf980; del buf980 # reuse
# Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_64, buf1107, 4194304, stream=stream0)
del primals_64
buf1108 = reinterpret_tensor(buf989, (327680, 2048), (2048, 1), 0); del buf989 # reuse
# Topologically Sorted Source Nodes: [redistribute_5, o, permute_355, mm_150], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf1078, buf1107, out=buf1108)
# Topologically Sorted Source Nodes: [all_reduce_104], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1106, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_104], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1106)
buf1113 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_871], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf1106, buf1113, 4194304, stream=stream0)
buf1115 = buf1102; del buf1102 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_379, view_380, permute_357, flex_attention_backward_8], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_per_fused_flex_attention_backward_permute_view_17.run(buf1103, buf1108, buf1115, 5242880, 128, stream=stream0)
buf1116 = buf1103; del buf1103 # reuse
buf1117 = reinterpret_tensor(buf973, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf973 # reuse
buf1118 = reinterpret_tensor(buf971, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf971 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_379, view_380, permute_357, flex_attention_backward_8], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_backward_permute_view_18.run(buf1096, buf1098, buf1100, buf1101, buf1115, buf1108, buf1116, buf1117, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf1118, s91, s16, 12800, 1, 4, stream=stream0)
buf1121 = reinterpret_tensor(buf972, (512, 2048), (2048, 1), 0); del buf972 # reuse
# Topologically Sorted Source Nodes: [view_381, permute_358, view_382, permute_359, mm_151], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1117, (512, 327680), (1, 512), 0), buf1094, out=buf1121)
buf1122 = buf1108; del buf1108 # reuse
# Topologically Sorted Source Nodes: [view_381, permute_358, view_382, permute_361, mm_152], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1117, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1099, (512, 2048), (2048, 1), 0), out=buf1122)
# Topologically Sorted Source Nodes: [all_reduce_105], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1121, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_105], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1121)
buf1127 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_876], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf1121, buf1127, 1048576, stream=stream0)
buf1128 = buf1121; del buf1121 # reuse
# Topologically Sorted Source Nodes: [view_383, permute_363, view_384, permute_364, mm_153], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1118, (512, 327680), (1, 512), 0), buf1094, out=buf1128)
buf1129 = buf1096; del buf1096 # reuse
# Topologically Sorted Source Nodes: [view_383, permute_363, view_384, permute_366, mm_154], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1118, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1097, (512, 2048), (2048, 1), 0), out=buf1129)
# Topologically Sorted Source Nodes: [all_reduce_106], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1128, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_106], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1128)
buf1134 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_881], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf1128, buf1134, 1048576, stream=stream0)
buf1135 = buf1106; del buf1106 # reuse
# Topologically Sorted Source Nodes: [view_385, permute_368, view_386, permute_369, mm_155], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1116, (2048, 327680), (1, 2048), 0), buf1094, out=buf1135)
buf1136 = buf1094; del buf1094 # reuse
# Topologically Sorted Source Nodes: [view_385, permute_368, view_386, permute_371, mm_156], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1116, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf1095, (2048, 2048), (2048, 1), 0), out=buf1136)
# Topologically Sorted Source Nodes: [all_reduce_107], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1135, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_107], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1135)
buf1141 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_886], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf1135, buf1141, 4194304, stream=stream0)
buf1149 = buf1078; del buf1078 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, add_179, add_180, convert_element_type_887, convert_element_type_889, mul_387, mul_388, sum_91, mul_389, sum_92, mul_390, sub_92, sub_93, div_21, mul_391, mul_392, sum_93, sum_94, convert_element_type_891, add_181], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward]
workspace_18 = workspace_17; del workspace_17 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf1149, buf1122, buf1129, buf1136, primals_59, add_35, buf1091, buf1092, workspace_18, 327680, 2048, stream=stream0)
buf1146 = workspace_18[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf1148 = workspace_18[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_35
del primals_59
buf1156 = buf1085; del buf1085 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_892, all_reduce_109], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1146, buf1156, 2048, stream=stream0)
buf1150 = buf1079; del buf1079 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_893, all_reduce_108], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1148, buf1150, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_893, all_reduce_108], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1150, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_108], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1150)
buf1155 = buf1148; del buf1148 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_894], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1150, buf1155, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_892, all_reduce_109], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1156, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_109], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1156)
buf1161 = buf1146; del buf1146 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_895], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1156, buf1161, 2048, stream=stream0)
buf1162 = reinterpret_tensor(buf1059, (2048, 8192), (8192, 1), 0); del buf1059 # reuse
# Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_57, buf1162, 16777216, stream=stream0)
del primals_57
buf1163 = buf1057; del buf1057 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, x_8, permute_373, mm_157], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf1149, buf1162, out=buf1163)
buf1164 = buf1092; del buf1092 # reuse
buf1165 = buf1091; del buf1091 # reuse
buf1167 = buf1136; del buf1136 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_30, primals_53, primals_54, buf1164, buf1165, buf1167, 327680, 2048, stream=stream0)
del primals_54
buf1168 = reinterpret_tensor(buf1162, (2048, 8192), (1, 2048), 0); del buf1162 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_55, buf1168, 16777216, stream=stream0)
del primals_55
buf1169 = buf1043; del buf1043 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm]
extern_kernels.mm(buf1167, buf1168, out=buf1169)
buf1170 = buf1042; del buf1042 # reuse
buf1184 = buf1163; del buf1163 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, x, x_7, convert_element_type_902, mul_398, mul_399, sub_94, mul_400, add_184, mul_401, mul_402, mul_403, add_185, mul_404, convert_element_type_904], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf1184, primals_56, buf1169, buf1170, 2684354560, stream=stream0)
del primals_56
buf1171 = buf1044; del buf1044 # reuse
# Topologically Sorted Source Nodes: [permute_374, redistribute_3, x, x_7, mm_158], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1149, (2048, 327680), (1, 2048), 0), buf1170, out=buf1171)
buf1172 = buf1045; del buf1045 # reuse
# Topologically Sorted Source Nodes: [sum_95], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_6.run(buf1149, buf1172, 327680, 2048, stream=stream0)
buf1173 = reinterpret_tensor(buf1156, (1, 2048), (2048, 1), 0); del buf1156 # reuse
# Topologically Sorted Source Nodes: [sum_95], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_7.run(buf1172, buf1173, 2048, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_387, all_reduce_110], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1173, (2048, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_387, wait_tensor_110], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1173, (2048, ), (1, ), 0))
buf1178 = empty_strided_cuda((2048, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_387, convert_element_type_900], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1173, buf1178, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_111], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1171, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_111], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1171)
buf1183 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_901], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf1171, buf1183, 16777216, stream=stream0)
buf1185 = buf1129; del buf1129 # reuse
# Topologically Sorted Source Nodes: [permute_377, mm_159], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(buf1184, reinterpret_tensor(buf1168, (8192, 2048), (2048, 1), 0), out=buf1185)
buf1186 = reinterpret_tensor(buf1168, (8192, 2048), (2048, 1), 0); del buf1168 # reuse
# Topologically Sorted Source Nodes: [permute_378, mm_160], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1184, (8192, 327680), (1, 8192), 0), buf1167, out=buf1186)
buf1187 = buf1060; del buf1060 # reuse
# Topologically Sorted Source Nodes: [sum_96], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_9.run(buf1184, buf1187, 1310720, 2048, stream=stream0)
buf1188 = buf1061; del buf1061 # reuse
# Topologically Sorted Source Nodes: [sum_96], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_10.run(buf1187, buf1188, 8192, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_388, all_reduce_112], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1188, (8192, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_388, wait_tensor_112], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1188, (8192, ), (1, ), 0))
buf1193 = empty_strided_cuda((8192, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_388, convert_element_type_909], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_view_11.run(buf1188, buf1193, 8192, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_113], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1186, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_113], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1186)
buf1198 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_910], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf1186, buf1198, 16777216, stream=stream0)
buf1205 = buf1149; del buf1149 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_911, convert_element_type_913, mul_406, mul_407, sum_97, mul_408, sum_98, mul_409, sub_96, sub_97, div_22, mul_410, mul_411, sum_99, sum_100, convert_element_type_915, add_186], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add]
workspace_19 = workspace_18; del workspace_18 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf1205, buf1185, primals_53, add_30, buf1164, buf1165, workspace_19, 327680, 2048, stream=stream0)
buf1202 = workspace_19[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf1204 = workspace_19[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_30
del primals_53
buf1212 = reinterpret_tensor(buf1173, (2048, ), (1, ), 0); del buf1173 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_916, all_reduce_115], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1202, buf1212, 2048, stream=stream0)
buf1206 = buf1150; del buf1150 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_917, all_reduce_114], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1204, buf1206, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_917, all_reduce_114], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1206, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_114], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1206)
buf1211 = buf1204; del buf1204 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_918], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1206, buf1211, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_916, all_reduce_115], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1212, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_115], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1212)
buf1217 = buf1202; del buf1202 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_919], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1212, buf1217, 2048, stream=stream0)
buf1218 = buf1165; del buf1165 # reuse
buf1219 = buf1164; del buf1164 # reuse
buf1221 = buf1185; del buf1185 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_27, primals_47, primals_48, buf1218, buf1219, buf1221, 327680, 2048, stream=stream0)
del primals_48
buf1222 = reinterpret_tensor(buf1135, (2048, 2048), (1, 2048), 0); del buf1135 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_49, buf1222, 4194304, stream=stream0)
del primals_49
buf1223 = buf1167; del buf1167 # reuse
# Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm]
extern_kernels.mm(buf1221, buf1222, out=buf1223)
buf1224 = reinterpret_tensor(buf1128, (2048, 512), (1, 2048), 0); del buf1128 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_50, buf1224, 1048576, stream=stream0)
del primals_50
buf1225 = reinterpret_tensor(buf1118, (327680, 512), (512, 1), 0); del buf1118 # reuse
# Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm]
extern_kernels.mm(buf1221, buf1224, out=buf1225)
buf1226 = buf1097; del buf1097 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_51, buf1226, 1048576, stream=stream0)
del primals_51
buf1227 = reinterpret_tensor(buf1117, (327680, 512), (512, 1), 0); del buf1117 # reuse
# Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm]
extern_kernels.mm(buf1221, buf1226, out=buf1227)
buf1228 = buf1115; del buf1115 # reuse
buf1229 = buf1101; del buf1101 # reuse
buf1230 = reinterpret_tensor(buf1122, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf1122 # reuse
# Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_permute_view_15.run(buf1223, buf1225, buf1227, buf1228, buf1229, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf1230, s91, 2560, 1, 16, stream=stream0)
buf1233 = reinterpret_tensor(buf1095, (2048, 2048), (2048, 1), 0); del buf1095 # reuse
# Topologically Sorted Source Nodes: [permute_381, rearrange_3, o, mm_161], Original ATen: [aten.t, aten.permute, aten.view, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1205, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf1230, (327680, 2048), (2048, 1), 0), out=buf1233)
buf1234 = buf1107; del buf1107 # reuse
# Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_52, buf1234, 4194304, stream=stream0)
del primals_52
buf1235 = reinterpret_tensor(buf1116, (327680, 2048), (2048, 1), 0); del buf1116 # reuse
# Topologically Sorted Source Nodes: [redistribute_5, o, permute_383, mm_162], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf1205, buf1234, out=buf1235)
# Topologically Sorted Source Nodes: [all_reduce_116], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1233, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_116], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1233)
buf1240 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_924], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf1233, buf1240, 4194304, stream=stream0)
buf1242 = buf1229; del buf1229 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_391, view_392, permute_385, flex_attention_backward_9], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_per_fused_flex_attention_backward_permute_view_17.run(buf1230, buf1235, buf1242, 5242880, 128, stream=stream0)
buf1243 = buf1230; del buf1230 # reuse
buf1244 = reinterpret_tensor(buf1100, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf1100 # reuse
buf1245 = reinterpret_tensor(buf1098, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf1098 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_391, view_392, permute_385, flex_attention_backward_9], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_backward_permute_view_18.run(buf1223, buf1225, buf1227, buf1228, buf1242, buf1235, buf1243, buf1244, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf1245, s91, s16, 12800, 1, 4, stream=stream0)
buf1248 = reinterpret_tensor(buf1099, (512, 2048), (2048, 1), 0); del buf1099 # reuse
# Topologically Sorted Source Nodes: [view_393, permute_386, view_394, permute_387, mm_163], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1244, (512, 327680), (1, 512), 0), buf1221, out=buf1248)
buf1249 = buf1235; del buf1235 # reuse
# Topologically Sorted Source Nodes: [view_393, permute_386, view_394, permute_389, mm_164], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1244, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1226, (512, 2048), (2048, 1), 0), out=buf1249)
# Topologically Sorted Source Nodes: [all_reduce_117], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1248, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_117], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1248)
buf1254 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_929], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf1248, buf1254, 1048576, stream=stream0)
buf1255 = buf1248; del buf1248 # reuse
# Topologically Sorted Source Nodes: [view_395, permute_391, view_396, permute_392, mm_165], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1245, (512, 327680), (1, 512), 0), buf1221, out=buf1255)
buf1256 = buf1223; del buf1223 # reuse
# Topologically Sorted Source Nodes: [view_395, permute_391, view_396, permute_394, mm_166], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1245, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1224, (512, 2048), (2048, 1), 0), out=buf1256)
# Topologically Sorted Source Nodes: [all_reduce_118], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1255, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_118], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1255)
buf1261 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_934], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf1255, buf1261, 1048576, stream=stream0)
buf1262 = buf1233; del buf1233 # reuse
# Topologically Sorted Source Nodes: [view_397, permute_396, view_398, permute_397, mm_167], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1243, (2048, 327680), (1, 2048), 0), buf1221, out=buf1262)
buf1263 = buf1221; del buf1221 # reuse
# Topologically Sorted Source Nodes: [view_397, permute_396, view_398, permute_399, mm_168], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1243, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf1222, (2048, 2048), (2048, 1), 0), out=buf1263)
# Topologically Sorted Source Nodes: [all_reduce_119], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1262, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_119], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1262)
buf1268 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_939], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf1262, buf1268, 4194304, stream=stream0)
buf1276 = buf1205; del buf1205 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, add_187, add_188, convert_element_type_940, convert_element_type_942, mul_413, mul_414, sum_101, mul_415, sum_102, mul_416, sub_99, sub_100, div_23, mul_417, mul_418, sum_103, sum_104, convert_element_type_944, add_189], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward]
workspace_20 = workspace_19; del workspace_19 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf1276, buf1249, buf1256, buf1263, primals_47, add_27, buf1218, buf1219, workspace_20, 327680, 2048, stream=stream0)
buf1273 = workspace_20[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf1275 = workspace_20[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_27
del primals_47
buf1283 = buf1212; del buf1212 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_945, all_reduce_121], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1273, buf1283, 2048, stream=stream0)
buf1277 = buf1206; del buf1206 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_946, all_reduce_120], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1275, buf1277, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_946, all_reduce_120], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1277, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_120], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1277)
buf1282 = buf1275; del buf1275 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_947], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1277, buf1282, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_945, all_reduce_121], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1283, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_121], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1283)
buf1288 = buf1273; del buf1273 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_948], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1283, buf1288, 2048, stream=stream0)
buf1289 = reinterpret_tensor(buf1186, (2048, 8192), (8192, 1), 0); del buf1186 # reuse
# Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_45, buf1289, 16777216, stream=stream0)
del primals_45
buf1290 = buf1184; del buf1184 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, x_6, permute_401, mm_169], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf1276, buf1289, out=buf1290)
buf1291 = buf1219; del buf1219 # reuse
buf1292 = buf1218; del buf1218 # reuse
buf1294 = buf1263; del buf1263 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_22, primals_41, primals_42, buf1291, buf1292, buf1294, 327680, 2048, stream=stream0)
del primals_42
buf1295 = reinterpret_tensor(buf1289, (2048, 8192), (1, 2048), 0); del buf1289 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_43, buf1295, 16777216, stream=stream0)
del primals_43
buf1296 = buf1170; del buf1170 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm]
extern_kernels.mm(buf1294, buf1295, out=buf1296)
buf1297 = buf1169; del buf1169 # reuse
buf1311 = buf1290; del buf1290 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, x, x_5, convert_element_type_955, mul_424, mul_425, sub_101, mul_426, add_192, mul_427, mul_428, mul_429, add_193, mul_430, convert_element_type_957], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf1311, primals_44, buf1296, buf1297, 2684354560, stream=stream0)
del primals_44
buf1298 = buf1171; del buf1171 # reuse
# Topologically Sorted Source Nodes: [permute_402, redistribute_3, x, x_5, mm_170], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1276, (2048, 327680), (1, 2048), 0), buf1297, out=buf1298)
buf1299 = buf1172; del buf1172 # reuse
# Topologically Sorted Source Nodes: [sum_105], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_6.run(buf1276, buf1299, 327680, 2048, stream=stream0)
buf1300 = reinterpret_tensor(buf1283, (1, 2048), (2048, 1), 0); del buf1283 # reuse
# Topologically Sorted Source Nodes: [sum_105], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_7.run(buf1299, buf1300, 2048, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_399, all_reduce_122], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1300, (2048, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_399, wait_tensor_122], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1300, (2048, ), (1, ), 0))
buf1305 = empty_strided_cuda((2048, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_399, convert_element_type_953], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1300, buf1305, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_123], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1298, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_123], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1298)
buf1310 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_954], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf1298, buf1310, 16777216, stream=stream0)
buf1312 = buf1256; del buf1256 # reuse
# Topologically Sorted Source Nodes: [permute_405, mm_171], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(buf1311, reinterpret_tensor(buf1295, (8192, 2048), (2048, 1), 0), out=buf1312)
buf1313 = reinterpret_tensor(buf1295, (8192, 2048), (2048, 1), 0); del buf1295 # reuse
# Topologically Sorted Source Nodes: [permute_406, mm_172], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1311, (8192, 327680), (1, 8192), 0), buf1294, out=buf1313)
buf1314 = buf1187; del buf1187 # reuse
# Topologically Sorted Source Nodes: [sum_106], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_9.run(buf1311, buf1314, 1310720, 2048, stream=stream0)
buf1315 = buf1188; del buf1188 # reuse
# Topologically Sorted Source Nodes: [sum_106], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_10.run(buf1314, buf1315, 8192, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_400, all_reduce_124], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1315, (8192, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_400, wait_tensor_124], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1315, (8192, ), (1, ), 0))
buf1320 = empty_strided_cuda((8192, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_400, convert_element_type_962], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_view_11.run(buf1315, buf1320, 8192, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_125], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1313, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_125], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1313)
buf1325 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_963], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf1313, buf1325, 16777216, stream=stream0)
buf1332 = buf1276; del buf1276 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_964, convert_element_type_966, mul_432, mul_433, sum_107, mul_434, sum_108, mul_435, sub_103, sub_104, div_24, mul_436, mul_437, sum_109, sum_110, convert_element_type_968, add_194], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add]
workspace_21 = workspace_20; del workspace_20 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf1332, buf1312, primals_41, add_22, buf1291, buf1292, workspace_21, 327680, 2048, stream=stream0)
buf1329 = workspace_21[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf1331 = workspace_21[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_22
del primals_41
buf1339 = reinterpret_tensor(buf1300, (2048, ), (1, ), 0); del buf1300 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_969, all_reduce_127], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1329, buf1339, 2048, stream=stream0)
buf1333 = buf1277; del buf1277 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_970, all_reduce_126], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1331, buf1333, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_970, all_reduce_126], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1333, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_126], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1333)
buf1338 = buf1331; del buf1331 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_971], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1333, buf1338, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_969, all_reduce_127], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1339, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_127], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1339)
buf1344 = buf1329; del buf1329 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_972], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1339, buf1344, 2048, stream=stream0)
buf1345 = buf1292; del buf1292 # reuse
buf1346 = buf1291; del buf1291 # reuse
buf1348 = buf1312; del buf1312 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_19, primals_35, primals_36, buf1345, buf1346, buf1348, 327680, 2048, stream=stream0)
del primals_36
buf1349 = reinterpret_tensor(buf1262, (2048, 2048), (1, 2048), 0); del buf1262 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_37, buf1349, 4194304, stream=stream0)
del primals_37
buf1350 = buf1294; del buf1294 # reuse
# Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm]
extern_kernels.mm(buf1348, buf1349, out=buf1350)
buf1351 = reinterpret_tensor(buf1255, (2048, 512), (1, 2048), 0); del buf1255 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_38, buf1351, 1048576, stream=stream0)
del primals_38
buf1352 = reinterpret_tensor(buf1245, (327680, 512), (512, 1), 0); del buf1245 # reuse
# Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm]
extern_kernels.mm(buf1348, buf1351, out=buf1352)
buf1353 = buf1224; del buf1224 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_39, buf1353, 1048576, stream=stream0)
del primals_39
buf1354 = reinterpret_tensor(buf1244, (327680, 512), (512, 1), 0); del buf1244 # reuse
# Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm]
extern_kernels.mm(buf1348, buf1353, out=buf1354)
buf1355 = buf1242; del buf1242 # reuse
buf1356 = buf1228; del buf1228 # reuse
buf1357 = reinterpret_tensor(buf1249, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf1249 # reuse
# Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_permute_view_15.run(buf1350, buf1352, buf1354, buf1355, buf1356, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf1357, s91, 2560, 1, 16, stream=stream0)
buf1360 = reinterpret_tensor(buf1222, (2048, 2048), (2048, 1), 0); del buf1222 # reuse
# Topologically Sorted Source Nodes: [permute_409, rearrange_3, o, mm_173], Original ATen: [aten.t, aten.permute, aten.view, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1332, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf1357, (327680, 2048), (2048, 1), 0), out=buf1360)
buf1361 = buf1234; del buf1234 # reuse
# Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_40, buf1361, 4194304, stream=stream0)
del primals_40
buf1362 = reinterpret_tensor(buf1243, (327680, 2048), (2048, 1), 0); del buf1243 # reuse
# Topologically Sorted Source Nodes: [redistribute_5, o, permute_411, mm_174], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf1332, buf1361, out=buf1362)
# Topologically Sorted Source Nodes: [all_reduce_128], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1360, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_128], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1360)
buf1367 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_977], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf1360, buf1367, 4194304, stream=stream0)
buf1369 = buf1356; del buf1356 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_403, view_404, permute_413, flex_attention_backward_10], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_per_fused_flex_attention_backward_permute_view_17.run(buf1357, buf1362, buf1369, 5242880, 128, stream=stream0)
buf1370 = buf1357; del buf1357 # reuse
buf1371 = reinterpret_tensor(buf1227, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf1227 # reuse
buf1372 = reinterpret_tensor(buf1225, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf1225 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_403, view_404, permute_413, flex_attention_backward_10], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_backward_permute_view_18.run(buf1350, buf1352, buf1354, buf1355, buf1369, buf1362, buf1370, buf1371, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf1372, s91, s16, 12800, 1, 4, stream=stream0)
buf1375 = reinterpret_tensor(buf1226, (512, 2048), (2048, 1), 0); del buf1226 # reuse
# Topologically Sorted Source Nodes: [view_405, permute_414, view_406, permute_415, mm_175], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1371, (512, 327680), (1, 512), 0), buf1348, out=buf1375)
buf1376 = buf1362; del buf1362 # reuse
# Topologically Sorted Source Nodes: [view_405, permute_414, view_406, permute_417, mm_176], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1371, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1353, (512, 2048), (2048, 1), 0), out=buf1376)
# Topologically Sorted Source Nodes: [all_reduce_129], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1375, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_129], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1375)
buf1381 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_982], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf1375, buf1381, 1048576, stream=stream0)
buf1382 = buf1375; del buf1375 # reuse
# Topologically Sorted Source Nodes: [view_407, permute_419, view_408, permute_420, mm_177], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1372, (512, 327680), (1, 512), 0), buf1348, out=buf1382)
buf1383 = buf1350; del buf1350 # reuse
# Topologically Sorted Source Nodes: [view_407, permute_419, view_408, permute_422, mm_178], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1372, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1351, (512, 2048), (2048, 1), 0), out=buf1383)
# Topologically Sorted Source Nodes: [all_reduce_130], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1382, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_130], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1382)
buf1388 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_987], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf1382, buf1388, 1048576, stream=stream0)
buf1389 = buf1360; del buf1360 # reuse
# Topologically Sorted Source Nodes: [view_409, permute_424, view_410, permute_425, mm_179], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1370, (2048, 327680), (1, 2048), 0), buf1348, out=buf1389)
buf1390 = buf1348; del buf1348 # reuse
# Topologically Sorted Source Nodes: [view_409, permute_424, view_410, permute_427, mm_180], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1370, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf1349, (2048, 2048), (2048, 1), 0), out=buf1390)
# Topologically Sorted Source Nodes: [all_reduce_131], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1389, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_131], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1389)
buf1395 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_992], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf1389, buf1395, 4194304, stream=stream0)
buf1403 = buf1332; del buf1332 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, add_195, add_196, convert_element_type_993, convert_element_type_995, mul_439, mul_440, sum_111, mul_441, sum_112, mul_442, sub_106, sub_107, div_25, mul_443, mul_444, sum_113, sum_114, convert_element_type_997, add_197], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward]
workspace_22 = workspace_21; del workspace_21 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_20.run(buf1403, buf1376, buf1383, buf1390, primals_35, add_19, buf1345, buf1346, workspace_22, 327680, 2048, stream=stream0)
buf1400 = workspace_22[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf1402 = workspace_22[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_19
del primals_35
buf1410 = buf1339; del buf1339 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_998, all_reduce_133], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1400, buf1410, 2048, stream=stream0)
buf1404 = buf1333; del buf1333 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_999, all_reduce_132], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1402, buf1404, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_999, all_reduce_132], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1404, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_132], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1404)
buf1409 = buf1402; del buf1402 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_1000], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1404, buf1409, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_998, all_reduce_133], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1410, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_133], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1410)
buf1415 = buf1400; del buf1400 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_1001], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1410, buf1415, 2048, stream=stream0)
buf1416 = reinterpret_tensor(buf1313, (2048, 8192), (8192, 1), 0); del buf1313 # reuse
# Topologically Sorted Source Nodes: [redistribute_4], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_33, buf1416, 16777216, stream=stream0)
del primals_33
buf1417 = buf1311; del buf1311 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, x_4, permute_429, mm_181], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf1403, buf1416, out=buf1417)
buf1418 = buf1346; del buf1346 # reuse
buf1419 = buf1345; del buf1345 # reuse
buf1421 = buf1390; del buf1390 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_14, primals_29, primals_30, buf1418, buf1419, buf1421, 327680, 2048, stream=stream0)
del primals_30
buf1422 = reinterpret_tensor(buf1416, (2048, 8192), (1, 2048), 0); del buf1416 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, x], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_3.run(primals_31, buf1422, 16777216, stream=stream0)
del primals_31
buf1423 = buf1297; del buf1297 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.addmm]
extern_kernels.mm(buf1421, buf1422, out=buf1423)
buf1424 = buf1296; del buf1296 # reuse
buf1438 = buf1417; del buf1417 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, x, x_3, convert_element_type_1008, mul_450, mul_451, sub_108, mul_452, add_200, mul_453, mul_454, mul_455, add_201, mul_456, convert_element_type_1010], Original ATen: [aten._to_copy, aten.addmm, aten.gelu, aten.gelu_backward]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_addmm_gelu_gelu_backward_5.run(buf1438, primals_32, buf1423, buf1424, 2684354560, stream=stream0)
del buf1423
del primals_32
buf1425 = buf1298; del buf1298 # reuse
# Topologically Sorted Source Nodes: [permute_430, redistribute_3, x, x_3, mm_182], Original ATen: [aten.t, aten._to_copy, aten.addmm, aten.gelu, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1403, (2048, 327680), (1, 2048), 0), buf1424, out=buf1425)
del buf1424
buf1426 = buf1299; del buf1299 # reuse
# Topologically Sorted Source Nodes: [sum_115], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_6.run(buf1403, buf1426, 327680, 2048, stream=stream0)
buf1427 = reinterpret_tensor(buf1410, (1, 2048), (2048, 1), 0); del buf1410 # reuse
# Topologically Sorted Source Nodes: [sum_115], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_7.run(buf1426, buf1427, 2048, 160, stream=stream0)
# Topologically Sorted Source Nodes: [view_411, all_reduce_134], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1427, (2048, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_411, wait_tensor_134], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1427, (2048, ), (1, ), 0))
buf1432 = empty_strided_cuda((2048, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_411, convert_element_type_1006], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1427, buf1432, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [all_reduce_135], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1425, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_135], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1425)
buf1437 = empty_strided_cuda((2048, 8192), (8192, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_1007], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf1425, buf1437, 16777216, stream=stream0)
del buf1425
buf1439 = buf1383; del buf1383 # reuse
# Topologically Sorted Source Nodes: [permute_433, mm_183], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(buf1438, reinterpret_tensor(buf1422, (8192, 2048), (2048, 1), 0), out=buf1439)
buf1440 = reinterpret_tensor(buf1422, (8192, 2048), (2048, 1), 0); del buf1422 # reuse
# Topologically Sorted Source Nodes: [permute_434, mm_184], Original ATen: [aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1438, (8192, 327680), (1, 8192), 0), buf1421, out=buf1440)
buf1441 = buf1314; del buf1314 # reuse
# Topologically Sorted Source Nodes: [sum_116], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_9.run(buf1438, buf1441, 1310720, 2048, stream=stream0)
del buf1438
buf1442 = buf1315; del buf1315 # reuse
# Topologically Sorted Source Nodes: [sum_116], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_10.run(buf1441, buf1442, 8192, 160, stream=stream0)
del buf1441
# Topologically Sorted Source Nodes: [view_412, all_reduce_136], Original ATen: [aten.view, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf1442, (8192, ), (1, ), 0), 'avg', '0')
# Topologically Sorted Source Nodes: [view_412, wait_tensor_136], Original ATen: [aten.view, _c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf1442, (8192, ), (1, ), 0))
buf1447 = empty_strided_cuda((8192, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [view_412, convert_element_type_1015], Original ATen: [aten.view, aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_view_11.run(buf1442, buf1447, 8192, stream=stream0)
del buf1442
# Topologically Sorted Source Nodes: [all_reduce_137], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1440, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_137], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1440)
buf1452 = empty_strided_cuda((8192, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_1016], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_8.run(buf1440, buf1452, 16777216, stream=stream0)
del buf1440
buf1459 = buf1403; del buf1403 # reuse
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, convert_element_type_1017, convert_element_type_1019, mul_458, mul_459, sum_117, mul_460, sum_118, mul_461, sub_110, sub_111, div_26, mul_462, mul_463, sum_119, sum_120, convert_element_type_1021, add_202], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.native_layer_norm_backward, aten.add]
workspace_23 = workspace_22; del workspace_22 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_12.run(buf1459, buf1439, primals_29, add_14, buf1418, buf1419, workspace_23, 327680, 2048, stream=stream0)
buf1456 = workspace_23[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf1458 = workspace_23[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del add_14
del primals_29
buf1466 = reinterpret_tensor(buf1427, (2048, ), (1, ), 0); del buf1427 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_1022, all_reduce_139], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1456, buf1466, 2048, stream=stream0)
buf1460 = buf1404; del buf1404 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_1023, all_reduce_138], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1458, buf1460, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_1023, all_reduce_138], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1460, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_138], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1460)
buf1465 = buf1458; del buf1458 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_1024], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1460, buf1465, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_1022, all_reduce_139], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1466, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_139], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1466)
buf1471 = buf1456; del buf1456 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_1025], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1466, buf1471, 2048, stream=stream0)
buf1472 = buf1419; del buf1419 # reuse
buf1473 = buf1418; del buf1418 # reuse
buf1475 = buf1439; del buf1439 # reuse
# Topologically Sorted Source Nodes: [redistribute, redistribute_1, layer_norm], Original ATen: [aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_native_layer_norm_4.run(add_11, primals_9, primals_10, buf1472, buf1473, buf1475, 327680, 2048, stream=stream0)
del primals_10
buf1476 = reinterpret_tensor(buf1389, (2048, 2048), (1, 2048), 0); del buf1389 # reuse
# Topologically Sorted Source Nodes: [redistribute_2, linear], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_11, buf1476, 4194304, stream=stream0)
del primals_11
buf1477 = buf1421; del buf1421 # reuse
# Topologically Sorted Source Nodes: [linear], Original ATen: [aten.mm]
extern_kernels.mm(buf1475, buf1476, out=buf1477)
buf1478 = reinterpret_tensor(buf1382, (2048, 512), (1, 2048), 0); del buf1382 # reuse
# Topologically Sorted Source Nodes: [redistribute_3, linear_1], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_12, buf1478, 1048576, stream=stream0)
del primals_12
buf1479 = reinterpret_tensor(buf1372, (327680, 512), (512, 1), 0); del buf1372 # reuse
# Topologically Sorted Source Nodes: [linear_1], Original ATen: [aten.mm]
extern_kernels.mm(buf1475, buf1478, out=buf1479)
buf1480 = buf1351; del buf1351 # reuse
# Topologically Sorted Source Nodes: [redistribute_4, linear_2], Original ATen: [aten._to_copy, aten.t]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_14.run(primals_13, buf1480, 1048576, stream=stream0)
del primals_13
buf1481 = reinterpret_tensor(buf1371, (327680, 512), (512, 1), 0); del buf1371 # reuse
# Topologically Sorted Source Nodes: [linear_2], Original ATen: [aten.mm]
extern_kernels.mm(buf1475, buf1480, out=buf1481)
buf1482 = buf1369; del buf1369 # reuse
buf1483 = buf1355; del buf1355 # reuse
buf1484 = reinterpret_tensor(buf1376, (1, 16, 327680, 128), (671088640, 128, 2048, 1), 0); del buf1376 # reuse
# Topologically Sorted Source Nodes: [q, k, v, flex_attention], Original ATen: [aten.view, aten.permute, flex_attention]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_permute_view_15.run(buf1477, buf1479, buf1481, buf1482, buf1483, primals_18, primals_17, primals_19, primals_21, primals_14, primals_15, buf1484, s91, 2560, 1, 16, stream=stream0)
buf1487 = reinterpret_tensor(buf1349, (2048, 2048), (2048, 1), 0); del buf1349 # reuse
# Topologically Sorted Source Nodes: [permute_437, rearrange_3, o, mm_185], Original ATen: [aten.t, aten.permute, aten.view, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1459, (2048, 327680), (1, 2048), 0), reinterpret_tensor(buf1484, (327680, 2048), (2048, 1), 0), out=buf1487)
buf1488 = buf1361; del buf1361 # reuse
# Topologically Sorted Source Nodes: [redistribute_5], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_t_13.run(primals_28, buf1488, 4194304, stream=stream0)
del primals_28
buf1489 = reinterpret_tensor(buf1370, (327680, 2048), (2048, 1), 0); del buf1370 # reuse
# Topologically Sorted Source Nodes: [redistribute_5, o, permute_439, mm_186], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf1459, buf1488, out=buf1489)
del buf1488
# Topologically Sorted Source Nodes: [all_reduce_140], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1487, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_140], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1487)
buf1494 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_1030], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf1487, buf1494, 4194304, stream=stream0)
buf1496 = buf1483; del buf1483 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_415, view_416, permute_441, flex_attention_backward_11], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_per_fused_flex_attention_backward_permute_view_17.run(buf1484, buf1489, buf1496, 5242880, 128, stream=stream0)
buf1497 = buf1484; del buf1484 # reuse
buf1498 = reinterpret_tensor(buf1354, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf1354 # reuse
buf1499 = reinterpret_tensor(buf1352, (1, 4, 327680, 128), (167772160, 128, 512, 1), 0); del buf1352 # reuse
# Topologically Sorted Source Nodes: [q, k, v, view_415, view_416, permute_441, flex_attention_backward_11], Original ATen: [aten.view, aten.permute, flex_attention_backward]
stream0 = get_raw_stream(0)
triton_tem_fused_flex_attention_backward_permute_view_18.run(buf1477, buf1479, buf1481, buf1482, buf1496, buf1489, buf1497, buf1498, primals_18, primals_17, primals_22, primals_24, primals_19, primals_21, primals_25, primals_27, primals_14, primals_15, buf1499, s91, s16, 12800, 1, 4, stream=stream0)
del buf1479
del buf1481
del buf1482
del buf1496
del primals_14
del primals_15
del primals_17
del primals_18
del primals_19
del primals_21
del primals_22
del primals_24
del primals_25
del primals_27
buf1502 = reinterpret_tensor(buf1353, (512, 2048), (2048, 1), 0); del buf1353 # reuse
# Topologically Sorted Source Nodes: [view_417, permute_442, view_418, permute_443, mm_187], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1498, (512, 327680), (1, 512), 0), buf1475, out=buf1502)
buf1503 = buf1489; del buf1489 # reuse
# Topologically Sorted Source Nodes: [view_417, permute_442, view_418, permute_445, mm_188], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1498, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1480, (512, 2048), (2048, 1), 0), out=buf1503)
del buf1480
del buf1498
# Topologically Sorted Source Nodes: [all_reduce_141], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1502, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_141], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1502)
buf1508 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_1035], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf1502, buf1508, 1048576, stream=stream0)
buf1509 = buf1502; del buf1502 # reuse
# Topologically Sorted Source Nodes: [view_419, permute_447, view_420, permute_448, mm_189], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1499, (512, 327680), (1, 512), 0), buf1475, out=buf1509)
buf1510 = buf1477; del buf1477 # reuse
# Topologically Sorted Source Nodes: [view_419, permute_447, view_420, permute_450, mm_190], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1499, (327680, 512), (512, 1), 0), reinterpret_tensor(buf1478, (512, 2048), (2048, 1), 0), out=buf1510)
del buf1478
del buf1499
# Topologically Sorted Source Nodes: [all_reduce_142], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1509, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_142], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1509)
buf1515 = empty_strided_cuda((512, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_1040], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_19.run(buf1509, buf1515, 1048576, stream=stream0)
del buf1509
buf1516 = buf1487; del buf1487 # reuse
# Topologically Sorted Source Nodes: [view_421, permute_452, view_422, permute_453, mm_191], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1497, (2048, 327680), (1, 2048), 0), buf1475, out=buf1516)
buf1517 = buf1475; del buf1475 # reuse
# Topologically Sorted Source Nodes: [view_421, permute_452, view_422, permute_455, mm_192], Original ATen: [aten.view, aten.permute, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1497, (327680, 2048), (2048, 1), 0), reinterpret_tensor(buf1476, (2048, 2048), (2048, 1), 0), out=buf1517)
del buf1476
# Topologically Sorted Source Nodes: [all_reduce_143], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1516, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_143], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1516)
buf1522 = empty_strided_cuda((2048, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_1045], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_16.run(buf1516, buf1522, 4194304, stream=stream0)
del buf1516
buf1542 = empty_strided_cuda((327680, 16), (16, 1), torch.uint8)
# Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_3, contiguous_1], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone]
stream0 = get_raw_stream(0)
triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_21.run(primals_1, buf1542, 5242880, stream=stream0)
# Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_3, contiguous_1, positions], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone, aten.view]
buf1543 = torch.ops.aten.view.dtype(buf1542, torch.int32)
buf1544 = buf1543
assert_size_stride(buf1544, (327680, 4), (4, 1), 'torch.ops.aten.view.dtype')
assert_alignment(buf1544, 16, 'torch.ops.aten.view.dtype')
buf1545 = empty_strided_cuda((327680, 2048), (2048, 1), torch.float32)
buf1546 = reinterpret_tensor(buf1426, (327680, 1), (1, 327680), 0); del buf1426 # reuse
buf1547 = empty_strided_cuda((327680, 1), (1, 327680), torch.float32)
# Topologically Sorted Source Nodes: [i, j, ifreqs, ifreqs_1, neg, freqs, getitem_6, getitem_7, mul_1, sin, cos, getitem_10, mul_3, sin_1, cos_1, posemb, to_1, layer_norm_1], Original ATen: [aten.select, aten.arange, aten.div, aten.neg, aten.pow, aten.unsqueeze, aten.mul, aten.sin, aten.cos, aten.cat, aten._to_copy, aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_arange_cat_cos_div_mul_native_layer_norm_neg_pow_select_sin_unsqueeze_22.run(buf1544, buf1545, buf1546, buf1547, 327680, 2048, stream=stream0)
del buf1542
del buf1543
del buf1544
buf1568 = empty_strided_cuda((327680, 768), (768, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_2, patches, to, truediv, patches_1], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone, aten._to_copy, aten.div, aten.sub]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_clone_div_eq_mul_select_slice_sub_unsqueeze_23.run(primals_1, buf1568, 251658240, stream=stream0)
buf1569 = empty_strided_cuda((2048, 768), (768, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [redistribute_1], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_24.run(primals_4, buf1569, 1572864, stream=stream0)
del primals_4
buf1570 = reinterpret_tensor(buf1497, (327680, 2048), (2048, 1), 0); del buf1497 # reuse
# Topologically Sorted Source Nodes: [redistribute_1, linear], Original ATen: [aten._to_copy, aten.t, aten.mm]
extern_kernels.mm(buf1568, reinterpret_tensor(buf1569, (768, 2048), (1, 768), 0), out=buf1570)
buf1571 = empty_strided_cuda((327680, 1), (1, 327680), torch.float32)
buf1572 = empty_strided_cuda((327680, 1), (1, 327680), torch.float32)
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.native_layer_norm]
stream0 = get_raw_stream(0)
triton_red_fused_native_layer_norm_25.run(buf1570, buf1571, buf1572, 327680, 2048, stream=stream0)
buf1525 = empty_strided_cuda((327680, 2048), (2048, 1), torch.float32)
buf1549 = buf1545; del buf1545 # reuse
buf1575 = empty_strided_cuda((327680, 2048), (2048, 1), torch.float32)
buf1592 = empty_strided_cuda((327680, 2048), (2048, 1), torch.bfloat16)
buf1585 = empty_strided_cuda((327680, 2048), (2048, 1), torch.bfloat16)
# Call mix order reduction kernel
# Topologically Sorted Source Nodes: [redistribute, layer_norm, add_203, add_204, convert_element_type_1046, convert_element_type_1048, mul_465, mul_466, sum_121, mul_467, sum_122, mul_468, sub_113, sub_114, div_27, mul_469, mul_470, sum_123, sum_124, convert_element_type_1050, add_205, getitem, mask, getitem_1, mul_471, convert_element_type_1055, to_1, layer_norm_1, mul_477, redistribute_2, convert_element_type_1065, mul_479, mul_480, sum_129, x, mul_481, sum_130, mul_482, sub_119, sub_120, div_28, mul_483, mul_484, convert_element_type_1067, mul_485, slice_21, full_default, copy_3], Original ATen: [aten._to_copy, aten.native_layer_norm, aten.add, aten.native_layer_norm_backward, aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.zeros_like, aten.copy]
workspace_24 = workspace_23; del workspace_23 # reuse
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_copy_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_unsqueeze_zeros_like_26.run(buf1549, buf1503, buf1510, buf1517, primals_9, add_11, buf1472, buf1473, buf1459, primals_1, buf1546, buf1547, primals_5, buf1570, buf1571, buf1572, buf1525, buf1575, buf1592, buf1585, workspace_24, 327680, 2048, stream=stream0)
buf1527 = workspace_24[0 * 2560 * 2048 : (0 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
buf1529 = workspace_24[1 * 2560 * 2048 : (1 + 1) * 2560 * 2048].view(2560, 2048).sum(dim=0)
del workspace_24
del add_11
del buf1472
del buf1503
del buf1510
del buf1517
del buf1546
del buf1547
del buf1570
del buf1571
del primals_5
del primals_9
buf1536 = buf1466; del buf1466 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_1051, all_reduce_145], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1527, buf1536, 2048, stream=stream0)
buf1530 = buf1460; del buf1460 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_1052, all_reduce_144], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_native_layer_norm_backward_1.run(buf1529, buf1530, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_1052, all_reduce_144], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1530, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_144], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1530)
buf1535 = buf1529; del buf1529 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_1053], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1530, buf1535, 2048, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_1051, all_reduce_145], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1536, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_145], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1536)
buf1541 = buf1527; del buf1527 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_1054], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1536, buf1541, 2048, stream=stream0)
buf1550 = reinterpret_tensor(buf1572, (2048, 160), (1, 2048), 0); del buf1572 # reuse
# Topologically Sorted Source Nodes: [sum_127], Original ATen: [aten.native_layer_norm_backward]
stream0 = get_raw_stream(0)
triton_red_fused_native_layer_norm_backward_27.run(buf1549, buf1550, 327680, 2048, stream=stream0)
del buf1549
buf1560 = buf1536; del buf1536 # reuse
# Topologically Sorted Source Nodes: [sum_127, convert_element_type_1059, all_reduce_147], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_red_fused_all_reduce_native_layer_norm_backward_28.run(buf1550, buf1560, 2048, 160, stream=stream0)
buf1552 = buf1550; del buf1550 # reuse
# Topologically Sorted Source Nodes: [layer_norm, div_27, mul_469, convert_element_type_1050, add_205, getitem, mask, getitem_1, mul_471, convert_element_type_1055, sum_128], Original ATen: [aten.native_layer_norm, aten.native_layer_norm_backward, aten.add, aten.select, aten.eq, aten.unsqueeze, aten.mul]
stream0 = get_raw_stream(0)
triton_red_fused_add_eq_mul_native_layer_norm_native_layer_norm_backward_select_unsqueeze_29.run(buf1459, buf1473, buf1525, primals_1, buf1552, 327680, 2048, stream=stream0)
buf1554 = buf1530; del buf1530 # reuse
# Topologically Sorted Source Nodes: [layer_norm, div_27, mul_469, convert_element_type_1050, add_205, getitem, mask, getitem_1, mul_471, convert_element_type_1055, sum_128, convert_element_type_1060, all_reduce_146], Original ATen: [aten.native_layer_norm, aten.native_layer_norm_backward, aten.add, aten.select, aten.eq, aten.unsqueeze, aten.mul, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_red_fused_all_reduce_native_layer_norm_backward_28.run(buf1552, buf1554, 2048, 160, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_1060, all_reduce_146], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1554, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_146], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1554)
buf1559 = empty_strided_cuda((2048, ), (1, ), torch.float32)
buf1578 = empty_strided_cuda((2048, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_1061, convert_element_type_1070], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_30.run(buf1554, buf1559, buf1578, 2048, stream=stream0)
del buf1554
# Topologically Sorted Source Nodes: [convert_element_type_1059, all_reduce_147], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1560, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_147], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1560)
buf1565 = empty_strided_cuda((2048, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_1062], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1560, buf1565, 2048, stream=stream0)
buf1576 = buf1552; del buf1552 # reuse
# Topologically Sorted Source Nodes: [sum_131], Original ATen: [aten.native_layer_norm_backward]
stream0 = get_raw_stream(0)
triton_red_fused_native_layer_norm_backward_27.run(buf1575, buf1576, 327680, 2048, stream=stream0)
del buf1575
buf1579 = buf1560; del buf1560 # reuse
# Topologically Sorted Source Nodes: [sum_131, convert_element_type_1068, all_reduce_149], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_red_fused_all_reduce_native_layer_norm_backward_28.run(buf1576, buf1579, 2048, 160, stream=stream0)
del buf1576
# Topologically Sorted Source Nodes: [convert_element_type_1068, all_reduce_149], Original ATen: [aten.native_layer_norm_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1579, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_149], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1579)
buf1584 = empty_strided_cuda((2048, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_1071], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_2.run(buf1579, buf1584, 2048, stream=stream0)
del buf1579
buf1586 = buf1569; del buf1569 # reuse
# Topologically Sorted Source Nodes: [mul_480, x, mul_482, sub_119, sub_120, div_28, mul_483, convert_element_type_1067, permute_457, mm_193], Original ATen: [aten.native_layer_norm_backward, aten.native_layer_norm, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(buf1585, (2048, 327680), (1, 2048), 0), buf1568, out=buf1586)
del buf1568
# Topologically Sorted Source Nodes: [all_reduce_150], Original ATen: [_c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1586, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_150], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1586)
buf1591 = empty_strided_cuda((2048, 768), (768, 1), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_1074], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_31.run(buf1586, buf1591, 1572864, stream=stream0)
del buf1586
buf1593 = buf1585; del buf1585 # reuse
# Topologically Sorted Source Nodes: [layer_norm, div_27, mul_469, convert_element_type_1050, add_205, getitem, mask, getitem_1, mul_485, slice_21, clone_112, full_default_1], Original ATen: [aten.native_layer_norm, aten.native_layer_norm_backward, aten.add, aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone, aten.slice_backward]
stream0 = get_raw_stream(0)
triton_poi_fused_add_clone_eq_mul_native_layer_norm_native_layer_norm_backward_select_slice_slice_backward_unsqueeze_32.run(buf1459, buf1473, buf1525, primals_1, buf1593, 671088640, stream=stream0)
del buf1473
buf1595 = empty_strided_cuda((2560, 64), (64, 1), torch.float32)
# Topologically Sorted Source Nodes: [add_206, slice_24, clone_113, convert_element_type_1075, mul_486, sum_133], Original ATen: [aten.add, aten.slice, aten.clone, aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_clone_mul_slice_sum_33.run(buf1592, buf1593, cos, buf1595, 163840, 2048, stream=stream0)
del cos
buf1596 = empty_strided_cuda((2560, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [add_206, slice_24, clone_113, convert_element_type_1075, mul_486, sum_133], Original ATen: [aten.add, aten.slice, aten.clone, aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_clone_mul_slice_sum_34.run(buf1595, buf1596, 2560, 64, stream=stream0)
buf1598 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [add_206, slice_24, clone_113, convert_element_type_1075, mul_486, sum_133, convert_element_type_1076], Original ATen: [aten.add, aten.slice, aten.clone, aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_clone_mul_slice_sum_35.run(buf1596, buf1598, 1, 2560, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_1076, all_reduce_151], Original ATen: [aten._to_copy, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1598, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_151], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1598)
buf1603 = buf1459; del buf1459 # reuse
# Topologically Sorted Source Nodes: [full_default, full_default_1, add_206, slice_24, clone_113, copy_5, slice_27, clone_114, copy_7, add_207], Original ATen: [aten.zeros_like, aten.slice_backward, aten.add, aten.slice, aten.clone, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_add_clone_copy_slice_slice_backward_zeros_like_36.run(buf1592, buf1593, buf1603, 671088640, stream=stream0)
del buf1592
del buf1593
buf1604 = empty_strided_cuda((327680, 4), (4, 1), torch.uint8)
# Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_4, contiguous_1], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone]
stream0 = get_raw_stream(0)
triton_poi_fused_clone_eq_mul_select_slice_unsqueeze_37.run(primals_1, buf1604, 1310720, stream=stream0)
del primals_1
# Topologically Sorted Source Nodes: [getitem, mask, getitem_1, data, getitem_4, contiguous_1, view_1], Original ATen: [aten.select, aten.eq, aten.unsqueeze, aten.mul, aten.slice, aten.clone, aten.view]
buf1605 = torch.ops.aten.view.dtype(buf1604, torch.int32)
buf1606 = buf1605
assert_size_stride(buf1606, (327680, 1), (1, 1), 'torch.ops.aten.view.dtype')
assert_alignment(buf1606, 16, 'torch.ops.aten.view.dtype')
buf1608 = buf1595; del buf1595 # reuse
# Topologically Sorted Source Nodes: [slice_30, clone_115, convert_element_type_1078, positions, ifreqs, ifreqs_1, neg, freqs, getitem_7, getitem_8, mul_1, sin, mul_487, sum_134], Original ATen: [aten.slice, aten.clone, aten._to_copy, aten.select, aten.arange, aten.div, aten.neg, aten.pow, aten.unsqueeze, aten.mul, aten.sin, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_arange_clone_div_mul_neg_pow_select_sin_slice_sum_unsqueeze_38.run(buf1603, buf1606, buf1608, 163840, 2048, stream=stream0)
del buf1604
del buf1605
del buf1606
buf1609 = buf1596; del buf1596 # reuse
# Topologically Sorted Source Nodes: [slice_30, clone_115, convert_element_type_1078, positions, ifreqs, ifreqs_1, neg, freqs, getitem_7, getitem_8, mul_1, sin, mul_487, sum_134], Original ATen: [aten.slice, aten.clone, aten._to_copy, aten.select, aten.arange, aten.div, aten.neg, aten.pow, aten.unsqueeze, aten.mul, aten.sin, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_clone_mul_slice_sum_34.run(buf1608, buf1609, 2560, 64, stream=stream0)
del buf1608
buf1611 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [slice_30, clone_115, convert_element_type_1078, positions, ifreqs, ifreqs_1, neg, freqs, getitem_7, getitem_8, mul_1, sin, mul_487, sum_134, convert_element_type_1079], Original ATen: [aten.slice, aten.clone, aten._to_copy, aten.select, aten.arange, aten.div, aten.neg, aten.pow, aten.unsqueeze, aten.mul, aten.sin, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_clone_mul_slice_sum_35.run(buf1609, buf1611, 1, 2560, stream=stream0)
del buf1609
# Topologically Sorted Source Nodes: [convert_element_type_1079, all_reduce_152], Original ATen: [aten._to_copy, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1611, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_152], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1611)
buf1616 = empty_strided_cuda((259, 2048), (2048, 1), torch.float32)
# Topologically Sorted Source Nodes: [full_default_5], Original ATen: [aten.embedding_dense_backward]
stream0 = get_raw_stream(0)
triton_poi_fused_embedding_dense_backward_39.run(buf1616, 530432, stream=stream0)
buf1617 = buf1525; del buf1525 # reuse
# Topologically Sorted Source Nodes: [slice_30, clone_115, copy_9, convert_element_type_1081, eq_2, unsqueeze_16, full_default_4, where], Original ATen: [aten.slice, aten.clone, aten.copy, aten.embedding_dense_backward]
stream0 = get_raw_stream(0)
triton_poi_fused_clone_copy_embedding_dense_backward_slice_40.run(select_1, buf1603, buf1617, 671088640, stream=stream0)
del buf1603
aten.index_put_(buf1616, [select_1], buf1617, True)
del buf1617
del select_1
buf1619 = empty_strided_cuda((259, 2048), (2048, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [convert_element_type_1082, all_reduce_153], Original ATen: [aten.embedding_dense_backward, _c10d_functional.all_reduce]
stream0 = get_raw_stream(0)
triton_poi_fused_all_reduce_embedding_dense_backward_41.run(buf1616, buf1619, 530432, stream=stream0)
# Topologically Sorted Source Nodes: [convert_element_type_1082, all_reduce_153], Original ATen: [aten.embedding_dense_backward, _c10d_functional.all_reduce]
torch.ops._c10d_functional.all_reduce_.default(buf1619, 'avg', '0')
# Topologically Sorted Source Nodes: [wait_tensor_153], Original ATen: [_c10d_functional.wait_tensor]
torch.ops._c10d_functional.wait_tensor.default(buf1619)
buf1624 = buf1616; del buf1616 # reuse
# Topologically Sorted Source Nodes: [convert_element_type_1083], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_42.run(buf1619, buf1624, 530432, stream=stream0)
del buf1619
buf1625 = empty_strided_cuda((), (), torch.float32)
# Topologically Sorted Source Nodes: [convert_element_type_1077, convert_element_type_1080, add_209], Original ATen: [aten._to_copy, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_add_43.run(buf1598, buf1611, buf1625, 1, stream=stream0)
del buf1598
del buf1611
return (None, buf1624, buf1625, buf1591, buf1584, buf1578, buf1565, buf1559, buf1541, buf1535, buf1522, buf1515, buf1508, None, None, None, None, None, None, None, None, None, None, None, None, None, None, buf1494, buf1471, buf1465, buf1452, buf1447, buf1437, buf1432, buf1415, buf1409, buf1395, buf1388, buf1381, buf1367, buf1344, buf1338, buf1325, buf1320, buf1310, buf1305, buf1288, buf1282, buf1268, buf1261, buf1254, buf1240, buf1217, buf1211, buf1198, buf1193, buf1183, buf1178, buf1161, buf1155, buf1141, buf1134, buf1127, buf1113, buf1090, buf1084, buf1071, buf1066, buf1056, buf1051, buf1034, buf1028, buf1014, buf1007, buf1000, buf986, buf963, buf957, buf944, buf939, buf929, buf924, buf907, buf901, buf887, buf880, buf873, buf859, buf836, buf830, buf817, buf812, buf802, buf797, buf780, buf774, buf760, buf753, buf746, buf732, buf709, buf703, buf690, buf685, buf675, buf670, buf653, buf647, buf633, buf626, buf619, buf605, buf582, buf576, buf563, buf558, buf548, buf543, buf526, buf520, buf506, buf499, buf492, buf478, buf455, buf449, buf436, buf431, buf421, buf416, buf399, buf393, buf379, buf372, buf365, buf351, buf328, buf322, buf309, buf304, buf294, buf289, buf272, buf266, buf252, buf245, buf238, buf224, buf201, buf195, buf182, buf177, buf167, buf162, buf145, buf139, buf125, buf118, buf111, buf97, buf74, buf68, buf55, buf50, buf40, buf35, buf18, buf12, )
runner = Runner(partitions=[])
call = runner.call
recursively_apply_fns = runner.recursively_apply_fns
def get_args():
from torch._dynamo.testing import rand_strided
primals_16 = 131
primals_20 = 131
primals_23 = 131
primals_26 = 131
primals_1 = rand_strided((327680, 785), (785, 1), device='cuda:0', dtype=torch.uint8)
primals_4 = rand_strided((2048, 768), (768, 1), device='cuda:0', dtype=torch.float32)
primals_5 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_9 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_10 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_11 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_12 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_13 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_14 = rand_strided((327680, ), (1, ), device='cuda:0', dtype=torch.int64)
primals_15 = rand_strided((327680, ), (1, ), device='cuda:0', dtype=torch.int64)
primals_17 = rand_strided((1, 1, 2560, 131), (335360, 335360, 131, 1), device='cuda:0', dtype=torch.int32)
primals_18 = rand_strided((1, 1, 2560), (2560, 2560, 1), device='cuda:0', dtype=torch.int32)
primals_19 = rand_strided((1, 1, 2560), (2560, 2560, 1), device='cuda:0', dtype=torch.int32)
primals_21 = rand_strided((1, 1, 2560, 131), (335360, 335360, 131, 1), device='cuda:0', dtype=torch.int32)
primals_22 = rand_strided((1, 1, 2560), (2560, 2560, 1), device='cuda:0', dtype=torch.int32)
primals_24 = rand_strided((1, 1, 2560, 131), (335360, 335360, 131, 1), device='cuda:0', dtype=torch.int32)
primals_25 = rand_strided((1, 1, 2560), (2560, 2560, 1), device='cuda:0', dtype=torch.int32)
primals_27 = rand_strided((1, 1, 2560, 131), (335360, 335360, 131, 1), device='cuda:0', dtype=torch.int32)
primals_28 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_29 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_30 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_31 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_32 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_33 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32)
primals_35 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_36 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_37 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_38 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_39 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_40 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_41 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_42 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_43 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_44 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_45 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32)
primals_47 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_48 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_49 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_50 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_51 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_52 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_53 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_54 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_55 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_56 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_57 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32)
primals_59 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_60 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_61 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_62 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_63 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_64 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_65 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_66 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_67 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_68 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_69 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32)
primals_71 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_72 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_73 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_74 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_75 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_76 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_77 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_78 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_79 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_80 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_81 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32)
primals_83 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_84 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_85 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_86 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_87 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_88 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_89 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_90 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_91 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_92 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_93 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32)
primals_95 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_96 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_97 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_98 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_99 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_100 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_101 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_102 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_103 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_104 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_105 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32)
primals_107 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_108 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_109 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_110 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_111 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_112 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_113 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_114 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_115 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_116 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_117 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32)
primals_119 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_120 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_121 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_122 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_123 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_124 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_125 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_126 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_127 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_128 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_129 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32)
primals_131 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_132 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_133 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_134 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_135 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_136 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_137 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_138 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_139 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_140 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_141 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32)
primals_143 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_144 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_145 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_146 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_147 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_148 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_149 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_150 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_151 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_152 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_153 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32)
primals_155 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_156 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_157 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_158 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_159 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_160 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_161 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_162 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_163 = rand_strided((8192, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
primals_164 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_165 = rand_strided((2048, 8192), (8192, 1), device='cuda:0', dtype=torch.float32)
primals_167 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
select_1 = rand_strided((327680, ), (1, ), device='cuda:0', dtype=torch.int64)
cos = rand_strided((327680, 1024), (1024, 1), device='cuda:0', dtype=torch.float32)
add_11 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_14 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_19 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_22 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_27 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_30 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_35 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_38 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_43 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_46 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_51 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_54 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_59 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_62 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_67 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_70 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_75 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_78 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_83 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_86 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_91 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_94 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_99 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_102 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
add_107 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_89 = rand_strided((327680, 1), (1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_26 = rand_strided((327680, 1), (1, 1), device='cuda:0', dtype=torch.float32)
tangents_1 = rand_strided((327680, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
return [primals_16, primals_20, primals_23, primals_26, primals_1, primals_4, primals_5, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_17, primals_18, primals_19, primals_21, primals_22, primals_24, primals_25, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_167, select_1, cos, add_11, add_14, add_19, add_22, add_27, add_30, add_35, add_38, add_43, add_46, add_51, add_54, add_59, add_62, add_67, add_70, add_75, add_78, add_83, add_86, add_91, add_94, add_99, add_102, add_107, getitem_89, rsqrt_26, tangents_1]
def benchmark_compiled_module(args, times=10, repeat=10):
from torch._inductor.utils import print_performance
fn = lambda: call(list(args))
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
args = get_args()
compiled_module_main('None', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment