-
-
Save shunting314/326029f1bcbe72d65246d2f2a8dcf3e4 to your computer and use it in GitHub Desktop.
This file has been truncated, but you can view the full file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| --- /home/shunting/runnable.py 2026-03-02 15:48:29.094844374 -0800 | |
| +++ /home/shunting/old_runnable.py 2026-03-02 15:51:57.647109947 -0800 | |
| @@ -5,14 +5,12 @@ | |
| os.environ['TORCHELASTIC_ENABLE_FILE_TIMER'] = '1' | |
| os.environ['TORCH_NCCL_DESYNC_DEBUG'] = '1' | |
| os.environ['TORCH_NCCL_RETHROW_CUDA_ERRORS'] = '0' | |
| -os.environ['TORCHX_INTERNAL_SESSION_ID'] = '05e3f00a-a074-49de-a031-c987eb224489' | |
| +os.environ['TORCHX_INTERNAL_SESSION_ID'] = '9f52e4a7-0c61-472a-8fb2-b3a3e913eb2e' | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
| -os.environ['TORCHX_JOB_ID'] = 'quickflow://aps/aps-f1042578221-1042601092' | |
| +os.environ['TORCHX_JOB_ID'] = 'quickflow://aps/aps-f1035405763-1035436703' | |
| os.environ['TORCH_NCCL_ASYNC_ERROR_HANDLING'] = '3' | |
| -os.environ['TORCHELASTIC_HEALTH_CHECK_PORT'] = '29906' | |
| +os.environ['TORCHELASTIC_HEALTH_CHECK_PORT'] = '30066' | |
| os.environ['TORCH_SHOW_CPP_STACKTRACES'] = '1' | |
| -os.environ['TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE'] = '0' | |
| -os.environ['TORCH_COMPILE_OVERRIDE_INDUCTOR_CONFIGS'] = '19:triton.mix_order_reduction=True' | |
| os.environ['PYTORCH_DDP_USE_SIDE_STREAM'] = '0' | |
| os.environ['TRITON_ALLOW_NON_CONSTEXPR_GLOBALS'] = '1' | |
| os.environ['TRITON_LIBHIP_PATH'] = '/usr/local/fbcode/platform010/lib/rocm-7.0/lib/libamdhip64.so' | |
| @@ -23,10 +21,10 @@ | |
| os.environ['TORCHELASTIC_SIGNALS_TO_HANDLE'] = 'SIGTERM,SIGINT,SIGHUP,SIGQUIT' | |
| os.environ['TORCHELASTIC_RESTART_COUNT'] = '0' | |
| os.environ['TORCHELASTIC_MAX_RESTARTS'] = '0' | |
| -os.environ['TORCHELASTIC_RUN_ID'] = 'aps-f1042578221-1042601092' | |
| +os.environ['TORCHELASTIC_RUN_ID'] = 'aps-f1035405763-1035436703' | |
| os.environ['TORCHELASTIC_USE_AGENT_STORE'] = 'False' | |
| -os.environ['TORCHELASTIC_TIMER_FILE'] = '/tmp/watchdog_timer_ba26a513-0990-4ea1-b20b-3c0a64a1e2b7' | |
| -os.environ['TORCHELASTIC_ERROR_FILE'] = '/tmp/chronos_secgrp_ads_training_p9e/torchelastic/03063d64-143f-11f1-8445-19efcccf41ee/aps-f1042578221-1042601092_xy49z3l9/attempt_0/0/error.json' | |
| +os.environ['TORCHELASTIC_TIMER_FILE'] = '/tmp/watchdog_timer_6b8388c2-99bd-4ef7-81bc-666b63a9878e' | |
| +os.environ['TORCHELASTIC_ERROR_FILE'] = '/tmp/chronos_secgrp_ads_training_p9e/torchelastic/18d058f0-08b9-11f1-bef1-b3bbbb82408b/aps-f1035405763-1035436703_3jp37eq8/attempt_0/0/error.json' | |
| os.environ['TORCH_NCCL_CUDA_EVENT_CACHE'] = 'True' | |
| os.environ['TORCH_NCCL_DUMP_ON_TIMEOUT'] = 'True' | |
| os.environ['TORCH_NCCL_TRACE_BUFFER_SIZE'] = '2000' | |
| @@ -38,7 +36,7 @@ | |
| os.environ['TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN'] = 'True' | |
| os.environ['TORCH_PROFILER_ENABLE_COLLECTIVE_PROFILING'] = '0' | |
| os.environ['TORCH_NCCL_DISABLE_DYNAMIC_LOADING'] = 'False' | |
| -os.environ['TORCH_NCCL_DEBUG_INFO_PIPE_FILE'] = '/tmp/nccl_trace_1772239691_rank_' | |
| +os.environ['TORCH_NCCL_DEBUG_INFO_PIPE_FILE'] = '/tmp/nccl_trace_1770972689_rank_' | |
| os.environ['TORCH_NCCL_AVOID_RECORD_STREAMS'] = '1' | |
| os.environ['TRITON_CACHE_DIR'] = '/var/tmp/torchinductor_nobody/triton/0' | |
| @@ -79,7 +77,6 @@ | |
| torch._inductor.config.deterministic = False | |
| torch._inductor.config.min_num_split = 256 | |
| torch._inductor.config.compile_threads = 32 | |
| -torch._inductor.config.static_launch_user_defined_triton_kernels = True | |
| torch._inductor.config.shape_padding = False | |
| torch._inductor.config.force_shape_pad = False | |
| torch._inductor.config.decompose_mem_bound_mm = False | |
| @@ -87,11 +84,10 @@ | |
| torch._inductor.config.triton.cudagraphs = False | |
| torch._inductor.config.triton.unique_kernel_names = True | |
| torch._inductor.config.triton.store_cubin = False | |
| -torch._inductor.config.triton.mix_order_reduction = False | |
| +torch._inductor.config.triton.mix_order_reduction = True | |
| torch._inductor.config.triton.mix_order_reduction_non_strict_mode = False | |
| torch._inductor.config.test_configs.runtime_triton_dtype_assert = False | |
| torch._functorch.config.functionalize_rng_ops = False | |
| -torch._functorch.config.autograd_cache_allow_custom_autograd_functions = True | |
| torch._functorch.config.bundled_autograd_cache = False | |
| torch._functorch.config.is_non_builtin_to_include = False | |
| torch._functorch.config.activation_memory_budget = 0.1 | |
| @@ -105,10 +101,10 @@ | |
| isolate_fails_code_str = None | |
| -# torch.ops.load_library("//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu") | |
| -# torch.ops.load_library("//caffe2/torch/fb/sparsenn:sparsenn_operators") | |
| -# torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") | |
| -# torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") | |
| +torch.ops.load_library("//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu") | |
| +torch.ops.load_library("//caffe2/torch/fb/sparsenn:sparsenn_operators") | |
| +torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") | |
| +torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") | |
| """ | |
| To run this script in fbcode: | |
| @@ -147,14 +143,13 @@ | |
| __compile_source__.splitlines(True), | |
| __after_aot_filename, | |
| ) | |
| -# torch version: 2.12.0a0+fb | |
| +# torch version: 2.11.0a0+fb | |
| # torch cuda version: 12.8.0 | |
| # CUDA Info: | |
| # nvcc not found | |
| # GPU Hardware Info: | |
| # NVIDIA H100 : 8 | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.reset_table() | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -194,7 +189,7 @@ | |
| key=[] | |
| ) | |
| @triton.jit | |
| -def fused_group_contiguous_nan_clamp_0(group_a_ptrs0, group_a_ptrs1, group_a_ptrs2, group_a_ptrs3, group_a_ptrs4, group_a_ptrs5, group_a_ptrs6, group_a_ptrs7, group_a_ptrs8, group_a_ptrs9, group_a_ptrs10, group_a_ptrs11, group_a_ptrs12, group_a_ptrs13, group_a_ptrs14, group_out_ptrs0, group_out_ptrs1, group_out_ptrs2, group_out_ptrs3, group_out_ptrs4, group_out_ptrs5, group_out_ptrs6, group_out_ptrs7, group_out_ptrs8, group_out_ptrs9, group_out_ptrs10, group_out_ptrs11, group_out_ptrs12, group_out_ptrs13, group_out_ptrs14, n_list0, n_list1, n_list2, n_list3, n_list4, n_list5, n_list6, n_list7, n_list8, n_list9, n_list10, n_list11, n_list12, n_list13, n_list14, stride_in_a0, stride_in_a1, stride_in_a2, stride_in_a3, stride_in_a4, stride_in_a5, stride_in_a6, stride_in_a7, stride_in_a8, stride_in_a9, stride_in_a10, stride_in_a11, stride_in_a12, stride_in_a13, stride_in_a14, M, replace_nan: tl.constexpr, clip_limit, first_ptr_for_dtype, needs_clip_limit: tl.constexpr, XBLOCK: tl.constexpr): | |
| +def fused_group_contiguous_nan_clamp(group_a_ptrs0, group_a_ptrs1, group_a_ptrs2, group_a_ptrs3, group_a_ptrs4, group_a_ptrs5, group_a_ptrs6, group_a_ptrs7, group_a_ptrs8, group_a_ptrs9, group_a_ptrs10, group_a_ptrs11, group_a_ptrs12, group_a_ptrs13, group_a_ptrs14, group_out_ptrs0, group_out_ptrs1, group_out_ptrs2, group_out_ptrs3, group_out_ptrs4, group_out_ptrs5, group_out_ptrs6, group_out_ptrs7, group_out_ptrs8, group_out_ptrs9, group_out_ptrs10, group_out_ptrs11, group_out_ptrs12, group_out_ptrs13, group_out_ptrs14, n_list0, n_list1, n_list2, n_list3, n_list4, n_list5, n_list6, n_list7, n_list8, n_list9, n_list10, n_list11, n_list12, n_list13, n_list14, stride_in_a0, stride_in_a1, stride_in_a2, stride_in_a3, stride_in_a4, stride_in_a5, stride_in_a6, stride_in_a7, stride_in_a8, stride_in_a9, stride_in_a10, stride_in_a11, stride_in_a12, stride_in_a13, stride_in_a14, M, replace_nan: tl.constexpr, clip_limit, first_ptr_for_dtype, needs_clip_limit: tl.constexpr, XBLOCK: tl.constexpr): | |
| xoffset = tl.program_id(1) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK) | |
| xindex = xindex[:] | |
| @@ -220,8 +215,7 @@ | |
| if needs_clip_limit: | |
| tmp0 = tl.clamp(tmp0, min=clip_limit_min, max=clip_limit_max) | |
| tl.store(out_ptr + x2, tmp0, x1 < M) | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(fused_group_contiguous_nan_clamp_0) | |
| -MIN_ELEMENT_ALIGNMENT_SIZE = 8 | |
| +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(fused_group_contiguous_nan_clamp) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -293,7 +287,7 @@ | |
| key=[] | |
| ) | |
| @triton.jit | |
| -def group_matmul_add_fixed_mn_kernel_1(group_a_ptrs0, group_a_ptrs1, group_a_ptrs2, group_a_ptrs3, group_a_ptrs4, group_a_ptrs5, group_a_ptrs6, group_a_ptrs7, group_a_ptrs8, group_a_ptrs9, group_a_ptrs10, group_a_ptrs11, group_a_ptrs12, group_a_ptrs13, group_a_ptrs14, group_b_ptrs0, group_b_ptrs1, group_b_ptrs2, group_b_ptrs3, group_b_ptrs4, group_b_ptrs5, group_b_ptrs6, group_b_ptrs7, group_b_ptrs8, group_b_ptrs9, group_b_ptrs10, group_b_ptrs11, group_b_ptrs12, group_b_ptrs13, group_b_ptrs14, group_c_ptr, group_out_ptr, k_sizes0, k_sizes1, k_sizes2, k_sizes3, k_sizes4, k_sizes5, k_sizes6, k_sizes7, k_sizes8, k_sizes9, k_sizes10, k_sizes11, k_sizes12, k_sizes13, k_sizes14, group_size, M, N, stride_out_a, stride_out_b, replace_nan: tl.constexpr, clip_limit, NUM_SM: tl.constexpr, INPUT_PRECISION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): | |
| +def group_matmul_add_fixed_mn_kernel(group_a_ptrs0, group_a_ptrs1, group_a_ptrs2, group_a_ptrs3, group_a_ptrs4, group_a_ptrs5, group_a_ptrs6, group_a_ptrs7, group_a_ptrs8, group_a_ptrs9, group_a_ptrs10, group_a_ptrs11, group_a_ptrs12, group_a_ptrs13, group_a_ptrs14, group_b_ptrs0, group_b_ptrs1, group_b_ptrs2, group_b_ptrs3, group_b_ptrs4, group_b_ptrs5, group_b_ptrs6, group_b_ptrs7, group_b_ptrs8, group_b_ptrs9, group_b_ptrs10, group_b_ptrs11, group_b_ptrs12, group_b_ptrs13, group_b_ptrs14, group_c_ptr, group_out_ptr, k_sizes0, k_sizes1, k_sizes2, k_sizes3, k_sizes4, k_sizes5, k_sizes6, k_sizes7, k_sizes8, k_sizes9, k_sizes10, k_sizes11, k_sizes12, k_sizes13, k_sizes14, group_size, M, N, stride_out_a, stride_out_b, replace_nan: tl.constexpr, clip_limit, NUM_SM: tl.constexpr, INPUT_PRECISION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): | |
| dtype = group_c_ptr.dtype.element_ty | |
| tile_idx = tl.program_id(0) | |
| last_problem_end = 0 | |
| @@ -348,7 +342,7 @@ | |
| tl.store(out_ptrs, out, mask=out_mask) | |
| tile_idx += NUM_SM | |
| last_problem_end = last_problem_end + num_tiles | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_matmul_add_fixed_mn_kernel_1) | |
| +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_matmul_add_fixed_mn_kernel) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -386,35 +380,7 @@ | |
| ) | |
| ], | |
| key=[] | |
| -) | |
| -@triton.jit | |
| -def fused_group_contiguous_nan_clamp_2(group_a_ptrs0, group_a_ptrs1, group_a_ptrs2, group_a_ptrs3, group_a_ptrs4, group_a_ptrs5, group_a_ptrs6, group_a_ptrs7, group_a_ptrs8, group_a_ptrs9, group_a_ptrs10, group_a_ptrs11, group_a_ptrs12, group_a_ptrs13, group_a_ptrs14, group_a_ptrs15, group_a_ptrs16, group_a_ptrs17, group_a_ptrs18, group_a_ptrs19, group_a_ptrs20, group_a_ptrs21, group_a_ptrs22, group_a_ptrs23, group_a_ptrs24, group_a_ptrs25, group_a_ptrs26, group_a_ptrs27, group_a_ptrs28, group_a_ptrs29, group_a_ptrs30, group_a_ptrs31, group_a_ptrs32, group_a_ptrs33, group_a_ptrs34, group_a_ptrs35, group_a_ptrs36, group_a_ptrs37, group_a_ptrs38, group_a_ptrs39, group_a_ptrs40, group_a_ptrs41, group_a_ptrs42, group_a_ptrs43, group_a_ptrs44, group_a_ptrs45, group_a_ptrs46, group_a_ptrs47, group_a_ptrs48, group_a_ptrs49, group_a_ptrs50, group_a_ptrs51, group_a_ptrs52, group_a_ptrs53, group_a_ptrs54, group_a_ptrs55, group_a_ptrs56, group_a_ptrs57, group_a_ptrs58, group_a_ptrs59, group_a_ptrs60, group_a_ptrs61, group_a_ptrs62, group_a_ptrs63, group_a_ptrs64, group_a_ptrs65, group_a_ptrs66, group_a_ptrs67, group_a_ptrs68, group_a_ptrs69, group_a_ptrs70, group_a_ptrs71, group_a_ptrs72, group_a_ptrs73, group_a_ptrs74, group_a_ptrs75, group_a_ptrs76, group_a_ptrs77, group_a_ptrs78, group_a_ptrs79, group_a_ptrs80, group_a_ptrs81, group_a_ptrs82, group_a_ptrs83, group_out_ptrs0, group_out_ptrs1, group_out_ptrs2, group_out_ptrs3, group_out_ptrs4, group_out_ptrs5, group_out_ptrs6, group_out_ptrs7, group_out_ptrs8, group_out_ptrs9, group_out_ptrs10, group_out_ptrs11, group_out_ptrs12, group_out_ptrs13, group_out_ptrs14, group_out_ptrs15, group_out_ptrs16, group_out_ptrs17, group_out_ptrs18, group_out_ptrs19, group_out_ptrs20, group_out_ptrs21, group_out_ptrs22, group_out_ptrs23, group_out_ptrs24, group_out_ptrs25, group_out_ptrs26, group_out_ptrs27, group_out_ptrs28, group_out_ptrs29, group_out_ptrs30, group_out_ptrs31, group_out_ptrs32, group_out_ptrs33, group_out_ptrs34, group_out_ptrs35, group_out_ptrs36, group_out_ptrs37, group_out_ptrs38, group_out_ptrs39, group_out_ptrs40, group_out_ptrs41, group_out_ptrs42, group_out_ptrs43, group_out_ptrs44, group_out_ptrs45, group_out_ptrs46, group_out_ptrs47, group_out_ptrs48, group_out_ptrs49, group_out_ptrs50, group_out_ptrs51, group_out_ptrs52, group_out_ptrs53, group_out_ptrs54, group_out_ptrs55, group_out_ptrs56, group_out_ptrs57, group_out_ptrs58, group_out_ptrs59, group_out_ptrs60, group_out_ptrs61, group_out_ptrs62, group_out_ptrs63, group_out_ptrs64, group_out_ptrs65, group_out_ptrs66, group_out_ptrs67, group_out_ptrs68, group_out_ptrs69, group_out_ptrs70, group_out_ptrs71, group_out_ptrs72, group_out_ptrs73, group_out_ptrs74, group_out_ptrs75, group_out_ptrs76, group_out_ptrs77, group_out_ptrs78, group_out_ptrs79, group_out_ptrs80, group_out_ptrs81, group_out_ptrs82, group_out_ptrs83, n_list0, n_list1, n_list2, n_list3, n_list4, n_list5, n_list6, n_list7, n_list8, n_list9, n_list10, n_list11, n_list12, n_list13, n_list14, n_list15, n_list16, n_list17, n_list18, n_list19, n_list20, n_list21, n_list22, n_list23, n_list24, n_list25, n_list26, n_list27, n_list28, n_list29, n_list30, n_list31, n_list32, n_list33, n_list34, n_list35, n_list36, n_list37, n_list38, n_list39, n_list40, n_list41, n_list42, n_list43, n_list44, n_list45, n_list46, n_list47, n_list48, n_list49, n_list50, n_list51, n_list52, n_list53, n_list54, n_list55, n_list56, n_list57, n_list58, n_list59, n_list60, n_list61, n_list62, n_list63, n_list64, n_list65, n_list66, n_list67, n_list68, n_list69, n_list70, n_list71, n_list72, n_list73, n_list74, n_list75, n_list76, n_list77, n_list78, n_list79, n_list80, n_list81, n_list82, n_list83, stride_in_a0, stride_in_a1, stride_in_a2, stride_in_a3, stride_in_a4, stride_in_a5, stride_in_a6, stride_in_a7, stride_in_a8, stride_in_a9, stride_in_a10, stride_in_a11, stride_in_a12, stride_in_a13, stride_in_a14, stride_in_a15, stride_in_a16, stride_in_a17, stride_in_a18, stride_in_a19, stride_in_a20, stride_in_a21, stride_in_a22, stride_in_a23, stride_in_a24, stride_in_a25, stride_in_a26, stride_in_a27, stride_in_a28, stride_in_a29, stride_in_a30, stride_in_a31, stride_in_a32, stride_in_a33, stride_in_a34, stride_in_a35, stride_in_a36, stride_in_a37, stride_in_a38, stride_in_a39, stride_in_a40, stride_in_a41, stride_in_a42, stride_in_a43, stride_in_a44, stride_in_a45, stride_in_a46, stride_in_a47, stride_in_a48, stride_in_a49, stride_in_a50, stride_in_a51, stride_in_a52, stride_in_a53, stride_in_a54, stride_in_a55, stride_in_a56, stride_in_a57, stride_in_a58, stride_in_a59, stride_in_a60, stride_in_a61, stride_in_a62, stride_in_a63, stride_in_a64, stride_in_a65, stride_in_a66, stride_in_a67, stride_in_a68, stride_in_a69, stride_in_a70, stride_in_a71, stride_in_a72, stride_in_a73, stride_in_a74, stride_in_a75, stride_in_a76, stride_in_a77, stride_in_a78, stride_in_a79, stride_in_a80, stride_in_a81, stride_in_a82, stride_in_a83, M, replace_nan: tl.constexpr, clip_limit, first_ptr_for_dtype, needs_clip_limit: tl.constexpr, XBLOCK: tl.constexpr): | |
| - xoffset = tl.program_id(1) * XBLOCK | |
| - xindex = xoffset + tl.arange(0, XBLOCK) | |
| - xindex = xindex[:] | |
| - g = tl.program_id(0) | |
| - dtype = first_ptr_for_dtype.dtype.element_ty | |
| - replace_nan_tl = tl.full((1,), 0.0, dtype=dtype) | |
| - if clip_limit is not None: | |
| - clip_limit_min = tl.full((1,), -clip_limit, dtype=dtype) | |
| - clip_limit_max = tl.full((1,), clip_limit, dtype=dtype) | |
| - else: | |
| - clip_limit_min = tl.full((1,), 0.0, dtype=dtype) | |
| - clip_limit_max = tl.full((1,), 0.0, dtype=dtype) | |
| - a_ptr = group_a_ptrs0 if g == 0 else group_a_ptrs1 if g == 1 else group_a_ptrs2 if g == 2 else group_a_ptrs3 if g == 3 else group_a_ptrs4 if g == 4 else group_a_ptrs5 if g == 5 else group_a_ptrs6 if g == 6 else group_a_ptrs7 if g == 7 else group_a_ptrs8 if g == 8 else group_a_ptrs9 if g == 9 else group_a_ptrs10 if g == 10 else group_a_ptrs11 if g == 11 else group_a_ptrs12 if g == 12 else group_a_ptrs13 if g == 13 else group_a_ptrs14 if g == 14 else group_a_ptrs15 if g == 15 else group_a_ptrs16 if g == 16 else group_a_ptrs17 if g == 17 else group_a_ptrs18 if g == 18 else group_a_ptrs19 if g == 19 else group_a_ptrs20 if g == 20 else group_a_ptrs21 if g == 21 else group_a_ptrs22 if g == 22 else group_a_ptrs23 if g == 23 else group_a_ptrs24 if g == 24 else group_a_ptrs25 if g == 25 else group_a_ptrs26 if g == 26 else group_a_ptrs27 if g == 27 else group_a_ptrs28 if g == 28 else group_a_ptrs29 if g == 29 else group_a_ptrs30 if g == 30 else group_a_ptrs31 if g == 31 else group_a_ptrs32 if g == 32 else group_a_ptrs33 if g == 33 else group_a_ptrs34 if g == 34 else group_a_ptrs35 if g == 35 else group_a_ptrs36 if g == 36 else group_a_ptrs37 if g == 37 else group_a_ptrs38 if g == 38 else group_a_ptrs39 if g == 39 else group_a_ptrs40 if g == 40 else group_a_ptrs41 if g == 41 else group_a_ptrs42 if g == 42 else group_a_ptrs43 if g == 43 else group_a_ptrs44 if g == 44 else group_a_ptrs45 if g == 45 else group_a_ptrs46 if g == 46 else group_a_ptrs47 if g == 47 else group_a_ptrs48 if g == 48 else group_a_ptrs49 if g == 49 else group_a_ptrs50 if g == 50 else group_a_ptrs51 if g == 51 else group_a_ptrs52 if g == 52 else group_a_ptrs53 if g == 53 else group_a_ptrs54 if g == 54 else group_a_ptrs55 if g == 55 else group_a_ptrs56 if g == 56 else group_a_ptrs57 if g == 57 else group_a_ptrs58 if g == 58 else group_a_ptrs59 if g == 59 else group_a_ptrs60 if g == 60 else group_a_ptrs61 if g == 61 else group_a_ptrs62 if g == 62 else group_a_ptrs63 if g == 63 else group_a_ptrs64 if g == 64 else group_a_ptrs65 if g == 65 else group_a_ptrs66 if g == 66 else group_a_ptrs67 if g == 67 else group_a_ptrs68 if g == 68 else group_a_ptrs69 if g == 69 else group_a_ptrs70 if g == 70 else group_a_ptrs71 if g == 71 else group_a_ptrs72 if g == 72 else group_a_ptrs73 if g == 73 else group_a_ptrs74 if g == 74 else group_a_ptrs75 if g == 75 else group_a_ptrs76 if g == 76 else group_a_ptrs77 if g == 77 else group_a_ptrs78 if g == 78 else group_a_ptrs79 if g == 79 else group_a_ptrs80 if g == 80 else group_a_ptrs81 if g == 81 else group_a_ptrs82 if g == 82 else group_a_ptrs83 | |
| - N = n_list0 if g == 0 else n_list1 if g == 1 else n_list2 if g == 2 else n_list3 if g == 3 else n_list4 if g == 4 else n_list5 if g == 5 else n_list6 if g == 6 else n_list7 if g == 7 else n_list8 if g == 8 else n_list9 if g == 9 else n_list10 if g == 10 else n_list11 if g == 11 else n_list12 if g == 12 else n_list13 if g == 13 else n_list14 if g == 14 else n_list15 if g == 15 else n_list16 if g == 16 else n_list17 if g == 17 else n_list18 if g == 18 else n_list19 if g == 19 else n_list20 if g == 20 else n_list21 if g == 21 else n_list22 if g == 22 else n_list23 if g == 23 else n_list24 if g == 24 else n_list25 if g == 25 else n_list26 if g == 26 else n_list27 if g == 27 else n_list28 if g == 28 else n_list29 if g == 29 else n_list30 if g == 30 else n_list31 if g == 31 else n_list32 if g == 32 else n_list33 if g == 33 else n_list34 if g == 34 else n_list35 if g == 35 else n_list36 if g == 36 else n_list37 if g == 37 else n_list38 if g == 38 else n_list39 if g == 39 else n_list40 if g == 40 else n_list41 if g == 41 else n_list42 if g == 42 else n_list43 if g == 43 else n_list44 if g == 44 else n_list45 if g == 45 else n_list46 if g == 46 else n_list47 if g == 47 else n_list48 if g == 48 else n_list49 if g == 49 else n_list50 if g == 50 else n_list51 if g == 51 else n_list52 if g == 52 else n_list53 if g == 53 else n_list54 if g == 54 else n_list55 if g == 55 else n_list56 if g == 56 else n_list57 if g == 57 else n_list58 if g == 58 else n_list59 if g == 59 else n_list60 if g == 60 else n_list61 if g == 61 else n_list62 if g == 62 else n_list63 if g == 63 else n_list64 if g == 64 else n_list65 if g == 65 else n_list66 if g == 66 else n_list67 if g == 67 else n_list68 if g == 68 else n_list69 if g == 69 else n_list70 if g == 70 else n_list71 if g == 71 else n_list72 if g == 72 else n_list73 if g == 73 else n_list74 if g == 74 else n_list75 if g == 75 else n_list76 if g == 76 else n_list77 if g == 77 else n_list78 if g == 78 else n_list79 if g == 79 else n_list80 if g == 80 else n_list81 if g == 81 else n_list82 if g == 82 else n_list83 | |
| - a_stride_am = stride_in_a0 if g == 0 else stride_in_a1 if g == 1 else stride_in_a2 if g == 2 else stride_in_a3 if g == 3 else stride_in_a4 if g == 4 else stride_in_a5 if g == 5 else stride_in_a6 if g == 6 else stride_in_a7 if g == 7 else stride_in_a8 if g == 8 else stride_in_a9 if g == 9 else stride_in_a10 if g == 10 else stride_in_a11 if g == 11 else stride_in_a12 if g == 12 else stride_in_a13 if g == 13 else stride_in_a14 if g == 14 else stride_in_a15 if g == 15 else stride_in_a16 if g == 16 else stride_in_a17 if g == 17 else stride_in_a18 if g == 18 else stride_in_a19 if g == 19 else stride_in_a20 if g == 20 else stride_in_a21 if g == 21 else stride_in_a22 if g == 22 else stride_in_a23 if g == 23 else stride_in_a24 if g == 24 else stride_in_a25 if g == 25 else stride_in_a26 if g == 26 else stride_in_a27 if g == 27 else stride_in_a28 if g == 28 else stride_in_a29 if g == 29 else stride_in_a30 if g == 30 else stride_in_a31 if g == 31 else stride_in_a32 if g == 32 else stride_in_a33 if g == 33 else stride_in_a34 if g == 34 else stride_in_a35 if g == 35 else stride_in_a36 if g == 36 else stride_in_a37 if g == 37 else stride_in_a38 if g == 38 else stride_in_a39 if g == 39 else stride_in_a40 if g == 40 else stride_in_a41 if g == 41 else stride_in_a42 if g == 42 else stride_in_a43 if g == 43 else stride_in_a44 if g == 44 else stride_in_a45 if g == 45 else stride_in_a46 if g == 46 else stride_in_a47 if g == 47 else stride_in_a48 if g == 48 else stride_in_a49 if g == 49 else stride_in_a50 if g == 50 else stride_in_a51 if g == 51 else stride_in_a52 if g == 52 else stride_in_a53 if g == 53 else stride_in_a54 if g == 54 else stride_in_a55 if g == 55 else stride_in_a56 if g == 56 else stride_in_a57 if g == 57 else stride_in_a58 if g == 58 else stride_in_a59 if g == 59 else stride_in_a60 if g == 60 else stride_in_a61 if g == 61 else stride_in_a62 if g == 62 else stride_in_a63 if g == 63 else stride_in_a64 if g == 64 else stride_in_a65 if g == 65 else stride_in_a66 if g == 66 else stride_in_a67 if g == 67 else stride_in_a68 if g == 68 else stride_in_a69 if g == 69 else stride_in_a70 if g == 70 else stride_in_a71 if g == 71 else stride_in_a72 if g == 72 else stride_in_a73 if g == 73 else stride_in_a74 if g == 74 else stride_in_a75 if g == 75 else stride_in_a76 if g == 76 else stride_in_a77 if g == 77 else stride_in_a78 if g == 78 else stride_in_a79 if g == 79 else stride_in_a80 if g == 80 else stride_in_a81 if g == 81 else stride_in_a82 if g == 82 else stride_in_a83 | |
| - out_ptr = group_out_ptrs0 if g == 0 else group_out_ptrs1 if g == 1 else group_out_ptrs2 if g == 2 else group_out_ptrs3 if g == 3 else group_out_ptrs4 if g == 4 else group_out_ptrs5 if g == 5 else group_out_ptrs6 if g == 6 else group_out_ptrs7 if g == 7 else group_out_ptrs8 if g == 8 else group_out_ptrs9 if g == 9 else group_out_ptrs10 if g == 10 else group_out_ptrs11 if g == 11 else group_out_ptrs12 if g == 12 else group_out_ptrs13 if g == 13 else group_out_ptrs14 if g == 14 else group_out_ptrs15 if g == 15 else group_out_ptrs16 if g == 16 else group_out_ptrs17 if g == 17 else group_out_ptrs18 if g == 18 else group_out_ptrs19 if g == 19 else group_out_ptrs20 if g == 20 else group_out_ptrs21 if g == 21 else group_out_ptrs22 if g == 22 else group_out_ptrs23 if g == 23 else group_out_ptrs24 if g == 24 else group_out_ptrs25 if g == 25 else group_out_ptrs26 if g == 26 else group_out_ptrs27 if g == 27 else group_out_ptrs28 if g == 28 else group_out_ptrs29 if g == 29 else group_out_ptrs30 if g == 30 else group_out_ptrs31 if g == 31 else group_out_ptrs32 if g == 32 else group_out_ptrs33 if g == 33 else group_out_ptrs34 if g == 34 else group_out_ptrs35 if g == 35 else group_out_ptrs36 if g == 36 else group_out_ptrs37 if g == 37 else group_out_ptrs38 if g == 38 else group_out_ptrs39 if g == 39 else group_out_ptrs40 if g == 40 else group_out_ptrs41 if g == 41 else group_out_ptrs42 if g == 42 else group_out_ptrs43 if g == 43 else group_out_ptrs44 if g == 44 else group_out_ptrs45 if g == 45 else group_out_ptrs46 if g == 46 else group_out_ptrs47 if g == 47 else group_out_ptrs48 if g == 48 else group_out_ptrs49 if g == 49 else group_out_ptrs50 if g == 50 else group_out_ptrs51 if g == 51 else group_out_ptrs52 if g == 52 else group_out_ptrs53 if g == 53 else group_out_ptrs54 if g == 54 else group_out_ptrs55 if g == 55 else group_out_ptrs56 if g == 56 else group_out_ptrs57 if g == 57 else group_out_ptrs58 if g == 58 else group_out_ptrs59 if g == 59 else group_out_ptrs60 if g == 60 else group_out_ptrs61 if g == 61 else group_out_ptrs62 if g == 62 else group_out_ptrs63 if g == 63 else group_out_ptrs64 if g == 64 else group_out_ptrs65 if g == 65 else group_out_ptrs66 if g == 66 else group_out_ptrs67 if g == 67 else group_out_ptrs68 if g == 68 else group_out_ptrs69 if g == 69 else group_out_ptrs70 if g == 70 else group_out_ptrs71 if g == 71 else group_out_ptrs72 if g == 72 else group_out_ptrs73 if g == 73 else group_out_ptrs74 if g == 74 else group_out_ptrs75 if g == 75 else group_out_ptrs76 if g == 76 else group_out_ptrs77 if g == 77 else group_out_ptrs78 if g == 78 else group_out_ptrs79 if g == 79 else group_out_ptrs80 if g == 80 else group_out_ptrs81 if g == 81 else group_out_ptrs82 if g == 82 else group_out_ptrs83 | |
| - x0 = xindex % N | |
| - x1 = xindex // N | |
| - x2 = xindex | |
| - tmp0 = tl.load(a_ptr + (x0 + a_stride_am * x1), x1 < M) | |
| - if replace_nan: | |
| - tmp0 = tl.where(tmp0 != tmp0, replace_nan_tl, tmp0) | |
| - if needs_clip_limit: | |
| - tmp0 = tl.clamp(tmp0, min=clip_limit_min, max=clip_limit_max) | |
| - tl.store(out_ptr + x2, tmp0, x1 < M) | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(fused_group_contiguous_nan_clamp_2) | |
| +)torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(fused_group_contiguous_nan_clamp) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -484,64 +450,7 @@ | |
| ) | |
| ], | |
| key=[] | |
| -) | |
| -@triton.jit | |
| -def group_matmul_add_fixed_mn_kernel_3(group_a_ptrs0, group_a_ptrs1, group_a_ptrs2, group_a_ptrs3, group_a_ptrs4, group_a_ptrs5, group_a_ptrs6, group_a_ptrs7, group_a_ptrs8, group_a_ptrs9, group_a_ptrs10, group_a_ptrs11, group_a_ptrs12, group_a_ptrs13, group_a_ptrs14, group_a_ptrs15, group_a_ptrs16, group_a_ptrs17, group_a_ptrs18, group_a_ptrs19, group_a_ptrs20, group_a_ptrs21, group_a_ptrs22, group_a_ptrs23, group_a_ptrs24, group_a_ptrs25, group_a_ptrs26, group_a_ptrs27, group_a_ptrs28, group_a_ptrs29, group_a_ptrs30, group_a_ptrs31, group_a_ptrs32, group_a_ptrs33, group_a_ptrs34, group_a_ptrs35, group_a_ptrs36, group_a_ptrs37, group_a_ptrs38, group_a_ptrs39, group_a_ptrs40, group_a_ptrs41, group_a_ptrs42, group_a_ptrs43, group_a_ptrs44, group_a_ptrs45, group_a_ptrs46, group_a_ptrs47, group_a_ptrs48, group_a_ptrs49, group_a_ptrs50, group_a_ptrs51, group_a_ptrs52, group_a_ptrs53, group_a_ptrs54, group_a_ptrs55, group_a_ptrs56, group_a_ptrs57, group_a_ptrs58, group_a_ptrs59, group_a_ptrs60, group_a_ptrs61, group_a_ptrs62, group_a_ptrs63, group_a_ptrs64, group_a_ptrs65, group_a_ptrs66, group_a_ptrs67, group_a_ptrs68, group_a_ptrs69, group_a_ptrs70, group_a_ptrs71, group_a_ptrs72, group_a_ptrs73, group_a_ptrs74, group_a_ptrs75, group_a_ptrs76, group_a_ptrs77, group_a_ptrs78, group_a_ptrs79, group_a_ptrs80, group_a_ptrs81, group_a_ptrs82, group_a_ptrs83, group_b_ptrs0, group_b_ptrs1, group_b_ptrs2, group_b_ptrs3, group_b_ptrs4, group_b_ptrs5, group_b_ptrs6, group_b_ptrs7, group_b_ptrs8, group_b_ptrs9, group_b_ptrs10, group_b_ptrs11, group_b_ptrs12, group_b_ptrs13, group_b_ptrs14, group_b_ptrs15, group_b_ptrs16, group_b_ptrs17, group_b_ptrs18, group_b_ptrs19, group_b_ptrs20, group_b_ptrs21, group_b_ptrs22, group_b_ptrs23, group_b_ptrs24, group_b_ptrs25, group_b_ptrs26, group_b_ptrs27, group_b_ptrs28, group_b_ptrs29, group_b_ptrs30, group_b_ptrs31, group_b_ptrs32, group_b_ptrs33, group_b_ptrs34, group_b_ptrs35, group_b_ptrs36, group_b_ptrs37, group_b_ptrs38, group_b_ptrs39, group_b_ptrs40, group_b_ptrs41, group_b_ptrs42, group_b_ptrs43, group_b_ptrs44, group_b_ptrs45, group_b_ptrs46, group_b_ptrs47, group_b_ptrs48, group_b_ptrs49, group_b_ptrs50, group_b_ptrs51, group_b_ptrs52, group_b_ptrs53, group_b_ptrs54, group_b_ptrs55, group_b_ptrs56, group_b_ptrs57, group_b_ptrs58, group_b_ptrs59, group_b_ptrs60, group_b_ptrs61, group_b_ptrs62, group_b_ptrs63, group_b_ptrs64, group_b_ptrs65, group_b_ptrs66, group_b_ptrs67, group_b_ptrs68, group_b_ptrs69, group_b_ptrs70, group_b_ptrs71, group_b_ptrs72, group_b_ptrs73, group_b_ptrs74, group_b_ptrs75, group_b_ptrs76, group_b_ptrs77, group_b_ptrs78, group_b_ptrs79, group_b_ptrs80, group_b_ptrs81, group_b_ptrs82, group_b_ptrs83, group_c_ptr, group_out_ptr, k_sizes0, k_sizes1, k_sizes2, k_sizes3, k_sizes4, k_sizes5, k_sizes6, k_sizes7, k_sizes8, k_sizes9, k_sizes10, k_sizes11, k_sizes12, k_sizes13, k_sizes14, k_sizes15, k_sizes16, k_sizes17, k_sizes18, k_sizes19, k_sizes20, k_sizes21, k_sizes22, k_sizes23, k_sizes24, k_sizes25, k_sizes26, k_sizes27, k_sizes28, k_sizes29, k_sizes30, k_sizes31, k_sizes32, k_sizes33, k_sizes34, k_sizes35, k_sizes36, k_sizes37, k_sizes38, k_sizes39, k_sizes40, k_sizes41, k_sizes42, k_sizes43, k_sizes44, k_sizes45, k_sizes46, k_sizes47, k_sizes48, k_sizes49, k_sizes50, k_sizes51, k_sizes52, k_sizes53, k_sizes54, k_sizes55, k_sizes56, k_sizes57, k_sizes58, k_sizes59, k_sizes60, k_sizes61, k_sizes62, k_sizes63, k_sizes64, k_sizes65, k_sizes66, k_sizes67, k_sizes68, k_sizes69, k_sizes70, k_sizes71, k_sizes72, k_sizes73, k_sizes74, k_sizes75, k_sizes76, k_sizes77, k_sizes78, k_sizes79, k_sizes80, k_sizes81, k_sizes82, k_sizes83, group_size, M, N, stride_out_a, stride_out_b, replace_nan: tl.constexpr, clip_limit, NUM_SM: tl.constexpr, INPUT_PRECISION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): | |
| - dtype = group_c_ptr.dtype.element_ty | |
| - tile_idx = tl.program_id(0) | |
| - last_problem_end = 0 | |
| - replace_nan_tl = tl.full((1,), 0.0, dtype=dtype) | |
| - if clip_limit is not None: | |
| - clip_limit_min = tl.full((1,), -clip_limit, dtype=dtype) | |
| - clip_limit_max = tl.full((1,), clip_limit, dtype=dtype) | |
| - else: | |
| - clip_limit_min = tl.full((1,), 0.0, dtype=dtype) | |
| - clip_limit_max = tl.full((1,), 0.0, dtype=dtype) | |
| - num_m_tiles = tl.cdiv(M, BLOCK_SIZE_M) | |
| - num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) | |
| - num_tiles = num_m_tiles * num_n_tiles | |
| - for g in range(84): | |
| - K = (k_sizes0 if g == 0 else k_sizes1 if g == 1 else k_sizes2 if g == 2 else k_sizes3 if g == 3 else k_sizes4 if g == 4 else k_sizes5 if g == 5 else k_sizes6 if g == 6 else k_sizes7 if g == 7 else k_sizes8 if g == 8 else k_sizes9 if g == 9 else k_sizes10 if g == 10 else k_sizes11 if g == 11 else k_sizes12 if g == 12 else k_sizes13 if g == 13 else k_sizes14 if g == 14 else k_sizes15 if g == 15 else k_sizes16 if g == 16 else k_sizes17 if g == 17 else k_sizes18 if g == 18 else k_sizes19 if g == 19 else k_sizes20 if g == 20 else k_sizes21 if g == 21 else k_sizes22 if g == 22 else k_sizes23 if g == 23 else k_sizes24 if g == 24 else k_sizes25 if g == 25 else k_sizes26 if g == 26 else k_sizes27 if g == 27 else k_sizes28 if g == 28 else k_sizes29 if g == 29 else k_sizes30 if g == 30 else k_sizes31 if g == 31 else k_sizes32 if g == 32 else k_sizes33 if g == 33 else k_sizes34 if g == 34 else k_sizes35 if g == 35 else k_sizes36 if g == 36 else k_sizes37 if g == 37 else k_sizes38 if g == 38 else k_sizes39 if g == 39 else k_sizes40 if g == 40 else k_sizes41 if g == 41 else k_sizes42 if g == 42 else k_sizes43 if g == 43 else k_sizes44 if g == 44 else k_sizes45 if g == 45 else k_sizes46 if g == 46 else k_sizes47 if g == 47 else k_sizes48 if g == 48 else k_sizes49 if g == 49 else k_sizes50 if g == 50 else k_sizes51 if g == 51 else k_sizes52 if g == 52 else k_sizes53 if g == 53 else k_sizes54 if g == 54 else k_sizes55 if g == 55 else k_sizes56 if g == 56 else k_sizes57 if g == 57 else k_sizes58 if g == 58 else k_sizes59 if g == 59 else k_sizes60 if g == 60 else k_sizes61 if g == 61 else k_sizes62 if g == 62 else k_sizes63 if g == 63 else k_sizes64 if g == 64 else k_sizes65 if g == 65 else k_sizes66 if g == 66 else k_sizes67 if g == 67 else k_sizes68 if g == 68 else k_sizes69 if g == 69 else k_sizes70 if g == 70 else k_sizes71 if g == 71 else k_sizes72 if g == 72 else k_sizes73 if g == 73 else k_sizes74 if g == 74 else k_sizes75 if g == 75 else k_sizes76 if g == 76 else k_sizes77 if g == 77 else k_sizes78 if g == 78 else k_sizes79 if g == 79 else k_sizes80 if g == 80 else k_sizes81 if g == 81 else k_sizes82 if g == 82 else k_sizes83) * MIN_ELEMENT_ALIGNMENT_SIZE | |
| - a_ptr = group_a_ptrs0 if g == 0 else group_a_ptrs1 if g == 1 else group_a_ptrs2 if g == 2 else group_a_ptrs3 if g == 3 else group_a_ptrs4 if g == 4 else group_a_ptrs5 if g == 5 else group_a_ptrs6 if g == 6 else group_a_ptrs7 if g == 7 else group_a_ptrs8 if g == 8 else group_a_ptrs9 if g == 9 else group_a_ptrs10 if g == 10 else group_a_ptrs11 if g == 11 else group_a_ptrs12 if g == 12 else group_a_ptrs13 if g == 13 else group_a_ptrs14 if g == 14 else group_a_ptrs15 if g == 15 else group_a_ptrs16 if g == 16 else group_a_ptrs17 if g == 17 else group_a_ptrs18 if g == 18 else group_a_ptrs19 if g == 19 else group_a_ptrs20 if g == 20 else group_a_ptrs21 if g == 21 else group_a_ptrs22 if g == 22 else group_a_ptrs23 if g == 23 else group_a_ptrs24 if g == 24 else group_a_ptrs25 if g == 25 else group_a_ptrs26 if g == 26 else group_a_ptrs27 if g == 27 else group_a_ptrs28 if g == 28 else group_a_ptrs29 if g == 29 else group_a_ptrs30 if g == 30 else group_a_ptrs31 if g == 31 else group_a_ptrs32 if g == 32 else group_a_ptrs33 if g == 33 else group_a_ptrs34 if g == 34 else group_a_ptrs35 if g == 35 else group_a_ptrs36 if g == 36 else group_a_ptrs37 if g == 37 else group_a_ptrs38 if g == 38 else group_a_ptrs39 if g == 39 else group_a_ptrs40 if g == 40 else group_a_ptrs41 if g == 41 else group_a_ptrs42 if g == 42 else group_a_ptrs43 if g == 43 else group_a_ptrs44 if g == 44 else group_a_ptrs45 if g == 45 else group_a_ptrs46 if g == 46 else group_a_ptrs47 if g == 47 else group_a_ptrs48 if g == 48 else group_a_ptrs49 if g == 49 else group_a_ptrs50 if g == 50 else group_a_ptrs51 if g == 51 else group_a_ptrs52 if g == 52 else group_a_ptrs53 if g == 53 else group_a_ptrs54 if g == 54 else group_a_ptrs55 if g == 55 else group_a_ptrs56 if g == 56 else group_a_ptrs57 if g == 57 else group_a_ptrs58 if g == 58 else group_a_ptrs59 if g == 59 else group_a_ptrs60 if g == 60 else group_a_ptrs61 if g == 61 else group_a_ptrs62 if g == 62 else group_a_ptrs63 if g == 63 else group_a_ptrs64 if g == 64 else group_a_ptrs65 if g == 65 else group_a_ptrs66 if g == 66 else group_a_ptrs67 if g == 67 else group_a_ptrs68 if g == 68 else group_a_ptrs69 if g == 69 else group_a_ptrs70 if g == 70 else group_a_ptrs71 if g == 71 else group_a_ptrs72 if g == 72 else group_a_ptrs73 if g == 73 else group_a_ptrs74 if g == 74 else group_a_ptrs75 if g == 75 else group_a_ptrs76 if g == 76 else group_a_ptrs77 if g == 77 else group_a_ptrs78 if g == 78 else group_a_ptrs79 if g == 79 else group_a_ptrs80 if g == 80 else group_a_ptrs81 if g == 81 else group_a_ptrs82 if g == 82 else group_a_ptrs83 | |
| - b_ptr = group_b_ptrs0 if g == 0 else group_b_ptrs1 if g == 1 else group_b_ptrs2 if g == 2 else group_b_ptrs3 if g == 3 else group_b_ptrs4 if g == 4 else group_b_ptrs5 if g == 5 else group_b_ptrs6 if g == 6 else group_b_ptrs7 if g == 7 else group_b_ptrs8 if g == 8 else group_b_ptrs9 if g == 9 else group_b_ptrs10 if g == 10 else group_b_ptrs11 if g == 11 else group_b_ptrs12 if g == 12 else group_b_ptrs13 if g == 13 else group_b_ptrs14 if g == 14 else group_b_ptrs15 if g == 15 else group_b_ptrs16 if g == 16 else group_b_ptrs17 if g == 17 else group_b_ptrs18 if g == 18 else group_b_ptrs19 if g == 19 else group_b_ptrs20 if g == 20 else group_b_ptrs21 if g == 21 else group_b_ptrs22 if g == 22 else group_b_ptrs23 if g == 23 else group_b_ptrs24 if g == 24 else group_b_ptrs25 if g == 25 else group_b_ptrs26 if g == 26 else group_b_ptrs27 if g == 27 else group_b_ptrs28 if g == 28 else group_b_ptrs29 if g == 29 else group_b_ptrs30 if g == 30 else group_b_ptrs31 if g == 31 else group_b_ptrs32 if g == 32 else group_b_ptrs33 if g == 33 else group_b_ptrs34 if g == 34 else group_b_ptrs35 if g == 35 else group_b_ptrs36 if g == 36 else group_b_ptrs37 if g == 37 else group_b_ptrs38 if g == 38 else group_b_ptrs39 if g == 39 else group_b_ptrs40 if g == 40 else group_b_ptrs41 if g == 41 else group_b_ptrs42 if g == 42 else group_b_ptrs43 if g == 43 else group_b_ptrs44 if g == 44 else group_b_ptrs45 if g == 45 else group_b_ptrs46 if g == 46 else group_b_ptrs47 if g == 47 else group_b_ptrs48 if g == 48 else group_b_ptrs49 if g == 49 else group_b_ptrs50 if g == 50 else group_b_ptrs51 if g == 51 else group_b_ptrs52 if g == 52 else group_b_ptrs53 if g == 53 else group_b_ptrs54 if g == 54 else group_b_ptrs55 if g == 55 else group_b_ptrs56 if g == 56 else group_b_ptrs57 if g == 57 else group_b_ptrs58 if g == 58 else group_b_ptrs59 if g == 59 else group_b_ptrs60 if g == 60 else group_b_ptrs61 if g == 61 else group_b_ptrs62 if g == 62 else group_b_ptrs63 if g == 63 else group_b_ptrs64 if g == 64 else group_b_ptrs65 if g == 65 else group_b_ptrs66 if g == 66 else group_b_ptrs67 if g == 67 else group_b_ptrs68 if g == 68 else group_b_ptrs69 if g == 69 else group_b_ptrs70 if g == 70 else group_b_ptrs71 if g == 71 else group_b_ptrs72 if g == 72 else group_b_ptrs73 if g == 73 else group_b_ptrs74 if g == 74 else group_b_ptrs75 if g == 75 else group_b_ptrs76 if g == 76 else group_b_ptrs77 if g == 77 else group_b_ptrs78 if g == 78 else group_b_ptrs79 if g == 79 else group_b_ptrs80 if g == 80 else group_b_ptrs81 if g == 81 else group_b_ptrs82 if g == 82 else group_b_ptrs83 | |
| - c_ptr = group_c_ptr + g * N | |
| - out_ptr = group_out_ptr + g * stride_out_a | |
| - a_stride_am = K | |
| - a_stride_ak = 1 | |
| - while tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles: | |
| - tile_idx_in_gemm = tile_idx - last_problem_end | |
| - tile_m_idx = tile_idx_in_gemm // num_n_tiles | |
| - tile_n_idx = tile_idx_in_gemm % num_n_tiles | |
| - offs_am = (tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | |
| - offs_bn = (tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | |
| - offs_k = tl.arange(0, BLOCK_SIZE_K) | |
| - a_ptrs = a_ptr + offs_am[:, None] * a_stride_am + offs_k[None, :] * a_stride_ak | |
| - b_ptrs = b_ptr + offs_k[:, None] * N + offs_bn[None, :] | |
| - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |
| - for kk in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | |
| - tl.multiple_of(a_ptrs, [16, 16]) | |
| - tl.multiple_of(b_ptrs, [16, 16]) | |
| - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - kk * BLOCK_SIZE_K, other=0.0) | |
| - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - kk * BLOCK_SIZE_K, other=0.0) | |
| - if replace_nan: | |
| - a = tl.where(a != a, replace_nan_tl, a) | |
| - if clip_limit is not None: | |
| - a = tl.clamp(a, min=clip_limit_min, max=clip_limit_max) | |
| - accumulator += tl.dot(a, b, input_precision=INPUT_PRECISION) | |
| - a_ptrs += BLOCK_SIZE_K | |
| - b_ptrs += BLOCK_SIZE_K * N | |
| - c_ptrs = c_ptr + offs_bn[None, :] | |
| - c = tl.load(c_ptrs) | |
| - accumulator += c | |
| - out = accumulator.to(a_ptrs.dtype.element_ty) | |
| - offs_out_m = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
| - offs_out_n = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
| - out_ptrs = out_ptr + stride_out_b * offs_out_m[:, None] + offs_out_n[None, :] | |
| - out_mask = (offs_out_m[:, None] < M) & (offs_out_n[None, :] < N) | |
| - tl.store(out_ptrs, out, mask=out_mask) | |
| - tile_idx += NUM_SM | |
| - last_problem_end = last_problem_end + num_tiles | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_matmul_add_fixed_mn_kernel_3) | |
| +)torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_matmul_add_fixed_mn_kernel) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -553,7 +462,7 @@ | |
| key=[] | |
| ) | |
| @triton.jit | |
| -def group_index_select_4(input_group_ptrs0, input_group_ptrs1, input_group_ptrs2, input_group_ptrs3, input_group_ptrs4, input_group_ptrs5, input_group_ptrs6, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, output_group_ptrs0, output_group_ptrs1, output_group_ptrs2, output_group_ptrs3, output_group_ptrs4, output_group_ptrs5, output_group_ptrs6, single_output_ptr, input_stride_0: tl.constexpr, input_stride_1: tl.constexpr, indices_stride: tl.constexpr, output_stride_0: tl.constexpr, output_stride_1: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| +def group_index_select(input_group_ptrs0, input_group_ptrs1, input_group_ptrs2, input_group_ptrs3, input_group_ptrs4, input_group_ptrs5, input_group_ptrs6, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, output_group_ptrs0, output_group_ptrs1, output_group_ptrs2, output_group_ptrs3, output_group_ptrs4, output_group_ptrs5, output_group_ptrs6, single_output_ptr, input_stride_0: tl.constexpr, input_stride_1: tl.constexpr, indices_stride: tl.constexpr, output_stride_0: tl.constexpr, output_stride_1: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| pid_idx = tl.program_id(0) | |
| if USE_I64_PID: | |
| pid_idx = pid_idx.to(tl.int64) | |
| @@ -680,7 +589,7 @@ | |
| value = tl.load(input_block_ptr, mask=mask) | |
| output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_4) | |
| +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -690,170 +599,7 @@ | |
| ) | |
| ], | |
| key=[] | |
| -) | |
| -@triton.jit | |
| -def group_index_select_5(input_group_ptrs0, input_group_ptrs1, input_group_ptrs2, input_group_ptrs3, input_group_ptrs4, input_group_ptrs5, input_group_ptrs6, input_group_ptrs7, input_group_ptrs8, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, indices_group_ptrs7, indices_group_ptrs8, output_group_ptrs0, output_group_ptrs1, output_group_ptrs2, output_group_ptrs3, output_group_ptrs4, output_group_ptrs5, output_group_ptrs6, output_group_ptrs7, output_group_ptrs8, single_output_ptr, input_stride_0: tl.constexpr, input_stride_1: tl.constexpr, indices_stride: tl.constexpr, output_stride_0: tl.constexpr, output_stride_1: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - input_ptr = input_group_ptrs0 | |
| - indices_ptr = indices_group_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 0 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs1 | |
| - indices_ptr = indices_group_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 1 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs2 | |
| - indices_ptr = indices_group_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 2 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs3 | |
| - indices_ptr = indices_group_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 3 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs4 | |
| - indices_ptr = indices_group_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 4 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs5 | |
| - indices_ptr = indices_group_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 5 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs6 | |
| - indices_ptr = indices_group_ptrs6 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 6 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs6 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs7 | |
| - indices_ptr = indices_group_ptrs7 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 7 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs7 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs8 | |
| - indices_ptr = indices_group_ptrs8 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 8 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs8 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_5) | |
| +)torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -863,187 +609,7 @@ | |
| ) | |
| ], | |
| key=[] | |
| -) | |
| -@triton.jit | |
| -def group_index_select_6(input_group_ptrs0, input_group_ptrs1, input_group_ptrs2, input_group_ptrs3, input_group_ptrs4, input_group_ptrs5, input_group_ptrs6, input_group_ptrs7, input_group_ptrs8, input_group_ptrs9, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, indices_group_ptrs7, indices_group_ptrs8, indices_group_ptrs9, output_group_ptrs0, output_group_ptrs1, output_group_ptrs2, output_group_ptrs3, output_group_ptrs4, output_group_ptrs5, output_group_ptrs6, output_group_ptrs7, output_group_ptrs8, output_group_ptrs9, single_output_ptr, input_stride_0: tl.constexpr, input_stride_1: tl.constexpr, indices_stride: tl.constexpr, output_stride_0: tl.constexpr, output_stride_1: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - input_ptr = input_group_ptrs0 | |
| - indices_ptr = indices_group_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 0 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs1 | |
| - indices_ptr = indices_group_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 1 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs2 | |
| - indices_ptr = indices_group_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 2 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs3 | |
| - indices_ptr = indices_group_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 3 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs4 | |
| - indices_ptr = indices_group_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 4 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs5 | |
| - indices_ptr = indices_group_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 5 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs6 | |
| - indices_ptr = indices_group_ptrs6 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 6 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs6 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs7 | |
| - indices_ptr = indices_group_ptrs7 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 7 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs7 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs8 | |
| - indices_ptr = indices_group_ptrs8 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 8 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs8 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs9 | |
| - indices_ptr = indices_group_ptrs9 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 9 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs9 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_6) | |
| +)torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -1053,119 +619,7 @@ | |
| ) | |
| ], | |
| key=[] | |
| -) | |
| -@triton.jit | |
| -def group_index_select_7(input_group_ptrs0, input_group_ptrs1, input_group_ptrs2, input_group_ptrs3, input_group_ptrs4, input_group_ptrs5, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, output_group_ptrs0, output_group_ptrs1, output_group_ptrs2, output_group_ptrs3, output_group_ptrs4, output_group_ptrs5, single_output_ptr, input_stride_0: tl.constexpr, input_stride_1: tl.constexpr, indices_stride: tl.constexpr, output_stride_0: tl.constexpr, output_stride_1: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - input_ptr = input_group_ptrs0 | |
| - indices_ptr = indices_group_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 0 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs1 | |
| - indices_ptr = indices_group_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 1 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs2 | |
| - indices_ptr = indices_group_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 2 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs3 | |
| - indices_ptr = indices_group_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 3 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs4 | |
| - indices_ptr = indices_group_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 4 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs5 | |
| - indices_ptr = indices_group_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 5 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_7) | |
| +)torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -1175,153 +629,7 @@ | |
| ) | |
| ], | |
| key=[] | |
| -) | |
| -@triton.jit | |
| -def group_index_select_8(input_group_ptrs0, input_group_ptrs1, input_group_ptrs2, input_group_ptrs3, input_group_ptrs4, input_group_ptrs5, input_group_ptrs6, input_group_ptrs7, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, indices_group_ptrs7, output_group_ptrs0, output_group_ptrs1, output_group_ptrs2, output_group_ptrs3, output_group_ptrs4, output_group_ptrs5, output_group_ptrs6, output_group_ptrs7, single_output_ptr, input_stride_0: tl.constexpr, input_stride_1: tl.constexpr, indices_stride: tl.constexpr, output_stride_0: tl.constexpr, output_stride_1: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - input_ptr = input_group_ptrs0 | |
| - indices_ptr = indices_group_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 0 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs1 | |
| - indices_ptr = indices_group_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 1 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs2 | |
| - indices_ptr = indices_group_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 2 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs3 | |
| - indices_ptr = indices_group_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 3 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs4 | |
| - indices_ptr = indices_group_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 4 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs5 | |
| - indices_ptr = indices_group_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 5 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs6 | |
| - indices_ptr = indices_group_ptrs6 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 6 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs6 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs7 | |
| - indices_ptr = indices_group_ptrs7 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 7 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs7 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_8) | |
| +)torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -1331,357 +639,7 @@ | |
| ) | |
| ], | |
| key=[] | |
| -) | |
| -@triton.jit | |
| -def group_index_select_9(input_group_ptrs0, input_group_ptrs1, input_group_ptrs2, input_group_ptrs3, input_group_ptrs4, input_group_ptrs5, input_group_ptrs6, input_group_ptrs7, input_group_ptrs8, input_group_ptrs9, input_group_ptrs10, input_group_ptrs11, input_group_ptrs12, input_group_ptrs13, input_group_ptrs14, input_group_ptrs15, input_group_ptrs16, input_group_ptrs17, input_group_ptrs18, input_group_ptrs19, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, indices_group_ptrs7, indices_group_ptrs8, indices_group_ptrs9, indices_group_ptrs10, indices_group_ptrs11, indices_group_ptrs12, indices_group_ptrs13, indices_group_ptrs14, indices_group_ptrs15, indices_group_ptrs16, indices_group_ptrs17, indices_group_ptrs18, indices_group_ptrs19, output_group_ptrs0, output_group_ptrs1, output_group_ptrs2, output_group_ptrs3, output_group_ptrs4, output_group_ptrs5, output_group_ptrs6, output_group_ptrs7, output_group_ptrs8, output_group_ptrs9, output_group_ptrs10, output_group_ptrs11, output_group_ptrs12, output_group_ptrs13, output_group_ptrs14, output_group_ptrs15, output_group_ptrs16, output_group_ptrs17, output_group_ptrs18, output_group_ptrs19, single_output_ptr, input_stride_0: tl.constexpr, input_stride_1: tl.constexpr, indices_stride: tl.constexpr, output_stride_0: tl.constexpr, output_stride_1: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - input_ptr = input_group_ptrs0 | |
| - indices_ptr = indices_group_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 0 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs1 | |
| - indices_ptr = indices_group_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 1 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs2 | |
| - indices_ptr = indices_group_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 2 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs3 | |
| - indices_ptr = indices_group_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 3 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs4 | |
| - indices_ptr = indices_group_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 4 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs5 | |
| - indices_ptr = indices_group_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 5 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs6 | |
| - indices_ptr = indices_group_ptrs6 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 6 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs6 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs7 | |
| - indices_ptr = indices_group_ptrs7 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 7 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs7 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs8 | |
| - indices_ptr = indices_group_ptrs8 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 8 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs8 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs9 | |
| - indices_ptr = indices_group_ptrs9 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 9 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs9 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs10 | |
| - indices_ptr = indices_group_ptrs10 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 10 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs10 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs11 | |
| - indices_ptr = indices_group_ptrs11 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 11 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs11 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs12 | |
| - indices_ptr = indices_group_ptrs12 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 12 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs12 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs13 | |
| - indices_ptr = indices_group_ptrs13 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 13 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs13 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs14 | |
| - indices_ptr = indices_group_ptrs14 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 14 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs14 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs15 | |
| - indices_ptr = indices_group_ptrs15 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 15 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs15 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs16 | |
| - indices_ptr = indices_group_ptrs16 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 16 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs16 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs17 | |
| - indices_ptr = indices_group_ptrs17 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 17 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs17 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs18 | |
| - indices_ptr = indices_group_ptrs18 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 18 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs18 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| - input_ptr = input_group_ptrs19 | |
| - indices_ptr = indices_group_ptrs19 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 19 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs19 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_9) | |
| +)torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -1691,37 +649,10 @@ | |
| ) | |
| ], | |
| key=[] | |
| -) | |
| -@triton.jit | |
| -def group_index_select_10(input_group_ptrs0, indices_group_ptrs0, output_group_ptrs0, single_output_ptr, input_stride_0: tl.constexpr, input_stride_1: tl.constexpr, indices_stride: tl.constexpr, output_stride_0: tl.constexpr, output_stride_1: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - input_ptr = input_group_ptrs0 | |
| - indices_ptr = indices_group_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - output_ptr = single_output_ptr | |
| - d_offsets_output = 0 * D + tl.arange(0, BLOCK_D) | |
| - else: | |
| - output_ptr = output_group_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - input_block_ptr = input_ptr + indices[:, None] * input_stride_0 + d_offsets[None, :] * input_stride_1 | |
| - mask = (d_offsets[None, :] < D) & (o_offsets[:, None] < O) | |
| - value = tl.load(input_block_ptr, mask=mask) | |
| - output_block_ptr = output_ptr + o_offsets[:, None] * output_stride_0 + d_offsets_output[None, :] * output_stride_1 | |
| - tl.store(output_block_ptr, value.to(output_ptr.dtype.element_ty), mask=mask) | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_10) | |
| +)torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select) | |
| @triton.jit | |
| -def group_index_select_backward_kernel_11(d_out_ptrs0, indices_group_ptrs0, d_in_ptrs0, single_d_out_ptr, d_out_stride_0: tl.constexpr, d_out_stride_1: tl.constexpr, d_in_stride_0: tl.constexpr, d_in_stride_1: tl.constexpr, indices_stride: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| +def group_index_select_backward_kernel(d_out_ptrs0, indices_group_ptrs0, d_in_ptrs0, single_d_out_ptr, d_out_stride_0: tl.constexpr, d_out_stride_1: tl.constexpr, d_in_stride_0: tl.constexpr, d_in_stride_1: tl.constexpr, indices_stride: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| pid_idx = tl.program_id(0) | |
| if USE_I64_PID: | |
| pid_idx = pid_idx.to(tl.int64) | |
| @@ -1745,7 +676,7 @@ | |
| d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel_11) | |
| +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -1755,364 +686,8 @@ | |
| ) | |
| ], | |
| key=[] | |
| -) | |
| -@triton.jit | |
| -def group_index_select_backward_kernel_12(d_out_ptrs0, indices_group_ptrs0, d_in_ptrs0, single_d_out_ptr, d_out_stride_0: tl.constexpr, d_out_stride_1: tl.constexpr, d_in_stride_0: tl.constexpr, d_in_stride_1: tl.constexpr, indices_stride: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - indices_ptr = indices_group_ptrs0 | |
| - d_in_ptr = d_in_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 0 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel_12) | |
| - | |
| -@triton.jit | |
| -def group_index_select_backward_kernel_13(d_out_ptrs0, d_out_ptrs1, d_out_ptrs2, d_out_ptrs3, d_out_ptrs4, d_out_ptrs5, d_out_ptrs6, d_out_ptrs7, d_out_ptrs8, d_out_ptrs9, d_out_ptrs10, d_out_ptrs11, d_out_ptrs12, d_out_ptrs13, d_out_ptrs14, d_out_ptrs15, d_out_ptrs16, d_out_ptrs17, d_out_ptrs18, d_out_ptrs19, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, indices_group_ptrs7, indices_group_ptrs8, indices_group_ptrs9, indices_group_ptrs10, indices_group_ptrs11, indices_group_ptrs12, indices_group_ptrs13, indices_group_ptrs14, indices_group_ptrs15, indices_group_ptrs16, indices_group_ptrs17, indices_group_ptrs18, indices_group_ptrs19, d_in_ptrs0, d_in_ptrs1, d_in_ptrs2, d_in_ptrs3, d_in_ptrs4, d_in_ptrs5, d_in_ptrs6, d_in_ptrs7, d_in_ptrs8, d_in_ptrs9, d_in_ptrs10, d_in_ptrs11, d_in_ptrs12, d_in_ptrs13, d_in_ptrs14, d_in_ptrs15, d_in_ptrs16, d_in_ptrs17, d_in_ptrs18, d_in_ptrs19, single_d_out_ptr, d_out_stride_0: tl.constexpr, d_out_stride_1: tl.constexpr, d_in_stride_0: tl.constexpr, d_in_stride_1: tl.constexpr, indices_stride: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - indices_ptr = indices_group_ptrs0 | |
| - d_in_ptr = d_in_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 0 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs1 | |
| - d_in_ptr = d_in_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 1 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs2 | |
| - d_in_ptr = d_in_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 2 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs3 | |
| - d_in_ptr = d_in_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 3 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs4 | |
| - d_in_ptr = d_in_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 4 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs5 | |
| - d_in_ptr = d_in_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 5 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs6 | |
| - d_in_ptr = d_in_ptrs6 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 6 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs6 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs7 | |
| - d_in_ptr = d_in_ptrs7 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 7 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs7 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs8 | |
| - d_in_ptr = d_in_ptrs8 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 8 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs8 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs9 | |
| - d_in_ptr = d_in_ptrs9 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 9 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs9 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs10 | |
| - d_in_ptr = d_in_ptrs10 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 10 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs10 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs11 | |
| - d_in_ptr = d_in_ptrs11 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 11 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs11 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs12 | |
| - d_in_ptr = d_in_ptrs12 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 12 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs12 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs13 | |
| - d_in_ptr = d_in_ptrs13 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 13 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs13 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs14 | |
| - d_in_ptr = d_in_ptrs14 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 14 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs14 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs15 | |
| - d_in_ptr = d_in_ptrs15 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 15 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs15 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs16 | |
| - d_in_ptr = d_in_ptrs16 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 16 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs16 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs17 | |
| - d_in_ptr = d_in_ptrs17 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 17 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs17 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs18 | |
| - d_in_ptr = d_in_ptrs18 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 18 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs18 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs19 | |
| - d_in_ptr = d_in_ptrs19 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 19 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs19 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel_13) | |
| +)torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel) | |
| +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -2122,476 +697,8 @@ | |
| ) | |
| ], | |
| key=[] | |
| -) | |
| -@triton.jit | |
| -def group_index_select_backward_kernel_14(d_out_ptrs0, d_out_ptrs1, d_out_ptrs2, d_out_ptrs3, d_out_ptrs4, d_out_ptrs5, d_out_ptrs6, d_out_ptrs7, d_out_ptrs8, d_out_ptrs9, d_out_ptrs10, d_out_ptrs11, d_out_ptrs12, d_out_ptrs13, d_out_ptrs14, d_out_ptrs15, d_out_ptrs16, d_out_ptrs17, d_out_ptrs18, d_out_ptrs19, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, indices_group_ptrs7, indices_group_ptrs8, indices_group_ptrs9, indices_group_ptrs10, indices_group_ptrs11, indices_group_ptrs12, indices_group_ptrs13, indices_group_ptrs14, indices_group_ptrs15, indices_group_ptrs16, indices_group_ptrs17, indices_group_ptrs18, indices_group_ptrs19, d_in_ptrs0, d_in_ptrs1, d_in_ptrs2, d_in_ptrs3, d_in_ptrs4, d_in_ptrs5, d_in_ptrs6, d_in_ptrs7, d_in_ptrs8, d_in_ptrs9, d_in_ptrs10, d_in_ptrs11, d_in_ptrs12, d_in_ptrs13, d_in_ptrs14, d_in_ptrs15, d_in_ptrs16, d_in_ptrs17, d_in_ptrs18, d_in_ptrs19, single_d_out_ptr, d_out_stride_0: tl.constexpr, d_out_stride_1: tl.constexpr, d_in_stride_0: tl.constexpr, d_in_stride_1: tl.constexpr, indices_stride: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - indices_ptr = indices_group_ptrs0 | |
| - d_in_ptr = d_in_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 0 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs1 | |
| - d_in_ptr = d_in_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 1 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs2 | |
| - d_in_ptr = d_in_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 2 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs3 | |
| - d_in_ptr = d_in_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 3 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs4 | |
| - d_in_ptr = d_in_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 4 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs5 | |
| - d_in_ptr = d_in_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 5 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs6 | |
| - d_in_ptr = d_in_ptrs6 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 6 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs6 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs7 | |
| - d_in_ptr = d_in_ptrs7 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 7 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs7 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs8 | |
| - d_in_ptr = d_in_ptrs8 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 8 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs8 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs9 | |
| - d_in_ptr = d_in_ptrs9 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 9 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs9 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs10 | |
| - d_in_ptr = d_in_ptrs10 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 10 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs10 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs11 | |
| - d_in_ptr = d_in_ptrs11 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 11 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs11 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs12 | |
| - d_in_ptr = d_in_ptrs12 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 12 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs12 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs13 | |
| - d_in_ptr = d_in_ptrs13 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 13 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs13 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs14 | |
| - d_in_ptr = d_in_ptrs14 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 14 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs14 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs15 | |
| - d_in_ptr = d_in_ptrs15 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 15 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs15 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs16 | |
| - d_in_ptr = d_in_ptrs16 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 16 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs16 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs17 | |
| - d_in_ptr = d_in_ptrs17 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 17 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs17 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs18 | |
| - d_in_ptr = d_in_ptrs18 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 18 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs18 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs19 | |
| - d_in_ptr = d_in_ptrs19 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 19 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs19 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel_14) | |
| - | |
| -@triton.jit | |
| -def group_index_select_backward_kernel_15(d_out_ptrs0, d_out_ptrs1, d_out_ptrs2, d_out_ptrs3, d_out_ptrs4, d_out_ptrs5, d_out_ptrs6, d_out_ptrs7, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, indices_group_ptrs7, d_in_ptrs0, d_in_ptrs1, d_in_ptrs2, d_in_ptrs3, d_in_ptrs4, d_in_ptrs5, d_in_ptrs6, d_in_ptrs7, single_d_out_ptr, d_out_stride_0: tl.constexpr, d_out_stride_1: tl.constexpr, d_in_stride_0: tl.constexpr, d_in_stride_1: tl.constexpr, indices_stride: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - indices_ptr = indices_group_ptrs0 | |
| - d_in_ptr = d_in_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 0 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs1 | |
| - d_in_ptr = d_in_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 1 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs2 | |
| - d_in_ptr = d_in_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 2 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs3 | |
| - d_in_ptr = d_in_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 3 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs4 | |
| - d_in_ptr = d_in_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 4 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs5 | |
| - d_in_ptr = d_in_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 5 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs6 | |
| - d_in_ptr = d_in_ptrs6 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 6 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs6 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs7 | |
| - d_in_ptr = d_in_ptrs7 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 7 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs7 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel_15) | |
| +)torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel) | |
| +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -2601,252 +708,8 @@ | |
| ) | |
| ], | |
| key=[] | |
| -) | |
| -@triton.jit | |
| -def group_index_select_backward_kernel_16(d_out_ptrs0, d_out_ptrs1, d_out_ptrs2, d_out_ptrs3, d_out_ptrs4, d_out_ptrs5, d_out_ptrs6, d_out_ptrs7, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, indices_group_ptrs7, d_in_ptrs0, d_in_ptrs1, d_in_ptrs2, d_in_ptrs3, d_in_ptrs4, d_in_ptrs5, d_in_ptrs6, d_in_ptrs7, single_d_out_ptr, d_out_stride_0: tl.constexpr, d_out_stride_1: tl.constexpr, d_in_stride_0: tl.constexpr, d_in_stride_1: tl.constexpr, indices_stride: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - indices_ptr = indices_group_ptrs0 | |
| - d_in_ptr = d_in_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 0 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs1 | |
| - d_in_ptr = d_in_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 1 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs2 | |
| - d_in_ptr = d_in_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 2 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs3 | |
| - d_in_ptr = d_in_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 3 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs4 | |
| - d_in_ptr = d_in_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 4 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs5 | |
| - d_in_ptr = d_in_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 5 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs6 | |
| - d_in_ptr = d_in_ptrs6 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 6 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs6 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs7 | |
| - d_in_ptr = d_in_ptrs7 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 7 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs7 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel_16) | |
| - | |
| -@triton.jit | |
| -def group_index_select_backward_kernel_17(d_out_ptrs0, d_out_ptrs1, d_out_ptrs2, d_out_ptrs3, d_out_ptrs4, d_out_ptrs5, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, d_in_ptrs0, d_in_ptrs1, d_in_ptrs2, d_in_ptrs3, d_in_ptrs4, d_in_ptrs5, single_d_out_ptr, d_out_stride_0: tl.constexpr, d_out_stride_1: tl.constexpr, d_in_stride_0: tl.constexpr, d_in_stride_1: tl.constexpr, indices_stride: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - indices_ptr = indices_group_ptrs0 | |
| - d_in_ptr = d_in_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 0 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs1 | |
| - d_in_ptr = d_in_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 1 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs2 | |
| - d_in_ptr = d_in_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 2 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs3 | |
| - d_in_ptr = d_in_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 3 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs4 | |
| - d_in_ptr = d_in_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 4 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs5 | |
| - d_in_ptr = d_in_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 5 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel_17) | |
| +)torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel) | |
| +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -2856,284 +719,8 @@ | |
| ) | |
| ], | |
| key=[] | |
| -) | |
| -@triton.jit | |
| -def group_index_select_backward_kernel_18(d_out_ptrs0, d_out_ptrs1, d_out_ptrs2, d_out_ptrs3, d_out_ptrs4, d_out_ptrs5, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, d_in_ptrs0, d_in_ptrs1, d_in_ptrs2, d_in_ptrs3, d_in_ptrs4, d_in_ptrs5, single_d_out_ptr, d_out_stride_0: tl.constexpr, d_out_stride_1: tl.constexpr, d_in_stride_0: tl.constexpr, d_in_stride_1: tl.constexpr, indices_stride: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - indices_ptr = indices_group_ptrs0 | |
| - d_in_ptr = d_in_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 0 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs1 | |
| - d_in_ptr = d_in_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 1 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs2 | |
| - d_in_ptr = d_in_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 2 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs3 | |
| - d_in_ptr = d_in_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 3 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs4 | |
| - d_in_ptr = d_in_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 4 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs5 | |
| - d_in_ptr = d_in_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 5 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel_18) | |
| - | |
| -@triton.jit | |
| -def group_index_select_backward_kernel_19(d_out_ptrs0, d_out_ptrs1, d_out_ptrs2, d_out_ptrs3, d_out_ptrs4, d_out_ptrs5, d_out_ptrs6, d_out_ptrs7, d_out_ptrs8, d_out_ptrs9, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, indices_group_ptrs7, indices_group_ptrs8, indices_group_ptrs9, d_in_ptrs0, d_in_ptrs1, d_in_ptrs2, d_in_ptrs3, d_in_ptrs4, d_in_ptrs5, d_in_ptrs6, d_in_ptrs7, d_in_ptrs8, d_in_ptrs9, single_d_out_ptr, d_out_stride_0: tl.constexpr, d_out_stride_1: tl.constexpr, d_in_stride_0: tl.constexpr, d_in_stride_1: tl.constexpr, indices_stride: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - indices_ptr = indices_group_ptrs0 | |
| - d_in_ptr = d_in_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 0 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs1 | |
| - d_in_ptr = d_in_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 1 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs2 | |
| - d_in_ptr = d_in_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 2 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs3 | |
| - d_in_ptr = d_in_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 3 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs4 | |
| - d_in_ptr = d_in_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 4 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs5 | |
| - d_in_ptr = d_in_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 5 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs6 | |
| - d_in_ptr = d_in_ptrs6 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 6 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs6 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs7 | |
| - d_in_ptr = d_in_ptrs7 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 7 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs7 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs8 | |
| - d_in_ptr = d_in_ptrs8 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 8 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs8 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs9 | |
| - d_in_ptr = d_in_ptrs9 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 9 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs9 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel_19) | |
| +)torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel) | |
| +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -3143,332 +730,8 @@ | |
| ) | |
| ], | |
| key=[] | |
| -) | |
| -@triton.jit | |
| -def group_index_select_backward_kernel_20(d_out_ptrs0, d_out_ptrs1, d_out_ptrs2, d_out_ptrs3, d_out_ptrs4, d_out_ptrs5, d_out_ptrs6, d_out_ptrs7, d_out_ptrs8, d_out_ptrs9, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, indices_group_ptrs7, indices_group_ptrs8, indices_group_ptrs9, d_in_ptrs0, d_in_ptrs1, d_in_ptrs2, d_in_ptrs3, d_in_ptrs4, d_in_ptrs5, d_in_ptrs6, d_in_ptrs7, d_in_ptrs8, d_in_ptrs9, single_d_out_ptr, d_out_stride_0: tl.constexpr, d_out_stride_1: tl.constexpr, d_in_stride_0: tl.constexpr, d_in_stride_1: tl.constexpr, indices_stride: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - indices_ptr = indices_group_ptrs0 | |
| - d_in_ptr = d_in_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 0 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs1 | |
| - d_in_ptr = d_in_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 1 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs2 | |
| - d_in_ptr = d_in_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 2 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs3 | |
| - d_in_ptr = d_in_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 3 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs4 | |
| - d_in_ptr = d_in_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 4 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs5 | |
| - d_in_ptr = d_in_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 5 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs6 | |
| - d_in_ptr = d_in_ptrs6 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 6 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs6 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs7 | |
| - d_in_ptr = d_in_ptrs7 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 7 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs7 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs8 | |
| - d_in_ptr = d_in_ptrs8 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 8 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs8 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs9 | |
| - d_in_ptr = d_in_ptrs9 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 9 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs9 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel_20) | |
| - | |
| -@triton.jit | |
| -def group_index_select_backward_kernel_21(d_out_ptrs0, d_out_ptrs1, d_out_ptrs2, d_out_ptrs3, d_out_ptrs4, d_out_ptrs5, d_out_ptrs6, d_out_ptrs7, d_out_ptrs8, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, indices_group_ptrs7, indices_group_ptrs8, d_in_ptrs0, d_in_ptrs1, d_in_ptrs2, d_in_ptrs3, d_in_ptrs4, d_in_ptrs5, d_in_ptrs6, d_in_ptrs7, d_in_ptrs8, single_d_out_ptr, d_out_stride_0: tl.constexpr, d_out_stride_1: tl.constexpr, d_in_stride_0: tl.constexpr, d_in_stride_1: tl.constexpr, indices_stride: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - indices_ptr = indices_group_ptrs0 | |
| - d_in_ptr = d_in_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 0 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs1 | |
| - d_in_ptr = d_in_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 1 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs2 | |
| - d_in_ptr = d_in_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 2 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs3 | |
| - d_in_ptr = d_in_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 3 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs4 | |
| - d_in_ptr = d_in_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 4 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs5 | |
| - d_in_ptr = d_in_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 5 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs6 | |
| - d_in_ptr = d_in_ptrs6 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 6 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs6 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs7 | |
| - d_in_ptr = d_in_ptrs7 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 7 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs7 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs8 | |
| - d_in_ptr = d_in_ptrs8 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 8 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs8 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel_21) | |
| +)torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel) | |
| +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -3478,284 +741,8 @@ | |
| ) | |
| ], | |
| key=[] | |
| -) | |
| -@triton.jit | |
| -def group_index_select_backward_kernel_22(d_out_ptrs0, d_out_ptrs1, d_out_ptrs2, d_out_ptrs3, d_out_ptrs4, d_out_ptrs5, d_out_ptrs6, d_out_ptrs7, d_out_ptrs8, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, indices_group_ptrs7, indices_group_ptrs8, d_in_ptrs0, d_in_ptrs1, d_in_ptrs2, d_in_ptrs3, d_in_ptrs4, d_in_ptrs5, d_in_ptrs6, d_in_ptrs7, d_in_ptrs8, single_d_out_ptr, d_out_stride_0: tl.constexpr, d_out_stride_1: tl.constexpr, d_in_stride_0: tl.constexpr, d_in_stride_1: tl.constexpr, indices_stride: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - indices_ptr = indices_group_ptrs0 | |
| - d_in_ptr = d_in_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 0 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs1 | |
| - d_in_ptr = d_in_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 1 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs2 | |
| - d_in_ptr = d_in_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 2 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs3 | |
| - d_in_ptr = d_in_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 3 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs4 | |
| - d_in_ptr = d_in_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 4 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs5 | |
| - d_in_ptr = d_in_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 5 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs6 | |
| - d_in_ptr = d_in_ptrs6 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 6 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs6 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs7 | |
| - d_in_ptr = d_in_ptrs7 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 7 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs7 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs8 | |
| - d_in_ptr = d_in_ptrs8 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 8 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs8 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel_22) | |
| - | |
| -@triton.jit | |
| -def group_index_select_backward_kernel_23(d_out_ptrs0, d_out_ptrs1, d_out_ptrs2, d_out_ptrs3, d_out_ptrs4, d_out_ptrs5, d_out_ptrs6, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, d_in_ptrs0, d_in_ptrs1, d_in_ptrs2, d_in_ptrs3, d_in_ptrs4, d_in_ptrs5, d_in_ptrs6, single_d_out_ptr, d_out_stride_0: tl.constexpr, d_out_stride_1: tl.constexpr, d_in_stride_0: tl.constexpr, d_in_stride_1: tl.constexpr, indices_stride: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - indices_ptr = indices_group_ptrs0 | |
| - d_in_ptr = d_in_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 0 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs1 | |
| - d_in_ptr = d_in_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 1 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs2 | |
| - d_in_ptr = d_in_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 2 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs3 | |
| - d_in_ptr = d_in_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 3 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs4 | |
| - d_in_ptr = d_in_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 4 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs5 | |
| - d_in_ptr = d_in_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 5 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs6 | |
| - d_in_ptr = d_in_ptrs6 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 6 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs6 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel_23) | |
| +)torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel) | |
| +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel) | |
| @triton.autotune( | |
| configs=[ | |
| triton.Config( | |
| @@ -3765,145 +752,17 @@ | |
| ) | |
| ], | |
| key=[] | |
| -) | |
| -@triton.jit | |
| -def group_index_select_backward_kernel_24(d_out_ptrs0, d_out_ptrs1, d_out_ptrs2, d_out_ptrs3, d_out_ptrs4, d_out_ptrs5, d_out_ptrs6, indices_group_ptrs0, indices_group_ptrs1, indices_group_ptrs2, indices_group_ptrs3, indices_group_ptrs4, indices_group_ptrs5, indices_group_ptrs6, d_in_ptrs0, d_in_ptrs1, d_in_ptrs2, d_in_ptrs3, d_in_ptrs4, d_in_ptrs5, d_in_ptrs6, single_d_out_ptr, d_out_stride_0: tl.constexpr, d_out_stride_1: tl.constexpr, d_in_stride_0: tl.constexpr, d_in_stride_1: tl.constexpr, indices_stride: tl.constexpr, D: tl.constexpr, O, O_BUCKET, BLOCK_D: tl.constexpr, BLOCK_O: tl.constexpr, SM_RATIO: tl.constexpr, FUSED_OUTPUT: tl.constexpr, USE_I64_PID: tl.constexpr) -> None: | |
| - pid_idx = tl.program_id(0) | |
| - if USE_I64_PID: | |
| - pid_idx = pid_idx.to(tl.int64) | |
| - num_pids = tl.num_programs(0) | |
| - if pid_idx * BLOCK_O >= O: | |
| - return | |
| - max_pid = tl.cdiv(O, BLOCK_O) | |
| - indices_ptr = indices_group_ptrs0 | |
| - d_in_ptr = d_in_ptrs0 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 0 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs0 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs1 | |
| - d_in_ptr = d_in_ptrs1 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 1 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs1 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs2 | |
| - d_in_ptr = d_in_ptrs2 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 2 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs2 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs3 | |
| - d_in_ptr = d_in_ptrs3 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 3 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs3 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs4 | |
| - d_in_ptr = d_in_ptrs4 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 4 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs4 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs5 | |
| - d_in_ptr = d_in_ptrs5 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 5 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs5 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| - indices_ptr = indices_group_ptrs6 | |
| - d_in_ptr = d_in_ptrs6 | |
| - for j in range(pid_idx, max_pid, num_pids): | |
| - o_offsets = j * BLOCK_O + tl.arange(0, BLOCK_O) | |
| - d_offsets = tl.arange(0, BLOCK_D) | |
| - if FUSED_OUTPUT: | |
| - d_out_ptr = single_d_out_ptr | |
| - d_offsets_output = 6 * D + d_offsets | |
| - else: | |
| - d_out_ptr = d_out_ptrs6 | |
| - d_offsets_output = d_offsets | |
| - indices = tl.load(indices_ptr + o_offsets * indices_stride, mask=o_offsets < O) | |
| - d_out_block_ptr = d_out_ptr + o_offsets[:, None] * d_out_stride_0 + d_offsets_output[None, :] * d_out_stride_1 | |
| - d_out = tl.load(d_out_block_ptr, mask=(d_offsets[None, :] < D) & (o_offsets[:, None] < O)) | |
| - d_in_block_ptr = d_in_ptr + indices[:, None] * d_in_stride_0 + d_offsets[None, :] * d_in_stride_1 | |
| - tl.atomic_add(d_in_block_ptr, d_out.to(d_in_ptr.dtype.element_ty), mask=d_offsets[None, :] < D, sem='relaxed') | |
| -torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel_24) | |
| +)torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(group_index_select_backward_kernel) | |
| torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.constant_args={0: {}, 1: {'clip_limit': None}, 2: {}, 3: {'clip_limit': None}, 4: {}, 5: {'clip_limit': None}, 6: {}, 7: {'clip_limit': None}, 8: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None}, 9: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None, 'output_group_ptrs8': None}, 10: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None, 'output_group_ptrs8': None, 'output_group_ptrs9': None}, 11: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None}, 12: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None}, 13: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None, 'output_group_ptrs8': None, 'output_group_ptrs9': None, 'output_group_ptrs10': None, 'output_group_ptrs11': None, 'output_group_ptrs12': None, 'output_group_ptrs13': None, 'output_group_ptrs14': None, 'output_group_ptrs15': None, 'output_group_ptrs16': None, 'output_group_ptrs17': None, 'output_group_ptrs18': None, 'output_group_ptrs19': None}, 14: {'output_group_ptrs0': None}, 15: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None}, 16: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None, 'output_group_ptrs8': None}, 17: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None, 'output_group_ptrs8': None, 'output_group_ptrs9': None}, 18: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None}, 19: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None}, 20: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None, 'output_group_ptrs8': None, 'output_group_ptrs9': None, 'output_group_ptrs10': None, 'output_group_ptrs11': None, 'output_group_ptrs12': None, 'output_group_ptrs13': None, 'output_group_ptrs14': None, 'output_group_ptrs15': None, 'output_group_ptrs16': None, 'output_group_ptrs17': None, 'output_group_ptrs18': None, 'output_group_ptrs19': None}, 21: {'output_group_ptrs0': None}, 22: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None}, 23: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None, 'output_group_ptrs8': None}, 24: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None, 'output_group_ptrs8': None, 'output_group_ptrs9': None}, 25: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None}, 26: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None}, 27: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None, 'output_group_ptrs8': None, 'output_group_ptrs9': None, 'output_group_ptrs10': None, 'output_group_ptrs11': None, 'output_group_ptrs12': None, 'output_group_ptrs13': None, 'output_group_ptrs14': None, 'output_group_ptrs15': None, 'output_group_ptrs16': None, 'output_group_ptrs17': None, 'output_group_ptrs18': None, 'output_group_ptrs19': None}, 28: {'output_group_ptrs0': None}, 29: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None}, 30: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None, 'output_group_ptrs8': None}, 31: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None, 'output_group_ptrs8': None, 'output_group_ptrs9': None}, 32: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None}, 33: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None}, 34: {'output_group_ptrs0': None, 'output_group_ptrs1': None, 'output_group_ptrs2': None, 'output_group_ptrs3': None, 'output_group_ptrs4': None, 'output_group_ptrs5': None, 'output_group_ptrs6': None, 'output_group_ptrs7': None, 'output_group_ptrs8': None, 'output_group_ptrs9': None, 'output_group_ptrs10': None, 'output_group_ptrs11': None, 'output_group_ptrs12': None, 'output_group_ptrs13': None, 'output_group_ptrs14': None, 'output_group_ptrs15': None, 'output_group_ptrs16': None, 'output_group_ptrs17': None, 'output_group_ptrs18': None, 'output_group_ptrs19': None}, 35: {'output_group_ptrs0': None}, 36: {'d_out_ptrs0': None}, 37: {'d_out_ptrs0': None, 'd_out_ptrs1': None, 'd_out_ptrs2': None, 'd_out_ptrs3': None, 'd_out_ptrs4': None, 'd_out_ptrs5': None, 'd_out_ptrs6': None, 'd_out_ptrs7': None, 'd_out_ptrs8': None, 'd_out_ptrs9': None, 'd_out_ptrs10': None, 'd_out_ptrs11': None, 'd_out_ptrs12': None, 'd_out_ptrs13': None, 'd_out_ptrs14': None, 'd_out_ptrs15': None, 'd_out_ptrs16': None, 'd_out_ptrs17': None, 'd_out_ptrs18': None, 'd_out_ptrs19': None}, 38: {'d_out_ptrs0': None, 'd_out_ptrs1': None, 'd_out_ptrs2': None, 'd_out_ptrs3': None, 'd_out_ptrs4': None, 'd_out_ptrs5': None, 'd_out_ptrs6': None, 'd_out_ptrs7': None}, 39: {'d_out_ptrs0': None, 'd_out_ptrs1': None, 'd_out_ptrs2': None, 'd_out_ptrs3': None, 'd_out_ptrs4': None, 'd_out_ptrs5': None}, 40: {'d_out_ptrs0': None, 'd_out_ptrs1': None, 'd_out_ptrs2': None, 'd_out_ptrs3': None, 'd_out_ptrs4': None, 'd_out_ptrs5': None, 'd_out_ptrs6': None, 'd_out_ptrs7': None, 'd_out_ptrs8': None, 'd_out_ptrs9': None}, 41: {'d_out_ptrs0': None, 'd_out_ptrs1': None, 'd_out_ptrs2': None, 'd_out_ptrs3': None, 'd_out_ptrs4': None, 'd_out_ptrs5': None, 'd_out_ptrs6': None, 'd_out_ptrs7': None, 'd_out_ptrs8': None}, 42: {'d_out_ptrs0': None, 'd_out_ptrs1': None, 'd_out_ptrs2': None, 'd_out_ptrs3': None, 'd_out_ptrs4': None, 'd_out_ptrs5': None, 'd_out_ptrs6': None}} | |
| from torch.nn import * | |
| - | |
| -def no_scale(a, b): | |
| - return a | |
| - | |
| -# torch.ops.fb.scale_gradient.default = no_scale | |
| - | |
| class Repro(torch.nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| - def forward(self, fp8_quant_pos_154_primals_1, fp8_quant_pos_155_primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, fp8_quant_pos_259_primals_110, fp8_quant_pos_260_primals_111, fp8_quant_pos_261_primals_112, primals_117, primals_118, primals_119, primals_120, primals_123, primals_124, primals_125, primals_126, primals_127, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_177, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_193, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, primals_294, primals_295, primals_296, primals_297, primals_298, primals_299, primals_300, primals_301, primals_302, primals_303, primals_304, primals_305, primals_306, primals_307, primals_308, primals_309, primals_310, primals_311, primals_312, primals_313, primals_314, primals_315, primals_316, primals_317, primals_318, primals_319, primals_320, primals_321, primals_322, primals_323, primals_324, primals_325, primals_326, primals_327, primals_328, primals_329, primals_330, primals_331, primals_332, primals_333, primals_334, primals_335, primals_336, primals_337, primals_338, primals_339, primals_340, primals_341, primals_342, primals_343, primals_344, primals_345, primals_346, primals_347, primals_348, primals_349, primals_350, primals_351, primals_352, primals_353, primals_354, primals_355, primals_356, primals_357, primals_358, primals_359, primals_360, primals_361, primals_362, primals_363, primals_364, primals_365, primals_366, primals_367, primals_368, primals_369, primals_370, primals_371, primals_372, primals_373, primals_374, primals_375, primals_376, primals_377, primals_378, primals_379, primals_380, primals_381, primals_382, primals_383, primals_384, primals_385, primals_386, primals_387, primals_388, primals_389, primals_390, primals_391, primals_392, primals_393, primals_394, primals_395, primals_396, primals_397, primals_398, primals_399, primals_400, primals_401, primals_402, primals_403, primals_404, primals_405, primals_406, primals_407, primals_408, primals_409, primals_410, primals_411, primals_412, primals_413, primals_414, primals_415, primals_416, primals_417, primals_418, primals_419, primals_420, primals_422, primals_423, primals_425, primals_426, primals_427, primals_429, primals_431, primals_433, primals_434, primals_436, primals_438, primals_439, primals_440, primals_441, primals_448, primals_450, primals_454, primals_455, primals_457, primals_460, primals_461, primals_463, primals_468, primals_469, primals_471, primals_474, primals_476, primals_478, primals_480, primals_482, primals_483, primals_484, primals_486, primals_487, primals_489, primals_493, primals_494, primals_496, primals_499, primals_500, primals_502, primals_506, primals_508, primals_510, primals_511, primals_512, primals_513, primals_515, primals_518, primals_519, primals_521, primals_524, primals_525, primals_527, primals_532, primals_533, primals_535, primals_538, primals_539, primals_541, primals_543, primals_545, primals_548, primals_549, primals_551, primals_553, primals_555, primals_558, primals_559, primals_561, primals_563, primals_565, primals_568, primals_569, primals_571, primals_573, primals_575, primals_577, primals_579, primals_580, primals_581, primals_583, primals_584, primals_586, primals_587, primals_589, primals_591, primals_593, primals_596, primals_597, primals_599, primals_601, primals_603, primals_606, primals_607, primals_609, primals_612, primals_613, primals_615, cat, clone, relu, fp8_quant_pos_661_add_6, getitem_27, getitem_28, getitem_33, getitem_34, mm_9, mm_10, mm_13, view_47, addmm_8, view_57, addmm_13, view_67, cat_6, fp8_quant_pos_675_mul_40, getitem_89, getitem_90, getitem_95, getitem_96, mm_17, mm_18, view_92, mm_21, view_95, cat_7, fp8_quant_pos_686_mul_58, mm_25, mm_26, mm_29, view_123, cat_8, fp8_quant_pos_692_mul_76, mm_33, mm_34, view_148, mm_37, view_151, cat_9, fp8_quant_pos_699_mul_94, mm_41, mm_42, view_176, mm_45, view_179, sub_32, mul_115, mul_117, fp8_quant_pos_708_view_181, fp8_quant_pos_709_view_183, view_193, view_194, mm_53, mm_54, add_111, fp8_quant_pos_715_bmm_3, fp8_quant_pos_716_bmm_4, slice_19, add_117, mm_59, mm_60, rsqrt_54, mul_154, rsqrt_55, mul_156, mm_63, mm_64, mm_65, mm_66, mul_170, mul_171, view_225, view_227, fp8_quant_pos_733_cat_16, fp8_quant_pos_734_cat_17, mm_75, mm_76, mm_77, mm_78, fp8_quant_pos_739_bmm_7, expand_14, fp8_quant_pos_741_expand_15, slice_46, add_154, mm_81, mm_82, mm_83, mm_84, mm_85, mm_86, mm_87, mm_88, view_273, view_275, fp8_quant_pos_754_cat_21, fp8_quant_pos_755_cat_22, mm_93, mm_94, mm_95, mm_96, expand_20, _scaled_mm_8, fp8_quant_pos_762_convert_element_type_735, _scaled_mm_9, mm_99, mm_100, mm_101, mm_102, mm_103, mm_104, add_203, addmm_30, mm_107, addmm_31, addmm_32, convert_element_type_784, relu_1, addmm_34, relu_2, addmm_36, mul_330, convert_element_type_792, relu_3, addmm_38, convert_element_type_794, relu_4, addmm_40, convert_element_type_798, relu_5, addmm_42, convert_element_type_800, relu_6, addmm_44, relu_7, relu_8, convert_element_type_808, relu_9, addmm_50, convert_element_type_810, relu_10, addmm_52, relu_11, relu_12, le_10, convert_element_type_814, relu_13, addmm_57, relu_14, addmm_59, le_12, convert_element_type_818, relu_15, addmm_61, convert_element_type_822, relu_16, addmm_63, convert_element_type_824, relu_17, addmm_65, convert_element_type_829, relu_18, addmm_67, convert_element_type_834, relu_19, addmm_69, convert_element_type_839, relu_20, addmm_71, relu_21, relu_22, convert_element_type_848, relu_23, addmm_77, convert_element_type_853, relu_24, addmm_79, relu_25, addmm_81, le_21, relu_26, addmm_83, le_22, sub_136, sub_138, sub_160, sub_164, sub_170, sub_183, permute_709, permute_712, fp8_quant_pos_850_permute_714, permute_844, fp8_quant_pos_852_permute_845, permute_847, fp8_quant_pos_854_permute_848, fp8_quant_pos_855_permute_851, fp8_quant_pos_856_permute_853, div_144, div_150, div_156, div_162, fp8_scale_pos_154_primals_1, fp8_scale_pos_155_primals_2, fp8_scale_pos_259_primals_110, fp8_scale_pos_260_primals_111, fp8_scale_pos_261_primals_112, fp8_scale_pos_661_add_6, fp8_scale_pos_675_mul_40, fp8_scale_pos_686_mul_58, fp8_scale_pos_692_mul_76, fp8_scale_pos_699_mul_94, fp8_scale_pos_708_view_181, fp8_scale_pos_709_view_183, fp8_scale_pos_715_bmm_3, fp8_scale_pos_716_bmm_4, fp8_scale_pos_733_cat_16, fp8_scale_pos_734_cat_17, fp8_scale_pos_739_bmm_7, fp8_scale_pos_741_expand_15, fp8_scale_pos_754_cat_21, fp8_scale_pos_755_cat_22, fp8_scale_pos_762_convert_element_type_735, fp8_scale_pos_850_permute_714, fp8_scale_pos_852_permute_845, fp8_scale_pos_854_permute_848, fp8_scale_pos_855_permute_851, fp8_scale_pos_856_permute_853, tangents_1, tangents_2, tangents_3, tangents_4, tangents_5, tangents_6, tangents_7, tangents_8): | |
| + def forward(self, fp8_quant_pos_154_primals_1, fp8_quant_pos_155_primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, fp8_quant_pos_259_primals_110, fp8_quant_pos_260_primals_111, fp8_quant_pos_261_primals_112, primals_117, primals_118, primals_119, primals_120, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_177, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_193, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, primals_294, primals_295, primals_296, primals_297, primals_298, primals_299, primals_300, primals_301, primals_302, primals_303, primals_304, primals_305, primals_306, primals_307, primals_308, primals_309, primals_310, primals_311, primals_312, primals_313, primals_314, primals_315, primals_316, primals_317, primals_318, primals_319, primals_320, primals_321, primals_322, primals_323, primals_324, primals_325, primals_326, primals_327, primals_328, primals_329, primals_330, primals_331, primals_332, primals_333, primals_334, primals_335, primals_336, primals_337, primals_338, primals_339, primals_340, primals_341, primals_342, primals_343, primals_344, primals_345, primals_346, primals_347, primals_348, primals_349, primals_350, primals_351, primals_352, primals_353, primals_354, primals_355, primals_356, primals_357, primals_358, primals_359, primals_360, primals_361, primals_362, primals_363, primals_364, primals_365, primals_366, primals_367, primals_368, primals_369, primals_370, primals_371, primals_372, primals_373, primals_374, primals_375, primals_376, primals_377, primals_378, primals_379, primals_380, primals_381, primals_382, primals_383, primals_384, primals_385, primals_386, primals_387, primals_388, primals_389, primals_390, primals_391, primals_392, primals_393, primals_394, primals_395, primals_396, primals_397, primals_398, primals_399, primals_400, primals_401, primals_402, primals_403, primals_404, primals_405, primals_406, primals_407, primals_408, primals_409, primals_410, primals_411, primals_412, primals_413, primals_414, primals_415, primals_416, primals_417, primals_418, primals_419, primals_420, primals_422, primals_423, primals_425, primals_426, primals_427, primals_429, primals_431, primals_433, primals_434, primals_436, primals_438, primals_439, primals_440, primals_441, primals_448, primals_450, primals_454, primals_455, primals_457, primals_460, primals_461, primals_463, primals_468, primals_469, primals_471, primals_474, primals_476, primals_478, primals_480, primals_482, primals_483, primals_484, primals_486, primals_487, primals_489, primals_493, primals_494, primals_496, primals_499, primals_500, primals_502, primals_506, primals_508, primals_510, primals_511, primals_512, primals_513, primals_515, primals_518, primals_519, primals_521, primals_524, primals_525, primals_527, primals_532, primals_533, primals_535, primals_538, primals_539, primals_541, primals_543, primals_545, primals_548, primals_549, primals_551, primals_553, primals_555, primals_558, primals_559, primals_561, primals_563, primals_565, primals_568, primals_569, primals_571, primals_573, primals_575, primals_577, primals_579, primals_580, primals_581, primals_583, primals_584, primals_586, primals_587, primals_589, primals_591, primals_593, primals_596, primals_597, primals_599, primals_601, primals_603, primals_606, primals_607, primals_609, primals_612, primals_613, primals_615, cat, clone, fp8_quant_pos_661_add_6, mm_9, mm_10, mm_13, view_47, convert_element_type_123, addmm_8, view_57, addmm_13, view_67, cat_6, fp8_quant_pos_672_mul_40, mm_17, mm_18, view_92, mm_21, view_95, cat_7, fp8_quant_pos_679_mul_58, mm_26, view_123, cat_8, fp8_quant_pos_683_mul_76, mm_34, view_148, view_151, cat_9, fp8_quant_pos_688_mul_94, mm_42, view_176, view_179, sub_32, mul_115, mul_117, fp8_quant_pos_695_view_181, fp8_quant_pos_696_view_183, view_193, view_194, mm_53, mm_54, add_111, fp8_quant_pos_702_bmm_3, fp8_quant_pos_703_bmm_4, slice_19, add_117, mm_59, mm_60, add_124, add_125, mm_63, mm_64, mm_65, mm_66, mul_170, mul_171, view_225, view_227, fp8_quant_pos_718_cat_16, fp8_quant_pos_719_cat_17, fp8_quant_pos_720_convert_element_type_545, fp8_quant_pos_721_convert_element_type_547, mm_75, mm_76, mm_77, mm_78, fp8_quant_pos_726_expand_8, expand_10, slice_46, add_154, mm_81, mm_82, mm_83, mm_84, mm_85, mm_86, mm_87, mm_88, view_273, view_275, fp8_quant_pos_740_cat_22, mm_93, mm_94, mm_96, fp8_quant_pos_744_bmm_11, bmm_12, _scaled_mm_8, _scaled_mm_9, mm_99, mm_100, mm_101, mm_102, mm_103, mm_104, add_205, convert_element_type_784, addmm_30, mm_107, addmm_31, addmm_32, convert_element_type_790, relu_1, addmm_34, relu_2, addmm_36, mul_330, convert_element_type_798, relu_3, addmm_38, convert_element_type_800, relu_4, addmm_40, convert_element_type_804, relu_5, addmm_42, convert_element_type_806, relu_6, addmm_44, relu_7, relu_8, convert_element_type_814, relu_9, addmm_50, convert_element_type_816, relu_10, addmm_52, relu_11, relu_12, le_10, convert_element_type_820, relu_13, addmm_57, relu_14, addmm_59, le_12, convert_element_type_824, relu_15, addmm_61, convert_element_type_828, relu_16, addmm_63, convert_element_type_830, relu_17, addmm_65, convert_element_type_835, relu_18, addmm_67, convert_element_type_840, relu_19, addmm_69, convert_element_type_845, relu_20, addmm_71, relu_21, relu_22, convert_element_type_854, relu_23, addmm_77, convert_element_type_859, relu_24, addmm_79, relu_25, addmm_81, le_21, relu_26, addmm_83, le_22, sub_136, sub_138, sub_160, sub_164, sub_170, sub_183, permute_844, fp8_quant_pos_834_permute_845, permute_847, fp8_quant_pos_836_permute_848, fp8_quant_pos_837_permute_851, fp8_quant_pos_838_permute_853, div_144, div_150, div_156, div_162, fp8_scale_pos_154_primals_1, fp8_scale_pos_155_primals_2, fp8_scale_pos_259_primals_110, fp8_scale_pos_260_primals_111, fp8_scale_pos_261_primals_112, fp8_scale_pos_661_add_6, fp8_scale_pos_672_mul_40, fp8_scale_pos_679_mul_58, fp8_scale_pos_683_mul_76, fp8_scale_pos_688_mul_94, fp8_scale_pos_695_view_181, fp8_scale_pos_696_view_183, fp8_scale_pos_702_bmm_3, fp8_scale_pos_703_bmm_4, fp8_scale_pos_718_cat_16, fp8_scale_pos_719_cat_17, fp8_scale_pos_720_convert_element_type_545, fp8_scale_pos_721_convert_element_type_547, fp8_scale_pos_726_expand_8, fp8_scale_pos_740_cat_22, fp8_scale_pos_744_bmm_11, fp8_scale_pos_834_permute_845, fp8_scale_pos_836_permute_848, fp8_scale_pos_837_permute_851, fp8_scale_pos_838_permute_853, tangents_1, tangents_2, tangents_3, tangents_4, tangents_5, tangents_6, tangents_7, tangents_8): | |
| slice_83 = torch.ops.aten.slice.Tensor(tangents_3, 0, 0, 29) | |
| slice_84 = torch.ops.aten.slice.Tensor(tangents_3, 0, 29, 30); tangents_3 = None | |
| add_267 = torch.ops.aten.add.Tensor(tangents_5, slice_84); tangents_5 = slice_84 = None | |
| @@ -4004,8 +863,8 @@ | |
| div_39 = torch.ops.aten.div.Scalar(expand_33, 4096); expand_33 = None | |
| full_default_21 = torch.ops.aten.full.default([], 1.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| full_default_61 = torch.ops.aten.full.default([], 0.800000011920929, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - convert_element_type_861 = torch.ops.prims.convert_element_type.default(primals_612, torch.float32); primals_612 = None | |
| - mul_479 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_861); convert_element_type_861 = None | |
| + convert_element_type_867 = torch.ops.prims.convert_element_type.default(primals_612, torch.float32); primals_612 = None | |
| + mul_479 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_867); convert_element_type_867 = None | |
| where_87 = torch.ops.aten.where.self(le_22, full_default_61, full_default_21); le_22 = None | |
| mul_481 = torch.ops.aten.mul.Tensor(where_87, 1.0); where_87 = None | |
| mul_482 = torch.ops.aten.mul.Tensor(mul_479, mul_481); mul_479 = mul_481 = None | |
| @@ -4036,40 +895,7 @@ | |
| permute_326 = torch.ops.aten.permute.default(permute_320, [1, 0]); permute_320 = None | |
| mm_111 = torch.ops.aten.mm.default(where_91, permute_326); permute_326 = None | |
| permute_327 = torch.ops.aten.permute.default(where_91, [1, 0]) | |
| - slice_74 = torch.ops.aten.slice.Tensor(_scaled_mm_8, 1, 3840, 9223372036854775807) | |
| - index_40 = torch.ops.aten.index.Tensor(slice_74, [sub_32]); slice_74 = None | |
| - add_189 = torch.ops.aten.add.Tensor(_scaled_mm_9, index_40); _scaled_mm_9 = index_40 = None | |
| - convert_element_type_744 = torch.ops.prims.convert_element_type.default(add_189, torch.float32); add_189 = None | |
| - pow_67 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_744, 2) | |
| - mean_66 = torch.ops.aten.mean.dim(pow_67, [1], True); pow_67 = None | |
| - add_191 = torch.ops.aten.add.Scalar(mean_66, 1.1920928955078125e-07); mean_66 = None | |
| - rsqrt_87 = torch.ops.aten.rsqrt.default(add_191); add_191 = None | |
| - mul_276 = torch.ops.aten.mul.Tensor(convert_element_type_744, rsqrt_87); convert_element_type_744 = None | |
| - mul_277 = torch.ops.aten.mul.Tensor(mul_276, primals_406) | |
| - sigmoid_40 = torch.ops.aten.sigmoid.default(mul_277) | |
| - mul_279 = torch.ops.aten.mul.Tensor(mul_277, sigmoid_40) | |
| - slice_78 = torch.ops.aten.slice.Tensor(mm_101, 1, 3840, 9223372036854775807) | |
| - index_42 = torch.ops.aten.index.Tensor(slice_78, [sub_32]); slice_78 = None | |
| - add_195 = torch.ops.aten.add.Tensor(mm_102, index_42); mm_102 = index_42 = None | |
| - add_197 = torch.ops.aten.add.Tensor(mul_279, add_195); add_195 = None | |
| - pow_71 = torch.ops.aten.pow.Tensor_Scalar(add_197, 2) | |
| - mean_70 = torch.ops.aten.mean.dim(pow_71, [1], True); pow_71 = None | |
| - add_199 = torch.ops.aten.add.Scalar(mean_70, 1.1920928955078125e-07); mean_70 = None | |
| - rsqrt_91 = torch.ops.aten.rsqrt.default(add_199); add_199 = None | |
| - mul_288 = torch.ops.aten.mul.Tensor(add_197, rsqrt_91); add_197 = None | |
| - mul_289 = torch.ops.aten.mul.Tensor(mul_288, primals_414) | |
| - sigmoid_44 = torch.ops.aten.sigmoid.default(mul_289) | |
| - mul_291 = torch.ops.aten.mul.Tensor(mul_289, sigmoid_44) | |
| - add_205 = torch.ops.aten.add.Tensor(mul_291, add_203); add_203 = None | |
| - pow_75 = torch.ops.aten.pow.Tensor_Scalar(add_205, 2) | |
| - mean_74 = torch.ops.aten.mean.dim(pow_75, [1], True); pow_75 = None | |
| - add_207 = torch.ops.aten.add.Scalar(mean_74, 1.1920928955078125e-07); mean_74 = None | |
| - rsqrt_95 = torch.ops.aten.rsqrt.default(add_207); add_207 = None | |
| - mul_300 = torch.ops.aten.mul.Tensor(add_205, rsqrt_95); add_205 = None | |
| - mul_301 = torch.ops.aten.mul.Tensor(mul_300, primals_422) | |
| - sigmoid_48 = torch.ops.aten.sigmoid.default(mul_301) | |
| - mul_303 = torch.ops.aten.mul.Tensor(mul_301, sigmoid_48) | |
| - scale_gradient_25 = no_scale(mul_303, 2.0) | |
| + scale_gradient_25 = torch.ops.fb.scale_gradient.default(convert_element_type_784, 2.0) | |
| mm_112 = torch.ops.aten.mm.default(permute_327, scale_gradient_25); permute_327 = None | |
| sum_30 = torch.ops.aten.sum.dim_IntList(where_91, [0], True); where_91 = None | |
| view_349 = torch.ops.aten.view.default(sum_30, [512]); sum_30 = None | |
| @@ -4077,8 +903,8 @@ | |
| mul_490 = torch.ops.aten.mul.Tensor(mm_111, full_default_161); mm_111 = None | |
| expand_34 = torch.ops.aten.expand.default(select_221, [4096]); select_221 = None | |
| div_40 = torch.ops.aten.div.Scalar(expand_34, 4096); expand_34 = None | |
| - convert_element_type_859 = torch.ops.prims.convert_element_type.default(primals_606, torch.float32); primals_606 = None | |
| - mul_474 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_859); convert_element_type_859 = None | |
| + convert_element_type_865 = torch.ops.prims.convert_element_type.default(primals_606, torch.float32); primals_606 = None | |
| + mul_474 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_865); convert_element_type_865 = None | |
| where_84 = torch.ops.aten.where.self(le_21, full_default_61, full_default_21); le_21 = None | |
| mul_476 = torch.ops.aten.mul.Tensor(where_84, 1.0); where_84 = None | |
| mul_477 = torch.ops.aten.mul.Tensor(mul_474, mul_476); mul_474 = mul_476 = None | |
| @@ -4115,9 +941,9 @@ | |
| add_289 = torch.ops.aten.add.Tensor(mul_490, mul_495); mul_490 = mul_495 = None | |
| full_default_22 = torch.ops.aten.full.default([], 0.10000000149011612, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| full_default_25 = torch.ops.aten.full.default([], 1, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - convert_element_type_854 = torch.ops.prims.convert_element_type.default(primals_596, torch.float32); primals_596 = None | |
| - mul_465 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_854); convert_element_type_854 = None | |
| - le_20 = torch.ops.aten.le.Scalar(convert_element_type_853, 0.5) | |
| + convert_element_type_860 = torch.ops.prims.convert_element_type.default(primals_596, torch.float32); primals_596 = None | |
| + mul_465 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_860); convert_element_type_860 = None | |
| + le_20 = torch.ops.aten.le.Scalar(convert_element_type_859, 0.5) | |
| where_79 = torch.ops.aten.where.self(le_20, full_default_22, full_default_21); le_20 = None | |
| mul_470 = torch.ops.aten.mul.Tensor(where_79, mul_465); where_79 = mul_465 = None | |
| mul_471 = torch.ops.aten.mul.Tensor(mul_470, 1.0); mul_470 = None | |
| @@ -4135,7 +961,7 @@ | |
| clamp_min_71 = torch.ops.aten.clamp_min.default(addmm_79, -5.1) | |
| clamp_max_47 = torch.ops.aten.clamp_max.default(clamp_min_71, 5.1); clamp_min_71 = None | |
| sigmoid_87 = torch.ops.aten.sigmoid.default(clamp_max_47); clamp_max_47 = None | |
| - unsqueeze_51 = torch.ops.aten.unsqueeze.default(convert_element_type_853, 1); convert_element_type_853 = None | |
| + unsqueeze_51 = torch.ops.aten.unsqueeze.default(convert_element_type_859, 1); convert_element_type_859 = None | |
| sub_140 = torch.ops.aten.sub.Tensor(sigmoid_87, unsqueeze_51); sigmoid_87 = unsqueeze_51 = None | |
| mul_498 = torch.ops.aten.mul.Tensor(sub_140, unsqueeze_84); sub_140 = unsqueeze_84 = None | |
| ge_2 = torch.ops.aten.ge.Scalar(addmm_79, -5.1) | |
| @@ -4172,7 +998,7 @@ | |
| permute_342 = torch.ops.aten.permute.default(permute_316, [1, 0]); permute_316 = None | |
| mm_119 = torch.ops.aten.mm.default(where_97, permute_342); permute_342 = None | |
| permute_343 = torch.ops.aten.permute.default(where_97, [1, 0]) | |
| - scale_gradient_24 = no_scale(mul_303, 0.3) | |
| + scale_gradient_24 = torch.ops.fb.scale_gradient.default(convert_element_type_784, 0.3) | |
| mm_120 = torch.ops.aten.mm.default(permute_343, scale_gradient_24); permute_343 = scale_gradient_24 = None | |
| sum_34 = torch.ops.aten.sum.dim_IntList(where_97, [0], True); where_97 = None | |
| view_355 = torch.ops.aten.view.default(sum_34, [512]); sum_34 = None | |
| @@ -4183,9 +1009,9 @@ | |
| div_42 = torch.ops.aten.div.Scalar(expand_36, 4096); expand_36 = None | |
| full_default_16 = torch.ops.aten.full.default([], 3.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| full_default_17 = torch.ops.aten.full.default([], 0.30000001192092896, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - convert_element_type_849 = torch.ops.prims.convert_element_type.default(primals_586, torch.float32); primals_586 = None | |
| - mul_457 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_849); convert_element_type_849 = None | |
| - le_19 = torch.ops.aten.le.Scalar(convert_element_type_848, 0.5) | |
| + convert_element_type_855 = torch.ops.prims.convert_element_type.default(primals_586, torch.float32); primals_586 = None | |
| + mul_457 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_855); convert_element_type_855 = None | |
| + le_19 = torch.ops.aten.le.Scalar(convert_element_type_854, 0.5) | |
| where_75 = torch.ops.aten.where.self(le_19, full_default_17, full_default_16); le_19 = None | |
| mul_462 = torch.ops.aten.mul.Tensor(where_75, 1.0); where_75 = None | |
| mul_463 = torch.ops.aten.mul.Tensor(mul_457, mul_462); mul_457 = mul_462 = None | |
| @@ -4194,7 +1020,7 @@ | |
| clamp_min_70 = torch.ops.aten.clamp_min.default(addmm_77, -5.1) | |
| clamp_max_46 = torch.ops.aten.clamp_max.default(clamp_min_70, 5.1); clamp_min_70 = None | |
| sigmoid_88 = torch.ops.aten.sigmoid.default(clamp_max_46); clamp_max_46 = None | |
| - unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_848, 1); convert_element_type_848 = None | |
| + unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_854, 1); convert_element_type_854 = None | |
| sub_142 = torch.ops.aten.sub.Tensor(sigmoid_88, unsqueeze_50); sigmoid_88 = unsqueeze_50 = None | |
| mul_504 = torch.ops.aten.mul.Tensor(sub_142, unsqueeze_85); sub_142 = unsqueeze_85 = None | |
| ge_3 = torch.ops.aten.ge.Scalar(addmm_77, -5.1) | |
| @@ -4231,7 +1057,7 @@ | |
| permute_350 = torch.ops.aten.permute.default(permute_314, [1, 0]); permute_314 = None | |
| mm_123 = torch.ops.aten.mm.default(where_101, permute_350); permute_350 = None | |
| permute_351 = torch.ops.aten.permute.default(where_101, [1, 0]) | |
| - scale_gradient_23 = no_scale(mul_303, 0.325) | |
| + scale_gradient_23 = torch.ops.fb.scale_gradient.default(convert_element_type_784, 0.325) | |
| mm_124 = torch.ops.aten.mm.default(permute_351, scale_gradient_23); permute_351 = scale_gradient_23 = None | |
| sum_36 = torch.ops.aten.sum.dim_IntList(where_101, [0], True); where_101 = None | |
| view_359 = torch.ops.aten.view.default(sum_36, [512]); sum_36 = None | |
| @@ -4241,8 +1067,8 @@ | |
| expand_37 = torch.ops.aten.expand.default(select_218, [16777216]); select_218 = None | |
| div_43 = torch.ops.aten.div.Scalar(expand_37, 16777216); expand_37 = None | |
| mul_509 = torch.ops.aten.mul.Tensor(div_43, 10.0); div_43 = None | |
| - convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_512, torch.float32); primals_512 = None | |
| - mul_392 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_815); convert_element_type_815 = None | |
| + convert_element_type_821 = torch.ops.prims.convert_element_type.default(primals_512, torch.float32); primals_512 = None | |
| + mul_392 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_821); convert_element_type_821 = None | |
| unsqueeze_49 = torch.ops.aten.unsqueeze.default(mul_392, 1) | |
| expand_32 = torch.ops.aten.expand.default(unsqueeze_49, [4096, 4096]); unsqueeze_49 = None | |
| permute_313 = torch.ops.aten.permute.default(expand_32, [1, 0]) | |
| @@ -4260,7 +1086,7 @@ | |
| clamp_min_69 = torch.ops.aten.clamp_min.default(view_330, -5.1) | |
| clamp_max_45 = torch.ops.aten.clamp_max.default(clamp_min_69, 5.1); clamp_min_69 = None | |
| sigmoid_89 = torch.ops.aten.sigmoid.default(clamp_max_45); clamp_max_45 = None | |
| - unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_814, 1) | |
| + unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_820, 1) | |
| expand_31 = torch.ops.aten.expand.default(unsqueeze_36, [4096, 4096]) | |
| permute_312 = torch.ops.aten.permute.default(expand_31, [1, 0]) | |
| sub_117 = torch.ops.aten.sub.Tensor(expand_31, permute_312); expand_31 = permute_312 = None | |
| @@ -4297,7 +1123,7 @@ | |
| permute_359 = torch.ops.aten.permute.default(permute_309, [1, 0]); permute_309 = None | |
| mm_127 = torch.ops.aten.mm.default(where_103, permute_359); permute_359 = None | |
| permute_360 = torch.ops.aten.permute.default(where_103, [1, 0]) | |
| - scale_gradient = no_scale(mul_303, 1.0) | |
| + scale_gradient = torch.ops.fb.scale_gradient.default(convert_element_type_784, 1.0) | |
| mm_128 = torch.ops.aten.mm.default(permute_360, scale_gradient); permute_360 = None | |
| sum_39 = torch.ops.aten.sum.dim_IntList(where_103, [0], True); where_103 = None | |
| view_363 = torch.ops.aten.view.default(sum_39, [32]); sum_39 = None | |
| @@ -4305,8 +1131,8 @@ | |
| expand_38 = torch.ops.aten.expand.default(select_217, [16777216]); select_217 = None | |
| div_44 = torch.ops.aten.div.Scalar(expand_38, 16777216); expand_38 = None | |
| mul_515 = torch.ops.aten.mul.Tensor(div_44, 10.0); div_44 = None | |
| - convert_element_type_801 = torch.ops.prims.convert_element_type.default(primals_474, torch.float32); primals_474 = None | |
| - mul_357 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_801); convert_element_type_801 = None | |
| + convert_element_type_807 = torch.ops.prims.convert_element_type.default(primals_474, torch.float32); primals_474 = None | |
| + mul_357 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_807); convert_element_type_807 = None | |
| unsqueeze_46 = torch.ops.aten.unsqueeze.default(mul_357, 1) | |
| expand_29 = torch.ops.aten.expand.default(unsqueeze_46, [4096, 4096]); unsqueeze_46 = None | |
| permute_308 = torch.ops.aten.permute.default(expand_29, [1, 0]) | |
| @@ -4324,7 +1150,7 @@ | |
| clamp_min_68 = torch.ops.aten.clamp_min.default(view_326, -5.1) | |
| clamp_max_44 = torch.ops.aten.clamp_max.default(clamp_min_68, 5.1); clamp_min_68 = None | |
| sigmoid_90 = torch.ops.aten.sigmoid.default(clamp_max_44); clamp_max_44 = None | |
| - unsqueeze_29 = torch.ops.aten.unsqueeze.default(convert_element_type_800, 1) | |
| + unsqueeze_29 = torch.ops.aten.unsqueeze.default(convert_element_type_806, 1) | |
| expand_28 = torch.ops.aten.expand.default(unsqueeze_29, [4096, 4096]) | |
| permute_307 = torch.ops.aten.permute.default(expand_28, [1, 0]) | |
| sub_112 = torch.ops.aten.sub.Tensor(expand_28, permute_307); expand_28 = permute_307 = None | |
| @@ -4367,9 +1193,9 @@ | |
| add_303 = torch.ops.aten.add.Tensor(add_300, mm_131); add_300 = mm_131 = None | |
| expand_39 = torch.ops.aten.expand.default(select_216, [4096]); select_216 = None | |
| div_45 = torch.ops.aten.div.Scalar(expand_39, 4096); expand_39 = None | |
| - convert_element_type_840 = torch.ops.prims.convert_element_type.default(primals_568, torch.float32); primals_568 = None | |
| - mul_441 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_840); convert_element_type_840 = None | |
| - le_18 = torch.ops.aten.le.Scalar(convert_element_type_839, 0.5) | |
| + convert_element_type_846 = torch.ops.prims.convert_element_type.default(primals_568, torch.float32); primals_568 = None | |
| + mul_441 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_846); convert_element_type_846 = None | |
| + le_18 = torch.ops.aten.le.Scalar(convert_element_type_845, 0.5) | |
| where_71 = torch.ops.aten.where.self(le_18, full_default_17, full_default_16); le_18 = None | |
| mul_446 = torch.ops.aten.mul.Tensor(where_71, 1.0); where_71 = None | |
| mul_447 = torch.ops.aten.mul.Tensor(mul_441, mul_446); mul_441 = mul_446 = None | |
| @@ -4378,7 +1204,7 @@ | |
| clamp_min_67 = torch.ops.aten.clamp_min.default(addmm_71, -5.1) | |
| clamp_max_43 = torch.ops.aten.clamp_max.default(clamp_min_67, 5.1); clamp_min_67 = None | |
| sigmoid_91 = torch.ops.aten.sigmoid.default(clamp_max_43); clamp_max_43 = None | |
| - unsqueeze_43 = torch.ops.aten.unsqueeze.default(convert_element_type_839, 1); convert_element_type_839 = None | |
| + unsqueeze_43 = torch.ops.aten.unsqueeze.default(convert_element_type_845, 1); convert_element_type_845 = None | |
| sub_148 = torch.ops.aten.sub.Tensor(sigmoid_91, unsqueeze_43); sigmoid_91 = unsqueeze_43 = None | |
| mul_522 = torch.ops.aten.mul.Tensor(sub_148, unsqueeze_86); sub_148 = unsqueeze_86 = None | |
| ge_6 = torch.ops.aten.ge.Scalar(addmm_71, -5.1) | |
| @@ -4415,7 +1241,7 @@ | |
| permute_376 = torch.ops.aten.permute.default(permute_302, [1, 0]); permute_302 = None | |
| mm_135 = torch.ops.aten.mm.default(where_109, permute_376); permute_376 = None | |
| permute_377 = torch.ops.aten.permute.default(where_109, [1, 0]) | |
| - scale_gradient_17 = no_scale(mul_303, 0.05) | |
| + scale_gradient_17 = torch.ops.fb.scale_gradient.default(convert_element_type_784, 0.05) | |
| mm_136 = torch.ops.aten.mm.default(permute_377, scale_gradient_17); permute_377 = None | |
| sum_44 = torch.ops.aten.sum.dim_IntList(where_109, [0], True); where_109 = None | |
| view_371 = torch.ops.aten.view.default(sum_44, [512]); sum_44 = None | |
| @@ -4424,9 +1250,9 @@ | |
| add_307 = torch.ops.aten.add.Tensor(add_303, mul_526); add_303 = mul_526 = None | |
| expand_40 = torch.ops.aten.expand.default(select_215, [4096]); select_215 = None | |
| div_46 = torch.ops.aten.div.Scalar(expand_40, 4096); expand_40 = None | |
| - convert_element_type_835 = torch.ops.prims.convert_element_type.default(primals_558, torch.float32); primals_558 = None | |
| - mul_433 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_835); convert_element_type_835 = None | |
| - le_17 = torch.ops.aten.le.Scalar(convert_element_type_834, 0.5) | |
| + convert_element_type_841 = torch.ops.prims.convert_element_type.default(primals_558, torch.float32); primals_558 = None | |
| + mul_433 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_841); convert_element_type_841 = None | |
| + le_17 = torch.ops.aten.le.Scalar(convert_element_type_840, 0.5) | |
| where_67 = torch.ops.aten.where.self(le_17, full_default_17, full_default_16); le_17 = None | |
| mul_438 = torch.ops.aten.mul.Tensor(where_67, 1.0); where_67 = None | |
| mul_439 = torch.ops.aten.mul.Tensor(mul_433, mul_438); mul_433 = mul_438 = None | |
| @@ -4435,7 +1261,7 @@ | |
| clamp_min_66 = torch.ops.aten.clamp_min.default(addmm_69, -5.1) | |
| clamp_max_42 = torch.ops.aten.clamp_max.default(clamp_min_66, 5.1); clamp_min_66 = None | |
| sigmoid_92 = torch.ops.aten.sigmoid.default(clamp_max_42); clamp_max_42 = None | |
| - unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_834, 1); convert_element_type_834 = None | |
| + unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_840, 1); convert_element_type_840 = None | |
| sub_150 = torch.ops.aten.sub.Tensor(sigmoid_92, unsqueeze_42); sigmoid_92 = unsqueeze_42 = None | |
| mul_528 = torch.ops.aten.mul.Tensor(sub_150, unsqueeze_87); sub_150 = unsqueeze_87 = None | |
| ge_7 = torch.ops.aten.ge.Scalar(addmm_69, -5.1) | |
| @@ -4479,9 +1305,9 @@ | |
| add_311 = torch.ops.aten.add.Tensor(add_307, mul_532); add_307 = mul_532 = None | |
| expand_41 = torch.ops.aten.expand.default(select_214, [4096]); select_214 = None | |
| div_47 = torch.ops.aten.div.Scalar(expand_41, 4096); expand_41 = None | |
| - convert_element_type_830 = torch.ops.prims.convert_element_type.default(primals_548, torch.float32); primals_548 = None | |
| - mul_425 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_830); convert_element_type_830 = None | |
| - le_16 = torch.ops.aten.le.Scalar(convert_element_type_829, 0.5) | |
| + convert_element_type_836 = torch.ops.prims.convert_element_type.default(primals_548, torch.float32); primals_548 = None | |
| + mul_425 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_836); convert_element_type_836 = None | |
| + le_16 = torch.ops.aten.le.Scalar(convert_element_type_835, 0.5) | |
| where_63 = torch.ops.aten.where.self(le_16, full_default_17, full_default_16); le_16 = None | |
| mul_430 = torch.ops.aten.mul.Tensor(where_63, 1.0); where_63 = None | |
| mul_431 = torch.ops.aten.mul.Tensor(mul_425, mul_430); mul_425 = mul_430 = None | |
| @@ -4490,7 +1316,7 @@ | |
| clamp_min_65 = torch.ops.aten.clamp_min.default(addmm_67, -5.1) | |
| clamp_max_41 = torch.ops.aten.clamp_max.default(clamp_min_65, 5.1); clamp_min_65 = None | |
| sigmoid_93 = torch.ops.aten.sigmoid.default(clamp_max_41); clamp_max_41 = None | |
| - unsqueeze_41 = torch.ops.aten.unsqueeze.default(convert_element_type_829, 1); convert_element_type_829 = None | |
| + unsqueeze_41 = torch.ops.aten.unsqueeze.default(convert_element_type_835, 1); convert_element_type_835 = None | |
| sub_152 = torch.ops.aten.sub.Tensor(sigmoid_93, unsqueeze_41); sigmoid_93 = unsqueeze_41 = None | |
| mul_534 = torch.ops.aten.mul.Tensor(sub_152, unsqueeze_88); sub_152 = unsqueeze_88 = None | |
| ge_8 = torch.ops.aten.ge.Scalar(addmm_67, -5.1) | |
| @@ -4534,9 +1360,9 @@ | |
| add_315 = torch.ops.aten.add.Tensor(add_311, mul_538); add_311 = mul_538 = None | |
| expand_42 = torch.ops.aten.expand.default(select_213, [4096]); select_213 = None | |
| div_48 = torch.ops.aten.div.Scalar(expand_42, 4096); expand_42 = None | |
| - convert_element_type_825 = torch.ops.prims.convert_element_type.default(primals_538, torch.float32); primals_538 = None | |
| - mul_417 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_825); convert_element_type_825 = None | |
| - le_15 = torch.ops.aten.le.Scalar(convert_element_type_824, 0.5) | |
| + convert_element_type_831 = torch.ops.prims.convert_element_type.default(primals_538, torch.float32); primals_538 = None | |
| + mul_417 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_831); convert_element_type_831 = None | |
| + le_15 = torch.ops.aten.le.Scalar(convert_element_type_830, 0.5) | |
| where_59 = torch.ops.aten.where.self(le_15, full_default_17, full_default_16); le_15 = None | |
| mul_422 = torch.ops.aten.mul.Tensor(where_59, 1.0); where_59 = None | |
| mul_423 = torch.ops.aten.mul.Tensor(mul_417, mul_422); mul_417 = mul_422 = None | |
| @@ -4545,7 +1371,7 @@ | |
| clamp_min_64 = torch.ops.aten.clamp_min.default(addmm_65, -5.1) | |
| clamp_max_40 = torch.ops.aten.clamp_max.default(clamp_min_64, 5.1); clamp_min_64 = None | |
| sigmoid_94 = torch.ops.aten.sigmoid.default(clamp_max_40); clamp_max_40 = None | |
| - unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_824, 1); convert_element_type_824 = None | |
| + unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_830, 1); convert_element_type_830 = None | |
| sub_154 = torch.ops.aten.sub.Tensor(sigmoid_94, unsqueeze_40); sigmoid_94 = unsqueeze_40 = None | |
| mul_540 = torch.ops.aten.mul.Tensor(sub_154, unsqueeze_89); sub_154 = unsqueeze_89 = None | |
| ge_9 = torch.ops.aten.ge.Scalar(addmm_65, -5.1) | |
| @@ -4587,9 +1413,9 @@ | |
| view_383 = torch.ops.aten.view.default(sum_50, [512]); sum_50 = None | |
| mul_544 = torch.ops.aten.mul.Tensor(mm_147, full_default_118); mm_147 = full_default_118 = None | |
| add_319 = torch.ops.aten.add.Tensor(add_315, mul_544); add_315 = mul_544 = None | |
| - convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_532, torch.float32); primals_532 = None | |
| - mul_411 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_823); convert_element_type_823 = None | |
| - le_14 = torch.ops.aten.le.Scalar(convert_element_type_822, 0.5) | |
| + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_532, torch.float32); primals_532 = None | |
| + mul_411 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_829); convert_element_type_829 = None | |
| + le_14 = torch.ops.aten.le.Scalar(convert_element_type_828, 0.5) | |
| where_53 = torch.ops.aten.where.self(le_14, full_default_22, full_default_21); le_14 = None | |
| mul_413 = torch.ops.aten.mul.Tensor(where_53, mul_411); where_53 = mul_411 = None | |
| mul_414 = torch.ops.aten.mul.Tensor(mul_413, 1.0); mul_413 = None | |
| @@ -4607,7 +1433,7 @@ | |
| clamp_min_63 = torch.ops.aten.clamp_min.default(addmm_63, -3.1) | |
| clamp_max_39 = torch.ops.aten.clamp_max.default(clamp_min_63, 3.1); clamp_min_63 = None | |
| sigmoid_95 = torch.ops.aten.sigmoid.default(clamp_max_39); clamp_max_39 = None | |
| - unsqueeze_39 = torch.ops.aten.unsqueeze.default(convert_element_type_822, 1); convert_element_type_822 = None | |
| + unsqueeze_39 = torch.ops.aten.unsqueeze.default(convert_element_type_828, 1); convert_element_type_828 = None | |
| sub_156 = torch.ops.aten.sub.Tensor(sigmoid_95, unsqueeze_39); sigmoid_95 = unsqueeze_39 = None | |
| mul_547 = torch.ops.aten.mul.Tensor(sub_156, unsqueeze_90); sub_156 = unsqueeze_90 = None | |
| ge_10 = torch.ops.aten.ge.Scalar(addmm_63, -3.1) | |
| @@ -4633,16 +1459,16 @@ | |
| permute_408 = torch.ops.aten.permute.default(permute_294, [1, 0]); permute_294 = None | |
| mm_151 = torch.ops.aten.mm.default(where_123, permute_408); permute_408 = None | |
| permute_409 = torch.ops.aten.permute.default(where_123, [1, 0]) | |
| - scale_gradient_5 = no_scale(mul_303, 0.1) | |
| + scale_gradient_5 = torch.ops.fb.scale_gradient.default(convert_element_type_784, 0.1) | |
| mm_152 = torch.ops.aten.mm.default(permute_409, scale_gradient_5); permute_409 = None | |
| sum_52 = torch.ops.aten.sum.dim_IntList(where_123, [0], True); where_123 = None | |
| view_385 = torch.ops.aten.view.default(sum_52, [512]); sum_52 = None | |
| full_default_48 = torch.ops.aten.full.default([], 0.10000000149011612, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) | |
| mul_550 = torch.ops.aten.mul.Tensor(mm_151, full_default_48); mm_151 = None | |
| add_321 = torch.ops.aten.add.Tensor(add_319, mul_550); add_319 = mul_550 = None | |
| - convert_element_type_819 = torch.ops.prims.convert_element_type.default(primals_524, torch.float32); primals_524 = None | |
| - mul_402 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_819); convert_element_type_819 = None | |
| - le_13 = torch.ops.aten.le.Scalar(convert_element_type_818, 0.5) | |
| + convert_element_type_825 = torch.ops.prims.convert_element_type.default(primals_524, torch.float32); primals_524 = None | |
| + mul_402 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_825); convert_element_type_825 = None | |
| + le_13 = torch.ops.aten.le.Scalar(convert_element_type_824, 0.5) | |
| where_48 = torch.ops.aten.where.self(le_13, full_default_21, full_default_21); le_13 = None | |
| mul_404 = torch.ops.aten.mul.Tensor(where_48, mul_402); where_48 = mul_402 = None | |
| mul_405 = torch.ops.aten.mul.Tensor(mul_404, 1.0); mul_404 = None | |
| @@ -4660,7 +1486,7 @@ | |
| clamp_min_62 = torch.ops.aten.clamp_min.default(addmm_61, -7.1) | |
| clamp_max_38 = torch.ops.aten.clamp_max.default(clamp_min_62, 7.1); clamp_min_62 = None | |
| sigmoid_96 = torch.ops.aten.sigmoid.default(clamp_max_38); clamp_max_38 = None | |
| - unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_818, 1); convert_element_type_818 = None | |
| + unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_824, 1); convert_element_type_824 = None | |
| sub_158 = torch.ops.aten.sub.Tensor(sigmoid_96, unsqueeze_38); sigmoid_96 = unsqueeze_38 = None | |
| mul_553 = torch.ops.aten.mul.Tensor(sub_158, unsqueeze_91); sub_158 = unsqueeze_91 = None | |
| ge_11 = torch.ops.aten.ge.Scalar(addmm_61, -7.1) | |
| @@ -4693,8 +1519,8 @@ | |
| add_323 = torch.ops.aten.add.Tensor(add_321, mul_556); add_321 = mul_556 = None | |
| expand_45 = torch.ops.aten.expand.default(select_207, [4096]); select_207 = None | |
| div_51 = torch.ops.aten.div.Scalar(expand_45, 4096); expand_45 = None | |
| - convert_element_type_817 = torch.ops.prims.convert_element_type.default(primals_518, torch.float32); primals_518 = None | |
| - mul_397 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_817); convert_element_type_817 = None | |
| + convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_518, torch.float32); primals_518 = None | |
| + mul_397 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_823); convert_element_type_823 = None | |
| where_47 = torch.ops.aten.where.self(le_12, full_default_61, full_default_21); le_12 = None | |
| mul_399 = torch.ops.aten.mul.Tensor(where_47, 1.0); where_47 = None | |
| mul_400 = torch.ops.aten.mul.Tensor(mul_397, mul_399); mul_397 = mul_399 = None | |
| @@ -4724,7 +1550,7 @@ | |
| permute_424 = torch.ops.aten.permute.default(permute_290, [1, 0]); permute_290 = None | |
| mm_159 = torch.ops.aten.mm.default(where_127, permute_424); permute_424 = None | |
| permute_425 = torch.ops.aten.permute.default(where_127, [1, 0]) | |
| - scale_gradient_13 = no_scale(mul_303, 1.5) | |
| + scale_gradient_13 = torch.ops.fb.scale_gradient.default(convert_element_type_784, 1.5) | |
| mm_160 = torch.ops.aten.mm.default(permute_425, scale_gradient_13); permute_425 = scale_gradient_13 = None | |
| sum_56 = torch.ops.aten.sum.dim_IntList(where_127, [0], True); where_127 = None | |
| view_389 = torch.ops.aten.view.default(sum_56, [512]); sum_56 = None | |
| @@ -4733,7 +1559,7 @@ | |
| add_325 = torch.ops.aten.add.Tensor(add_323, mul_561); add_323 = mul_561 = None | |
| expand_46 = torch.ops.aten.expand.default(select_206, [4096]); select_206 = None | |
| div_52 = torch.ops.aten.div.Scalar(expand_46, 4096); expand_46 = None | |
| - le_11 = torch.ops.aten.le.Scalar(convert_element_type_814, 0.5); convert_element_type_814 = None | |
| + le_11 = torch.ops.aten.le.Scalar(convert_element_type_820, 0.5); convert_element_type_820 = None | |
| where_44 = torch.ops.aten.where.self(le_11, full_default_61, full_default_21); le_11 = None | |
| mul_394 = torch.ops.aten.mul.Tensor(where_44, 1.0); where_44 = None | |
| mul_395 = torch.ops.aten.mul.Tensor(mul_392, mul_394); mul_392 = mul_394 = None | |
| @@ -4779,7 +1605,7 @@ | |
| permute_432 = torch.ops.aten.permute.default(permute_288, [1, 0]); permute_288 = None | |
| mm_163 = torch.ops.aten.mm.default(where_129, permute_432); permute_432 = None | |
| permute_433 = torch.ops.aten.permute.default(where_129, [1, 0]) | |
| - scale_gradient_6 = no_scale(mul_303, 4.0) | |
| + scale_gradient_6 = torch.ops.fb.scale_gradient.default(convert_element_type_784, 4.0) | |
| mm_164 = torch.ops.aten.mm.default(permute_433, scale_gradient_6); permute_433 = None | |
| sum_58 = torch.ops.aten.sum.dim_IntList(where_129, [0], True); where_129 = None | |
| view_391 = torch.ops.aten.view.default(sum_58, [512]); sum_58 = None | |
| @@ -4826,7 +1652,7 @@ | |
| permute_444 = torch.ops.aten.permute.default(permute_285, [1, 0]); permute_285 = None | |
| mm_169 = torch.ops.aten.mm.default(where_132, permute_444); permute_444 = None | |
| permute_445 = torch.ops.aten.permute.default(where_132, [1, 0]) | |
| - scale_gradient_10 = no_scale(mul_303, 3.0) | |
| + scale_gradient_10 = torch.ops.fb.scale_gradient.default(convert_element_type_784, 3.0) | |
| mm_170 = torch.ops.aten.mm.default(permute_445, scale_gradient_10); permute_445 = None | |
| sum_61 = torch.ops.aten.sum.dim_IntList(where_132, [0], True); where_132 = None | |
| view_394 = torch.ops.aten.view.default(sum_61, [256]); sum_61 = None | |
| @@ -4835,9 +1661,9 @@ | |
| add_329 = torch.ops.aten.add.Tensor(add_327, mul_571); add_327 = mul_571 = None | |
| expand_48 = torch.ops.aten.expand.default(select_204, [4096]); select_204 = None | |
| div_54 = torch.ops.aten.div.Scalar(expand_48, 4096); expand_48 = None | |
| - convert_element_type_811 = torch.ops.prims.convert_element_type.default(primals_499, torch.float32); primals_499 = None | |
| - mul_383 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_811); convert_element_type_811 = None | |
| - le_9 = torch.ops.aten.le.Scalar(convert_element_type_810, 0.5) | |
| + convert_element_type_817 = torch.ops.prims.convert_element_type.default(primals_499, torch.float32); primals_499 = None | |
| + mul_383 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_817); convert_element_type_817 = None | |
| + le_9 = torch.ops.aten.le.Scalar(convert_element_type_816, 0.5) | |
| full_default_80 = torch.ops.aten.full.default([], 0.5, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| where_38 = torch.ops.aten.where.self(le_9, full_default_80, full_default_21); le_9 = full_default_80 = None | |
| mul_385 = torch.ops.aten.mul.Tensor(where_38, 1.0); where_38 = None | |
| @@ -4847,7 +1673,7 @@ | |
| clamp_min_58 = torch.ops.aten.clamp_min.default(addmm_52, -5.1) | |
| clamp_max_34 = torch.ops.aten.clamp_max.default(clamp_min_58, 5.1); clamp_min_58 = None | |
| sigmoid_100 = torch.ops.aten.sigmoid.default(clamp_max_34); clamp_max_34 = None | |
| - unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_810, 1); convert_element_type_810 = None | |
| + unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_816, 1); convert_element_type_816 = None | |
| sub_166 = torch.ops.aten.sub.Tensor(sigmoid_100, unsqueeze_34); sigmoid_100 = unsqueeze_34 = None | |
| mul_573 = torch.ops.aten.mul.Tensor(sub_166, unsqueeze_95); sub_166 = unsqueeze_95 = None | |
| ge_15 = torch.ops.aten.ge.Scalar(addmm_52, -5.1) | |
| @@ -4878,9 +1704,9 @@ | |
| view_396 = torch.ops.aten.view.default(sum_63, [512]); sum_63 = None | |
| mul_576 = torch.ops.aten.mul.Tensor(mm_173, full_default_77); mm_173 = full_default_77 = None | |
| add_331 = torch.ops.aten.add.Tensor(add_329, mul_576); add_329 = mul_576 = None | |
| - convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_493, torch.float32); primals_493 = None | |
| - mul_377 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_809); convert_element_type_809 = None | |
| - le_8 = torch.ops.aten.le.Scalar(convert_element_type_808, 0.5) | |
| + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_493, torch.float32); primals_493 = None | |
| + mul_377 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_815); convert_element_type_815 = None | |
| + le_8 = torch.ops.aten.le.Scalar(convert_element_type_814, 0.5) | |
| where_35 = torch.ops.aten.where.self(le_8, full_default_22, full_default_21); le_8 = None | |
| mul_379 = torch.ops.aten.mul.Tensor(where_35, mul_377); where_35 = mul_377 = None | |
| mul_380 = torch.ops.aten.mul.Tensor(mul_379, 1.0); mul_379 = None | |
| @@ -4898,7 +1724,7 @@ | |
| clamp_min_57 = torch.ops.aten.clamp_min.default(addmm_50, -3.1) | |
| clamp_max_33 = torch.ops.aten.clamp_max.default(clamp_min_57, 3.1); clamp_min_57 = None | |
| sigmoid_101 = torch.ops.aten.sigmoid.default(clamp_max_33); clamp_max_33 = None | |
| - unsqueeze_33 = torch.ops.aten.unsqueeze.default(convert_element_type_808, 1); convert_element_type_808 = None | |
| + unsqueeze_33 = torch.ops.aten.unsqueeze.default(convert_element_type_814, 1); convert_element_type_814 = None | |
| sub_168 = torch.ops.aten.sub.Tensor(sigmoid_101, unsqueeze_33); sigmoid_101 = unsqueeze_33 = None | |
| mul_579 = torch.ops.aten.mul.Tensor(sub_168, unsqueeze_96); sub_168 = unsqueeze_96 = None | |
| ge_16 = torch.ops.aten.ge.Scalar(addmm_50, -3.1) | |
| @@ -4931,8 +1757,8 @@ | |
| add_333 = torch.ops.aten.add.Tensor(add_331, mul_582); add_331 = mul_582 = None | |
| expand_50 = torch.ops.aten.expand.default(select_201, [4096]); select_201 = None | |
| div_56 = torch.ops.aten.div.Scalar(expand_50, 4096); expand_50 = None | |
| - convert_element_type_805 = torch.ops.prims.convert_element_type.default(primals_489, torch.float32); primals_489 = None | |
| - mul_370 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_805); convert_element_type_805 = None | |
| + convert_element_type_811 = torch.ops.prims.convert_element_type.default(primals_489, torch.float32); primals_489 = None | |
| + mul_370 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_811); convert_element_type_811 = None | |
| mul_373 = torch.ops.aten.mul.Tensor(mul_370, 1.0); mul_370 = None | |
| mul_583 = torch.ops.aten.mul.Tensor(div_56, mul_373); div_56 = mul_373 = None | |
| unsqueeze_97 = torch.ops.aten.unsqueeze.default(mul_583, 1); mul_583 = None | |
| @@ -4951,7 +1777,7 @@ | |
| sub_69 = torch.ops.aten.sub.Tensor(1, where_32) | |
| div_20 = torch.ops.aten.div.Tensor(where_32, sub_69); where_32 = sub_69 = None | |
| log_4 = torch.ops.aten.log.default(div_20); div_20 = None | |
| - scale_gradient_7 = no_scale(log_4, 0.0); log_4 = None | |
| + scale_gradient_7 = torch.ops.fb.scale_gradient.default(log_4, 0.0); log_4 = None | |
| add_227 = torch.ops.aten.add.Tensor(scale_gradient_7, 0.0); scale_gradient_7 = None | |
| sigmoid_57 = torch.ops.aten.sigmoid.default(add_227); add_227 = None | |
| sub_171 = torch.ops.aten.sub.Tensor(1, sigmoid_57) | |
| @@ -4979,7 +1805,7 @@ | |
| expand_52 = torch.ops.aten.expand.default(mul_592, [4096]); mul_592 = None | |
| div_59 = torch.ops.aten.div.Scalar(expand_52, 4096); expand_52 = None | |
| mul_594 = torch.ops.aten.mul.Tensor(div_59, mul_357); div_59 = None | |
| - unsqueeze_23 = torch.ops.aten.unsqueeze.default(convert_element_type_784, 1) | |
| + unsqueeze_23 = torch.ops.aten.unsqueeze.default(convert_element_type_790, 1) | |
| view_315 = torch.ops.aten.view.default(unsqueeze_23, [-1]) | |
| mul_595 = torch.ops.aten.mul.Tensor(mul_593, view_315); mul_593 = None | |
| unsqueeze_98 = torch.ops.aten.unsqueeze.default(mul_595, 1); mul_595 = None | |
| @@ -4988,7 +1814,7 @@ | |
| clamp_min_55 = torch.ops.aten.clamp_min.default(addmm_48, -5.1) | |
| clamp_max_31 = torch.ops.aten.clamp_max.default(clamp_min_55, 5.1); clamp_min_55 = None | |
| sigmoid_103 = torch.ops.aten.sigmoid.default(clamp_max_31); clamp_max_31 = None | |
| - view_314 = torch.ops.aten.view.default(convert_element_type_800, [-1, 1]) | |
| + view_314 = torch.ops.aten.view.default(convert_element_type_806, [-1, 1]) | |
| sub_61 = torch.ops.aten.sub.Tensor(1, view_314) | |
| sub_173 = torch.ops.aten.sub.Tensor(sigmoid_103, sub_61); sigmoid_103 = sub_61 = None | |
| mul_596 = torch.ops.aten.mul.Tensor(sub_173, unsqueeze_98); sub_173 = unsqueeze_98 = None | |
| @@ -5043,7 +1869,7 @@ | |
| add_337 = torch.ops.aten.add.Tensor(mm_181, mm_185); mm_181 = mm_185 = None | |
| expand_53 = torch.ops.aten.expand.default(select_200, [4096]); select_200 = None | |
| div_60 = torch.ops.aten.div.Scalar(expand_53, 4096); expand_53 = None | |
| - le_7 = torch.ops.aten.le.Scalar(convert_element_type_800, 0.5); convert_element_type_800 = None | |
| + le_7 = torch.ops.aten.le.Scalar(convert_element_type_806, 0.5); convert_element_type_806 = None | |
| where_30 = torch.ops.aten.where.self(le_7, full_default_61, full_default_21); le_7 = full_default_61 = None | |
| mul_359 = torch.ops.aten.mul.Tensor(where_30, 1.0); where_30 = None | |
| mul_360 = torch.ops.aten.mul.Tensor(mul_357, mul_359); mul_357 = mul_359 = None | |
| @@ -5092,9 +1918,9 @@ | |
| add_339 = torch.ops.aten.add.Tensor(add_337, mm_189); add_337 = mm_189 = None | |
| mul_603 = torch.ops.aten.mul.Tensor(add_339, full_default_55); add_339 = full_default_55 = None | |
| add_340 = torch.ops.aten.add.Tensor(add_333, mul_603); add_333 = mul_603 = None | |
| - convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_468, torch.float32); primals_468 = None | |
| - mul_351 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_799); convert_element_type_799 = None | |
| - le_6 = torch.ops.aten.le.Scalar(convert_element_type_798, 0.5) | |
| + convert_element_type_805 = torch.ops.prims.convert_element_type.default(primals_468, torch.float32); primals_468 = None | |
| + mul_351 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_805); convert_element_type_805 = None | |
| + le_6 = torch.ops.aten.le.Scalar(convert_element_type_804, 0.5) | |
| where_25 = torch.ops.aten.where.self(le_6, full_default_21, full_default_21); le_6 = None | |
| mul_353 = torch.ops.aten.mul.Tensor(where_25, mul_351); where_25 = mul_351 = None | |
| mul_354 = torch.ops.aten.mul.Tensor(mul_353, 1.0); mul_353 = None | |
| @@ -5112,7 +1938,7 @@ | |
| clamp_min_52 = torch.ops.aten.clamp_min.default(addmm_42, -5.1) | |
| clamp_max_28 = torch.ops.aten.clamp_max.default(clamp_min_52, 5.1); clamp_min_52 = None | |
| sigmoid_106 = torch.ops.aten.sigmoid.default(clamp_max_28); clamp_max_28 = None | |
| - unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_798, 1); convert_element_type_798 = None | |
| + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_804, 1); convert_element_type_804 = None | |
| sub_177 = torch.ops.aten.sub.Tensor(sigmoid_106, unsqueeze_28); sigmoid_106 = unsqueeze_28 = None | |
| mul_606 = torch.ops.aten.mul.Tensor(sub_177, unsqueeze_101); sub_177 = unsqueeze_101 = None | |
| ge_21 = torch.ops.aten.ge.Scalar(addmm_42, -5.1) | |
| @@ -5141,9 +1967,9 @@ | |
| view_406 = torch.ops.aten.view.default(sum_73, [512]); sum_73 = None | |
| mul_609 = torch.ops.aten.mul.Tensor(mm_193, full_default_48); mm_193 = full_default_48 = None | |
| add_342 = torch.ops.aten.add.Tensor(add_340, mul_609); add_340 = mul_609 = None | |
| - convert_element_type_795 = torch.ops.prims.convert_element_type.default(primals_460, torch.float32); primals_460 = None | |
| - mul_339 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_795); convert_element_type_795 = None | |
| - le_4 = torch.ops.aten.le.Scalar(convert_element_type_794, 0.5) | |
| + convert_element_type_801 = torch.ops.prims.convert_element_type.default(primals_460, torch.float32); primals_460 = None | |
| + mul_339 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_801); convert_element_type_801 = None | |
| + le_4 = torch.ops.aten.le.Scalar(convert_element_type_800, 0.5) | |
| where_17 = torch.ops.aten.where.self(le_4, full_default_22, full_default_21); le_4 = None | |
| mul_341 = torch.ops.aten.mul.Tensor(where_17, mul_339); where_17 = mul_339 = None | |
| mul_342 = torch.ops.aten.mul.Tensor(mul_341, 1.0); mul_341 = None | |
| @@ -5161,7 +1987,7 @@ | |
| clamp_min_49 = torch.ops.aten.clamp_min.default(addmm_40, -5.1) | |
| clamp_max_25 = torch.ops.aten.clamp_max.default(clamp_min_49, 5.1); clamp_min_49 = None | |
| sigmoid_107 = torch.ops.aten.sigmoid.default(clamp_max_25); clamp_max_25 = None | |
| - unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_794, 1); convert_element_type_794 = None | |
| + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_800, 1); convert_element_type_800 = None | |
| sub_179 = torch.ops.aten.sub.Tensor(sigmoid_107, unsqueeze_26); sigmoid_107 = unsqueeze_26 = None | |
| mul_612 = torch.ops.aten.mul.Tensor(sub_179, unsqueeze_102); sub_179 = unsqueeze_102 = None | |
| ge_22 = torch.ops.aten.ge.Scalar(addmm_40, -5.1) | |
| @@ -5187,16 +2013,16 @@ | |
| permute_500 = torch.ops.aten.permute.default(permute_271, [1, 0]); permute_271 = None | |
| mm_197 = torch.ops.aten.mm.default(where_147, permute_500); permute_500 = None | |
| permute_501 = torch.ops.aten.permute.default(where_147, [1, 0]) | |
| - scale_gradient_2 = no_scale(mul_303, 0.5) | |
| + scale_gradient_2 = torch.ops.fb.scale_gradient.default(convert_element_type_784, 0.5) | |
| mm_198 = torch.ops.aten.mm.default(permute_501, scale_gradient_2); permute_501 = None | |
| sum_75 = torch.ops.aten.sum.dim_IntList(where_147, [0], True); where_147 = None | |
| view_408 = torch.ops.aten.view.default(sum_75, [512]); sum_75 = None | |
| full_default_26 = torch.ops.aten.full.default([], 0.5, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) | |
| mul_615 = torch.ops.aten.mul.Tensor(mm_197, full_default_26); mm_197 = None | |
| add_344 = torch.ops.aten.add.Tensor(add_342, mul_615); add_342 = mul_615 = None | |
| - convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_454, torch.float32); primals_454 = None | |
| - mul_333 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_793); convert_element_type_793 = None | |
| - le_3 = torch.ops.aten.le.Scalar(convert_element_type_792, 0.5) | |
| + convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_454, torch.float32); primals_454 = None | |
| + mul_333 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_799); convert_element_type_799 = None | |
| + le_3 = torch.ops.aten.le.Scalar(convert_element_type_798, 0.5) | |
| where_14 = torch.ops.aten.where.self(le_3, full_default_22, full_default_21); le_3 = full_default_22 = None | |
| mul_335 = torch.ops.aten.mul.Tensor(where_14, mul_333); where_14 = mul_333 = None | |
| mul_336 = torch.ops.aten.mul.Tensor(mul_335, 1.0); mul_335 = None | |
| @@ -5214,7 +2040,7 @@ | |
| clamp_min_48 = torch.ops.aten.clamp_min.default(addmm_38, -5.1) | |
| clamp_max_24 = torch.ops.aten.clamp_max.default(clamp_min_48, 5.1); clamp_min_48 = None | |
| sigmoid_108 = torch.ops.aten.sigmoid.default(clamp_max_24); clamp_max_24 = None | |
| - unsqueeze_25 = torch.ops.aten.unsqueeze.default(convert_element_type_792, 1); convert_element_type_792 = None | |
| + unsqueeze_25 = torch.ops.aten.unsqueeze.default(convert_element_type_798, 1); convert_element_type_798 = None | |
| sub_181 = torch.ops.aten.sub.Tensor(sigmoid_108, unsqueeze_25); sigmoid_108 = unsqueeze_25 = None | |
| mul_618 = torch.ops.aten.mul.Tensor(sub_181, unsqueeze_103); sub_181 = unsqueeze_103 = None | |
| ge_23 = torch.ops.aten.ge.Scalar(addmm_38, -5.1) | |
| @@ -5284,9 +2110,9 @@ | |
| add_348 = torch.ops.aten.add.Tensor(add_346, mm_205); add_346 = mm_205 = None | |
| expand_58 = torch.ops.aten.expand.default(select_196, [4096]); select_196 = None | |
| div_65 = torch.ops.aten.div.Scalar(expand_58, 4096); expand_58 = None | |
| - convert_element_type_785 = torch.ops.prims.convert_element_type.default(primals_433, torch.float32); primals_433 = None | |
| - mul_313 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_785); primals_431 = convert_element_type_785 = None | |
| - le_1 = torch.ops.aten.le.Scalar(convert_element_type_784, 0.5); convert_element_type_784 = None | |
| + convert_element_type_791 = torch.ops.prims.convert_element_type.default(primals_433, torch.float32); primals_433 = None | |
| + mul_313 = torch.ops.aten.mul.Tensor(primals_431, convert_element_type_791); primals_431 = convert_element_type_791 = None | |
| + le_1 = torch.ops.aten.le.Scalar(convert_element_type_790, 0.5); convert_element_type_790 = None | |
| where_10 = torch.ops.aten.where.self(le_1, full_default_17, full_default_16); le_1 = full_default_17 = full_default_16 = None | |
| mul_322 = torch.ops.aten.mul.Tensor(where_10, 1.0); where_10 = None | |
| mul_323 = torch.ops.aten.mul.Tensor(mul_313, mul_322); mul_313 = mul_322 = None | |
| @@ -5310,7 +2136,7 @@ | |
| gt = torch.ops.aten.gt.Scalar(index_45, 1000); index_45 = None | |
| where_153 = torch.ops.aten.where.self(gt, full_default_3, view_413) | |
| where_154 = torch.ops.aten.where.self(gt, view_413, full_default_3); gt = view_413 = None | |
| - convert_element_type_864 = torch.ops.prims.convert_element_type.default(where_154, torch.float64); where_154 = None | |
| + convert_element_type_870 = torch.ops.prims.convert_element_type.default(where_154, torch.float64); where_154 = None | |
| full_default_242 = torch.ops.aten.full.default([], 0.0, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| index_46 = torch.ops.aten.index.Tensor(primals_440, [bucketize]); primals_440 = None | |
| index_47 = torch.ops.aten.index.Tensor(primals_441, [bucketize]); primals_441 = bucketize = None | |
| @@ -5328,10 +2154,10 @@ | |
| ge_26 = torch.ops.aten.ge.Scalar(mul_314, 1e-09) | |
| le_74 = torch.ops.aten.le.Scalar(mul_314, 0.999999999); mul_314 = None | |
| logical_and_26 = torch.ops.aten.logical_and.default(ge_26, le_74); ge_26 = le_74 = None | |
| - where_155 = torch.ops.aten.where.self(logical_and_26, convert_element_type_864, full_default_242); logical_and_26 = convert_element_type_864 = full_default_242 = None | |
| + where_155 = torch.ops.aten.where.self(logical_and_26, convert_element_type_870, full_default_242); logical_and_26 = convert_element_type_870 = full_default_242 = None | |
| mul_630 = torch.ops.aten.mul.Tensor(where_155, clamp_max_20); where_155 = clamp_max_20 = None | |
| - convert_element_type_865 = torch.ops.prims.convert_element_type.default(mul_630, torch.float32); mul_630 = None | |
| - add_349 = torch.ops.aten.add.Tensor(where_153, convert_element_type_865); where_153 = convert_element_type_865 = None | |
| + convert_element_type_871 = torch.ops.prims.convert_element_type.default(mul_630, torch.float32); mul_630 = None | |
| + add_349 = torch.ops.aten.add.Tensor(where_153, convert_element_type_871); where_153 = convert_element_type_871 = None | |
| view_414 = torch.ops.aten.view.default(add_349, [4096, 1]); add_349 = None | |
| add_350 = torch.ops.aten.add.Tensor(select_140, view_414); select_140 = view_414 = None | |
| sub_186 = torch.ops.aten.sub.Tensor(1, sigmoid_49) | |
| @@ -5361,8 +2187,8 @@ | |
| iota_2 = torch.ops.prims.iota.default(4096, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) | |
| ne = torch.ops.aten.ne.Scalar(iota_2, -100) | |
| sum_7 = torch.ops.aten.sum.default(ne); ne = None | |
| - convert_element_type_783 = torch.ops.prims.convert_element_type.default(sum_7, torch.float32); sum_7 = None | |
| - div_66 = torch.ops.aten.div.Tensor(mul_634, convert_element_type_783); mul_634 = convert_element_type_783 = None | |
| + convert_element_type_789 = torch.ops.prims.convert_element_type.default(sum_7, torch.float32); sum_7 = None | |
| + div_66 = torch.ops.aten.div.Tensor(mul_634, convert_element_type_789); mul_634 = convert_element_type_789 = None | |
| unsqueeze_106 = torch.ops.aten.unsqueeze.default(iota_2, 1) | |
| ne_12 = torch.ops.aten.ne.Scalar(unsqueeze_106, -100) | |
| full_default_5 = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| @@ -5434,11 +2260,12 @@ | |
| permute_532 = torch.ops.aten.permute.default(permute_263, [1, 0]); permute_263 = None | |
| mm_213 = torch.ops.aten.mm.default(add_354, permute_532); permute_532 = None | |
| permute_533 = torch.ops.aten.permute.default(add_354, [1, 0]) | |
| - pow_79 = torch.ops.aten.pow.Tensor_Scalar(mul_171, 2) | |
| + convert_element_type_533 = torch.ops.prims.convert_element_type.default(mul_171, torch.float32) | |
| + pow_79 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_533, 2) | |
| mean_78 = torch.ops.aten.mean.dim(pow_79, [1], True); pow_79 = None | |
| add_212 = torch.ops.aten.add.Scalar(mean_78, 1.1920928955078125e-07); mean_78 = None | |
| rsqrt_97 = torch.ops.aten.rsqrt.default(add_212); add_212 = None | |
| - mul_310 = torch.ops.aten.mul.Tensor(mul_171, rsqrt_97) | |
| + mul_310 = torch.ops.aten.mul.Tensor(convert_element_type_533, rsqrt_97); convert_element_type_533 = None | |
| mul_311 = torch.ops.aten.mul.Tensor(mul_310, primals_426) | |
| mm_214 = torch.ops.aten.mm.default(permute_533, mul_311); permute_533 = mul_311 = None | |
| sum_86 = torch.ops.aten.sum.dim_IntList(add_354, [0], True); add_354 = None | |
| @@ -5447,11 +2274,12 @@ | |
| permute_536 = torch.ops.aten.permute.default(permute_262, [1, 0]); permute_262 = None | |
| mm_215 = torch.ops.aten.mm.default(add_355, permute_536); permute_536 = None | |
| permute_537 = torch.ops.aten.permute.default(add_355, [1, 0]) | |
| - pow_78 = torch.ops.aten.pow.Tensor_Scalar(mul_170, 2) | |
| + convert_element_type_526 = torch.ops.prims.convert_element_type.default(mul_170, torch.float32) | |
| + pow_78 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_526, 2) | |
| mean_77 = torch.ops.aten.mean.dim(pow_78, [1], True); pow_78 = None | |
| add_211 = torch.ops.aten.add.Scalar(mean_77, 1.1920928955078125e-07); mean_77 = None | |
| rsqrt_96 = torch.ops.aten.rsqrt.default(add_211); add_211 = None | |
| - mul_308 = torch.ops.aten.mul.Tensor(mul_170, rsqrt_96) | |
| + mul_308 = torch.ops.aten.mul.Tensor(convert_element_type_526, rsqrt_96); convert_element_type_526 = None | |
| mul_309 = torch.ops.aten.mul.Tensor(mul_308, primals_425) | |
| mm_216 = torch.ops.aten.mm.default(permute_537, mul_309); permute_537 = mul_309 = None | |
| sum_87 = torch.ops.aten.sum.dim_IntList(add_355, [0], True); add_355 = None | |
| @@ -5465,7 +2293,6 @@ | |
| mul_645 = torch.ops.aten.mul.Tensor(sub_188, rsqrt_97); sub_188 = rsqrt_97 = None | |
| mul_646 = torch.ops.aten.mul.Tensor(mm_213, mul_310); mm_213 = mul_310 = None | |
| sum_89 = torch.ops.aten.sum.dim_IntList(mul_646, [0]); mul_646 = None | |
| - add_356 = torch.ops.aten.add.Tensor(tangents_2, mul_645); tangents_2 = mul_645 = None | |
| mul_647 = torch.ops.aten.mul.Tensor(mm_215, primals_425); primals_425 = None | |
| mul_649 = torch.ops.aten.mul.Tensor(mul_308, mul_647) | |
| sum_90 = torch.ops.aten.sum.dim_IntList(mul_649, [1], True); mul_649 = None | |
| @@ -5475,7 +2302,10 @@ | |
| mul_651 = torch.ops.aten.mul.Tensor(sub_189, rsqrt_96); sub_189 = rsqrt_96 = None | |
| mul_652 = torch.ops.aten.mul.Tensor(mm_215, mul_308); mm_215 = mul_308 = None | |
| sum_91 = torch.ops.aten.sum.dim_IntList(mul_652, [0]); mul_652 = None | |
| - add_357 = torch.ops.aten.add.Tensor(tangents_1, mul_651); tangents_1 = mul_651 = None | |
| + convert_element_type_872 = torch.ops.prims.convert_element_type.default(mul_645, torch.bfloat16); mul_645 = None | |
| + add_356 = torch.ops.aten.add.Tensor(tangents_2, convert_element_type_872); tangents_2 = convert_element_type_872 = None | |
| + convert_element_type_873 = torch.ops.prims.convert_element_type.default(mul_651, torch.bfloat16); mul_651 = None | |
| + add_357 = torch.ops.aten.add.Tensor(tangents_1, convert_element_type_873); tangents_1 = convert_element_type_873 = None | |
| sum_92 = torch.ops.aten.sum.default(add_269); add_269 = None | |
| mul_653 = torch.ops.aten.mul.Tensor(sum_92, 0.06) | |
| mul_654 = torch.ops.aten.mul.Tensor(sum_92, 0.96); sum_92 = None | |
| @@ -5552,31 +2382,52 @@ | |
| permute_544 = torch.ops.aten.permute.default(permute_259, [1, 0]); permute_259 = None | |
| mm_219 = torch.ops.aten.mm.default(add_364, permute_544); permute_544 = None | |
| permute_545 = torch.ops.aten.permute.default(add_364, [1, 0]) | |
| - mm_220 = torch.ops.aten.mm.default(permute_545, mul_303); permute_545 = mul_303 = None | |
| + mm_220 = torch.ops.aten.mm.default(permute_545, convert_element_type_784); permute_545 = convert_element_type_784 = None | |
| sum_95 = torch.ops.aten.sum.dim_IntList(add_364, [0], True); add_364 = None | |
| view_419 = torch.ops.aten.view.default(sum_95, [512]); sum_95 = None | |
| add_365 = torch.ops.aten.add.Tensor(add_352, mm_219); add_352 = mm_219 = None | |
| - mul_662 = torch.ops.aten.mul.Tensor(add_365, mul_301); mul_301 = None | |
| - mul_663 = torch.ops.aten.mul.Tensor(add_365, sigmoid_48); add_365 = None | |
| - sub_190 = torch.ops.aten.sub.Tensor(1, sigmoid_48) | |
| - mul_664 = torch.ops.aten.mul.Tensor(sigmoid_48, sub_190); sigmoid_48 = sub_190 = None | |
| - mul_665 = torch.ops.aten.mul.Tensor(mul_662, mul_664); mul_662 = mul_664 = None | |
| - add_366 = torch.ops.aten.add.Tensor(mul_663, mul_665); mul_663 = mul_665 = None | |
| - mul_666 = torch.ops.aten.mul.Tensor(add_366, primals_422); primals_422 = None | |
| - mul_668 = torch.ops.aten.mul.Tensor(mul_300, mul_666) | |
| - sum_96 = torch.ops.aten.sum.dim_IntList(mul_668, [1], True); mul_668 = None | |
| - div_88 = torch.ops.aten.div.Tensor(mul_300, 3840) | |
| - mul_669 = torch.ops.aten.mul.Tensor(div_88, sum_96); div_88 = sum_96 = None | |
| - sub_191 = torch.ops.aten.sub.Tensor(mul_666, mul_669); mul_666 = mul_669 = None | |
| - mul_670 = torch.ops.aten.mul.Tensor(sub_191, rsqrt_95); sub_191 = rsqrt_95 = None | |
| - mul_671 = torch.ops.aten.mul.Tensor(add_366, mul_300); add_366 = mul_300 = None | |
| - sum_97 = torch.ops.aten.sum.dim_IntList(mul_671, [0]); mul_671 = None | |
| - convert_element_type_866 = torch.ops.prims.convert_element_type.default(mul_670, torch.bfloat16) | |
| + convert_element_type_874 = torch.ops.prims.convert_element_type.default(add_365, torch.bfloat16); add_365 = None | |
| + convert_element_type_782 = torch.ops.prims.convert_element_type.default(add_205, torch.float32); add_205 = None | |
| + pow_75 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_782, 2) | |
| + mean_74 = torch.ops.aten.mean.dim(pow_75, [1], True); pow_75 = None | |
| + add_207 = torch.ops.aten.add.Scalar(mean_74, 1.1920928955078125e-07); mean_74 = None | |
| + rsqrt_95 = torch.ops.aten.rsqrt.default(add_207); add_207 = None | |
| + mul_300 = torch.ops.aten.mul.Tensor(convert_element_type_782, rsqrt_95) | |
| + mul_301 = torch.ops.aten.mul.Tensor(mul_300, primals_422) | |
| + convert_element_type_783 = torch.ops.prims.convert_element_type.default(mul_301, torch.bfloat16); mul_301 = None | |
| + mul_662 = torch.ops.aten.mul.Tensor(convert_element_type_874, convert_element_type_783) | |
| + sigmoid_48 = torch.ops.aten.sigmoid.default(convert_element_type_783); convert_element_type_783 = None | |
| + mul_663 = torch.ops.aten.mul.Tensor(convert_element_type_874, sigmoid_48); convert_element_type_874 = None | |
| + convert_element_type_875 = torch.ops.prims.convert_element_type.default(mul_662, torch.float32); mul_662 = None | |
| + convert_element_type_876 = torch.ops.prims.convert_element_type.default(sigmoid_48, torch.float32); sigmoid_48 = None | |
| + sub_190 = torch.ops.aten.sub.Tensor(1, convert_element_type_876) | |
| + mul_664 = torch.ops.aten.mul.Tensor(convert_element_type_876, sub_190); convert_element_type_876 = sub_190 = None | |
| + mul_665 = torch.ops.aten.mul.Tensor(convert_element_type_875, mul_664); convert_element_type_875 = mul_664 = None | |
| + convert_element_type_877 = torch.ops.prims.convert_element_type.default(mul_665, torch.bfloat16); mul_665 = None | |
| + add_366 = torch.ops.aten.add.Tensor(mul_663, convert_element_type_877); mul_663 = convert_element_type_877 = None | |
| + convert_element_type_878 = torch.ops.prims.convert_element_type.default(add_366, torch.float32); add_366 = None | |
| + mul_666 = torch.ops.aten.mul.Tensor(convert_element_type_878, mul_300); mul_300 = None | |
| + mul_667 = torch.ops.aten.mul.Tensor(convert_element_type_878, primals_422); convert_element_type_878 = primals_422 = None | |
| + sum_96 = torch.ops.aten.sum.dim_IntList(mul_666, [0], True); mul_666 = None | |
| + view_421 = torch.ops.aten.view.default(sum_96, [3840]); sum_96 = None | |
| + mul_668 = torch.ops.aten.mul.Tensor(mul_667, convert_element_type_782) | |
| + mul_669 = torch.ops.aten.mul.Tensor(mul_667, rsqrt_95); mul_667 = None | |
| + sum_97 = torch.ops.aten.sum.dim_IntList(mul_668, [1], True); mul_668 = None | |
| + mul_670 = torch.ops.aten.mul.Scalar(sum_97, -0.5); sum_97 = None | |
| + pow_84 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_95, 3); rsqrt_95 = None | |
| + mul_671 = torch.ops.aten.mul.Tensor(mul_670, pow_84); mul_670 = pow_84 = None | |
| + expand_63 = torch.ops.aten.expand.default(mul_671, [4096, 3840]); mul_671 = None | |
| + div_88 = torch.ops.aten.div.Scalar(expand_63, 3840); expand_63 = None | |
| + pow_85 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_782, 1.0); convert_element_type_782 = None | |
| + mul_672 = torch.ops.aten.mul.Scalar(pow_85, 2.0); pow_85 = None | |
| + mul_673 = torch.ops.aten.mul.Tensor(div_88, mul_672); div_88 = mul_672 = None | |
| + add_367 = torch.ops.aten.add.Tensor(mul_669, mul_673); mul_669 = mul_673 = None | |
| + convert_element_type_879 = torch.ops.prims.convert_element_type.default(add_367, torch.bfloat16); add_367 = None | |
| full_default_254 = torch.ops.aten.full.default([4096, 3840], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - index_put_18 = torch.ops.aten.index_put.default(full_default_254, [sub_32], convert_element_type_866, True) | |
| + index_put_18 = torch.ops.aten.index_put.default(full_default_254, [sub_32], convert_element_type_879, True) | |
| full_default_255 = torch.ops.aten.full.default([4096, 7680], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter = torch.ops.aten.slice_scatter.default(full_default_255, index_put_18, 1, 3840, 9223372036854775807); index_put_18 = None | |
| - permute_548 = torch.ops.aten.permute.default(convert_element_type_866, [1, 0]) | |
| + permute_548 = torch.ops.aten.permute.default(convert_element_type_879, [1, 0]) | |
| slice_80 = torch.ops.aten.slice.Tensor(mm_103, 1, 1280, 9223372036854775807) | |
| index_43 = torch.ops.aten.index.Tensor(slice_80, [sub_32]); slice_80 = None | |
| add_200 = torch.ops.aten.add.Tensor(mm_104, index_43); mm_104 = index_43 = None | |
| @@ -5585,673 +2436,840 @@ | |
| mean_72 = torch.ops.aten.mean.dim(pow_73, [1], True); pow_73 = None | |
| add_202 = torch.ops.aten.add.Scalar(mean_72, 1.1920928955078125e-07); mean_72 = None | |
| rsqrt_93 = torch.ops.aten.rsqrt.default(add_202); add_202 = None | |
| - mul_294 = torch.ops.aten.mul.Tensor(convert_element_type_772, rsqrt_93); convert_element_type_772 = None | |
| + mul_294 = torch.ops.aten.mul.Tensor(convert_element_type_772, rsqrt_93) | |
| mul_295 = torch.ops.aten.mul.Tensor(mul_294, primals_418) | |
| - sigmoid_46 = torch.ops.aten.sigmoid.default(mul_295) | |
| - mul_297 = torch.ops.aten.mul.Tensor(mul_295, sigmoid_46) | |
| - convert_element_type_777 = torch.ops.prims.convert_element_type.default(mul_297, torch.bfloat16); mul_297 = None | |
| - mm_221 = torch.ops.aten.mm.default(permute_548, convert_element_type_777); permute_548 = convert_element_type_777 = None | |
| - convert_element_type_778 = torch.ops.prims.convert_element_type.default(primals_420, torch.bfloat16); primals_420 = None | |
| - permute_258 = torch.ops.aten.permute.default(convert_element_type_778, [1, 0]); convert_element_type_778 = None | |
| + convert_element_type_773 = torch.ops.prims.convert_element_type.default(mul_295, torch.bfloat16); mul_295 = None | |
| + sigmoid_46 = torch.ops.aten.sigmoid.default(convert_element_type_773) | |
| + mul_297 = torch.ops.aten.mul.Tensor(convert_element_type_773, sigmoid_46) | |
| + mm_221 = torch.ops.aten.mm.default(permute_548, mul_297); permute_548 = mul_297 = None | |
| + convert_element_type_777 = torch.ops.prims.convert_element_type.default(primals_420, torch.bfloat16); primals_420 = None | |
| + permute_258 = torch.ops.aten.permute.default(convert_element_type_777, [1, 0]); convert_element_type_777 = None | |
| permute_550 = torch.ops.aten.permute.default(permute_258, [1, 0]); permute_258 = None | |
| - mm_222 = torch.ops.aten.mm.default(convert_element_type_866, permute_550); convert_element_type_866 = permute_550 = None | |
| - convert_element_type_871 = torch.ops.prims.convert_element_type.default(mm_221, torch.float32); mm_221 = None | |
| - convert_element_type_872 = torch.ops.prims.convert_element_type.default(mm_222, torch.float32); mm_222 = None | |
| + mm_222 = torch.ops.aten.mm.default(convert_element_type_879, permute_550); permute_550 = None | |
| + convert_element_type_884 = torch.ops.prims.convert_element_type.default(mm_221, torch.float32); mm_221 = None | |
| permute_552 = torch.ops.aten.permute.default(slice_scatter, [1, 0]) | |
| slice_79 = torch.ops.aten.slice.Tensor(mm_103, 1, 0, 1280); mm_103 = None | |
| - convert_element_type_771 = torch.ops.prims.convert_element_type.default(slice_79, torch.float32); slice_79 = None | |
| - pow_72 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_771, 2) | |
| + convert_element_type_770 = torch.ops.prims.convert_element_type.default(slice_79, torch.float32); slice_79 = None | |
| + pow_72 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_770, 2) | |
| mean_71 = torch.ops.aten.mean.dim(pow_72, [1], True); pow_72 = None | |
| add_201 = torch.ops.aten.add.Scalar(mean_71, 1.1920928955078125e-07); mean_71 = None | |
| rsqrt_92 = torch.ops.aten.rsqrt.default(add_201); add_201 = None | |
| - mul_292 = torch.ops.aten.mul.Tensor(convert_element_type_771, rsqrt_92); convert_element_type_771 = None | |
| + mul_292 = torch.ops.aten.mul.Tensor(convert_element_type_770, rsqrt_92) | |
| mul_293 = torch.ops.aten.mul.Tensor(mul_292, primals_417) | |
| - sigmoid_45 = torch.ops.aten.sigmoid.default(mul_293) | |
| - mul_296 = torch.ops.aten.mul.Tensor(mul_293, sigmoid_45) | |
| - convert_element_type_773 = torch.ops.prims.convert_element_type.default(mul_296, torch.bfloat16); mul_296 = None | |
| - mm_223 = torch.ops.aten.mm.default(permute_552, convert_element_type_773); permute_552 = convert_element_type_773 = None | |
| + convert_element_type_771 = torch.ops.prims.convert_element_type.default(mul_293, torch.bfloat16); mul_293 = None | |
| + sigmoid_45 = torch.ops.aten.sigmoid.default(convert_element_type_771) | |
| + mul_296 = torch.ops.aten.mul.Tensor(convert_element_type_771, sigmoid_45) | |
| + mm_223 = torch.ops.aten.mm.default(permute_552, mul_296); permute_552 = mul_296 = None | |
| convert_element_type_774 = torch.ops.prims.convert_element_type.default(primals_419, torch.bfloat16); primals_419 = None | |
| permute_257 = torch.ops.aten.permute.default(convert_element_type_774, [1, 0]); convert_element_type_774 = None | |
| permute_554 = torch.ops.aten.permute.default(permute_257, [1, 0]); permute_257 = None | |
| mm_224 = torch.ops.aten.mm.default(slice_scatter, permute_554); slice_scatter = permute_554 = None | |
| - convert_element_type_877 = torch.ops.prims.convert_element_type.default(mm_223, torch.float32); mm_223 = None | |
| - convert_element_type_878 = torch.ops.prims.convert_element_type.default(mm_224, torch.float32); mm_224 = None | |
| - mul_672 = torch.ops.aten.mul.Tensor(convert_element_type_872, mul_295); mul_295 = None | |
| - mul_673 = torch.ops.aten.mul.Tensor(convert_element_type_872, sigmoid_46); convert_element_type_872 = None | |
| - sub_192 = torch.ops.aten.sub.Tensor(1, sigmoid_46) | |
| - mul_674 = torch.ops.aten.mul.Tensor(sigmoid_46, sub_192); sigmoid_46 = sub_192 = None | |
| - mul_675 = torch.ops.aten.mul.Tensor(mul_672, mul_674); mul_672 = mul_674 = None | |
| - add_367 = torch.ops.aten.add.Tensor(mul_673, mul_675); mul_673 = mul_675 = None | |
| - mul_676 = torch.ops.aten.mul.Tensor(convert_element_type_878, mul_293); mul_293 = None | |
| - mul_677 = torch.ops.aten.mul.Tensor(convert_element_type_878, sigmoid_45); convert_element_type_878 = None | |
| - sub_193 = torch.ops.aten.sub.Tensor(1, sigmoid_45) | |
| - mul_678 = torch.ops.aten.mul.Tensor(sigmoid_45, sub_193); sigmoid_45 = sub_193 = None | |
| - mul_679 = torch.ops.aten.mul.Tensor(mul_676, mul_678); mul_676 = mul_678 = None | |
| - add_368 = torch.ops.aten.add.Tensor(mul_677, mul_679); mul_677 = mul_679 = None | |
| - mul_680 = torch.ops.aten.mul.Tensor(add_367, primals_418); primals_418 = None | |
| - mul_682 = torch.ops.aten.mul.Tensor(mul_294, mul_680) | |
| - sum_98 = torch.ops.aten.sum.dim_IntList(mul_682, [1], True); mul_682 = None | |
| - div_89 = torch.ops.aten.div.Tensor(mul_294, 1280) | |
| - mul_683 = torch.ops.aten.mul.Tensor(div_89, sum_98); div_89 = sum_98 = None | |
| - sub_194 = torch.ops.aten.sub.Tensor(mul_680, mul_683); mul_680 = mul_683 = None | |
| - mul_684 = torch.ops.aten.mul.Tensor(sub_194, rsqrt_93); sub_194 = rsqrt_93 = None | |
| - mul_685 = torch.ops.aten.mul.Tensor(add_367, mul_294); add_367 = mul_294 = None | |
| - sum_99 = torch.ops.aten.sum.dim_IntList(mul_685, [0]); mul_685 = None | |
| - convert_element_type_879 = torch.ops.prims.convert_element_type.default(mul_684, torch.bfloat16); mul_684 = None | |
| - mul_686 = torch.ops.aten.mul.Tensor(add_368, primals_417); primals_417 = None | |
| - mul_688 = torch.ops.aten.mul.Tensor(mul_292, mul_686) | |
| - sum_100 = torch.ops.aten.sum.dim_IntList(mul_688, [1], True); mul_688 = None | |
| - div_90 = torch.ops.aten.div.Tensor(mul_292, 1280) | |
| - mul_689 = torch.ops.aten.mul.Tensor(div_90, sum_100); div_90 = sum_100 = None | |
| - sub_195 = torch.ops.aten.sub.Tensor(mul_686, mul_689); mul_686 = mul_689 = None | |
| - mul_690 = torch.ops.aten.mul.Tensor(sub_195, rsqrt_92); sub_195 = rsqrt_92 = None | |
| - mul_691 = torch.ops.aten.mul.Tensor(add_368, mul_292); add_368 = mul_292 = None | |
| - sum_101 = torch.ops.aten.sum.dim_IntList(mul_691, [0]); mul_691 = None | |
| - convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_690, torch.bfloat16); mul_690 = None | |
| + convert_element_type_889 = torch.ops.prims.convert_element_type.default(mm_223, torch.float32); mm_223 = None | |
| + mul_674 = torch.ops.aten.mul.Tensor(mm_222, convert_element_type_773); convert_element_type_773 = None | |
| + mul_675 = torch.ops.aten.mul.Tensor(mm_222, sigmoid_46); mm_222 = None | |
| + convert_element_type_890 = torch.ops.prims.convert_element_type.default(mul_674, torch.float32); mul_674 = None | |
| + convert_element_type_891 = torch.ops.prims.convert_element_type.default(sigmoid_46, torch.float32); sigmoid_46 = None | |
| + sub_191 = torch.ops.aten.sub.Tensor(1, convert_element_type_891) | |
| + mul_676 = torch.ops.aten.mul.Tensor(convert_element_type_891, sub_191); convert_element_type_891 = sub_191 = None | |
| + mul_677 = torch.ops.aten.mul.Tensor(convert_element_type_890, mul_676); convert_element_type_890 = mul_676 = None | |
| + convert_element_type_892 = torch.ops.prims.convert_element_type.default(mul_677, torch.bfloat16); mul_677 = None | |
| + add_368 = torch.ops.aten.add.Tensor(mul_675, convert_element_type_892); mul_675 = convert_element_type_892 = None | |
| + mul_678 = torch.ops.aten.mul.Tensor(mm_224, convert_element_type_771); convert_element_type_771 = None | |
| + mul_679 = torch.ops.aten.mul.Tensor(mm_224, sigmoid_45); mm_224 = None | |
| + convert_element_type_893 = torch.ops.prims.convert_element_type.default(mul_678, torch.float32); mul_678 = None | |
| + convert_element_type_894 = torch.ops.prims.convert_element_type.default(sigmoid_45, torch.float32); sigmoid_45 = None | |
| + sub_192 = torch.ops.aten.sub.Tensor(1, convert_element_type_894) | |
| + mul_680 = torch.ops.aten.mul.Tensor(convert_element_type_894, sub_192); convert_element_type_894 = sub_192 = None | |
| + mul_681 = torch.ops.aten.mul.Tensor(convert_element_type_893, mul_680); convert_element_type_893 = mul_680 = None | |
| + convert_element_type_895 = torch.ops.prims.convert_element_type.default(mul_681, torch.bfloat16); mul_681 = None | |
| + add_369 = torch.ops.aten.add.Tensor(mul_679, convert_element_type_895); mul_679 = convert_element_type_895 = None | |
| + convert_element_type_896 = torch.ops.prims.convert_element_type.default(add_368, torch.float32); add_368 = None | |
| + mul_682 = torch.ops.aten.mul.Tensor(convert_element_type_896, mul_294); mul_294 = None | |
| + mul_683 = torch.ops.aten.mul.Tensor(convert_element_type_896, primals_418); convert_element_type_896 = primals_418 = None | |
| + sum_98 = torch.ops.aten.sum.dim_IntList(mul_682, [0], True); mul_682 = None | |
| + view_422 = torch.ops.aten.view.default(sum_98, [1280]); sum_98 = None | |
| + mul_684 = torch.ops.aten.mul.Tensor(mul_683, convert_element_type_772) | |
| + mul_685 = torch.ops.aten.mul.Tensor(mul_683, rsqrt_93); mul_683 = None | |
| + sum_99 = torch.ops.aten.sum.dim_IntList(mul_684, [1], True); mul_684 = None | |
| + mul_686 = torch.ops.aten.mul.Scalar(sum_99, -0.5); sum_99 = None | |
| + pow_86 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_93, 3); rsqrt_93 = None | |
| + mul_687 = torch.ops.aten.mul.Tensor(mul_686, pow_86); mul_686 = pow_86 = None | |
| + expand_64 = torch.ops.aten.expand.default(mul_687, [4096, 1280]); mul_687 = None | |
| + div_89 = torch.ops.aten.div.Scalar(expand_64, 1280); expand_64 = None | |
| + pow_87 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_772, 1.0); convert_element_type_772 = None | |
| + mul_688 = torch.ops.aten.mul.Scalar(pow_87, 2.0); pow_87 = None | |
| + mul_689 = torch.ops.aten.mul.Tensor(div_89, mul_688); div_89 = mul_688 = None | |
| + add_370 = torch.ops.aten.add.Tensor(mul_685, mul_689); mul_685 = mul_689 = None | |
| + convert_element_type_897 = torch.ops.prims.convert_element_type.default(add_370, torch.bfloat16); add_370 = None | |
| + convert_element_type_898 = torch.ops.prims.convert_element_type.default(add_369, torch.float32); add_369 = None | |
| + mul_690 = torch.ops.aten.mul.Tensor(convert_element_type_898, mul_292); mul_292 = None | |
| + mul_691 = torch.ops.aten.mul.Tensor(convert_element_type_898, primals_417); convert_element_type_898 = primals_417 = None | |
| + sum_100 = torch.ops.aten.sum.dim_IntList(mul_690, [0], True); mul_690 = None | |
| + view_423 = torch.ops.aten.view.default(sum_100, [1280]); sum_100 = None | |
| + mul_692 = torch.ops.aten.mul.Tensor(mul_691, convert_element_type_770) | |
| + mul_693 = torch.ops.aten.mul.Tensor(mul_691, rsqrt_92); mul_691 = None | |
| + sum_101 = torch.ops.aten.sum.dim_IntList(mul_692, [1], True); mul_692 = None | |
| + mul_694 = torch.ops.aten.mul.Scalar(sum_101, -0.5); sum_101 = None | |
| + pow_88 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_92, 3); rsqrt_92 = None | |
| + mul_695 = torch.ops.aten.mul.Tensor(mul_694, pow_88); mul_694 = pow_88 = None | |
| + expand_65 = torch.ops.aten.expand.default(mul_695, [4096, 1280]); mul_695 = None | |
| + div_90 = torch.ops.aten.div.Scalar(expand_65, 1280); expand_65 = None | |
| + pow_89 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_770, 1.0); convert_element_type_770 = None | |
| + mul_696 = torch.ops.aten.mul.Scalar(pow_89, 2.0); pow_89 = None | |
| + mul_697 = torch.ops.aten.mul.Tensor(div_90, mul_696); div_90 = mul_696 = None | |
| + add_371 = torch.ops.aten.add.Tensor(mul_693, mul_697); mul_693 = mul_697 = None | |
| + convert_element_type_899 = torch.ops.prims.convert_element_type.default(add_371, torch.bfloat16); add_371 = None | |
| full_default_256 = torch.ops.aten.full.default([4096, 1280], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - index_put_19 = torch.ops.aten.index_put.default(full_default_256, [sub_32], convert_element_type_879, True); full_default_256 = None | |
| + index_put_19 = torch.ops.aten.index_put.default(full_default_256, [sub_32], convert_element_type_897, True); full_default_256 = None | |
| full_default_257 = torch.ops.aten.full.default([4096, 2560], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_1 = torch.ops.aten.slice_scatter.default(full_default_257, index_put_19, 1, 1280, 9223372036854775807); index_put_19 = None | |
| - permute_556 = torch.ops.aten.permute.default(convert_element_type_879, [1, 0]) | |
| - convert_element_type_767 = torch.ops.prims.convert_element_type.default(mul_291, torch.bfloat16); mul_291 = None | |
| - mm_225 = torch.ops.aten.mm.default(permute_556, convert_element_type_767); permute_556 = convert_element_type_767 = None | |
| - convert_element_type_768 = torch.ops.prims.convert_element_type.default(primals_416, torch.bfloat16); primals_416 = None | |
| - permute_256 = torch.ops.aten.permute.default(convert_element_type_768, [1, 0]); convert_element_type_768 = None | |
| + permute_556 = torch.ops.aten.permute.default(convert_element_type_897, [1, 0]) | |
| + slice_74 = torch.ops.aten.slice.Tensor(_scaled_mm_8, 1, 3840, 9223372036854775807) | |
| + index_40 = torch.ops.aten.index.Tensor(slice_74, [sub_32]); slice_74 = None | |
| + add_189 = torch.ops.aten.add.Tensor(_scaled_mm_9, index_40); _scaled_mm_9 = index_40 = None | |
| + convert_element_type_742 = torch.ops.prims.convert_element_type.default(add_189, torch.float32); add_189 = None | |
| + pow_67 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_742, 2) | |
| + mean_66 = torch.ops.aten.mean.dim(pow_67, [1], True); pow_67 = None | |
| + add_191 = torch.ops.aten.add.Scalar(mean_66, 1.1920928955078125e-07); mean_66 = None | |
| + rsqrt_87 = torch.ops.aten.rsqrt.default(add_191); add_191 = None | |
| + mul_276 = torch.ops.aten.mul.Tensor(convert_element_type_742, rsqrt_87) | |
| + mul_277 = torch.ops.aten.mul.Tensor(mul_276, primals_406) | |
| + convert_element_type_743 = torch.ops.prims.convert_element_type.default(mul_277, torch.bfloat16); mul_277 = None | |
| + sigmoid_40 = torch.ops.aten.sigmoid.default(convert_element_type_743) | |
| + mul_279 = torch.ops.aten.mul.Tensor(convert_element_type_743, sigmoid_40) | |
| + slice_78 = torch.ops.aten.slice.Tensor(mm_101, 1, 3840, 9223372036854775807) | |
| + index_42 = torch.ops.aten.index.Tensor(slice_78, [sub_32]); slice_78 = None | |
| + add_195 = torch.ops.aten.add.Tensor(mm_102, index_42); mm_102 = index_42 = None | |
| + add_197 = torch.ops.aten.add.Tensor(mul_279, add_195); add_195 = None | |
| + convert_element_type_762 = torch.ops.prims.convert_element_type.default(add_197, torch.float32); add_197 = None | |
| + pow_71 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_762, 2) | |
| + mean_70 = torch.ops.aten.mean.dim(pow_71, [1], True); pow_71 = None | |
| + add_199 = torch.ops.aten.add.Scalar(mean_70, 1.1920928955078125e-07); mean_70 = None | |
| + rsqrt_91 = torch.ops.aten.rsqrt.default(add_199); add_199 = None | |
| + mul_288 = torch.ops.aten.mul.Tensor(convert_element_type_762, rsqrt_91) | |
| + mul_289 = torch.ops.aten.mul.Tensor(mul_288, primals_414) | |
| + convert_element_type_763 = torch.ops.prims.convert_element_type.default(mul_289, torch.bfloat16); mul_289 = None | |
| + sigmoid_44 = torch.ops.aten.sigmoid.default(convert_element_type_763) | |
| + mul_291 = torch.ops.aten.mul.Tensor(convert_element_type_763, sigmoid_44) | |
| + mm_225 = torch.ops.aten.mm.default(permute_556, mul_291); permute_556 = mul_291 = None | |
| + convert_element_type_767 = torch.ops.prims.convert_element_type.default(primals_416, torch.bfloat16); primals_416 = None | |
| + permute_256 = torch.ops.aten.permute.default(convert_element_type_767, [1, 0]); convert_element_type_767 = None | |
| permute_558 = torch.ops.aten.permute.default(permute_256, [1, 0]); permute_256 = None | |
| - mm_226 = torch.ops.aten.mm.default(convert_element_type_879, permute_558); convert_element_type_879 = permute_558 = None | |
| - convert_element_type_885 = torch.ops.prims.convert_element_type.default(mm_225, torch.float32); mm_225 = None | |
| - convert_element_type_886 = torch.ops.prims.convert_element_type.default(mm_226, torch.float32); mm_226 = None | |
| - add_369 = torch.ops.aten.add.Tensor(mul_670, convert_element_type_886); mul_670 = convert_element_type_886 = None | |
| - slice_scatter_2 = torch.ops.aten.slice_scatter.default(full_default_257, convert_element_type_880, 1, 0, 1280); full_default_257 = convert_element_type_880 = None | |
| - add_370 = torch.ops.aten.add.Tensor(slice_scatter_1, slice_scatter_2); slice_scatter_1 = slice_scatter_2 = None | |
| - permute_560 = torch.ops.aten.permute.default(add_370, [1, 0]) | |
| + mm_226 = torch.ops.aten.mm.default(convert_element_type_897, permute_558); convert_element_type_897 = permute_558 = None | |
| + add_372 = torch.ops.aten.add.Tensor(convert_element_type_879, mm_226); convert_element_type_879 = mm_226 = None | |
| + convert_element_type_904 = torch.ops.prims.convert_element_type.default(mm_225, torch.float32); mm_225 = None | |
| + slice_scatter_2 = torch.ops.aten.slice_scatter.default(full_default_257, convert_element_type_899, 1, 0, 1280); full_default_257 = convert_element_type_899 = None | |
| + add_373 = torch.ops.aten.add.Tensor(slice_scatter_1, slice_scatter_2); slice_scatter_1 = slice_scatter_2 = None | |
| + permute_560 = torch.ops.aten.permute.default(add_373, [1, 0]) | |
| slice_73 = torch.ops.aten.slice.Tensor(_scaled_mm_8, 1, 0, 3840); _scaled_mm_8 = None | |
| - convert_element_type_743 = torch.ops.prims.convert_element_type.default(slice_73, torch.float32); slice_73 = None | |
| - pow_66 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_743, 2) | |
| + convert_element_type_740 = torch.ops.prims.convert_element_type.default(slice_73, torch.float32); slice_73 = None | |
| + pow_66 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_740, 2) | |
| mean_65 = torch.ops.aten.mean.dim(pow_66, [1], True); pow_66 = None | |
| add_190 = torch.ops.aten.add.Scalar(mean_65, 1.1920928955078125e-07); mean_65 = None | |
| rsqrt_86 = torch.ops.aten.rsqrt.default(add_190); add_190 = None | |
| - mul_274 = torch.ops.aten.mul.Tensor(convert_element_type_743, rsqrt_86); convert_element_type_743 = None | |
| + mul_274 = torch.ops.aten.mul.Tensor(convert_element_type_740, rsqrt_86) | |
| mul_275 = torch.ops.aten.mul.Tensor(mul_274, primals_405) | |
| - sigmoid_39 = torch.ops.aten.sigmoid.default(mul_275) | |
| - mul_278 = torch.ops.aten.mul.Tensor(mul_275, sigmoid_39) | |
| + convert_element_type_741 = torch.ops.prims.convert_element_type.default(mul_275, torch.bfloat16); mul_275 = None | |
| + sigmoid_39 = torch.ops.aten.sigmoid.default(convert_element_type_741) | |
| + mul_278 = torch.ops.aten.mul.Tensor(convert_element_type_741, sigmoid_39) | |
| slice_77 = torch.ops.aten.slice.Tensor(mm_101, 1, 0, 3840); mm_101 = None | |
| add_196 = torch.ops.aten.add.Tensor(mul_278, slice_77); slice_77 = None | |
| - pow_70 = torch.ops.aten.pow.Tensor_Scalar(add_196, 2) | |
| + convert_element_type_760 = torch.ops.prims.convert_element_type.default(add_196, torch.float32); add_196 = None | |
| + pow_70 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_760, 2) | |
| mean_69 = torch.ops.aten.mean.dim(pow_70, [1], True); pow_70 = None | |
| add_198 = torch.ops.aten.add.Scalar(mean_69, 1.1920928955078125e-07); mean_69 = None | |
| rsqrt_90 = torch.ops.aten.rsqrt.default(add_198); add_198 = None | |
| - mul_286 = torch.ops.aten.mul.Tensor(add_196, rsqrt_90); add_196 = None | |
| + mul_286 = torch.ops.aten.mul.Tensor(convert_element_type_760, rsqrt_90) | |
| mul_287 = torch.ops.aten.mul.Tensor(mul_286, primals_413) | |
| - sigmoid_43 = torch.ops.aten.sigmoid.default(mul_287) | |
| - mul_290 = torch.ops.aten.mul.Tensor(mul_287, sigmoid_43) | |
| - convert_element_type_763 = torch.ops.prims.convert_element_type.default(mul_290, torch.bfloat16); mul_290 = None | |
| - mm_227 = torch.ops.aten.mm.default(permute_560, convert_element_type_763); permute_560 = convert_element_type_763 = None | |
| + convert_element_type_761 = torch.ops.prims.convert_element_type.default(mul_287, torch.bfloat16); mul_287 = None | |
| + sigmoid_43 = torch.ops.aten.sigmoid.default(convert_element_type_761) | |
| + mul_290 = torch.ops.aten.mul.Tensor(convert_element_type_761, sigmoid_43) | |
| + mm_227 = torch.ops.aten.mm.default(permute_560, mul_290); permute_560 = mul_290 = None | |
| convert_element_type_764 = torch.ops.prims.convert_element_type.default(primals_415, torch.bfloat16); primals_415 = None | |
| permute_255 = torch.ops.aten.permute.default(convert_element_type_764, [1, 0]); convert_element_type_764 = None | |
| permute_562 = torch.ops.aten.permute.default(permute_255, [1, 0]); permute_255 = None | |
| - mm_228 = torch.ops.aten.mm.default(add_370, permute_562); add_370 = permute_562 = None | |
| - convert_element_type_891 = torch.ops.prims.convert_element_type.default(mm_227, torch.float32); mm_227 = None | |
| - convert_element_type_892 = torch.ops.prims.convert_element_type.default(mm_228, torch.float32); mm_228 = None | |
| - mul_692 = torch.ops.aten.mul.Tensor(add_369, mul_289); mul_289 = None | |
| - mul_693 = torch.ops.aten.mul.Tensor(add_369, sigmoid_44); add_369 = None | |
| - sub_196 = torch.ops.aten.sub.Tensor(1, sigmoid_44) | |
| - mul_694 = torch.ops.aten.mul.Tensor(sigmoid_44, sub_196); sigmoid_44 = sub_196 = None | |
| - mul_695 = torch.ops.aten.mul.Tensor(mul_692, mul_694); mul_692 = mul_694 = None | |
| - add_371 = torch.ops.aten.add.Tensor(mul_693, mul_695); mul_693 = mul_695 = None | |
| - mul_696 = torch.ops.aten.mul.Tensor(convert_element_type_892, mul_287); mul_287 = None | |
| - mul_697 = torch.ops.aten.mul.Tensor(convert_element_type_892, sigmoid_43); convert_element_type_892 = None | |
| - sub_197 = torch.ops.aten.sub.Tensor(1, sigmoid_43) | |
| - mul_698 = torch.ops.aten.mul.Tensor(sigmoid_43, sub_197); sigmoid_43 = sub_197 = None | |
| - mul_699 = torch.ops.aten.mul.Tensor(mul_696, mul_698); mul_696 = mul_698 = None | |
| - add_372 = torch.ops.aten.add.Tensor(mul_697, mul_699); mul_697 = mul_699 = None | |
| - mul_700 = torch.ops.aten.mul.Tensor(add_371, primals_414); primals_414 = None | |
| - mul_702 = torch.ops.aten.mul.Tensor(mul_288, mul_700) | |
| - sum_102 = torch.ops.aten.sum.dim_IntList(mul_702, [1], True); mul_702 = None | |
| - div_91 = torch.ops.aten.div.Tensor(mul_288, 3840) | |
| - mul_703 = torch.ops.aten.mul.Tensor(div_91, sum_102); div_91 = sum_102 = None | |
| - sub_198 = torch.ops.aten.sub.Tensor(mul_700, mul_703); mul_700 = mul_703 = None | |
| - mul_704 = torch.ops.aten.mul.Tensor(sub_198, rsqrt_91); sub_198 = rsqrt_91 = None | |
| - mul_705 = torch.ops.aten.mul.Tensor(add_371, mul_288); add_371 = mul_288 = None | |
| - sum_103 = torch.ops.aten.sum.dim_IntList(mul_705, [0]); mul_705 = None | |
| - mul_706 = torch.ops.aten.mul.Tensor(add_372, primals_413); primals_413 = None | |
| - mul_708 = torch.ops.aten.mul.Tensor(mul_286, mul_706) | |
| - sum_104 = torch.ops.aten.sum.dim_IntList(mul_708, [1], True); mul_708 = None | |
| - div_92 = torch.ops.aten.div.Tensor(mul_286, 3840) | |
| - mul_709 = torch.ops.aten.mul.Tensor(div_92, sum_104); div_92 = sum_104 = None | |
| - sub_199 = torch.ops.aten.sub.Tensor(mul_706, mul_709); mul_706 = mul_709 = None | |
| - mul_710 = torch.ops.aten.mul.Tensor(sub_199, rsqrt_90); sub_199 = rsqrt_90 = None | |
| - mul_711 = torch.ops.aten.mul.Tensor(add_372, mul_286); add_372 = mul_286 = None | |
| - sum_105 = torch.ops.aten.sum.dim_IntList(mul_711, [0]); mul_711 = None | |
| - convert_element_type_893 = torch.ops.prims.convert_element_type.default(mul_704, torch.bfloat16) | |
| - convert_element_type_894 = torch.ops.prims.convert_element_type.default(mul_710, torch.bfloat16) | |
| - index_put_20 = torch.ops.aten.index_put.default(full_default_254, [sub_32], convert_element_type_893, True) | |
| + mm_228 = torch.ops.aten.mm.default(add_373, permute_562); add_373 = permute_562 = None | |
| + convert_element_type_909 = torch.ops.prims.convert_element_type.default(mm_227, torch.float32); mm_227 = None | |
| + mul_698 = torch.ops.aten.mul.Tensor(add_372, convert_element_type_763); convert_element_type_763 = None | |
| + mul_699 = torch.ops.aten.mul.Tensor(add_372, sigmoid_44); add_372 = None | |
| + convert_element_type_910 = torch.ops.prims.convert_element_type.default(mul_698, torch.float32); mul_698 = None | |
| + convert_element_type_911 = torch.ops.prims.convert_element_type.default(sigmoid_44, torch.float32); sigmoid_44 = None | |
| + sub_193 = torch.ops.aten.sub.Tensor(1, convert_element_type_911) | |
| + mul_700 = torch.ops.aten.mul.Tensor(convert_element_type_911, sub_193); convert_element_type_911 = sub_193 = None | |
| + mul_701 = torch.ops.aten.mul.Tensor(convert_element_type_910, mul_700); convert_element_type_910 = mul_700 = None | |
| + convert_element_type_912 = torch.ops.prims.convert_element_type.default(mul_701, torch.bfloat16); mul_701 = None | |
| + add_374 = torch.ops.aten.add.Tensor(mul_699, convert_element_type_912); mul_699 = convert_element_type_912 = None | |
| + mul_702 = torch.ops.aten.mul.Tensor(mm_228, convert_element_type_761); convert_element_type_761 = None | |
| + mul_703 = torch.ops.aten.mul.Tensor(mm_228, sigmoid_43); mm_228 = None | |
| + convert_element_type_913 = torch.ops.prims.convert_element_type.default(mul_702, torch.float32); mul_702 = None | |
| + convert_element_type_914 = torch.ops.prims.convert_element_type.default(sigmoid_43, torch.float32); sigmoid_43 = None | |
| + sub_194 = torch.ops.aten.sub.Tensor(1, convert_element_type_914) | |
| + mul_704 = torch.ops.aten.mul.Tensor(convert_element_type_914, sub_194); convert_element_type_914 = sub_194 = None | |
| + mul_705 = torch.ops.aten.mul.Tensor(convert_element_type_913, mul_704); convert_element_type_913 = mul_704 = None | |
| + convert_element_type_915 = torch.ops.prims.convert_element_type.default(mul_705, torch.bfloat16); mul_705 = None | |
| + add_375 = torch.ops.aten.add.Tensor(mul_703, convert_element_type_915); mul_703 = convert_element_type_915 = None | |
| + convert_element_type_916 = torch.ops.prims.convert_element_type.default(add_374, torch.float32); add_374 = None | |
| + mul_706 = torch.ops.aten.mul.Tensor(convert_element_type_916, mul_288); mul_288 = None | |
| + mul_707 = torch.ops.aten.mul.Tensor(convert_element_type_916, primals_414); convert_element_type_916 = primals_414 = None | |
| + sum_102 = torch.ops.aten.sum.dim_IntList(mul_706, [0], True); mul_706 = None | |
| + view_424 = torch.ops.aten.view.default(sum_102, [3840]); sum_102 = None | |
| + mul_708 = torch.ops.aten.mul.Tensor(mul_707, convert_element_type_762) | |
| + mul_709 = torch.ops.aten.mul.Tensor(mul_707, rsqrt_91); mul_707 = None | |
| + sum_103 = torch.ops.aten.sum.dim_IntList(mul_708, [1], True); mul_708 = None | |
| + mul_710 = torch.ops.aten.mul.Scalar(sum_103, -0.5); sum_103 = None | |
| + pow_90 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_91, 3); rsqrt_91 = None | |
| + mul_711 = torch.ops.aten.mul.Tensor(mul_710, pow_90); mul_710 = pow_90 = None | |
| + expand_66 = torch.ops.aten.expand.default(mul_711, [4096, 3840]); mul_711 = None | |
| + div_91 = torch.ops.aten.div.Scalar(expand_66, 3840); expand_66 = None | |
| + pow_91 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_762, 1.0); convert_element_type_762 = None | |
| + mul_712 = torch.ops.aten.mul.Scalar(pow_91, 2.0); pow_91 = None | |
| + mul_713 = torch.ops.aten.mul.Tensor(div_91, mul_712); div_91 = mul_712 = None | |
| + add_376 = torch.ops.aten.add.Tensor(mul_709, mul_713); mul_709 = mul_713 = None | |
| + convert_element_type_917 = torch.ops.prims.convert_element_type.default(add_376, torch.bfloat16); add_376 = None | |
| + convert_element_type_918 = torch.ops.prims.convert_element_type.default(add_375, torch.float32); add_375 = None | |
| + mul_714 = torch.ops.aten.mul.Tensor(convert_element_type_918, mul_286); mul_286 = None | |
| + mul_715 = torch.ops.aten.mul.Tensor(convert_element_type_918, primals_413); convert_element_type_918 = primals_413 = None | |
| + sum_104 = torch.ops.aten.sum.dim_IntList(mul_714, [0], True); mul_714 = None | |
| + view_425 = torch.ops.aten.view.default(sum_104, [3840]); sum_104 = None | |
| + mul_716 = torch.ops.aten.mul.Tensor(mul_715, convert_element_type_760) | |
| + mul_717 = torch.ops.aten.mul.Tensor(mul_715, rsqrt_90); mul_715 = None | |
| + sum_105 = torch.ops.aten.sum.dim_IntList(mul_716, [1], True); mul_716 = None | |
| + mul_718 = torch.ops.aten.mul.Scalar(sum_105, -0.5); sum_105 = None | |
| + pow_92 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_90, 3); rsqrt_90 = None | |
| + mul_719 = torch.ops.aten.mul.Tensor(mul_718, pow_92); mul_718 = pow_92 = None | |
| + expand_67 = torch.ops.aten.expand.default(mul_719, [4096, 3840]); mul_719 = None | |
| + div_92 = torch.ops.aten.div.Scalar(expand_67, 3840); expand_67 = None | |
| + pow_93 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_760, 1.0); convert_element_type_760 = None | |
| + mul_720 = torch.ops.aten.mul.Scalar(pow_93, 2.0); pow_93 = None | |
| + mul_721 = torch.ops.aten.mul.Tensor(div_92, mul_720); div_92 = mul_720 = None | |
| + add_377 = torch.ops.aten.add.Tensor(mul_717, mul_721); mul_717 = mul_721 = None | |
| + convert_element_type_919 = torch.ops.prims.convert_element_type.default(add_377, torch.bfloat16); add_377 = None | |
| + index_put_20 = torch.ops.aten.index_put.default(full_default_254, [sub_32], convert_element_type_917, True) | |
| slice_scatter_3 = torch.ops.aten.slice_scatter.default(full_default_255, index_put_20, 1, 3840, 9223372036854775807); index_put_20 = None | |
| - permute_564 = torch.ops.aten.permute.default(convert_element_type_893, [1, 0]) | |
| + permute_564 = torch.ops.aten.permute.default(convert_element_type_917, [1, 0]) | |
| slice_76 = torch.ops.aten.slice.Tensor(mm_99, 1, 1920, 9223372036854775807) | |
| index_41 = torch.ops.aten.index.Tensor(slice_76, [sub_32]); slice_76 = None | |
| add_192 = torch.ops.aten.add.Tensor(mm_100, index_41); mm_100 = index_41 = None | |
| - convert_element_type_754 = torch.ops.prims.convert_element_type.default(add_192, torch.float32); add_192 = None | |
| - pow_69 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_754, 2) | |
| + convert_element_type_752 = torch.ops.prims.convert_element_type.default(add_192, torch.float32); add_192 = None | |
| + pow_69 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_752, 2) | |
| mean_68 = torch.ops.aten.mean.dim(pow_69, [1], True); pow_69 = None | |
| add_194 = torch.ops.aten.add.Scalar(mean_68, 1.1920928955078125e-07); mean_68 = None | |
| rsqrt_89 = torch.ops.aten.rsqrt.default(add_194); add_194 = None | |
| - mul_282 = torch.ops.aten.mul.Tensor(convert_element_type_754, rsqrt_89); convert_element_type_754 = None | |
| + mul_282 = torch.ops.aten.mul.Tensor(convert_element_type_752, rsqrt_89) | |
| mul_283 = torch.ops.aten.mul.Tensor(mul_282, primals_410) | |
| - sigmoid_42 = torch.ops.aten.sigmoid.default(mul_283) | |
| - mul_285 = torch.ops.aten.mul.Tensor(mul_283, sigmoid_42) | |
| - convert_element_type_759 = torch.ops.prims.convert_element_type.default(mul_285, torch.bfloat16); mul_285 = None | |
| - mm_229 = torch.ops.aten.mm.default(permute_564, convert_element_type_759); permute_564 = convert_element_type_759 = None | |
| - convert_element_type_760 = torch.ops.prims.convert_element_type.default(primals_412, torch.bfloat16); primals_412 = None | |
| - permute_254 = torch.ops.aten.permute.default(convert_element_type_760, [1, 0]); convert_element_type_760 = None | |
| + convert_element_type_753 = torch.ops.prims.convert_element_type.default(mul_283, torch.bfloat16); mul_283 = None | |
| + sigmoid_42 = torch.ops.aten.sigmoid.default(convert_element_type_753) | |
| + mul_285 = torch.ops.aten.mul.Tensor(convert_element_type_753, sigmoid_42) | |
| + mm_229 = torch.ops.aten.mm.default(permute_564, mul_285); permute_564 = mul_285 = None | |
| + convert_element_type_757 = torch.ops.prims.convert_element_type.default(primals_412, torch.bfloat16); primals_412 = None | |
| + permute_254 = torch.ops.aten.permute.default(convert_element_type_757, [1, 0]); convert_element_type_757 = None | |
| permute_566 = torch.ops.aten.permute.default(permute_254, [1, 0]); permute_254 = None | |
| - mm_230 = torch.ops.aten.mm.default(convert_element_type_893, permute_566); convert_element_type_893 = permute_566 = None | |
| - convert_element_type_899 = torch.ops.prims.convert_element_type.default(mm_229, torch.float32); mm_229 = None | |
| - convert_element_type_900 = torch.ops.prims.convert_element_type.default(mm_230, torch.float32); mm_230 = None | |
| - slice_scatter_4 = torch.ops.aten.slice_scatter.default(full_default_255, convert_element_type_894, 1, 0, 3840); convert_element_type_894 = None | |
| - add_373 = torch.ops.aten.add.Tensor(slice_scatter_3, slice_scatter_4); slice_scatter_3 = slice_scatter_4 = None | |
| - permute_568 = torch.ops.aten.permute.default(add_373, [1, 0]) | |
| + mm_230 = torch.ops.aten.mm.default(convert_element_type_917, permute_566); permute_566 = None | |
| + convert_element_type_924 = torch.ops.prims.convert_element_type.default(mm_229, torch.float32); mm_229 = None | |
| + slice_scatter_4 = torch.ops.aten.slice_scatter.default(full_default_255, convert_element_type_919, 1, 0, 3840) | |
| + add_378 = torch.ops.aten.add.Tensor(slice_scatter_3, slice_scatter_4); slice_scatter_3 = slice_scatter_4 = None | |
| + permute_568 = torch.ops.aten.permute.default(add_378, [1, 0]) | |
| slice_75 = torch.ops.aten.slice.Tensor(mm_99, 1, 0, 1920); mm_99 = None | |
| - convert_element_type_753 = torch.ops.prims.convert_element_type.default(slice_75, torch.float32); slice_75 = None | |
| - pow_68 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_753, 2) | |
| + convert_element_type_750 = torch.ops.prims.convert_element_type.default(slice_75, torch.float32); slice_75 = None | |
| + pow_68 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_750, 2) | |
| mean_67 = torch.ops.aten.mean.dim(pow_68, [1], True); pow_68 = None | |
| add_193 = torch.ops.aten.add.Scalar(mean_67, 1.1920928955078125e-07); mean_67 = None | |
| rsqrt_88 = torch.ops.aten.rsqrt.default(add_193); add_193 = None | |
| - mul_280 = torch.ops.aten.mul.Tensor(convert_element_type_753, rsqrt_88); convert_element_type_753 = None | |
| + mul_280 = torch.ops.aten.mul.Tensor(convert_element_type_750, rsqrt_88) | |
| mul_281 = torch.ops.aten.mul.Tensor(mul_280, primals_409) | |
| - sigmoid_41 = torch.ops.aten.sigmoid.default(mul_281) | |
| - mul_284 = torch.ops.aten.mul.Tensor(mul_281, sigmoid_41) | |
| - convert_element_type_755 = torch.ops.prims.convert_element_type.default(mul_284, torch.bfloat16); mul_284 = None | |
| - mm_231 = torch.ops.aten.mm.default(permute_568, convert_element_type_755); permute_568 = convert_element_type_755 = None | |
| - convert_element_type_756 = torch.ops.prims.convert_element_type.default(primals_411, torch.bfloat16); primals_411 = None | |
| - permute_253 = torch.ops.aten.permute.default(convert_element_type_756, [1, 0]); convert_element_type_756 = None | |
| + convert_element_type_751 = torch.ops.prims.convert_element_type.default(mul_281, torch.bfloat16); mul_281 = None | |
| + sigmoid_41 = torch.ops.aten.sigmoid.default(convert_element_type_751) | |
| + mul_284 = torch.ops.aten.mul.Tensor(convert_element_type_751, sigmoid_41) | |
| + mm_231 = torch.ops.aten.mm.default(permute_568, mul_284); permute_568 = mul_284 = None | |
| + convert_element_type_754 = torch.ops.prims.convert_element_type.default(primals_411, torch.bfloat16); primals_411 = None | |
| + permute_253 = torch.ops.aten.permute.default(convert_element_type_754, [1, 0]); convert_element_type_754 = None | |
| permute_570 = torch.ops.aten.permute.default(permute_253, [1, 0]); permute_253 = None | |
| - mm_232 = torch.ops.aten.mm.default(add_373, permute_570); add_373 = permute_570 = None | |
| - convert_element_type_905 = torch.ops.prims.convert_element_type.default(mm_231, torch.float32); mm_231 = None | |
| - convert_element_type_906 = torch.ops.prims.convert_element_type.default(mm_232, torch.float32); mm_232 = None | |
| - mul_712 = torch.ops.aten.mul.Tensor(convert_element_type_900, mul_283); mul_283 = None | |
| - mul_713 = torch.ops.aten.mul.Tensor(convert_element_type_900, sigmoid_42); convert_element_type_900 = None | |
| - sub_200 = torch.ops.aten.sub.Tensor(1, sigmoid_42) | |
| - mul_714 = torch.ops.aten.mul.Tensor(sigmoid_42, sub_200); sigmoid_42 = sub_200 = None | |
| - mul_715 = torch.ops.aten.mul.Tensor(mul_712, mul_714); mul_712 = mul_714 = None | |
| - add_374 = torch.ops.aten.add.Tensor(mul_713, mul_715); mul_713 = mul_715 = None | |
| - mul_716 = torch.ops.aten.mul.Tensor(convert_element_type_906, mul_281); mul_281 = None | |
| - mul_717 = torch.ops.aten.mul.Tensor(convert_element_type_906, sigmoid_41); convert_element_type_906 = None | |
| - sub_201 = torch.ops.aten.sub.Tensor(1, sigmoid_41) | |
| - mul_718 = torch.ops.aten.mul.Tensor(sigmoid_41, sub_201); sigmoid_41 = sub_201 = None | |
| - mul_719 = torch.ops.aten.mul.Tensor(mul_716, mul_718); mul_716 = mul_718 = None | |
| - add_375 = torch.ops.aten.add.Tensor(mul_717, mul_719); mul_717 = mul_719 = None | |
| - mul_720 = torch.ops.aten.mul.Tensor(add_374, primals_410); primals_410 = None | |
| - mul_722 = torch.ops.aten.mul.Tensor(mul_282, mul_720) | |
| - sum_106 = torch.ops.aten.sum.dim_IntList(mul_722, [1], True); mul_722 = None | |
| - div_93 = torch.ops.aten.div.Tensor(mul_282, 1920) | |
| - mul_723 = torch.ops.aten.mul.Tensor(div_93, sum_106); div_93 = sum_106 = None | |
| - sub_202 = torch.ops.aten.sub.Tensor(mul_720, mul_723); mul_720 = mul_723 = None | |
| - mul_724 = torch.ops.aten.mul.Tensor(sub_202, rsqrt_89); sub_202 = rsqrt_89 = None | |
| - mul_725 = torch.ops.aten.mul.Tensor(add_374, mul_282); add_374 = mul_282 = None | |
| - sum_107 = torch.ops.aten.sum.dim_IntList(mul_725, [0]); mul_725 = None | |
| - convert_element_type_907 = torch.ops.prims.convert_element_type.default(mul_724, torch.bfloat16); mul_724 = None | |
| - mul_726 = torch.ops.aten.mul.Tensor(add_375, primals_409); primals_409 = None | |
| - mul_728 = torch.ops.aten.mul.Tensor(mul_280, mul_726) | |
| - sum_108 = torch.ops.aten.sum.dim_IntList(mul_728, [1], True); mul_728 = None | |
| - div_94 = torch.ops.aten.div.Tensor(mul_280, 1920) | |
| - mul_729 = torch.ops.aten.mul.Tensor(div_94, sum_108); div_94 = sum_108 = None | |
| - sub_203 = torch.ops.aten.sub.Tensor(mul_726, mul_729); mul_726 = mul_729 = None | |
| - mul_730 = torch.ops.aten.mul.Tensor(sub_203, rsqrt_88); sub_203 = rsqrt_88 = None | |
| - mul_731 = torch.ops.aten.mul.Tensor(add_375, mul_280); add_375 = mul_280 = None | |
| - sum_109 = torch.ops.aten.sum.dim_IntList(mul_731, [0]); mul_731 = None | |
| - convert_element_type_908 = torch.ops.prims.convert_element_type.default(mul_730, torch.bfloat16); mul_730 = None | |
| + mm_232 = torch.ops.aten.mm.default(add_378, permute_570); add_378 = permute_570 = None | |
| + convert_element_type_929 = torch.ops.prims.convert_element_type.default(mm_231, torch.float32); mm_231 = None | |
| + mul_722 = torch.ops.aten.mul.Tensor(mm_230, convert_element_type_753); convert_element_type_753 = None | |
| + mul_723 = torch.ops.aten.mul.Tensor(mm_230, sigmoid_42); mm_230 = None | |
| + convert_element_type_930 = torch.ops.prims.convert_element_type.default(mul_722, torch.float32); mul_722 = None | |
| + convert_element_type_931 = torch.ops.prims.convert_element_type.default(sigmoid_42, torch.float32); sigmoid_42 = None | |
| + sub_195 = torch.ops.aten.sub.Tensor(1, convert_element_type_931) | |
| + mul_724 = torch.ops.aten.mul.Tensor(convert_element_type_931, sub_195); convert_element_type_931 = sub_195 = None | |
| + mul_725 = torch.ops.aten.mul.Tensor(convert_element_type_930, mul_724); convert_element_type_930 = mul_724 = None | |
| + convert_element_type_932 = torch.ops.prims.convert_element_type.default(mul_725, torch.bfloat16); mul_725 = None | |
| + add_379 = torch.ops.aten.add.Tensor(mul_723, convert_element_type_932); mul_723 = convert_element_type_932 = None | |
| + mul_726 = torch.ops.aten.mul.Tensor(mm_232, convert_element_type_751); convert_element_type_751 = None | |
| + mul_727 = torch.ops.aten.mul.Tensor(mm_232, sigmoid_41); mm_232 = None | |
| + convert_element_type_933 = torch.ops.prims.convert_element_type.default(mul_726, torch.float32); mul_726 = None | |
| + convert_element_type_934 = torch.ops.prims.convert_element_type.default(sigmoid_41, torch.float32); sigmoid_41 = None | |
| + sub_196 = torch.ops.aten.sub.Tensor(1, convert_element_type_934) | |
| + mul_728 = torch.ops.aten.mul.Tensor(convert_element_type_934, sub_196); convert_element_type_934 = sub_196 = None | |
| + mul_729 = torch.ops.aten.mul.Tensor(convert_element_type_933, mul_728); convert_element_type_933 = mul_728 = None | |
| + convert_element_type_935 = torch.ops.prims.convert_element_type.default(mul_729, torch.bfloat16); mul_729 = None | |
| + add_380 = torch.ops.aten.add.Tensor(mul_727, convert_element_type_935); mul_727 = convert_element_type_935 = None | |
| + convert_element_type_936 = torch.ops.prims.convert_element_type.default(add_379, torch.float32); add_379 = None | |
| + mul_730 = torch.ops.aten.mul.Tensor(convert_element_type_936, mul_282); mul_282 = None | |
| + mul_731 = torch.ops.aten.mul.Tensor(convert_element_type_936, primals_410); convert_element_type_936 = primals_410 = None | |
| + sum_106 = torch.ops.aten.sum.dim_IntList(mul_730, [0], True); mul_730 = None | |
| + view_426 = torch.ops.aten.view.default(sum_106, [1920]); sum_106 = None | |
| + mul_732 = torch.ops.aten.mul.Tensor(mul_731, convert_element_type_752) | |
| + mul_733 = torch.ops.aten.mul.Tensor(mul_731, rsqrt_89); mul_731 = None | |
| + sum_107 = torch.ops.aten.sum.dim_IntList(mul_732, [1], True); mul_732 = None | |
| + mul_734 = torch.ops.aten.mul.Scalar(sum_107, -0.5); sum_107 = None | |
| + pow_94 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_89, 3); rsqrt_89 = None | |
| + mul_735 = torch.ops.aten.mul.Tensor(mul_734, pow_94); mul_734 = pow_94 = None | |
| + expand_68 = torch.ops.aten.expand.default(mul_735, [4096, 1920]); mul_735 = None | |
| + div_93 = torch.ops.aten.div.Scalar(expand_68, 1920); expand_68 = None | |
| + pow_95 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_752, 1.0); convert_element_type_752 = None | |
| + mul_736 = torch.ops.aten.mul.Scalar(pow_95, 2.0); pow_95 = None | |
| + mul_737 = torch.ops.aten.mul.Tensor(div_93, mul_736); div_93 = mul_736 = None | |
| + add_381 = torch.ops.aten.add.Tensor(mul_733, mul_737); mul_733 = mul_737 = None | |
| + convert_element_type_937 = torch.ops.prims.convert_element_type.default(add_381, torch.bfloat16); add_381 = None | |
| + convert_element_type_938 = torch.ops.prims.convert_element_type.default(add_380, torch.float32); add_380 = None | |
| + mul_738 = torch.ops.aten.mul.Tensor(convert_element_type_938, mul_280); mul_280 = None | |
| + mul_739 = torch.ops.aten.mul.Tensor(convert_element_type_938, primals_409); convert_element_type_938 = primals_409 = None | |
| + sum_108 = torch.ops.aten.sum.dim_IntList(mul_738, [0], True); mul_738 = None | |
| + view_427 = torch.ops.aten.view.default(sum_108, [1920]); sum_108 = None | |
| + mul_740 = torch.ops.aten.mul.Tensor(mul_739, convert_element_type_750) | |
| + mul_741 = torch.ops.aten.mul.Tensor(mul_739, rsqrt_88); mul_739 = None | |
| + sum_109 = torch.ops.aten.sum.dim_IntList(mul_740, [1], True); mul_740 = None | |
| + mul_742 = torch.ops.aten.mul.Scalar(sum_109, -0.5); sum_109 = None | |
| + pow_96 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_88, 3); rsqrt_88 = None | |
| + mul_743 = torch.ops.aten.mul.Tensor(mul_742, pow_96); mul_742 = pow_96 = None | |
| + expand_69 = torch.ops.aten.expand.default(mul_743, [4096, 1920]); mul_743 = None | |
| + div_94 = torch.ops.aten.div.Scalar(expand_69, 1920); expand_69 = None | |
| + pow_97 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_750, 1.0); convert_element_type_750 = None | |
| + mul_744 = torch.ops.aten.mul.Scalar(pow_97, 2.0); pow_97 = None | |
| + mul_745 = torch.ops.aten.mul.Tensor(div_94, mul_744); div_94 = mul_744 = None | |
| + add_382 = torch.ops.aten.add.Tensor(mul_741, mul_745); mul_741 = mul_745 = None | |
| + convert_element_type_939 = torch.ops.prims.convert_element_type.default(add_382, torch.bfloat16); add_382 = None | |
| full_default_262 = torch.ops.aten.full.default([4096, 1920], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - index_put_21 = torch.ops.aten.index_put.default(full_default_262, [sub_32], convert_element_type_907, True); full_default_262 = None | |
| + index_put_21 = torch.ops.aten.index_put.default(full_default_262, [sub_32], convert_element_type_937, True); full_default_262 = None | |
| slice_scatter_5 = torch.ops.aten.slice_scatter.default(full_default_254, index_put_21, 1, 1920, 9223372036854775807); index_put_21 = None | |
| - permute_572 = torch.ops.aten.permute.default(convert_element_type_907, [1, 0]) | |
| - convert_element_type_749 = torch.ops.prims.convert_element_type.default(mul_279, torch.bfloat16); mul_279 = None | |
| - mm_233 = torch.ops.aten.mm.default(permute_572, convert_element_type_749); permute_572 = convert_element_type_749 = None | |
| - convert_element_type_750 = torch.ops.prims.convert_element_type.default(primals_408, torch.bfloat16); primals_408 = None | |
| - permute_252 = torch.ops.aten.permute.default(convert_element_type_750, [1, 0]); convert_element_type_750 = None | |
| + permute_572 = torch.ops.aten.permute.default(convert_element_type_937, [1, 0]) | |
| + mm_233 = torch.ops.aten.mm.default(permute_572, mul_279); permute_572 = mul_279 = None | |
| + convert_element_type_747 = torch.ops.prims.convert_element_type.default(primals_408, torch.bfloat16); primals_408 = None | |
| + permute_252 = torch.ops.aten.permute.default(convert_element_type_747, [1, 0]); convert_element_type_747 = None | |
| permute_574 = torch.ops.aten.permute.default(permute_252, [1, 0]); permute_252 = None | |
| - mm_234 = torch.ops.aten.mm.default(convert_element_type_907, permute_574); convert_element_type_907 = permute_574 = None | |
| - convert_element_type_913 = torch.ops.prims.convert_element_type.default(mm_233, torch.float32); mm_233 = None | |
| - convert_element_type_914 = torch.ops.prims.convert_element_type.default(mm_234, torch.float32); mm_234 = None | |
| - add_376 = torch.ops.aten.add.Tensor(mul_704, convert_element_type_914); mul_704 = convert_element_type_914 = None | |
| - slice_scatter_6 = torch.ops.aten.slice_scatter.default(full_default_254, convert_element_type_908, 1, 0, 1920); convert_element_type_908 = None | |
| - add_377 = torch.ops.aten.add.Tensor(slice_scatter_5, slice_scatter_6); slice_scatter_5 = slice_scatter_6 = None | |
| - permute_576 = torch.ops.aten.permute.default(add_377, [1, 0]) | |
| - convert_element_type_745 = torch.ops.prims.convert_element_type.default(mul_278, torch.bfloat16); mul_278 = None | |
| - mm_235 = torch.ops.aten.mm.default(permute_576, convert_element_type_745); permute_576 = convert_element_type_745 = None | |
| - convert_element_type_746 = torch.ops.prims.convert_element_type.default(primals_407, torch.bfloat16); primals_407 = None | |
| - permute_251 = torch.ops.aten.permute.default(convert_element_type_746, [1, 0]); convert_element_type_746 = None | |
| + mm_234 = torch.ops.aten.mm.default(convert_element_type_937, permute_574); convert_element_type_937 = permute_574 = None | |
| + add_383 = torch.ops.aten.add.Tensor(convert_element_type_917, mm_234); convert_element_type_917 = mm_234 = None | |
| + convert_element_type_944 = torch.ops.prims.convert_element_type.default(mm_233, torch.float32); mm_233 = None | |
| + slice_scatter_6 = torch.ops.aten.slice_scatter.default(full_default_254, convert_element_type_939, 1, 0, 1920); convert_element_type_939 = None | |
| + add_384 = torch.ops.aten.add.Tensor(slice_scatter_5, slice_scatter_6); slice_scatter_5 = slice_scatter_6 = None | |
| + permute_576 = torch.ops.aten.permute.default(add_384, [1, 0]) | |
| + mm_235 = torch.ops.aten.mm.default(permute_576, mul_278); permute_576 = mul_278 = None | |
| + convert_element_type_744 = torch.ops.prims.convert_element_type.default(primals_407, torch.bfloat16); primals_407 = None | |
| + permute_251 = torch.ops.aten.permute.default(convert_element_type_744, [1, 0]); convert_element_type_744 = None | |
| permute_578 = torch.ops.aten.permute.default(permute_251, [1, 0]); permute_251 = None | |
| - mm_236 = torch.ops.aten.mm.default(add_377, permute_578); add_377 = permute_578 = None | |
| - convert_element_type_919 = torch.ops.prims.convert_element_type.default(mm_235, torch.float32); mm_235 = None | |
| - convert_element_type_920 = torch.ops.prims.convert_element_type.default(mm_236, torch.float32); mm_236 = None | |
| - add_378 = torch.ops.aten.add.Tensor(mul_710, convert_element_type_920); mul_710 = convert_element_type_920 = None | |
| - mul_732 = torch.ops.aten.mul.Tensor(add_376, mul_277); mul_277 = None | |
| - mul_733 = torch.ops.aten.mul.Tensor(add_376, sigmoid_40); add_376 = None | |
| - sub_204 = torch.ops.aten.sub.Tensor(1, sigmoid_40) | |
| - mul_734 = torch.ops.aten.mul.Tensor(sigmoid_40, sub_204); sigmoid_40 = sub_204 = None | |
| - mul_735 = torch.ops.aten.mul.Tensor(mul_732, mul_734); mul_732 = mul_734 = None | |
| - add_379 = torch.ops.aten.add.Tensor(mul_733, mul_735); mul_733 = mul_735 = None | |
| - mul_736 = torch.ops.aten.mul.Tensor(add_378, mul_275); mul_275 = None | |
| - mul_737 = torch.ops.aten.mul.Tensor(add_378, sigmoid_39); add_378 = None | |
| - sub_205 = torch.ops.aten.sub.Tensor(1, sigmoid_39) | |
| - mul_738 = torch.ops.aten.mul.Tensor(sigmoid_39, sub_205); sigmoid_39 = sub_205 = None | |
| - mul_739 = torch.ops.aten.mul.Tensor(mul_736, mul_738); mul_736 = mul_738 = None | |
| - add_380 = torch.ops.aten.add.Tensor(mul_737, mul_739); mul_737 = mul_739 = None | |
| - mul_740 = torch.ops.aten.mul.Tensor(add_379, primals_406); primals_406 = None | |
| - mul_742 = torch.ops.aten.mul.Tensor(mul_276, mul_740) | |
| - sum_110 = torch.ops.aten.sum.dim_IntList(mul_742, [1], True); mul_742 = None | |
| - div_95 = torch.ops.aten.div.Tensor(mul_276, 3840) | |
| - mul_743 = torch.ops.aten.mul.Tensor(div_95, sum_110); div_95 = sum_110 = None | |
| - sub_206 = torch.ops.aten.sub.Tensor(mul_740, mul_743); mul_740 = mul_743 = None | |
| - mul_744 = torch.ops.aten.mul.Tensor(sub_206, rsqrt_87); sub_206 = rsqrt_87 = None | |
| - mul_745 = torch.ops.aten.mul.Tensor(add_379, mul_276); add_379 = mul_276 = None | |
| - sum_111 = torch.ops.aten.sum.dim_IntList(mul_745, [0]); mul_745 = None | |
| - convert_element_type_921 = torch.ops.prims.convert_element_type.default(mul_744, torch.bfloat16); mul_744 = None | |
| - mul_746 = torch.ops.aten.mul.Tensor(add_380, primals_405); primals_405 = None | |
| - mul_748 = torch.ops.aten.mul.Tensor(mul_274, mul_746) | |
| - sum_112 = torch.ops.aten.sum.dim_IntList(mul_748, [1], True); mul_748 = None | |
| - div_96 = torch.ops.aten.div.Tensor(mul_274, 3840) | |
| - mul_749 = torch.ops.aten.mul.Tensor(div_96, sum_112); div_96 = sum_112 = None | |
| - sub_207 = torch.ops.aten.sub.Tensor(mul_746, mul_749); mul_746 = mul_749 = None | |
| - mul_750 = torch.ops.aten.mul.Tensor(sub_207, rsqrt_86); sub_207 = rsqrt_86 = None | |
| - mul_751 = torch.ops.aten.mul.Tensor(add_380, mul_274); add_380 = mul_274 = None | |
| - sum_113 = torch.ops.aten.sum.dim_IntList(mul_751, [0]); mul_751 = None | |
| - convert_element_type_922 = torch.ops.prims.convert_element_type.default(mul_750, torch.bfloat16); mul_750 = None | |
| - index_put_22 = torch.ops.aten.index_put.default(full_default_254, [sub_32], convert_element_type_921, True); full_default_254 = None | |
| + mm_236 = torch.ops.aten.mm.default(add_384, permute_578); add_384 = permute_578 = None | |
| + add_385 = torch.ops.aten.add.Tensor(convert_element_type_919, mm_236); convert_element_type_919 = mm_236 = None | |
| + convert_element_type_949 = torch.ops.prims.convert_element_type.default(mm_235, torch.float32); mm_235 = None | |
| + mul_746 = torch.ops.aten.mul.Tensor(add_383, convert_element_type_743); convert_element_type_743 = None | |
| + mul_747 = torch.ops.aten.mul.Tensor(add_383, sigmoid_40); add_383 = None | |
| + convert_element_type_950 = torch.ops.prims.convert_element_type.default(mul_746, torch.float32); mul_746 = None | |
| + convert_element_type_951 = torch.ops.prims.convert_element_type.default(sigmoid_40, torch.float32); sigmoid_40 = None | |
| + sub_197 = torch.ops.aten.sub.Tensor(1, convert_element_type_951) | |
| + mul_748 = torch.ops.aten.mul.Tensor(convert_element_type_951, sub_197); convert_element_type_951 = sub_197 = None | |
| + mul_749 = torch.ops.aten.mul.Tensor(convert_element_type_950, mul_748); convert_element_type_950 = mul_748 = None | |
| + convert_element_type_952 = torch.ops.prims.convert_element_type.default(mul_749, torch.bfloat16); mul_749 = None | |
| + add_386 = torch.ops.aten.add.Tensor(mul_747, convert_element_type_952); mul_747 = convert_element_type_952 = None | |
| + mul_750 = torch.ops.aten.mul.Tensor(add_385, convert_element_type_741); convert_element_type_741 = None | |
| + mul_751 = torch.ops.aten.mul.Tensor(add_385, sigmoid_39); add_385 = None | |
| + convert_element_type_953 = torch.ops.prims.convert_element_type.default(mul_750, torch.float32); mul_750 = None | |
| + convert_element_type_954 = torch.ops.prims.convert_element_type.default(sigmoid_39, torch.float32); sigmoid_39 = None | |
| + sub_198 = torch.ops.aten.sub.Tensor(1, convert_element_type_954) | |
| + mul_752 = torch.ops.aten.mul.Tensor(convert_element_type_954, sub_198); convert_element_type_954 = sub_198 = None | |
| + mul_753 = torch.ops.aten.mul.Tensor(convert_element_type_953, mul_752); convert_element_type_953 = mul_752 = None | |
| + convert_element_type_955 = torch.ops.prims.convert_element_type.default(mul_753, torch.bfloat16); mul_753 = None | |
| + add_387 = torch.ops.aten.add.Tensor(mul_751, convert_element_type_955); mul_751 = convert_element_type_955 = None | |
| + convert_element_type_956 = torch.ops.prims.convert_element_type.default(add_386, torch.float32); add_386 = None | |
| + mul_754 = torch.ops.aten.mul.Tensor(convert_element_type_956, mul_276); mul_276 = None | |
| + mul_755 = torch.ops.aten.mul.Tensor(convert_element_type_956, primals_406); convert_element_type_956 = primals_406 = None | |
| + sum_110 = torch.ops.aten.sum.dim_IntList(mul_754, [0], True); mul_754 = None | |
| + view_428 = torch.ops.aten.view.default(sum_110, [3840]); sum_110 = None | |
| + mul_756 = torch.ops.aten.mul.Tensor(mul_755, convert_element_type_742) | |
| + mul_757 = torch.ops.aten.mul.Tensor(mul_755, rsqrt_87); mul_755 = None | |
| + sum_111 = torch.ops.aten.sum.dim_IntList(mul_756, [1], True); mul_756 = None | |
| + mul_758 = torch.ops.aten.mul.Scalar(sum_111, -0.5); sum_111 = None | |
| + pow_98 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_87, 3); rsqrt_87 = None | |
| + mul_759 = torch.ops.aten.mul.Tensor(mul_758, pow_98); mul_758 = pow_98 = None | |
| + expand_70 = torch.ops.aten.expand.default(mul_759, [4096, 3840]); mul_759 = None | |
| + div_95 = torch.ops.aten.div.Scalar(expand_70, 3840); expand_70 = None | |
| + pow_99 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_742, 1.0); convert_element_type_742 = None | |
| + mul_760 = torch.ops.aten.mul.Scalar(pow_99, 2.0); pow_99 = None | |
| + mul_761 = torch.ops.aten.mul.Tensor(div_95, mul_760); div_95 = mul_760 = None | |
| + add_388 = torch.ops.aten.add.Tensor(mul_757, mul_761); mul_757 = mul_761 = None | |
| + convert_element_type_957 = torch.ops.prims.convert_element_type.default(add_388, torch.bfloat16); add_388 = None | |
| + convert_element_type_958 = torch.ops.prims.convert_element_type.default(add_387, torch.float32); add_387 = None | |
| + mul_762 = torch.ops.aten.mul.Tensor(convert_element_type_958, mul_274); mul_274 = None | |
| + mul_763 = torch.ops.aten.mul.Tensor(convert_element_type_958, primals_405); convert_element_type_958 = primals_405 = None | |
| + sum_112 = torch.ops.aten.sum.dim_IntList(mul_762, [0], True); mul_762 = None | |
| + view_429 = torch.ops.aten.view.default(sum_112, [3840]); sum_112 = None | |
| + mul_764 = torch.ops.aten.mul.Tensor(mul_763, convert_element_type_740) | |
| + mul_765 = torch.ops.aten.mul.Tensor(mul_763, rsqrt_86); mul_763 = None | |
| + sum_113 = torch.ops.aten.sum.dim_IntList(mul_764, [1], True); mul_764 = None | |
| + mul_766 = torch.ops.aten.mul.Scalar(sum_113, -0.5); sum_113 = None | |
| + pow_100 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_86, 3); rsqrt_86 = None | |
| + mul_767 = torch.ops.aten.mul.Tensor(mul_766, pow_100); mul_766 = pow_100 = None | |
| + expand_71 = torch.ops.aten.expand.default(mul_767, [4096, 3840]); mul_767 = None | |
| + div_96 = torch.ops.aten.div.Scalar(expand_71, 3840); expand_71 = None | |
| + pow_101 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_740, 1.0); convert_element_type_740 = None | |
| + mul_768 = torch.ops.aten.mul.Scalar(pow_101, 2.0); pow_101 = None | |
| + mul_769 = torch.ops.aten.mul.Tensor(div_96, mul_768); div_96 = mul_768 = None | |
| + add_389 = torch.ops.aten.add.Tensor(mul_765, mul_769); mul_765 = mul_769 = None | |
| + convert_element_type_959 = torch.ops.prims.convert_element_type.default(add_389, torch.bfloat16); add_389 = None | |
| + index_put_22 = torch.ops.aten.index_put.default(full_default_254, [sub_32], convert_element_type_957, True); full_default_254 = None | |
| slice_scatter_7 = torch.ops.aten.slice_scatter.default(full_default_255, index_put_22, 1, 3840, 9223372036854775807); index_put_22 = None | |
| - abs_48 = torch.ops.aten.abs.default(convert_element_type_921) | |
| + abs_48 = torch.ops.aten.abs.default(convert_element_type_957) | |
| amax_21 = torch.ops.aten.amax.default(abs_48, [-1], True); abs_48 = None | |
| - convert_element_type_923 = torch.ops.prims.convert_element_type.default(amax_21, torch.float64); amax_21 = None | |
| - clamp_min_74 = torch.ops.aten.clamp_min.default(convert_element_type_923, 1e-12); convert_element_type_923 = None | |
| + convert_element_type_960 = torch.ops.prims.convert_element_type.default(amax_21, torch.float64); amax_21 = None | |
| + clamp_min_74 = torch.ops.aten.clamp_min.default(convert_element_type_960, 1e-12); convert_element_type_960 = None | |
| reciprocal_40 = torch.ops.aten.reciprocal.default(clamp_min_74); clamp_min_74 = None | |
| - mul_752 = torch.ops.aten.mul.Tensor(reciprocal_40, 448.0); reciprocal_40 = None | |
| - convert_element_type_924 = torch.ops.prims.convert_element_type.default(mul_752, torch.float32); mul_752 = None | |
| - log2_20 = torch.ops.aten.log2.default(convert_element_type_924); convert_element_type_924 = None | |
| + mul_770 = torch.ops.aten.mul.Tensor(reciprocal_40, 448.0); reciprocal_40 = None | |
| + convert_element_type_961 = torch.ops.prims.convert_element_type.default(mul_770, torch.float32); mul_770 = None | |
| + log2_20 = torch.ops.aten.log2.default(convert_element_type_961); convert_element_type_961 = None | |
| floor_20 = torch.ops.aten.floor.default(log2_20); log2_20 = None | |
| exp2_20 = torch.ops.aten.exp2.default(floor_20); floor_20 = None | |
| - convert_element_type_925 = torch.ops.prims.convert_element_type.default(convert_element_type_921, torch.float32) | |
| - mul_753 = torch.ops.aten.mul.Tensor(convert_element_type_925, exp2_20); convert_element_type_925 = None | |
| - clamp_min_75 = torch.ops.aten.clamp_min.default(mul_753, -448.0); mul_753 = None | |
| + convert_element_type_962 = torch.ops.prims.convert_element_type.default(convert_element_type_957, torch.float32) | |
| + mul_771 = torch.ops.aten.mul.Tensor(convert_element_type_962, exp2_20); convert_element_type_962 = None | |
| + clamp_min_75 = torch.ops.aten.clamp_min.default(mul_771, -448.0); mul_771 = None | |
| clamp_max_50 = torch.ops.aten.clamp_max.default(clamp_min_75, 448.0); clamp_min_75 = None | |
| - convert_element_type_926 = torch.ops.prims.convert_element_type.default(clamp_max_50, torch.float8_e4m3fn); clamp_max_50 = None | |
| + convert_element_type_963 = torch.ops.prims.convert_element_type.default(clamp_max_50, torch.float8_e4m3fn); clamp_max_50 = None | |
| permute_250 = torch.ops.aten.permute.default(primals_404, [1, 0]); primals_404 = None | |
| abs_20 = torch.ops.aten.abs.default(permute_250) | |
| max_1 = torch.ops.aten.max.default(abs_20); abs_20 = None | |
| - convert_element_type_927 = torch.ops.prims.convert_element_type.default(max_1, torch.float64); max_1 = None | |
| - clamp_min_76 = torch.ops.aten.clamp_min.default(convert_element_type_927, 1e-12); convert_element_type_927 = None | |
| + convert_element_type_964 = torch.ops.prims.convert_element_type.default(max_1, torch.float64); max_1 = None | |
| + clamp_min_76 = torch.ops.aten.clamp_min.default(convert_element_type_964, 1e-12); convert_element_type_964 = None | |
| reciprocal_41 = torch.ops.aten.reciprocal.default(clamp_min_76); clamp_min_76 = None | |
| - mul_754 = torch.ops.aten.mul.Tensor(reciprocal_41, 448.0); reciprocal_41 = None | |
| - convert_element_type_928 = torch.ops.prims.convert_element_type.default(mul_754, torch.float32); mul_754 = None | |
| - log2_21 = torch.ops.aten.log2.default(convert_element_type_928); convert_element_type_928 = None | |
| + mul_772 = torch.ops.aten.mul.Tensor(reciprocal_41, 448.0); reciprocal_41 = None | |
| + convert_element_type_965 = torch.ops.prims.convert_element_type.default(mul_772, torch.float32); mul_772 = None | |
| + log2_21 = torch.ops.aten.log2.default(convert_element_type_965); convert_element_type_965 = None | |
| floor_21 = torch.ops.aten.floor.default(log2_21); log2_21 = None | |
| exp2_21 = torch.ops.aten.exp2.default(floor_21); floor_21 = None | |
| - mul_755 = torch.ops.aten.mul.Tensor(permute_250, exp2_21); permute_250 = None | |
| - clamp_min_77 = torch.ops.aten.clamp_min.default(mul_755, -448.0); mul_755 = None | |
| + mul_773 = torch.ops.aten.mul.Tensor(permute_250, exp2_21); permute_250 = None | |
| + clamp_min_77 = torch.ops.aten.clamp_min.default(mul_773, -448.0); mul_773 = None | |
| clamp_max_51 = torch.ops.aten.clamp_max.default(clamp_min_77, 448.0); clamp_min_77 = None | |
| - convert_element_type_929 = torch.ops.prims.convert_element_type.default(clamp_max_51, torch.float8_e4m3fn); clamp_max_51 = None | |
| - clone_86 = torch.ops.aten.clone.default(convert_element_type_929, memory_format = torch.contiguous_format); convert_element_type_929 = None | |
| + convert_element_type_966 = torch.ops.prims.convert_element_type.default(clamp_max_51, torch.float8_e4m3fn); clamp_max_51 = None | |
| + clone_86 = torch.ops.aten.clone.default(convert_element_type_966, memory_format = torch.contiguous_format); convert_element_type_966 = None | |
| permute_582 = torch.ops.aten.permute.default(clone_86, [1, 0]); clone_86 = None | |
| repeat_4 = torch.ops.aten.repeat.default(exp2_21, [25600]); exp2_21 = None | |
| - view_422 = torch.ops.aten.view.default(repeat_4, [1, -1]); repeat_4 = None | |
| + view_431 = torch.ops.aten.view.default(repeat_4, [1, -1]); repeat_4 = None | |
| reciprocal_42 = torch.ops.aten.reciprocal.default(exp2_20); exp2_20 = None | |
| - reciprocal_43 = torch.ops.aten.reciprocal.default(view_422); view_422 = None | |
| - mul_756 = torch.ops.aten.mul.Tensor(reciprocal_42, reciprocal_43); reciprocal_42 = reciprocal_43 = None | |
| + reciprocal_43 = torch.ops.aten.reciprocal.default(view_431); view_431 = None | |
| + mul_774 = torch.ops.aten.mul.Tensor(reciprocal_42, reciprocal_43); reciprocal_42 = reciprocal_43 = None | |
| full_default_267 = torch.ops.aten.full.default([], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - _scaled_mm_10 = torch.ops.aten._scaled_mm.default(convert_element_type_926, permute_582, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_926 = permute_582 = None | |
| - mul_757 = torch.ops.aten.mul.Tensor(_scaled_mm_10, mul_756); _scaled_mm_10 = mul_756 = None | |
| - permute_583 = torch.ops.aten.permute.default(convert_element_type_921, [1, 0]); convert_element_type_921 = None | |
| - convert_element_type_default_52 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_762_convert_element_type_735, torch.bfloat16); fp8_quant_pos_762_convert_element_type_735 = None | |
| - div_tensor_21 = torch.ops.aten.div.Tensor(convert_element_type_default_52, fp8_scale_pos_762_convert_element_type_735); convert_element_type_default_52 = fp8_scale_pos_762_convert_element_type_735 = None | |
| - convert_element_type_default_53 = torch.ops.prims.convert_element_type.default(div_tensor_21, torch.bfloat16); div_tensor_21 = None | |
| - mm_237 = torch.ops.aten.mm.default(permute_583, convert_element_type_default_53); permute_583 = convert_element_type_default_53 = None | |
| + _scaled_mm_10 = torch.ops.aten._scaled_mm.default(convert_element_type_963, permute_582, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_963 = permute_582 = None | |
| + mul_775 = torch.ops.aten.mul.Tensor(_scaled_mm_10, mul_774); _scaled_mm_10 = mul_774 = None | |
| + permute_583 = torch.ops.aten.permute.default(convert_element_type_957, [1, 0]); convert_element_type_957 = None | |
| + convert_element_type_default_49 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_744_bmm_11, torch.bfloat16); fp8_quant_pos_744_bmm_11 = None | |
| + div_tensor_21 = torch.ops.aten.div.Tensor(convert_element_type_default_49, fp8_scale_pos_744_bmm_11); convert_element_type_default_49 = fp8_scale_pos_744_bmm_11 = None | |
| + convert_element_type_default_50 = torch.ops.prims.convert_element_type.default(div_tensor_21, torch.bfloat16); div_tensor_21 = None | |
| + view_298 = torch.ops.aten.view.default(convert_element_type_default_50, [4096, -1]); convert_element_type_default_50 = None | |
| + cat_25 = torch.ops.aten.cat.default([view_298, mul_117], 1); view_298 = None | |
| + pow_65 = torch.ops.aten.pow.Tensor_Scalar(cat_25, 2) | |
| + mean_64 = torch.ops.aten.mean.dim(pow_65, [1], True); pow_65 = None | |
| + add_188 = torch.ops.aten.add.Scalar(mean_64, 1.1920928955078125e-07); mean_64 = None | |
| + rsqrt_85 = torch.ops.aten.rsqrt.default(add_188); add_188 = None | |
| + mul_264 = torch.ops.aten.mul.Tensor(cat_25, rsqrt_85); cat_25 = None | |
| + mul_265 = torch.ops.aten.mul.Tensor(mul_264, primals_402) | |
| + convert_element_type_732 = torch.ops.prims.convert_element_type.default(mul_265, torch.bfloat16); mul_265 = None | |
| + mm_237 = torch.ops.aten.mm.default(permute_583, convert_element_type_732); permute_583 = convert_element_type_732 = None | |
| permute_584 = torch.ops.aten.permute.default(mm_237, [1, 0]); mm_237 = None | |
| - convert_element_type_933 = torch.ops.prims.convert_element_type.default(permute_584, torch.float32); permute_584 = None | |
| - permute_585 = torch.ops.aten.permute.default(convert_element_type_933, [1, 0]); convert_element_type_933 = None | |
| - convert_element_type_default_11 = torch.ops.prims.convert_element_type.default(mul_757, torch.float32); mul_757 = None | |
| - slice_scatter_8 = torch.ops.aten.slice_scatter.default(full_default_255, convert_element_type_922, 1, 0, 3840); full_default_255 = convert_element_type_922 = None | |
| - add_381 = torch.ops.aten.add.Tensor(slice_scatter_7, slice_scatter_8); slice_scatter_7 = slice_scatter_8 = None | |
| - abs_50 = torch.ops.aten.abs.default(add_381) | |
| + convert_element_type_970 = torch.ops.prims.convert_element_type.default(permute_584, torch.float32); permute_584 = None | |
| + permute_585 = torch.ops.aten.permute.default(convert_element_type_970, [1, 0]); convert_element_type_970 = None | |
| + convert_element_type_default_8 = torch.ops.prims.convert_element_type.default(mul_775, torch.float32); mul_775 = None | |
| + slice_scatter_8 = torch.ops.aten.slice_scatter.default(full_default_255, convert_element_type_959, 1, 0, 3840); full_default_255 = convert_element_type_959 = None | |
| + add_390 = torch.ops.aten.add.Tensor(slice_scatter_7, slice_scatter_8); slice_scatter_7 = slice_scatter_8 = None | |
| + abs_50 = torch.ops.aten.abs.default(add_390) | |
| amax_22 = torch.ops.aten.amax.default(abs_50, [-1], True); abs_50 = None | |
| - convert_element_type_935 = torch.ops.prims.convert_element_type.default(amax_22, torch.float64); amax_22 = None | |
| - clamp_min_78 = torch.ops.aten.clamp_min.default(convert_element_type_935, 1e-12); convert_element_type_935 = None | |
| + convert_element_type_972 = torch.ops.prims.convert_element_type.default(amax_22, torch.float64); amax_22 = None | |
| + clamp_min_78 = torch.ops.aten.clamp_min.default(convert_element_type_972, 1e-12); convert_element_type_972 = None | |
| reciprocal_44 = torch.ops.aten.reciprocal.default(clamp_min_78); clamp_min_78 = None | |
| - mul_758 = torch.ops.aten.mul.Tensor(reciprocal_44, 448.0); reciprocal_44 = None | |
| - convert_element_type_936 = torch.ops.prims.convert_element_type.default(mul_758, torch.float32); mul_758 = None | |
| - log2_22 = torch.ops.aten.log2.default(convert_element_type_936); convert_element_type_936 = None | |
| + mul_776 = torch.ops.aten.mul.Tensor(reciprocal_44, 448.0); reciprocal_44 = None | |
| + convert_element_type_973 = torch.ops.prims.convert_element_type.default(mul_776, torch.float32); mul_776 = None | |
| + log2_22 = torch.ops.aten.log2.default(convert_element_type_973); convert_element_type_973 = None | |
| floor_22 = torch.ops.aten.floor.default(log2_22); log2_22 = None | |
| exp2_22 = torch.ops.aten.exp2.default(floor_22); floor_22 = None | |
| - convert_element_type_937 = torch.ops.prims.convert_element_type.default(add_381, torch.float32) | |
| - mul_759 = torch.ops.aten.mul.Tensor(convert_element_type_937, exp2_22); convert_element_type_937 = None | |
| - clamp_min_79 = torch.ops.aten.clamp_min.default(mul_759, -448.0); mul_759 = None | |
| + convert_element_type_974 = torch.ops.prims.convert_element_type.default(add_390, torch.float32) | |
| + mul_777 = torch.ops.aten.mul.Tensor(convert_element_type_974, exp2_22); convert_element_type_974 = None | |
| + clamp_min_79 = torch.ops.aten.clamp_min.default(mul_777, -448.0); mul_777 = None | |
| clamp_max_52 = torch.ops.aten.clamp_max.default(clamp_min_79, 448.0); clamp_min_79 = None | |
| - convert_element_type_938 = torch.ops.prims.convert_element_type.default(clamp_max_52, torch.float8_e4m3fn); clamp_max_52 = None | |
| + convert_element_type_975 = torch.ops.prims.convert_element_type.default(clamp_max_52, torch.float8_e4m3fn); clamp_max_52 = None | |
| permute_249 = torch.ops.aten.permute.default(primals_403, [1, 0]); primals_403 = None | |
| abs_18 = torch.ops.aten.abs.default(permute_249) | |
| max_2 = torch.ops.aten.max.default(abs_18); abs_18 = None | |
| - convert_element_type_939 = torch.ops.prims.convert_element_type.default(max_2, torch.float64); max_2 = None | |
| - clamp_min_80 = torch.ops.aten.clamp_min.default(convert_element_type_939, 1e-12); convert_element_type_939 = None | |
| + convert_element_type_976 = torch.ops.prims.convert_element_type.default(max_2, torch.float64); max_2 = None | |
| + clamp_min_80 = torch.ops.aten.clamp_min.default(convert_element_type_976, 1e-12); convert_element_type_976 = None | |
| reciprocal_45 = torch.ops.aten.reciprocal.default(clamp_min_80); clamp_min_80 = None | |
| - mul_760 = torch.ops.aten.mul.Tensor(reciprocal_45, 448.0); reciprocal_45 = None | |
| - convert_element_type_940 = torch.ops.prims.convert_element_type.default(mul_760, torch.float32); mul_760 = None | |
| - log2_23 = torch.ops.aten.log2.default(convert_element_type_940); convert_element_type_940 = None | |
| + mul_778 = torch.ops.aten.mul.Tensor(reciprocal_45, 448.0); reciprocal_45 = None | |
| + convert_element_type_977 = torch.ops.prims.convert_element_type.default(mul_778, torch.float32); mul_778 = None | |
| + log2_23 = torch.ops.aten.log2.default(convert_element_type_977); convert_element_type_977 = None | |
| floor_23 = torch.ops.aten.floor.default(log2_23); log2_23 = None | |
| exp2_23 = torch.ops.aten.exp2.default(floor_23); floor_23 = None | |
| - mul_761 = torch.ops.aten.mul.Tensor(permute_249, exp2_23); permute_249 = None | |
| - clamp_min_81 = torch.ops.aten.clamp_min.default(mul_761, -448.0); mul_761 = None | |
| + mul_779 = torch.ops.aten.mul.Tensor(permute_249, exp2_23); permute_249 = None | |
| + clamp_min_81 = torch.ops.aten.clamp_min.default(mul_779, -448.0); mul_779 = None | |
| clamp_max_53 = torch.ops.aten.clamp_max.default(clamp_min_81, 448.0); clamp_min_81 = None | |
| - convert_element_type_941 = torch.ops.prims.convert_element_type.default(clamp_max_53, torch.float8_e4m3fn); clamp_max_53 = None | |
| - clone_87 = torch.ops.aten.clone.default(convert_element_type_941, memory_format = torch.contiguous_format); convert_element_type_941 = None | |
| + convert_element_type_978 = torch.ops.prims.convert_element_type.default(clamp_max_53, torch.float8_e4m3fn); clamp_max_53 = None | |
| + clone_87 = torch.ops.aten.clone.default(convert_element_type_978, memory_format = torch.contiguous_format); convert_element_type_978 = None | |
| permute_588 = torch.ops.aten.permute.default(clone_87, [1, 0]); clone_87 = None | |
| repeat_5 = torch.ops.aten.repeat.default(exp2_23, [14848]); exp2_23 = None | |
| - view_427 = torch.ops.aten.view.default(repeat_5, [1, -1]); repeat_5 = None | |
| + view_436 = torch.ops.aten.view.default(repeat_5, [1, -1]); repeat_5 = None | |
| reciprocal_46 = torch.ops.aten.reciprocal.default(exp2_22); exp2_22 = None | |
| - reciprocal_47 = torch.ops.aten.reciprocal.default(view_427); view_427 = None | |
| - mul_762 = torch.ops.aten.mul.Tensor(reciprocal_46, reciprocal_47); reciprocal_46 = reciprocal_47 = None | |
| - _scaled_mm_11 = torch.ops.aten._scaled_mm.default(convert_element_type_938, permute_588, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_938 = permute_588 = None | |
| - mul_763 = torch.ops.aten.mul.Tensor(_scaled_mm_11, mul_762); _scaled_mm_11 = mul_762 = None | |
| - permute_589 = torch.ops.aten.permute.default(add_381, [1, 0]); add_381 = None | |
| - convert_element_type_default_48 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_754_cat_21, torch.bfloat16); fp8_quant_pos_754_cat_21 = None | |
| - div_tensor_19 = torch.ops.aten.div.Tensor(convert_element_type_default_48, fp8_scale_pos_754_cat_21); convert_element_type_default_48 = fp8_scale_pos_754_cat_21 = None | |
| - convert_element_type_default_49 = torch.ops.prims.convert_element_type.default(div_tensor_19, torch.bfloat16); div_tensor_19 = None | |
| - convert_element_type_679 = torch.ops.prims.convert_element_type.default(convert_element_type_default_49, torch.float32); convert_element_type_default_49 = None | |
| - pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_679, 2) | |
| + reciprocal_47 = torch.ops.aten.reciprocal.default(view_436); view_436 = None | |
| + mul_780 = torch.ops.aten.mul.Tensor(reciprocal_46, reciprocal_47); reciprocal_46 = reciprocal_47 = None | |
| + _scaled_mm_11 = torch.ops.aten._scaled_mm.default(convert_element_type_975, permute_588, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_975 = permute_588 = None | |
| + mul_781 = torch.ops.aten.mul.Tensor(_scaled_mm_11, mul_780); _scaled_mm_11 = mul_780 = None | |
| + permute_589 = torch.ops.aten.permute.default(add_390, [1, 0]); add_390 = None | |
| + view_297 = torch.ops.aten.view.default(bmm_12, [4096, -1]); bmm_12 = None | |
| + cat_24 = torch.ops.aten.cat.default([view_297, mul_115], 1); view_297 = None | |
| + pow_64 = torch.ops.aten.pow.Tensor_Scalar(cat_24, 2) | |
| + mean_63 = torch.ops.aten.mean.dim(pow_64, [1], True); pow_64 = None | |
| + add_187 = torch.ops.aten.add.Scalar(mean_63, 1.1920928955078125e-07); mean_63 = None | |
| + rsqrt_84 = torch.ops.aten.rsqrt.default(add_187); add_187 = None | |
| + mul_262 = torch.ops.aten.mul.Tensor(cat_24, rsqrt_84); cat_24 = None | |
| + mul_263 = torch.ops.aten.mul.Tensor(mul_262, primals_401) | |
| + convert_element_type_724 = torch.ops.prims.convert_element_type.default(mul_263, torch.bfloat16); mul_263 = None | |
| + mm_238 = torch.ops.aten.mm.default(permute_589, convert_element_type_724); permute_589 = convert_element_type_724 = None | |
| + permute_590 = torch.ops.aten.permute.default(mm_238, [1, 0]); mm_238 = None | |
| + convert_element_type_982 = torch.ops.prims.convert_element_type.default(permute_590, torch.float32); permute_590 = None | |
| + permute_591 = torch.ops.aten.permute.default(convert_element_type_982, [1, 0]); convert_element_type_982 = None | |
| + convert_element_type_default_7 = torch.ops.prims.convert_element_type.default(mul_781, torch.float32); mul_781 = None | |
| + mul_782 = torch.ops.aten.mul.Tensor(convert_element_type_default_8, primals_402); primals_402 = None | |
| + mul_784 = torch.ops.aten.mul.Tensor(mul_264, mul_782) | |
| + sum_114 = torch.ops.aten.sum.dim_IntList(mul_784, [1], True); mul_784 = None | |
| + div_97 = torch.ops.aten.div.Tensor(mul_264, 25600) | |
| + mul_785 = torch.ops.aten.mul.Tensor(div_97, sum_114); div_97 = sum_114 = None | |
| + sub_199 = torch.ops.aten.sub.Tensor(mul_782, mul_785); mul_782 = mul_785 = None | |
| + mul_786 = torch.ops.aten.mul.Tensor(sub_199, rsqrt_85); sub_199 = rsqrt_85 = None | |
| + mul_787 = torch.ops.aten.mul.Tensor(convert_element_type_default_8, mul_264); convert_element_type_default_8 = mul_264 = None | |
| + sum_115 = torch.ops.aten.sum.dim_IntList(mul_787, [0]); mul_787 = None | |
| + mul_788 = torch.ops.aten.mul.Tensor(convert_element_type_default_7, primals_401); primals_401 = None | |
| + mul_790 = torch.ops.aten.mul.Tensor(mul_262, mul_788) | |
| + sum_116 = torch.ops.aten.sum.dim_IntList(mul_790, [1], True); mul_790 = None | |
| + div_98 = torch.ops.aten.div.Tensor(mul_262, 14848) | |
| + mul_791 = torch.ops.aten.mul.Tensor(div_98, sum_116); div_98 = sum_116 = None | |
| + sub_200 = torch.ops.aten.sub.Tensor(mul_788, mul_791); mul_788 = mul_791 = None | |
| + mul_792 = torch.ops.aten.mul.Tensor(sub_200, rsqrt_84); sub_200 = rsqrt_84 = None | |
| + mul_793 = torch.ops.aten.mul.Tensor(convert_element_type_default_7, mul_262); convert_element_type_default_7 = mul_262 = None | |
| + sum_117 = torch.ops.aten.sum.dim_IntList(mul_793, [0]); mul_793 = None | |
| + slice_87 = torch.ops.aten.slice.Tensor(mul_786, 1, 0, 21504) | |
| + slice_88 = torch.ops.aten.slice.Tensor(mul_786, 1, 21504, 25600); mul_786 = None | |
| + convert_element_type_984 = torch.ops.prims.convert_element_type.default(slice_87, torch.bfloat16); slice_87 = None | |
| + slice_89 = torch.ops.aten.slice.Tensor(mul_792, 1, 0, 10752) | |
| + slice_90 = torch.ops.aten.slice.Tensor(mul_792, 1, 10752, 14848); mul_792 = None | |
| + convert_element_type_985 = torch.ops.prims.convert_element_type.default(slice_89, torch.bfloat16); slice_89 = None | |
| + view_440 = torch.ops.aten.view.default(convert_element_type_984, [4096, 48, 448]); convert_element_type_984 = None | |
| + view_441 = torch.ops.aten.view.default(convert_element_type_985, [4096, 48, 224]); convert_element_type_985 = None | |
| + convert_element_type_default_41 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_720_convert_element_type_545, torch.bfloat16); fp8_quant_pos_720_convert_element_type_545 = None | |
| + div_tensor_17 = torch.ops.aten.div.Tensor(convert_element_type_default_41, fp8_scale_pos_720_convert_element_type_545); convert_element_type_default_41 = fp8_scale_pos_720_convert_element_type_545 = None | |
| + convert_element_type_default_42 = torch.ops.prims.convert_element_type.default(div_tensor_17, torch.bfloat16); div_tensor_17 = None | |
| + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_345, torch.bfloat16); primals_345 = None | |
| + permute_191 = torch.ops.aten.permute.default(convert_element_type_default_42, [0, 2, 1]) | |
| + permute_192 = torch.ops.aten.permute.default(convert_element_type_548, [1, 0]); convert_element_type_548 = None | |
| + clone_66 = torch.ops.aten.clone.default(permute_191, memory_format = torch.contiguous_format) | |
| + view_229 = torch.ops.aten.view.default(clone_66, [458752, 256]); clone_66 = None | |
| + mm_69 = torch.ops.aten.mm.default(view_229, permute_192) | |
| + view_230 = torch.ops.aten.view.default(mm_69, [4096, 112, 256]); mm_69 = None | |
| + permute_193 = torch.ops.aten.permute.default(view_230, [0, 2, 1]); view_230 = None | |
| + clone_67 = torch.ops.aten.clone.default(permute_193, memory_format = torch.contiguous_format); permute_193 = None | |
| + slice_33 = torch.ops.aten.slice.Tensor(clone_67, 1, 0, 128); clone_67 = None | |
| + convert_element_type_560 = torch.ops.prims.convert_element_type.default(primals_349, torch.bfloat16); primals_349 = None | |
| + permute_204 = torch.ops.aten.permute.default(convert_element_type_560, [1, 0]); convert_element_type_560 = None | |
| + mm_73 = torch.ops.aten.mm.default(view_229, permute_204) | |
| + view_238 = torch.ops.aten.view.default(mm_73, [4096, 112, 192]); mm_73 = None | |
| + permute_205 = torch.ops.aten.permute.default(view_238, [0, 2, 1]); view_238 = None | |
| + clone_75 = torch.ops.aten.clone.default(permute_205, memory_format = torch.contiguous_format); permute_205 = None | |
| + slice_37 = torch.ops.aten.slice.Tensor(clone_75, 1, 0, 96); clone_75 = None | |
| + convert_element_type_674 = torch.ops.prims.convert_element_type.default(primals_385, torch.bfloat16); primals_385 = None | |
| + permute_230 = torch.ops.aten.permute.default(convert_element_type_674, [1, 0]); convert_element_type_674 = None | |
| + mm_89 = torch.ops.aten.mm.default(view_273, permute_230) | |
| + view_274 = torch.ops.aten.view.default(mm_89, [4096, 112, 256]); mm_89 = None | |
| + permute_231 = torch.ops.aten.permute.default(view_274, [0, 2, 1]); view_274 = None | |
| + clone_79 = torch.ops.aten.clone.default(permute_231, memory_format = torch.contiguous_format); permute_231 = None | |
| + slice_58 = torch.ops.aten.slice.Tensor(clone_79, 1, 0, 128); clone_79 = None | |
| + add_175 = torch.ops.aten.add.Tensor(slice_58, slice_33); slice_58 = slice_33 = None | |
| + cat_21 = torch.ops.aten.cat.default([add_175, slice_37], 1); add_175 = slice_37 = None | |
| + convert_element_type_680 = torch.ops.prims.convert_element_type.default(cat_21, torch.float32); cat_21 = None | |
| + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_680, 2) | |
| mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None | |
| add_177 = torch.ops.aten.add.Scalar(mean_57, 1.1920928955078125e-07); mean_57 = None | |
| rsqrt_78 = torch.ops.aten.rsqrt.default(add_177); add_177 = None | |
| - mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_679, rsqrt_78); convert_element_type_679 = None | |
| + mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_680, rsqrt_78) | |
| mul_247 = torch.ops.aten.mul.Tensor(mul_246, primals_387) | |
| - convert_element_type_682 = torch.ops.prims.convert_element_type.default(mul_247, torch.bfloat16) | |
| + convert_element_type_681 = torch.ops.prims.convert_element_type.default(mul_247, torch.bfloat16); mul_247 = None | |
| + slice_66 = torch.ops.aten.slice.Tensor(mm_93, 1, 0, 768) | |
| + convert_element_type_696 = torch.ops.prims.convert_element_type.default(slice_66, torch.float32); slice_66 = None | |
| + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_696, 2) | |
| + mean_59 = torch.ops.aten.mean.dim(pow_60, [1], True); pow_60 = None | |
| + add_181 = torch.ops.aten.add.Scalar(mean_59, 1.1920928955078125e-07); mean_59 = None | |
| + rsqrt_80 = torch.ops.aten.rsqrt.default(add_181); add_181 = None | |
| + mul_250 = torch.ops.aten.mul.Tensor(convert_element_type_696, rsqrt_80) | |
| + mul_251 = torch.ops.aten.mul.Tensor(mul_250, primals_393) | |
| + convert_element_type_697 = torch.ops.prims.convert_element_type.default(mul_251, torch.bfloat16); mul_251 = None | |
| + sigmoid_35 = torch.ops.aten.sigmoid.default(convert_element_type_697) | |
| + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_697, sigmoid_35) | |
| + convert_element_type_700 = torch.ops.prims.convert_element_type.default(primals_395, torch.bfloat16); primals_395 = None | |
| + permute_243 = torch.ops.aten.permute.default(convert_element_type_700, [1, 0]); convert_element_type_700 = None | |
| + mm_95 = torch.ops.aten.mm.default(mul_254, permute_243) | |
| slice_68 = torch.ops.aten.slice.Tensor(mm_95, 1, 0, 512) | |
| - convert_element_type_705 = torch.ops.prims.convert_element_type.default(slice_68, torch.float32); slice_68 = None | |
| - pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_705, 2) | |
| + convert_element_type_706 = torch.ops.prims.convert_element_type.default(slice_68, torch.float32); slice_68 = None | |
| + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_706, 2) | |
| mean_61 = torch.ops.aten.mean.dim(pow_62, [1], True); pow_62 = None | |
| add_184 = torch.ops.aten.add.Scalar(mean_61, 1.1920928955078125e-07); mean_61 = None | |
| rsqrt_82 = torch.ops.aten.rsqrt.default(add_184); add_184 = None | |
| - mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_705, rsqrt_82); convert_element_type_705 = None | |
| + mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_706, rsqrt_82) | |
| mul_257 = torch.ops.aten.mul.Tensor(mul_256, primals_397) | |
| - sigmoid_37 = torch.ops.aten.sigmoid.default(mul_257) | |
| - mul_260 = torch.ops.aten.mul.Tensor(mul_257, sigmoid_37) | |
| - convert_element_type_707 = torch.ops.prims.convert_element_type.default(mul_260, torch.bfloat16); mul_260 = None | |
| - convert_element_type_708 = torch.ops.prims.convert_element_type.default(primals_399, torch.bfloat16); primals_399 = None | |
| - permute_245 = torch.ops.aten.permute.default(convert_element_type_708, [1, 0]); convert_element_type_708 = None | |
| - mm_97 = torch.ops.aten.mm.default(convert_element_type_707, permute_245) | |
| + convert_element_type_707 = torch.ops.prims.convert_element_type.default(mul_257, torch.bfloat16); mul_257 = None | |
| + sigmoid_37 = torch.ops.aten.sigmoid.default(convert_element_type_707) | |
| + mul_260 = torch.ops.aten.mul.Tensor(convert_element_type_707, sigmoid_37) | |
| + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_399, torch.bfloat16); primals_399 = None | |
| + permute_245 = torch.ops.aten.permute.default(convert_element_type_710, [1, 0]); convert_element_type_710 = None | |
| + mm_97 = torch.ops.aten.mm.default(mul_260, permute_245) | |
| slice_70 = torch.ops.aten.slice.Tensor(mm_97, 1, 0, 21504) | |
| view_283 = torch.ops.aten.view.default(slice_70, [4096, -1, 448]); slice_70 = None | |
| slice_72 = torch.ops.aten.slice.Tensor(view_283, 2, 0, 224); view_283 = None | |
| expand_18 = torch.ops.aten.expand.default(slice_72, [4096, 48, 224]); slice_72 = None | |
| - expand_19 = torch.ops.aten.expand.default(convert_element_type_682, [4096, 224, 112]) | |
| + expand_19 = torch.ops.aten.expand.default(convert_element_type_681, [4096, 224, 112]) | |
| bmm_10 = torch.ops.aten.bmm.default(expand_18, expand_19) | |
| - permute_248 = torch.ops.aten.permute.default(mul_247, [0, 2, 1]) | |
| - convert_element_type_724 = torch.ops.prims.convert_element_type.default(permute_248, torch.bfloat16); permute_248 = None | |
| expand_22 = torch.ops.aten.expand.default(bmm_10, [4096, 48, 112]); bmm_10 = None | |
| - expand_23 = torch.ops.aten.expand.default(convert_element_type_724, [4096, 112, 224]); convert_element_type_724 = None | |
| - bmm_12 = torch.ops.aten.bmm.default(expand_22, expand_23) | |
| - view_297 = torch.ops.aten.view.default(bmm_12, [4096, -1]); bmm_12 = None | |
| - cat_24 = torch.ops.aten.cat.default([view_297, mul_115], 1); view_297 = None | |
| - pow_64 = torch.ops.aten.pow.Tensor_Scalar(cat_24, 2) | |
| - mean_63 = torch.ops.aten.mean.dim(pow_64, [1], True); pow_64 = None | |
| - add_187 = torch.ops.aten.add.Scalar(mean_63, 1.1920928955078125e-07); mean_63 = None | |
| - rsqrt_84 = torch.ops.aten.rsqrt.default(add_187); add_187 = None | |
| - mul_262 = torch.ops.aten.mul.Tensor(cat_24, rsqrt_84); cat_24 = None | |
| - mul_263 = torch.ops.aten.mul.Tensor(mul_262, primals_401) | |
| - convert_element_type_727 = torch.ops.prims.convert_element_type.default(mul_263, torch.bfloat16); mul_263 = None | |
| - mm_238 = torch.ops.aten.mm.default(permute_589, convert_element_type_727); permute_589 = convert_element_type_727 = None | |
| - permute_590 = torch.ops.aten.permute.default(mm_238, [1, 0]); mm_238 = None | |
| - convert_element_type_945 = torch.ops.prims.convert_element_type.default(permute_590, torch.float32); permute_590 = None | |
| - permute_591 = torch.ops.aten.permute.default(convert_element_type_945, [1, 0]); convert_element_type_945 = None | |
| - convert_element_type_default_10 = torch.ops.prims.convert_element_type.default(mul_763, torch.float32); mul_763 = None | |
| - mul_764 = torch.ops.aten.mul.Tensor(convert_element_type_default_11, primals_402); primals_402 = None | |
| - convert_element_type_default_50 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_755_cat_22, torch.bfloat16); fp8_quant_pos_755_cat_22 = None | |
| - div_tensor_20 = torch.ops.aten.div.Tensor(convert_element_type_default_50, fp8_scale_pos_755_cat_22); convert_element_type_default_50 = fp8_scale_pos_755_cat_22 = None | |
| - convert_element_type_default_51 = torch.ops.prims.convert_element_type.default(div_tensor_20, torch.bfloat16); div_tensor_20 = None | |
| - convert_element_type_680 = torch.ops.prims.convert_element_type.default(convert_element_type_default_51, torch.float32); convert_element_type_default_51 = None | |
| - pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_680, 2) | |
| + permute_592 = torch.ops.aten.permute.default(expand_22, [0, 2, 1]); expand_22 = None | |
| + bmm_13 = torch.ops.aten.bmm.default(permute_592, view_441); permute_592 = None | |
| + permute_235 = torch.ops.aten.permute.default(convert_element_type_681, [0, 2, 1]) | |
| + expand_23 = torch.ops.aten.expand.default(permute_235, [4096, 112, 224]) | |
| + permute_593 = torch.ops.aten.permute.default(expand_23, [0, 2, 1]); expand_23 = None | |
| + bmm_14 = torch.ops.aten.bmm.default(view_441, permute_593); view_441 = permute_593 = None | |
| + permute_594 = torch.ops.aten.permute.default(bmm_13, [0, 2, 1]); bmm_13 = None | |
| + convert_element_type_default_47 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_740_cat_22, torch.bfloat16); fp8_quant_pos_740_cat_22 = None | |
| + div_tensor_20 = torch.ops.aten.div.Tensor(convert_element_type_default_47, fp8_scale_pos_740_cat_22); convert_element_type_default_47 = fp8_scale_pos_740_cat_22 = None | |
| + convert_element_type_default_48 = torch.ops.prims.convert_element_type.default(div_tensor_20, torch.bfloat16); div_tensor_20 = None | |
| + convert_element_type_682 = torch.ops.prims.convert_element_type.default(convert_element_type_default_48, torch.float32); convert_element_type_default_48 = None | |
| + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_682, 2) | |
| mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None | |
| add_178 = torch.ops.aten.add.Scalar(mean_58, 1.1920928955078125e-07); mean_58 = None | |
| rsqrt_79 = torch.ops.aten.rsqrt.default(add_178); add_178 = None | |
| - mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_680, rsqrt_79); convert_element_type_680 = None | |
| + mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_682, rsqrt_79) | |
| mul_249 = torch.ops.aten.mul.Tensor(mul_248, primals_388) | |
| - index_39 = torch.ops.aten.index.Tensor(mul_247, [sub_32]); mul_247 = None | |
| - cat_23 = torch.ops.aten.cat.default([index_39, mul_249], 1); index_39 = None | |
| - permute_247 = torch.ops.aten.permute.default(cat_23, [0, 2, 1]) | |
| - convert_element_type_721 = torch.ops.prims.convert_element_type.default(permute_247, torch.bfloat16); permute_247 = None | |
| - expand_21 = torch.ops.aten.expand.default(convert_element_type_721, [4096, 112, 448]); convert_element_type_721 = None | |
| - bmm_11 = torch.ops.aten.bmm.default(expand_20, expand_21) | |
| - view_298 = torch.ops.aten.view.default(bmm_11, [4096, -1]); bmm_11 = None | |
| - cat_25 = torch.ops.aten.cat.default([view_298, mul_117], 1); view_298 = None | |
| - pow_65 = torch.ops.aten.pow.Tensor_Scalar(cat_25, 2) | |
| - mean_64 = torch.ops.aten.mean.dim(pow_65, [1], True); pow_65 = None | |
| - add_188 = torch.ops.aten.add.Scalar(mean_64, 1.1920928955078125e-07); mean_64 = None | |
| - rsqrt_85 = torch.ops.aten.rsqrt.default(add_188); add_188 = None | |
| - mul_264 = torch.ops.aten.mul.Tensor(cat_25, rsqrt_85); cat_25 = None | |
| - mul_766 = torch.ops.aten.mul.Tensor(mul_264, mul_764) | |
| - sum_114 = torch.ops.aten.sum.dim_IntList(mul_766, [1], True); mul_766 = None | |
| - div_97 = torch.ops.aten.div.Tensor(mul_264, 25600) | |
| - mul_767 = torch.ops.aten.mul.Tensor(div_97, sum_114); div_97 = sum_114 = None | |
| - sub_208 = torch.ops.aten.sub.Tensor(mul_764, mul_767); mul_764 = mul_767 = None | |
| - mul_768 = torch.ops.aten.mul.Tensor(sub_208, rsqrt_85); sub_208 = rsqrt_85 = None | |
| - mul_769 = torch.ops.aten.mul.Tensor(convert_element_type_default_11, mul_264); convert_element_type_default_11 = mul_264 = None | |
| - sum_115 = torch.ops.aten.sum.dim_IntList(mul_769, [0]); mul_769 = None | |
| - mul_770 = torch.ops.aten.mul.Tensor(convert_element_type_default_10, primals_401); primals_401 = None | |
| - mul_772 = torch.ops.aten.mul.Tensor(mul_262, mul_770) | |
| - sum_116 = torch.ops.aten.sum.dim_IntList(mul_772, [1], True); mul_772 = None | |
| - div_98 = torch.ops.aten.div.Tensor(mul_262, 14848) | |
| - mul_773 = torch.ops.aten.mul.Tensor(div_98, sum_116); div_98 = sum_116 = None | |
| - sub_209 = torch.ops.aten.sub.Tensor(mul_770, mul_773); mul_770 = mul_773 = None | |
| - mul_774 = torch.ops.aten.mul.Tensor(sub_209, rsqrt_84); sub_209 = rsqrt_84 = None | |
| - mul_775 = torch.ops.aten.mul.Tensor(convert_element_type_default_10, mul_262); convert_element_type_default_10 = mul_262 = None | |
| - sum_117 = torch.ops.aten.sum.dim_IntList(mul_775, [0]); mul_775 = None | |
| - slice_87 = torch.ops.aten.slice.Tensor(mul_768, 1, 0, 21504) | |
| - slice_88 = torch.ops.aten.slice.Tensor(mul_768, 1, 21504, 25600); mul_768 = None | |
| - convert_element_type_947 = torch.ops.prims.convert_element_type.default(slice_87, torch.bfloat16); slice_87 = None | |
| - slice_89 = torch.ops.aten.slice.Tensor(mul_774, 1, 0, 10752) | |
| - slice_90 = torch.ops.aten.slice.Tensor(mul_774, 1, 10752, 14848); mul_774 = None | |
| - convert_element_type_948 = torch.ops.prims.convert_element_type.default(slice_89, torch.bfloat16); slice_89 = None | |
| - view_431 = torch.ops.aten.view.default(convert_element_type_947, [4096, 48, 448]); convert_element_type_947 = None | |
| - view_432 = torch.ops.aten.view.default(convert_element_type_948, [4096, 48, 224]); convert_element_type_948 = None | |
| - permute_592 = torch.ops.aten.permute.default(expand_22, [0, 2, 1]); expand_22 = None | |
| - bmm_13 = torch.ops.aten.bmm.default(permute_592, view_432); permute_592 = None | |
| - permute_593 = torch.ops.aten.permute.default(expand_23, [0, 2, 1]); expand_23 = None | |
| - bmm_14 = torch.ops.aten.bmm.default(view_432, permute_593); view_432 = permute_593 = None | |
| - convert_element_type_953 = torch.ops.prims.convert_element_type.default(bmm_13, torch.float32); bmm_13 = None | |
| - permute_594 = torch.ops.aten.permute.default(convert_element_type_953, [0, 2, 1]); convert_element_type_953 = None | |
| - permute_595 = torch.ops.aten.permute.default(expand_20, [0, 2, 1]); expand_20 = None | |
| - bmm_15 = torch.ops.aten.bmm.default(permute_595, view_431); permute_595 = None | |
| - permute_596 = torch.ops.aten.permute.default(expand_21, [0, 2, 1]); expand_21 = None | |
| - bmm_16 = torch.ops.aten.bmm.default(view_431, permute_596); view_431 = permute_596 = None | |
| - convert_element_type_958 = torch.ops.prims.convert_element_type.default(bmm_15, torch.float32); bmm_15 = None | |
| - permute_597 = torch.ops.aten.permute.default(convert_element_type_958, [0, 2, 1]); convert_element_type_958 = None | |
| - permute_598 = torch.ops.aten.permute.default(expand_18, [0, 2, 1]); expand_18 = None | |
| - bmm_17 = torch.ops.aten.bmm.default(permute_598, bmm_14); permute_598 = None | |
| - permute_599 = torch.ops.aten.permute.default(expand_19, [0, 2, 1]); expand_19 = None | |
| - bmm_18 = torch.ops.aten.bmm.default(bmm_14, permute_599); bmm_14 = permute_599 = None | |
| - convert_element_type_963 = torch.ops.prims.convert_element_type.default(bmm_17, torch.float32); bmm_17 = None | |
| - add_382 = torch.ops.aten.add.Tensor(permute_594, convert_element_type_963); permute_594 = convert_element_type_963 = None | |
| - full_default_272 = torch.ops.aten.full.default([4096, 48, 448], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - slice_scatter_9 = torch.ops.aten.slice_scatter.default(full_default_272, bmm_18, 2, 0, 224); full_default_272 = bmm_18 = None | |
| + convert_element_type_683 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None | |
| slice_69 = torch.ops.aten.slice.Tensor(mm_95, 1, 512, 9223372036854775807); mm_95 = None | |
| index_37 = torch.ops.aten.index.Tensor(slice_69, [sub_32]); slice_69 = None | |
| add_183 = torch.ops.aten.add.Tensor(mm_96, index_37); mm_96 = index_37 = None | |
| - convert_element_type_706 = torch.ops.prims.convert_element_type.default(add_183, torch.float32); add_183 = None | |
| - pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_706, 2) | |
| + convert_element_type_708 = torch.ops.prims.convert_element_type.default(add_183, torch.float32); add_183 = None | |
| + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_708, 2) | |
| mean_62 = torch.ops.aten.mean.dim(pow_63, [1], True); pow_63 = None | |
| add_185 = torch.ops.aten.add.Scalar(mean_62, 1.1920928955078125e-07); mean_62 = None | |
| rsqrt_83 = torch.ops.aten.rsqrt.default(add_185); add_185 = None | |
| - mul_258 = torch.ops.aten.mul.Tensor(convert_element_type_706, rsqrt_83); convert_element_type_706 = None | |
| + mul_258 = torch.ops.aten.mul.Tensor(convert_element_type_708, rsqrt_83) | |
| mul_259 = torch.ops.aten.mul.Tensor(mul_258, primals_398) | |
| - sigmoid_38 = torch.ops.aten.sigmoid.default(mul_259) | |
| - mul_261 = torch.ops.aten.mul.Tensor(mul_259, sigmoid_38) | |
| - convert_element_type_711 = torch.ops.prims.convert_element_type.default(mul_261, torch.bfloat16); mul_261 = None | |
| - convert_element_type_712 = torch.ops.prims.convert_element_type.default(primals_400, torch.bfloat16); primals_400 = None | |
| - permute_246 = torch.ops.aten.permute.default(convert_element_type_712, [1, 0]); convert_element_type_712 = None | |
| - mm_98 = torch.ops.aten.mm.default(convert_element_type_711, permute_246) | |
| + convert_element_type_709 = torch.ops.prims.convert_element_type.default(mul_259, torch.bfloat16); mul_259 = None | |
| + sigmoid_38 = torch.ops.aten.sigmoid.default(convert_element_type_709) | |
| + mul_261 = torch.ops.aten.mul.Tensor(convert_element_type_709, sigmoid_38) | |
| + convert_element_type_713 = torch.ops.prims.convert_element_type.default(primals_400, torch.bfloat16); primals_400 = None | |
| + permute_246 = torch.ops.aten.permute.default(convert_element_type_713, [1, 0]); convert_element_type_713 = None | |
| + mm_98 = torch.ops.aten.mm.default(mul_261, permute_246) | |
| slice_71 = torch.ops.aten.slice.Tensor(mm_97, 1, 21504, 9223372036854775807); mm_97 = None | |
| index_38 = torch.ops.aten.index.Tensor(slice_71, [sub_32]); slice_71 = None | |
| add_186 = torch.ops.aten.add.Tensor(mm_98, index_38); mm_98 = index_38 = None | |
| view_284 = torch.ops.aten.view.default(add_186, [4096, -1, 448]); add_186 = None | |
| + index_39 = torch.ops.aten.index.Tensor(convert_element_type_681, [sub_32]); convert_element_type_681 = None | |
| + cat_23 = torch.ops.aten.cat.default([index_39, convert_element_type_683], 1); index_39 = None | |
| expand_16 = torch.ops.aten.expand.default(view_284, [4096, 48, 448]); view_284 = None | |
| + expand_17 = torch.ops.aten.expand.default(cat_23, [4096, 448, 112]) | |
| + bmm_9 = torch.ops.aten.bmm.default(expand_16, expand_17) | |
| + expand_20 = torch.ops.aten.expand.default(bmm_9, [4096, 48, 112]); bmm_9 = None | |
| + permute_595 = torch.ops.aten.permute.default(expand_20, [0, 2, 1]); expand_20 = None | |
| + bmm_15 = torch.ops.aten.bmm.default(permute_595, view_440); permute_595 = None | |
| + permute_247 = torch.ops.aten.permute.default(cat_23, [0, 2, 1]); cat_23 = None | |
| + expand_21 = torch.ops.aten.expand.default(permute_247, [4096, 112, 448]); permute_247 = None | |
| + permute_596 = torch.ops.aten.permute.default(expand_21, [0, 2, 1]); expand_21 = None | |
| + bmm_16 = torch.ops.aten.bmm.default(view_440, permute_596); view_440 = permute_596 = None | |
| + permute_597 = torch.ops.aten.permute.default(bmm_15, [0, 2, 1]); bmm_15 = None | |
| + permute_598 = torch.ops.aten.permute.default(expand_18, [0, 2, 1]); expand_18 = None | |
| + bmm_17 = torch.ops.aten.bmm.default(permute_598, bmm_14); permute_598 = None | |
| + permute_599 = torch.ops.aten.permute.default(expand_19, [0, 2, 1]); expand_19 = None | |
| + bmm_18 = torch.ops.aten.bmm.default(bmm_14, permute_599); bmm_14 = permute_599 = None | |
| + add_391 = torch.ops.aten.add.Tensor(permute_594, bmm_17); permute_594 = bmm_17 = None | |
| + full_default_272 = torch.ops.aten.full.default([4096, 48, 448], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| + slice_scatter_9 = torch.ops.aten.slice_scatter.default(full_default_272, bmm_18, 2, 0, 224); full_default_272 = bmm_18 = None | |
| permute_600 = torch.ops.aten.permute.default(expand_16, [0, 2, 1]); expand_16 = None | |
| bmm_19 = torch.ops.aten.bmm.default(permute_600, bmm_16); permute_600 = None | |
| - convert_element_type_715 = torch.ops.prims.convert_element_type.default(cat_23, torch.bfloat16); cat_23 = None | |
| - expand_17 = torch.ops.aten.expand.default(convert_element_type_715, [4096, 448, 112]); convert_element_type_715 = None | |
| permute_601 = torch.ops.aten.permute.default(expand_17, [0, 2, 1]); expand_17 = None | |
| bmm_20 = torch.ops.aten.bmm.default(bmm_16, permute_601); bmm_16 = permute_601 = None | |
| - convert_element_type_968 = torch.ops.prims.convert_element_type.default(bmm_19, torch.float32); bmm_19 = None | |
| - add_383 = torch.ops.aten.add.Tensor(permute_597, convert_element_type_968); permute_597 = convert_element_type_968 = None | |
| - slice_91 = torch.ops.aten.slice.Tensor(add_383, 1, 0, 224) | |
| - slice_92 = torch.ops.aten.slice.Tensor(add_383, 1, 224, 448); add_383 = None | |
| - full_default_273 = torch.ops.aten.full.default([4096, 224, 112], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| + add_392 = torch.ops.aten.add.Tensor(permute_597, bmm_19); permute_597 = bmm_19 = None | |
| + slice_91 = torch.ops.aten.slice.Tensor(add_392, 1, 0, 224) | |
| + slice_92 = torch.ops.aten.slice.Tensor(add_392, 1, 224, 448); add_392 = None | |
| + full_default_273 = torch.ops.aten.full.default([4096, 224, 112], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| index_put_23 = torch.ops.aten.index_put.default(full_default_273, [sub_32], slice_91, True); full_default_273 = slice_91 = None | |
| - add_384 = torch.ops.aten.add.Tensor(add_382, index_put_23); add_382 = index_put_23 = None | |
| - view_445 = torch.ops.aten.view.default(bmm_20, [4096, 21504]); bmm_20 = None | |
| - view_446 = torch.ops.aten.view.default(slice_scatter_9, [4096, 21504]); slice_scatter_9 = None | |
| + add_393 = torch.ops.aten.add.Tensor(add_391, index_put_23); add_391 = index_put_23 = None | |
| + view_454 = torch.ops.aten.view.default(bmm_20, [4096, 21504]); bmm_20 = None | |
| + view_455 = torch.ops.aten.view.default(slice_scatter_9, [4096, 21504]); slice_scatter_9 = None | |
| full_default_274 = torch.ops.aten.full.default([4096, 21504], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - index_put_24 = torch.ops.aten.index_put.default(full_default_274, [sub_32], view_445, True); full_default_274 = None | |
| + index_put_24 = torch.ops.aten.index_put.default(full_default_274, [sub_32], view_454, True); full_default_274 = None | |
| full_default_275 = torch.ops.aten.full.default([4096, 43008], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_10 = torch.ops.aten.slice_scatter.default(full_default_275, index_put_24, 1, 21504, 9223372036854775807); index_put_24 = None | |
| - permute_602 = torch.ops.aten.permute.default(view_445, [1, 0]) | |
| - mm_239 = torch.ops.aten.mm.default(permute_602, convert_element_type_711); permute_602 = convert_element_type_711 = None | |
| + permute_602 = torch.ops.aten.permute.default(view_454, [1, 0]) | |
| + mm_239 = torch.ops.aten.mm.default(permute_602, mul_261); permute_602 = mul_261 = None | |
| permute_604 = torch.ops.aten.permute.default(permute_246, [1, 0]); permute_246 = None | |
| - mm_240 = torch.ops.aten.mm.default(view_445, permute_604); view_445 = permute_604 = None | |
| - convert_element_type_973 = torch.ops.prims.convert_element_type.default(mm_239, torch.float32); mm_239 = None | |
| - convert_element_type_974 = torch.ops.prims.convert_element_type.default(mm_240, torch.float32); mm_240 = None | |
| - slice_scatter_11 = torch.ops.aten.slice_scatter.default(full_default_275, view_446, 1, 0, 21504); full_default_275 = view_446 = None | |
| - add_385 = torch.ops.aten.add.Tensor(slice_scatter_10, slice_scatter_11); slice_scatter_10 = slice_scatter_11 = None | |
| - permute_606 = torch.ops.aten.permute.default(add_385, [1, 0]) | |
| - mm_241 = torch.ops.aten.mm.default(permute_606, convert_element_type_707); permute_606 = convert_element_type_707 = None | |
| + mm_240 = torch.ops.aten.mm.default(view_454, permute_604); view_454 = permute_604 = None | |
| + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(mm_239, torch.float32); mm_239 = None | |
| + slice_scatter_11 = torch.ops.aten.slice_scatter.default(full_default_275, view_455, 1, 0, 21504); full_default_275 = view_455 = None | |
| + add_394 = torch.ops.aten.add.Tensor(slice_scatter_10, slice_scatter_11); slice_scatter_10 = slice_scatter_11 = None | |
| + permute_606 = torch.ops.aten.permute.default(add_394, [1, 0]) | |
| + mm_241 = torch.ops.aten.mm.default(permute_606, mul_260); permute_606 = mul_260 = None | |
| permute_608 = torch.ops.aten.permute.default(permute_245, [1, 0]); permute_245 = None | |
| - mm_242 = torch.ops.aten.mm.default(add_385, permute_608); add_385 = permute_608 = None | |
| - convert_element_type_979 = torch.ops.prims.convert_element_type.default(mm_241, torch.float32); mm_241 = None | |
| - convert_element_type_980 = torch.ops.prims.convert_element_type.default(mm_242, torch.float32); mm_242 = None | |
| - mul_776 = torch.ops.aten.mul.Tensor(convert_element_type_974, mul_259); mul_259 = None | |
| - mul_777 = torch.ops.aten.mul.Tensor(convert_element_type_974, sigmoid_38); convert_element_type_974 = None | |
| - sub_210 = torch.ops.aten.sub.Tensor(1, sigmoid_38) | |
| - mul_778 = torch.ops.aten.mul.Tensor(sigmoid_38, sub_210); sigmoid_38 = sub_210 = None | |
| - mul_779 = torch.ops.aten.mul.Tensor(mul_776, mul_778); mul_776 = mul_778 = None | |
| - add_386 = torch.ops.aten.add.Tensor(mul_777, mul_779); mul_777 = mul_779 = None | |
| - mul_780 = torch.ops.aten.mul.Tensor(convert_element_type_980, mul_257); mul_257 = None | |
| - mul_781 = torch.ops.aten.mul.Tensor(convert_element_type_980, sigmoid_37); convert_element_type_980 = None | |
| - sub_211 = torch.ops.aten.sub.Tensor(1, sigmoid_37) | |
| - mul_782 = torch.ops.aten.mul.Tensor(sigmoid_37, sub_211); sigmoid_37 = sub_211 = None | |
| - mul_783 = torch.ops.aten.mul.Tensor(mul_780, mul_782); mul_780 = mul_782 = None | |
| - add_387 = torch.ops.aten.add.Tensor(mul_781, mul_783); mul_781 = mul_783 = None | |
| - mul_784 = torch.ops.aten.mul.Tensor(add_386, primals_398); primals_398 = None | |
| - mul_786 = torch.ops.aten.mul.Tensor(mul_258, mul_784) | |
| - sum_118 = torch.ops.aten.sum.dim_IntList(mul_786, [1], True); mul_786 = None | |
| - div_99 = torch.ops.aten.div.Tensor(mul_258, 512) | |
| - mul_787 = torch.ops.aten.mul.Tensor(div_99, sum_118); div_99 = sum_118 = None | |
| - sub_212 = torch.ops.aten.sub.Tensor(mul_784, mul_787); mul_784 = mul_787 = None | |
| - mul_788 = torch.ops.aten.mul.Tensor(sub_212, rsqrt_83); sub_212 = rsqrt_83 = None | |
| - mul_789 = torch.ops.aten.mul.Tensor(add_386, mul_258); add_386 = mul_258 = None | |
| - sum_119 = torch.ops.aten.sum.dim_IntList(mul_789, [0]); mul_789 = None | |
| - convert_element_type_981 = torch.ops.prims.convert_element_type.default(mul_788, torch.bfloat16); mul_788 = None | |
| - mul_790 = torch.ops.aten.mul.Tensor(add_387, primals_397); primals_397 = None | |
| - mul_792 = torch.ops.aten.mul.Tensor(mul_256, mul_790) | |
| - sum_120 = torch.ops.aten.sum.dim_IntList(mul_792, [1], True); mul_792 = None | |
| - div_100 = torch.ops.aten.div.Tensor(mul_256, 512) | |
| - mul_793 = torch.ops.aten.mul.Tensor(div_100, sum_120); div_100 = sum_120 = None | |
| - sub_213 = torch.ops.aten.sub.Tensor(mul_790, mul_793); mul_790 = mul_793 = None | |
| - mul_794 = torch.ops.aten.mul.Tensor(sub_213, rsqrt_82); sub_213 = rsqrt_82 = None | |
| - mul_795 = torch.ops.aten.mul.Tensor(add_387, mul_256); add_387 = mul_256 = None | |
| - sum_121 = torch.ops.aten.sum.dim_IntList(mul_795, [0]); mul_795 = None | |
| - convert_element_type_982 = torch.ops.prims.convert_element_type.default(mul_794, torch.bfloat16); mul_794 = None | |
| + mm_242 = torch.ops.aten.mm.default(add_394, permute_608); add_394 = permute_608 = None | |
| + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(mm_241, torch.float32); mm_241 = None | |
| + mul_794 = torch.ops.aten.mul.Tensor(mm_240, convert_element_type_709); convert_element_type_709 = None | |
| + mul_795 = torch.ops.aten.mul.Tensor(mm_240, sigmoid_38); mm_240 = None | |
| + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mul_794, torch.float32); mul_794 = None | |
| + convert_element_type_1013 = torch.ops.prims.convert_element_type.default(sigmoid_38, torch.float32); sigmoid_38 = None | |
| + sub_201 = torch.ops.aten.sub.Tensor(1, convert_element_type_1013) | |
| + mul_796 = torch.ops.aten.mul.Tensor(convert_element_type_1013, sub_201); convert_element_type_1013 = sub_201 = None | |
| + mul_797 = torch.ops.aten.mul.Tensor(convert_element_type_1012, mul_796); convert_element_type_1012 = mul_796 = None | |
| + convert_element_type_1014 = torch.ops.prims.convert_element_type.default(mul_797, torch.bfloat16); mul_797 = None | |
| + add_395 = torch.ops.aten.add.Tensor(mul_795, convert_element_type_1014); mul_795 = convert_element_type_1014 = None | |
| + mul_798 = torch.ops.aten.mul.Tensor(mm_242, convert_element_type_707); convert_element_type_707 = None | |
| + mul_799 = torch.ops.aten.mul.Tensor(mm_242, sigmoid_37); mm_242 = None | |
| + convert_element_type_1015 = torch.ops.prims.convert_element_type.default(mul_798, torch.float32); mul_798 = None | |
| + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(sigmoid_37, torch.float32); sigmoid_37 = None | |
| + sub_202 = torch.ops.aten.sub.Tensor(1, convert_element_type_1016) | |
| + mul_800 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sub_202); convert_element_type_1016 = sub_202 = None | |
| + mul_801 = torch.ops.aten.mul.Tensor(convert_element_type_1015, mul_800); convert_element_type_1015 = mul_800 = None | |
| + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(mul_801, torch.bfloat16); mul_801 = None | |
| + add_396 = torch.ops.aten.add.Tensor(mul_799, convert_element_type_1017); mul_799 = convert_element_type_1017 = None | |
| + convert_element_type_1018 = torch.ops.prims.convert_element_type.default(add_395, torch.float32); add_395 = None | |
| + mul_802 = torch.ops.aten.mul.Tensor(convert_element_type_1018, mul_258); mul_258 = None | |
| + mul_803 = torch.ops.aten.mul.Tensor(convert_element_type_1018, primals_398); convert_element_type_1018 = primals_398 = None | |
| + sum_118 = torch.ops.aten.sum.dim_IntList(mul_802, [0], True); mul_802 = None | |
| + view_456 = torch.ops.aten.view.default(sum_118, [512]); sum_118 = None | |
| + mul_804 = torch.ops.aten.mul.Tensor(mul_803, convert_element_type_708) | |
| + mul_805 = torch.ops.aten.mul.Tensor(mul_803, rsqrt_83); mul_803 = None | |
| + sum_119 = torch.ops.aten.sum.dim_IntList(mul_804, [1], True); mul_804 = None | |
| + mul_806 = torch.ops.aten.mul.Scalar(sum_119, -0.5); sum_119 = None | |
| + pow_102 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_83, 3); rsqrt_83 = None | |
| + mul_807 = torch.ops.aten.mul.Tensor(mul_806, pow_102); mul_806 = pow_102 = None | |
| + expand_72 = torch.ops.aten.expand.default(mul_807, [4096, 512]); mul_807 = None | |
| + div_99 = torch.ops.aten.div.Scalar(expand_72, 512); expand_72 = None | |
| + pow_103 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_708, 1.0); convert_element_type_708 = None | |
| + mul_808 = torch.ops.aten.mul.Scalar(pow_103, 2.0); pow_103 = None | |
| + mul_809 = torch.ops.aten.mul.Tensor(div_99, mul_808); div_99 = mul_808 = None | |
| + add_397 = torch.ops.aten.add.Tensor(mul_805, mul_809); mul_805 = mul_809 = None | |
| + convert_element_type_1019 = torch.ops.prims.convert_element_type.default(add_397, torch.bfloat16); add_397 = None | |
| + convert_element_type_1020 = torch.ops.prims.convert_element_type.default(add_396, torch.float32); add_396 = None | |
| + mul_810 = torch.ops.aten.mul.Tensor(convert_element_type_1020, mul_256); mul_256 = None | |
| + mul_811 = torch.ops.aten.mul.Tensor(convert_element_type_1020, primals_397); convert_element_type_1020 = primals_397 = None | |
| + sum_120 = torch.ops.aten.sum.dim_IntList(mul_810, [0], True); mul_810 = None | |
| + view_457 = torch.ops.aten.view.default(sum_120, [512]); sum_120 = None | |
| + mul_812 = torch.ops.aten.mul.Tensor(mul_811, convert_element_type_706) | |
| + mul_813 = torch.ops.aten.mul.Tensor(mul_811, rsqrt_82); mul_811 = None | |
| + sum_121 = torch.ops.aten.sum.dim_IntList(mul_812, [1], True); mul_812 = None | |
| + mul_814 = torch.ops.aten.mul.Scalar(sum_121, -0.5); sum_121 = None | |
| + pow_104 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_82, 3); rsqrt_82 = None | |
| + mul_815 = torch.ops.aten.mul.Tensor(mul_814, pow_104); mul_814 = pow_104 = None | |
| + expand_73 = torch.ops.aten.expand.default(mul_815, [4096, 512]); mul_815 = None | |
| + div_100 = torch.ops.aten.div.Scalar(expand_73, 512); expand_73 = None | |
| + pow_105 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_706, 1.0); convert_element_type_706 = None | |
| + mul_816 = torch.ops.aten.mul.Scalar(pow_105, 2.0); pow_105 = None | |
| + mul_817 = torch.ops.aten.mul.Tensor(div_100, mul_816); div_100 = mul_816 = None | |
| + add_398 = torch.ops.aten.add.Tensor(mul_813, mul_817); mul_813 = mul_817 = None | |
| + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(add_398, torch.bfloat16); add_398 = None | |
| full_default_277 = torch.ops.aten.full.default([4096, 512], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - index_put_25 = torch.ops.aten.index_put.default(full_default_277, [sub_32], convert_element_type_981, True) | |
| + index_put_25 = torch.ops.aten.index_put.default(full_default_277, [sub_32], convert_element_type_1019, True) | |
| full_default_278 = torch.ops.aten.full.default([4096, 1024], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_12 = torch.ops.aten.slice_scatter.default(full_default_278, index_put_25, 1, 512, 9223372036854775807); index_put_25 = None | |
| - permute_610 = torch.ops.aten.permute.default(convert_element_type_981, [1, 0]) | |
| - slice_67 = torch.ops.aten.slice.Tensor(mm_93, 1, 768, 9223372036854775807) | |
| + permute_610 = torch.ops.aten.permute.default(convert_element_type_1019, [1, 0]) | |
| + slice_67 = torch.ops.aten.slice.Tensor(mm_93, 1, 768, 9223372036854775807); mm_93 = None | |
| index_36 = torch.ops.aten.index.Tensor(slice_67, [sub_32]); slice_67 = None | |
| add_180 = torch.ops.aten.add.Tensor(mm_94, index_36); mm_94 = index_36 = None | |
| - convert_element_type_696 = torch.ops.prims.convert_element_type.default(add_180, torch.float32); add_180 = None | |
| - pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_696, 2) | |
| + convert_element_type_698 = torch.ops.prims.convert_element_type.default(add_180, torch.float32); add_180 = None | |
| + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_698, 2) | |
| mean_60 = torch.ops.aten.mean.dim(pow_61, [1], True); pow_61 = None | |
| add_182 = torch.ops.aten.add.Scalar(mean_60, 1.1920928955078125e-07); mean_60 = None | |
| rsqrt_81 = torch.ops.aten.rsqrt.default(add_182); add_182 = None | |
| - mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_696, rsqrt_81); convert_element_type_696 = None | |
| + mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_698, rsqrt_81) | |
| mul_253 = torch.ops.aten.mul.Tensor(mul_252, primals_394) | |
| - sigmoid_36 = torch.ops.aten.sigmoid.default(mul_253) | |
| - mul_255 = torch.ops.aten.mul.Tensor(mul_253, sigmoid_36) | |
| - convert_element_type_701 = torch.ops.prims.convert_element_type.default(mul_255, torch.bfloat16); mul_255 = None | |
| - mm_243 = torch.ops.aten.mm.default(permute_610, convert_element_type_701); permute_610 = convert_element_type_701 = None | |
| - convert_element_type_702 = torch.ops.prims.convert_element_type.default(primals_396, torch.bfloat16); primals_396 = None | |
| - permute_244 = torch.ops.aten.permute.default(convert_element_type_702, [1, 0]); convert_element_type_702 = None | |
| + convert_element_type_699 = torch.ops.prims.convert_element_type.default(mul_253, torch.bfloat16); mul_253 = None | |
| + sigmoid_36 = torch.ops.aten.sigmoid.default(convert_element_type_699) | |
| + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_699, sigmoid_36) | |
| + mm_243 = torch.ops.aten.mm.default(permute_610, mul_255); permute_610 = mul_255 = None | |
| + convert_element_type_703 = torch.ops.prims.convert_element_type.default(primals_396, torch.bfloat16); primals_396 = None | |
| + permute_244 = torch.ops.aten.permute.default(convert_element_type_703, [1, 0]); convert_element_type_703 = None | |
| permute_612 = torch.ops.aten.permute.default(permute_244, [1, 0]); permute_244 = None | |
| - mm_244 = torch.ops.aten.mm.default(convert_element_type_981, permute_612); convert_element_type_981 = permute_612 = None | |
| - convert_element_type_987 = torch.ops.prims.convert_element_type.default(mm_243, torch.float32); mm_243 = None | |
| - convert_element_type_988 = torch.ops.prims.convert_element_type.default(mm_244, torch.float32); mm_244 = None | |
| - slice_scatter_13 = torch.ops.aten.slice_scatter.default(full_default_278, convert_element_type_982, 1, 0, 512); convert_element_type_982 = None | |
| - add_388 = torch.ops.aten.add.Tensor(slice_scatter_12, slice_scatter_13); slice_scatter_12 = slice_scatter_13 = None | |
| - permute_614 = torch.ops.aten.permute.default(add_388, [1, 0]) | |
| - slice_66 = torch.ops.aten.slice.Tensor(mm_93, 1, 0, 768); mm_93 = None | |
| - convert_element_type_695 = torch.ops.prims.convert_element_type.default(slice_66, torch.float32); slice_66 = None | |
| - pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_695, 2) | |
| - mean_59 = torch.ops.aten.mean.dim(pow_60, [1], True); pow_60 = None | |
| - add_181 = torch.ops.aten.add.Scalar(mean_59, 1.1920928955078125e-07); mean_59 = None | |
| - rsqrt_80 = torch.ops.aten.rsqrt.default(add_181); add_181 = None | |
| - mul_250 = torch.ops.aten.mul.Tensor(convert_element_type_695, rsqrt_80); convert_element_type_695 = None | |
| - mul_251 = torch.ops.aten.mul.Tensor(mul_250, primals_393) | |
| - sigmoid_35 = torch.ops.aten.sigmoid.default(mul_251) | |
| - mul_254 = torch.ops.aten.mul.Tensor(mul_251, sigmoid_35) | |
| - convert_element_type_697 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None | |
| - mm_245 = torch.ops.aten.mm.default(permute_614, convert_element_type_697); permute_614 = convert_element_type_697 = None | |
| - convert_element_type_698 = torch.ops.prims.convert_element_type.default(primals_395, torch.bfloat16); primals_395 = None | |
| - permute_243 = torch.ops.aten.permute.default(convert_element_type_698, [1, 0]); convert_element_type_698 = None | |
| + mm_244 = torch.ops.aten.mm.default(convert_element_type_1019, permute_612); convert_element_type_1019 = permute_612 = None | |
| + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(mm_243, torch.float32); mm_243 = None | |
| + slice_scatter_13 = torch.ops.aten.slice_scatter.default(full_default_278, convert_element_type_1021, 1, 0, 512); convert_element_type_1021 = None | |
| + add_399 = torch.ops.aten.add.Tensor(slice_scatter_12, slice_scatter_13); slice_scatter_12 = slice_scatter_13 = None | |
| + permute_614 = torch.ops.aten.permute.default(add_399, [1, 0]) | |
| + mm_245 = torch.ops.aten.mm.default(permute_614, mul_254); permute_614 = mul_254 = None | |
| permute_616 = torch.ops.aten.permute.default(permute_243, [1, 0]); permute_243 = None | |
| - mm_246 = torch.ops.aten.mm.default(add_388, permute_616); add_388 = permute_616 = None | |
| - convert_element_type_993 = torch.ops.prims.convert_element_type.default(mm_245, torch.float32); mm_245 = None | |
| - convert_element_type_994 = torch.ops.prims.convert_element_type.default(mm_246, torch.float32); mm_246 = None | |
| - mul_796 = torch.ops.aten.mul.Tensor(convert_element_type_988, mul_253); mul_253 = None | |
| - mul_797 = torch.ops.aten.mul.Tensor(convert_element_type_988, sigmoid_36); convert_element_type_988 = None | |
| - sub_214 = torch.ops.aten.sub.Tensor(1, sigmoid_36) | |
| - mul_798 = torch.ops.aten.mul.Tensor(sigmoid_36, sub_214); sigmoid_36 = sub_214 = None | |
| - mul_799 = torch.ops.aten.mul.Tensor(mul_796, mul_798); mul_796 = mul_798 = None | |
| - add_389 = torch.ops.aten.add.Tensor(mul_797, mul_799); mul_797 = mul_799 = None | |
| - mul_800 = torch.ops.aten.mul.Tensor(convert_element_type_994, mul_251); mul_251 = None | |
| - mul_801 = torch.ops.aten.mul.Tensor(convert_element_type_994, sigmoid_35); convert_element_type_994 = None | |
| - sub_215 = torch.ops.aten.sub.Tensor(1, sigmoid_35) | |
| - mul_802 = torch.ops.aten.mul.Tensor(sigmoid_35, sub_215); sigmoid_35 = sub_215 = None | |
| - mul_803 = torch.ops.aten.mul.Tensor(mul_800, mul_802); mul_800 = mul_802 = None | |
| - add_390 = torch.ops.aten.add.Tensor(mul_801, mul_803); mul_801 = mul_803 = None | |
| - mul_804 = torch.ops.aten.mul.Tensor(add_389, primals_394); primals_394 = None | |
| - mul_806 = torch.ops.aten.mul.Tensor(mul_252, mul_804) | |
| - sum_122 = torch.ops.aten.sum.dim_IntList(mul_806, [1], True); mul_806 = None | |
| - div_101 = torch.ops.aten.div.Tensor(mul_252, 768) | |
| - mul_807 = torch.ops.aten.mul.Tensor(div_101, sum_122); div_101 = sum_122 = None | |
| - sub_216 = torch.ops.aten.sub.Tensor(mul_804, mul_807); mul_804 = mul_807 = None | |
| - mul_808 = torch.ops.aten.mul.Tensor(sub_216, rsqrt_81); sub_216 = rsqrt_81 = None | |
| - mul_809 = torch.ops.aten.mul.Tensor(add_389, mul_252); add_389 = mul_252 = None | |
| - sum_123 = torch.ops.aten.sum.dim_IntList(mul_809, [0]); mul_809 = None | |
| - convert_element_type_995 = torch.ops.prims.convert_element_type.default(mul_808, torch.bfloat16); mul_808 = None | |
| - mul_810 = torch.ops.aten.mul.Tensor(add_390, primals_393); primals_393 = None | |
| - mul_812 = torch.ops.aten.mul.Tensor(mul_250, mul_810) | |
| - sum_124 = torch.ops.aten.sum.dim_IntList(mul_812, [1], True); mul_812 = None | |
| - div_102 = torch.ops.aten.div.Tensor(mul_250, 768) | |
| - mul_813 = torch.ops.aten.mul.Tensor(div_102, sum_124); div_102 = sum_124 = None | |
| - sub_217 = torch.ops.aten.sub.Tensor(mul_810, mul_813); mul_810 = mul_813 = None | |
| - mul_814 = torch.ops.aten.mul.Tensor(sub_217, rsqrt_80); sub_217 = rsqrt_80 = None | |
| - mul_815 = torch.ops.aten.mul.Tensor(add_390, mul_250); add_390 = mul_250 = None | |
| - sum_125 = torch.ops.aten.sum.dim_IntList(mul_815, [0]); mul_815 = None | |
| - convert_element_type_996 = torch.ops.prims.convert_element_type.default(mul_814, torch.bfloat16); mul_814 = None | |
| + mm_246 = torch.ops.aten.mm.default(add_399, permute_616); add_399 = permute_616 = None | |
| + convert_element_type_1031 = torch.ops.prims.convert_element_type.default(mm_245, torch.float32); mm_245 = None | |
| + mul_818 = torch.ops.aten.mul.Tensor(mm_244, convert_element_type_699); convert_element_type_699 = None | |
| + mul_819 = torch.ops.aten.mul.Tensor(mm_244, sigmoid_36); mm_244 = None | |
| + convert_element_type_1032 = torch.ops.prims.convert_element_type.default(mul_818, torch.float32); mul_818 = None | |
| + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(sigmoid_36, torch.float32); sigmoid_36 = None | |
| + sub_203 = torch.ops.aten.sub.Tensor(1, convert_element_type_1033) | |
| + mul_820 = torch.ops.aten.mul.Tensor(convert_element_type_1033, sub_203); convert_element_type_1033 = sub_203 = None | |
| + mul_821 = torch.ops.aten.mul.Tensor(convert_element_type_1032, mul_820); convert_element_type_1032 = mul_820 = None | |
| + convert_element_type_1034 = torch.ops.prims.convert_element_type.default(mul_821, torch.bfloat16); mul_821 = None | |
| + add_400 = torch.ops.aten.add.Tensor(mul_819, convert_element_type_1034); mul_819 = convert_element_type_1034 = None | |
| + mul_822 = torch.ops.aten.mul.Tensor(mm_246, convert_element_type_697); convert_element_type_697 = None | |
| + mul_823 = torch.ops.aten.mul.Tensor(mm_246, sigmoid_35); mm_246 = None | |
| + convert_element_type_1035 = torch.ops.prims.convert_element_type.default(mul_822, torch.float32); mul_822 = None | |
| + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(sigmoid_35, torch.float32); sigmoid_35 = None | |
| + sub_204 = torch.ops.aten.sub.Tensor(1, convert_element_type_1036) | |
| + mul_824 = torch.ops.aten.mul.Tensor(convert_element_type_1036, sub_204); convert_element_type_1036 = sub_204 = None | |
| + mul_825 = torch.ops.aten.mul.Tensor(convert_element_type_1035, mul_824); convert_element_type_1035 = mul_824 = None | |
| + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(mul_825, torch.bfloat16); mul_825 = None | |
| + add_401 = torch.ops.aten.add.Tensor(mul_823, convert_element_type_1037); mul_823 = convert_element_type_1037 = None | |
| + convert_element_type_1038 = torch.ops.prims.convert_element_type.default(add_400, torch.float32); add_400 = None | |
| + mul_826 = torch.ops.aten.mul.Tensor(convert_element_type_1038, mul_252); mul_252 = None | |
| + mul_827 = torch.ops.aten.mul.Tensor(convert_element_type_1038, primals_394); convert_element_type_1038 = primals_394 = None | |
| + sum_122 = torch.ops.aten.sum.dim_IntList(mul_826, [0], True); mul_826 = None | |
| + view_458 = torch.ops.aten.view.default(sum_122, [768]); sum_122 = None | |
| + mul_828 = torch.ops.aten.mul.Tensor(mul_827, convert_element_type_698) | |
| + mul_829 = torch.ops.aten.mul.Tensor(mul_827, rsqrt_81); mul_827 = None | |
| + sum_123 = torch.ops.aten.sum.dim_IntList(mul_828, [1], True); mul_828 = None | |
| + mul_830 = torch.ops.aten.mul.Scalar(sum_123, -0.5); sum_123 = None | |
| + pow_106 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_81, 3); rsqrt_81 = None | |
| + mul_831 = torch.ops.aten.mul.Tensor(mul_830, pow_106); mul_830 = pow_106 = None | |
| + expand_74 = torch.ops.aten.expand.default(mul_831, [4096, 768]); mul_831 = None | |
| + div_101 = torch.ops.aten.div.Scalar(expand_74, 768); expand_74 = None | |
| + pow_107 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_698, 1.0); convert_element_type_698 = None | |
| + mul_832 = torch.ops.aten.mul.Scalar(pow_107, 2.0); pow_107 = None | |
| + mul_833 = torch.ops.aten.mul.Tensor(div_101, mul_832); div_101 = mul_832 = None | |
| + add_402 = torch.ops.aten.add.Tensor(mul_829, mul_833); mul_829 = mul_833 = None | |
| + convert_element_type_1039 = torch.ops.prims.convert_element_type.default(add_402, torch.bfloat16); add_402 = None | |
| + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(add_401, torch.float32); add_401 = None | |
| + mul_834 = torch.ops.aten.mul.Tensor(convert_element_type_1040, mul_250); mul_250 = None | |
| + mul_835 = torch.ops.aten.mul.Tensor(convert_element_type_1040, primals_393); convert_element_type_1040 = primals_393 = None | |
| + sum_124 = torch.ops.aten.sum.dim_IntList(mul_834, [0], True); mul_834 = None | |
| + view_459 = torch.ops.aten.view.default(sum_124, [768]); sum_124 = None | |
| + mul_836 = torch.ops.aten.mul.Tensor(mul_835, convert_element_type_696) | |
| + mul_837 = torch.ops.aten.mul.Tensor(mul_835, rsqrt_80); mul_835 = None | |
| + sum_125 = torch.ops.aten.sum.dim_IntList(mul_836, [1], True); mul_836 = None | |
| + mul_838 = torch.ops.aten.mul.Scalar(sum_125, -0.5); sum_125 = None | |
| + pow_108 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_80, 3); rsqrt_80 = None | |
| + mul_839 = torch.ops.aten.mul.Tensor(mul_838, pow_108); mul_838 = pow_108 = None | |
| + expand_75 = torch.ops.aten.expand.default(mul_839, [4096, 768]); mul_839 = None | |
| + div_102 = torch.ops.aten.div.Scalar(expand_75, 768); expand_75 = None | |
| + pow_109 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_696, 1.0); convert_element_type_696 = None | |
| + mul_840 = torch.ops.aten.mul.Scalar(pow_109, 2.0); pow_109 = None | |
| + mul_841 = torch.ops.aten.mul.Tensor(div_102, mul_840); div_102 = mul_840 = None | |
| + add_403 = torch.ops.aten.add.Tensor(mul_837, mul_841); mul_837 = mul_841 = None | |
| + convert_element_type_1041 = torch.ops.prims.convert_element_type.default(add_403, torch.bfloat16); add_403 = None | |
| full_default_280 = torch.ops.aten.full.default([4096, 768], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - index_put_26 = torch.ops.aten.index_put.default(full_default_280, [sub_32], convert_element_type_995, True) | |
| + index_put_26 = torch.ops.aten.index_put.default(full_default_280, [sub_32], convert_element_type_1039, True) | |
| full_default_281 = torch.ops.aten.full.default([4096, 1536], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_14 = torch.ops.aten.slice_scatter.default(full_default_281, index_put_26, 1, 768, 9223372036854775807); index_put_26 = None | |
| - permute_618 = torch.ops.aten.permute.default(convert_element_type_995, [1, 0]) | |
| - convert_element_type_681 = torch.ops.prims.convert_element_type.default(primals_389, torch.bfloat16); primals_389 = None | |
| - permute_235 = torch.ops.aten.permute.default(convert_element_type_682, [0, 2, 1]); convert_element_type_682 = None | |
| - permute_236 = torch.ops.aten.permute.default(convert_element_type_681, [1, 0]); convert_element_type_681 = None | |
| + permute_618 = torch.ops.aten.permute.default(convert_element_type_1039, [1, 0]) | |
| + convert_element_type_684 = torch.ops.prims.convert_element_type.default(primals_389, torch.bfloat16); primals_389 = None | |
| + permute_236 = torch.ops.aten.permute.default(convert_element_type_684, [1, 0]); convert_element_type_684 = None | |
| clone_82 = torch.ops.aten.clone.default(permute_235, memory_format = torch.contiguous_format); permute_235 = None | |
| view_277 = torch.ops.aten.view.default(clone_82, [458752, 224]); clone_82 = None | |
| mm_91 = torch.ops.aten.mm.default(view_277, permute_236) | |
| view_278 = torch.ops.aten.view.default(mm_91, [4096, 112, 96]); mm_91 = None | |
| permute_237 = torch.ops.aten.permute.default(view_278, [0, 2, 1]); view_278 = None | |
| clone_83 = torch.ops.aten.clone.default(permute_237, memory_format = torch.contiguous_format); permute_237 = None | |
| - convert_element_type_685 = torch.ops.prims.convert_element_type.default(primals_390, torch.bfloat16); primals_390 = None | |
| - convert_element_type_686 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None | |
| - permute_238 = torch.ops.aten.permute.default(convert_element_type_686, [0, 2, 1]); convert_element_type_686 = None | |
| - permute_239 = torch.ops.aten.permute.default(convert_element_type_685, [1, 0]); convert_element_type_685 = None | |
| + convert_element_type_687 = torch.ops.prims.convert_element_type.default(primals_390, torch.bfloat16); primals_390 = None | |
| + permute_238 = torch.ops.aten.permute.default(convert_element_type_683, [0, 2, 1]); convert_element_type_683 = None | |
| + permute_239 = torch.ops.aten.permute.default(convert_element_type_687, [1, 0]); convert_element_type_687 = None | |
| clone_84 = torch.ops.aten.clone.default(permute_238, memory_format = torch.contiguous_format); permute_238 = None | |
| view_279 = torch.ops.aten.view.default(clone_84, [458752, 224]); clone_84 = None | |
| mm_92 = torch.ops.aten.mm.default(view_279, permute_239) | |
| @@ -6263,308 +3281,350 @@ | |
| add_179 = torch.ops.aten.add.Tensor(clone_85, index_35); clone_85 = index_35 = None | |
| view_282 = torch.ops.aten.view.default(add_179, [4096, -1]); add_179 = None | |
| mm_247 = torch.ops.aten.mm.default(permute_618, view_282); permute_618 = view_282 = None | |
| - convert_element_type_692 = torch.ops.prims.convert_element_type.default(primals_392, torch.bfloat16); primals_392 = None | |
| - permute_242 = torch.ops.aten.permute.default(convert_element_type_692, [1, 0]); convert_element_type_692 = None | |
| + convert_element_type_693 = torch.ops.prims.convert_element_type.default(primals_392, torch.bfloat16); primals_392 = None | |
| + permute_242 = torch.ops.aten.permute.default(convert_element_type_693, [1, 0]); convert_element_type_693 = None | |
| permute_620 = torch.ops.aten.permute.default(permute_242, [1, 0]); permute_242 = None | |
| - mm_248 = torch.ops.aten.mm.default(convert_element_type_995, permute_620); convert_element_type_995 = permute_620 = None | |
| - convert_element_type_1001 = torch.ops.prims.convert_element_type.default(mm_247, torch.float32); mm_247 = None | |
| - slice_scatter_15 = torch.ops.aten.slice_scatter.default(full_default_281, convert_element_type_996, 1, 0, 768); convert_element_type_996 = None | |
| - add_391 = torch.ops.aten.add.Tensor(slice_scatter_14, slice_scatter_15); slice_scatter_14 = slice_scatter_15 = None | |
| - permute_622 = torch.ops.aten.permute.default(add_391, [1, 0]) | |
| + mm_248 = torch.ops.aten.mm.default(convert_element_type_1039, permute_620); convert_element_type_1039 = permute_620 = None | |
| + convert_element_type_1046 = torch.ops.prims.convert_element_type.default(mm_247, torch.float32); mm_247 = None | |
| + slice_scatter_15 = torch.ops.aten.slice_scatter.default(full_default_281, convert_element_type_1041, 1, 0, 768); convert_element_type_1041 = None | |
| + add_404 = torch.ops.aten.add.Tensor(slice_scatter_14, slice_scatter_15); slice_scatter_14 = slice_scatter_15 = None | |
| + permute_622 = torch.ops.aten.permute.default(add_404, [1, 0]) | |
| slice_62 = torch.ops.aten.slice.Tensor(clone_83, 1, 0, 48); clone_83 = None | |
| view_281 = torch.ops.aten.view.default(slice_62, [4096, -1]); slice_62 = None | |
| mm_249 = torch.ops.aten.mm.default(permute_622, view_281); permute_622 = view_281 = None | |
| - convert_element_type_689 = torch.ops.prims.convert_element_type.default(primals_391, torch.bfloat16); primals_391 = None | |
| - permute_241 = torch.ops.aten.permute.default(convert_element_type_689, [1, 0]); convert_element_type_689 = None | |
| + convert_element_type_690 = torch.ops.prims.convert_element_type.default(primals_391, torch.bfloat16); primals_391 = None | |
| + permute_241 = torch.ops.aten.permute.default(convert_element_type_690, [1, 0]); convert_element_type_690 = None | |
| permute_624 = torch.ops.aten.permute.default(permute_241, [1, 0]); permute_241 = None | |
| - mm_250 = torch.ops.aten.mm.default(add_391, permute_624); add_391 = permute_624 = None | |
| - convert_element_type_1006 = torch.ops.prims.convert_element_type.default(mm_249, torch.float32); mm_249 = None | |
| - view_447 = torch.ops.aten.view.default(mm_248, [4096, 48, 112]); mm_248 = None | |
| - view_448 = torch.ops.aten.view.default(mm_250, [4096, 48, 112]); mm_250 = None | |
| + mm_250 = torch.ops.aten.mm.default(add_404, permute_624); add_404 = permute_624 = None | |
| + convert_element_type_1051 = torch.ops.prims.convert_element_type.default(mm_249, torch.float32); mm_249 = None | |
| + view_460 = torch.ops.aten.view.default(mm_248, [4096, 48, 112]); mm_248 = None | |
| + view_461 = torch.ops.aten.view.default(mm_250, [4096, 48, 112]); mm_250 = None | |
| full_default_283 = torch.ops.aten.full.default([4096, 48, 112], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - index_put_27 = torch.ops.aten.index_put.default(full_default_283, [sub_32], view_447, True) | |
| + index_put_27 = torch.ops.aten.index_put.default(full_default_283, [sub_32], view_460, True) | |
| full_default_284 = torch.ops.aten.full.default([4096, 96, 112], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_16 = torch.ops.aten.slice_scatter.default(full_default_284, index_put_27, 1, 48, 9223372036854775807); index_put_27 = None | |
| - permute_626 = torch.ops.aten.permute.default(view_447, [0, 2, 1]); view_447 = None | |
| + permute_626 = torch.ops.aten.permute.default(view_460, [0, 2, 1]); view_460 = None | |
| clone_88 = torch.ops.aten.clone.default(permute_626, memory_format = torch.contiguous_format); permute_626 = None | |
| - view_449 = torch.ops.aten.view.default(clone_88, [458752, 48]); clone_88 = None | |
| - permute_627 = torch.ops.aten.permute.default(view_449, [1, 0]) | |
| + view_462 = torch.ops.aten.view.default(clone_88, [458752, 48]); clone_88 = None | |
| + permute_627 = torch.ops.aten.permute.default(view_462, [1, 0]) | |
| mm_251 = torch.ops.aten.mm.default(permute_627, view_279); permute_627 = view_279 = None | |
| permute_629 = torch.ops.aten.permute.default(permute_239, [1, 0]); permute_239 = None | |
| - mm_252 = torch.ops.aten.mm.default(view_449, permute_629); view_449 = permute_629 = None | |
| - view_450 = torch.ops.aten.view.default(mm_252, [4096, 112, 224]); mm_252 = None | |
| - permute_631 = torch.ops.aten.permute.default(view_450, [0, 2, 1]); view_450 = None | |
| - convert_element_type_1011 = torch.ops.prims.convert_element_type.default(permute_631, torch.float32); permute_631 = None | |
| - add_392 = torch.ops.aten.add.Tensor(slice_92, convert_element_type_1011); slice_92 = convert_element_type_1011 = None | |
| - convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mm_251, torch.float32); mm_251 = None | |
| - slice_scatter_17 = torch.ops.aten.slice_scatter.default(full_default_284, view_448, 1, 0, 48); view_448 = None | |
| - add_393 = torch.ops.aten.add.Tensor(slice_scatter_16, slice_scatter_17); slice_scatter_16 = slice_scatter_17 = None | |
| - permute_632 = torch.ops.aten.permute.default(add_393, [0, 2, 1]); add_393 = None | |
| + mm_252 = torch.ops.aten.mm.default(view_462, permute_629); view_462 = permute_629 = None | |
| + view_463 = torch.ops.aten.view.default(mm_252, [4096, 112, 224]); mm_252 = None | |
| + permute_631 = torch.ops.aten.permute.default(view_463, [0, 2, 1]); view_463 = None | |
| + add_405 = torch.ops.aten.add.Tensor(slice_92, permute_631); slice_92 = permute_631 = None | |
| + convert_element_type_1056 = torch.ops.prims.convert_element_type.default(mm_251, torch.float32); mm_251 = None | |
| + slice_scatter_17 = torch.ops.aten.slice_scatter.default(full_default_284, view_461, 1, 0, 48); view_461 = None | |
| + add_406 = torch.ops.aten.add.Tensor(slice_scatter_16, slice_scatter_17); slice_scatter_16 = slice_scatter_17 = None | |
| + permute_632 = torch.ops.aten.permute.default(add_406, [0, 2, 1]); add_406 = None | |
| clone_89 = torch.ops.aten.clone.default(permute_632, memory_format = torch.contiguous_format); permute_632 = None | |
| - view_451 = torch.ops.aten.view.default(clone_89, [458752, 96]); clone_89 = None | |
| - permute_633 = torch.ops.aten.permute.default(view_451, [1, 0]) | |
| + view_464 = torch.ops.aten.view.default(clone_89, [458752, 96]); clone_89 = None | |
| + permute_633 = torch.ops.aten.permute.default(view_464, [1, 0]) | |
| mm_253 = torch.ops.aten.mm.default(permute_633, view_277); permute_633 = view_277 = None | |
| permute_635 = torch.ops.aten.permute.default(permute_236, [1, 0]); permute_236 = None | |
| - mm_254 = torch.ops.aten.mm.default(view_451, permute_635); view_451 = permute_635 = None | |
| - view_452 = torch.ops.aten.view.default(mm_254, [4096, 112, 224]); mm_254 = None | |
| - permute_637 = torch.ops.aten.permute.default(view_452, [0, 2, 1]); view_452 = None | |
| - convert_element_type_1017 = torch.ops.prims.convert_element_type.default(permute_637, torch.float32); permute_637 = None | |
| - add_394 = torch.ops.aten.add.Tensor(add_384, convert_element_type_1017); add_384 = convert_element_type_1017 = None | |
| - convert_element_type_1018 = torch.ops.prims.convert_element_type.default(mm_253, torch.float32); mm_253 = None | |
| - mul_816 = torch.ops.aten.mul.Tensor(add_392, primals_388); primals_388 = None | |
| - mul_818 = torch.ops.aten.mul.Tensor(mul_248, mul_816) | |
| - sum_126 = torch.ops.aten.sum.dim_IntList(mul_818, [2], True); mul_818 = None | |
| - div_103 = torch.ops.aten.div.Tensor(mul_248, 112) | |
| - mul_819 = torch.ops.aten.mul.Tensor(div_103, sum_126); div_103 = sum_126 = None | |
| - sub_218 = torch.ops.aten.sub.Tensor(mul_816, mul_819); mul_816 = mul_819 = None | |
| - mul_820 = torch.ops.aten.mul.Tensor(sub_218, rsqrt_79); sub_218 = rsqrt_79 = None | |
| - mul_821 = torch.ops.aten.mul.Tensor(add_392, mul_248); add_392 = mul_248 = None | |
| - sum_127 = torch.ops.aten.sum.dim_IntList(mul_821, [0, 1]); mul_821 = None | |
| - convert_element_type_1019 = torch.ops.prims.convert_element_type.default(mul_820, torch.bfloat16); mul_820 = None | |
| - mul_822 = torch.ops.aten.mul.Tensor(add_394, primals_387); primals_387 = None | |
| - mul_824 = torch.ops.aten.mul.Tensor(mul_246, mul_822) | |
| - sum_128 = torch.ops.aten.sum.dim_IntList(mul_824, [2], True); mul_824 = None | |
| - div_104 = torch.ops.aten.div.Tensor(mul_246, 112) | |
| - mul_825 = torch.ops.aten.mul.Tensor(div_104, sum_128); div_104 = sum_128 = None | |
| - sub_219 = torch.ops.aten.sub.Tensor(mul_822, mul_825); mul_822 = mul_825 = None | |
| - mul_826 = torch.ops.aten.mul.Tensor(sub_219, rsqrt_78); sub_219 = rsqrt_78 = None | |
| - mul_827 = torch.ops.aten.mul.Tensor(add_394, mul_246); add_394 = mul_246 = None | |
| - sum_129 = torch.ops.aten.sum.dim_IntList(mul_827, [0, 1]); mul_827 = None | |
| - convert_element_type_1020 = torch.ops.prims.convert_element_type.default(mul_826, torch.bfloat16); mul_826 = None | |
| - slice_93 = torch.ops.aten.slice.Tensor(convert_element_type_1019, 1, 0, 128) | |
| - slice_94 = torch.ops.aten.slice.Tensor(convert_element_type_1019, 1, 128, 224); convert_element_type_1019 = None | |
| - slice_95 = torch.ops.aten.slice.Tensor(convert_element_type_1020, 1, 0, 128) | |
| - slice_96 = torch.ops.aten.slice.Tensor(convert_element_type_1020, 1, 128, 224); convert_element_type_1020 = None | |
| + mm_254 = torch.ops.aten.mm.default(view_464, permute_635); view_464 = permute_635 = None | |
| + view_465 = torch.ops.aten.view.default(mm_254, [4096, 112, 224]); mm_254 = None | |
| + permute_637 = torch.ops.aten.permute.default(view_465, [0, 2, 1]); view_465 = None | |
| + add_407 = torch.ops.aten.add.Tensor(add_393, permute_637); add_393 = permute_637 = None | |
| + convert_element_type_1061 = torch.ops.prims.convert_element_type.default(mm_253, torch.float32); mm_253 = None | |
| + convert_element_type_1062 = torch.ops.prims.convert_element_type.default(add_405, torch.float32); add_405 = None | |
| + mul_842 = torch.ops.aten.mul.Tensor(convert_element_type_1062, mul_248); mul_248 = None | |
| + mul_843 = torch.ops.aten.mul.Tensor(convert_element_type_1062, primals_388); convert_element_type_1062 = primals_388 = None | |
| + sum_126 = torch.ops.aten.sum.dim_IntList(mul_842, [0, 1], True); mul_842 = None | |
| + view_466 = torch.ops.aten.view.default(sum_126, [112]); sum_126 = None | |
| + mul_844 = torch.ops.aten.mul.Tensor(mul_843, convert_element_type_682) | |
| + mul_845 = torch.ops.aten.mul.Tensor(mul_843, rsqrt_79); mul_843 = None | |
| + sum_127 = torch.ops.aten.sum.dim_IntList(mul_844, [2], True); mul_844 = None | |
| + mul_846 = torch.ops.aten.mul.Scalar(sum_127, -0.5); sum_127 = None | |
| + pow_110 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_79, 3); rsqrt_79 = None | |
| + mul_847 = torch.ops.aten.mul.Tensor(mul_846, pow_110); mul_846 = pow_110 = None | |
| + expand_76 = torch.ops.aten.expand.default(mul_847, [4096, 224, 112]); mul_847 = None | |
| + div_103 = torch.ops.aten.div.Scalar(expand_76, 112); expand_76 = None | |
| + pow_111 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_682, 1.0); convert_element_type_682 = None | |
| + mul_848 = torch.ops.aten.mul.Scalar(pow_111, 2.0); pow_111 = None | |
| + mul_849 = torch.ops.aten.mul.Tensor(div_103, mul_848); div_103 = mul_848 = None | |
| + add_408 = torch.ops.aten.add.Tensor(mul_845, mul_849); mul_845 = mul_849 = None | |
| + convert_element_type_1063 = torch.ops.prims.convert_element_type.default(add_408, torch.bfloat16); add_408 = None | |
| + convert_element_type_1064 = torch.ops.prims.convert_element_type.default(add_407, torch.float32); add_407 = None | |
| + mul_850 = torch.ops.aten.mul.Tensor(convert_element_type_1064, mul_246); mul_246 = None | |
| + mul_851 = torch.ops.aten.mul.Tensor(convert_element_type_1064, primals_387); convert_element_type_1064 = primals_387 = None | |
| + sum_128 = torch.ops.aten.sum.dim_IntList(mul_850, [0, 1], True); mul_850 = None | |
| + view_467 = torch.ops.aten.view.default(sum_128, [112]); sum_128 = None | |
| + mul_852 = torch.ops.aten.mul.Tensor(mul_851, convert_element_type_680) | |
| + mul_853 = torch.ops.aten.mul.Tensor(mul_851, rsqrt_78); mul_851 = None | |
| + sum_129 = torch.ops.aten.sum.dim_IntList(mul_852, [2], True); mul_852 = None | |
| + mul_854 = torch.ops.aten.mul.Scalar(sum_129, -0.5); sum_129 = None | |
| + pow_112 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_78, 3); rsqrt_78 = None | |
| + mul_855 = torch.ops.aten.mul.Tensor(mul_854, pow_112); mul_854 = pow_112 = None | |
| + expand_77 = torch.ops.aten.expand.default(mul_855, [4096, 224, 112]); mul_855 = None | |
| + div_104 = torch.ops.aten.div.Scalar(expand_77, 112); expand_77 = None | |
| + pow_113 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_680, 1.0); convert_element_type_680 = None | |
| + mul_856 = torch.ops.aten.mul.Scalar(pow_113, 2.0); pow_113 = None | |
| + mul_857 = torch.ops.aten.mul.Tensor(div_104, mul_856); div_104 = mul_856 = None | |
| + add_409 = torch.ops.aten.add.Tensor(mul_853, mul_857); mul_853 = mul_857 = None | |
| + convert_element_type_1065 = torch.ops.prims.convert_element_type.default(add_409, torch.bfloat16); add_409 = None | |
| + slice_93 = torch.ops.aten.slice.Tensor(convert_element_type_1063, 1, 0, 128) | |
| + slice_94 = torch.ops.aten.slice.Tensor(convert_element_type_1063, 1, 128, 224); convert_element_type_1063 = None | |
| + slice_95 = torch.ops.aten.slice.Tensor(convert_element_type_1065, 1, 0, 128) | |
| + slice_96 = torch.ops.aten.slice.Tensor(convert_element_type_1065, 1, 128, 224); convert_element_type_1065 = None | |
| full_default_286 = torch.ops.aten.full.default([4096, 128, 112], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| index_put_28 = torch.ops.aten.index_put.default(full_default_286, [sub_32], slice_93, True) | |
| full_default_287 = torch.ops.aten.full.default([4096, 256, 112], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_18 = torch.ops.aten.slice_scatter.default(full_default_287, index_put_28, 1, 128, 9223372036854775807); index_put_28 = None | |
| permute_638 = torch.ops.aten.permute.default(slice_93, [0, 2, 1]); slice_93 = None | |
| - view_453 = torch.ops.aten.view.default(permute_638, [458752, 128]); permute_638 = None | |
| - permute_639 = torch.ops.aten.permute.default(view_453, [1, 0]) | |
| + view_468 = torch.ops.aten.view.default(permute_638, [458752, 128]); permute_638 = None | |
| + permute_639 = torch.ops.aten.permute.default(view_468, [1, 0]) | |
| mm_255 = torch.ops.aten.mm.default(permute_639, view_275); view_275 = None | |
| - convert_element_type_676 = torch.ops.prims.convert_element_type.default(primals_386, torch.bfloat16); primals_386 = None | |
| - permute_233 = torch.ops.aten.permute.default(convert_element_type_676, [1, 0]); convert_element_type_676 = None | |
| + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_386, torch.bfloat16); primals_386 = None | |
| + permute_233 = torch.ops.aten.permute.default(convert_element_type_677, [1, 0]); convert_element_type_677 = None | |
| permute_641 = torch.ops.aten.permute.default(permute_233, [1, 0]); permute_233 = None | |
| - mm_256 = torch.ops.aten.mm.default(view_453, permute_641); permute_641 = None | |
| - view_454 = torch.ops.aten.view.default(mm_256, [4096, 112, 64]); mm_256 = None | |
| - permute_643 = torch.ops.aten.permute.default(view_454, [0, 2, 1]); view_454 = None | |
| - convert_element_type_1025 = torch.ops.prims.convert_element_type.default(mm_255, torch.float32); mm_255 = None | |
| + mm_256 = torch.ops.aten.mm.default(view_468, permute_641); permute_641 = None | |
| + view_469 = torch.ops.aten.view.default(mm_256, [4096, 112, 64]); mm_256 = None | |
| + permute_643 = torch.ops.aten.permute.default(view_469, [0, 2, 1]); view_469 = None | |
| + convert_element_type_1070 = torch.ops.prims.convert_element_type.default(mm_255, torch.float32); mm_255 = None | |
| slice_scatter_19 = torch.ops.aten.slice_scatter.default(full_default_287, slice_95, 1, 0, 128); slice_95 = None | |
| - add_395 = torch.ops.aten.add.Tensor(slice_scatter_18, slice_scatter_19); slice_scatter_18 = slice_scatter_19 = None | |
| - permute_644 = torch.ops.aten.permute.default(add_395, [0, 2, 1]); add_395 = None | |
| + add_410 = torch.ops.aten.add.Tensor(slice_scatter_18, slice_scatter_19); slice_scatter_18 = slice_scatter_19 = None | |
| + permute_644 = torch.ops.aten.permute.default(add_410, [0, 2, 1]); add_410 = None | |
| clone_90 = torch.ops.aten.clone.default(permute_644, memory_format = torch.contiguous_format); permute_644 = None | |
| - view_455 = torch.ops.aten.view.default(clone_90, [458752, 256]); clone_90 = None | |
| - permute_645 = torch.ops.aten.permute.default(view_455, [1, 0]) | |
| + view_470 = torch.ops.aten.view.default(clone_90, [458752, 256]); clone_90 = None | |
| + permute_645 = torch.ops.aten.permute.default(view_470, [1, 0]) | |
| mm_257 = torch.ops.aten.mm.default(permute_645, view_273); view_273 = None | |
| - convert_element_type_673 = torch.ops.prims.convert_element_type.default(primals_385, torch.bfloat16); primals_385 = None | |
| - permute_230 = torch.ops.aten.permute.default(convert_element_type_673, [1, 0]); convert_element_type_673 = None | |
| permute_647 = torch.ops.aten.permute.default(permute_230, [1, 0]); permute_230 = None | |
| - mm_258 = torch.ops.aten.mm.default(view_455, permute_647); permute_647 = None | |
| - view_456 = torch.ops.aten.view.default(mm_258, [4096, 112, 64]); mm_258 = None | |
| - permute_649 = torch.ops.aten.permute.default(view_456, [0, 2, 1]); view_456 = None | |
| - convert_element_type_1030 = torch.ops.prims.convert_element_type.default(mm_257, torch.float32); mm_257 = None | |
| + mm_258 = torch.ops.aten.mm.default(view_470, permute_647); permute_647 = None | |
| + view_471 = torch.ops.aten.view.default(mm_258, [4096, 112, 64]); mm_258 = None | |
| + permute_649 = torch.ops.aten.permute.default(view_471, [0, 2, 1]); view_471 = None | |
| + convert_element_type_1075 = torch.ops.prims.convert_element_type.default(mm_257, torch.float32); mm_257 = None | |
| clone_91 = torch.ops.aten.clone.default(permute_643, memory_format = torch.contiguous_format); permute_643 = None | |
| - view_457 = torch.ops.aten.view.default(clone_91, [4096, 7168]); clone_91 = None | |
| + view_472 = torch.ops.aten.view.default(clone_91, [4096, 7168]); clone_91 = None | |
| clone_92 = torch.ops.aten.clone.default(permute_649, memory_format = torch.contiguous_format); permute_649 = None | |
| - view_458 = torch.ops.aten.view.default(clone_92, [4096, 7168]); clone_92 = None | |
| + view_473 = torch.ops.aten.view.default(clone_92, [4096, 7168]); clone_92 = None | |
| full_default_289 = torch.ops.aten.full.default([4096, 7168], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - index_put_29 = torch.ops.aten.index_put.default(full_default_289, [sub_32], view_457, True) | |
| + index_put_29 = torch.ops.aten.index_put.default(full_default_289, [sub_32], view_472, True) | |
| full_default_290 = torch.ops.aten.full.default([4096, 14336], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_20 = torch.ops.aten.slice_scatter.default(full_default_290, index_put_29, 1, 7168, 9223372036854775807); index_put_29 = None | |
| - abs_52 = torch.ops.aten.abs.default(view_457) | |
| + abs_52 = torch.ops.aten.abs.default(view_472) | |
| amax_23 = torch.ops.aten.amax.default(abs_52, [-1], True); abs_52 = None | |
| - convert_element_type_1031 = torch.ops.prims.convert_element_type.default(amax_23, torch.float64); amax_23 = None | |
| - clamp_min_82 = torch.ops.aten.clamp_min.default(convert_element_type_1031, 1e-12); convert_element_type_1031 = None | |
| + convert_element_type_1076 = torch.ops.prims.convert_element_type.default(amax_23, torch.float64); amax_23 = None | |
| + clamp_min_82 = torch.ops.aten.clamp_min.default(convert_element_type_1076, 1e-12); convert_element_type_1076 = None | |
| reciprocal_48 = torch.ops.aten.reciprocal.default(clamp_min_82); clamp_min_82 = None | |
| - mul_828 = torch.ops.aten.mul.Tensor(reciprocal_48, 448.0); reciprocal_48 = None | |
| - convert_element_type_1032 = torch.ops.prims.convert_element_type.default(mul_828, torch.float32); mul_828 = None | |
| - log2_24 = torch.ops.aten.log2.default(convert_element_type_1032); convert_element_type_1032 = None | |
| + mul_858 = torch.ops.aten.mul.Tensor(reciprocal_48, 448.0); reciprocal_48 = None | |
| + convert_element_type_1077 = torch.ops.prims.convert_element_type.default(mul_858, torch.float32); mul_858 = None | |
| + log2_24 = torch.ops.aten.log2.default(convert_element_type_1077); convert_element_type_1077 = None | |
| floor_24 = torch.ops.aten.floor.default(log2_24); log2_24 = None | |
| exp2_24 = torch.ops.aten.exp2.default(floor_24); floor_24 = None | |
| - convert_element_type_1033 = torch.ops.prims.convert_element_type.default(view_457, torch.float32) | |
| - mul_829 = torch.ops.aten.mul.Tensor(convert_element_type_1033, exp2_24); convert_element_type_1033 = None | |
| - clamp_min_83 = torch.ops.aten.clamp_min.default(mul_829, -448.0); mul_829 = None | |
| + convert_element_type_1078 = torch.ops.prims.convert_element_type.default(view_472, torch.float32) | |
| + mul_859 = torch.ops.aten.mul.Tensor(convert_element_type_1078, exp2_24); convert_element_type_1078 = None | |
| + clamp_min_83 = torch.ops.aten.clamp_min.default(mul_859, -448.0); mul_859 = None | |
| clamp_max_54 = torch.ops.aten.clamp_max.default(clamp_min_83, 448.0); clamp_min_83 = None | |
| - convert_element_type_1034 = torch.ops.prims.convert_element_type.default(clamp_max_54, torch.float8_e4m3fn); clamp_max_54 = None | |
| + convert_element_type_1079 = torch.ops.prims.convert_element_type.default(clamp_max_54, torch.float8_e4m3fn); clamp_max_54 = None | |
| permute_228 = torch.ops.aten.permute.default(primals_384, [1, 0]); primals_384 = None | |
| abs_16 = torch.ops.aten.abs.default(permute_228) | |
| max_3 = torch.ops.aten.max.default(abs_16); abs_16 = None | |
| - convert_element_type_1035 = torch.ops.prims.convert_element_type.default(max_3, torch.float64); max_3 = None | |
| - clamp_min_84 = torch.ops.aten.clamp_min.default(convert_element_type_1035, 1e-12); convert_element_type_1035 = None | |
| + convert_element_type_1080 = torch.ops.prims.convert_element_type.default(max_3, torch.float64); max_3 = None | |
| + clamp_min_84 = torch.ops.aten.clamp_min.default(convert_element_type_1080, 1e-12); convert_element_type_1080 = None | |
| reciprocal_49 = torch.ops.aten.reciprocal.default(clamp_min_84); clamp_min_84 = None | |
| - mul_830 = torch.ops.aten.mul.Tensor(reciprocal_49, 448.0); reciprocal_49 = None | |
| - convert_element_type_1036 = torch.ops.prims.convert_element_type.default(mul_830, torch.float32); mul_830 = None | |
| - log2_25 = torch.ops.aten.log2.default(convert_element_type_1036); convert_element_type_1036 = None | |
| + mul_860 = torch.ops.aten.mul.Tensor(reciprocal_49, 448.0); reciprocal_49 = None | |
| + convert_element_type_1081 = torch.ops.prims.convert_element_type.default(mul_860, torch.float32); mul_860 = None | |
| + log2_25 = torch.ops.aten.log2.default(convert_element_type_1081); convert_element_type_1081 = None | |
| floor_25 = torch.ops.aten.floor.default(log2_25); log2_25 = None | |
| exp2_25 = torch.ops.aten.exp2.default(floor_25); floor_25 = None | |
| - mul_831 = torch.ops.aten.mul.Tensor(permute_228, exp2_25); permute_228 = None | |
| - clamp_min_85 = torch.ops.aten.clamp_min.default(mul_831, -448.0); mul_831 = None | |
| + mul_861 = torch.ops.aten.mul.Tensor(permute_228, exp2_25); permute_228 = None | |
| + clamp_min_85 = torch.ops.aten.clamp_min.default(mul_861, -448.0); mul_861 = None | |
| clamp_max_55 = torch.ops.aten.clamp_max.default(clamp_min_85, 448.0); clamp_min_85 = None | |
| - convert_element_type_1037 = torch.ops.prims.convert_element_type.default(clamp_max_55, torch.float8_e4m3fn); clamp_max_55 = None | |
| - clone_93 = torch.ops.aten.clone.default(convert_element_type_1037, memory_format = torch.contiguous_format); convert_element_type_1037 = None | |
| + convert_element_type_1082 = torch.ops.prims.convert_element_type.default(clamp_max_55, torch.float8_e4m3fn); clamp_max_55 = None | |
| + clone_93 = torch.ops.aten.clone.default(convert_element_type_1082, memory_format = torch.contiguous_format); convert_element_type_1082 = None | |
| permute_652 = torch.ops.aten.permute.default(clone_93, [1, 0]); clone_93 = None | |
| repeat_6 = torch.ops.aten.repeat.default(exp2_25, [4608]); exp2_25 = None | |
| - view_460 = torch.ops.aten.view.default(repeat_6, [1, -1]); repeat_6 = None | |
| + view_475 = torch.ops.aten.view.default(repeat_6, [1, -1]); repeat_6 = None | |
| reciprocal_50 = torch.ops.aten.reciprocal.default(exp2_24); exp2_24 = None | |
| - reciprocal_51 = torch.ops.aten.reciprocal.default(view_460); view_460 = None | |
| - mul_832 = torch.ops.aten.mul.Tensor(reciprocal_50, reciprocal_51); reciprocal_50 = reciprocal_51 = None | |
| - _scaled_mm_12 = torch.ops.aten._scaled_mm.default(convert_element_type_1034, permute_652, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1034 = permute_652 = None | |
| - mul_833 = torch.ops.aten.mul.Tensor(_scaled_mm_12, mul_832); _scaled_mm_12 = mul_832 = None | |
| - permute_653 = torch.ops.aten.permute.default(view_457, [1, 0]); view_457 = None | |
| - convert_element_type_620 = torch.ops.prims.convert_element_type.default(add_154, torch.float32); add_154 = None | |
| - pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_620, 2) | |
| + reciprocal_51 = torch.ops.aten.reciprocal.default(view_475); view_475 = None | |
| + mul_862 = torch.ops.aten.mul.Tensor(reciprocal_50, reciprocal_51); reciprocal_50 = reciprocal_51 = None | |
| + _scaled_mm_12 = torch.ops.aten._scaled_mm.default(convert_element_type_1079, permute_652, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1079 = permute_652 = None | |
| + mul_863 = torch.ops.aten.mul.Tensor(_scaled_mm_12, mul_862); _scaled_mm_12 = mul_862 = None | |
| + convert_element_type_1083 = torch.ops.prims.convert_element_type.default(mul_863, torch.bfloat16); mul_863 = None | |
| + permute_653 = torch.ops.aten.permute.default(view_472, [1, 0]); view_472 = None | |
| + convert_element_type_618 = torch.ops.prims.convert_element_type.default(add_154, torch.float32); add_154 = None | |
| + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_618, 2) | |
| mean_48 = torch.ops.aten.mean.dim(pow_49, [1], True); pow_49 = None | |
| add_156 = torch.ops.aten.add.Scalar(mean_48, 1.1920928955078125e-07); mean_48 = None | |
| rsqrt_69 = torch.ops.aten.rsqrt.default(add_156); add_156 = None | |
| - mul_210 = torch.ops.aten.mul.Tensor(convert_element_type_620, rsqrt_69); convert_element_type_620 = None | |
| + mul_210 = torch.ops.aten.mul.Tensor(convert_element_type_618, rsqrt_69) | |
| mul_211 = torch.ops.aten.mul.Tensor(mul_210, primals_366) | |
| - sigmoid_26 = torch.ops.aten.sigmoid.default(mul_211) | |
| - mul_213 = torch.ops.aten.mul.Tensor(mul_211, sigmoid_26) | |
| + convert_element_type_619 = torch.ops.prims.convert_element_type.default(mul_211, torch.bfloat16); mul_211 = None | |
| + sigmoid_26 = torch.ops.aten.sigmoid.default(convert_element_type_619) | |
| + mul_213 = torch.ops.aten.mul.Tensor(convert_element_type_619, sigmoid_26) | |
| slice_51 = torch.ops.aten.slice.Tensor(mm_83, 1, 4608, 9223372036854775807) | |
| index_30 = torch.ops.aten.index.Tensor(slice_51, [sub_32]); slice_51 = None | |
| add_160 = torch.ops.aten.add.Tensor(mm_84, index_30); mm_84 = index_30 = None | |
| add_162 = torch.ops.aten.add.Tensor(mul_213, add_160); add_160 = None | |
| - pow_53 = torch.ops.aten.pow.Tensor_Scalar(add_162, 2) | |
| + convert_element_type_638 = torch.ops.prims.convert_element_type.default(add_162, torch.float32); add_162 = None | |
| + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_638, 2) | |
| mean_52 = torch.ops.aten.mean.dim(pow_53, [1], True); pow_53 = None | |
| add_164 = torch.ops.aten.add.Scalar(mean_52, 1.1920928955078125e-07); mean_52 = None | |
| rsqrt_73 = torch.ops.aten.rsqrt.default(add_164); add_164 = None | |
| - mul_222 = torch.ops.aten.mul.Tensor(add_162, rsqrt_73); add_162 = None | |
| + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_638, rsqrt_73) | |
| mul_223 = torch.ops.aten.mul.Tensor(mul_222, primals_374) | |
| - sigmoid_30 = torch.ops.aten.sigmoid.default(mul_223) | |
| - mul_225 = torch.ops.aten.mul.Tensor(mul_223, sigmoid_30) | |
| + convert_element_type_639 = torch.ops.prims.convert_element_type.default(mul_223, torch.bfloat16); mul_223 = None | |
| + sigmoid_30 = torch.ops.aten.sigmoid.default(convert_element_type_639) | |
| + mul_225 = torch.ops.aten.mul.Tensor(convert_element_type_639, sigmoid_30) | |
| slice_55 = torch.ops.aten.slice.Tensor(mm_87, 1, 4608, 9223372036854775807) | |
| index_32 = torch.ops.aten.index.Tensor(slice_55, [sub_32]); slice_55 = None | |
| add_168 = torch.ops.aten.add.Tensor(mm_88, index_32); mm_88 = index_32 = None | |
| add_170 = torch.ops.aten.add.Tensor(mul_225, add_168); add_168 = None | |
| - pow_57 = torch.ops.aten.pow.Tensor_Scalar(add_170, 2) | |
| + convert_element_type_658 = torch.ops.prims.convert_element_type.default(add_170, torch.float32); add_170 = None | |
| + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_658, 2) | |
| mean_56 = torch.ops.aten.mean.dim(pow_57, [1], True); pow_57 = None | |
| add_172 = torch.ops.aten.add.Scalar(mean_56, 1.1920928955078125e-07); mean_56 = None | |
| rsqrt_77 = torch.ops.aten.rsqrt.default(add_172); add_172 = None | |
| - mul_234 = torch.ops.aten.mul.Tensor(add_170, rsqrt_77); add_170 = None | |
| + mul_234 = torch.ops.aten.mul.Tensor(convert_element_type_658, rsqrt_77) | |
| mul_235 = torch.ops.aten.mul.Tensor(mul_234, primals_382) | |
| - sigmoid_34 = torch.ops.aten.sigmoid.default(mul_235) | |
| - mul_237 = torch.ops.aten.mul.Tensor(mul_235, sigmoid_34) | |
| - convert_element_type_665 = torch.ops.prims.convert_element_type.default(mul_237, torch.bfloat16); mul_237 = None | |
| - mm_259 = torch.ops.aten.mm.default(permute_653, convert_element_type_665); permute_653 = convert_element_type_665 = None | |
| + convert_element_type_659 = torch.ops.prims.convert_element_type.default(mul_235, torch.bfloat16); mul_235 = None | |
| + sigmoid_34 = torch.ops.aten.sigmoid.default(convert_element_type_659) | |
| + mul_237 = torch.ops.aten.mul.Tensor(convert_element_type_659, sigmoid_34) | |
| + mm_259 = torch.ops.aten.mm.default(permute_653, mul_237); permute_653 = mul_237 = None | |
| permute_654 = torch.ops.aten.permute.default(mm_259, [1, 0]); mm_259 = None | |
| - convert_element_type_1041 = torch.ops.prims.convert_element_type.default(permute_654, torch.float32); permute_654 = None | |
| - permute_655 = torch.ops.aten.permute.default(convert_element_type_1041, [1, 0]); convert_element_type_1041 = None | |
| - convert_element_type_default_9 = torch.ops.prims.convert_element_type.default(mul_833, torch.float32); mul_833 = None | |
| - slice_scatter_21 = torch.ops.aten.slice_scatter.default(full_default_290, view_458, 1, 0, 7168); view_458 = None | |
| - add_396 = torch.ops.aten.add.Tensor(slice_scatter_20, slice_scatter_21); slice_scatter_20 = slice_scatter_21 = None | |
| - abs_54 = torch.ops.aten.abs.default(add_396) | |
| + convert_element_type_1086 = torch.ops.prims.convert_element_type.default(permute_654, torch.float32); permute_654 = None | |
| + permute_655 = torch.ops.aten.permute.default(convert_element_type_1086, [1, 0]); convert_element_type_1086 = None | |
| + slice_scatter_21 = torch.ops.aten.slice_scatter.default(full_default_290, view_473, 1, 0, 7168); view_473 = None | |
| + add_411 = torch.ops.aten.add.Tensor(slice_scatter_20, slice_scatter_21); slice_scatter_20 = slice_scatter_21 = None | |
| + abs_54 = torch.ops.aten.abs.default(add_411) | |
| amax_24 = torch.ops.aten.amax.default(abs_54, [-1], True); abs_54 = None | |
| - convert_element_type_1043 = torch.ops.prims.convert_element_type.default(amax_24, torch.float64); amax_24 = None | |
| - clamp_min_86 = torch.ops.aten.clamp_min.default(convert_element_type_1043, 1e-12); convert_element_type_1043 = None | |
| + convert_element_type_1087 = torch.ops.prims.convert_element_type.default(amax_24, torch.float64); amax_24 = None | |
| + clamp_min_86 = torch.ops.aten.clamp_min.default(convert_element_type_1087, 1e-12); convert_element_type_1087 = None | |
| reciprocal_52 = torch.ops.aten.reciprocal.default(clamp_min_86); clamp_min_86 = None | |
| - mul_834 = torch.ops.aten.mul.Tensor(reciprocal_52, 448.0); reciprocal_52 = None | |
| - convert_element_type_1044 = torch.ops.prims.convert_element_type.default(mul_834, torch.float32); mul_834 = None | |
| - log2_26 = torch.ops.aten.log2.default(convert_element_type_1044); convert_element_type_1044 = None | |
| + mul_864 = torch.ops.aten.mul.Tensor(reciprocal_52, 448.0); reciprocal_52 = None | |
| + convert_element_type_1088 = torch.ops.prims.convert_element_type.default(mul_864, torch.float32); mul_864 = None | |
| + log2_26 = torch.ops.aten.log2.default(convert_element_type_1088); convert_element_type_1088 = None | |
| floor_26 = torch.ops.aten.floor.default(log2_26); log2_26 = None | |
| exp2_26 = torch.ops.aten.exp2.default(floor_26); floor_26 = None | |
| - convert_element_type_1045 = torch.ops.prims.convert_element_type.default(add_396, torch.float32) | |
| - mul_835 = torch.ops.aten.mul.Tensor(convert_element_type_1045, exp2_26); convert_element_type_1045 = None | |
| - clamp_min_87 = torch.ops.aten.clamp_min.default(mul_835, -448.0); mul_835 = None | |
| + convert_element_type_1089 = torch.ops.prims.convert_element_type.default(add_411, torch.float32) | |
| + mul_865 = torch.ops.aten.mul.Tensor(convert_element_type_1089, exp2_26); convert_element_type_1089 = None | |
| + clamp_min_87 = torch.ops.aten.clamp_min.default(mul_865, -448.0); mul_865 = None | |
| clamp_max_56 = torch.ops.aten.clamp_max.default(clamp_min_87, 448.0); clamp_min_87 = None | |
| - convert_element_type_1046 = torch.ops.prims.convert_element_type.default(clamp_max_56, torch.float8_e4m3fn); clamp_max_56 = None | |
| + convert_element_type_1090 = torch.ops.prims.convert_element_type.default(clamp_max_56, torch.float8_e4m3fn); clamp_max_56 = None | |
| permute_227 = torch.ops.aten.permute.default(primals_383, [1, 0]); primals_383 = None | |
| abs_14 = torch.ops.aten.abs.default(permute_227) | |
| max_4 = torch.ops.aten.max.default(abs_14); abs_14 = None | |
| - convert_element_type_1047 = torch.ops.prims.convert_element_type.default(max_4, torch.float64); max_4 = None | |
| - clamp_min_88 = torch.ops.aten.clamp_min.default(convert_element_type_1047, 1e-12); convert_element_type_1047 = None | |
| + convert_element_type_1091 = torch.ops.prims.convert_element_type.default(max_4, torch.float64); max_4 = None | |
| + clamp_min_88 = torch.ops.aten.clamp_min.default(convert_element_type_1091, 1e-12); convert_element_type_1091 = None | |
| reciprocal_53 = torch.ops.aten.reciprocal.default(clamp_min_88); clamp_min_88 = None | |
| - mul_836 = torch.ops.aten.mul.Tensor(reciprocal_53, 448.0); reciprocal_53 = None | |
| - convert_element_type_1048 = torch.ops.prims.convert_element_type.default(mul_836, torch.float32); mul_836 = None | |
| - log2_27 = torch.ops.aten.log2.default(convert_element_type_1048); convert_element_type_1048 = None | |
| + mul_866 = torch.ops.aten.mul.Tensor(reciprocal_53, 448.0); reciprocal_53 = None | |
| + convert_element_type_1092 = torch.ops.prims.convert_element_type.default(mul_866, torch.float32); mul_866 = None | |
| + log2_27 = torch.ops.aten.log2.default(convert_element_type_1092); convert_element_type_1092 = None | |
| floor_27 = torch.ops.aten.floor.default(log2_27); log2_27 = None | |
| exp2_27 = torch.ops.aten.exp2.default(floor_27); floor_27 = None | |
| - mul_837 = torch.ops.aten.mul.Tensor(permute_227, exp2_27); permute_227 = None | |
| - clamp_min_89 = torch.ops.aten.clamp_min.default(mul_837, -448.0); mul_837 = None | |
| + mul_867 = torch.ops.aten.mul.Tensor(permute_227, exp2_27); permute_227 = None | |
| + clamp_min_89 = torch.ops.aten.clamp_min.default(mul_867, -448.0); mul_867 = None | |
| clamp_max_57 = torch.ops.aten.clamp_max.default(clamp_min_89, 448.0); clamp_min_89 = None | |
| - convert_element_type_1049 = torch.ops.prims.convert_element_type.default(clamp_max_57, torch.float8_e4m3fn); clamp_max_57 = None | |
| - clone_94 = torch.ops.aten.clone.default(convert_element_type_1049, memory_format = torch.contiguous_format); convert_element_type_1049 = None | |
| + convert_element_type_1093 = torch.ops.prims.convert_element_type.default(clamp_max_57, torch.float8_e4m3fn); clamp_max_57 = None | |
| + clone_94 = torch.ops.aten.clone.default(convert_element_type_1093, memory_format = torch.contiguous_format); convert_element_type_1093 = None | |
| permute_658 = torch.ops.aten.permute.default(clone_94, [1, 0]); clone_94 = None | |
| repeat_7 = torch.ops.aten.repeat.default(exp2_27, [4608]); exp2_27 = None | |
| - view_465 = torch.ops.aten.view.default(repeat_7, [1, -1]); repeat_7 = None | |
| + view_479 = torch.ops.aten.view.default(repeat_7, [1, -1]); repeat_7 = None | |
| reciprocal_54 = torch.ops.aten.reciprocal.default(exp2_26); exp2_26 = None | |
| - reciprocal_55 = torch.ops.aten.reciprocal.default(view_465); view_465 = None | |
| - mul_838 = torch.ops.aten.mul.Tensor(reciprocal_54, reciprocal_55); reciprocal_54 = reciprocal_55 = None | |
| - _scaled_mm_13 = torch.ops.aten._scaled_mm.default(convert_element_type_1046, permute_658, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1046 = permute_658 = None | |
| - mul_839 = torch.ops.aten.mul.Tensor(_scaled_mm_13, mul_838); _scaled_mm_13 = mul_838 = None | |
| - permute_659 = torch.ops.aten.permute.default(add_396, [1, 0]); add_396 = None | |
| - convert_element_type_619 = torch.ops.prims.convert_element_type.default(slice_46, torch.float32); slice_46 = None | |
| - pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_619, 2) | |
| + reciprocal_55 = torch.ops.aten.reciprocal.default(view_479); view_479 = None | |
| + mul_868 = torch.ops.aten.mul.Tensor(reciprocal_54, reciprocal_55); reciprocal_54 = reciprocal_55 = None | |
| + _scaled_mm_13 = torch.ops.aten._scaled_mm.default(convert_element_type_1090, permute_658, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1090 = permute_658 = None | |
| + mul_869 = torch.ops.aten.mul.Tensor(_scaled_mm_13, mul_868); _scaled_mm_13 = mul_868 = None | |
| + convert_element_type_1094 = torch.ops.prims.convert_element_type.default(mul_869, torch.bfloat16); mul_869 = None | |
| + permute_659 = torch.ops.aten.permute.default(add_411, [1, 0]); add_411 = None | |
| + convert_element_type_616 = torch.ops.prims.convert_element_type.default(slice_46, torch.float32); slice_46 = None | |
| + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_616, 2) | |
| mean_47 = torch.ops.aten.mean.dim(pow_48, [1], True); pow_48 = None | |
| add_155 = torch.ops.aten.add.Scalar(mean_47, 1.1920928955078125e-07); mean_47 = None | |
| rsqrt_68 = torch.ops.aten.rsqrt.default(add_155); add_155 = None | |
| - mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_619, rsqrt_68); convert_element_type_619 = None | |
| + mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_616, rsqrt_68) | |
| mul_209 = torch.ops.aten.mul.Tensor(mul_208, primals_365) | |
| - sigmoid_25 = torch.ops.aten.sigmoid.default(mul_209) | |
| - mul_212 = torch.ops.aten.mul.Tensor(mul_209, sigmoid_25) | |
| + convert_element_type_617 = torch.ops.prims.convert_element_type.default(mul_209, torch.bfloat16); mul_209 = None | |
| + sigmoid_25 = torch.ops.aten.sigmoid.default(convert_element_type_617) | |
| + mul_212 = torch.ops.aten.mul.Tensor(convert_element_type_617, sigmoid_25) | |
| slice_50 = torch.ops.aten.slice.Tensor(mm_83, 1, 0, 4608); mm_83 = None | |
| add_161 = torch.ops.aten.add.Tensor(mul_212, slice_50); slice_50 = None | |
| - pow_52 = torch.ops.aten.pow.Tensor_Scalar(add_161, 2) | |
| + convert_element_type_636 = torch.ops.prims.convert_element_type.default(add_161, torch.float32); add_161 = None | |
| + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_636, 2) | |
| mean_51 = torch.ops.aten.mean.dim(pow_52, [1], True); pow_52 = None | |
| add_163 = torch.ops.aten.add.Scalar(mean_51, 1.1920928955078125e-07); mean_51 = None | |
| rsqrt_72 = torch.ops.aten.rsqrt.default(add_163); add_163 = None | |
| - mul_220 = torch.ops.aten.mul.Tensor(add_161, rsqrt_72); add_161 = None | |
| + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_636, rsqrt_72) | |
| mul_221 = torch.ops.aten.mul.Tensor(mul_220, primals_373) | |
| - sigmoid_29 = torch.ops.aten.sigmoid.default(mul_221) | |
| - mul_224 = torch.ops.aten.mul.Tensor(mul_221, sigmoid_29) | |
| + convert_element_type_637 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None | |
| + sigmoid_29 = torch.ops.aten.sigmoid.default(convert_element_type_637) | |
| + mul_224 = torch.ops.aten.mul.Tensor(convert_element_type_637, sigmoid_29) | |
| slice_54 = torch.ops.aten.slice.Tensor(mm_87, 1, 0, 4608); mm_87 = None | |
| add_169 = torch.ops.aten.add.Tensor(mul_224, slice_54); slice_54 = None | |
| - pow_56 = torch.ops.aten.pow.Tensor_Scalar(add_169, 2) | |
| + convert_element_type_656 = torch.ops.prims.convert_element_type.default(add_169, torch.float32); add_169 = None | |
| + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_656, 2) | |
| mean_55 = torch.ops.aten.mean.dim(pow_56, [1], True); pow_56 = None | |
| add_171 = torch.ops.aten.add.Scalar(mean_55, 1.1920928955078125e-07); mean_55 = None | |
| rsqrt_76 = torch.ops.aten.rsqrt.default(add_171); add_171 = None | |
| - mul_232 = torch.ops.aten.mul.Tensor(add_169, rsqrt_76); add_169 = None | |
| + mul_232 = torch.ops.aten.mul.Tensor(convert_element_type_656, rsqrt_76) | |
| mul_233 = torch.ops.aten.mul.Tensor(mul_232, primals_381) | |
| - sigmoid_33 = torch.ops.aten.sigmoid.default(mul_233) | |
| - mul_236 = torch.ops.aten.mul.Tensor(mul_233, sigmoid_33) | |
| - convert_element_type_657 = torch.ops.prims.convert_element_type.default(mul_236, torch.bfloat16); mul_236 = None | |
| - mm_260 = torch.ops.aten.mm.default(permute_659, convert_element_type_657); permute_659 = convert_element_type_657 = None | |
| + convert_element_type_657 = torch.ops.prims.convert_element_type.default(mul_233, torch.bfloat16); mul_233 = None | |
| + sigmoid_33 = torch.ops.aten.sigmoid.default(convert_element_type_657) | |
| + mul_236 = torch.ops.aten.mul.Tensor(convert_element_type_657, sigmoid_33) | |
| + mm_260 = torch.ops.aten.mm.default(permute_659, mul_236); permute_659 = mul_236 = None | |
| permute_660 = torch.ops.aten.permute.default(mm_260, [1, 0]); mm_260 = None | |
| - convert_element_type_1053 = torch.ops.prims.convert_element_type.default(permute_660, torch.float32); permute_660 = None | |
| - permute_661 = torch.ops.aten.permute.default(convert_element_type_1053, [1, 0]); convert_element_type_1053 = None | |
| - convert_element_type_default_8 = torch.ops.prims.convert_element_type.default(mul_839, torch.float32); mul_839 = None | |
| - mul_840 = torch.ops.aten.mul.Tensor(convert_element_type_default_9, mul_235); mul_235 = None | |
| - mul_841 = torch.ops.aten.mul.Tensor(convert_element_type_default_9, sigmoid_34); convert_element_type_default_9 = None | |
| - sub_220 = torch.ops.aten.sub.Tensor(1, sigmoid_34) | |
| - mul_842 = torch.ops.aten.mul.Tensor(sigmoid_34, sub_220); sigmoid_34 = sub_220 = None | |
| - mul_843 = torch.ops.aten.mul.Tensor(mul_840, mul_842); mul_840 = mul_842 = None | |
| - add_397 = torch.ops.aten.add.Tensor(mul_841, mul_843); mul_841 = mul_843 = None | |
| - mul_844 = torch.ops.aten.mul.Tensor(convert_element_type_default_8, mul_233); mul_233 = None | |
| - mul_845 = torch.ops.aten.mul.Tensor(convert_element_type_default_8, sigmoid_33); convert_element_type_default_8 = None | |
| - sub_221 = torch.ops.aten.sub.Tensor(1, sigmoid_33) | |
| - mul_846 = torch.ops.aten.mul.Tensor(sigmoid_33, sub_221); sigmoid_33 = sub_221 = None | |
| - mul_847 = torch.ops.aten.mul.Tensor(mul_844, mul_846); mul_844 = mul_846 = None | |
| - add_398 = torch.ops.aten.add.Tensor(mul_845, mul_847); mul_845 = mul_847 = None | |
| - mul_848 = torch.ops.aten.mul.Tensor(add_397, primals_382); primals_382 = None | |
| - mul_850 = torch.ops.aten.mul.Tensor(mul_234, mul_848) | |
| - sum_130 = torch.ops.aten.sum.dim_IntList(mul_850, [1], True); mul_850 = None | |
| - div_105 = torch.ops.aten.div.Tensor(mul_234, 4608) | |
| - mul_851 = torch.ops.aten.mul.Tensor(div_105, sum_130); div_105 = sum_130 = None | |
| - sub_222 = torch.ops.aten.sub.Tensor(mul_848, mul_851); mul_848 = mul_851 = None | |
| - mul_852 = torch.ops.aten.mul.Tensor(sub_222, rsqrt_77); sub_222 = rsqrt_77 = None | |
| - mul_853 = torch.ops.aten.mul.Tensor(add_397, mul_234); add_397 = mul_234 = None | |
| - sum_131 = torch.ops.aten.sum.dim_IntList(mul_853, [0]); mul_853 = None | |
| - mul_854 = torch.ops.aten.mul.Tensor(add_398, primals_381); primals_381 = None | |
| - mul_856 = torch.ops.aten.mul.Tensor(mul_232, mul_854) | |
| - sum_132 = torch.ops.aten.sum.dim_IntList(mul_856, [1], True); mul_856 = None | |
| - div_106 = torch.ops.aten.div.Tensor(mul_232, 4608) | |
| - mul_857 = torch.ops.aten.mul.Tensor(div_106, sum_132); div_106 = sum_132 = None | |
| - sub_223 = torch.ops.aten.sub.Tensor(mul_854, mul_857); mul_854 = mul_857 = None | |
| - mul_858 = torch.ops.aten.mul.Tensor(sub_223, rsqrt_76); sub_223 = rsqrt_76 = None | |
| - mul_859 = torch.ops.aten.mul.Tensor(add_398, mul_232); add_398 = mul_232 = None | |
| - sum_133 = torch.ops.aten.sum.dim_IntList(mul_859, [0]); mul_859 = None | |
| - convert_element_type_1055 = torch.ops.prims.convert_element_type.default(mul_852, torch.bfloat16) | |
| - convert_element_type_1056 = torch.ops.prims.convert_element_type.default(mul_858, torch.bfloat16) | |
| + convert_element_type_1097 = torch.ops.prims.convert_element_type.default(permute_660, torch.float32); permute_660 = None | |
| + permute_661 = torch.ops.aten.permute.default(convert_element_type_1097, [1, 0]); convert_element_type_1097 = None | |
| + mul_870 = torch.ops.aten.mul.Tensor(convert_element_type_1083, convert_element_type_659); convert_element_type_659 = None | |
| + mul_871 = torch.ops.aten.mul.Tensor(convert_element_type_1083, sigmoid_34); convert_element_type_1083 = None | |
| + convert_element_type_1098 = torch.ops.prims.convert_element_type.default(mul_870, torch.float32); mul_870 = None | |
| + convert_element_type_1099 = torch.ops.prims.convert_element_type.default(sigmoid_34, torch.float32); sigmoid_34 = None | |
| + sub_205 = torch.ops.aten.sub.Tensor(1, convert_element_type_1099) | |
| + mul_872 = torch.ops.aten.mul.Tensor(convert_element_type_1099, sub_205); convert_element_type_1099 = sub_205 = None | |
| + mul_873 = torch.ops.aten.mul.Tensor(convert_element_type_1098, mul_872); convert_element_type_1098 = mul_872 = None | |
| + convert_element_type_1100 = torch.ops.prims.convert_element_type.default(mul_873, torch.bfloat16); mul_873 = None | |
| + add_412 = torch.ops.aten.add.Tensor(mul_871, convert_element_type_1100); mul_871 = convert_element_type_1100 = None | |
| + mul_874 = torch.ops.aten.mul.Tensor(convert_element_type_1094, convert_element_type_657); convert_element_type_657 = None | |
| + mul_875 = torch.ops.aten.mul.Tensor(convert_element_type_1094, sigmoid_33); convert_element_type_1094 = None | |
| + convert_element_type_1101 = torch.ops.prims.convert_element_type.default(mul_874, torch.float32); mul_874 = None | |
| + convert_element_type_1102 = torch.ops.prims.convert_element_type.default(sigmoid_33, torch.float32); sigmoid_33 = None | |
| + sub_206 = torch.ops.aten.sub.Tensor(1, convert_element_type_1102) | |
| + mul_876 = torch.ops.aten.mul.Tensor(convert_element_type_1102, sub_206); convert_element_type_1102 = sub_206 = None | |
| + mul_877 = torch.ops.aten.mul.Tensor(convert_element_type_1101, mul_876); convert_element_type_1101 = mul_876 = None | |
| + convert_element_type_1103 = torch.ops.prims.convert_element_type.default(mul_877, torch.bfloat16); mul_877 = None | |
| + add_413 = torch.ops.aten.add.Tensor(mul_875, convert_element_type_1103); mul_875 = convert_element_type_1103 = None | |
| + convert_element_type_1104 = torch.ops.prims.convert_element_type.default(add_412, torch.float32); add_412 = None | |
| + mul_878 = torch.ops.aten.mul.Tensor(convert_element_type_1104, mul_234); mul_234 = None | |
| + mul_879 = torch.ops.aten.mul.Tensor(convert_element_type_1104, primals_382); convert_element_type_1104 = primals_382 = None | |
| + sum_130 = torch.ops.aten.sum.dim_IntList(mul_878, [0], True); mul_878 = None | |
| + view_484 = torch.ops.aten.view.default(sum_130, [4608]); sum_130 = None | |
| + mul_880 = torch.ops.aten.mul.Tensor(mul_879, convert_element_type_658) | |
| + mul_881 = torch.ops.aten.mul.Tensor(mul_879, rsqrt_77); mul_879 = None | |
| + sum_131 = torch.ops.aten.sum.dim_IntList(mul_880, [1], True); mul_880 = None | |
| + mul_882 = torch.ops.aten.mul.Scalar(sum_131, -0.5); sum_131 = None | |
| + pow_114 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_77, 3); rsqrt_77 = None | |
| + mul_883 = torch.ops.aten.mul.Tensor(mul_882, pow_114); mul_882 = pow_114 = None | |
| + expand_78 = torch.ops.aten.expand.default(mul_883, [4096, 4608]); mul_883 = None | |
| + div_105 = torch.ops.aten.div.Scalar(expand_78, 4608); expand_78 = None | |
| + pow_115 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_658, 1.0); convert_element_type_658 = None | |
| + mul_884 = torch.ops.aten.mul.Scalar(pow_115, 2.0); pow_115 = None | |
| + mul_885 = torch.ops.aten.mul.Tensor(div_105, mul_884); div_105 = mul_884 = None | |
| + add_414 = torch.ops.aten.add.Tensor(mul_881, mul_885); mul_881 = mul_885 = None | |
| + convert_element_type_1105 = torch.ops.prims.convert_element_type.default(add_414, torch.bfloat16); add_414 = None | |
| + convert_element_type_1106 = torch.ops.prims.convert_element_type.default(add_413, torch.float32); add_413 = None | |
| + mul_886 = torch.ops.aten.mul.Tensor(convert_element_type_1106, mul_232); mul_232 = None | |
| + mul_887 = torch.ops.aten.mul.Tensor(convert_element_type_1106, primals_381); convert_element_type_1106 = primals_381 = None | |
| + sum_132 = torch.ops.aten.sum.dim_IntList(mul_886, [0], True); mul_886 = None | |
| + view_485 = torch.ops.aten.view.default(sum_132, [4608]); sum_132 = None | |
| + mul_888 = torch.ops.aten.mul.Tensor(mul_887, convert_element_type_656) | |
| + mul_889 = torch.ops.aten.mul.Tensor(mul_887, rsqrt_76); mul_887 = None | |
| + sum_133 = torch.ops.aten.sum.dim_IntList(mul_888, [1], True); mul_888 = None | |
| + mul_890 = torch.ops.aten.mul.Scalar(sum_133, -0.5); sum_133 = None | |
| + pow_116 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_76, 3); rsqrt_76 = None | |
| + mul_891 = torch.ops.aten.mul.Tensor(mul_890, pow_116); mul_890 = pow_116 = None | |
| + expand_79 = torch.ops.aten.expand.default(mul_891, [4096, 4608]); mul_891 = None | |
| + div_106 = torch.ops.aten.div.Scalar(expand_79, 4608); expand_79 = None | |
| + pow_117 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_656, 1.0); convert_element_type_656 = None | |
| + mul_892 = torch.ops.aten.mul.Scalar(pow_117, 2.0); pow_117 = None | |
| + mul_893 = torch.ops.aten.mul.Tensor(div_106, mul_892); div_106 = mul_892 = None | |
| + add_415 = torch.ops.aten.add.Tensor(mul_889, mul_893); mul_889 = mul_893 = None | |
| + convert_element_type_1107 = torch.ops.prims.convert_element_type.default(add_415, torch.bfloat16); add_415 = None | |
| full_default_296 = torch.ops.aten.full.default([4096, 4608], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - index_put_30 = torch.ops.aten.index_put.default(full_default_296, [sub_32], convert_element_type_1055, True) | |
| + index_put_30 = torch.ops.aten.index_put.default(full_default_296, [sub_32], convert_element_type_1105, True) | |
| full_default_297 = torch.ops.aten.full.default([4096, 9216], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_22 = torch.ops.aten.slice_scatter.default(full_default_297, index_put_30, 1, 4608, 9223372036854775807); index_put_30 = None | |
| - permute_662 = torch.ops.aten.permute.default(convert_element_type_1055, [1, 0]) | |
| + permute_662 = torch.ops.aten.permute.default(convert_element_type_1105, [1, 0]) | |
| slice_53 = torch.ops.aten.slice.Tensor(mm_85, 1, 1536, 9223372036854775807) | |
| index_31 = torch.ops.aten.index.Tensor(slice_53, [sub_32]); slice_53 = None | |
| add_165 = torch.ops.aten.add.Tensor(mm_86, index_31); mm_86 = index_31 = None | |
| @@ -6573,307 +3633,394 @@ | |
| mean_54 = torch.ops.aten.mean.dim(pow_55, [1], True); pow_55 = None | |
| add_167 = torch.ops.aten.add.Scalar(mean_54, 1.1920928955078125e-07); mean_54 = None | |
| rsqrt_75 = torch.ops.aten.rsqrt.default(add_167); add_167 = None | |
| - mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_75); convert_element_type_648 = None | |
| + mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_75) | |
| mul_229 = torch.ops.aten.mul.Tensor(mul_228, primals_378) | |
| - sigmoid_32 = torch.ops.aten.sigmoid.default(mul_229) | |
| - mul_231 = torch.ops.aten.mul.Tensor(mul_229, sigmoid_32) | |
| - convert_element_type_653 = torch.ops.prims.convert_element_type.default(mul_231, torch.bfloat16); mul_231 = None | |
| - mm_261 = torch.ops.aten.mm.default(permute_662, convert_element_type_653); permute_662 = convert_element_type_653 = None | |
| - convert_element_type_654 = torch.ops.prims.convert_element_type.default(primals_380, torch.bfloat16); primals_380 = None | |
| - permute_226 = torch.ops.aten.permute.default(convert_element_type_654, [1, 0]); convert_element_type_654 = None | |
| + convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None | |
| + sigmoid_32 = torch.ops.aten.sigmoid.default(convert_element_type_649) | |
| + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_649, sigmoid_32) | |
| + mm_261 = torch.ops.aten.mm.default(permute_662, mul_231); permute_662 = mul_231 = None | |
| + convert_element_type_653 = torch.ops.prims.convert_element_type.default(primals_380, torch.bfloat16); primals_380 = None | |
| + permute_226 = torch.ops.aten.permute.default(convert_element_type_653, [1, 0]); convert_element_type_653 = None | |
| permute_664 = torch.ops.aten.permute.default(permute_226, [1, 0]); permute_226 = None | |
| - mm_262 = torch.ops.aten.mm.default(convert_element_type_1055, permute_664); convert_element_type_1055 = permute_664 = None | |
| - convert_element_type_1061 = torch.ops.prims.convert_element_type.default(mm_261, torch.float32); mm_261 = None | |
| - convert_element_type_1062 = torch.ops.prims.convert_element_type.default(mm_262, torch.float32); mm_262 = None | |
| - slice_scatter_23 = torch.ops.aten.slice_scatter.default(full_default_297, convert_element_type_1056, 1, 0, 4608); convert_element_type_1056 = None | |
| - add_399 = torch.ops.aten.add.Tensor(slice_scatter_22, slice_scatter_23); slice_scatter_22 = slice_scatter_23 = None | |
| - permute_666 = torch.ops.aten.permute.default(add_399, [1, 0]) | |
| + mm_262 = torch.ops.aten.mm.default(convert_element_type_1105, permute_664); permute_664 = None | |
| + convert_element_type_1112 = torch.ops.prims.convert_element_type.default(mm_261, torch.float32); mm_261 = None | |
| + slice_scatter_23 = torch.ops.aten.slice_scatter.default(full_default_297, convert_element_type_1107, 1, 0, 4608) | |
| + add_416 = torch.ops.aten.add.Tensor(slice_scatter_22, slice_scatter_23); slice_scatter_22 = slice_scatter_23 = None | |
| + permute_666 = torch.ops.aten.permute.default(add_416, [1, 0]) | |
| slice_52 = torch.ops.aten.slice.Tensor(mm_85, 1, 0, 1536); mm_85 = None | |
| - convert_element_type_647 = torch.ops.prims.convert_element_type.default(slice_52, torch.float32); slice_52 = None | |
| - pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_647, 2) | |
| + convert_element_type_646 = torch.ops.prims.convert_element_type.default(slice_52, torch.float32); slice_52 = None | |
| + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_646, 2) | |
| mean_53 = torch.ops.aten.mean.dim(pow_54, [1], True); pow_54 = None | |
| add_166 = torch.ops.aten.add.Scalar(mean_53, 1.1920928955078125e-07); mean_53 = None | |
| rsqrt_74 = torch.ops.aten.rsqrt.default(add_166); add_166 = None | |
| - mul_226 = torch.ops.aten.mul.Tensor(convert_element_type_647, rsqrt_74); convert_element_type_647 = None | |
| + mul_226 = torch.ops.aten.mul.Tensor(convert_element_type_646, rsqrt_74) | |
| mul_227 = torch.ops.aten.mul.Tensor(mul_226, primals_377) | |
| - sigmoid_31 = torch.ops.aten.sigmoid.default(mul_227) | |
| - mul_230 = torch.ops.aten.mul.Tensor(mul_227, sigmoid_31) | |
| - convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_230, torch.bfloat16); mul_230 = None | |
| - mm_263 = torch.ops.aten.mm.default(permute_666, convert_element_type_649); permute_666 = convert_element_type_649 = None | |
| + convert_element_type_647 = torch.ops.prims.convert_element_type.default(mul_227, torch.bfloat16); mul_227 = None | |
| + sigmoid_31 = torch.ops.aten.sigmoid.default(convert_element_type_647) | |
| + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_647, sigmoid_31) | |
| + mm_263 = torch.ops.aten.mm.default(permute_666, mul_230); permute_666 = mul_230 = None | |
| convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_379, torch.bfloat16); primals_379 = None | |
| permute_225 = torch.ops.aten.permute.default(convert_element_type_650, [1, 0]); convert_element_type_650 = None | |
| permute_668 = torch.ops.aten.permute.default(permute_225, [1, 0]); permute_225 = None | |
| - mm_264 = torch.ops.aten.mm.default(add_399, permute_668); add_399 = permute_668 = None | |
| - convert_element_type_1067 = torch.ops.prims.convert_element_type.default(mm_263, torch.float32); mm_263 = None | |
| - convert_element_type_1068 = torch.ops.prims.convert_element_type.default(mm_264, torch.float32); mm_264 = None | |
| - mul_860 = torch.ops.aten.mul.Tensor(convert_element_type_1062, mul_229); mul_229 = None | |
| - mul_861 = torch.ops.aten.mul.Tensor(convert_element_type_1062, sigmoid_32); convert_element_type_1062 = None | |
| - sub_224 = torch.ops.aten.sub.Tensor(1, sigmoid_32) | |
| - mul_862 = torch.ops.aten.mul.Tensor(sigmoid_32, sub_224); sigmoid_32 = sub_224 = None | |
| - mul_863 = torch.ops.aten.mul.Tensor(mul_860, mul_862); mul_860 = mul_862 = None | |
| - add_400 = torch.ops.aten.add.Tensor(mul_861, mul_863); mul_861 = mul_863 = None | |
| - mul_864 = torch.ops.aten.mul.Tensor(convert_element_type_1068, mul_227); mul_227 = None | |
| - mul_865 = torch.ops.aten.mul.Tensor(convert_element_type_1068, sigmoid_31); convert_element_type_1068 = None | |
| - sub_225 = torch.ops.aten.sub.Tensor(1, sigmoid_31) | |
| - mul_866 = torch.ops.aten.mul.Tensor(sigmoid_31, sub_225); sigmoid_31 = sub_225 = None | |
| - mul_867 = torch.ops.aten.mul.Tensor(mul_864, mul_866); mul_864 = mul_866 = None | |
| - add_401 = torch.ops.aten.add.Tensor(mul_865, mul_867); mul_865 = mul_867 = None | |
| - mul_868 = torch.ops.aten.mul.Tensor(add_400, primals_378); primals_378 = None | |
| - mul_870 = torch.ops.aten.mul.Tensor(mul_228, mul_868) | |
| - sum_134 = torch.ops.aten.sum.dim_IntList(mul_870, [1], True); mul_870 = None | |
| - div_107 = torch.ops.aten.div.Tensor(mul_228, 1536) | |
| - mul_871 = torch.ops.aten.mul.Tensor(div_107, sum_134); div_107 = sum_134 = None | |
| - sub_226 = torch.ops.aten.sub.Tensor(mul_868, mul_871); mul_868 = mul_871 = None | |
| - mul_872 = torch.ops.aten.mul.Tensor(sub_226, rsqrt_75); sub_226 = rsqrt_75 = None | |
| - mul_873 = torch.ops.aten.mul.Tensor(add_400, mul_228); add_400 = mul_228 = None | |
| - sum_135 = torch.ops.aten.sum.dim_IntList(mul_873, [0]); mul_873 = None | |
| - convert_element_type_1069 = torch.ops.prims.convert_element_type.default(mul_872, torch.bfloat16); mul_872 = None | |
| - mul_874 = torch.ops.aten.mul.Tensor(add_401, primals_377); primals_377 = None | |
| - mul_876 = torch.ops.aten.mul.Tensor(mul_226, mul_874) | |
| - sum_136 = torch.ops.aten.sum.dim_IntList(mul_876, [1], True); mul_876 = None | |
| - div_108 = torch.ops.aten.div.Tensor(mul_226, 1536) | |
| - mul_877 = torch.ops.aten.mul.Tensor(div_108, sum_136); div_108 = sum_136 = None | |
| - sub_227 = torch.ops.aten.sub.Tensor(mul_874, mul_877); mul_874 = mul_877 = None | |
| - mul_878 = torch.ops.aten.mul.Tensor(sub_227, rsqrt_74); sub_227 = rsqrt_74 = None | |
| - mul_879 = torch.ops.aten.mul.Tensor(add_401, mul_226); add_401 = mul_226 = None | |
| - sum_137 = torch.ops.aten.sum.dim_IntList(mul_879, [0]); mul_879 = None | |
| - convert_element_type_1070 = torch.ops.prims.convert_element_type.default(mul_878, torch.bfloat16); mul_878 = None | |
| - index_put_31 = torch.ops.aten.index_put.default(full_default_281, [sub_32], convert_element_type_1069, True) | |
| + mm_264 = torch.ops.aten.mm.default(add_416, permute_668); add_416 = permute_668 = None | |
| + convert_element_type_1117 = torch.ops.prims.convert_element_type.default(mm_263, torch.float32); mm_263 = None | |
| + mul_894 = torch.ops.aten.mul.Tensor(mm_262, convert_element_type_649); convert_element_type_649 = None | |
| + mul_895 = torch.ops.aten.mul.Tensor(mm_262, sigmoid_32); mm_262 = None | |
| + convert_element_type_1118 = torch.ops.prims.convert_element_type.default(mul_894, torch.float32); mul_894 = None | |
| + convert_element_type_1119 = torch.ops.prims.convert_element_type.default(sigmoid_32, torch.float32); sigmoid_32 = None | |
| + sub_207 = torch.ops.aten.sub.Tensor(1, convert_element_type_1119) | |
| + mul_896 = torch.ops.aten.mul.Tensor(convert_element_type_1119, sub_207); convert_element_type_1119 = sub_207 = None | |
| + mul_897 = torch.ops.aten.mul.Tensor(convert_element_type_1118, mul_896); convert_element_type_1118 = mul_896 = None | |
| + convert_element_type_1120 = torch.ops.prims.convert_element_type.default(mul_897, torch.bfloat16); mul_897 = None | |
| + add_417 = torch.ops.aten.add.Tensor(mul_895, convert_element_type_1120); mul_895 = convert_element_type_1120 = None | |
| + mul_898 = torch.ops.aten.mul.Tensor(mm_264, convert_element_type_647); convert_element_type_647 = None | |
| + mul_899 = torch.ops.aten.mul.Tensor(mm_264, sigmoid_31); mm_264 = None | |
| + convert_element_type_1121 = torch.ops.prims.convert_element_type.default(mul_898, torch.float32); mul_898 = None | |
| + convert_element_type_1122 = torch.ops.prims.convert_element_type.default(sigmoid_31, torch.float32); sigmoid_31 = None | |
| + sub_208 = torch.ops.aten.sub.Tensor(1, convert_element_type_1122) | |
| + mul_900 = torch.ops.aten.mul.Tensor(convert_element_type_1122, sub_208); convert_element_type_1122 = sub_208 = None | |
| + mul_901 = torch.ops.aten.mul.Tensor(convert_element_type_1121, mul_900); convert_element_type_1121 = mul_900 = None | |
| + convert_element_type_1123 = torch.ops.prims.convert_element_type.default(mul_901, torch.bfloat16); mul_901 = None | |
| + add_418 = torch.ops.aten.add.Tensor(mul_899, convert_element_type_1123); mul_899 = convert_element_type_1123 = None | |
| + convert_element_type_1124 = torch.ops.prims.convert_element_type.default(add_417, torch.float32); add_417 = None | |
| + mul_902 = torch.ops.aten.mul.Tensor(convert_element_type_1124, mul_228); mul_228 = None | |
| + mul_903 = torch.ops.aten.mul.Tensor(convert_element_type_1124, primals_378); convert_element_type_1124 = primals_378 = None | |
| + sum_134 = torch.ops.aten.sum.dim_IntList(mul_902, [0], True); mul_902 = None | |
| + view_486 = torch.ops.aten.view.default(sum_134, [1536]); sum_134 = None | |
| + mul_904 = torch.ops.aten.mul.Tensor(mul_903, convert_element_type_648) | |
| + mul_905 = torch.ops.aten.mul.Tensor(mul_903, rsqrt_75); mul_903 = None | |
| + sum_135 = torch.ops.aten.sum.dim_IntList(mul_904, [1], True); mul_904 = None | |
| + mul_906 = torch.ops.aten.mul.Scalar(sum_135, -0.5); sum_135 = None | |
| + pow_118 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_75, 3); rsqrt_75 = None | |
| + mul_907 = torch.ops.aten.mul.Tensor(mul_906, pow_118); mul_906 = pow_118 = None | |
| + expand_80 = torch.ops.aten.expand.default(mul_907, [4096, 1536]); mul_907 = None | |
| + div_107 = torch.ops.aten.div.Scalar(expand_80, 1536); expand_80 = None | |
| + pow_119 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_648, 1.0); convert_element_type_648 = None | |
| + mul_908 = torch.ops.aten.mul.Scalar(pow_119, 2.0); pow_119 = None | |
| + mul_909 = torch.ops.aten.mul.Tensor(div_107, mul_908); div_107 = mul_908 = None | |
| + add_419 = torch.ops.aten.add.Tensor(mul_905, mul_909); mul_905 = mul_909 = None | |
| + convert_element_type_1125 = torch.ops.prims.convert_element_type.default(add_419, torch.bfloat16); add_419 = None | |
| + convert_element_type_1126 = torch.ops.prims.convert_element_type.default(add_418, torch.float32); add_418 = None | |
| + mul_910 = torch.ops.aten.mul.Tensor(convert_element_type_1126, mul_226); mul_226 = None | |
| + mul_911 = torch.ops.aten.mul.Tensor(convert_element_type_1126, primals_377); convert_element_type_1126 = primals_377 = None | |
| + sum_136 = torch.ops.aten.sum.dim_IntList(mul_910, [0], True); mul_910 = None | |
| + view_487 = torch.ops.aten.view.default(sum_136, [1536]); sum_136 = None | |
| + mul_912 = torch.ops.aten.mul.Tensor(mul_911, convert_element_type_646) | |
| + mul_913 = torch.ops.aten.mul.Tensor(mul_911, rsqrt_74); mul_911 = None | |
| + sum_137 = torch.ops.aten.sum.dim_IntList(mul_912, [1], True); mul_912 = None | |
| + mul_914 = torch.ops.aten.mul.Scalar(sum_137, -0.5); sum_137 = None | |
| + pow_120 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_74, 3); rsqrt_74 = None | |
| + mul_915 = torch.ops.aten.mul.Tensor(mul_914, pow_120); mul_914 = pow_120 = None | |
| + expand_81 = torch.ops.aten.expand.default(mul_915, [4096, 1536]); mul_915 = None | |
| + div_108 = torch.ops.aten.div.Scalar(expand_81, 1536); expand_81 = None | |
| + pow_121 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_646, 1.0); convert_element_type_646 = None | |
| + mul_916 = torch.ops.aten.mul.Scalar(pow_121, 2.0); pow_121 = None | |
| + mul_917 = torch.ops.aten.mul.Tensor(div_108, mul_916); div_108 = mul_916 = None | |
| + add_420 = torch.ops.aten.add.Tensor(mul_913, mul_917); mul_913 = mul_917 = None | |
| + convert_element_type_1127 = torch.ops.prims.convert_element_type.default(add_420, torch.bfloat16); add_420 = None | |
| + index_put_31 = torch.ops.aten.index_put.default(full_default_281, [sub_32], convert_element_type_1125, True) | |
| full_default_300 = torch.ops.aten.full.default([4096, 3072], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_24 = torch.ops.aten.slice_scatter.default(full_default_300, index_put_31, 1, 1536, 9223372036854775807); index_put_31 = None | |
| - permute_670 = torch.ops.aten.permute.default(convert_element_type_1069, [1, 0]) | |
| - convert_element_type_643 = torch.ops.prims.convert_element_type.default(mul_225, torch.bfloat16); mul_225 = None | |
| - mm_265 = torch.ops.aten.mm.default(permute_670, convert_element_type_643); permute_670 = convert_element_type_643 = None | |
| - convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_376, torch.bfloat16); primals_376 = None | |
| - permute_224 = torch.ops.aten.permute.default(convert_element_type_644, [1, 0]); convert_element_type_644 = None | |
| + permute_670 = torch.ops.aten.permute.default(convert_element_type_1125, [1, 0]) | |
| + mm_265 = torch.ops.aten.mm.default(permute_670, mul_225); permute_670 = mul_225 = None | |
| + convert_element_type_643 = torch.ops.prims.convert_element_type.default(primals_376, torch.bfloat16); primals_376 = None | |
| + permute_224 = torch.ops.aten.permute.default(convert_element_type_643, [1, 0]); convert_element_type_643 = None | |
| permute_672 = torch.ops.aten.permute.default(permute_224, [1, 0]); permute_224 = None | |
| - mm_266 = torch.ops.aten.mm.default(convert_element_type_1069, permute_672); convert_element_type_1069 = permute_672 = None | |
| - convert_element_type_1075 = torch.ops.prims.convert_element_type.default(mm_265, torch.float32); mm_265 = None | |
| - convert_element_type_1076 = torch.ops.prims.convert_element_type.default(mm_266, torch.float32); mm_266 = None | |
| - add_402 = torch.ops.aten.add.Tensor(mul_852, convert_element_type_1076); mul_852 = convert_element_type_1076 = None | |
| - slice_scatter_25 = torch.ops.aten.slice_scatter.default(full_default_300, convert_element_type_1070, 1, 0, 1536); convert_element_type_1070 = None | |
| - add_403 = torch.ops.aten.add.Tensor(slice_scatter_24, slice_scatter_25); slice_scatter_24 = slice_scatter_25 = None | |
| - permute_674 = torch.ops.aten.permute.default(add_403, [1, 0]) | |
| - convert_element_type_639 = torch.ops.prims.convert_element_type.default(mul_224, torch.bfloat16); mul_224 = None | |
| - mm_267 = torch.ops.aten.mm.default(permute_674, convert_element_type_639); permute_674 = convert_element_type_639 = None | |
| + mm_266 = torch.ops.aten.mm.default(convert_element_type_1125, permute_672); convert_element_type_1125 = permute_672 = None | |
| + add_421 = torch.ops.aten.add.Tensor(convert_element_type_1105, mm_266); convert_element_type_1105 = mm_266 = None | |
| + convert_element_type_1132 = torch.ops.prims.convert_element_type.default(mm_265, torch.float32); mm_265 = None | |
| + slice_scatter_25 = torch.ops.aten.slice_scatter.default(full_default_300, convert_element_type_1127, 1, 0, 1536); convert_element_type_1127 = None | |
| + add_422 = torch.ops.aten.add.Tensor(slice_scatter_24, slice_scatter_25); slice_scatter_24 = slice_scatter_25 = None | |
| + permute_674 = torch.ops.aten.permute.default(add_422, [1, 0]) | |
| + mm_267 = torch.ops.aten.mm.default(permute_674, mul_224); permute_674 = mul_224 = None | |
| convert_element_type_640 = torch.ops.prims.convert_element_type.default(primals_375, torch.bfloat16); primals_375 = None | |
| permute_223 = torch.ops.aten.permute.default(convert_element_type_640, [1, 0]); convert_element_type_640 = None | |
| permute_676 = torch.ops.aten.permute.default(permute_223, [1, 0]); permute_223 = None | |
| - mm_268 = torch.ops.aten.mm.default(add_403, permute_676); add_403 = permute_676 = None | |
| - convert_element_type_1081 = torch.ops.prims.convert_element_type.default(mm_267, torch.float32); mm_267 = None | |
| - convert_element_type_1082 = torch.ops.prims.convert_element_type.default(mm_268, torch.float32); mm_268 = None | |
| - add_404 = torch.ops.aten.add.Tensor(mul_858, convert_element_type_1082); mul_858 = convert_element_type_1082 = None | |
| - mul_880 = torch.ops.aten.mul.Tensor(add_402, mul_223); mul_223 = None | |
| - mul_881 = torch.ops.aten.mul.Tensor(add_402, sigmoid_30); add_402 = None | |
| - sub_228 = torch.ops.aten.sub.Tensor(1, sigmoid_30) | |
| - mul_882 = torch.ops.aten.mul.Tensor(sigmoid_30, sub_228); sigmoid_30 = sub_228 = None | |
| - mul_883 = torch.ops.aten.mul.Tensor(mul_880, mul_882); mul_880 = mul_882 = None | |
| - add_405 = torch.ops.aten.add.Tensor(mul_881, mul_883); mul_881 = mul_883 = None | |
| - mul_884 = torch.ops.aten.mul.Tensor(add_404, mul_221); mul_221 = None | |
| - mul_885 = torch.ops.aten.mul.Tensor(add_404, sigmoid_29); add_404 = None | |
| - sub_229 = torch.ops.aten.sub.Tensor(1, sigmoid_29) | |
| - mul_886 = torch.ops.aten.mul.Tensor(sigmoid_29, sub_229); sigmoid_29 = sub_229 = None | |
| - mul_887 = torch.ops.aten.mul.Tensor(mul_884, mul_886); mul_884 = mul_886 = None | |
| - add_406 = torch.ops.aten.add.Tensor(mul_885, mul_887); mul_885 = mul_887 = None | |
| - mul_888 = torch.ops.aten.mul.Tensor(add_405, primals_374); primals_374 = None | |
| - mul_890 = torch.ops.aten.mul.Tensor(mul_222, mul_888) | |
| - sum_138 = torch.ops.aten.sum.dim_IntList(mul_890, [1], True); mul_890 = None | |
| - div_109 = torch.ops.aten.div.Tensor(mul_222, 4608) | |
| - mul_891 = torch.ops.aten.mul.Tensor(div_109, sum_138); div_109 = sum_138 = None | |
| - sub_230 = torch.ops.aten.sub.Tensor(mul_888, mul_891); mul_888 = mul_891 = None | |
| - mul_892 = torch.ops.aten.mul.Tensor(sub_230, rsqrt_73); sub_230 = rsqrt_73 = None | |
| - mul_893 = torch.ops.aten.mul.Tensor(add_405, mul_222); add_405 = mul_222 = None | |
| - sum_139 = torch.ops.aten.sum.dim_IntList(mul_893, [0]); mul_893 = None | |
| - mul_894 = torch.ops.aten.mul.Tensor(add_406, primals_373); primals_373 = None | |
| - mul_896 = torch.ops.aten.mul.Tensor(mul_220, mul_894) | |
| - sum_140 = torch.ops.aten.sum.dim_IntList(mul_896, [1], True); mul_896 = None | |
| - div_110 = torch.ops.aten.div.Tensor(mul_220, 4608) | |
| - mul_897 = torch.ops.aten.mul.Tensor(div_110, sum_140); div_110 = sum_140 = None | |
| - sub_231 = torch.ops.aten.sub.Tensor(mul_894, mul_897); mul_894 = mul_897 = None | |
| - mul_898 = torch.ops.aten.mul.Tensor(sub_231, rsqrt_72); sub_231 = rsqrt_72 = None | |
| - mul_899 = torch.ops.aten.mul.Tensor(add_406, mul_220); add_406 = mul_220 = None | |
| - sum_141 = torch.ops.aten.sum.dim_IntList(mul_899, [0]); mul_899 = None | |
| - convert_element_type_1083 = torch.ops.prims.convert_element_type.default(mul_892, torch.bfloat16) | |
| - convert_element_type_1084 = torch.ops.prims.convert_element_type.default(mul_898, torch.bfloat16) | |
| - index_put_32 = torch.ops.aten.index_put.default(full_default_296, [sub_32], convert_element_type_1083, True) | |
| + mm_268 = torch.ops.aten.mm.default(add_422, permute_676); add_422 = permute_676 = None | |
| + add_423 = torch.ops.aten.add.Tensor(convert_element_type_1107, mm_268); convert_element_type_1107 = mm_268 = None | |
| + convert_element_type_1137 = torch.ops.prims.convert_element_type.default(mm_267, torch.float32); mm_267 = None | |
| + mul_918 = torch.ops.aten.mul.Tensor(add_421, convert_element_type_639); convert_element_type_639 = None | |
| + mul_919 = torch.ops.aten.mul.Tensor(add_421, sigmoid_30); add_421 = None | |
| + convert_element_type_1138 = torch.ops.prims.convert_element_type.default(mul_918, torch.float32); mul_918 = None | |
| + convert_element_type_1139 = torch.ops.prims.convert_element_type.default(sigmoid_30, torch.float32); sigmoid_30 = None | |
| + sub_209 = torch.ops.aten.sub.Tensor(1, convert_element_type_1139) | |
| + mul_920 = torch.ops.aten.mul.Tensor(convert_element_type_1139, sub_209); convert_element_type_1139 = sub_209 = None | |
| + mul_921 = torch.ops.aten.mul.Tensor(convert_element_type_1138, mul_920); convert_element_type_1138 = mul_920 = None | |
| + convert_element_type_1140 = torch.ops.prims.convert_element_type.default(mul_921, torch.bfloat16); mul_921 = None | |
| + add_424 = torch.ops.aten.add.Tensor(mul_919, convert_element_type_1140); mul_919 = convert_element_type_1140 = None | |
| + mul_922 = torch.ops.aten.mul.Tensor(add_423, convert_element_type_637); convert_element_type_637 = None | |
| + mul_923 = torch.ops.aten.mul.Tensor(add_423, sigmoid_29); add_423 = None | |
| + convert_element_type_1141 = torch.ops.prims.convert_element_type.default(mul_922, torch.float32); mul_922 = None | |
| + convert_element_type_1142 = torch.ops.prims.convert_element_type.default(sigmoid_29, torch.float32); sigmoid_29 = None | |
| + sub_210 = torch.ops.aten.sub.Tensor(1, convert_element_type_1142) | |
| + mul_924 = torch.ops.aten.mul.Tensor(convert_element_type_1142, sub_210); convert_element_type_1142 = sub_210 = None | |
| + mul_925 = torch.ops.aten.mul.Tensor(convert_element_type_1141, mul_924); convert_element_type_1141 = mul_924 = None | |
| + convert_element_type_1143 = torch.ops.prims.convert_element_type.default(mul_925, torch.bfloat16); mul_925 = None | |
| + add_425 = torch.ops.aten.add.Tensor(mul_923, convert_element_type_1143); mul_923 = convert_element_type_1143 = None | |
| + convert_element_type_1144 = torch.ops.prims.convert_element_type.default(add_424, torch.float32); add_424 = None | |
| + mul_926 = torch.ops.aten.mul.Tensor(convert_element_type_1144, mul_222); mul_222 = None | |
| + mul_927 = torch.ops.aten.mul.Tensor(convert_element_type_1144, primals_374); convert_element_type_1144 = primals_374 = None | |
| + sum_138 = torch.ops.aten.sum.dim_IntList(mul_926, [0], True); mul_926 = None | |
| + view_488 = torch.ops.aten.view.default(sum_138, [4608]); sum_138 = None | |
| + mul_928 = torch.ops.aten.mul.Tensor(mul_927, convert_element_type_638) | |
| + mul_929 = torch.ops.aten.mul.Tensor(mul_927, rsqrt_73); mul_927 = None | |
| + sum_139 = torch.ops.aten.sum.dim_IntList(mul_928, [1], True); mul_928 = None | |
| + mul_930 = torch.ops.aten.mul.Scalar(sum_139, -0.5); sum_139 = None | |
| + pow_122 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_73, 3); rsqrt_73 = None | |
| + mul_931 = torch.ops.aten.mul.Tensor(mul_930, pow_122); mul_930 = pow_122 = None | |
| + expand_82 = torch.ops.aten.expand.default(mul_931, [4096, 4608]); mul_931 = None | |
| + div_109 = torch.ops.aten.div.Scalar(expand_82, 4608); expand_82 = None | |
| + pow_123 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_638, 1.0); convert_element_type_638 = None | |
| + mul_932 = torch.ops.aten.mul.Scalar(pow_123, 2.0); pow_123 = None | |
| + mul_933 = torch.ops.aten.mul.Tensor(div_109, mul_932); div_109 = mul_932 = None | |
| + add_426 = torch.ops.aten.add.Tensor(mul_929, mul_933); mul_929 = mul_933 = None | |
| + convert_element_type_1145 = torch.ops.prims.convert_element_type.default(add_426, torch.bfloat16); add_426 = None | |
| + convert_element_type_1146 = torch.ops.prims.convert_element_type.default(add_425, torch.float32); add_425 = None | |
| + mul_934 = torch.ops.aten.mul.Tensor(convert_element_type_1146, mul_220); mul_220 = None | |
| + mul_935 = torch.ops.aten.mul.Tensor(convert_element_type_1146, primals_373); convert_element_type_1146 = primals_373 = None | |
| + sum_140 = torch.ops.aten.sum.dim_IntList(mul_934, [0], True); mul_934 = None | |
| + view_489 = torch.ops.aten.view.default(sum_140, [4608]); sum_140 = None | |
| + mul_936 = torch.ops.aten.mul.Tensor(mul_935, convert_element_type_636) | |
| + mul_937 = torch.ops.aten.mul.Tensor(mul_935, rsqrt_72); mul_935 = None | |
| + sum_141 = torch.ops.aten.sum.dim_IntList(mul_936, [1], True); mul_936 = None | |
| + mul_938 = torch.ops.aten.mul.Scalar(sum_141, -0.5); sum_141 = None | |
| + pow_124 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_72, 3); rsqrt_72 = None | |
| + mul_939 = torch.ops.aten.mul.Tensor(mul_938, pow_124); mul_938 = pow_124 = None | |
| + expand_83 = torch.ops.aten.expand.default(mul_939, [4096, 4608]); mul_939 = None | |
| + div_110 = torch.ops.aten.div.Scalar(expand_83, 4608); expand_83 = None | |
| + pow_125 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_636, 1.0); convert_element_type_636 = None | |
| + mul_940 = torch.ops.aten.mul.Scalar(pow_125, 2.0); pow_125 = None | |
| + mul_941 = torch.ops.aten.mul.Tensor(div_110, mul_940); div_110 = mul_940 = None | |
| + add_427 = torch.ops.aten.add.Tensor(mul_937, mul_941); mul_937 = mul_941 = None | |
| + convert_element_type_1147 = torch.ops.prims.convert_element_type.default(add_427, torch.bfloat16); add_427 = None | |
| + index_put_32 = torch.ops.aten.index_put.default(full_default_296, [sub_32], convert_element_type_1145, True) | |
| slice_scatter_26 = torch.ops.aten.slice_scatter.default(full_default_297, index_put_32, 1, 4608, 9223372036854775807); index_put_32 = None | |
| - permute_678 = torch.ops.aten.permute.default(convert_element_type_1083, [1, 0]) | |
| + permute_678 = torch.ops.aten.permute.default(convert_element_type_1145, [1, 0]) | |
| slice_49 = torch.ops.aten.slice.Tensor(mm_81, 1, 2304, 9223372036854775807) | |
| index_29 = torch.ops.aten.index.Tensor(slice_49, [sub_32]); slice_49 = None | |
| add_157 = torch.ops.aten.add.Tensor(mm_82, index_29); mm_82 = index_29 = None | |
| - convert_element_type_630 = torch.ops.prims.convert_element_type.default(add_157, torch.float32); add_157 = None | |
| - pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_630, 2) | |
| + convert_element_type_628 = torch.ops.prims.convert_element_type.default(add_157, torch.float32); add_157 = None | |
| + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_628, 2) | |
| mean_50 = torch.ops.aten.mean.dim(pow_51, [1], True); pow_51 = None | |
| add_159 = torch.ops.aten.add.Scalar(mean_50, 1.1920928955078125e-07); mean_50 = None | |
| rsqrt_71 = torch.ops.aten.rsqrt.default(add_159); add_159 = None | |
| - mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_630, rsqrt_71); convert_element_type_630 = None | |
| + mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_628, rsqrt_71) | |
| mul_217 = torch.ops.aten.mul.Tensor(mul_216, primals_370) | |
| - sigmoid_28 = torch.ops.aten.sigmoid.default(mul_217) | |
| - mul_219 = torch.ops.aten.mul.Tensor(mul_217, sigmoid_28) | |
| - convert_element_type_635 = torch.ops.prims.convert_element_type.default(mul_219, torch.bfloat16); mul_219 = None | |
| - mm_269 = torch.ops.aten.mm.default(permute_678, convert_element_type_635); permute_678 = convert_element_type_635 = None | |
| - convert_element_type_636 = torch.ops.prims.convert_element_type.default(primals_372, torch.bfloat16); primals_372 = None | |
| - permute_222 = torch.ops.aten.permute.default(convert_element_type_636, [1, 0]); convert_element_type_636 = None | |
| + convert_element_type_629 = torch.ops.prims.convert_element_type.default(mul_217, torch.bfloat16); mul_217 = None | |
| + sigmoid_28 = torch.ops.aten.sigmoid.default(convert_element_type_629) | |
| + mul_219 = torch.ops.aten.mul.Tensor(convert_element_type_629, sigmoid_28) | |
| + mm_269 = torch.ops.aten.mm.default(permute_678, mul_219); permute_678 = mul_219 = None | |
| + convert_element_type_633 = torch.ops.prims.convert_element_type.default(primals_372, torch.bfloat16); primals_372 = None | |
| + permute_222 = torch.ops.aten.permute.default(convert_element_type_633, [1, 0]); convert_element_type_633 = None | |
| permute_680 = torch.ops.aten.permute.default(permute_222, [1, 0]); permute_222 = None | |
| - mm_270 = torch.ops.aten.mm.default(convert_element_type_1083, permute_680); convert_element_type_1083 = permute_680 = None | |
| - convert_element_type_1089 = torch.ops.prims.convert_element_type.default(mm_269, torch.float32); mm_269 = None | |
| - convert_element_type_1090 = torch.ops.prims.convert_element_type.default(mm_270, torch.float32); mm_270 = None | |
| - slice_scatter_27 = torch.ops.aten.slice_scatter.default(full_default_297, convert_element_type_1084, 1, 0, 4608); convert_element_type_1084 = None | |
| - add_407 = torch.ops.aten.add.Tensor(slice_scatter_26, slice_scatter_27); slice_scatter_26 = slice_scatter_27 = None | |
| - permute_682 = torch.ops.aten.permute.default(add_407, [1, 0]) | |
| + mm_270 = torch.ops.aten.mm.default(convert_element_type_1145, permute_680); permute_680 = None | |
| + convert_element_type_1152 = torch.ops.prims.convert_element_type.default(mm_269, torch.float32); mm_269 = None | |
| + slice_scatter_27 = torch.ops.aten.slice_scatter.default(full_default_297, convert_element_type_1147, 1, 0, 4608) | |
| + add_428 = torch.ops.aten.add.Tensor(slice_scatter_26, slice_scatter_27); slice_scatter_26 = slice_scatter_27 = None | |
| + permute_682 = torch.ops.aten.permute.default(add_428, [1, 0]) | |
| slice_48 = torch.ops.aten.slice.Tensor(mm_81, 1, 0, 2304); mm_81 = None | |
| - convert_element_type_629 = torch.ops.prims.convert_element_type.default(slice_48, torch.float32); slice_48 = None | |
| - pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_629, 2) | |
| + convert_element_type_626 = torch.ops.prims.convert_element_type.default(slice_48, torch.float32); slice_48 = None | |
| + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_626, 2) | |
| mean_49 = torch.ops.aten.mean.dim(pow_50, [1], True); pow_50 = None | |
| add_158 = torch.ops.aten.add.Scalar(mean_49, 1.1920928955078125e-07); mean_49 = None | |
| rsqrt_70 = torch.ops.aten.rsqrt.default(add_158); add_158 = None | |
| - mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_629, rsqrt_70); convert_element_type_629 = None | |
| + mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_626, rsqrt_70) | |
| mul_215 = torch.ops.aten.mul.Tensor(mul_214, primals_369) | |
| - sigmoid_27 = torch.ops.aten.sigmoid.default(mul_215) | |
| - mul_218 = torch.ops.aten.mul.Tensor(mul_215, sigmoid_27) | |
| - convert_element_type_631 = torch.ops.prims.convert_element_type.default(mul_218, torch.bfloat16); mul_218 = None | |
| - mm_271 = torch.ops.aten.mm.default(permute_682, convert_element_type_631); permute_682 = convert_element_type_631 = None | |
| - convert_element_type_632 = torch.ops.prims.convert_element_type.default(primals_371, torch.bfloat16); primals_371 = None | |
| - permute_221 = torch.ops.aten.permute.default(convert_element_type_632, [1, 0]); convert_element_type_632 = None | |
| + convert_element_type_627 = torch.ops.prims.convert_element_type.default(mul_215, torch.bfloat16); mul_215 = None | |
| + sigmoid_27 = torch.ops.aten.sigmoid.default(convert_element_type_627) | |
| + mul_218 = torch.ops.aten.mul.Tensor(convert_element_type_627, sigmoid_27) | |
| + mm_271 = torch.ops.aten.mm.default(permute_682, mul_218); permute_682 = mul_218 = None | |
| + convert_element_type_630 = torch.ops.prims.convert_element_type.default(primals_371, torch.bfloat16); primals_371 = None | |
| + permute_221 = torch.ops.aten.permute.default(convert_element_type_630, [1, 0]); convert_element_type_630 = None | |
| permute_684 = torch.ops.aten.permute.default(permute_221, [1, 0]); permute_221 = None | |
| - mm_272 = torch.ops.aten.mm.default(add_407, permute_684); add_407 = permute_684 = None | |
| - convert_element_type_1095 = torch.ops.prims.convert_element_type.default(mm_271, torch.float32); mm_271 = None | |
| - convert_element_type_1096 = torch.ops.prims.convert_element_type.default(mm_272, torch.float32); mm_272 = None | |
| - mul_900 = torch.ops.aten.mul.Tensor(convert_element_type_1090, mul_217); mul_217 = None | |
| - mul_901 = torch.ops.aten.mul.Tensor(convert_element_type_1090, sigmoid_28); convert_element_type_1090 = None | |
| - sub_232 = torch.ops.aten.sub.Tensor(1, sigmoid_28) | |
| - mul_902 = torch.ops.aten.mul.Tensor(sigmoid_28, sub_232); sigmoid_28 = sub_232 = None | |
| - mul_903 = torch.ops.aten.mul.Tensor(mul_900, mul_902); mul_900 = mul_902 = None | |
| - add_408 = torch.ops.aten.add.Tensor(mul_901, mul_903); mul_901 = mul_903 = None | |
| - mul_904 = torch.ops.aten.mul.Tensor(convert_element_type_1096, mul_215); mul_215 = None | |
| - mul_905 = torch.ops.aten.mul.Tensor(convert_element_type_1096, sigmoid_27); convert_element_type_1096 = None | |
| - sub_233 = torch.ops.aten.sub.Tensor(1, sigmoid_27) | |
| - mul_906 = torch.ops.aten.mul.Tensor(sigmoid_27, sub_233); sigmoid_27 = sub_233 = None | |
| - mul_907 = torch.ops.aten.mul.Tensor(mul_904, mul_906); mul_904 = mul_906 = None | |
| - add_409 = torch.ops.aten.add.Tensor(mul_905, mul_907); mul_905 = mul_907 = None | |
| - mul_908 = torch.ops.aten.mul.Tensor(add_408, primals_370); primals_370 = None | |
| - mul_910 = torch.ops.aten.mul.Tensor(mul_216, mul_908) | |
| - sum_142 = torch.ops.aten.sum.dim_IntList(mul_910, [1], True); mul_910 = None | |
| - div_111 = torch.ops.aten.div.Tensor(mul_216, 2304) | |
| - mul_911 = torch.ops.aten.mul.Tensor(div_111, sum_142); div_111 = sum_142 = None | |
| - sub_234 = torch.ops.aten.sub.Tensor(mul_908, mul_911); mul_908 = mul_911 = None | |
| - mul_912 = torch.ops.aten.mul.Tensor(sub_234, rsqrt_71); sub_234 = rsqrt_71 = None | |
| - mul_913 = torch.ops.aten.mul.Tensor(add_408, mul_216); add_408 = mul_216 = None | |
| - sum_143 = torch.ops.aten.sum.dim_IntList(mul_913, [0]); mul_913 = None | |
| - convert_element_type_1097 = torch.ops.prims.convert_element_type.default(mul_912, torch.bfloat16); mul_912 = None | |
| - mul_914 = torch.ops.aten.mul.Tensor(add_409, primals_369); primals_369 = None | |
| - mul_916 = torch.ops.aten.mul.Tensor(mul_214, mul_914) | |
| - sum_144 = torch.ops.aten.sum.dim_IntList(mul_916, [1], True); mul_916 = None | |
| - div_112 = torch.ops.aten.div.Tensor(mul_214, 2304) | |
| - mul_917 = torch.ops.aten.mul.Tensor(div_112, sum_144); div_112 = sum_144 = None | |
| - sub_235 = torch.ops.aten.sub.Tensor(mul_914, mul_917); mul_914 = mul_917 = None | |
| - mul_918 = torch.ops.aten.mul.Tensor(sub_235, rsqrt_70); sub_235 = rsqrt_70 = None | |
| - mul_919 = torch.ops.aten.mul.Tensor(add_409, mul_214); add_409 = mul_214 = None | |
| - sum_145 = torch.ops.aten.sum.dim_IntList(mul_919, [0]); mul_919 = None | |
| - convert_element_type_1098 = torch.ops.prims.convert_element_type.default(mul_918, torch.bfloat16); mul_918 = None | |
| + mm_272 = torch.ops.aten.mm.default(add_428, permute_684); add_428 = permute_684 = None | |
| + convert_element_type_1157 = torch.ops.prims.convert_element_type.default(mm_271, torch.float32); mm_271 = None | |
| + mul_942 = torch.ops.aten.mul.Tensor(mm_270, convert_element_type_629); convert_element_type_629 = None | |
| + mul_943 = torch.ops.aten.mul.Tensor(mm_270, sigmoid_28); mm_270 = None | |
| + convert_element_type_1158 = torch.ops.prims.convert_element_type.default(mul_942, torch.float32); mul_942 = None | |
| + convert_element_type_1159 = torch.ops.prims.convert_element_type.default(sigmoid_28, torch.float32); sigmoid_28 = None | |
| + sub_211 = torch.ops.aten.sub.Tensor(1, convert_element_type_1159) | |
| + mul_944 = torch.ops.aten.mul.Tensor(convert_element_type_1159, sub_211); convert_element_type_1159 = sub_211 = None | |
| + mul_945 = torch.ops.aten.mul.Tensor(convert_element_type_1158, mul_944); convert_element_type_1158 = mul_944 = None | |
| + convert_element_type_1160 = torch.ops.prims.convert_element_type.default(mul_945, torch.bfloat16); mul_945 = None | |
| + add_429 = torch.ops.aten.add.Tensor(mul_943, convert_element_type_1160); mul_943 = convert_element_type_1160 = None | |
| + mul_946 = torch.ops.aten.mul.Tensor(mm_272, convert_element_type_627); convert_element_type_627 = None | |
| + mul_947 = torch.ops.aten.mul.Tensor(mm_272, sigmoid_27); mm_272 = None | |
| + convert_element_type_1161 = torch.ops.prims.convert_element_type.default(mul_946, torch.float32); mul_946 = None | |
| + convert_element_type_1162 = torch.ops.prims.convert_element_type.default(sigmoid_27, torch.float32); sigmoid_27 = None | |
| + sub_212 = torch.ops.aten.sub.Tensor(1, convert_element_type_1162) | |
| + mul_948 = torch.ops.aten.mul.Tensor(convert_element_type_1162, sub_212); convert_element_type_1162 = sub_212 = None | |
| + mul_949 = torch.ops.aten.mul.Tensor(convert_element_type_1161, mul_948); convert_element_type_1161 = mul_948 = None | |
| + convert_element_type_1163 = torch.ops.prims.convert_element_type.default(mul_949, torch.bfloat16); mul_949 = None | |
| + add_430 = torch.ops.aten.add.Tensor(mul_947, convert_element_type_1163); mul_947 = convert_element_type_1163 = None | |
| + convert_element_type_1164 = torch.ops.prims.convert_element_type.default(add_429, torch.float32); add_429 = None | |
| + mul_950 = torch.ops.aten.mul.Tensor(convert_element_type_1164, mul_216); mul_216 = None | |
| + mul_951 = torch.ops.aten.mul.Tensor(convert_element_type_1164, primals_370); convert_element_type_1164 = primals_370 = None | |
| + sum_142 = torch.ops.aten.sum.dim_IntList(mul_950, [0], True); mul_950 = None | |
| + view_490 = torch.ops.aten.view.default(sum_142, [2304]); sum_142 = None | |
| + mul_952 = torch.ops.aten.mul.Tensor(mul_951, convert_element_type_628) | |
| + mul_953 = torch.ops.aten.mul.Tensor(mul_951, rsqrt_71); mul_951 = None | |
| + sum_143 = torch.ops.aten.sum.dim_IntList(mul_952, [1], True); mul_952 = None | |
| + mul_954 = torch.ops.aten.mul.Scalar(sum_143, -0.5); sum_143 = None | |
| + pow_126 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_71, 3); rsqrt_71 = None | |
| + mul_955 = torch.ops.aten.mul.Tensor(mul_954, pow_126); mul_954 = pow_126 = None | |
| + expand_84 = torch.ops.aten.expand.default(mul_955, [4096, 2304]); mul_955 = None | |
| + div_111 = torch.ops.aten.div.Scalar(expand_84, 2304); expand_84 = None | |
| + pow_127 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_628, 1.0); convert_element_type_628 = None | |
| + mul_956 = torch.ops.aten.mul.Scalar(pow_127, 2.0); pow_127 = None | |
| + mul_957 = torch.ops.aten.mul.Tensor(div_111, mul_956); div_111 = mul_956 = None | |
| + add_431 = torch.ops.aten.add.Tensor(mul_953, mul_957); mul_953 = mul_957 = None | |
| + convert_element_type_1165 = torch.ops.prims.convert_element_type.default(add_431, torch.bfloat16); add_431 = None | |
| + convert_element_type_1166 = torch.ops.prims.convert_element_type.default(add_430, torch.float32); add_430 = None | |
| + mul_958 = torch.ops.aten.mul.Tensor(convert_element_type_1166, mul_214); mul_214 = None | |
| + mul_959 = torch.ops.aten.mul.Tensor(convert_element_type_1166, primals_369); convert_element_type_1166 = primals_369 = None | |
| + sum_144 = torch.ops.aten.sum.dim_IntList(mul_958, [0], True); mul_958 = None | |
| + view_491 = torch.ops.aten.view.default(sum_144, [2304]); sum_144 = None | |
| + mul_960 = torch.ops.aten.mul.Tensor(mul_959, convert_element_type_626) | |
| + mul_961 = torch.ops.aten.mul.Tensor(mul_959, rsqrt_70); mul_959 = None | |
| + sum_145 = torch.ops.aten.sum.dim_IntList(mul_960, [1], True); mul_960 = None | |
| + mul_962 = torch.ops.aten.mul.Scalar(sum_145, -0.5); sum_145 = None | |
| + pow_128 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_70, 3); rsqrt_70 = None | |
| + mul_963 = torch.ops.aten.mul.Tensor(mul_962, pow_128); mul_962 = pow_128 = None | |
| + expand_85 = torch.ops.aten.expand.default(mul_963, [4096, 2304]); mul_963 = None | |
| + div_112 = torch.ops.aten.div.Scalar(expand_85, 2304); expand_85 = None | |
| + pow_129 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_626, 1.0); convert_element_type_626 = None | |
| + mul_964 = torch.ops.aten.mul.Scalar(pow_129, 2.0); pow_129 = None | |
| + mul_965 = torch.ops.aten.mul.Tensor(div_112, mul_964); div_112 = mul_964 = None | |
| + add_432 = torch.ops.aten.add.Tensor(mul_961, mul_965); mul_961 = mul_965 = None | |
| + convert_element_type_1167 = torch.ops.prims.convert_element_type.default(add_432, torch.bfloat16); add_432 = None | |
| full_default_305 = torch.ops.aten.full.default([4096, 2304], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - index_put_33 = torch.ops.aten.index_put.default(full_default_305, [sub_32], convert_element_type_1097, True) | |
| + index_put_33 = torch.ops.aten.index_put.default(full_default_305, [sub_32], convert_element_type_1165, True) | |
| slice_scatter_28 = torch.ops.aten.slice_scatter.default(full_default_296, index_put_33, 1, 2304, 9223372036854775807); index_put_33 = None | |
| - permute_686 = torch.ops.aten.permute.default(convert_element_type_1097, [1, 0]) | |
| - convert_element_type_625 = torch.ops.prims.convert_element_type.default(mul_213, torch.bfloat16); mul_213 = None | |
| - mm_273 = torch.ops.aten.mm.default(permute_686, convert_element_type_625); permute_686 = convert_element_type_625 = None | |
| - convert_element_type_626 = torch.ops.prims.convert_element_type.default(primals_368, torch.bfloat16); primals_368 = None | |
| - permute_220 = torch.ops.aten.permute.default(convert_element_type_626, [1, 0]); convert_element_type_626 = None | |
| + permute_686 = torch.ops.aten.permute.default(convert_element_type_1165, [1, 0]) | |
| + mm_273 = torch.ops.aten.mm.default(permute_686, mul_213); permute_686 = mul_213 = None | |
| + convert_element_type_623 = torch.ops.prims.convert_element_type.default(primals_368, torch.bfloat16); primals_368 = None | |
| + permute_220 = torch.ops.aten.permute.default(convert_element_type_623, [1, 0]); convert_element_type_623 = None | |
| permute_688 = torch.ops.aten.permute.default(permute_220, [1, 0]); permute_220 = None | |
| - mm_274 = torch.ops.aten.mm.default(convert_element_type_1097, permute_688); convert_element_type_1097 = permute_688 = None | |
| - convert_element_type_1103 = torch.ops.prims.convert_element_type.default(mm_273, torch.float32); mm_273 = None | |
| - convert_element_type_1104 = torch.ops.prims.convert_element_type.default(mm_274, torch.float32); mm_274 = None | |
| - add_410 = torch.ops.aten.add.Tensor(mul_892, convert_element_type_1104); mul_892 = convert_element_type_1104 = None | |
| - slice_scatter_29 = torch.ops.aten.slice_scatter.default(full_default_296, convert_element_type_1098, 1, 0, 2304); convert_element_type_1098 = None | |
| - add_411 = torch.ops.aten.add.Tensor(slice_scatter_28, slice_scatter_29); slice_scatter_28 = slice_scatter_29 = None | |
| - permute_690 = torch.ops.aten.permute.default(add_411, [1, 0]) | |
| - convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_212, torch.bfloat16); mul_212 = None | |
| - mm_275 = torch.ops.aten.mm.default(permute_690, convert_element_type_621); permute_690 = convert_element_type_621 = None | |
| - convert_element_type_622 = torch.ops.prims.convert_element_type.default(primals_367, torch.bfloat16); primals_367 = None | |
| - permute_219 = torch.ops.aten.permute.default(convert_element_type_622, [1, 0]); convert_element_type_622 = None | |
| + mm_274 = torch.ops.aten.mm.default(convert_element_type_1165, permute_688); convert_element_type_1165 = permute_688 = None | |
| + add_433 = torch.ops.aten.add.Tensor(convert_element_type_1145, mm_274); convert_element_type_1145 = mm_274 = None | |
| + convert_element_type_1172 = torch.ops.prims.convert_element_type.default(mm_273, torch.float32); mm_273 = None | |
| + slice_scatter_29 = torch.ops.aten.slice_scatter.default(full_default_296, convert_element_type_1167, 1, 0, 2304); convert_element_type_1167 = None | |
| + add_434 = torch.ops.aten.add.Tensor(slice_scatter_28, slice_scatter_29); slice_scatter_28 = slice_scatter_29 = None | |
| + permute_690 = torch.ops.aten.permute.default(add_434, [1, 0]) | |
| + mm_275 = torch.ops.aten.mm.default(permute_690, mul_212); permute_690 = mul_212 = None | |
| + convert_element_type_620 = torch.ops.prims.convert_element_type.default(primals_367, torch.bfloat16); primals_367 = None | |
| + permute_219 = torch.ops.aten.permute.default(convert_element_type_620, [1, 0]); convert_element_type_620 = None | |
| permute_692 = torch.ops.aten.permute.default(permute_219, [1, 0]); permute_219 = None | |
| - mm_276 = torch.ops.aten.mm.default(add_411, permute_692); add_411 = permute_692 = None | |
| - convert_element_type_1109 = torch.ops.prims.convert_element_type.default(mm_275, torch.float32); mm_275 = None | |
| - convert_element_type_1110 = torch.ops.prims.convert_element_type.default(mm_276, torch.float32); mm_276 = None | |
| - add_412 = torch.ops.aten.add.Tensor(mul_898, convert_element_type_1110); mul_898 = convert_element_type_1110 = None | |
| - mul_920 = torch.ops.aten.mul.Tensor(add_410, mul_211); mul_211 = None | |
| - mul_921 = torch.ops.aten.mul.Tensor(add_410, sigmoid_26); add_410 = None | |
| - sub_236 = torch.ops.aten.sub.Tensor(1, sigmoid_26) | |
| - mul_922 = torch.ops.aten.mul.Tensor(sigmoid_26, sub_236); sigmoid_26 = sub_236 = None | |
| - mul_923 = torch.ops.aten.mul.Tensor(mul_920, mul_922); mul_920 = mul_922 = None | |
| - add_413 = torch.ops.aten.add.Tensor(mul_921, mul_923); mul_921 = mul_923 = None | |
| - mul_924 = torch.ops.aten.mul.Tensor(add_412, mul_209); mul_209 = None | |
| - mul_925 = torch.ops.aten.mul.Tensor(add_412, sigmoid_25); add_412 = None | |
| - sub_237 = torch.ops.aten.sub.Tensor(1, sigmoid_25) | |
| - mul_926 = torch.ops.aten.mul.Tensor(sigmoid_25, sub_237); sigmoid_25 = sub_237 = None | |
| - mul_927 = torch.ops.aten.mul.Tensor(mul_924, mul_926); mul_924 = mul_926 = None | |
| - add_414 = torch.ops.aten.add.Tensor(mul_925, mul_927); mul_925 = mul_927 = None | |
| - mul_928 = torch.ops.aten.mul.Tensor(add_413, primals_366); primals_366 = None | |
| - mul_930 = torch.ops.aten.mul.Tensor(mul_210, mul_928) | |
| - sum_146 = torch.ops.aten.sum.dim_IntList(mul_930, [1], True); mul_930 = None | |
| - div_113 = torch.ops.aten.div.Tensor(mul_210, 4608) | |
| - mul_931 = torch.ops.aten.mul.Tensor(div_113, sum_146); div_113 = sum_146 = None | |
| - sub_238 = torch.ops.aten.sub.Tensor(mul_928, mul_931); mul_928 = mul_931 = None | |
| - mul_932 = torch.ops.aten.mul.Tensor(sub_238, rsqrt_69); sub_238 = rsqrt_69 = None | |
| - mul_933 = torch.ops.aten.mul.Tensor(add_413, mul_210); add_413 = mul_210 = None | |
| - sum_147 = torch.ops.aten.sum.dim_IntList(mul_933, [0]); mul_933 = None | |
| - convert_element_type_1111 = torch.ops.prims.convert_element_type.default(mul_932, torch.bfloat16); mul_932 = None | |
| - mul_934 = torch.ops.aten.mul.Tensor(add_414, primals_365); primals_365 = None | |
| - mul_936 = torch.ops.aten.mul.Tensor(mul_208, mul_934) | |
| - sum_148 = torch.ops.aten.sum.dim_IntList(mul_936, [1], True); mul_936 = None | |
| - div_114 = torch.ops.aten.div.Tensor(mul_208, 4608) | |
| - mul_937 = torch.ops.aten.mul.Tensor(div_114, sum_148); div_114 = sum_148 = None | |
| - sub_239 = torch.ops.aten.sub.Tensor(mul_934, mul_937); mul_934 = mul_937 = None | |
| - mul_938 = torch.ops.aten.mul.Tensor(sub_239, rsqrt_68); sub_239 = rsqrt_68 = None | |
| - mul_939 = torch.ops.aten.mul.Tensor(add_414, mul_208); add_414 = mul_208 = None | |
| - sum_149 = torch.ops.aten.sum.dim_IntList(mul_939, [0]); mul_939 = None | |
| - convert_element_type_1112 = torch.ops.prims.convert_element_type.default(mul_938, torch.bfloat16); mul_938 = None | |
| - index_put_34 = torch.ops.aten.index_put.default(full_default_296, [sub_32], convert_element_type_1111, True) | |
| + mm_276 = torch.ops.aten.mm.default(add_434, permute_692); add_434 = permute_692 = None | |
| + add_435 = torch.ops.aten.add.Tensor(convert_element_type_1147, mm_276); convert_element_type_1147 = mm_276 = None | |
| + convert_element_type_1177 = torch.ops.prims.convert_element_type.default(mm_275, torch.float32); mm_275 = None | |
| + mul_966 = torch.ops.aten.mul.Tensor(add_433, convert_element_type_619); convert_element_type_619 = None | |
| + mul_967 = torch.ops.aten.mul.Tensor(add_433, sigmoid_26); add_433 = None | |
| + convert_element_type_1178 = torch.ops.prims.convert_element_type.default(mul_966, torch.float32); mul_966 = None | |
| + convert_element_type_1179 = torch.ops.prims.convert_element_type.default(sigmoid_26, torch.float32); sigmoid_26 = None | |
| + sub_213 = torch.ops.aten.sub.Tensor(1, convert_element_type_1179) | |
| + mul_968 = torch.ops.aten.mul.Tensor(convert_element_type_1179, sub_213); convert_element_type_1179 = sub_213 = None | |
| + mul_969 = torch.ops.aten.mul.Tensor(convert_element_type_1178, mul_968); convert_element_type_1178 = mul_968 = None | |
| + convert_element_type_1180 = torch.ops.prims.convert_element_type.default(mul_969, torch.bfloat16); mul_969 = None | |
| + add_436 = torch.ops.aten.add.Tensor(mul_967, convert_element_type_1180); mul_967 = convert_element_type_1180 = None | |
| + mul_970 = torch.ops.aten.mul.Tensor(add_435, convert_element_type_617); convert_element_type_617 = None | |
| + mul_971 = torch.ops.aten.mul.Tensor(add_435, sigmoid_25); add_435 = None | |
| + convert_element_type_1181 = torch.ops.prims.convert_element_type.default(mul_970, torch.float32); mul_970 = None | |
| + convert_element_type_1182 = torch.ops.prims.convert_element_type.default(sigmoid_25, torch.float32); sigmoid_25 = None | |
| + sub_214 = torch.ops.aten.sub.Tensor(1, convert_element_type_1182) | |
| + mul_972 = torch.ops.aten.mul.Tensor(convert_element_type_1182, sub_214); convert_element_type_1182 = sub_214 = None | |
| + mul_973 = torch.ops.aten.mul.Tensor(convert_element_type_1181, mul_972); convert_element_type_1181 = mul_972 = None | |
| + convert_element_type_1183 = torch.ops.prims.convert_element_type.default(mul_973, torch.bfloat16); mul_973 = None | |
| + add_437 = torch.ops.aten.add.Tensor(mul_971, convert_element_type_1183); mul_971 = convert_element_type_1183 = None | |
| + convert_element_type_1184 = torch.ops.prims.convert_element_type.default(add_436, torch.float32); add_436 = None | |
| + mul_974 = torch.ops.aten.mul.Tensor(convert_element_type_1184, mul_210); mul_210 = None | |
| + mul_975 = torch.ops.aten.mul.Tensor(convert_element_type_1184, primals_366); convert_element_type_1184 = primals_366 = None | |
| + sum_146 = torch.ops.aten.sum.dim_IntList(mul_974, [0], True); mul_974 = None | |
| + view_492 = torch.ops.aten.view.default(sum_146, [4608]); sum_146 = None | |
| + mul_976 = torch.ops.aten.mul.Tensor(mul_975, convert_element_type_618) | |
| + mul_977 = torch.ops.aten.mul.Tensor(mul_975, rsqrt_69); mul_975 = None | |
| + sum_147 = torch.ops.aten.sum.dim_IntList(mul_976, [1], True); mul_976 = None | |
| + mul_978 = torch.ops.aten.mul.Scalar(sum_147, -0.5); sum_147 = None | |
| + pow_130 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_69, 3); rsqrt_69 = None | |
| + mul_979 = torch.ops.aten.mul.Tensor(mul_978, pow_130); mul_978 = pow_130 = None | |
| + expand_86 = torch.ops.aten.expand.default(mul_979, [4096, 4608]); mul_979 = None | |
| + div_113 = torch.ops.aten.div.Scalar(expand_86, 4608); expand_86 = None | |
| + pow_131 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_618, 1.0); convert_element_type_618 = None | |
| + mul_980 = torch.ops.aten.mul.Scalar(pow_131, 2.0); pow_131 = None | |
| + mul_981 = torch.ops.aten.mul.Tensor(div_113, mul_980); div_113 = mul_980 = None | |
| + add_438 = torch.ops.aten.add.Tensor(mul_977, mul_981); mul_977 = mul_981 = None | |
| + convert_element_type_1185 = torch.ops.prims.convert_element_type.default(add_438, torch.bfloat16); add_438 = None | |
| + convert_element_type_1186 = torch.ops.prims.convert_element_type.default(add_437, torch.float32); add_437 = None | |
| + mul_982 = torch.ops.aten.mul.Tensor(convert_element_type_1186, mul_208); mul_208 = None | |
| + mul_983 = torch.ops.aten.mul.Tensor(convert_element_type_1186, primals_365); convert_element_type_1186 = primals_365 = None | |
| + sum_148 = torch.ops.aten.sum.dim_IntList(mul_982, [0], True); mul_982 = None | |
| + view_493 = torch.ops.aten.view.default(sum_148, [4608]); sum_148 = None | |
| + mul_984 = torch.ops.aten.mul.Tensor(mul_983, convert_element_type_616) | |
| + mul_985 = torch.ops.aten.mul.Tensor(mul_983, rsqrt_68); mul_983 = None | |
| + sum_149 = torch.ops.aten.sum.dim_IntList(mul_984, [1], True); mul_984 = None | |
| + mul_986 = torch.ops.aten.mul.Scalar(sum_149, -0.5); sum_149 = None | |
| + pow_132 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_68, 3); rsqrt_68 = None | |
| + mul_987 = torch.ops.aten.mul.Tensor(mul_986, pow_132); mul_986 = pow_132 = None | |
| + expand_87 = torch.ops.aten.expand.default(mul_987, [4096, 4608]); mul_987 = None | |
| + div_114 = torch.ops.aten.div.Scalar(expand_87, 4608); expand_87 = None | |
| + pow_133 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_616, 1.0); convert_element_type_616 = None | |
| + mul_988 = torch.ops.aten.mul.Scalar(pow_133, 2.0); pow_133 = None | |
| + mul_989 = torch.ops.aten.mul.Tensor(div_114, mul_988); div_114 = mul_988 = None | |
| + add_439 = torch.ops.aten.add.Tensor(mul_985, mul_989); mul_985 = mul_989 = None | |
| + convert_element_type_1187 = torch.ops.prims.convert_element_type.default(add_439, torch.bfloat16); add_439 = None | |
| + index_put_34 = torch.ops.aten.index_put.default(full_default_296, [sub_32], convert_element_type_1185, True) | |
| slice_scatter_30 = torch.ops.aten.slice_scatter.default(full_default_297, index_put_34, 1, 4608, 9223372036854775807); index_put_34 = None | |
| - abs_56 = torch.ops.aten.abs.default(convert_element_type_1111) | |
| + abs_56 = torch.ops.aten.abs.default(convert_element_type_1185) | |
| amax_25 = torch.ops.aten.amax.default(abs_56, [-1], True); abs_56 = None | |
| - convert_element_type_1113 = torch.ops.prims.convert_element_type.default(amax_25, torch.float64); amax_25 = None | |
| - clamp_min_90 = torch.ops.aten.clamp_min.default(convert_element_type_1113, 1e-12); convert_element_type_1113 = None | |
| + convert_element_type_1188 = torch.ops.prims.convert_element_type.default(amax_25, torch.float64); amax_25 = None | |
| + clamp_min_90 = torch.ops.aten.clamp_min.default(convert_element_type_1188, 1e-12); convert_element_type_1188 = None | |
| reciprocal_56 = torch.ops.aten.reciprocal.default(clamp_min_90); clamp_min_90 = None | |
| - mul_940 = torch.ops.aten.mul.Tensor(reciprocal_56, 448.0); reciprocal_56 = None | |
| - convert_element_type_1114 = torch.ops.prims.convert_element_type.default(mul_940, torch.float32); mul_940 = None | |
| - log2_28 = torch.ops.aten.log2.default(convert_element_type_1114); convert_element_type_1114 = None | |
| + mul_990 = torch.ops.aten.mul.Tensor(reciprocal_56, 448.0); reciprocal_56 = None | |
| + convert_element_type_1189 = torch.ops.prims.convert_element_type.default(mul_990, torch.float32); mul_990 = None | |
| + log2_28 = torch.ops.aten.log2.default(convert_element_type_1189); convert_element_type_1189 = None | |
| floor_28 = torch.ops.aten.floor.default(log2_28); log2_28 = None | |
| exp2_28 = torch.ops.aten.exp2.default(floor_28); floor_28 = None | |
| - convert_element_type_1115 = torch.ops.prims.convert_element_type.default(convert_element_type_1111, torch.float32) | |
| - mul_941 = torch.ops.aten.mul.Tensor(convert_element_type_1115, exp2_28); convert_element_type_1115 = None | |
| - clamp_min_91 = torch.ops.aten.clamp_min.default(mul_941, -448.0); mul_941 = None | |
| + convert_element_type_1190 = torch.ops.prims.convert_element_type.default(convert_element_type_1185, torch.float32) | |
| + mul_991 = torch.ops.aten.mul.Tensor(convert_element_type_1190, exp2_28); convert_element_type_1190 = None | |
| + clamp_min_91 = torch.ops.aten.clamp_min.default(mul_991, -448.0); mul_991 = None | |
| clamp_max_58 = torch.ops.aten.clamp_max.default(clamp_min_91, 448.0); clamp_min_91 = None | |
| - convert_element_type_1116 = torch.ops.prims.convert_element_type.default(clamp_max_58, torch.float8_e4m3fn); clamp_max_58 = None | |
| + convert_element_type_1191 = torch.ops.prims.convert_element_type.default(clamp_max_58, torch.float8_e4m3fn); clamp_max_58 = None | |
| permute_218 = torch.ops.aten.permute.default(primals_364, [1, 0]); primals_364 = None | |
| abs_12 = torch.ops.aten.abs.default(permute_218) | |
| max_5 = torch.ops.aten.max.default(abs_12); abs_12 = None | |
| - convert_element_type_1117 = torch.ops.prims.convert_element_type.default(max_5, torch.float64); max_5 = None | |
| - clamp_min_92 = torch.ops.aten.clamp_min.default(convert_element_type_1117, 1e-12); convert_element_type_1117 = None | |
| + convert_element_type_1192 = torch.ops.prims.convert_element_type.default(max_5, torch.float64); max_5 = None | |
| + clamp_min_92 = torch.ops.aten.clamp_min.default(convert_element_type_1192, 1e-12); convert_element_type_1192 = None | |
| reciprocal_57 = torch.ops.aten.reciprocal.default(clamp_min_92); clamp_min_92 = None | |
| - mul_942 = torch.ops.aten.mul.Tensor(reciprocal_57, 448.0); reciprocal_57 = None | |
| - convert_element_type_1118 = torch.ops.prims.convert_element_type.default(mul_942, torch.float32); mul_942 = None | |
| - log2_29 = torch.ops.aten.log2.default(convert_element_type_1118); convert_element_type_1118 = None | |
| + mul_992 = torch.ops.aten.mul.Tensor(reciprocal_57, 448.0); reciprocal_57 = None | |
| + convert_element_type_1193 = torch.ops.prims.convert_element_type.default(mul_992, torch.float32); mul_992 = None | |
| + log2_29 = torch.ops.aten.log2.default(convert_element_type_1193); convert_element_type_1193 = None | |
| floor_29 = torch.ops.aten.floor.default(log2_29); log2_29 = None | |
| exp2_29 = torch.ops.aten.exp2.default(floor_29); floor_29 = None | |
| - mul_943 = torch.ops.aten.mul.Tensor(permute_218, exp2_29); permute_218 = None | |
| - clamp_min_93 = torch.ops.aten.clamp_min.default(mul_943, -448.0); mul_943 = None | |
| + mul_993 = torch.ops.aten.mul.Tensor(permute_218, exp2_29); permute_218 = None | |
| + clamp_min_93 = torch.ops.aten.clamp_min.default(mul_993, -448.0); mul_993 = None | |
| clamp_max_59 = torch.ops.aten.clamp_max.default(clamp_min_93, 448.0); clamp_min_93 = None | |
| - convert_element_type_1119 = torch.ops.prims.convert_element_type.default(clamp_max_59, torch.float8_e4m3fn); clamp_max_59 = None | |
| - clone_95 = torch.ops.aten.clone.default(convert_element_type_1119, memory_format = torch.contiguous_format); convert_element_type_1119 = None | |
| + convert_element_type_1194 = torch.ops.prims.convert_element_type.default(clamp_max_59, torch.float8_e4m3fn); clamp_max_59 = None | |
| + clone_95 = torch.ops.aten.clone.default(convert_element_type_1194, memory_format = torch.contiguous_format); convert_element_type_1194 = None | |
| permute_696 = torch.ops.aten.permute.default(clone_95, [1, 0]); clone_95 = None | |
| repeat_8 = torch.ops.aten.repeat.default(exp2_29, [28672]); exp2_29 = None | |
| - view_470 = torch.ops.aten.view.default(repeat_8, [1, -1]); repeat_8 = None | |
| + view_495 = torch.ops.aten.view.default(repeat_8, [1, -1]); repeat_8 = None | |
| reciprocal_58 = torch.ops.aten.reciprocal.default(exp2_28); exp2_28 = None | |
| - reciprocal_59 = torch.ops.aten.reciprocal.default(view_470); view_470 = None | |
| - mul_944 = torch.ops.aten.mul.Tensor(reciprocal_58, reciprocal_59); reciprocal_58 = reciprocal_59 = None | |
| - _scaled_mm_14 = torch.ops.aten._scaled_mm.default(convert_element_type_1116, permute_696, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1116 = permute_696 = None | |
| - mul_945 = torch.ops.aten.mul.Tensor(_scaled_mm_14, mul_944); _scaled_mm_14 = mul_944 = None | |
| - permute_697 = torch.ops.aten.permute.default(convert_element_type_1111, [1, 0]); convert_element_type_1111 = None | |
| - convert_element_type_default_44 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_739_bmm_7, torch.bfloat16); fp8_quant_pos_739_bmm_7 = None | |
| - div_tensor_17 = torch.ops.aten.div.Tensor(convert_element_type_default_44, fp8_scale_pos_739_bmm_7); convert_element_type_default_44 = fp8_scale_pos_739_bmm_7 = None | |
| - convert_element_type_default_45 = torch.ops.prims.convert_element_type.default(div_tensor_17, torch.bfloat16); div_tensor_17 = None | |
| - view_258 = torch.ops.aten.view.default(convert_element_type_default_45, [4096, -1]); convert_element_type_default_45 = None | |
| + reciprocal_59 = torch.ops.aten.reciprocal.default(view_495); view_495 = None | |
| + mul_994 = torch.ops.aten.mul.Tensor(reciprocal_58, reciprocal_59); reciprocal_58 = reciprocal_59 = None | |
| + _scaled_mm_14 = torch.ops.aten._scaled_mm.default(convert_element_type_1191, permute_696, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1191 = permute_696 = None | |
| + mul_995 = torch.ops.aten.mul.Tensor(_scaled_mm_14, mul_994); _scaled_mm_14 = mul_994 = None | |
| + permute_697 = torch.ops.aten.permute.default(convert_element_type_1185, [1, 0]); convert_element_type_1185 = None | |
| + convert_element_type_default_43 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_721_convert_element_type_547, torch.bfloat16); fp8_quant_pos_721_convert_element_type_547 = None | |
| + div_tensor_18 = torch.ops.aten.div.Tensor(convert_element_type_default_43, fp8_scale_pos_721_convert_element_type_547); convert_element_type_default_43 = fp8_scale_pos_721_convert_element_type_547 = None | |
| + convert_element_type_default_44 = torch.ops.prims.convert_element_type.default(div_tensor_18, torch.bfloat16); div_tensor_18 = None | |
| + convert_element_type_default_45 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_726_expand_8, torch.bfloat16); fp8_quant_pos_726_expand_8 = None | |
| + div_tensor_19 = torch.ops.aten.div.Tensor(convert_element_type_default_45, fp8_scale_pos_726_expand_8); convert_element_type_default_45 = fp8_scale_pos_726_expand_8 = None | |
| + convert_element_type_default_46 = torch.ops.prims.convert_element_type.default(div_tensor_19, torch.bfloat16); div_tensor_19 = None | |
| + index_27 = torch.ops.aten.index.Tensor(convert_element_type_default_42, [sub_32]) | |
| + cat_18 = torch.ops.aten.cat.default([index_27, convert_element_type_default_44], 1); index_27 = None | |
| + expand_9 = torch.ops.aten.expand.default(cat_18, [4096, 512, 112]) | |
| + bmm_5 = torch.ops.aten.bmm.default(convert_element_type_default_46, expand_9) | |
| + permute_215 = torch.ops.aten.permute.default(cat_18, [0, 2, 1]); cat_18 = None | |
| + expand_12 = torch.ops.aten.expand.default(bmm_5, [4096, 48, 112]); bmm_5 = None | |
| + expand_13 = torch.ops.aten.expand.default(permute_215, [4096, 112, 512]); permute_215 = None | |
| + bmm_7 = torch.ops.aten.bmm.default(expand_12, expand_13) | |
| + view_258 = torch.ops.aten.view.default(bmm_7, [4096, -1]); bmm_7 = None | |
| cat_20 = torch.ops.aten.cat.default([view_258, mul_117], 1); view_258 = None | |
| pow_47 = torch.ops.aten.pow.Tensor_Scalar(cat_20, 2) | |
| mean_46 = torch.ops.aten.mean.dim(pow_47, [1], True); pow_47 = None | |
| @@ -6881,58 +4028,59 @@ | |
| rsqrt_67 = torch.ops.aten.rsqrt.default(add_153); add_153 = None | |
| mul_198 = torch.ops.aten.mul.Tensor(cat_20, rsqrt_67); cat_20 = None | |
| mul_199 = torch.ops.aten.mul.Tensor(mul_198, primals_362) | |
| - convert_element_type_611 = torch.ops.prims.convert_element_type.default(mul_199, torch.bfloat16); mul_199 = None | |
| - mm_277 = torch.ops.aten.mm.default(permute_697, convert_element_type_611); permute_697 = convert_element_type_611 = None | |
| + convert_element_type_608 = torch.ops.prims.convert_element_type.default(mul_199, torch.bfloat16); mul_199 = None | |
| + mm_277 = torch.ops.aten.mm.default(permute_697, convert_element_type_608); permute_697 = convert_element_type_608 = None | |
| permute_698 = torch.ops.aten.permute.default(mm_277, [1, 0]); mm_277 = None | |
| - convert_element_type_1123 = torch.ops.prims.convert_element_type.default(permute_698, torch.float32); permute_698 = None | |
| - permute_699 = torch.ops.aten.permute.default(convert_element_type_1123, [1, 0]); convert_element_type_1123 = None | |
| - convert_element_type_default_7 = torch.ops.prims.convert_element_type.default(mul_945, torch.float32); mul_945 = None | |
| - slice_scatter_31 = torch.ops.aten.slice_scatter.default(full_default_297, convert_element_type_1112, 1, 0, 4608); convert_element_type_1112 = None | |
| - add_415 = torch.ops.aten.add.Tensor(slice_scatter_30, slice_scatter_31); slice_scatter_30 = slice_scatter_31 = None | |
| - abs_58 = torch.ops.aten.abs.default(add_415) | |
| + convert_element_type_1198 = torch.ops.prims.convert_element_type.default(permute_698, torch.float32); permute_698 = None | |
| + permute_699 = torch.ops.aten.permute.default(convert_element_type_1198, [1, 0]); convert_element_type_1198 = None | |
| + convert_element_type_default_6 = torch.ops.prims.convert_element_type.default(mul_995, torch.float32); mul_995 = None | |
| + slice_scatter_31 = torch.ops.aten.slice_scatter.default(full_default_297, convert_element_type_1187, 1, 0, 4608); convert_element_type_1187 = None | |
| + add_440 = torch.ops.aten.add.Tensor(slice_scatter_30, slice_scatter_31); slice_scatter_30 = slice_scatter_31 = None | |
| + abs_58 = torch.ops.aten.abs.default(add_440) | |
| amax_26 = torch.ops.aten.amax.default(abs_58, [-1], True); abs_58 = None | |
| - convert_element_type_1125 = torch.ops.prims.convert_element_type.default(amax_26, torch.float64); amax_26 = None | |
| - clamp_min_94 = torch.ops.aten.clamp_min.default(convert_element_type_1125, 1e-12); convert_element_type_1125 = None | |
| + convert_element_type_1200 = torch.ops.prims.convert_element_type.default(amax_26, torch.float64); amax_26 = None | |
| + clamp_min_94 = torch.ops.aten.clamp_min.default(convert_element_type_1200, 1e-12); convert_element_type_1200 = None | |
| reciprocal_60 = torch.ops.aten.reciprocal.default(clamp_min_94); clamp_min_94 = None | |
| - mul_946 = torch.ops.aten.mul.Tensor(reciprocal_60, 448.0); reciprocal_60 = None | |
| - convert_element_type_1126 = torch.ops.prims.convert_element_type.default(mul_946, torch.float32); mul_946 = None | |
| - log2_30 = torch.ops.aten.log2.default(convert_element_type_1126); convert_element_type_1126 = None | |
| + mul_996 = torch.ops.aten.mul.Tensor(reciprocal_60, 448.0); reciprocal_60 = None | |
| + convert_element_type_1201 = torch.ops.prims.convert_element_type.default(mul_996, torch.float32); mul_996 = None | |
| + log2_30 = torch.ops.aten.log2.default(convert_element_type_1201); convert_element_type_1201 = None | |
| floor_30 = torch.ops.aten.floor.default(log2_30); log2_30 = None | |
| exp2_30 = torch.ops.aten.exp2.default(floor_30); floor_30 = None | |
| - convert_element_type_1127 = torch.ops.prims.convert_element_type.default(add_415, torch.float32) | |
| - mul_947 = torch.ops.aten.mul.Tensor(convert_element_type_1127, exp2_30); convert_element_type_1127 = None | |
| - clamp_min_95 = torch.ops.aten.clamp_min.default(mul_947, -448.0); mul_947 = None | |
| + convert_element_type_1202 = torch.ops.prims.convert_element_type.default(add_440, torch.float32) | |
| + mul_997 = torch.ops.aten.mul.Tensor(convert_element_type_1202, exp2_30); convert_element_type_1202 = None | |
| + clamp_min_95 = torch.ops.aten.clamp_min.default(mul_997, -448.0); mul_997 = None | |
| clamp_max_60 = torch.ops.aten.clamp_max.default(clamp_min_95, 448.0); clamp_min_95 = None | |
| - convert_element_type_1128 = torch.ops.prims.convert_element_type.default(clamp_max_60, torch.float8_e4m3fn); clamp_max_60 = None | |
| + convert_element_type_1203 = torch.ops.prims.convert_element_type.default(clamp_max_60, torch.float8_e4m3fn); clamp_max_60 = None | |
| permute_217 = torch.ops.aten.permute.default(primals_363, [1, 0]); primals_363 = None | |
| abs_10 = torch.ops.aten.abs.default(permute_217) | |
| max_6 = torch.ops.aten.max.default(abs_10); abs_10 = None | |
| - convert_element_type_1129 = torch.ops.prims.convert_element_type.default(max_6, torch.float64); max_6 = None | |
| - clamp_min_96 = torch.ops.aten.clamp_min.default(convert_element_type_1129, 1e-12); convert_element_type_1129 = None | |
| + convert_element_type_1204 = torch.ops.prims.convert_element_type.default(max_6, torch.float64); max_6 = None | |
| + clamp_min_96 = torch.ops.aten.clamp_min.default(convert_element_type_1204, 1e-12); convert_element_type_1204 = None | |
| reciprocal_61 = torch.ops.aten.reciprocal.default(clamp_min_96); clamp_min_96 = None | |
| - mul_948 = torch.ops.aten.mul.Tensor(reciprocal_61, 448.0); reciprocal_61 = None | |
| - convert_element_type_1130 = torch.ops.prims.convert_element_type.default(mul_948, torch.float32); mul_948 = None | |
| - log2_31 = torch.ops.aten.log2.default(convert_element_type_1130); convert_element_type_1130 = None | |
| + mul_998 = torch.ops.aten.mul.Tensor(reciprocal_61, 448.0); reciprocal_61 = None | |
| + convert_element_type_1205 = torch.ops.prims.convert_element_type.default(mul_998, torch.float32); mul_998 = None | |
| + log2_31 = torch.ops.aten.log2.default(convert_element_type_1205); convert_element_type_1205 = None | |
| floor_31 = torch.ops.aten.floor.default(log2_31); log2_31 = None | |
| exp2_31 = torch.ops.aten.exp2.default(floor_31); floor_31 = None | |
| - mul_949 = torch.ops.aten.mul.Tensor(permute_217, exp2_31); permute_217 = None | |
| - clamp_min_97 = torch.ops.aten.clamp_min.default(mul_949, -448.0); mul_949 = None | |
| + mul_999 = torch.ops.aten.mul.Tensor(permute_217, exp2_31); permute_217 = None | |
| + clamp_min_97 = torch.ops.aten.clamp_min.default(mul_999, -448.0); mul_999 = None | |
| clamp_max_61 = torch.ops.aten.clamp_max.default(clamp_min_97, 448.0); clamp_min_97 = None | |
| - convert_element_type_1131 = torch.ops.prims.convert_element_type.default(clamp_max_61, torch.float8_e4m3fn); clamp_max_61 = None | |
| - clone_96 = torch.ops.aten.clone.default(convert_element_type_1131, memory_format = torch.contiguous_format); convert_element_type_1131 = None | |
| + convert_element_type_1206 = torch.ops.prims.convert_element_type.default(clamp_max_61, torch.float8_e4m3fn); clamp_max_61 = None | |
| + clone_96 = torch.ops.aten.clone.default(convert_element_type_1206, memory_format = torch.contiguous_format); convert_element_type_1206 = None | |
| permute_702 = torch.ops.aten.permute.default(clone_96, [1, 0]); clone_96 = None | |
| repeat_9 = torch.ops.aten.repeat.default(exp2_31, [16384]); exp2_31 = None | |
| - view_475 = torch.ops.aten.view.default(repeat_9, [1, -1]); repeat_9 = None | |
| + view_500 = torch.ops.aten.view.default(repeat_9, [1, -1]); repeat_9 = None | |
| reciprocal_62 = torch.ops.aten.reciprocal.default(exp2_30); exp2_30 = None | |
| - reciprocal_63 = torch.ops.aten.reciprocal.default(view_475); view_475 = None | |
| - mul_950 = torch.ops.aten.mul.Tensor(reciprocal_62, reciprocal_63); reciprocal_62 = reciprocal_63 = None | |
| - _scaled_mm_15 = torch.ops.aten._scaled_mm.default(convert_element_type_1128, permute_702, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1128 = permute_702 = None | |
| - mul_951 = torch.ops.aten.mul.Tensor(_scaled_mm_15, mul_950); _scaled_mm_15 = mul_950 = None | |
| - permute_703 = torch.ops.aten.permute.default(add_415, [1, 0]); add_415 = None | |
| - convert_element_type_default_46 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_741_expand_15, torch.bfloat16); fp8_quant_pos_741_expand_15 = None | |
| - div_tensor_18 = torch.ops.aten.div.Tensor(convert_element_type_default_46, fp8_scale_pos_741_expand_15); convert_element_type_default_46 = fp8_scale_pos_741_expand_15 = None | |
| - convert_element_type_default_47 = torch.ops.prims.convert_element_type.default(div_tensor_18, torch.bfloat16); div_tensor_18 = None | |
| - bmm_8 = torch.ops.aten.bmm.default(expand_14, convert_element_type_default_47) | |
| + reciprocal_63 = torch.ops.aten.reciprocal.default(view_500); view_500 = None | |
| + mul_1000 = torch.ops.aten.mul.Tensor(reciprocal_62, reciprocal_63); reciprocal_62 = reciprocal_63 = None | |
| + _scaled_mm_15 = torch.ops.aten._scaled_mm.default(convert_element_type_1203, permute_702, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1203 = permute_702 = None | |
| + mul_1001 = torch.ops.aten.mul.Tensor(_scaled_mm_15, mul_1000); _scaled_mm_15 = mul_1000 = None | |
| + permute_703 = torch.ops.aten.permute.default(add_440, [1, 0]); add_440 = None | |
| + expand_11 = torch.ops.aten.expand.default(convert_element_type_default_42, [4096, 256, 112]); convert_element_type_default_42 = None | |
| + bmm_6 = torch.ops.aten.bmm.default(expand_10, expand_11) | |
| + expand_14 = torch.ops.aten.expand.default(bmm_6, [4096, 48, 112]); bmm_6 = None | |
| + expand_15 = torch.ops.aten.expand.default(permute_191, [4096, 112, 256]); permute_191 = None | |
| + bmm_8 = torch.ops.aten.bmm.default(expand_14, expand_15) | |
| view_257 = torch.ops.aten.view.default(bmm_8, [4096, -1]); bmm_8 = None | |
| cat_19 = torch.ops.aten.cat.default([view_257, mul_115], 1); view_257 = None | |
| pow_46 = torch.ops.aten.pow.Tensor_Scalar(cat_19, 2) | |
| @@ -6941,275 +4089,277 @@ | |
| rsqrt_66 = torch.ops.aten.rsqrt.default(add_152); add_152 = None | |
| mul_196 = torch.ops.aten.mul.Tensor(cat_19, rsqrt_66); cat_19 = None | |
| mul_197 = torch.ops.aten.mul.Tensor(mul_196, primals_361) | |
| - convert_element_type_603 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None | |
| - mm_278 = torch.ops.aten.mm.default(permute_703, convert_element_type_603); permute_703 = convert_element_type_603 = None | |
| + convert_element_type_600 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None | |
| + mm_278 = torch.ops.aten.mm.default(permute_703, convert_element_type_600); permute_703 = convert_element_type_600 = None | |
| permute_704 = torch.ops.aten.permute.default(mm_278, [1, 0]); mm_278 = None | |
| - convert_element_type_1135 = torch.ops.prims.convert_element_type.default(permute_704, torch.float32); permute_704 = None | |
| - permute_705 = torch.ops.aten.permute.default(convert_element_type_1135, [1, 0]); convert_element_type_1135 = None | |
| - convert_element_type_default_6 = torch.ops.prims.convert_element_type.default(mul_951, torch.float32); mul_951 = None | |
| - mul_952 = torch.ops.aten.mul.Tensor(convert_element_type_default_7, primals_362); primals_362 = None | |
| - mul_954 = torch.ops.aten.mul.Tensor(mul_198, mul_952) | |
| - sum_150 = torch.ops.aten.sum.dim_IntList(mul_954, [1], True); mul_954 = None | |
| + convert_element_type_1210 = torch.ops.prims.convert_element_type.default(permute_704, torch.float32); permute_704 = None | |
| + permute_705 = torch.ops.aten.permute.default(convert_element_type_1210, [1, 0]); convert_element_type_1210 = None | |
| + convert_element_type_default_5 = torch.ops.prims.convert_element_type.default(mul_1001, torch.float32); mul_1001 = None | |
| + mul_1002 = torch.ops.aten.mul.Tensor(convert_element_type_default_6, primals_362); primals_362 = None | |
| + mul_1004 = torch.ops.aten.mul.Tensor(mul_198, mul_1002) | |
| + sum_150 = torch.ops.aten.sum.dim_IntList(mul_1004, [1], True); mul_1004 = None | |
| div_115 = torch.ops.aten.div.Tensor(mul_198, 28672) | |
| - mul_955 = torch.ops.aten.mul.Tensor(div_115, sum_150); div_115 = sum_150 = None | |
| - sub_240 = torch.ops.aten.sub.Tensor(mul_952, mul_955); mul_952 = mul_955 = None | |
| - mul_956 = torch.ops.aten.mul.Tensor(sub_240, rsqrt_67); sub_240 = rsqrt_67 = None | |
| - mul_957 = torch.ops.aten.mul.Tensor(convert_element_type_default_7, mul_198); convert_element_type_default_7 = mul_198 = None | |
| - sum_151 = torch.ops.aten.sum.dim_IntList(mul_957, [0]); mul_957 = None | |
| - mul_958 = torch.ops.aten.mul.Tensor(convert_element_type_default_6, primals_361); primals_361 = None | |
| - mul_960 = torch.ops.aten.mul.Tensor(mul_196, mul_958) | |
| - sum_152 = torch.ops.aten.sum.dim_IntList(mul_960, [1], True); mul_960 = None | |
| + mul_1005 = torch.ops.aten.mul.Tensor(div_115, sum_150); div_115 = sum_150 = None | |
| + sub_215 = torch.ops.aten.sub.Tensor(mul_1002, mul_1005); mul_1002 = mul_1005 = None | |
| + mul_1006 = torch.ops.aten.mul.Tensor(sub_215, rsqrt_67); sub_215 = rsqrt_67 = None | |
| + mul_1007 = torch.ops.aten.mul.Tensor(convert_element_type_default_6, mul_198); convert_element_type_default_6 = mul_198 = None | |
| + sum_151 = torch.ops.aten.sum.dim_IntList(mul_1007, [0]); mul_1007 = None | |
| + mul_1008 = torch.ops.aten.mul.Tensor(convert_element_type_default_5, primals_361); primals_361 = None | |
| + mul_1010 = torch.ops.aten.mul.Tensor(mul_196, mul_1008) | |
| + sum_152 = torch.ops.aten.sum.dim_IntList(mul_1010, [1], True); mul_1010 = None | |
| div_116 = torch.ops.aten.div.Tensor(mul_196, 16384) | |
| - mul_961 = torch.ops.aten.mul.Tensor(div_116, sum_152); div_116 = sum_152 = None | |
| - sub_241 = torch.ops.aten.sub.Tensor(mul_958, mul_961); mul_958 = mul_961 = None | |
| - mul_962 = torch.ops.aten.mul.Tensor(sub_241, rsqrt_66); sub_241 = rsqrt_66 = None | |
| - mul_963 = torch.ops.aten.mul.Tensor(convert_element_type_default_6, mul_196); convert_element_type_default_6 = mul_196 = None | |
| - sum_153 = torch.ops.aten.sum.dim_IntList(mul_963, [0]); mul_963 = None | |
| - slice_97 = torch.ops.aten.slice.Tensor(mul_956, 1, 0, 24576) | |
| - slice_98 = torch.ops.aten.slice.Tensor(mul_956, 1, 24576, 28672); mul_956 = None | |
| - convert_element_type_1137 = torch.ops.prims.convert_element_type.default(slice_97, torch.bfloat16); slice_97 = None | |
| - add_416 = torch.ops.aten.add.Tensor(slice_88, slice_98); slice_88 = slice_98 = None | |
| - slice_99 = torch.ops.aten.slice.Tensor(mul_962, 1, 0, 12288) | |
| - slice_100 = torch.ops.aten.slice.Tensor(mul_962, 1, 12288, 16384); mul_962 = None | |
| - convert_element_type_1138 = torch.ops.prims.convert_element_type.default(slice_99, torch.bfloat16); slice_99 = None | |
| - add_417 = torch.ops.aten.add.Tensor(slice_90, slice_100); slice_90 = slice_100 = None | |
| - view_479 = torch.ops.aten.view.default(convert_element_type_1137, [4096, 48, 512]); convert_element_type_1137 = None | |
| - view_480 = torch.ops.aten.view.default(convert_element_type_1138, [4096, 48, 256]); convert_element_type_1138 = None | |
| + mul_1011 = torch.ops.aten.mul.Tensor(div_116, sum_152); div_116 = sum_152 = None | |
| + sub_216 = torch.ops.aten.sub.Tensor(mul_1008, mul_1011); mul_1008 = mul_1011 = None | |
| + mul_1012 = torch.ops.aten.mul.Tensor(sub_216, rsqrt_66); sub_216 = rsqrt_66 = None | |
| + mul_1013 = torch.ops.aten.mul.Tensor(convert_element_type_default_5, mul_196); convert_element_type_default_5 = mul_196 = None | |
| + sum_153 = torch.ops.aten.sum.dim_IntList(mul_1013, [0]); mul_1013 = None | |
| + slice_97 = torch.ops.aten.slice.Tensor(mul_1006, 1, 0, 24576) | |
| + slice_98 = torch.ops.aten.slice.Tensor(mul_1006, 1, 24576, 28672); mul_1006 = None | |
| + convert_element_type_1212 = torch.ops.prims.convert_element_type.default(slice_97, torch.bfloat16); slice_97 = None | |
| + add_441 = torch.ops.aten.add.Tensor(slice_88, slice_98); slice_88 = slice_98 = None | |
| + slice_99 = torch.ops.aten.slice.Tensor(mul_1012, 1, 0, 12288) | |
| + slice_100 = torch.ops.aten.slice.Tensor(mul_1012, 1, 12288, 16384); mul_1012 = None | |
| + convert_element_type_1213 = torch.ops.prims.convert_element_type.default(slice_99, torch.bfloat16); slice_99 = None | |
| + add_442 = torch.ops.aten.add.Tensor(slice_90, slice_100); slice_90 = slice_100 = None | |
| + view_504 = torch.ops.aten.view.default(convert_element_type_1212, [4096, 48, 512]); convert_element_type_1212 = None | |
| + view_505 = torch.ops.aten.view.default(convert_element_type_1213, [4096, 48, 256]); convert_element_type_1213 = None | |
| permute_706 = torch.ops.aten.permute.default(expand_14, [0, 2, 1]); expand_14 = None | |
| - bmm_21 = torch.ops.aten.bmm.default(permute_706, view_480); permute_706 = None | |
| - permute_707 = torch.ops.aten.permute.default(convert_element_type_default_47, [0, 2, 1]); convert_element_type_default_47 = None | |
| - bmm_22 = torch.ops.aten.bmm.default(view_480, permute_707); view_480 = permute_707 = None | |
| - convert_element_type_1143 = torch.ops.prims.convert_element_type.default(bmm_21, torch.float32); bmm_21 = None | |
| - permute_708 = torch.ops.aten.permute.default(convert_element_type_1143, [0, 2, 1]); convert_element_type_1143 = None | |
| - bmm_23 = torch.ops.aten.bmm.default(permute_709, view_479); permute_709 = None | |
| - convert_element_type_default_40 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_733_cat_16, torch.bfloat16); fp8_quant_pos_733_cat_16 = None | |
| - div_tensor_15 = torch.ops.aten.div.Tensor(convert_element_type_default_40, fp8_scale_pos_733_cat_16); convert_element_type_default_40 = fp8_scale_pos_733_cat_16 = None | |
| - convert_element_type_default_41 = torch.ops.prims.convert_element_type.default(div_tensor_15, torch.bfloat16); div_tensor_15 = None | |
| - convert_element_type_default_42 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_734_cat_17, torch.bfloat16); fp8_quant_pos_734_cat_17 = None | |
| - div_tensor_16 = torch.ops.aten.div.Tensor(convert_element_type_default_42, fp8_scale_pos_734_cat_17); convert_element_type_default_42 = fp8_scale_pos_734_cat_17 = None | |
| - convert_element_type_default_43 = torch.ops.prims.convert_element_type.default(div_tensor_16, torch.bfloat16); div_tensor_16 = None | |
| - convert_element_type_539 = torch.ops.prims.convert_element_type.default(convert_element_type_default_41, torch.float32); convert_element_type_default_41 = None | |
| - pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_539, 2) | |
| - mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None | |
| - add_140 = torch.ops.aten.add.Scalar(mean_39, 1.1920928955078125e-07); mean_39 = None | |
| - rsqrt_60 = torch.ops.aten.rsqrt.default(add_140); add_140 = None | |
| - mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_539, rsqrt_60); convert_element_type_539 = None | |
| - mul_181 = torch.ops.aten.mul.Tensor(mul_180, primals_343) | |
| - convert_element_type_540 = torch.ops.prims.convert_element_type.default(convert_element_type_default_43, torch.float32); convert_element_type_default_43 = None | |
| - pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_540, 2) | |
| - mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None | |
| - add_141 = torch.ops.aten.add.Scalar(mean_40, 1.1920928955078125e-07); mean_40 = None | |
| - rsqrt_61 = torch.ops.aten.rsqrt.default(add_141); add_141 = None | |
| - mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_540, rsqrt_61); convert_element_type_540 = None | |
| - mul_183 = torch.ops.aten.mul.Tensor(mul_182, primals_344) | |
| - index_27 = torch.ops.aten.index.Tensor(mul_181, [sub_32]) | |
| - cat_18 = torch.ops.aten.cat.default([index_27, mul_183], 1); index_27 = None | |
| - permute_215 = torch.ops.aten.permute.default(cat_18, [0, 2, 1]) | |
| - convert_element_type_597 = torch.ops.prims.convert_element_type.default(permute_215, torch.bfloat16); permute_215 = None | |
| - expand_13 = torch.ops.aten.expand.default(convert_element_type_597, [4096, 112, 512]); convert_element_type_597 = None | |
| + bmm_21 = torch.ops.aten.bmm.default(permute_706, view_505); permute_706 = None | |
| + permute_707 = torch.ops.aten.permute.default(expand_15, [0, 2, 1]); expand_15 = None | |
| + bmm_22 = torch.ops.aten.bmm.default(view_505, permute_707); view_505 = permute_707 = None | |
| + permute_708 = torch.ops.aten.permute.default(bmm_21, [0, 2, 1]); bmm_21 = None | |
| + permute_709 = torch.ops.aten.permute.default(expand_12, [0, 2, 1]); expand_12 = None | |
| + bmm_23 = torch.ops.aten.bmm.default(permute_709, view_504); permute_709 = None | |
| permute_710 = torch.ops.aten.permute.default(expand_13, [0, 2, 1]); expand_13 = None | |
| - bmm_24 = torch.ops.aten.bmm.default(view_479, permute_710); view_479 = permute_710 = None | |
| - convert_element_type_1148 = torch.ops.prims.convert_element_type.default(bmm_23, torch.float32); bmm_23 = None | |
| - permute_711 = torch.ops.aten.permute.default(convert_element_type_1148, [0, 2, 1]); convert_element_type_1148 = None | |
| + bmm_24 = torch.ops.aten.bmm.default(view_504, permute_710); view_504 = permute_710 = None | |
| + permute_711 = torch.ops.aten.permute.default(bmm_23, [0, 2, 1]); bmm_23 = None | |
| + permute_712 = torch.ops.aten.permute.default(expand_10, [0, 2, 1]); expand_10 = None | |
| bmm_25 = torch.ops.aten.bmm.default(permute_712, bmm_22); permute_712 = None | |
| - convert_element_type_542 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None | |
| - expand_11 = torch.ops.aten.expand.default(convert_element_type_542, [4096, 256, 112]) | |
| permute_713 = torch.ops.aten.permute.default(expand_11, [0, 2, 1]); expand_11 = None | |
| bmm_26 = torch.ops.aten.bmm.default(bmm_22, permute_713); bmm_22 = permute_713 = None | |
| - convert_element_type_1153 = torch.ops.prims.convert_element_type.default(bmm_25, torch.float32); bmm_25 = None | |
| - add_418 = torch.ops.aten.add.Tensor(permute_708, convert_element_type_1153); permute_708 = convert_element_type_1153 = None | |
| + add_443 = torch.ops.aten.add.Tensor(permute_708, bmm_25); permute_708 = bmm_25 = None | |
| full_default_315 = torch.ops.aten.full.default([4096, 48, 512], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_32 = torch.ops.aten.slice_scatter.default(full_default_315, bmm_26, 2, 0, 256); full_default_315 = bmm_26 = None | |
| - convert_element_type_default_54 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_850_permute_714, torch.bfloat16); fp8_quant_pos_850_permute_714 = None | |
| - div_tensor_22 = torch.ops.aten.div.Tensor(convert_element_type_default_54, fp8_scale_pos_850_permute_714); convert_element_type_default_54 = fp8_scale_pos_850_permute_714 = None | |
| - convert_element_type_default_55 = torch.ops.prims.convert_element_type.default(div_tensor_22, torch.bfloat16); div_tensor_22 = None | |
| - bmm_27 = torch.ops.aten.bmm.default(convert_element_type_default_55, bmm_24); convert_element_type_default_55 = None | |
| - convert_element_type_591 = torch.ops.prims.convert_element_type.default(cat_18, torch.bfloat16); cat_18 = None | |
| - expand_9 = torch.ops.aten.expand.default(convert_element_type_591, [4096, 512, 112]); convert_element_type_591 = None | |
| + permute_714 = torch.ops.aten.permute.default(convert_element_type_default_46, [0, 2, 1]); convert_element_type_default_46 = None | |
| + bmm_27 = torch.ops.aten.bmm.default(permute_714, bmm_24); permute_714 = None | |
| permute_715 = torch.ops.aten.permute.default(expand_9, [0, 2, 1]); expand_9 = None | |
| bmm_28 = torch.ops.aten.bmm.default(bmm_24, permute_715); bmm_24 = permute_715 = None | |
| - convert_element_type_1158 = torch.ops.prims.convert_element_type.default(bmm_27, torch.float32); bmm_27 = None | |
| - add_419 = torch.ops.aten.add.Tensor(permute_711, convert_element_type_1158); permute_711 = convert_element_type_1158 = None | |
| - slice_101 = torch.ops.aten.slice.Tensor(add_419, 1, 0, 256) | |
| - slice_102 = torch.ops.aten.slice.Tensor(add_419, 1, 256, 512); add_419 = None | |
| - full_default_316 = torch.ops.aten.full.default([4096, 256, 112], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - index_put_35 = torch.ops.aten.index_put.default(full_default_316, [sub_32], slice_101, True); full_default_316 = slice_101 = None | |
| - add_420 = torch.ops.aten.add.Tensor(add_418, index_put_35); add_418 = index_put_35 = None | |
| - view_493 = torch.ops.aten.view.default(bmm_28, [4096, 24576]); bmm_28 = None | |
| - view_494 = torch.ops.aten.view.default(slice_scatter_32, [4096, 24576]); slice_scatter_32 = None | |
| + add_444 = torch.ops.aten.add.Tensor(permute_711, bmm_27); permute_711 = bmm_27 = None | |
| + slice_101 = torch.ops.aten.slice.Tensor(add_444, 1, 0, 256) | |
| + slice_102 = torch.ops.aten.slice.Tensor(add_444, 1, 256, 512); add_444 = None | |
| + index_put_35 = torch.ops.aten.index_put.default(full_default_287, [sub_32], slice_101, True); slice_101 = None | |
| + add_445 = torch.ops.aten.add.Tensor(add_443, index_put_35); add_443 = index_put_35 = None | |
| + view_518 = torch.ops.aten.view.default(bmm_28, [4096, 24576]); bmm_28 = None | |
| + view_519 = torch.ops.aten.view.default(slice_scatter_32, [4096, 24576]); slice_scatter_32 = None | |
| full_default_317 = torch.ops.aten.full.default([4096, 24576], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - index_put_36 = torch.ops.aten.index_put.default(full_default_317, [sub_32], view_493, True); full_default_317 = None | |
| + index_put_36 = torch.ops.aten.index_put.default(full_default_317, [sub_32], view_518, True); full_default_317 = None | |
| full_default_318 = torch.ops.aten.full.default([4096, 49152], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_33 = torch.ops.aten.slice_scatter.default(full_default_318, index_put_36, 1, 24576, 9223372036854775807); index_put_36 = None | |
| - permute_716 = torch.ops.aten.permute.default(view_493, [1, 0]) | |
| + permute_716 = torch.ops.aten.permute.default(view_518, [1, 0]) | |
| slice_42 = torch.ops.aten.slice.Tensor(mm_77, 1, 768, 9223372036854775807) | |
| index_25 = torch.ops.aten.index.Tensor(slice_42, [sub_32]); slice_42 = None | |
| add_148 = torch.ops.aten.add.Tensor(mm_78, index_25); mm_78 = index_25 = None | |
| - convert_element_type_582 = torch.ops.prims.convert_element_type.default(add_148, torch.float32); add_148 = None | |
| - pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) | |
| + convert_element_type_584 = torch.ops.prims.convert_element_type.default(add_148, torch.float32); add_148 = None | |
| + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_584, 2) | |
| mean_44 = torch.ops.aten.mean.dim(pow_45, [1], True); pow_45 = None | |
| add_150 = torch.ops.aten.add.Scalar(mean_44, 1.1920928955078125e-07); mean_44 = None | |
| rsqrt_65 = torch.ops.aten.rsqrt.default(add_150); add_150 = None | |
| - mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_65); convert_element_type_582 = None | |
| + mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_584, rsqrt_65) | |
| mul_193 = torch.ops.aten.mul.Tensor(mul_192, primals_358) | |
| - sigmoid_24 = torch.ops.aten.sigmoid.default(mul_193) | |
| - mul_195 = torch.ops.aten.mul.Tensor(mul_193, sigmoid_24) | |
| - convert_element_type_587 = torch.ops.prims.convert_element_type.default(mul_195, torch.bfloat16); mul_195 = None | |
| - mm_279 = torch.ops.aten.mm.default(permute_716, convert_element_type_587); permute_716 = convert_element_type_587 = None | |
| - convert_element_type_588 = torch.ops.prims.convert_element_type.default(primals_360, torch.bfloat16); primals_360 = None | |
| - permute_214 = torch.ops.aten.permute.default(convert_element_type_588, [1, 0]); convert_element_type_588 = None | |
| + convert_element_type_585 = torch.ops.prims.convert_element_type.default(mul_193, torch.bfloat16); mul_193 = None | |
| + sigmoid_24 = torch.ops.aten.sigmoid.default(convert_element_type_585) | |
| + mul_195 = torch.ops.aten.mul.Tensor(convert_element_type_585, sigmoid_24) | |
| + mm_279 = torch.ops.aten.mm.default(permute_716, mul_195); permute_716 = mul_195 = None | |
| + convert_element_type_589 = torch.ops.prims.convert_element_type.default(primals_360, torch.bfloat16); primals_360 = None | |
| + permute_214 = torch.ops.aten.permute.default(convert_element_type_589, [1, 0]); convert_element_type_589 = None | |
| permute_718 = torch.ops.aten.permute.default(permute_214, [1, 0]); permute_214 = None | |
| - mm_280 = torch.ops.aten.mm.default(view_493, permute_718); view_493 = permute_718 = None | |
| - convert_element_type_1163 = torch.ops.prims.convert_element_type.default(mm_279, torch.float32); mm_279 = None | |
| - convert_element_type_1164 = torch.ops.prims.convert_element_type.default(mm_280, torch.float32); mm_280 = None | |
| - slice_scatter_34 = torch.ops.aten.slice_scatter.default(full_default_318, view_494, 1, 0, 24576); full_default_318 = view_494 = None | |
| - add_421 = torch.ops.aten.add.Tensor(slice_scatter_33, slice_scatter_34); slice_scatter_33 = slice_scatter_34 = None | |
| - permute_720 = torch.ops.aten.permute.default(add_421, [1, 0]) | |
| + mm_280 = torch.ops.aten.mm.default(view_518, permute_718); view_518 = permute_718 = None | |
| + convert_element_type_1234 = torch.ops.prims.convert_element_type.default(mm_279, torch.float32); mm_279 = None | |
| + slice_scatter_34 = torch.ops.aten.slice_scatter.default(full_default_318, view_519, 1, 0, 24576); full_default_318 = view_519 = None | |
| + add_446 = torch.ops.aten.add.Tensor(slice_scatter_33, slice_scatter_34); slice_scatter_33 = slice_scatter_34 = None | |
| + permute_720 = torch.ops.aten.permute.default(add_446, [1, 0]) | |
| slice_41 = torch.ops.aten.slice.Tensor(mm_77, 1, 0, 768); mm_77 = None | |
| - convert_element_type_581 = torch.ops.prims.convert_element_type.default(slice_41, torch.float32); slice_41 = None | |
| - pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_581, 2) | |
| + convert_element_type_582 = torch.ops.prims.convert_element_type.default(slice_41, torch.float32); slice_41 = None | |
| + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) | |
| mean_43 = torch.ops.aten.mean.dim(pow_44, [1], True); pow_44 = None | |
| add_149 = torch.ops.aten.add.Scalar(mean_43, 1.1920928955078125e-07); mean_43 = None | |
| rsqrt_64 = torch.ops.aten.rsqrt.default(add_149); add_149 = None | |
| - mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_581, rsqrt_64); convert_element_type_581 = None | |
| + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_64) | |
| mul_191 = torch.ops.aten.mul.Tensor(mul_190, primals_357) | |
| - sigmoid_23 = torch.ops.aten.sigmoid.default(mul_191) | |
| - mul_194 = torch.ops.aten.mul.Tensor(mul_191, sigmoid_23) | |
| - convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_194, torch.bfloat16); mul_194 = None | |
| - mm_281 = torch.ops.aten.mm.default(permute_720, convert_element_type_583); permute_720 = convert_element_type_583 = None | |
| - convert_element_type_584 = torch.ops.prims.convert_element_type.default(primals_359, torch.bfloat16); primals_359 = None | |
| - permute_213 = torch.ops.aten.permute.default(convert_element_type_584, [1, 0]); convert_element_type_584 = None | |
| + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_191, torch.bfloat16); mul_191 = None | |
| + sigmoid_23 = torch.ops.aten.sigmoid.default(convert_element_type_583) | |
| + mul_194 = torch.ops.aten.mul.Tensor(convert_element_type_583, sigmoid_23) | |
| + mm_281 = torch.ops.aten.mm.default(permute_720, mul_194); permute_720 = mul_194 = None | |
| + convert_element_type_586 = torch.ops.prims.convert_element_type.default(primals_359, torch.bfloat16); primals_359 = None | |
| + permute_213 = torch.ops.aten.permute.default(convert_element_type_586, [1, 0]); convert_element_type_586 = None | |
| permute_722 = torch.ops.aten.permute.default(permute_213, [1, 0]); permute_213 = None | |
| - mm_282 = torch.ops.aten.mm.default(add_421, permute_722); add_421 = permute_722 = None | |
| - convert_element_type_1169 = torch.ops.prims.convert_element_type.default(mm_281, torch.float32); mm_281 = None | |
| - convert_element_type_1170 = torch.ops.prims.convert_element_type.default(mm_282, torch.float32); mm_282 = None | |
| - mul_964 = torch.ops.aten.mul.Tensor(convert_element_type_1164, mul_193); mul_193 = None | |
| - mul_965 = torch.ops.aten.mul.Tensor(convert_element_type_1164, sigmoid_24); convert_element_type_1164 = None | |
| - sub_242 = torch.ops.aten.sub.Tensor(1, sigmoid_24) | |
| - mul_966 = torch.ops.aten.mul.Tensor(sigmoid_24, sub_242); sigmoid_24 = sub_242 = None | |
| - mul_967 = torch.ops.aten.mul.Tensor(mul_964, mul_966); mul_964 = mul_966 = None | |
| - add_422 = torch.ops.aten.add.Tensor(mul_965, mul_967); mul_965 = mul_967 = None | |
| - mul_968 = torch.ops.aten.mul.Tensor(convert_element_type_1170, mul_191); mul_191 = None | |
| - mul_969 = torch.ops.aten.mul.Tensor(convert_element_type_1170, sigmoid_23); convert_element_type_1170 = None | |
| - sub_243 = torch.ops.aten.sub.Tensor(1, sigmoid_23) | |
| - mul_970 = torch.ops.aten.mul.Tensor(sigmoid_23, sub_243); sigmoid_23 = sub_243 = None | |
| - mul_971 = torch.ops.aten.mul.Tensor(mul_968, mul_970); mul_968 = mul_970 = None | |
| - add_423 = torch.ops.aten.add.Tensor(mul_969, mul_971); mul_969 = mul_971 = None | |
| - mul_972 = torch.ops.aten.mul.Tensor(add_422, primals_358); primals_358 = None | |
| - mul_974 = torch.ops.aten.mul.Tensor(mul_192, mul_972) | |
| - sum_154 = torch.ops.aten.sum.dim_IntList(mul_974, [1], True); mul_974 = None | |
| - div_117 = torch.ops.aten.div.Tensor(mul_192, 768) | |
| - mul_975 = torch.ops.aten.mul.Tensor(div_117, sum_154); div_117 = sum_154 = None | |
| - sub_244 = torch.ops.aten.sub.Tensor(mul_972, mul_975); mul_972 = mul_975 = None | |
| - mul_976 = torch.ops.aten.mul.Tensor(sub_244, rsqrt_65); sub_244 = rsqrt_65 = None | |
| - mul_977 = torch.ops.aten.mul.Tensor(add_422, mul_192); add_422 = mul_192 = None | |
| - sum_155 = torch.ops.aten.sum.dim_IntList(mul_977, [0]); mul_977 = None | |
| - convert_element_type_1171 = torch.ops.prims.convert_element_type.default(mul_976, torch.bfloat16); mul_976 = None | |
| - mul_978 = torch.ops.aten.mul.Tensor(add_423, primals_357); primals_357 = None | |
| - mul_980 = torch.ops.aten.mul.Tensor(mul_190, mul_978) | |
| - sum_156 = torch.ops.aten.sum.dim_IntList(mul_980, [1], True); mul_980 = None | |
| - div_118 = torch.ops.aten.div.Tensor(mul_190, 768) | |
| - mul_981 = torch.ops.aten.mul.Tensor(div_118, sum_156); div_118 = sum_156 = None | |
| - sub_245 = torch.ops.aten.sub.Tensor(mul_978, mul_981); mul_978 = mul_981 = None | |
| - mul_982 = torch.ops.aten.mul.Tensor(sub_245, rsqrt_64); sub_245 = rsqrt_64 = None | |
| - mul_983 = torch.ops.aten.mul.Tensor(add_423, mul_190); add_423 = mul_190 = None | |
| - sum_157 = torch.ops.aten.sum.dim_IntList(mul_983, [0]); mul_983 = None | |
| - convert_element_type_1172 = torch.ops.prims.convert_element_type.default(mul_982, torch.bfloat16); mul_982 = None | |
| - index_put_37 = torch.ops.aten.index_put.default(full_default_280, [sub_32], convert_element_type_1171, True); full_default_280 = None | |
| + mm_282 = torch.ops.aten.mm.default(add_446, permute_722); add_446 = permute_722 = None | |
| + convert_element_type_1239 = torch.ops.prims.convert_element_type.default(mm_281, torch.float32); mm_281 = None | |
| + mul_1014 = torch.ops.aten.mul.Tensor(mm_280, convert_element_type_585); convert_element_type_585 = None | |
| + mul_1015 = torch.ops.aten.mul.Tensor(mm_280, sigmoid_24); mm_280 = None | |
| + convert_element_type_1240 = torch.ops.prims.convert_element_type.default(mul_1014, torch.float32); mul_1014 = None | |
| + convert_element_type_1241 = torch.ops.prims.convert_element_type.default(sigmoid_24, torch.float32); sigmoid_24 = None | |
| + sub_217 = torch.ops.aten.sub.Tensor(1, convert_element_type_1241) | |
| + mul_1016 = torch.ops.aten.mul.Tensor(convert_element_type_1241, sub_217); convert_element_type_1241 = sub_217 = None | |
| + mul_1017 = torch.ops.aten.mul.Tensor(convert_element_type_1240, mul_1016); convert_element_type_1240 = mul_1016 = None | |
| + convert_element_type_1242 = torch.ops.prims.convert_element_type.default(mul_1017, torch.bfloat16); mul_1017 = None | |
| + add_447 = torch.ops.aten.add.Tensor(mul_1015, convert_element_type_1242); mul_1015 = convert_element_type_1242 = None | |
| + mul_1018 = torch.ops.aten.mul.Tensor(mm_282, convert_element_type_583); convert_element_type_583 = None | |
| + mul_1019 = torch.ops.aten.mul.Tensor(mm_282, sigmoid_23); mm_282 = None | |
| + convert_element_type_1243 = torch.ops.prims.convert_element_type.default(mul_1018, torch.float32); mul_1018 = None | |
| + convert_element_type_1244 = torch.ops.prims.convert_element_type.default(sigmoid_23, torch.float32); sigmoid_23 = None | |
| + sub_218 = torch.ops.aten.sub.Tensor(1, convert_element_type_1244) | |
| + mul_1020 = torch.ops.aten.mul.Tensor(convert_element_type_1244, sub_218); convert_element_type_1244 = sub_218 = None | |
| + mul_1021 = torch.ops.aten.mul.Tensor(convert_element_type_1243, mul_1020); convert_element_type_1243 = mul_1020 = None | |
| + convert_element_type_1245 = torch.ops.prims.convert_element_type.default(mul_1021, torch.bfloat16); mul_1021 = None | |
| + add_448 = torch.ops.aten.add.Tensor(mul_1019, convert_element_type_1245); mul_1019 = convert_element_type_1245 = None | |
| + convert_element_type_1246 = torch.ops.prims.convert_element_type.default(add_447, torch.float32); add_447 = None | |
| + mul_1022 = torch.ops.aten.mul.Tensor(convert_element_type_1246, mul_192); mul_192 = None | |
| + mul_1023 = torch.ops.aten.mul.Tensor(convert_element_type_1246, primals_358); convert_element_type_1246 = primals_358 = None | |
| + sum_154 = torch.ops.aten.sum.dim_IntList(mul_1022, [0], True); mul_1022 = None | |
| + view_520 = torch.ops.aten.view.default(sum_154, [768]); sum_154 = None | |
| + mul_1024 = torch.ops.aten.mul.Tensor(mul_1023, convert_element_type_584) | |
| + mul_1025 = torch.ops.aten.mul.Tensor(mul_1023, rsqrt_65); mul_1023 = None | |
| + sum_155 = torch.ops.aten.sum.dim_IntList(mul_1024, [1], True); mul_1024 = None | |
| + mul_1026 = torch.ops.aten.mul.Scalar(sum_155, -0.5); sum_155 = None | |
| + pow_134 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_65, 3); rsqrt_65 = None | |
| + mul_1027 = torch.ops.aten.mul.Tensor(mul_1026, pow_134); mul_1026 = pow_134 = None | |
| + expand_88 = torch.ops.aten.expand.default(mul_1027, [4096, 768]); mul_1027 = None | |
| + div_117 = torch.ops.aten.div.Scalar(expand_88, 768); expand_88 = None | |
| + pow_135 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_584, 1.0); convert_element_type_584 = None | |
| + mul_1028 = torch.ops.aten.mul.Scalar(pow_135, 2.0); pow_135 = None | |
| + mul_1029 = torch.ops.aten.mul.Tensor(div_117, mul_1028); div_117 = mul_1028 = None | |
| + add_449 = torch.ops.aten.add.Tensor(mul_1025, mul_1029); mul_1025 = mul_1029 = None | |
| + convert_element_type_1247 = torch.ops.prims.convert_element_type.default(add_449, torch.bfloat16); add_449 = None | |
| + convert_element_type_1248 = torch.ops.prims.convert_element_type.default(add_448, torch.float32); add_448 = None | |
| + mul_1030 = torch.ops.aten.mul.Tensor(convert_element_type_1248, mul_190); mul_190 = None | |
| + mul_1031 = torch.ops.aten.mul.Tensor(convert_element_type_1248, primals_357); convert_element_type_1248 = primals_357 = None | |
| + sum_156 = torch.ops.aten.sum.dim_IntList(mul_1030, [0], True); mul_1030 = None | |
| + view_521 = torch.ops.aten.view.default(sum_156, [768]); sum_156 = None | |
| + mul_1032 = torch.ops.aten.mul.Tensor(mul_1031, convert_element_type_582) | |
| + mul_1033 = torch.ops.aten.mul.Tensor(mul_1031, rsqrt_64); mul_1031 = None | |
| + sum_157 = torch.ops.aten.sum.dim_IntList(mul_1032, [1], True); mul_1032 = None | |
| + mul_1034 = torch.ops.aten.mul.Scalar(sum_157, -0.5); sum_157 = None | |
| + pow_136 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_64, 3); rsqrt_64 = None | |
| + mul_1035 = torch.ops.aten.mul.Tensor(mul_1034, pow_136); mul_1034 = pow_136 = None | |
| + expand_89 = torch.ops.aten.expand.default(mul_1035, [4096, 768]); mul_1035 = None | |
| + div_118 = torch.ops.aten.div.Scalar(expand_89, 768); expand_89 = None | |
| + pow_137 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 1.0); convert_element_type_582 = None | |
| + mul_1036 = torch.ops.aten.mul.Scalar(pow_137, 2.0); pow_137 = None | |
| + mul_1037 = torch.ops.aten.mul.Tensor(div_118, mul_1036); div_118 = mul_1036 = None | |
| + add_450 = torch.ops.aten.add.Tensor(mul_1033, mul_1037); mul_1033 = mul_1037 = None | |
| + convert_element_type_1249 = torch.ops.prims.convert_element_type.default(add_450, torch.bfloat16); add_450 = None | |
| + index_put_37 = torch.ops.aten.index_put.default(full_default_280, [sub_32], convert_element_type_1247, True); full_default_280 = None | |
| slice_scatter_35 = torch.ops.aten.slice_scatter.default(full_default_281, index_put_37, 1, 768, 9223372036854775807); index_put_37 = None | |
| - permute_724 = torch.ops.aten.permute.default(convert_element_type_1171, [1, 0]) | |
| + permute_724 = torch.ops.aten.permute.default(convert_element_type_1247, [1, 0]) | |
| slice_40 = torch.ops.aten.slice.Tensor(mm_75, 1, 1024, 9223372036854775807) | |
| index_24 = torch.ops.aten.index.Tensor(slice_40, [sub_32]); slice_40 = None | |
| add_145 = torch.ops.aten.add.Tensor(mm_76, index_24); mm_76 = index_24 = None | |
| - convert_element_type_572 = torch.ops.prims.convert_element_type.default(add_145, torch.float32); add_145 = None | |
| - pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_572, 2) | |
| + convert_element_type_574 = torch.ops.prims.convert_element_type.default(add_145, torch.float32); add_145 = None | |
| + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_574, 2) | |
| mean_42 = torch.ops.aten.mean.dim(pow_43, [1], True); pow_43 = None | |
| add_147 = torch.ops.aten.add.Scalar(mean_42, 1.1920928955078125e-07); mean_42 = None | |
| rsqrt_63 = torch.ops.aten.rsqrt.default(add_147); add_147 = None | |
| - mul_186 = torch.ops.aten.mul.Tensor(convert_element_type_572, rsqrt_63); convert_element_type_572 = None | |
| + mul_186 = torch.ops.aten.mul.Tensor(convert_element_type_574, rsqrt_63) | |
| mul_187 = torch.ops.aten.mul.Tensor(mul_186, primals_354) | |
| - sigmoid_22 = torch.ops.aten.sigmoid.default(mul_187) | |
| - mul_189 = torch.ops.aten.mul.Tensor(mul_187, sigmoid_22) | |
| - convert_element_type_577 = torch.ops.prims.convert_element_type.default(mul_189, torch.bfloat16); mul_189 = None | |
| - mm_283 = torch.ops.aten.mm.default(permute_724, convert_element_type_577); permute_724 = convert_element_type_577 = None | |
| - convert_element_type_578 = torch.ops.prims.convert_element_type.default(primals_356, torch.bfloat16); primals_356 = None | |
| - permute_212 = torch.ops.aten.permute.default(convert_element_type_578, [1, 0]); convert_element_type_578 = None | |
| + convert_element_type_575 = torch.ops.prims.convert_element_type.default(mul_187, torch.bfloat16); mul_187 = None | |
| + sigmoid_22 = torch.ops.aten.sigmoid.default(convert_element_type_575) | |
| + mul_189 = torch.ops.aten.mul.Tensor(convert_element_type_575, sigmoid_22) | |
| + mm_283 = torch.ops.aten.mm.default(permute_724, mul_189); permute_724 = mul_189 = None | |
| + convert_element_type_579 = torch.ops.prims.convert_element_type.default(primals_356, torch.bfloat16); primals_356 = None | |
| + permute_212 = torch.ops.aten.permute.default(convert_element_type_579, [1, 0]); convert_element_type_579 = None | |
| permute_726 = torch.ops.aten.permute.default(permute_212, [1, 0]); permute_212 = None | |
| - mm_284 = torch.ops.aten.mm.default(convert_element_type_1171, permute_726); convert_element_type_1171 = permute_726 = None | |
| - convert_element_type_1177 = torch.ops.prims.convert_element_type.default(mm_283, torch.float32); mm_283 = None | |
| - convert_element_type_1178 = torch.ops.prims.convert_element_type.default(mm_284, torch.float32); mm_284 = None | |
| - slice_scatter_36 = torch.ops.aten.slice_scatter.default(full_default_281, convert_element_type_1172, 1, 0, 768); convert_element_type_1172 = None | |
| - add_424 = torch.ops.aten.add.Tensor(slice_scatter_35, slice_scatter_36); slice_scatter_35 = slice_scatter_36 = None | |
| - permute_728 = torch.ops.aten.permute.default(add_424, [1, 0]) | |
| + mm_284 = torch.ops.aten.mm.default(convert_element_type_1247, permute_726); convert_element_type_1247 = permute_726 = None | |
| + convert_element_type_1254 = torch.ops.prims.convert_element_type.default(mm_283, torch.float32); mm_283 = None | |
| + slice_scatter_36 = torch.ops.aten.slice_scatter.default(full_default_281, convert_element_type_1249, 1, 0, 768); convert_element_type_1249 = None | |
| + add_451 = torch.ops.aten.add.Tensor(slice_scatter_35, slice_scatter_36); slice_scatter_35 = slice_scatter_36 = None | |
| + permute_728 = torch.ops.aten.permute.default(add_451, [1, 0]) | |
| slice_39 = torch.ops.aten.slice.Tensor(mm_75, 1, 0, 1024); mm_75 = None | |
| - convert_element_type_571 = torch.ops.prims.convert_element_type.default(slice_39, torch.float32); slice_39 = None | |
| - pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_571, 2) | |
| + convert_element_type_572 = torch.ops.prims.convert_element_type.default(slice_39, torch.float32); slice_39 = None | |
| + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_572, 2) | |
| mean_41 = torch.ops.aten.mean.dim(pow_42, [1], True); pow_42 = None | |
| add_146 = torch.ops.aten.add.Scalar(mean_41, 1.1920928955078125e-07); mean_41 = None | |
| rsqrt_62 = torch.ops.aten.rsqrt.default(add_146); add_146 = None | |
| - mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_571, rsqrt_62); convert_element_type_571 = None | |
| + mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_572, rsqrt_62) | |
| mul_185 = torch.ops.aten.mul.Tensor(mul_184, primals_353) | |
| - sigmoid_21 = torch.ops.aten.sigmoid.default(mul_185) | |
| - mul_188 = torch.ops.aten.mul.Tensor(mul_185, sigmoid_21) | |
| - convert_element_type_573 = torch.ops.prims.convert_element_type.default(mul_188, torch.bfloat16); mul_188 = None | |
| - mm_285 = torch.ops.aten.mm.default(permute_728, convert_element_type_573); permute_728 = convert_element_type_573 = None | |
| - convert_element_type_574 = torch.ops.prims.convert_element_type.default(primals_355, torch.bfloat16); primals_355 = None | |
| - permute_211 = torch.ops.aten.permute.default(convert_element_type_574, [1, 0]); convert_element_type_574 = None | |
| + convert_element_type_573 = torch.ops.prims.convert_element_type.default(mul_185, torch.bfloat16); mul_185 = None | |
| + sigmoid_21 = torch.ops.aten.sigmoid.default(convert_element_type_573) | |
| + mul_188 = torch.ops.aten.mul.Tensor(convert_element_type_573, sigmoid_21) | |
| + mm_285 = torch.ops.aten.mm.default(permute_728, mul_188); permute_728 = mul_188 = None | |
| + convert_element_type_576 = torch.ops.prims.convert_element_type.default(primals_355, torch.bfloat16); primals_355 = None | |
| + permute_211 = torch.ops.aten.permute.default(convert_element_type_576, [1, 0]); convert_element_type_576 = None | |
| permute_730 = torch.ops.aten.permute.default(permute_211, [1, 0]); permute_211 = None | |
| - mm_286 = torch.ops.aten.mm.default(add_424, permute_730); add_424 = permute_730 = None | |
| - convert_element_type_1183 = torch.ops.prims.convert_element_type.default(mm_285, torch.float32); mm_285 = None | |
| - convert_element_type_1184 = torch.ops.prims.convert_element_type.default(mm_286, torch.float32); mm_286 = None | |
| - mul_984 = torch.ops.aten.mul.Tensor(convert_element_type_1178, mul_187); mul_187 = None | |
| - mul_985 = torch.ops.aten.mul.Tensor(convert_element_type_1178, sigmoid_22); convert_element_type_1178 = None | |
| - sub_246 = torch.ops.aten.sub.Tensor(1, sigmoid_22) | |
| - mul_986 = torch.ops.aten.mul.Tensor(sigmoid_22, sub_246); sigmoid_22 = sub_246 = None | |
| - mul_987 = torch.ops.aten.mul.Tensor(mul_984, mul_986); mul_984 = mul_986 = None | |
| - add_425 = torch.ops.aten.add.Tensor(mul_985, mul_987); mul_985 = mul_987 = None | |
| - mul_988 = torch.ops.aten.mul.Tensor(convert_element_type_1184, mul_185); mul_185 = None | |
| - mul_989 = torch.ops.aten.mul.Tensor(convert_element_type_1184, sigmoid_21); convert_element_type_1184 = None | |
| - sub_247 = torch.ops.aten.sub.Tensor(1, sigmoid_21) | |
| - mul_990 = torch.ops.aten.mul.Tensor(sigmoid_21, sub_247); sigmoid_21 = sub_247 = None | |
| - mul_991 = torch.ops.aten.mul.Tensor(mul_988, mul_990); mul_988 = mul_990 = None | |
| - add_426 = torch.ops.aten.add.Tensor(mul_989, mul_991); mul_989 = mul_991 = None | |
| - mul_992 = torch.ops.aten.mul.Tensor(add_425, primals_354); primals_354 = None | |
| - mul_994 = torch.ops.aten.mul.Tensor(mul_186, mul_992) | |
| - sum_158 = torch.ops.aten.sum.dim_IntList(mul_994, [1], True); mul_994 = None | |
| - div_119 = torch.ops.aten.div.Tensor(mul_186, 1024) | |
| - mul_995 = torch.ops.aten.mul.Tensor(div_119, sum_158); div_119 = sum_158 = None | |
| - sub_248 = torch.ops.aten.sub.Tensor(mul_992, mul_995); mul_992 = mul_995 = None | |
| - mul_996 = torch.ops.aten.mul.Tensor(sub_248, rsqrt_63); sub_248 = rsqrt_63 = None | |
| - mul_997 = torch.ops.aten.mul.Tensor(add_425, mul_186); add_425 = mul_186 = None | |
| - sum_159 = torch.ops.aten.sum.dim_IntList(mul_997, [0]); mul_997 = None | |
| - convert_element_type_1185 = torch.ops.prims.convert_element_type.default(mul_996, torch.bfloat16); mul_996 = None | |
| - mul_998 = torch.ops.aten.mul.Tensor(add_426, primals_353); primals_353 = None | |
| - mul_1000 = torch.ops.aten.mul.Tensor(mul_184, mul_998) | |
| - sum_160 = torch.ops.aten.sum.dim_IntList(mul_1000, [1], True); mul_1000 = None | |
| - div_120 = torch.ops.aten.div.Tensor(mul_184, 1024) | |
| - mul_1001 = torch.ops.aten.mul.Tensor(div_120, sum_160); div_120 = sum_160 = None | |
| - sub_249 = torch.ops.aten.sub.Tensor(mul_998, mul_1001); mul_998 = mul_1001 = None | |
| - mul_1002 = torch.ops.aten.mul.Tensor(sub_249, rsqrt_62); sub_249 = rsqrt_62 = None | |
| - mul_1003 = torch.ops.aten.mul.Tensor(add_426, mul_184); add_426 = mul_184 = None | |
| - sum_161 = torch.ops.aten.sum.dim_IntList(mul_1003, [0]); mul_1003 = None | |
| - convert_element_type_1186 = torch.ops.prims.convert_element_type.default(mul_1002, torch.bfloat16); mul_1002 = None | |
| - index_put_38 = torch.ops.aten.index_put.default(full_default_278, [sub_32], convert_element_type_1185, True) | |
| + mm_286 = torch.ops.aten.mm.default(add_451, permute_730); add_451 = permute_730 = None | |
| + convert_element_type_1259 = torch.ops.prims.convert_element_type.default(mm_285, torch.float32); mm_285 = None | |
| + mul_1038 = torch.ops.aten.mul.Tensor(mm_284, convert_element_type_575); convert_element_type_575 = None | |
| + mul_1039 = torch.ops.aten.mul.Tensor(mm_284, sigmoid_22); mm_284 = None | |
| + convert_element_type_1260 = torch.ops.prims.convert_element_type.default(mul_1038, torch.float32); mul_1038 = None | |
| + convert_element_type_1261 = torch.ops.prims.convert_element_type.default(sigmoid_22, torch.float32); sigmoid_22 = None | |
| + sub_219 = torch.ops.aten.sub.Tensor(1, convert_element_type_1261) | |
| + mul_1040 = torch.ops.aten.mul.Tensor(convert_element_type_1261, sub_219); convert_element_type_1261 = sub_219 = None | |
| + mul_1041 = torch.ops.aten.mul.Tensor(convert_element_type_1260, mul_1040); convert_element_type_1260 = mul_1040 = None | |
| + convert_element_type_1262 = torch.ops.prims.convert_element_type.default(mul_1041, torch.bfloat16); mul_1041 = None | |
| + add_452 = torch.ops.aten.add.Tensor(mul_1039, convert_element_type_1262); mul_1039 = convert_element_type_1262 = None | |
| + mul_1042 = torch.ops.aten.mul.Tensor(mm_286, convert_element_type_573); convert_element_type_573 = None | |
| + mul_1043 = torch.ops.aten.mul.Tensor(mm_286, sigmoid_21); mm_286 = None | |
| + convert_element_type_1263 = torch.ops.prims.convert_element_type.default(mul_1042, torch.float32); mul_1042 = None | |
| + convert_element_type_1264 = torch.ops.prims.convert_element_type.default(sigmoid_21, torch.float32); sigmoid_21 = None | |
| + sub_220 = torch.ops.aten.sub.Tensor(1, convert_element_type_1264) | |
| + mul_1044 = torch.ops.aten.mul.Tensor(convert_element_type_1264, sub_220); convert_element_type_1264 = sub_220 = None | |
| + mul_1045 = torch.ops.aten.mul.Tensor(convert_element_type_1263, mul_1044); convert_element_type_1263 = mul_1044 = None | |
| + convert_element_type_1265 = torch.ops.prims.convert_element_type.default(mul_1045, torch.bfloat16); mul_1045 = None | |
| + add_453 = torch.ops.aten.add.Tensor(mul_1043, convert_element_type_1265); mul_1043 = convert_element_type_1265 = None | |
| + convert_element_type_1266 = torch.ops.prims.convert_element_type.default(add_452, torch.float32); add_452 = None | |
| + mul_1046 = torch.ops.aten.mul.Tensor(convert_element_type_1266, mul_186); mul_186 = None | |
| + mul_1047 = torch.ops.aten.mul.Tensor(convert_element_type_1266, primals_354); convert_element_type_1266 = primals_354 = None | |
| + sum_158 = torch.ops.aten.sum.dim_IntList(mul_1046, [0], True); mul_1046 = None | |
| + view_522 = torch.ops.aten.view.default(sum_158, [1024]); sum_158 = None | |
| + mul_1048 = torch.ops.aten.mul.Tensor(mul_1047, convert_element_type_574) | |
| + mul_1049 = torch.ops.aten.mul.Tensor(mul_1047, rsqrt_63); mul_1047 = None | |
| + sum_159 = torch.ops.aten.sum.dim_IntList(mul_1048, [1], True); mul_1048 = None | |
| + mul_1050 = torch.ops.aten.mul.Scalar(sum_159, -0.5); sum_159 = None | |
| + pow_138 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_63, 3); rsqrt_63 = None | |
| + mul_1051 = torch.ops.aten.mul.Tensor(mul_1050, pow_138); mul_1050 = pow_138 = None | |
| + expand_90 = torch.ops.aten.expand.default(mul_1051, [4096, 1024]); mul_1051 = None | |
| + div_119 = torch.ops.aten.div.Scalar(expand_90, 1024); expand_90 = None | |
| + pow_139 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_574, 1.0); convert_element_type_574 = None | |
| + mul_1052 = torch.ops.aten.mul.Scalar(pow_139, 2.0); pow_139 = None | |
| + mul_1053 = torch.ops.aten.mul.Tensor(div_119, mul_1052); div_119 = mul_1052 = None | |
| + add_454 = torch.ops.aten.add.Tensor(mul_1049, mul_1053); mul_1049 = mul_1053 = None | |
| + convert_element_type_1267 = torch.ops.prims.convert_element_type.default(add_454, torch.bfloat16); add_454 = None | |
| + convert_element_type_1268 = torch.ops.prims.convert_element_type.default(add_453, torch.float32); add_453 = None | |
| + mul_1054 = torch.ops.aten.mul.Tensor(convert_element_type_1268, mul_184); mul_184 = None | |
| + mul_1055 = torch.ops.aten.mul.Tensor(convert_element_type_1268, primals_353); convert_element_type_1268 = primals_353 = None | |
| + sum_160 = torch.ops.aten.sum.dim_IntList(mul_1054, [0], True); mul_1054 = None | |
| + view_523 = torch.ops.aten.view.default(sum_160, [1024]); sum_160 = None | |
| + mul_1056 = torch.ops.aten.mul.Tensor(mul_1055, convert_element_type_572) | |
| + mul_1057 = torch.ops.aten.mul.Tensor(mul_1055, rsqrt_62); mul_1055 = None | |
| + sum_161 = torch.ops.aten.sum.dim_IntList(mul_1056, [1], True); mul_1056 = None | |
| + mul_1058 = torch.ops.aten.mul.Scalar(sum_161, -0.5); sum_161 = None | |
| + pow_140 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_62, 3); rsqrt_62 = None | |
| + mul_1059 = torch.ops.aten.mul.Tensor(mul_1058, pow_140); mul_1058 = pow_140 = None | |
| + expand_91 = torch.ops.aten.expand.default(mul_1059, [4096, 1024]); mul_1059 = None | |
| + div_120 = torch.ops.aten.div.Scalar(expand_91, 1024); expand_91 = None | |
| + pow_141 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_572, 1.0); convert_element_type_572 = None | |
| + mul_1060 = torch.ops.aten.mul.Scalar(pow_141, 2.0); pow_141 = None | |
| + mul_1061 = torch.ops.aten.mul.Tensor(div_120, mul_1060); div_120 = mul_1060 = None | |
| + add_455 = torch.ops.aten.add.Tensor(mul_1057, mul_1061); mul_1057 = mul_1061 = None | |
| + convert_element_type_1269 = torch.ops.prims.convert_element_type.default(add_455, torch.bfloat16); add_455 = None | |
| + index_put_38 = torch.ops.aten.index_put.default(full_default_278, [sub_32], convert_element_type_1267, True) | |
| full_default_324 = torch.ops.aten.full.default([4096, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_37 = torch.ops.aten.slice_scatter.default(full_default_324, index_put_38, 1, 1024, 9223372036854775807); index_put_38 = None | |
| - permute_732 = torch.ops.aten.permute.default(convert_element_type_1185, [1, 0]) | |
| - permute_191 = torch.ops.aten.permute.default(convert_element_type_542, [0, 2, 1]); convert_element_type_542 = None | |
| - clone_66 = torch.ops.aten.clone.default(permute_191, memory_format = torch.contiguous_format); permute_191 = None | |
| - view_229 = torch.ops.aten.view.default(clone_66, [458752, 256]); clone_66 = None | |
| - convert_element_type_546 = torch.ops.prims.convert_element_type.default(mul_183, torch.bfloat16); mul_183 = None | |
| - permute_194 = torch.ops.aten.permute.default(convert_element_type_546, [0, 2, 1]); convert_element_type_546 = None | |
| + permute_732 = torch.ops.aten.permute.default(convert_element_type_1267, [1, 0]) | |
| + permute_194 = torch.ops.aten.permute.default(convert_element_type_default_44, [0, 2, 1]); convert_element_type_default_44 = None | |
| clone_68 = torch.ops.aten.clone.default(permute_194, memory_format = torch.contiguous_format); permute_194 = None | |
| view_231 = torch.ops.aten.view.default(clone_68, [458752, 256]); clone_68 = None | |
| - convert_element_type_549 = torch.ops.prims.convert_element_type.default(primals_347, torch.bfloat16); primals_347 = None | |
| - permute_198 = torch.ops.aten.permute.default(convert_element_type_549, [1, 0]); convert_element_type_549 = None | |
| + convert_element_type_554 = torch.ops.prims.convert_element_type.default(primals_347, torch.bfloat16); primals_347 = None | |
| + permute_198 = torch.ops.aten.permute.default(convert_element_type_554, [1, 0]); convert_element_type_554 = None | |
| mm_71 = torch.ops.aten.mm.default(view_229, permute_198) | |
| view_234 = torch.ops.aten.view.default(mm_71, [4096, 112, 96]); mm_71 = None | |
| permute_199 = torch.ops.aten.permute.default(view_234, [0, 2, 1]); view_234 = None | |
| clone_71 = torch.ops.aten.clone.default(permute_199, memory_format = torch.contiguous_format); permute_199 = None | |
| - convert_element_type_553 = torch.ops.prims.convert_element_type.default(primals_348, torch.bfloat16); primals_348 = None | |
| - permute_201 = torch.ops.aten.permute.default(convert_element_type_553, [1, 0]); convert_element_type_553 = None | |
| + convert_element_type_557 = torch.ops.prims.convert_element_type.default(primals_348, torch.bfloat16); primals_348 = None | |
| + permute_201 = torch.ops.aten.permute.default(convert_element_type_557, [1, 0]); convert_element_type_557 = None | |
| mm_72 = torch.ops.aten.mm.default(view_231, permute_201) | |
| view_236 = torch.ops.aten.view.default(mm_72, [4096, 112, 48]); mm_72 = None | |
| permute_202 = torch.ops.aten.permute.default(view_236, [0, 2, 1]); view_236 = None | |
| @@ -7219,643 +4369,783 @@ | |
| add_143 = torch.ops.aten.add.Tensor(clone_73, index_22); clone_73 = index_22 = None | |
| view_242 = torch.ops.aten.view.default(add_143, [4096, -1]); add_143 = None | |
| mm_287 = torch.ops.aten.mm.default(permute_732, view_242); permute_732 = view_242 = None | |
| - convert_element_type_568 = torch.ops.prims.convert_element_type.default(primals_352, torch.bfloat16); primals_352 = None | |
| - permute_210 = torch.ops.aten.permute.default(convert_element_type_568, [1, 0]); convert_element_type_568 = None | |
| + convert_element_type_569 = torch.ops.prims.convert_element_type.default(primals_352, torch.bfloat16); primals_352 = None | |
| + permute_210 = torch.ops.aten.permute.default(convert_element_type_569, [1, 0]); convert_element_type_569 = None | |
| permute_734 = torch.ops.aten.permute.default(permute_210, [1, 0]); permute_210 = None | |
| - mm_288 = torch.ops.aten.mm.default(convert_element_type_1185, permute_734); convert_element_type_1185 = permute_734 = None | |
| - convert_element_type_1191 = torch.ops.prims.convert_element_type.default(mm_287, torch.float32); mm_287 = None | |
| - slice_scatter_38 = torch.ops.aten.slice_scatter.default(full_default_324, convert_element_type_1186, 1, 0, 1024); full_default_324 = convert_element_type_1186 = None | |
| - add_427 = torch.ops.aten.add.Tensor(slice_scatter_37, slice_scatter_38); slice_scatter_37 = slice_scatter_38 = None | |
| - permute_736 = torch.ops.aten.permute.default(add_427, [1, 0]) | |
| + mm_288 = torch.ops.aten.mm.default(convert_element_type_1267, permute_734); convert_element_type_1267 = permute_734 = None | |
| + convert_element_type_1274 = torch.ops.prims.convert_element_type.default(mm_287, torch.float32); mm_287 = None | |
| + slice_scatter_38 = torch.ops.aten.slice_scatter.default(full_default_324, convert_element_type_1269, 1, 0, 1024); full_default_324 = convert_element_type_1269 = None | |
| + add_456 = torch.ops.aten.add.Tensor(slice_scatter_37, slice_scatter_38); slice_scatter_37 = slice_scatter_38 = None | |
| + permute_736 = torch.ops.aten.permute.default(add_456, [1, 0]) | |
| slice_35 = torch.ops.aten.slice.Tensor(clone_71, 1, 0, 48); clone_71 = None | |
| view_241 = torch.ops.aten.view.default(slice_35, [4096, -1]); slice_35 = None | |
| mm_289 = torch.ops.aten.mm.default(permute_736, view_241); permute_736 = view_241 = None | |
| - convert_element_type_565 = torch.ops.prims.convert_element_type.default(primals_351, torch.bfloat16); primals_351 = None | |
| - permute_209 = torch.ops.aten.permute.default(convert_element_type_565, [1, 0]); convert_element_type_565 = None | |
| + convert_element_type_566 = torch.ops.prims.convert_element_type.default(primals_351, torch.bfloat16); primals_351 = None | |
| + permute_209 = torch.ops.aten.permute.default(convert_element_type_566, [1, 0]); convert_element_type_566 = None | |
| permute_738 = torch.ops.aten.permute.default(permute_209, [1, 0]); permute_209 = None | |
| - mm_290 = torch.ops.aten.mm.default(add_427, permute_738); add_427 = permute_738 = None | |
| - convert_element_type_1196 = torch.ops.prims.convert_element_type.default(mm_289, torch.float32); mm_289 = None | |
| - view_495 = torch.ops.aten.view.default(mm_288, [4096, 48, 112]); mm_288 = None | |
| - view_496 = torch.ops.aten.view.default(mm_290, [4096, 48, 112]); mm_290 = None | |
| + mm_290 = torch.ops.aten.mm.default(add_456, permute_738); add_456 = permute_738 = None | |
| + convert_element_type_1279 = torch.ops.prims.convert_element_type.default(mm_289, torch.float32); mm_289 = None | |
| + view_524 = torch.ops.aten.view.default(mm_288, [4096, 48, 112]); mm_288 = None | |
| + view_525 = torch.ops.aten.view.default(mm_290, [4096, 48, 112]); mm_290 = None | |
| index_put_39 = torch.ops.aten.index_put.default(full_default_284, [sub_32], slice_94, True) | |
| full_default_327 = torch.ops.aten.full.default([4096, 192, 112], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_39 = torch.ops.aten.slice_scatter.default(full_default_327, index_put_39, 1, 96, 9223372036854775807); index_put_39 = None | |
| permute_740 = torch.ops.aten.permute.default(slice_94, [0, 2, 1]); slice_94 = None | |
| - view_497 = torch.ops.aten.view.default(permute_740, [458752, 96]); permute_740 = None | |
| - permute_741 = torch.ops.aten.permute.default(view_497, [1, 0]) | |
| + view_526 = torch.ops.aten.view.default(permute_740, [458752, 96]); permute_740 = None | |
| + permute_741 = torch.ops.aten.permute.default(view_526, [1, 0]) | |
| mm_291 = torch.ops.aten.mm.default(permute_741, view_231); permute_741 = None | |
| - convert_element_type_561 = torch.ops.prims.convert_element_type.default(primals_350, torch.bfloat16); primals_350 = None | |
| - permute_207 = torch.ops.aten.permute.default(convert_element_type_561, [1, 0]); convert_element_type_561 = None | |
| + convert_element_type_563 = torch.ops.prims.convert_element_type.default(primals_350, torch.bfloat16); primals_350 = None | |
| + permute_207 = torch.ops.aten.permute.default(convert_element_type_563, [1, 0]); convert_element_type_563 = None | |
| permute_743 = torch.ops.aten.permute.default(permute_207, [1, 0]); permute_207 = None | |
| - mm_292 = torch.ops.aten.mm.default(view_497, permute_743); view_497 = permute_743 = None | |
| - view_498 = torch.ops.aten.view.default(mm_292, [4096, 112, 256]); mm_292 = None | |
| - permute_745 = torch.ops.aten.permute.default(view_498, [0, 2, 1]); view_498 = None | |
| - convert_element_type_1201 = torch.ops.prims.convert_element_type.default(permute_745, torch.float32); permute_745 = None | |
| - add_428 = torch.ops.aten.add.Tensor(slice_102, convert_element_type_1201); slice_102 = convert_element_type_1201 = None | |
| - convert_element_type_1202 = torch.ops.prims.convert_element_type.default(mm_291, torch.float32); mm_291 = None | |
| + mm_292 = torch.ops.aten.mm.default(view_526, permute_743); view_526 = permute_743 = None | |
| + view_527 = torch.ops.aten.view.default(mm_292, [4096, 112, 256]); mm_292 = None | |
| + permute_745 = torch.ops.aten.permute.default(view_527, [0, 2, 1]); view_527 = None | |
| + add_457 = torch.ops.aten.add.Tensor(slice_102, permute_745); slice_102 = permute_745 = None | |
| + convert_element_type_1284 = torch.ops.prims.convert_element_type.default(mm_291, torch.float32); mm_291 = None | |
| slice_scatter_40 = torch.ops.aten.slice_scatter.default(full_default_327, slice_96, 1, 0, 96); full_default_327 = slice_96 = None | |
| - add_429 = torch.ops.aten.add.Tensor(slice_scatter_39, slice_scatter_40); slice_scatter_39 = slice_scatter_40 = None | |
| - permute_746 = torch.ops.aten.permute.default(add_429, [0, 2, 1]); add_429 = None | |
| + add_458 = torch.ops.aten.add.Tensor(slice_scatter_39, slice_scatter_40); slice_scatter_39 = slice_scatter_40 = None | |
| + permute_746 = torch.ops.aten.permute.default(add_458, [0, 2, 1]); add_458 = None | |
| clone_97 = torch.ops.aten.clone.default(permute_746, memory_format = torch.contiguous_format); permute_746 = None | |
| - view_499 = torch.ops.aten.view.default(clone_97, [458752, 192]); clone_97 = None | |
| - permute_747 = torch.ops.aten.permute.default(view_499, [1, 0]) | |
| + view_528 = torch.ops.aten.view.default(clone_97, [458752, 192]); clone_97 = None | |
| + permute_747 = torch.ops.aten.permute.default(view_528, [1, 0]) | |
| mm_293 = torch.ops.aten.mm.default(permute_747, view_229); permute_747 = None | |
| - convert_element_type_557 = torch.ops.prims.convert_element_type.default(primals_349, torch.bfloat16); primals_349 = None | |
| - permute_204 = torch.ops.aten.permute.default(convert_element_type_557, [1, 0]); convert_element_type_557 = None | |
| permute_749 = torch.ops.aten.permute.default(permute_204, [1, 0]); permute_204 = None | |
| - mm_294 = torch.ops.aten.mm.default(view_499, permute_749); view_499 = permute_749 = None | |
| - view_500 = torch.ops.aten.view.default(mm_294, [4096, 112, 256]); mm_294 = None | |
| - permute_751 = torch.ops.aten.permute.default(view_500, [0, 2, 1]); view_500 = None | |
| - convert_element_type_1207 = torch.ops.prims.convert_element_type.default(permute_751, torch.float32); permute_751 = None | |
| - add_430 = torch.ops.aten.add.Tensor(add_420, convert_element_type_1207); add_420 = convert_element_type_1207 = None | |
| - convert_element_type_1208 = torch.ops.prims.convert_element_type.default(mm_293, torch.float32); mm_293 = None | |
| - index_put_40 = torch.ops.aten.index_put.default(full_default_283, [sub_32], view_495, True); full_default_283 = None | |
| + mm_294 = torch.ops.aten.mm.default(view_528, permute_749); view_528 = permute_749 = None | |
| + view_529 = torch.ops.aten.view.default(mm_294, [4096, 112, 256]); mm_294 = None | |
| + permute_751 = torch.ops.aten.permute.default(view_529, [0, 2, 1]); view_529 = None | |
| + add_459 = torch.ops.aten.add.Tensor(add_445, permute_751); add_445 = permute_751 = None | |
| + convert_element_type_1289 = torch.ops.prims.convert_element_type.default(mm_293, torch.float32); mm_293 = None | |
| + index_put_40 = torch.ops.aten.index_put.default(full_default_283, [sub_32], view_524, True); full_default_283 = None | |
| slice_scatter_41 = torch.ops.aten.slice_scatter.default(full_default_284, index_put_40, 1, 48, 9223372036854775807); index_put_40 = None | |
| - permute_752 = torch.ops.aten.permute.default(view_495, [0, 2, 1]); view_495 = None | |
| + permute_752 = torch.ops.aten.permute.default(view_524, [0, 2, 1]); view_524 = None | |
| clone_98 = torch.ops.aten.clone.default(permute_752, memory_format = torch.contiguous_format); permute_752 = None | |
| - view_501 = torch.ops.aten.view.default(clone_98, [458752, 48]); clone_98 = None | |
| - permute_753 = torch.ops.aten.permute.default(view_501, [1, 0]) | |
| + view_530 = torch.ops.aten.view.default(clone_98, [458752, 48]); clone_98 = None | |
| + permute_753 = torch.ops.aten.permute.default(view_530, [1, 0]) | |
| mm_295 = torch.ops.aten.mm.default(permute_753, view_231); permute_753 = None | |
| permute_755 = torch.ops.aten.permute.default(permute_201, [1, 0]); permute_201 = None | |
| - mm_296 = torch.ops.aten.mm.default(view_501, permute_755); view_501 = permute_755 = None | |
| - view_502 = torch.ops.aten.view.default(mm_296, [4096, 112, 256]); mm_296 = None | |
| - permute_757 = torch.ops.aten.permute.default(view_502, [0, 2, 1]); view_502 = None | |
| - convert_element_type_1213 = torch.ops.prims.convert_element_type.default(permute_757, torch.float32); permute_757 = None | |
| - add_431 = torch.ops.aten.add.Tensor(add_428, convert_element_type_1213); add_428 = convert_element_type_1213 = None | |
| - convert_element_type_1214 = torch.ops.prims.convert_element_type.default(mm_295, torch.float32); mm_295 = None | |
| - slice_scatter_42 = torch.ops.aten.slice_scatter.default(full_default_284, view_496, 1, 0, 48); full_default_284 = view_496 = None | |
| - add_432 = torch.ops.aten.add.Tensor(slice_scatter_41, slice_scatter_42); slice_scatter_41 = slice_scatter_42 = None | |
| - permute_758 = torch.ops.aten.permute.default(add_432, [0, 2, 1]); add_432 = None | |
| + mm_296 = torch.ops.aten.mm.default(view_530, permute_755); view_530 = permute_755 = None | |
| + view_531 = torch.ops.aten.view.default(mm_296, [4096, 112, 256]); mm_296 = None | |
| + permute_757 = torch.ops.aten.permute.default(view_531, [0, 2, 1]); view_531 = None | |
| + add_460 = torch.ops.aten.add.Tensor(add_457, permute_757); add_457 = permute_757 = None | |
| + convert_element_type_1294 = torch.ops.prims.convert_element_type.default(mm_295, torch.float32); mm_295 = None | |
| + slice_scatter_42 = torch.ops.aten.slice_scatter.default(full_default_284, view_525, 1, 0, 48); full_default_284 = view_525 = None | |
| + add_461 = torch.ops.aten.add.Tensor(slice_scatter_41, slice_scatter_42); slice_scatter_41 = slice_scatter_42 = None | |
| + permute_758 = torch.ops.aten.permute.default(add_461, [0, 2, 1]); add_461 = None | |
| clone_99 = torch.ops.aten.clone.default(permute_758, memory_format = torch.contiguous_format); permute_758 = None | |
| - view_503 = torch.ops.aten.view.default(clone_99, [458752, 96]); clone_99 = None | |
| - permute_759 = torch.ops.aten.permute.default(view_503, [1, 0]) | |
| + view_532 = torch.ops.aten.view.default(clone_99, [458752, 96]); clone_99 = None | |
| + permute_759 = torch.ops.aten.permute.default(view_532, [1, 0]) | |
| mm_297 = torch.ops.aten.mm.default(permute_759, view_229); permute_759 = None | |
| permute_761 = torch.ops.aten.permute.default(permute_198, [1, 0]); permute_198 = None | |
| - mm_298 = torch.ops.aten.mm.default(view_503, permute_761); view_503 = permute_761 = None | |
| - view_504 = torch.ops.aten.view.default(mm_298, [4096, 112, 256]); mm_298 = None | |
| - permute_763 = torch.ops.aten.permute.default(view_504, [0, 2, 1]); view_504 = None | |
| - convert_element_type_1219 = torch.ops.prims.convert_element_type.default(permute_763, torch.float32); permute_763 = None | |
| - add_433 = torch.ops.aten.add.Tensor(add_430, convert_element_type_1219); add_430 = convert_element_type_1219 = None | |
| - convert_element_type_1220 = torch.ops.prims.convert_element_type.default(mm_297, torch.float32); mm_297 = None | |
| + mm_298 = torch.ops.aten.mm.default(view_532, permute_761); view_532 = permute_761 = None | |
| + view_533 = torch.ops.aten.view.default(mm_298, [4096, 112, 256]); mm_298 = None | |
| + permute_763 = torch.ops.aten.permute.default(view_533, [0, 2, 1]); view_533 = None | |
| + add_462 = torch.ops.aten.add.Tensor(add_459, permute_763); add_459 = permute_763 = None | |
| + convert_element_type_1299 = torch.ops.prims.convert_element_type.default(mm_297, torch.float32); mm_297 = None | |
| mm_299 = torch.ops.aten.mm.default(permute_639, view_231); permute_639 = view_231 = None | |
| - convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_346, torch.bfloat16); primals_346 = None | |
| - permute_195 = torch.ops.aten.permute.default(convert_element_type_545, [1, 0]); convert_element_type_545 = None | |
| + convert_element_type_551 = torch.ops.prims.convert_element_type.default(primals_346, torch.bfloat16); primals_346 = None | |
| + permute_195 = torch.ops.aten.permute.default(convert_element_type_551, [1, 0]); convert_element_type_551 = None | |
| permute_767 = torch.ops.aten.permute.default(permute_195, [1, 0]); permute_195 = None | |
| - mm_300 = torch.ops.aten.mm.default(view_453, permute_767); view_453 = permute_767 = None | |
| - view_506 = torch.ops.aten.view.default(mm_300, [4096, 112, 256]); mm_300 = None | |
| - permute_769 = torch.ops.aten.permute.default(view_506, [0, 2, 1]); view_506 = None | |
| - convert_element_type_1225 = torch.ops.prims.convert_element_type.default(permute_769, torch.float32); permute_769 = None | |
| - add_434 = torch.ops.aten.add.Tensor(add_431, convert_element_type_1225); add_431 = convert_element_type_1225 = None | |
| - convert_element_type_1226 = torch.ops.prims.convert_element_type.default(mm_299, torch.float32); mm_299 = None | |
| + mm_300 = torch.ops.aten.mm.default(view_468, permute_767); view_468 = permute_767 = None | |
| + view_535 = torch.ops.aten.view.default(mm_300, [4096, 112, 256]); mm_300 = None | |
| + permute_769 = torch.ops.aten.permute.default(view_535, [0, 2, 1]); view_535 = None | |
| + add_463 = torch.ops.aten.add.Tensor(add_460, permute_769); add_460 = permute_769 = None | |
| + convert_element_type_1304 = torch.ops.prims.convert_element_type.default(mm_299, torch.float32); mm_299 = None | |
| mm_301 = torch.ops.aten.mm.default(permute_645, view_229); permute_645 = view_229 = None | |
| - convert_element_type_541 = torch.ops.prims.convert_element_type.default(primals_345, torch.bfloat16); primals_345 = None | |
| - permute_192 = torch.ops.aten.permute.default(convert_element_type_541, [1, 0]); convert_element_type_541 = None | |
| permute_773 = torch.ops.aten.permute.default(permute_192, [1, 0]); permute_192 = None | |
| - mm_302 = torch.ops.aten.mm.default(view_455, permute_773); view_455 = permute_773 = None | |
| - view_508 = torch.ops.aten.view.default(mm_302, [4096, 112, 256]); mm_302 = None | |
| - permute_775 = torch.ops.aten.permute.default(view_508, [0, 2, 1]); view_508 = None | |
| - convert_element_type_1231 = torch.ops.prims.convert_element_type.default(permute_775, torch.float32); permute_775 = None | |
| - add_436 = torch.ops.aten.add.Tensor(add_433, convert_element_type_1231); add_433 = convert_element_type_1231 = None | |
| - convert_element_type_1232 = torch.ops.prims.convert_element_type.default(mm_301, torch.float32); mm_301 = None | |
| - mul_1004 = torch.ops.aten.mul.Tensor(add_434, primals_344); primals_344 = None | |
| - mul_1006 = torch.ops.aten.mul.Tensor(mul_182, mul_1004) | |
| - sum_162 = torch.ops.aten.sum.dim_IntList(mul_1006, [2], True); mul_1006 = None | |
| - div_121 = torch.ops.aten.div.Tensor(mul_182, 112) | |
| - mul_1007 = torch.ops.aten.mul.Tensor(div_121, sum_162); div_121 = sum_162 = None | |
| - sub_250 = torch.ops.aten.sub.Tensor(mul_1004, mul_1007); mul_1004 = mul_1007 = None | |
| - mul_1008 = torch.ops.aten.mul.Tensor(sub_250, rsqrt_61); sub_250 = rsqrt_61 = None | |
| - mul_1009 = torch.ops.aten.mul.Tensor(add_434, mul_182); add_434 = mul_182 = None | |
| - sum_163 = torch.ops.aten.sum.dim_IntList(mul_1009, [0, 1]); mul_1009 = None | |
| - convert_element_type_1233 = torch.ops.prims.convert_element_type.default(mul_1008, torch.bfloat16); mul_1008 = None | |
| - mul_1010 = torch.ops.aten.mul.Tensor(add_436, primals_343); primals_343 = None | |
| - mul_1012 = torch.ops.aten.mul.Tensor(mul_180, mul_1010) | |
| - sum_164 = torch.ops.aten.sum.dim_IntList(mul_1012, [2], True); mul_1012 = None | |
| - div_122 = torch.ops.aten.div.Tensor(mul_180, 112) | |
| - mul_1013 = torch.ops.aten.mul.Tensor(div_122, sum_164); div_122 = sum_164 = None | |
| - sub_251 = torch.ops.aten.sub.Tensor(mul_1010, mul_1013); mul_1010 = mul_1013 = None | |
| - mul_1014 = torch.ops.aten.mul.Tensor(sub_251, rsqrt_60); sub_251 = rsqrt_60 = None | |
| - mul_1015 = torch.ops.aten.mul.Tensor(add_436, mul_180); add_436 = mul_180 = None | |
| - sum_165 = torch.ops.aten.sum.dim_IntList(mul_1015, [0, 1]); mul_1015 = None | |
| - convert_element_type_1234 = torch.ops.prims.convert_element_type.default(mul_1014, torch.bfloat16); mul_1014 = None | |
| - slice_103 = torch.ops.aten.slice.Tensor(convert_element_type_1233, 1, 0, 128) | |
| - slice_104 = torch.ops.aten.slice.Tensor(convert_element_type_1233, 1, 128, 256); convert_element_type_1233 = None | |
| - slice_105 = torch.ops.aten.slice.Tensor(convert_element_type_1234, 1, 0, 128) | |
| - slice_106 = torch.ops.aten.slice.Tensor(convert_element_type_1234, 1, 128, 256); convert_element_type_1234 = None | |
| + mm_302 = torch.ops.aten.mm.default(view_470, permute_773); view_470 = permute_773 = None | |
| + view_537 = torch.ops.aten.view.default(mm_302, [4096, 112, 256]); mm_302 = None | |
| + permute_775 = torch.ops.aten.permute.default(view_537, [0, 2, 1]); view_537 = None | |
| + add_465 = torch.ops.aten.add.Tensor(add_462, permute_775); add_462 = permute_775 = None | |
| + convert_element_type_1309 = torch.ops.prims.convert_element_type.default(mm_301, torch.float32); mm_301 = None | |
| + convert_element_type_1310 = torch.ops.prims.convert_element_type.default(add_463, torch.float32); add_463 = None | |
| + convert_element_type_default_39 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_719_cat_17, torch.bfloat16); fp8_quant_pos_719_cat_17 = None | |
| + div_tensor_16 = torch.ops.aten.div.Tensor(convert_element_type_default_39, fp8_scale_pos_719_cat_17); convert_element_type_default_39 = fp8_scale_pos_719_cat_17 = None | |
| + convert_element_type_default_40 = torch.ops.prims.convert_element_type.default(div_tensor_16, torch.bfloat16); div_tensor_16 = None | |
| + convert_element_type_546 = torch.ops.prims.convert_element_type.default(convert_element_type_default_40, torch.float32); convert_element_type_default_40 = None | |
| + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_546, 2) | |
| + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None | |
| + add_141 = torch.ops.aten.add.Scalar(mean_40, 1.1920928955078125e-07); mean_40 = None | |
| + rsqrt_61 = torch.ops.aten.rsqrt.default(add_141); add_141 = None | |
| + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_546, rsqrt_61) | |
| + mul_1062 = torch.ops.aten.mul.Tensor(convert_element_type_1310, mul_182); mul_182 = None | |
| + mul_1063 = torch.ops.aten.mul.Tensor(convert_element_type_1310, primals_344); convert_element_type_1310 = primals_344 = None | |
| + sum_162 = torch.ops.aten.sum.dim_IntList(mul_1062, [0, 1], True); mul_1062 = None | |
| + view_538 = torch.ops.aten.view.default(sum_162, [112]); sum_162 = None | |
| + mul_1064 = torch.ops.aten.mul.Tensor(mul_1063, convert_element_type_546) | |
| + mul_1065 = torch.ops.aten.mul.Tensor(mul_1063, rsqrt_61); mul_1063 = None | |
| + sum_163 = torch.ops.aten.sum.dim_IntList(mul_1064, [2], True); mul_1064 = None | |
| + mul_1066 = torch.ops.aten.mul.Scalar(sum_163, -0.5); sum_163 = None | |
| + pow_142 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_61, 3); rsqrt_61 = None | |
| + mul_1067 = torch.ops.aten.mul.Tensor(mul_1066, pow_142); mul_1066 = pow_142 = None | |
| + expand_92 = torch.ops.aten.expand.default(mul_1067, [4096, 256, 112]); mul_1067 = None | |
| + div_121 = torch.ops.aten.div.Scalar(expand_92, 112); expand_92 = None | |
| + pow_143 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_546, 1.0); convert_element_type_546 = None | |
| + mul_1068 = torch.ops.aten.mul.Scalar(pow_143, 2.0); pow_143 = None | |
| + mul_1069 = torch.ops.aten.mul.Tensor(div_121, mul_1068); div_121 = mul_1068 = None | |
| + add_466 = torch.ops.aten.add.Tensor(mul_1065, mul_1069); mul_1065 = mul_1069 = None | |
| + convert_element_type_1311 = torch.ops.prims.convert_element_type.default(add_466, torch.bfloat16); add_466 = None | |
| + convert_element_type_1312 = torch.ops.prims.convert_element_type.default(add_465, torch.float32); add_465 = None | |
| + convert_element_type_default_37 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_718_cat_16, torch.bfloat16); fp8_quant_pos_718_cat_16 = None | |
| + div_tensor_15 = torch.ops.aten.div.Tensor(convert_element_type_default_37, fp8_scale_pos_718_cat_16); convert_element_type_default_37 = fp8_scale_pos_718_cat_16 = None | |
| + convert_element_type_default_38 = torch.ops.prims.convert_element_type.default(div_tensor_15, torch.bfloat16); div_tensor_15 = None | |
| + convert_element_type_544 = torch.ops.prims.convert_element_type.default(convert_element_type_default_38, torch.float32); convert_element_type_default_38 = None | |
| + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_544, 2) | |
| + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None | |
| + add_140 = torch.ops.aten.add.Scalar(mean_39, 1.1920928955078125e-07); mean_39 = None | |
| + rsqrt_60 = torch.ops.aten.rsqrt.default(add_140); add_140 = None | |
| + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_544, rsqrt_60) | |
| + mul_1070 = torch.ops.aten.mul.Tensor(convert_element_type_1312, mul_180); mul_180 = None | |
| + mul_1071 = torch.ops.aten.mul.Tensor(convert_element_type_1312, primals_343); convert_element_type_1312 = primals_343 = None | |
| + sum_164 = torch.ops.aten.sum.dim_IntList(mul_1070, [0, 1], True); mul_1070 = None | |
| + view_539 = torch.ops.aten.view.default(sum_164, [112]); sum_164 = None | |
| + mul_1072 = torch.ops.aten.mul.Tensor(mul_1071, convert_element_type_544) | |
| + mul_1073 = torch.ops.aten.mul.Tensor(mul_1071, rsqrt_60); mul_1071 = None | |
| + sum_165 = torch.ops.aten.sum.dim_IntList(mul_1072, [2], True); mul_1072 = None | |
| + mul_1074 = torch.ops.aten.mul.Scalar(sum_165, -0.5); sum_165 = None | |
| + pow_144 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_60, 3); rsqrt_60 = None | |
| + mul_1075 = torch.ops.aten.mul.Tensor(mul_1074, pow_144); mul_1074 = pow_144 = None | |
| + expand_93 = torch.ops.aten.expand.default(mul_1075, [4096, 256, 112]); mul_1075 = None | |
| + div_122 = torch.ops.aten.div.Scalar(expand_93, 112); expand_93 = None | |
| + pow_145 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_544, 1.0); convert_element_type_544 = None | |
| + mul_1076 = torch.ops.aten.mul.Scalar(pow_145, 2.0); pow_145 = None | |
| + mul_1077 = torch.ops.aten.mul.Tensor(div_122, mul_1076); div_122 = mul_1076 = None | |
| + add_467 = torch.ops.aten.add.Tensor(mul_1073, mul_1077); mul_1073 = mul_1077 = None | |
| + convert_element_type_1313 = torch.ops.prims.convert_element_type.default(add_467, torch.bfloat16); add_467 = None | |
| + slice_103 = torch.ops.aten.slice.Tensor(convert_element_type_1311, 1, 0, 128) | |
| + slice_104 = torch.ops.aten.slice.Tensor(convert_element_type_1311, 1, 128, 256); convert_element_type_1311 = None | |
| + slice_105 = torch.ops.aten.slice.Tensor(convert_element_type_1313, 1, 0, 128) | |
| + slice_106 = torch.ops.aten.slice.Tensor(convert_element_type_1313, 1, 128, 256); convert_element_type_1313 = None | |
| index_put_42 = torch.ops.aten.index_put.default(full_default_286, [sub_32], slice_103, True) | |
| slice_scatter_45 = torch.ops.aten.slice_scatter.default(full_default_287, index_put_42, 1, 128, 9223372036854775807); index_put_42 = None | |
| permute_776 = torch.ops.aten.permute.default(slice_103, [0, 2, 1]); slice_103 = None | |
| - view_509 = torch.ops.aten.view.default(permute_776, [458752, 128]); permute_776 = None | |
| - permute_777 = torch.ops.aten.permute.default(view_509, [1, 0]) | |
| + view_540 = torch.ops.aten.view.default(permute_776, [458752, 128]); permute_776 = None | |
| + permute_777 = torch.ops.aten.permute.default(view_540, [1, 0]) | |
| mm_303 = torch.ops.aten.mm.default(permute_777, view_227); view_227 = None | |
| - convert_element_type_536 = torch.ops.prims.convert_element_type.default(primals_342, torch.bfloat16); primals_342 = None | |
| - permute_189 = torch.ops.aten.permute.default(convert_element_type_536, [1, 0]); convert_element_type_536 = None | |
| + convert_element_type_541 = torch.ops.prims.convert_element_type.default(primals_342, torch.bfloat16); primals_342 = None | |
| + permute_189 = torch.ops.aten.permute.default(convert_element_type_541, [1, 0]); convert_element_type_541 = None | |
| permute_779 = torch.ops.aten.permute.default(permute_189, [1, 0]); permute_189 = None | |
| - mm_304 = torch.ops.aten.mm.default(view_509, permute_779); permute_779 = None | |
| - view_510 = torch.ops.aten.view.default(mm_304, [4096, 112, 64]); mm_304 = None | |
| - permute_781 = torch.ops.aten.permute.default(view_510, [0, 2, 1]); view_510 = None | |
| - convert_element_type_1239 = torch.ops.prims.convert_element_type.default(mm_303, torch.float32); mm_303 = None | |
| + mm_304 = torch.ops.aten.mm.default(view_540, permute_779); permute_779 = None | |
| + view_541 = torch.ops.aten.view.default(mm_304, [4096, 112, 64]); mm_304 = None | |
| + permute_781 = torch.ops.aten.permute.default(view_541, [0, 2, 1]); view_541 = None | |
| + convert_element_type_1318 = torch.ops.prims.convert_element_type.default(mm_303, torch.float32); mm_303 = None | |
| slice_scatter_46 = torch.ops.aten.slice_scatter.default(full_default_287, slice_105, 1, 0, 128); slice_105 = None | |
| - add_437 = torch.ops.aten.add.Tensor(slice_scatter_45, slice_scatter_46); slice_scatter_45 = slice_scatter_46 = None | |
| - permute_782 = torch.ops.aten.permute.default(add_437, [0, 2, 1]); add_437 = None | |
| + add_468 = torch.ops.aten.add.Tensor(slice_scatter_45, slice_scatter_46); slice_scatter_45 = slice_scatter_46 = None | |
| + permute_782 = torch.ops.aten.permute.default(add_468, [0, 2, 1]); add_468 = None | |
| clone_101 = torch.ops.aten.clone.default(permute_782, memory_format = torch.contiguous_format); permute_782 = None | |
| - view_511 = torch.ops.aten.view.default(clone_101, [458752, 256]); clone_101 = None | |
| - permute_783 = torch.ops.aten.permute.default(view_511, [1, 0]) | |
| + view_542 = torch.ops.aten.view.default(clone_101, [458752, 256]); clone_101 = None | |
| + permute_783 = torch.ops.aten.permute.default(view_542, [1, 0]) | |
| mm_305 = torch.ops.aten.mm.default(permute_783, view_225); view_225 = None | |
| - convert_element_type_533 = torch.ops.prims.convert_element_type.default(primals_341, torch.bfloat16); primals_341 = None | |
| - permute_186 = torch.ops.aten.permute.default(convert_element_type_533, [1, 0]); convert_element_type_533 = None | |
| + convert_element_type_538 = torch.ops.prims.convert_element_type.default(primals_341, torch.bfloat16); primals_341 = None | |
| + permute_186 = torch.ops.aten.permute.default(convert_element_type_538, [1, 0]); convert_element_type_538 = None | |
| permute_785 = torch.ops.aten.permute.default(permute_186, [1, 0]); permute_186 = None | |
| - mm_306 = torch.ops.aten.mm.default(view_511, permute_785); permute_785 = None | |
| - view_512 = torch.ops.aten.view.default(mm_306, [4096, 112, 64]); mm_306 = None | |
| - permute_787 = torch.ops.aten.permute.default(view_512, [0, 2, 1]); view_512 = None | |
| - convert_element_type_1244 = torch.ops.prims.convert_element_type.default(mm_305, torch.float32); mm_305 = None | |
| + mm_306 = torch.ops.aten.mm.default(view_542, permute_785); permute_785 = None | |
| + view_543 = torch.ops.aten.view.default(mm_306, [4096, 112, 64]); mm_306 = None | |
| + permute_787 = torch.ops.aten.permute.default(view_543, [0, 2, 1]); view_543 = None | |
| + convert_element_type_1323 = torch.ops.prims.convert_element_type.default(mm_305, torch.float32); mm_305 = None | |
| clone_102 = torch.ops.aten.clone.default(permute_781, memory_format = torch.contiguous_format); permute_781 = None | |
| - view_513 = torch.ops.aten.view.default(clone_102, [4096, 7168]); clone_102 = None | |
| + view_544 = torch.ops.aten.view.default(clone_102, [4096, 7168]); clone_102 = None | |
| clone_103 = torch.ops.aten.clone.default(permute_787, memory_format = torch.contiguous_format); permute_787 = None | |
| - view_514 = torch.ops.aten.view.default(clone_103, [4096, 7168]); clone_103 = None | |
| - index_put_43 = torch.ops.aten.index_put.default(full_default_289, [sub_32], view_513, True); full_default_289 = None | |
| + view_545 = torch.ops.aten.view.default(clone_103, [4096, 7168]); clone_103 = None | |
| + index_put_43 = torch.ops.aten.index_put.default(full_default_289, [sub_32], view_544, True); full_default_289 = None | |
| slice_scatter_47 = torch.ops.aten.slice_scatter.default(full_default_290, index_put_43, 1, 7168, 9223372036854775807); index_put_43 = None | |
| - abs_60 = torch.ops.aten.abs.default(view_513) | |
| + abs_60 = torch.ops.aten.abs.default(view_544) | |
| amax_27 = torch.ops.aten.amax.default(abs_60, [-1], True); abs_60 = None | |
| - convert_element_type_1245 = torch.ops.prims.convert_element_type.default(amax_27, torch.float64); amax_27 = None | |
| - clamp_min_98 = torch.ops.aten.clamp_min.default(convert_element_type_1245, 1e-12); convert_element_type_1245 = None | |
| + convert_element_type_1324 = torch.ops.prims.convert_element_type.default(amax_27, torch.float64); amax_27 = None | |
| + clamp_min_98 = torch.ops.aten.clamp_min.default(convert_element_type_1324, 1e-12); convert_element_type_1324 = None | |
| reciprocal_64 = torch.ops.aten.reciprocal.default(clamp_min_98); clamp_min_98 = None | |
| - mul_1016 = torch.ops.aten.mul.Tensor(reciprocal_64, 448.0); reciprocal_64 = None | |
| - convert_element_type_1246 = torch.ops.prims.convert_element_type.default(mul_1016, torch.float32); mul_1016 = None | |
| - log2_32 = torch.ops.aten.log2.default(convert_element_type_1246); convert_element_type_1246 = None | |
| + mul_1078 = torch.ops.aten.mul.Tensor(reciprocal_64, 448.0); reciprocal_64 = None | |
| + convert_element_type_1325 = torch.ops.prims.convert_element_type.default(mul_1078, torch.float32); mul_1078 = None | |
| + log2_32 = torch.ops.aten.log2.default(convert_element_type_1325); convert_element_type_1325 = None | |
| floor_32 = torch.ops.aten.floor.default(log2_32); log2_32 = None | |
| exp2_32 = torch.ops.aten.exp2.default(floor_32); floor_32 = None | |
| - convert_element_type_1247 = torch.ops.prims.convert_element_type.default(view_513, torch.float32) | |
| - mul_1017 = torch.ops.aten.mul.Tensor(convert_element_type_1247, exp2_32); convert_element_type_1247 = None | |
| - clamp_min_99 = torch.ops.aten.clamp_min.default(mul_1017, -448.0); mul_1017 = None | |
| + convert_element_type_1326 = torch.ops.prims.convert_element_type.default(view_544, torch.float32) | |
| + mul_1079 = torch.ops.aten.mul.Tensor(convert_element_type_1326, exp2_32); convert_element_type_1326 = None | |
| + clamp_min_99 = torch.ops.aten.clamp_min.default(mul_1079, -448.0); mul_1079 = None | |
| clamp_max_62 = torch.ops.aten.clamp_max.default(clamp_min_99, 448.0); clamp_min_99 = None | |
| - convert_element_type_1248 = torch.ops.prims.convert_element_type.default(clamp_max_62, torch.float8_e4m3fn); clamp_max_62 = None | |
| + convert_element_type_1327 = torch.ops.prims.convert_element_type.default(clamp_max_62, torch.float8_e4m3fn); clamp_max_62 = None | |
| permute_184 = torch.ops.aten.permute.default(primals_340, [1, 0]); primals_340 = None | |
| abs_8 = torch.ops.aten.abs.default(permute_184) | |
| max_7 = torch.ops.aten.max.default(abs_8); abs_8 = None | |
| - convert_element_type_1249 = torch.ops.prims.convert_element_type.default(max_7, torch.float64); max_7 = None | |
| - clamp_min_100 = torch.ops.aten.clamp_min.default(convert_element_type_1249, 1e-12); convert_element_type_1249 = None | |
| + convert_element_type_1328 = torch.ops.prims.convert_element_type.default(max_7, torch.float64); max_7 = None | |
| + clamp_min_100 = torch.ops.aten.clamp_min.default(convert_element_type_1328, 1e-12); convert_element_type_1328 = None | |
| reciprocal_65 = torch.ops.aten.reciprocal.default(clamp_min_100); clamp_min_100 = None | |
| - mul_1018 = torch.ops.aten.mul.Tensor(reciprocal_65, 448.0); reciprocal_65 = None | |
| - convert_element_type_1250 = torch.ops.prims.convert_element_type.default(mul_1018, torch.float32); mul_1018 = None | |
| - log2_33 = torch.ops.aten.log2.default(convert_element_type_1250); convert_element_type_1250 = None | |
| + mul_1080 = torch.ops.aten.mul.Tensor(reciprocal_65, 448.0); reciprocal_65 = None | |
| + convert_element_type_1329 = torch.ops.prims.convert_element_type.default(mul_1080, torch.float32); mul_1080 = None | |
| + log2_33 = torch.ops.aten.log2.default(convert_element_type_1329); convert_element_type_1329 = None | |
| floor_33 = torch.ops.aten.floor.default(log2_33); log2_33 = None | |
| exp2_33 = torch.ops.aten.exp2.default(floor_33); floor_33 = None | |
| - mul_1019 = torch.ops.aten.mul.Tensor(permute_184, exp2_33); permute_184 = None | |
| - clamp_min_101 = torch.ops.aten.clamp_min.default(mul_1019, -448.0); mul_1019 = None | |
| + mul_1081 = torch.ops.aten.mul.Tensor(permute_184, exp2_33); permute_184 = None | |
| + clamp_min_101 = torch.ops.aten.clamp_min.default(mul_1081, -448.0); mul_1081 = None | |
| clamp_max_63 = torch.ops.aten.clamp_max.default(clamp_min_101, 448.0); clamp_min_101 = None | |
| - convert_element_type_1251 = torch.ops.prims.convert_element_type.default(clamp_max_63, torch.float8_e4m3fn); clamp_max_63 = None | |
| - clone_104 = torch.ops.aten.clone.default(convert_element_type_1251, memory_format = torch.contiguous_format); convert_element_type_1251 = None | |
| + convert_element_type_1330 = torch.ops.prims.convert_element_type.default(clamp_max_63, torch.float8_e4m3fn); clamp_max_63 = None | |
| + clone_104 = torch.ops.aten.clone.default(convert_element_type_1330, memory_format = torch.contiguous_format); convert_element_type_1330 = None | |
| permute_790 = torch.ops.aten.permute.default(clone_104, [1, 0]); clone_104 = None | |
| repeat_10 = torch.ops.aten.repeat.default(exp2_33, [4608]); exp2_33 = None | |
| - view_516 = torch.ops.aten.view.default(repeat_10, [1, -1]); repeat_10 = None | |
| + view_547 = torch.ops.aten.view.default(repeat_10, [1, -1]); repeat_10 = None | |
| reciprocal_66 = torch.ops.aten.reciprocal.default(exp2_32); exp2_32 = None | |
| - reciprocal_67 = torch.ops.aten.reciprocal.default(view_516); view_516 = None | |
| - mul_1020 = torch.ops.aten.mul.Tensor(reciprocal_66, reciprocal_67); reciprocal_66 = reciprocal_67 = None | |
| - _scaled_mm_16 = torch.ops.aten._scaled_mm.default(convert_element_type_1248, permute_790, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1248 = permute_790 = None | |
| - mul_1021 = torch.ops.aten.mul.Tensor(_scaled_mm_16, mul_1020); _scaled_mm_16 = mul_1020 = None | |
| - permute_791 = torch.ops.aten.permute.default(view_513, [1, 0]); view_513 = None | |
| - convert_element_type_525 = torch.ops.prims.convert_element_type.default(mul_171, torch.bfloat16); mul_171 = None | |
| - mm_307 = torch.ops.aten.mm.default(permute_791, convert_element_type_525); permute_791 = convert_element_type_525 = None | |
| + reciprocal_67 = torch.ops.aten.reciprocal.default(view_547); view_547 = None | |
| + mul_1082 = torch.ops.aten.mul.Tensor(reciprocal_66, reciprocal_67); reciprocal_66 = reciprocal_67 = None | |
| + _scaled_mm_16 = torch.ops.aten._scaled_mm.default(convert_element_type_1327, permute_790, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1327 = permute_790 = None | |
| + mul_1083 = torch.ops.aten.mul.Tensor(_scaled_mm_16, mul_1082); _scaled_mm_16 = mul_1082 = None | |
| + convert_element_type_1331 = torch.ops.prims.convert_element_type.default(mul_1083, torch.bfloat16); mul_1083 = None | |
| + permute_791 = torch.ops.aten.permute.default(view_544, [1, 0]); view_544 = None | |
| + mm_307 = torch.ops.aten.mm.default(permute_791, mul_171); permute_791 = mul_171 = None | |
| permute_792 = torch.ops.aten.permute.default(mm_307, [1, 0]); mm_307 = None | |
| - convert_element_type_1255 = torch.ops.prims.convert_element_type.default(permute_792, torch.float32); permute_792 = None | |
| - permute_793 = torch.ops.aten.permute.default(convert_element_type_1255, [1, 0]); convert_element_type_1255 = None | |
| - convert_element_type_default_5 = torch.ops.prims.convert_element_type.default(mul_1021, torch.float32); mul_1021 = None | |
| - add_438 = torch.ops.aten.add.Tensor(add_356, convert_element_type_default_5); add_356 = convert_element_type_default_5 = None | |
| - slice_scatter_48 = torch.ops.aten.slice_scatter.default(full_default_290, view_514, 1, 0, 7168); full_default_290 = view_514 = None | |
| - add_439 = torch.ops.aten.add.Tensor(slice_scatter_47, slice_scatter_48); slice_scatter_47 = slice_scatter_48 = None | |
| - abs_62 = torch.ops.aten.abs.default(add_439) | |
| + convert_element_type_1334 = torch.ops.prims.convert_element_type.default(permute_792, torch.float32); permute_792 = None | |
| + add_469 = torch.ops.aten.add.Tensor(add_356, convert_element_type_1331); add_356 = convert_element_type_1331 = None | |
| + permute_793 = torch.ops.aten.permute.default(convert_element_type_1334, [1, 0]); convert_element_type_1334 = None | |
| + slice_scatter_48 = torch.ops.aten.slice_scatter.default(full_default_290, view_545, 1, 0, 7168); full_default_290 = view_545 = None | |
| + add_470 = torch.ops.aten.add.Tensor(slice_scatter_47, slice_scatter_48); slice_scatter_47 = slice_scatter_48 = None | |
| + abs_62 = torch.ops.aten.abs.default(add_470) | |
| amax_28 = torch.ops.aten.amax.default(abs_62, [-1], True); abs_62 = None | |
| - convert_element_type_1257 = torch.ops.prims.convert_element_type.default(amax_28, torch.float64); amax_28 = None | |
| - clamp_min_102 = torch.ops.aten.clamp_min.default(convert_element_type_1257, 1e-12); convert_element_type_1257 = None | |
| + convert_element_type_1335 = torch.ops.prims.convert_element_type.default(amax_28, torch.float64); amax_28 = None | |
| + clamp_min_102 = torch.ops.aten.clamp_min.default(convert_element_type_1335, 1e-12); convert_element_type_1335 = None | |
| reciprocal_68 = torch.ops.aten.reciprocal.default(clamp_min_102); clamp_min_102 = None | |
| - mul_1022 = torch.ops.aten.mul.Tensor(reciprocal_68, 448.0); reciprocal_68 = None | |
| - convert_element_type_1258 = torch.ops.prims.convert_element_type.default(mul_1022, torch.float32); mul_1022 = None | |
| - log2_34 = torch.ops.aten.log2.default(convert_element_type_1258); convert_element_type_1258 = None | |
| + mul_1084 = torch.ops.aten.mul.Tensor(reciprocal_68, 448.0); reciprocal_68 = None | |
| + convert_element_type_1336 = torch.ops.prims.convert_element_type.default(mul_1084, torch.float32); mul_1084 = None | |
| + log2_34 = torch.ops.aten.log2.default(convert_element_type_1336); convert_element_type_1336 = None | |
| floor_34 = torch.ops.aten.floor.default(log2_34); log2_34 = None | |
| exp2_34 = torch.ops.aten.exp2.default(floor_34); floor_34 = None | |
| - convert_element_type_1259 = torch.ops.prims.convert_element_type.default(add_439, torch.float32) | |
| - mul_1023 = torch.ops.aten.mul.Tensor(convert_element_type_1259, exp2_34); convert_element_type_1259 = None | |
| - clamp_min_103 = torch.ops.aten.clamp_min.default(mul_1023, -448.0); mul_1023 = None | |
| + convert_element_type_1337 = torch.ops.prims.convert_element_type.default(add_470, torch.float32) | |
| + mul_1085 = torch.ops.aten.mul.Tensor(convert_element_type_1337, exp2_34); convert_element_type_1337 = None | |
| + clamp_min_103 = torch.ops.aten.clamp_min.default(mul_1085, -448.0); mul_1085 = None | |
| clamp_max_64 = torch.ops.aten.clamp_max.default(clamp_min_103, 448.0); clamp_min_103 = None | |
| - convert_element_type_1260 = torch.ops.prims.convert_element_type.default(clamp_max_64, torch.float8_e4m3fn); clamp_max_64 = None | |
| + convert_element_type_1338 = torch.ops.prims.convert_element_type.default(clamp_max_64, torch.float8_e4m3fn); clamp_max_64 = None | |
| permute_183 = torch.ops.aten.permute.default(primals_339, [1, 0]); primals_339 = None | |
| abs_6 = torch.ops.aten.abs.default(permute_183) | |
| max_8 = torch.ops.aten.max.default(abs_6); abs_6 = None | |
| - convert_element_type_1261 = torch.ops.prims.convert_element_type.default(max_8, torch.float64); max_8 = None | |
| - clamp_min_104 = torch.ops.aten.clamp_min.default(convert_element_type_1261, 1e-12); convert_element_type_1261 = None | |
| + convert_element_type_1339 = torch.ops.prims.convert_element_type.default(max_8, torch.float64); max_8 = None | |
| + clamp_min_104 = torch.ops.aten.clamp_min.default(convert_element_type_1339, 1e-12); convert_element_type_1339 = None | |
| reciprocal_69 = torch.ops.aten.reciprocal.default(clamp_min_104); clamp_min_104 = None | |
| - mul_1024 = torch.ops.aten.mul.Tensor(reciprocal_69, 448.0); reciprocal_69 = None | |
| - convert_element_type_1262 = torch.ops.prims.convert_element_type.default(mul_1024, torch.float32); mul_1024 = None | |
| - log2_35 = torch.ops.aten.log2.default(convert_element_type_1262); convert_element_type_1262 = None | |
| + mul_1086 = torch.ops.aten.mul.Tensor(reciprocal_69, 448.0); reciprocal_69 = None | |
| + convert_element_type_1340 = torch.ops.prims.convert_element_type.default(mul_1086, torch.float32); mul_1086 = None | |
| + log2_35 = torch.ops.aten.log2.default(convert_element_type_1340); convert_element_type_1340 = None | |
| floor_35 = torch.ops.aten.floor.default(log2_35); log2_35 = None | |
| exp2_35 = torch.ops.aten.exp2.default(floor_35); floor_35 = None | |
| - mul_1025 = torch.ops.aten.mul.Tensor(permute_183, exp2_35); permute_183 = None | |
| - clamp_min_105 = torch.ops.aten.clamp_min.default(mul_1025, -448.0); mul_1025 = None | |
| + mul_1087 = torch.ops.aten.mul.Tensor(permute_183, exp2_35); permute_183 = None | |
| + clamp_min_105 = torch.ops.aten.clamp_min.default(mul_1087, -448.0); mul_1087 = None | |
| clamp_max_65 = torch.ops.aten.clamp_max.default(clamp_min_105, 448.0); clamp_min_105 = None | |
| - convert_element_type_1263 = torch.ops.prims.convert_element_type.default(clamp_max_65, torch.float8_e4m3fn); clamp_max_65 = None | |
| - clone_105 = torch.ops.aten.clone.default(convert_element_type_1263, memory_format = torch.contiguous_format); convert_element_type_1263 = None | |
| + convert_element_type_1341 = torch.ops.prims.convert_element_type.default(clamp_max_65, torch.float8_e4m3fn); clamp_max_65 = None | |
| + clone_105 = torch.ops.aten.clone.default(convert_element_type_1341, memory_format = torch.contiguous_format); convert_element_type_1341 = None | |
| permute_796 = torch.ops.aten.permute.default(clone_105, [1, 0]); clone_105 = None | |
| repeat_11 = torch.ops.aten.repeat.default(exp2_35, [4608]); exp2_35 = None | |
| - view_521 = torch.ops.aten.view.default(repeat_11, [1, -1]); repeat_11 = None | |
| + view_552 = torch.ops.aten.view.default(repeat_11, [1, -1]); repeat_11 = None | |
| reciprocal_70 = torch.ops.aten.reciprocal.default(exp2_34); exp2_34 = None | |
| - reciprocal_71 = torch.ops.aten.reciprocal.default(view_521); view_521 = None | |
| - mul_1026 = torch.ops.aten.mul.Tensor(reciprocal_70, reciprocal_71); reciprocal_70 = reciprocal_71 = None | |
| - _scaled_mm_17 = torch.ops.aten._scaled_mm.default(convert_element_type_1260, permute_796, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1260 = permute_796 = None | |
| - mul_1027 = torch.ops.aten.mul.Tensor(_scaled_mm_17, mul_1026); _scaled_mm_17 = mul_1026 = None | |
| - permute_797 = torch.ops.aten.permute.default(add_439, [1, 0]); add_439 = None | |
| - convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_170, torch.bfloat16); mul_170 = None | |
| - mm_308 = torch.ops.aten.mm.default(permute_797, convert_element_type_517); permute_797 = convert_element_type_517 = None | |
| + reciprocal_71 = torch.ops.aten.reciprocal.default(view_552); view_552 = None | |
| + mul_1088 = torch.ops.aten.mul.Tensor(reciprocal_70, reciprocal_71); reciprocal_70 = reciprocal_71 = None | |
| + _scaled_mm_17 = torch.ops.aten._scaled_mm.default(convert_element_type_1338, permute_796, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1338 = permute_796 = None | |
| + mul_1089 = torch.ops.aten.mul.Tensor(_scaled_mm_17, mul_1088); _scaled_mm_17 = mul_1088 = None | |
| + convert_element_type_1342 = torch.ops.prims.convert_element_type.default(mul_1089, torch.bfloat16); mul_1089 = None | |
| + permute_797 = torch.ops.aten.permute.default(add_470, [1, 0]); add_470 = None | |
| + mm_308 = torch.ops.aten.mm.default(permute_797, mul_170); permute_797 = mul_170 = None | |
| permute_798 = torch.ops.aten.permute.default(mm_308, [1, 0]); mm_308 = None | |
| - convert_element_type_1267 = torch.ops.prims.convert_element_type.default(permute_798, torch.float32); permute_798 = None | |
| - permute_799 = torch.ops.aten.permute.default(convert_element_type_1267, [1, 0]); convert_element_type_1267 = None | |
| - convert_element_type_default_4 = torch.ops.prims.convert_element_type.default(mul_1027, torch.float32); mul_1027 = None | |
| - add_440 = torch.ops.aten.add.Tensor(add_357, convert_element_type_default_4); add_357 = convert_element_type_default_4 = None | |
| + convert_element_type_1345 = torch.ops.prims.convert_element_type.default(permute_798, torch.float32); permute_798 = None | |
| + add_471 = torch.ops.aten.add.Tensor(add_357, convert_element_type_1342); add_357 = convert_element_type_1342 = None | |
| + permute_799 = torch.ops.aten.permute.default(convert_element_type_1345, [1, 0]); convert_element_type_1345 = None | |
| + convert_element_type_502 = torch.ops.prims.convert_element_type.default(add_125, torch.float32); add_125 = None | |
| + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_502, 2) | |
| + mean_34 = torch.ops.aten.mean.dim(pow_35, [1], True); pow_35 = None | |
| + add_127 = torch.ops.aten.add.Scalar(mean_34, 1.1920928955078125e-07); mean_34 = None | |
| + rsqrt_55 = torch.ops.aten.rsqrt.default(add_127); add_127 = None | |
| + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_502, rsqrt_55) | |
| mul_157 = torch.ops.aten.mul.Tensor(mul_156, primals_330) | |
| - sigmoid_16 = torch.ops.aten.sigmoid.default(mul_157) | |
| - mul_159 = torch.ops.aten.mul.Tensor(mul_157, sigmoid_16) | |
| + convert_element_type_503 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None | |
| + sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_503) | |
| + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_503, sigmoid_16) | |
| slice_28 = torch.ops.aten.slice.Tensor(mm_65, 1, 4608, 9223372036854775807) | |
| index_18 = torch.ops.aten.index.Tensor(slice_28, [sub_32]); slice_28 = None | |
| add_131 = torch.ops.aten.add.Tensor(mm_66, index_18); mm_66 = index_18 = None | |
| add_133 = torch.ops.aten.add.Tensor(mul_159, add_131); add_131 = None | |
| - pow_39 = torch.ops.aten.pow.Tensor_Scalar(add_133, 2) | |
| + convert_element_type_522 = torch.ops.prims.convert_element_type.default(add_133, torch.float32); add_133 = None | |
| + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_522, 2) | |
| mean_38 = torch.ops.aten.mean.dim(pow_39, [1], True); pow_39 = None | |
| add_135 = torch.ops.aten.add.Scalar(mean_38, 1.1920928955078125e-07); mean_38 = None | |
| rsqrt_59 = torch.ops.aten.rsqrt.default(add_135); add_135 = None | |
| - mul_168 = torch.ops.aten.mul.Tensor(add_133, rsqrt_59); add_133 = None | |
| + mul_168 = torch.ops.aten.mul.Tensor(convert_element_type_522, rsqrt_59) | |
| mul_169 = torch.ops.aten.mul.Tensor(mul_168, primals_338) | |
| - mul_1028 = torch.ops.aten.mul.Tensor(add_438, mul_169) | |
| - sigmoid_20 = torch.ops.aten.sigmoid.default(mul_169); mul_169 = None | |
| - mul_1029 = torch.ops.aten.mul.Tensor(add_438, sigmoid_20); add_438 = None | |
| - sub_252 = torch.ops.aten.sub.Tensor(1, sigmoid_20) | |
| - mul_1030 = torch.ops.aten.mul.Tensor(sigmoid_20, sub_252); sigmoid_20 = sub_252 = None | |
| - mul_1031 = torch.ops.aten.mul.Tensor(mul_1028, mul_1030); mul_1028 = mul_1030 = None | |
| - add_441 = torch.ops.aten.add.Tensor(mul_1029, mul_1031); mul_1029 = mul_1031 = None | |
| + convert_element_type_523 = torch.ops.prims.convert_element_type.default(mul_169, torch.bfloat16); mul_169 = None | |
| + mul_1090 = torch.ops.aten.mul.Tensor(add_469, convert_element_type_523) | |
| + sigmoid_20 = torch.ops.aten.sigmoid.default(convert_element_type_523); convert_element_type_523 = None | |
| + mul_1091 = torch.ops.aten.mul.Tensor(add_469, sigmoid_20); add_469 = None | |
| + convert_element_type_1346 = torch.ops.prims.convert_element_type.default(mul_1090, torch.float32); mul_1090 = None | |
| + convert_element_type_1347 = torch.ops.prims.convert_element_type.default(sigmoid_20, torch.float32); sigmoid_20 = None | |
| + sub_221 = torch.ops.aten.sub.Tensor(1, convert_element_type_1347) | |
| + mul_1092 = torch.ops.aten.mul.Tensor(convert_element_type_1347, sub_221); convert_element_type_1347 = sub_221 = None | |
| + mul_1093 = torch.ops.aten.mul.Tensor(convert_element_type_1346, mul_1092); convert_element_type_1346 = mul_1092 = None | |
| + convert_element_type_1348 = torch.ops.prims.convert_element_type.default(mul_1093, torch.bfloat16); mul_1093 = None | |
| + add_472 = torch.ops.aten.add.Tensor(mul_1091, convert_element_type_1348); mul_1091 = convert_element_type_1348 = None | |
| + convert_element_type_500 = torch.ops.prims.convert_element_type.default(add_124, torch.float32); add_124 = None | |
| + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_500, 2) | |
| + mean_33 = torch.ops.aten.mean.dim(pow_34, [1], True); pow_34 = None | |
| + add_126 = torch.ops.aten.add.Scalar(mean_33, 1.1920928955078125e-07); mean_33 = None | |
| + rsqrt_54 = torch.ops.aten.rsqrt.default(add_126); add_126 = None | |
| + mul_154 = torch.ops.aten.mul.Tensor(convert_element_type_500, rsqrt_54) | |
| mul_155 = torch.ops.aten.mul.Tensor(mul_154, primals_329) | |
| - sigmoid_15 = torch.ops.aten.sigmoid.default(mul_155) | |
| - mul_158 = torch.ops.aten.mul.Tensor(mul_155, sigmoid_15) | |
| + convert_element_type_501 = torch.ops.prims.convert_element_type.default(mul_155, torch.bfloat16); mul_155 = None | |
| + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_501) | |
| + mul_158 = torch.ops.aten.mul.Tensor(convert_element_type_501, sigmoid_15) | |
| slice_27 = torch.ops.aten.slice.Tensor(mm_65, 1, 0, 4608); mm_65 = None | |
| add_132 = torch.ops.aten.add.Tensor(mul_158, slice_27); slice_27 = None | |
| - pow_38 = torch.ops.aten.pow.Tensor_Scalar(add_132, 2) | |
| + convert_element_type_520 = torch.ops.prims.convert_element_type.default(add_132, torch.float32); add_132 = None | |
| + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_520, 2) | |
| mean_37 = torch.ops.aten.mean.dim(pow_38, [1], True); pow_38 = None | |
| add_134 = torch.ops.aten.add.Scalar(mean_37, 1.1920928955078125e-07); mean_37 = None | |
| rsqrt_58 = torch.ops.aten.rsqrt.default(add_134); add_134 = None | |
| - mul_166 = torch.ops.aten.mul.Tensor(add_132, rsqrt_58); add_132 = None | |
| + mul_166 = torch.ops.aten.mul.Tensor(convert_element_type_520, rsqrt_58) | |
| mul_167 = torch.ops.aten.mul.Tensor(mul_166, primals_337) | |
| - mul_1032 = torch.ops.aten.mul.Tensor(add_440, mul_167) | |
| - sigmoid_19 = torch.ops.aten.sigmoid.default(mul_167); mul_167 = None | |
| - mul_1033 = torch.ops.aten.mul.Tensor(add_440, sigmoid_19); add_440 = None | |
| - sub_253 = torch.ops.aten.sub.Tensor(1, sigmoid_19) | |
| - mul_1034 = torch.ops.aten.mul.Tensor(sigmoid_19, sub_253); sigmoid_19 = sub_253 = None | |
| - mul_1035 = torch.ops.aten.mul.Tensor(mul_1032, mul_1034); mul_1032 = mul_1034 = None | |
| - add_442 = torch.ops.aten.add.Tensor(mul_1033, mul_1035); mul_1033 = mul_1035 = None | |
| - mul_1036 = torch.ops.aten.mul.Tensor(add_441, primals_338); primals_338 = None | |
| - mul_1038 = torch.ops.aten.mul.Tensor(mul_168, mul_1036) | |
| - sum_166 = torch.ops.aten.sum.dim_IntList(mul_1038, [1], True); mul_1038 = None | |
| - div_123 = torch.ops.aten.div.Tensor(mul_168, 4608) | |
| - mul_1039 = torch.ops.aten.mul.Tensor(div_123, sum_166); div_123 = sum_166 = None | |
| - sub_254 = torch.ops.aten.sub.Tensor(mul_1036, mul_1039); mul_1036 = mul_1039 = None | |
| - mul_1040 = torch.ops.aten.mul.Tensor(sub_254, rsqrt_59); sub_254 = rsqrt_59 = None | |
| - mul_1041 = torch.ops.aten.mul.Tensor(add_441, mul_168); add_441 = mul_168 = None | |
| - sum_167 = torch.ops.aten.sum.dim_IntList(mul_1041, [0]); mul_1041 = None | |
| - mul_1042 = torch.ops.aten.mul.Tensor(add_442, primals_337); primals_337 = None | |
| - mul_1044 = torch.ops.aten.mul.Tensor(mul_166, mul_1042) | |
| - sum_168 = torch.ops.aten.sum.dim_IntList(mul_1044, [1], True); mul_1044 = None | |
| - div_124 = torch.ops.aten.div.Tensor(mul_166, 4608) | |
| - mul_1045 = torch.ops.aten.mul.Tensor(div_124, sum_168); div_124 = sum_168 = None | |
| - sub_255 = torch.ops.aten.sub.Tensor(mul_1042, mul_1045); mul_1042 = mul_1045 = None | |
| - mul_1046 = torch.ops.aten.mul.Tensor(sub_255, rsqrt_58); sub_255 = rsqrt_58 = None | |
| - mul_1047 = torch.ops.aten.mul.Tensor(add_442, mul_166); add_442 = mul_166 = None | |
| - sum_169 = torch.ops.aten.sum.dim_IntList(mul_1047, [0]); mul_1047 = None | |
| - convert_element_type_1269 = torch.ops.prims.convert_element_type.default(mul_1040, torch.bfloat16) | |
| - convert_element_type_1270 = torch.ops.prims.convert_element_type.default(mul_1046, torch.bfloat16) | |
| - index_put_44 = torch.ops.aten.index_put.default(full_default_296, [sub_32], convert_element_type_1269, True) | |
| + convert_element_type_521 = torch.ops.prims.convert_element_type.default(mul_167, torch.bfloat16); mul_167 = None | |
| + mul_1094 = torch.ops.aten.mul.Tensor(add_471, convert_element_type_521) | |
| + sigmoid_19 = torch.ops.aten.sigmoid.default(convert_element_type_521); convert_element_type_521 = None | |
| + mul_1095 = torch.ops.aten.mul.Tensor(add_471, sigmoid_19); add_471 = None | |
| + convert_element_type_1349 = torch.ops.prims.convert_element_type.default(mul_1094, torch.float32); mul_1094 = None | |
| + convert_element_type_1350 = torch.ops.prims.convert_element_type.default(sigmoid_19, torch.float32); sigmoid_19 = None | |
| + sub_222 = torch.ops.aten.sub.Tensor(1, convert_element_type_1350) | |
| + mul_1096 = torch.ops.aten.mul.Tensor(convert_element_type_1350, sub_222); convert_element_type_1350 = sub_222 = None | |
| + mul_1097 = torch.ops.aten.mul.Tensor(convert_element_type_1349, mul_1096); convert_element_type_1349 = mul_1096 = None | |
| + convert_element_type_1351 = torch.ops.prims.convert_element_type.default(mul_1097, torch.bfloat16); mul_1097 = None | |
| + add_473 = torch.ops.aten.add.Tensor(mul_1095, convert_element_type_1351); mul_1095 = convert_element_type_1351 = None | |
| + convert_element_type_1352 = torch.ops.prims.convert_element_type.default(add_472, torch.float32); add_472 = None | |
| + mul_1098 = torch.ops.aten.mul.Tensor(convert_element_type_1352, mul_168); mul_168 = None | |
| + mul_1099 = torch.ops.aten.mul.Tensor(convert_element_type_1352, primals_338); convert_element_type_1352 = primals_338 = None | |
| + sum_166 = torch.ops.aten.sum.dim_IntList(mul_1098, [0], True); mul_1098 = None | |
| + view_556 = torch.ops.aten.view.default(sum_166, [4608]); sum_166 = None | |
| + mul_1100 = torch.ops.aten.mul.Tensor(mul_1099, convert_element_type_522) | |
| + mul_1101 = torch.ops.aten.mul.Tensor(mul_1099, rsqrt_59); mul_1099 = None | |
| + sum_167 = torch.ops.aten.sum.dim_IntList(mul_1100, [1], True); mul_1100 = None | |
| + mul_1102 = torch.ops.aten.mul.Scalar(sum_167, -0.5); sum_167 = None | |
| + pow_146 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_59, 3); rsqrt_59 = None | |
| + mul_1103 = torch.ops.aten.mul.Tensor(mul_1102, pow_146); mul_1102 = pow_146 = None | |
| + expand_94 = torch.ops.aten.expand.default(mul_1103, [4096, 4608]); mul_1103 = None | |
| + div_123 = torch.ops.aten.div.Scalar(expand_94, 4608); expand_94 = None | |
| + pow_147 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_522, 1.0); convert_element_type_522 = None | |
| + mul_1104 = torch.ops.aten.mul.Scalar(pow_147, 2.0); pow_147 = None | |
| + mul_1105 = torch.ops.aten.mul.Tensor(div_123, mul_1104); div_123 = mul_1104 = None | |
| + add_474 = torch.ops.aten.add.Tensor(mul_1101, mul_1105); mul_1101 = mul_1105 = None | |
| + convert_element_type_1353 = torch.ops.prims.convert_element_type.default(add_474, torch.bfloat16); add_474 = None | |
| + convert_element_type_1354 = torch.ops.prims.convert_element_type.default(add_473, torch.float32); add_473 = None | |
| + mul_1106 = torch.ops.aten.mul.Tensor(convert_element_type_1354, mul_166); mul_166 = None | |
| + mul_1107 = torch.ops.aten.mul.Tensor(convert_element_type_1354, primals_337); convert_element_type_1354 = primals_337 = None | |
| + sum_168 = torch.ops.aten.sum.dim_IntList(mul_1106, [0], True); mul_1106 = None | |
| + view_557 = torch.ops.aten.view.default(sum_168, [4608]); sum_168 = None | |
| + mul_1108 = torch.ops.aten.mul.Tensor(mul_1107, convert_element_type_520) | |
| + mul_1109 = torch.ops.aten.mul.Tensor(mul_1107, rsqrt_58); mul_1107 = None | |
| + sum_169 = torch.ops.aten.sum.dim_IntList(mul_1108, [1], True); mul_1108 = None | |
| + mul_1110 = torch.ops.aten.mul.Scalar(sum_169, -0.5); sum_169 = None | |
| + pow_148 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_58, 3); rsqrt_58 = None | |
| + mul_1111 = torch.ops.aten.mul.Tensor(mul_1110, pow_148); mul_1110 = pow_148 = None | |
| + expand_95 = torch.ops.aten.expand.default(mul_1111, [4096, 4608]); mul_1111 = None | |
| + div_124 = torch.ops.aten.div.Scalar(expand_95, 4608); expand_95 = None | |
| + pow_149 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_520, 1.0); convert_element_type_520 = None | |
| + mul_1112 = torch.ops.aten.mul.Scalar(pow_149, 2.0); pow_149 = None | |
| + mul_1113 = torch.ops.aten.mul.Tensor(div_124, mul_1112); div_124 = mul_1112 = None | |
| + add_475 = torch.ops.aten.add.Tensor(mul_1109, mul_1113); mul_1109 = mul_1113 = None | |
| + convert_element_type_1355 = torch.ops.prims.convert_element_type.default(add_475, torch.bfloat16); add_475 = None | |
| + index_put_44 = torch.ops.aten.index_put.default(full_default_296, [sub_32], convert_element_type_1353, True) | |
| slice_scatter_49 = torch.ops.aten.slice_scatter.default(full_default_297, index_put_44, 1, 4608, 9223372036854775807); index_put_44 = None | |
| - permute_800 = torch.ops.aten.permute.default(convert_element_type_1269, [1, 0]) | |
| + permute_800 = torch.ops.aten.permute.default(convert_element_type_1353, [1, 0]) | |
| slice_26 = torch.ops.aten.slice.Tensor(mm_63, 1, 1536, 9223372036854775807) | |
| index_17 = torch.ops.aten.index.Tensor(slice_26, [sub_32]); slice_26 = None | |
| add_128 = torch.ops.aten.add.Tensor(mm_64, index_17); mm_64 = index_17 = None | |
| - convert_element_type_508 = torch.ops.prims.convert_element_type.default(add_128, torch.float32); add_128 = None | |
| - pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_508, 2) | |
| + convert_element_type_512 = torch.ops.prims.convert_element_type.default(add_128, torch.float32); add_128 = None | |
| + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_512, 2) | |
| mean_36 = torch.ops.aten.mean.dim(pow_37, [1], True); pow_37 = None | |
| add_130 = torch.ops.aten.add.Scalar(mean_36, 1.1920928955078125e-07); mean_36 = None | |
| rsqrt_57 = torch.ops.aten.rsqrt.default(add_130); add_130 = None | |
| - mul_162 = torch.ops.aten.mul.Tensor(convert_element_type_508, rsqrt_57); convert_element_type_508 = None | |
| + mul_162 = torch.ops.aten.mul.Tensor(convert_element_type_512, rsqrt_57) | |
| mul_163 = torch.ops.aten.mul.Tensor(mul_162, primals_334) | |
| - sigmoid_18 = torch.ops.aten.sigmoid.default(mul_163) | |
| - mul_165 = torch.ops.aten.mul.Tensor(mul_163, sigmoid_18) | |
| - convert_element_type_513 = torch.ops.prims.convert_element_type.default(mul_165, torch.bfloat16); mul_165 = None | |
| - mm_309 = torch.ops.aten.mm.default(permute_800, convert_element_type_513); permute_800 = convert_element_type_513 = None | |
| - convert_element_type_514 = torch.ops.prims.convert_element_type.default(primals_336, torch.bfloat16); primals_336 = None | |
| - permute_182 = torch.ops.aten.permute.default(convert_element_type_514, [1, 0]); convert_element_type_514 = None | |
| + convert_element_type_513 = torch.ops.prims.convert_element_type.default(mul_163, torch.bfloat16); mul_163 = None | |
| + sigmoid_18 = torch.ops.aten.sigmoid.default(convert_element_type_513) | |
| + mul_165 = torch.ops.aten.mul.Tensor(convert_element_type_513, sigmoid_18) | |
| + mm_309 = torch.ops.aten.mm.default(permute_800, mul_165); permute_800 = mul_165 = None | |
| + convert_element_type_517 = torch.ops.prims.convert_element_type.default(primals_336, torch.bfloat16); primals_336 = None | |
| + permute_182 = torch.ops.aten.permute.default(convert_element_type_517, [1, 0]); convert_element_type_517 = None | |
| permute_802 = torch.ops.aten.permute.default(permute_182, [1, 0]); permute_182 = None | |
| - mm_310 = torch.ops.aten.mm.default(convert_element_type_1269, permute_802); convert_element_type_1269 = permute_802 = None | |
| - convert_element_type_1275 = torch.ops.prims.convert_element_type.default(mm_309, torch.float32); mm_309 = None | |
| - convert_element_type_1276 = torch.ops.prims.convert_element_type.default(mm_310, torch.float32); mm_310 = None | |
| - slice_scatter_50 = torch.ops.aten.slice_scatter.default(full_default_297, convert_element_type_1270, 1, 0, 4608); convert_element_type_1270 = None | |
| - add_443 = torch.ops.aten.add.Tensor(slice_scatter_49, slice_scatter_50); slice_scatter_49 = slice_scatter_50 = None | |
| - permute_804 = torch.ops.aten.permute.default(add_443, [1, 0]) | |
| + mm_310 = torch.ops.aten.mm.default(convert_element_type_1353, permute_802); permute_802 = None | |
| + convert_element_type_1360 = torch.ops.prims.convert_element_type.default(mm_309, torch.float32); mm_309 = None | |
| + slice_scatter_50 = torch.ops.aten.slice_scatter.default(full_default_297, convert_element_type_1355, 1, 0, 4608) | |
| + add_476 = torch.ops.aten.add.Tensor(slice_scatter_49, slice_scatter_50); slice_scatter_49 = slice_scatter_50 = None | |
| + permute_804 = torch.ops.aten.permute.default(add_476, [1, 0]) | |
| slice_25 = torch.ops.aten.slice.Tensor(mm_63, 1, 0, 1536); mm_63 = None | |
| - convert_element_type_507 = torch.ops.prims.convert_element_type.default(slice_25, torch.float32); slice_25 = None | |
| - pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_507, 2) | |
| + convert_element_type_510 = torch.ops.prims.convert_element_type.default(slice_25, torch.float32); slice_25 = None | |
| + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_510, 2) | |
| mean_35 = torch.ops.aten.mean.dim(pow_36, [1], True); pow_36 = None | |
| add_129 = torch.ops.aten.add.Scalar(mean_35, 1.1920928955078125e-07); mean_35 = None | |
| rsqrt_56 = torch.ops.aten.rsqrt.default(add_129); add_129 = None | |
| - mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_507, rsqrt_56); convert_element_type_507 = None | |
| + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_510, rsqrt_56) | |
| mul_161 = torch.ops.aten.mul.Tensor(mul_160, primals_333) | |
| - sigmoid_17 = torch.ops.aten.sigmoid.default(mul_161) | |
| - mul_164 = torch.ops.aten.mul.Tensor(mul_161, sigmoid_17) | |
| - convert_element_type_509 = torch.ops.prims.convert_element_type.default(mul_164, torch.bfloat16); mul_164 = None | |
| - mm_311 = torch.ops.aten.mm.default(permute_804, convert_element_type_509); permute_804 = convert_element_type_509 = None | |
| - convert_element_type_510 = torch.ops.prims.convert_element_type.default(primals_335, torch.bfloat16); primals_335 = None | |
| - permute_181 = torch.ops.aten.permute.default(convert_element_type_510, [1, 0]); convert_element_type_510 = None | |
| + convert_element_type_511 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None | |
| + sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_511) | |
| + mul_164 = torch.ops.aten.mul.Tensor(convert_element_type_511, sigmoid_17) | |
| + mm_311 = torch.ops.aten.mm.default(permute_804, mul_164); permute_804 = mul_164 = None | |
| + convert_element_type_514 = torch.ops.prims.convert_element_type.default(primals_335, torch.bfloat16); primals_335 = None | |
| + permute_181 = torch.ops.aten.permute.default(convert_element_type_514, [1, 0]); convert_element_type_514 = None | |
| permute_806 = torch.ops.aten.permute.default(permute_181, [1, 0]); permute_181 = None | |
| - mm_312 = torch.ops.aten.mm.default(add_443, permute_806); add_443 = permute_806 = None | |
| - convert_element_type_1281 = torch.ops.prims.convert_element_type.default(mm_311, torch.float32); mm_311 = None | |
| - convert_element_type_1282 = torch.ops.prims.convert_element_type.default(mm_312, torch.float32); mm_312 = None | |
| - mul_1048 = torch.ops.aten.mul.Tensor(convert_element_type_1276, mul_163); mul_163 = None | |
| - mul_1049 = torch.ops.aten.mul.Tensor(convert_element_type_1276, sigmoid_18); convert_element_type_1276 = None | |
| - sub_256 = torch.ops.aten.sub.Tensor(1, sigmoid_18) | |
| - mul_1050 = torch.ops.aten.mul.Tensor(sigmoid_18, sub_256); sigmoid_18 = sub_256 = None | |
| - mul_1051 = torch.ops.aten.mul.Tensor(mul_1048, mul_1050); mul_1048 = mul_1050 = None | |
| - add_444 = torch.ops.aten.add.Tensor(mul_1049, mul_1051); mul_1049 = mul_1051 = None | |
| - mul_1052 = torch.ops.aten.mul.Tensor(convert_element_type_1282, mul_161); mul_161 = None | |
| - mul_1053 = torch.ops.aten.mul.Tensor(convert_element_type_1282, sigmoid_17); convert_element_type_1282 = None | |
| - sub_257 = torch.ops.aten.sub.Tensor(1, sigmoid_17) | |
| - mul_1054 = torch.ops.aten.mul.Tensor(sigmoid_17, sub_257); sigmoid_17 = sub_257 = None | |
| - mul_1055 = torch.ops.aten.mul.Tensor(mul_1052, mul_1054); mul_1052 = mul_1054 = None | |
| - add_445 = torch.ops.aten.add.Tensor(mul_1053, mul_1055); mul_1053 = mul_1055 = None | |
| - mul_1056 = torch.ops.aten.mul.Tensor(add_444, primals_334); primals_334 = None | |
| - mul_1058 = torch.ops.aten.mul.Tensor(mul_162, mul_1056) | |
| - sum_170 = torch.ops.aten.sum.dim_IntList(mul_1058, [1], True); mul_1058 = None | |
| - div_125 = torch.ops.aten.div.Tensor(mul_162, 1536) | |
| - mul_1059 = torch.ops.aten.mul.Tensor(div_125, sum_170); div_125 = sum_170 = None | |
| - sub_258 = torch.ops.aten.sub.Tensor(mul_1056, mul_1059); mul_1056 = mul_1059 = None | |
| - mul_1060 = torch.ops.aten.mul.Tensor(sub_258, rsqrt_57); sub_258 = rsqrt_57 = None | |
| - mul_1061 = torch.ops.aten.mul.Tensor(add_444, mul_162); add_444 = mul_162 = None | |
| - sum_171 = torch.ops.aten.sum.dim_IntList(mul_1061, [0]); mul_1061 = None | |
| - convert_element_type_1283 = torch.ops.prims.convert_element_type.default(mul_1060, torch.bfloat16); mul_1060 = None | |
| - mul_1062 = torch.ops.aten.mul.Tensor(add_445, primals_333); primals_333 = None | |
| - mul_1064 = torch.ops.aten.mul.Tensor(mul_160, mul_1062) | |
| - sum_172 = torch.ops.aten.sum.dim_IntList(mul_1064, [1], True); mul_1064 = None | |
| - div_126 = torch.ops.aten.div.Tensor(mul_160, 1536) | |
| - mul_1065 = torch.ops.aten.mul.Tensor(div_126, sum_172); div_126 = sum_172 = None | |
| - sub_259 = torch.ops.aten.sub.Tensor(mul_1062, mul_1065); mul_1062 = mul_1065 = None | |
| - mul_1066 = torch.ops.aten.mul.Tensor(sub_259, rsqrt_56); sub_259 = rsqrt_56 = None | |
| - mul_1067 = torch.ops.aten.mul.Tensor(add_445, mul_160); add_445 = mul_160 = None | |
| - sum_173 = torch.ops.aten.sum.dim_IntList(mul_1067, [0]); mul_1067 = None | |
| - convert_element_type_1284 = torch.ops.prims.convert_element_type.default(mul_1066, torch.bfloat16); mul_1066 = None | |
| - index_put_45 = torch.ops.aten.index_put.default(full_default_281, [sub_32], convert_element_type_1283, True); full_default_281 = None | |
| + mm_312 = torch.ops.aten.mm.default(add_476, permute_806); add_476 = permute_806 = None | |
| + convert_element_type_1365 = torch.ops.prims.convert_element_type.default(mm_311, torch.float32); mm_311 = None | |
| + mul_1114 = torch.ops.aten.mul.Tensor(mm_310, convert_element_type_513); convert_element_type_513 = None | |
| + mul_1115 = torch.ops.aten.mul.Tensor(mm_310, sigmoid_18); mm_310 = None | |
| + convert_element_type_1366 = torch.ops.prims.convert_element_type.default(mul_1114, torch.float32); mul_1114 = None | |
| + convert_element_type_1367 = torch.ops.prims.convert_element_type.default(sigmoid_18, torch.float32); sigmoid_18 = None | |
| + sub_223 = torch.ops.aten.sub.Tensor(1, convert_element_type_1367) | |
| + mul_1116 = torch.ops.aten.mul.Tensor(convert_element_type_1367, sub_223); convert_element_type_1367 = sub_223 = None | |
| + mul_1117 = torch.ops.aten.mul.Tensor(convert_element_type_1366, mul_1116); convert_element_type_1366 = mul_1116 = None | |
| + convert_element_type_1368 = torch.ops.prims.convert_element_type.default(mul_1117, torch.bfloat16); mul_1117 = None | |
| + add_477 = torch.ops.aten.add.Tensor(mul_1115, convert_element_type_1368); mul_1115 = convert_element_type_1368 = None | |
| + mul_1118 = torch.ops.aten.mul.Tensor(mm_312, convert_element_type_511); convert_element_type_511 = None | |
| + mul_1119 = torch.ops.aten.mul.Tensor(mm_312, sigmoid_17); mm_312 = None | |
| + convert_element_type_1369 = torch.ops.prims.convert_element_type.default(mul_1118, torch.float32); mul_1118 = None | |
| + convert_element_type_1370 = torch.ops.prims.convert_element_type.default(sigmoid_17, torch.float32); sigmoid_17 = None | |
| + sub_224 = torch.ops.aten.sub.Tensor(1, convert_element_type_1370) | |
| + mul_1120 = torch.ops.aten.mul.Tensor(convert_element_type_1370, sub_224); convert_element_type_1370 = sub_224 = None | |
| + mul_1121 = torch.ops.aten.mul.Tensor(convert_element_type_1369, mul_1120); convert_element_type_1369 = mul_1120 = None | |
| + convert_element_type_1371 = torch.ops.prims.convert_element_type.default(mul_1121, torch.bfloat16); mul_1121 = None | |
| + add_478 = torch.ops.aten.add.Tensor(mul_1119, convert_element_type_1371); mul_1119 = convert_element_type_1371 = None | |
| + convert_element_type_1372 = torch.ops.prims.convert_element_type.default(add_477, torch.float32); add_477 = None | |
| + mul_1122 = torch.ops.aten.mul.Tensor(convert_element_type_1372, mul_162); mul_162 = None | |
| + mul_1123 = torch.ops.aten.mul.Tensor(convert_element_type_1372, primals_334); convert_element_type_1372 = primals_334 = None | |
| + sum_170 = torch.ops.aten.sum.dim_IntList(mul_1122, [0], True); mul_1122 = None | |
| + view_558 = torch.ops.aten.view.default(sum_170, [1536]); sum_170 = None | |
| + mul_1124 = torch.ops.aten.mul.Tensor(mul_1123, convert_element_type_512) | |
| + mul_1125 = torch.ops.aten.mul.Tensor(mul_1123, rsqrt_57); mul_1123 = None | |
| + sum_171 = torch.ops.aten.sum.dim_IntList(mul_1124, [1], True); mul_1124 = None | |
| + mul_1126 = torch.ops.aten.mul.Scalar(sum_171, -0.5); sum_171 = None | |
| + pow_150 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_57, 3); rsqrt_57 = None | |
| + mul_1127 = torch.ops.aten.mul.Tensor(mul_1126, pow_150); mul_1126 = pow_150 = None | |
| + expand_96 = torch.ops.aten.expand.default(mul_1127, [4096, 1536]); mul_1127 = None | |
| + div_125 = torch.ops.aten.div.Scalar(expand_96, 1536); expand_96 = None | |
| + pow_151 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_512, 1.0); convert_element_type_512 = None | |
| + mul_1128 = torch.ops.aten.mul.Scalar(pow_151, 2.0); pow_151 = None | |
| + mul_1129 = torch.ops.aten.mul.Tensor(div_125, mul_1128); div_125 = mul_1128 = None | |
| + add_479 = torch.ops.aten.add.Tensor(mul_1125, mul_1129); mul_1125 = mul_1129 = None | |
| + convert_element_type_1373 = torch.ops.prims.convert_element_type.default(add_479, torch.bfloat16); add_479 = None | |
| + convert_element_type_1374 = torch.ops.prims.convert_element_type.default(add_478, torch.float32); add_478 = None | |
| + mul_1130 = torch.ops.aten.mul.Tensor(convert_element_type_1374, mul_160); mul_160 = None | |
| + mul_1131 = torch.ops.aten.mul.Tensor(convert_element_type_1374, primals_333); convert_element_type_1374 = primals_333 = None | |
| + sum_172 = torch.ops.aten.sum.dim_IntList(mul_1130, [0], True); mul_1130 = None | |
| + view_559 = torch.ops.aten.view.default(sum_172, [1536]); sum_172 = None | |
| + mul_1132 = torch.ops.aten.mul.Tensor(mul_1131, convert_element_type_510) | |
| + mul_1133 = torch.ops.aten.mul.Tensor(mul_1131, rsqrt_56); mul_1131 = None | |
| + sum_173 = torch.ops.aten.sum.dim_IntList(mul_1132, [1], True); mul_1132 = None | |
| + mul_1134 = torch.ops.aten.mul.Scalar(sum_173, -0.5); sum_173 = None | |
| + pow_152 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_56, 3); rsqrt_56 = None | |
| + mul_1135 = torch.ops.aten.mul.Tensor(mul_1134, pow_152); mul_1134 = pow_152 = None | |
| + expand_97 = torch.ops.aten.expand.default(mul_1135, [4096, 1536]); mul_1135 = None | |
| + div_126 = torch.ops.aten.div.Scalar(expand_97, 1536); expand_97 = None | |
| + pow_153 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_510, 1.0); convert_element_type_510 = None | |
| + mul_1136 = torch.ops.aten.mul.Scalar(pow_153, 2.0); pow_153 = None | |
| + mul_1137 = torch.ops.aten.mul.Tensor(div_126, mul_1136); div_126 = mul_1136 = None | |
| + add_480 = torch.ops.aten.add.Tensor(mul_1133, mul_1137); mul_1133 = mul_1137 = None | |
| + convert_element_type_1375 = torch.ops.prims.convert_element_type.default(add_480, torch.bfloat16); add_480 = None | |
| + index_put_45 = torch.ops.aten.index_put.default(full_default_281, [sub_32], convert_element_type_1373, True); full_default_281 = None | |
| slice_scatter_51 = torch.ops.aten.slice_scatter.default(full_default_300, index_put_45, 1, 1536, 9223372036854775807); index_put_45 = None | |
| - permute_808 = torch.ops.aten.permute.default(convert_element_type_1283, [1, 0]) | |
| - convert_element_type_503 = torch.ops.prims.convert_element_type.default(mul_159, torch.bfloat16); mul_159 = None | |
| - mm_313 = torch.ops.aten.mm.default(permute_808, convert_element_type_503); permute_808 = convert_element_type_503 = None | |
| - convert_element_type_504 = torch.ops.prims.convert_element_type.default(primals_332, torch.bfloat16); primals_332 = None | |
| - permute_180 = torch.ops.aten.permute.default(convert_element_type_504, [1, 0]); convert_element_type_504 = None | |
| + permute_808 = torch.ops.aten.permute.default(convert_element_type_1373, [1, 0]) | |
| + mm_313 = torch.ops.aten.mm.default(permute_808, mul_159); permute_808 = mul_159 = None | |
| + convert_element_type_507 = torch.ops.prims.convert_element_type.default(primals_332, torch.bfloat16); primals_332 = None | |
| + permute_180 = torch.ops.aten.permute.default(convert_element_type_507, [1, 0]); convert_element_type_507 = None | |
| permute_810 = torch.ops.aten.permute.default(permute_180, [1, 0]); permute_180 = None | |
| - mm_314 = torch.ops.aten.mm.default(convert_element_type_1283, permute_810); convert_element_type_1283 = permute_810 = None | |
| - convert_element_type_1289 = torch.ops.prims.convert_element_type.default(mm_313, torch.float32); mm_313 = None | |
| - convert_element_type_1290 = torch.ops.prims.convert_element_type.default(mm_314, torch.float32); mm_314 = None | |
| - add_446 = torch.ops.aten.add.Tensor(mul_1040, convert_element_type_1290); mul_1040 = convert_element_type_1290 = None | |
| - slice_scatter_52 = torch.ops.aten.slice_scatter.default(full_default_300, convert_element_type_1284, 1, 0, 1536); full_default_300 = convert_element_type_1284 = None | |
| - add_447 = torch.ops.aten.add.Tensor(slice_scatter_51, slice_scatter_52); slice_scatter_51 = slice_scatter_52 = None | |
| - permute_812 = torch.ops.aten.permute.default(add_447, [1, 0]) | |
| - convert_element_type_499 = torch.ops.prims.convert_element_type.default(mul_158, torch.bfloat16); mul_158 = None | |
| - mm_315 = torch.ops.aten.mm.default(permute_812, convert_element_type_499); permute_812 = convert_element_type_499 = None | |
| - convert_element_type_500 = torch.ops.prims.convert_element_type.default(primals_331, torch.bfloat16); primals_331 = None | |
| - permute_179 = torch.ops.aten.permute.default(convert_element_type_500, [1, 0]); convert_element_type_500 = None | |
| + mm_314 = torch.ops.aten.mm.default(convert_element_type_1373, permute_810); convert_element_type_1373 = permute_810 = None | |
| + add_481 = torch.ops.aten.add.Tensor(convert_element_type_1353, mm_314); convert_element_type_1353 = mm_314 = None | |
| + convert_element_type_1380 = torch.ops.prims.convert_element_type.default(mm_313, torch.float32); mm_313 = None | |
| + slice_scatter_52 = torch.ops.aten.slice_scatter.default(full_default_300, convert_element_type_1375, 1, 0, 1536); full_default_300 = convert_element_type_1375 = None | |
| + add_482 = torch.ops.aten.add.Tensor(slice_scatter_51, slice_scatter_52); slice_scatter_51 = slice_scatter_52 = None | |
| + permute_812 = torch.ops.aten.permute.default(add_482, [1, 0]) | |
| + mm_315 = torch.ops.aten.mm.default(permute_812, mul_158); permute_812 = mul_158 = None | |
| + convert_element_type_504 = torch.ops.prims.convert_element_type.default(primals_331, torch.bfloat16); primals_331 = None | |
| + permute_179 = torch.ops.aten.permute.default(convert_element_type_504, [1, 0]); convert_element_type_504 = None | |
| permute_814 = torch.ops.aten.permute.default(permute_179, [1, 0]); permute_179 = None | |
| - mm_316 = torch.ops.aten.mm.default(add_447, permute_814); add_447 = permute_814 = None | |
| - convert_element_type_1295 = torch.ops.prims.convert_element_type.default(mm_315, torch.float32); mm_315 = None | |
| - convert_element_type_1296 = torch.ops.prims.convert_element_type.default(mm_316, torch.float32); mm_316 = None | |
| - add_448 = torch.ops.aten.add.Tensor(mul_1046, convert_element_type_1296); mul_1046 = convert_element_type_1296 = None | |
| - mul_1068 = torch.ops.aten.mul.Tensor(add_446, mul_157); mul_157 = None | |
| - mul_1069 = torch.ops.aten.mul.Tensor(add_446, sigmoid_16); add_446 = None | |
| - sub_260 = torch.ops.aten.sub.Tensor(1, sigmoid_16) | |
| - mul_1070 = torch.ops.aten.mul.Tensor(sigmoid_16, sub_260); sigmoid_16 = sub_260 = None | |
| - mul_1071 = torch.ops.aten.mul.Tensor(mul_1068, mul_1070); mul_1068 = mul_1070 = None | |
| - add_449 = torch.ops.aten.add.Tensor(mul_1069, mul_1071); mul_1069 = mul_1071 = None | |
| - mul_1072 = torch.ops.aten.mul.Tensor(add_448, mul_155); mul_155 = None | |
| - mul_1073 = torch.ops.aten.mul.Tensor(add_448, sigmoid_15); add_448 = None | |
| - sub_261 = torch.ops.aten.sub.Tensor(1, sigmoid_15) | |
| - mul_1074 = torch.ops.aten.mul.Tensor(sigmoid_15, sub_261); sigmoid_15 = sub_261 = None | |
| - mul_1075 = torch.ops.aten.mul.Tensor(mul_1072, mul_1074); mul_1072 = mul_1074 = None | |
| - add_450 = torch.ops.aten.add.Tensor(mul_1073, mul_1075); mul_1073 = mul_1075 = None | |
| - mul_1076 = torch.ops.aten.mul.Tensor(add_449, primals_330); primals_330 = None | |
| - mul_1078 = torch.ops.aten.mul.Tensor(mul_156, mul_1076) | |
| - sum_174 = torch.ops.aten.sum.dim_IntList(mul_1078, [1], True); mul_1078 = None | |
| - div_127 = torch.ops.aten.div.Tensor(mul_156, 4608) | |
| - mul_1079 = torch.ops.aten.mul.Tensor(div_127, sum_174); div_127 = sum_174 = None | |
| - sub_262 = torch.ops.aten.sub.Tensor(mul_1076, mul_1079); mul_1076 = mul_1079 = None | |
| - mul_1080 = torch.ops.aten.mul.Tensor(sub_262, rsqrt_55); sub_262 = rsqrt_55 = None | |
| - mul_1081 = torch.ops.aten.mul.Tensor(add_449, mul_156); add_449 = mul_156 = None | |
| - sum_175 = torch.ops.aten.sum.dim_IntList(mul_1081, [0]); mul_1081 = None | |
| - mul_1082 = torch.ops.aten.mul.Tensor(add_450, primals_329); primals_329 = None | |
| - mul_1084 = torch.ops.aten.mul.Tensor(mul_154, mul_1082) | |
| - sum_176 = torch.ops.aten.sum.dim_IntList(mul_1084, [1], True); mul_1084 = None | |
| - div_128 = torch.ops.aten.div.Tensor(mul_154, 4608) | |
| - mul_1085 = torch.ops.aten.mul.Tensor(div_128, sum_176); div_128 = sum_176 = None | |
| - sub_263 = torch.ops.aten.sub.Tensor(mul_1082, mul_1085); mul_1082 = mul_1085 = None | |
| - mul_1086 = torch.ops.aten.mul.Tensor(sub_263, rsqrt_54); sub_263 = rsqrt_54 = None | |
| - mul_1087 = torch.ops.aten.mul.Tensor(add_450, mul_154); add_450 = mul_154 = None | |
| - sum_177 = torch.ops.aten.sum.dim_IntList(mul_1087, [0]); mul_1087 = None | |
| - convert_element_type_1297 = torch.ops.prims.convert_element_type.default(mul_1080, torch.bfloat16) | |
| - convert_element_type_1298 = torch.ops.prims.convert_element_type.default(mul_1086, torch.bfloat16) | |
| - index_put_46 = torch.ops.aten.index_put.default(full_default_296, [sub_32], convert_element_type_1297, True) | |
| + mm_316 = torch.ops.aten.mm.default(add_482, permute_814); add_482 = permute_814 = None | |
| + add_483 = torch.ops.aten.add.Tensor(convert_element_type_1355, mm_316); convert_element_type_1355 = mm_316 = None | |
| + convert_element_type_1385 = torch.ops.prims.convert_element_type.default(mm_315, torch.float32); mm_315 = None | |
| + mul_1138 = torch.ops.aten.mul.Tensor(add_481, convert_element_type_503); convert_element_type_503 = None | |
| + mul_1139 = torch.ops.aten.mul.Tensor(add_481, sigmoid_16); add_481 = None | |
| + convert_element_type_1386 = torch.ops.prims.convert_element_type.default(mul_1138, torch.float32); mul_1138 = None | |
| + convert_element_type_1387 = torch.ops.prims.convert_element_type.default(sigmoid_16, torch.float32); sigmoid_16 = None | |
| + sub_225 = torch.ops.aten.sub.Tensor(1, convert_element_type_1387) | |
| + mul_1140 = torch.ops.aten.mul.Tensor(convert_element_type_1387, sub_225); convert_element_type_1387 = sub_225 = None | |
| + mul_1141 = torch.ops.aten.mul.Tensor(convert_element_type_1386, mul_1140); convert_element_type_1386 = mul_1140 = None | |
| + convert_element_type_1388 = torch.ops.prims.convert_element_type.default(mul_1141, torch.bfloat16); mul_1141 = None | |
| + add_484 = torch.ops.aten.add.Tensor(mul_1139, convert_element_type_1388); mul_1139 = convert_element_type_1388 = None | |
| + mul_1142 = torch.ops.aten.mul.Tensor(add_483, convert_element_type_501); convert_element_type_501 = None | |
| + mul_1143 = torch.ops.aten.mul.Tensor(add_483, sigmoid_15); add_483 = None | |
| + convert_element_type_1389 = torch.ops.prims.convert_element_type.default(mul_1142, torch.float32); mul_1142 = None | |
| + convert_element_type_1390 = torch.ops.prims.convert_element_type.default(sigmoid_15, torch.float32); sigmoid_15 = None | |
| + sub_226 = torch.ops.aten.sub.Tensor(1, convert_element_type_1390) | |
| + mul_1144 = torch.ops.aten.mul.Tensor(convert_element_type_1390, sub_226); convert_element_type_1390 = sub_226 = None | |
| + mul_1145 = torch.ops.aten.mul.Tensor(convert_element_type_1389, mul_1144); convert_element_type_1389 = mul_1144 = None | |
| + convert_element_type_1391 = torch.ops.prims.convert_element_type.default(mul_1145, torch.bfloat16); mul_1145 = None | |
| + add_485 = torch.ops.aten.add.Tensor(mul_1143, convert_element_type_1391); mul_1143 = convert_element_type_1391 = None | |
| + convert_element_type_1392 = torch.ops.prims.convert_element_type.default(add_484, torch.float32); add_484 = None | |
| + mul_1146 = torch.ops.aten.mul.Tensor(convert_element_type_1392, mul_156); mul_156 = None | |
| + mul_1147 = torch.ops.aten.mul.Tensor(convert_element_type_1392, primals_330); convert_element_type_1392 = primals_330 = None | |
| + sum_174 = torch.ops.aten.sum.dim_IntList(mul_1146, [0], True); mul_1146 = None | |
| + view_560 = torch.ops.aten.view.default(sum_174, [4608]); sum_174 = None | |
| + mul_1148 = torch.ops.aten.mul.Tensor(mul_1147, convert_element_type_502) | |
| + mul_1149 = torch.ops.aten.mul.Tensor(mul_1147, rsqrt_55); mul_1147 = None | |
| + sum_175 = torch.ops.aten.sum.dim_IntList(mul_1148, [1], True); mul_1148 = None | |
| + mul_1150 = torch.ops.aten.mul.Scalar(sum_175, -0.5); sum_175 = None | |
| + pow_154 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_55, 3); rsqrt_55 = None | |
| + mul_1151 = torch.ops.aten.mul.Tensor(mul_1150, pow_154); mul_1150 = pow_154 = None | |
| + expand_98 = torch.ops.aten.expand.default(mul_1151, [4096, 4608]); mul_1151 = None | |
| + div_127 = torch.ops.aten.div.Scalar(expand_98, 4608); expand_98 = None | |
| + pow_155 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_502, 1.0); convert_element_type_502 = None | |
| + mul_1152 = torch.ops.aten.mul.Scalar(pow_155, 2.0); pow_155 = None | |
| + mul_1153 = torch.ops.aten.mul.Tensor(div_127, mul_1152); div_127 = mul_1152 = None | |
| + add_486 = torch.ops.aten.add.Tensor(mul_1149, mul_1153); mul_1149 = mul_1153 = None | |
| + convert_element_type_1393 = torch.ops.prims.convert_element_type.default(add_486, torch.bfloat16); add_486 = None | |
| + convert_element_type_1394 = torch.ops.prims.convert_element_type.default(add_485, torch.float32); add_485 = None | |
| + mul_1154 = torch.ops.aten.mul.Tensor(convert_element_type_1394, mul_154); mul_154 = None | |
| + mul_1155 = torch.ops.aten.mul.Tensor(convert_element_type_1394, primals_329); convert_element_type_1394 = primals_329 = None | |
| + sum_176 = torch.ops.aten.sum.dim_IntList(mul_1154, [0], True); mul_1154 = None | |
| + view_561 = torch.ops.aten.view.default(sum_176, [4608]); sum_176 = None | |
| + mul_1156 = torch.ops.aten.mul.Tensor(mul_1155, convert_element_type_500) | |
| + mul_1157 = torch.ops.aten.mul.Tensor(mul_1155, rsqrt_54); mul_1155 = None | |
| + sum_177 = torch.ops.aten.sum.dim_IntList(mul_1156, [1], True); mul_1156 = None | |
| + mul_1158 = torch.ops.aten.mul.Scalar(sum_177, -0.5); sum_177 = None | |
| + pow_156 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_54, 3); rsqrt_54 = None | |
| + mul_1159 = torch.ops.aten.mul.Tensor(mul_1158, pow_156); mul_1158 = pow_156 = None | |
| + expand_99 = torch.ops.aten.expand.default(mul_1159, [4096, 4608]); mul_1159 = None | |
| + div_128 = torch.ops.aten.div.Scalar(expand_99, 4608); expand_99 = None | |
| + pow_157 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_500, 1.0); convert_element_type_500 = None | |
| + mul_1160 = torch.ops.aten.mul.Scalar(pow_157, 2.0); pow_157 = None | |
| + mul_1161 = torch.ops.aten.mul.Tensor(div_128, mul_1160); div_128 = mul_1160 = None | |
| + add_487 = torch.ops.aten.add.Tensor(mul_1157, mul_1161); mul_1157 = mul_1161 = None | |
| + convert_element_type_1395 = torch.ops.prims.convert_element_type.default(add_487, torch.bfloat16); add_487 = None | |
| + index_put_46 = torch.ops.aten.index_put.default(full_default_296, [sub_32], convert_element_type_1393, True) | |
| slice_scatter_53 = torch.ops.aten.slice_scatter.default(full_default_297, index_put_46, 1, 4608, 9223372036854775807); index_put_46 = None | |
| - permute_816 = torch.ops.aten.permute.default(convert_element_type_1297, [1, 0]) | |
| + permute_816 = torch.ops.aten.permute.default(convert_element_type_1393, [1, 0]) | |
| slice_22 = torch.ops.aten.slice.Tensor(mm_59, 1, 2304, 9223372036854775807) | |
| index_15 = torch.ops.aten.index.Tensor(slice_22, [sub_32]); slice_22 = None | |
| add_120 = torch.ops.aten.add.Tensor(mm_60, index_15); mm_60 = index_15 = None | |
| - convert_element_type_490 = torch.ops.prims.convert_element_type.default(add_120, torch.float32); add_120 = None | |
| - pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_490, 2) | |
| + convert_element_type_492 = torch.ops.prims.convert_element_type.default(add_120, torch.float32); add_120 = None | |
| + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_492, 2) | |
| mean_32 = torch.ops.aten.mean.dim(pow_33, [1], True); pow_33 = None | |
| add_122 = torch.ops.aten.add.Scalar(mean_32, 1.1920928955078125e-07); mean_32 = None | |
| rsqrt_53 = torch.ops.aten.rsqrt.default(add_122); add_122 = None | |
| - mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_490, rsqrt_53); convert_element_type_490 = None | |
| + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_492, rsqrt_53) | |
| mul_151 = torch.ops.aten.mul.Tensor(mul_150, primals_326) | |
| - sigmoid_14 = torch.ops.aten.sigmoid.default(mul_151) | |
| - mul_153 = torch.ops.aten.mul.Tensor(mul_151, sigmoid_14) | |
| - convert_element_type_495 = torch.ops.prims.convert_element_type.default(mul_153, torch.bfloat16); mul_153 = None | |
| - mm_317 = torch.ops.aten.mm.default(permute_816, convert_element_type_495); permute_816 = convert_element_type_495 = None | |
| - convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_328, torch.bfloat16); primals_328 = None | |
| - permute_178 = torch.ops.aten.permute.default(convert_element_type_496, [1, 0]); convert_element_type_496 = None | |
| + convert_element_type_493 = torch.ops.prims.convert_element_type.default(mul_151, torch.bfloat16); mul_151 = None | |
| + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_493) | |
| + mul_153 = torch.ops.aten.mul.Tensor(convert_element_type_493, sigmoid_14) | |
| + mm_317 = torch.ops.aten.mm.default(permute_816, mul_153); permute_816 = mul_153 = None | |
| + convert_element_type_497 = torch.ops.prims.convert_element_type.default(primals_328, torch.bfloat16); primals_328 = None | |
| + permute_178 = torch.ops.aten.permute.default(convert_element_type_497, [1, 0]); convert_element_type_497 = None | |
| permute_818 = torch.ops.aten.permute.default(permute_178, [1, 0]); permute_178 = None | |
| - mm_318 = torch.ops.aten.mm.default(convert_element_type_1297, permute_818); convert_element_type_1297 = permute_818 = None | |
| - convert_element_type_1303 = torch.ops.prims.convert_element_type.default(mm_317, torch.float32); mm_317 = None | |
| - convert_element_type_1304 = torch.ops.prims.convert_element_type.default(mm_318, torch.float32); mm_318 = None | |
| - slice_scatter_54 = torch.ops.aten.slice_scatter.default(full_default_297, convert_element_type_1298, 1, 0, 4608); convert_element_type_1298 = None | |
| - add_451 = torch.ops.aten.add.Tensor(slice_scatter_53, slice_scatter_54); slice_scatter_53 = slice_scatter_54 = None | |
| - permute_820 = torch.ops.aten.permute.default(add_451, [1, 0]) | |
| + mm_318 = torch.ops.aten.mm.default(convert_element_type_1393, permute_818); permute_818 = None | |
| + convert_element_type_1400 = torch.ops.prims.convert_element_type.default(mm_317, torch.float32); mm_317 = None | |
| + slice_scatter_54 = torch.ops.aten.slice_scatter.default(full_default_297, convert_element_type_1395, 1, 0, 4608) | |
| + add_488 = torch.ops.aten.add.Tensor(slice_scatter_53, slice_scatter_54); slice_scatter_53 = slice_scatter_54 = None | |
| + permute_820 = torch.ops.aten.permute.default(add_488, [1, 0]) | |
| slice_21 = torch.ops.aten.slice.Tensor(mm_59, 1, 0, 2304); mm_59 = None | |
| - convert_element_type_489 = torch.ops.prims.convert_element_type.default(slice_21, torch.float32); slice_21 = None | |
| - pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_489, 2) | |
| + convert_element_type_490 = torch.ops.prims.convert_element_type.default(slice_21, torch.float32); slice_21 = None | |
| + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_490, 2) | |
| mean_31 = torch.ops.aten.mean.dim(pow_32, [1], True); pow_32 = None | |
| add_121 = torch.ops.aten.add.Scalar(mean_31, 1.1920928955078125e-07); mean_31 = None | |
| rsqrt_52 = torch.ops.aten.rsqrt.default(add_121); add_121 = None | |
| - mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_489, rsqrt_52); convert_element_type_489 = None | |
| + mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_490, rsqrt_52) | |
| mul_149 = torch.ops.aten.mul.Tensor(mul_148, primals_325) | |
| - sigmoid_13 = torch.ops.aten.sigmoid.default(mul_149) | |
| - mul_152 = torch.ops.aten.mul.Tensor(mul_149, sigmoid_13) | |
| - convert_element_type_491 = torch.ops.prims.convert_element_type.default(mul_152, torch.bfloat16); mul_152 = None | |
| - mm_319 = torch.ops.aten.mm.default(permute_820, convert_element_type_491); permute_820 = convert_element_type_491 = None | |
| - convert_element_type_492 = torch.ops.prims.convert_element_type.default(primals_327, torch.bfloat16); primals_327 = None | |
| - permute_177 = torch.ops.aten.permute.default(convert_element_type_492, [1, 0]); convert_element_type_492 = None | |
| + convert_element_type_491 = torch.ops.prims.convert_element_type.default(mul_149, torch.bfloat16); mul_149 = None | |
| + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_491) | |
| + mul_152 = torch.ops.aten.mul.Tensor(convert_element_type_491, sigmoid_13) | |
| + mm_319 = torch.ops.aten.mm.default(permute_820, mul_152); permute_820 = mul_152 = None | |
| + convert_element_type_494 = torch.ops.prims.convert_element_type.default(primals_327, torch.bfloat16); primals_327 = None | |
| + permute_177 = torch.ops.aten.permute.default(convert_element_type_494, [1, 0]); convert_element_type_494 = None | |
| permute_822 = torch.ops.aten.permute.default(permute_177, [1, 0]); permute_177 = None | |
| - mm_320 = torch.ops.aten.mm.default(add_451, permute_822); add_451 = permute_822 = None | |
| - convert_element_type_1309 = torch.ops.prims.convert_element_type.default(mm_319, torch.float32); mm_319 = None | |
| - convert_element_type_1310 = torch.ops.prims.convert_element_type.default(mm_320, torch.float32); mm_320 = None | |
| - mul_1088 = torch.ops.aten.mul.Tensor(convert_element_type_1304, mul_151); mul_151 = None | |
| - mul_1089 = torch.ops.aten.mul.Tensor(convert_element_type_1304, sigmoid_14); convert_element_type_1304 = None | |
| - sub_264 = torch.ops.aten.sub.Tensor(1, sigmoid_14) | |
| - mul_1090 = torch.ops.aten.mul.Tensor(sigmoid_14, sub_264); sigmoid_14 = sub_264 = None | |
| - mul_1091 = torch.ops.aten.mul.Tensor(mul_1088, mul_1090); mul_1088 = mul_1090 = None | |
| - add_452 = torch.ops.aten.add.Tensor(mul_1089, mul_1091); mul_1089 = mul_1091 = None | |
| - mul_1092 = torch.ops.aten.mul.Tensor(convert_element_type_1310, mul_149); mul_149 = None | |
| - mul_1093 = torch.ops.aten.mul.Tensor(convert_element_type_1310, sigmoid_13); convert_element_type_1310 = None | |
| - sub_265 = torch.ops.aten.sub.Tensor(1, sigmoid_13) | |
| - mul_1094 = torch.ops.aten.mul.Tensor(sigmoid_13, sub_265); sigmoid_13 = sub_265 = None | |
| - mul_1095 = torch.ops.aten.mul.Tensor(mul_1092, mul_1094); mul_1092 = mul_1094 = None | |
| - add_453 = torch.ops.aten.add.Tensor(mul_1093, mul_1095); mul_1093 = mul_1095 = None | |
| - mul_1096 = torch.ops.aten.mul.Tensor(add_452, primals_326); primals_326 = None | |
| - mul_1098 = torch.ops.aten.mul.Tensor(mul_150, mul_1096) | |
| - sum_178 = torch.ops.aten.sum.dim_IntList(mul_1098, [1], True); mul_1098 = None | |
| - div_129 = torch.ops.aten.div.Tensor(mul_150, 2304) | |
| - mul_1099 = torch.ops.aten.mul.Tensor(div_129, sum_178); div_129 = sum_178 = None | |
| - sub_266 = torch.ops.aten.sub.Tensor(mul_1096, mul_1099); mul_1096 = mul_1099 = None | |
| - mul_1100 = torch.ops.aten.mul.Tensor(sub_266, rsqrt_53); sub_266 = rsqrt_53 = None | |
| - mul_1101 = torch.ops.aten.mul.Tensor(add_452, mul_150); add_452 = mul_150 = None | |
| - sum_179 = torch.ops.aten.sum.dim_IntList(mul_1101, [0]); mul_1101 = None | |
| - convert_element_type_1311 = torch.ops.prims.convert_element_type.default(mul_1100, torch.bfloat16); mul_1100 = None | |
| - mul_1102 = torch.ops.aten.mul.Tensor(add_453, primals_325); primals_325 = None | |
| - mul_1104 = torch.ops.aten.mul.Tensor(mul_148, mul_1102) | |
| - sum_180 = torch.ops.aten.sum.dim_IntList(mul_1104, [1], True); mul_1104 = None | |
| - div_130 = torch.ops.aten.div.Tensor(mul_148, 2304) | |
| - mul_1105 = torch.ops.aten.mul.Tensor(div_130, sum_180); div_130 = sum_180 = None | |
| - sub_267 = torch.ops.aten.sub.Tensor(mul_1102, mul_1105); mul_1102 = mul_1105 = None | |
| - mul_1106 = torch.ops.aten.mul.Tensor(sub_267, rsqrt_52); sub_267 = rsqrt_52 = None | |
| - mul_1107 = torch.ops.aten.mul.Tensor(add_453, mul_148); add_453 = mul_148 = None | |
| - sum_181 = torch.ops.aten.sum.dim_IntList(mul_1107, [0]); mul_1107 = None | |
| - convert_element_type_1312 = torch.ops.prims.convert_element_type.default(mul_1106, torch.bfloat16); mul_1106 = None | |
| - index_put_47 = torch.ops.aten.index_put.default(full_default_305, [sub_32], convert_element_type_1311, True); full_default_305 = None | |
| + mm_320 = torch.ops.aten.mm.default(add_488, permute_822); add_488 = permute_822 = None | |
| + convert_element_type_1405 = torch.ops.prims.convert_element_type.default(mm_319, torch.float32); mm_319 = None | |
| + mul_1162 = torch.ops.aten.mul.Tensor(mm_318, convert_element_type_493); convert_element_type_493 = None | |
| + mul_1163 = torch.ops.aten.mul.Tensor(mm_318, sigmoid_14); mm_318 = None | |
| + convert_element_type_1406 = torch.ops.prims.convert_element_type.default(mul_1162, torch.float32); mul_1162 = None | |
| + convert_element_type_1407 = torch.ops.prims.convert_element_type.default(sigmoid_14, torch.float32); sigmoid_14 = None | |
| + sub_227 = torch.ops.aten.sub.Tensor(1, convert_element_type_1407) | |
| + mul_1164 = torch.ops.aten.mul.Tensor(convert_element_type_1407, sub_227); convert_element_type_1407 = sub_227 = None | |
| + mul_1165 = torch.ops.aten.mul.Tensor(convert_element_type_1406, mul_1164); convert_element_type_1406 = mul_1164 = None | |
| + convert_element_type_1408 = torch.ops.prims.convert_element_type.default(mul_1165, torch.bfloat16); mul_1165 = None | |
| + add_489 = torch.ops.aten.add.Tensor(mul_1163, convert_element_type_1408); mul_1163 = convert_element_type_1408 = None | |
| + mul_1166 = torch.ops.aten.mul.Tensor(mm_320, convert_element_type_491); convert_element_type_491 = None | |
| + mul_1167 = torch.ops.aten.mul.Tensor(mm_320, sigmoid_13); mm_320 = None | |
| + convert_element_type_1409 = torch.ops.prims.convert_element_type.default(mul_1166, torch.float32); mul_1166 = None | |
| + convert_element_type_1410 = torch.ops.prims.convert_element_type.default(sigmoid_13, torch.float32); sigmoid_13 = None | |
| + sub_228 = torch.ops.aten.sub.Tensor(1, convert_element_type_1410) | |
| + mul_1168 = torch.ops.aten.mul.Tensor(convert_element_type_1410, sub_228); convert_element_type_1410 = sub_228 = None | |
| + mul_1169 = torch.ops.aten.mul.Tensor(convert_element_type_1409, mul_1168); convert_element_type_1409 = mul_1168 = None | |
| + convert_element_type_1411 = torch.ops.prims.convert_element_type.default(mul_1169, torch.bfloat16); mul_1169 = None | |
| + add_490 = torch.ops.aten.add.Tensor(mul_1167, convert_element_type_1411); mul_1167 = convert_element_type_1411 = None | |
| + convert_element_type_1412 = torch.ops.prims.convert_element_type.default(add_489, torch.float32); add_489 = None | |
| + mul_1170 = torch.ops.aten.mul.Tensor(convert_element_type_1412, mul_150); mul_150 = None | |
| + mul_1171 = torch.ops.aten.mul.Tensor(convert_element_type_1412, primals_326); convert_element_type_1412 = primals_326 = None | |
| + sum_178 = torch.ops.aten.sum.dim_IntList(mul_1170, [0], True); mul_1170 = None | |
| + view_562 = torch.ops.aten.view.default(sum_178, [2304]); sum_178 = None | |
| + mul_1172 = torch.ops.aten.mul.Tensor(mul_1171, convert_element_type_492) | |
| + mul_1173 = torch.ops.aten.mul.Tensor(mul_1171, rsqrt_53); mul_1171 = None | |
| + sum_179 = torch.ops.aten.sum.dim_IntList(mul_1172, [1], True); mul_1172 = None | |
| + mul_1174 = torch.ops.aten.mul.Scalar(sum_179, -0.5); sum_179 = None | |
| + pow_158 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_53, 3); rsqrt_53 = None | |
| + mul_1175 = torch.ops.aten.mul.Tensor(mul_1174, pow_158); mul_1174 = pow_158 = None | |
| + expand_100 = torch.ops.aten.expand.default(mul_1175, [4096, 2304]); mul_1175 = None | |
| + div_129 = torch.ops.aten.div.Scalar(expand_100, 2304); expand_100 = None | |
| + pow_159 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_492, 1.0); convert_element_type_492 = None | |
| + mul_1176 = torch.ops.aten.mul.Scalar(pow_159, 2.0); pow_159 = None | |
| + mul_1177 = torch.ops.aten.mul.Tensor(div_129, mul_1176); div_129 = mul_1176 = None | |
| + add_491 = torch.ops.aten.add.Tensor(mul_1173, mul_1177); mul_1173 = mul_1177 = None | |
| + convert_element_type_1413 = torch.ops.prims.convert_element_type.default(add_491, torch.bfloat16); add_491 = None | |
| + convert_element_type_1414 = torch.ops.prims.convert_element_type.default(add_490, torch.float32); add_490 = None | |
| + mul_1178 = torch.ops.aten.mul.Tensor(convert_element_type_1414, mul_148); mul_148 = None | |
| + mul_1179 = torch.ops.aten.mul.Tensor(convert_element_type_1414, primals_325); convert_element_type_1414 = primals_325 = None | |
| + sum_180 = torch.ops.aten.sum.dim_IntList(mul_1178, [0], True); mul_1178 = None | |
| + view_563 = torch.ops.aten.view.default(sum_180, [2304]); sum_180 = None | |
| + mul_1180 = torch.ops.aten.mul.Tensor(mul_1179, convert_element_type_490) | |
| + mul_1181 = torch.ops.aten.mul.Tensor(mul_1179, rsqrt_52); mul_1179 = None | |
| + sum_181 = torch.ops.aten.sum.dim_IntList(mul_1180, [1], True); mul_1180 = None | |
| + mul_1182 = torch.ops.aten.mul.Scalar(sum_181, -0.5); sum_181 = None | |
| + pow_160 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_52, 3); rsqrt_52 = None | |
| + mul_1183 = torch.ops.aten.mul.Tensor(mul_1182, pow_160); mul_1182 = pow_160 = None | |
| + expand_101 = torch.ops.aten.expand.default(mul_1183, [4096, 2304]); mul_1183 = None | |
| + div_130 = torch.ops.aten.div.Scalar(expand_101, 2304); expand_101 = None | |
| + pow_161 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_490, 1.0); convert_element_type_490 = None | |
| + mul_1184 = torch.ops.aten.mul.Scalar(pow_161, 2.0); pow_161 = None | |
| + mul_1185 = torch.ops.aten.mul.Tensor(div_130, mul_1184); div_130 = mul_1184 = None | |
| + add_492 = torch.ops.aten.add.Tensor(mul_1181, mul_1185); mul_1181 = mul_1185 = None | |
| + convert_element_type_1415 = torch.ops.prims.convert_element_type.default(add_492, torch.bfloat16); add_492 = None | |
| + index_put_47 = torch.ops.aten.index_put.default(full_default_305, [sub_32], convert_element_type_1413, True); full_default_305 = None | |
| slice_scatter_55 = torch.ops.aten.slice_scatter.default(full_default_296, index_put_47, 1, 2304, 9223372036854775807); index_put_47 = None | |
| - permute_824 = torch.ops.aten.permute.default(convert_element_type_1311, [1, 0]) | |
| - convert_element_type_480 = torch.ops.prims.convert_element_type.default(add_117, torch.float32); add_117 = None | |
| - pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_480, 2) | |
| + permute_824 = torch.ops.aten.permute.default(convert_element_type_1413, [1, 0]) | |
| + convert_element_type_482 = torch.ops.prims.convert_element_type.default(add_117, torch.float32); add_117 = None | |
| + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_482, 2) | |
| mean_30 = torch.ops.aten.mean.dim(pow_31, [1], True); pow_31 = None | |
| add_119 = torch.ops.aten.add.Scalar(mean_30, 1.1920928955078125e-07); mean_30 = None | |
| rsqrt_51 = torch.ops.aten.rsqrt.default(add_119); add_119 = None | |
| - mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_480, rsqrt_51); convert_element_type_480 = None | |
| + mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_482, rsqrt_51) | |
| mul_145 = torch.ops.aten.mul.Tensor(mul_144, primals_322) | |
| - sigmoid_12 = torch.ops.aten.sigmoid.default(mul_145) | |
| - mul_147 = torch.ops.aten.mul.Tensor(mul_145, sigmoid_12) | |
| - convert_element_type_485 = torch.ops.prims.convert_element_type.default(mul_147, torch.bfloat16); mul_147 = None | |
| - mm_321 = torch.ops.aten.mm.default(permute_824, convert_element_type_485); permute_824 = convert_element_type_485 = None | |
| - convert_element_type_486 = torch.ops.prims.convert_element_type.default(primals_324, torch.bfloat16); primals_324 = None | |
| - permute_176 = torch.ops.aten.permute.default(convert_element_type_486, [1, 0]); convert_element_type_486 = None | |
| + convert_element_type_483 = torch.ops.prims.convert_element_type.default(mul_145, torch.bfloat16); mul_145 = None | |
| + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_483) | |
| + mul_147 = torch.ops.aten.mul.Tensor(convert_element_type_483, sigmoid_12) | |
| + mm_321 = torch.ops.aten.mm.default(permute_824, mul_147); permute_824 = mul_147 = None | |
| + convert_element_type_487 = torch.ops.prims.convert_element_type.default(primals_324, torch.bfloat16); primals_324 = None | |
| + permute_176 = torch.ops.aten.permute.default(convert_element_type_487, [1, 0]); convert_element_type_487 = None | |
| permute_826 = torch.ops.aten.permute.default(permute_176, [1, 0]); permute_176 = None | |
| - mm_322 = torch.ops.aten.mm.default(convert_element_type_1311, permute_826); convert_element_type_1311 = permute_826 = None | |
| - convert_element_type_1317 = torch.ops.prims.convert_element_type.default(mm_321, torch.float32); mm_321 = None | |
| - convert_element_type_1318 = torch.ops.prims.convert_element_type.default(mm_322, torch.float32); mm_322 = None | |
| - add_454 = torch.ops.aten.add.Tensor(mul_1080, convert_element_type_1318); mul_1080 = convert_element_type_1318 = None | |
| - slice_scatter_56 = torch.ops.aten.slice_scatter.default(full_default_296, convert_element_type_1312, 1, 0, 2304); convert_element_type_1312 = None | |
| - add_455 = torch.ops.aten.add.Tensor(slice_scatter_55, slice_scatter_56); slice_scatter_55 = slice_scatter_56 = None | |
| - permute_828 = torch.ops.aten.permute.default(add_455, [1, 0]) | |
| - convert_element_type_479 = torch.ops.prims.convert_element_type.default(slice_19, torch.float32); slice_19 = None | |
| - pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_479, 2) | |
| + mm_322 = torch.ops.aten.mm.default(convert_element_type_1413, permute_826); convert_element_type_1413 = permute_826 = None | |
| + add_493 = torch.ops.aten.add.Tensor(convert_element_type_1393, mm_322); convert_element_type_1393 = mm_322 = None | |
| + convert_element_type_1420 = torch.ops.prims.convert_element_type.default(mm_321, torch.float32); mm_321 = None | |
| + slice_scatter_56 = torch.ops.aten.slice_scatter.default(full_default_296, convert_element_type_1415, 1, 0, 2304); convert_element_type_1415 = None | |
| + add_494 = torch.ops.aten.add.Tensor(slice_scatter_55, slice_scatter_56); slice_scatter_55 = slice_scatter_56 = None | |
| + permute_828 = torch.ops.aten.permute.default(add_494, [1, 0]) | |
| + convert_element_type_480 = torch.ops.prims.convert_element_type.default(slice_19, torch.float32); slice_19 = None | |
| + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_480, 2) | |
| mean_29 = torch.ops.aten.mean.dim(pow_30, [1], True); pow_30 = None | |
| add_118 = torch.ops.aten.add.Scalar(mean_29, 1.1920928955078125e-07); mean_29 = None | |
| rsqrt_50 = torch.ops.aten.rsqrt.default(add_118); add_118 = None | |
| - mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_479, rsqrt_50); convert_element_type_479 = None | |
| + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_480, rsqrt_50) | |
| mul_143 = torch.ops.aten.mul.Tensor(mul_142, primals_321) | |
| - sigmoid_11 = torch.ops.aten.sigmoid.default(mul_143) | |
| - mul_146 = torch.ops.aten.mul.Tensor(mul_143, sigmoid_11) | |
| - convert_element_type_481 = torch.ops.prims.convert_element_type.default(mul_146, torch.bfloat16); mul_146 = None | |
| - mm_323 = torch.ops.aten.mm.default(permute_828, convert_element_type_481); permute_828 = convert_element_type_481 = None | |
| - convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_323, torch.bfloat16); primals_323 = None | |
| - permute_175 = torch.ops.aten.permute.default(convert_element_type_482, [1, 0]); convert_element_type_482 = None | |
| + convert_element_type_481 = torch.ops.prims.convert_element_type.default(mul_143, torch.bfloat16); mul_143 = None | |
| + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_481) | |
| + mul_146 = torch.ops.aten.mul.Tensor(convert_element_type_481, sigmoid_11) | |
| + mm_323 = torch.ops.aten.mm.default(permute_828, mul_146); permute_828 = mul_146 = None | |
| + convert_element_type_484 = torch.ops.prims.convert_element_type.default(primals_323, torch.bfloat16); primals_323 = None | |
| + permute_175 = torch.ops.aten.permute.default(convert_element_type_484, [1, 0]); convert_element_type_484 = None | |
| permute_830 = torch.ops.aten.permute.default(permute_175, [1, 0]); permute_175 = None | |
| - mm_324 = torch.ops.aten.mm.default(add_455, permute_830); add_455 = permute_830 = None | |
| - convert_element_type_1323 = torch.ops.prims.convert_element_type.default(mm_323, torch.float32); mm_323 = None | |
| - convert_element_type_1324 = torch.ops.prims.convert_element_type.default(mm_324, torch.float32); mm_324 = None | |
| - add_456 = torch.ops.aten.add.Tensor(mul_1086, convert_element_type_1324); mul_1086 = convert_element_type_1324 = None | |
| - mul_1108 = torch.ops.aten.mul.Tensor(add_454, mul_145); mul_145 = None | |
| - mul_1109 = torch.ops.aten.mul.Tensor(add_454, sigmoid_12); add_454 = None | |
| - sub_268 = torch.ops.aten.sub.Tensor(1, sigmoid_12) | |
| - mul_1110 = torch.ops.aten.mul.Tensor(sigmoid_12, sub_268); sigmoid_12 = sub_268 = None | |
| - mul_1111 = torch.ops.aten.mul.Tensor(mul_1108, mul_1110); mul_1108 = mul_1110 = None | |
| - add_457 = torch.ops.aten.add.Tensor(mul_1109, mul_1111); mul_1109 = mul_1111 = None | |
| - mul_1112 = torch.ops.aten.mul.Tensor(add_456, mul_143); mul_143 = None | |
| - mul_1113 = torch.ops.aten.mul.Tensor(add_456, sigmoid_11); add_456 = None | |
| - sub_269 = torch.ops.aten.sub.Tensor(1, sigmoid_11) | |
| - mul_1114 = torch.ops.aten.mul.Tensor(sigmoid_11, sub_269); sigmoid_11 = sub_269 = None | |
| - mul_1115 = torch.ops.aten.mul.Tensor(mul_1112, mul_1114); mul_1112 = mul_1114 = None | |
| - add_458 = torch.ops.aten.add.Tensor(mul_1113, mul_1115); mul_1113 = mul_1115 = None | |
| - mul_1116 = torch.ops.aten.mul.Tensor(add_457, primals_322); primals_322 = None | |
| - mul_1118 = torch.ops.aten.mul.Tensor(mul_144, mul_1116) | |
| - sum_182 = torch.ops.aten.sum.dim_IntList(mul_1118, [1], True); mul_1118 = None | |
| - div_131 = torch.ops.aten.div.Tensor(mul_144, 4608) | |
| - mul_1119 = torch.ops.aten.mul.Tensor(div_131, sum_182); div_131 = sum_182 = None | |
| - sub_270 = torch.ops.aten.sub.Tensor(mul_1116, mul_1119); mul_1116 = mul_1119 = None | |
| - mul_1120 = torch.ops.aten.mul.Tensor(sub_270, rsqrt_51); sub_270 = rsqrt_51 = None | |
| - mul_1121 = torch.ops.aten.mul.Tensor(add_457, mul_144); add_457 = mul_144 = None | |
| - sum_183 = torch.ops.aten.sum.dim_IntList(mul_1121, [0]); mul_1121 = None | |
| - convert_element_type_1325 = torch.ops.prims.convert_element_type.default(mul_1120, torch.bfloat16); mul_1120 = None | |
| - mul_1122 = torch.ops.aten.mul.Tensor(add_458, primals_321); primals_321 = None | |
| - mul_1124 = torch.ops.aten.mul.Tensor(mul_142, mul_1122) | |
| - sum_184 = torch.ops.aten.sum.dim_IntList(mul_1124, [1], True); mul_1124 = None | |
| - div_132 = torch.ops.aten.div.Tensor(mul_142, 4608) | |
| - mul_1125 = torch.ops.aten.mul.Tensor(div_132, sum_184); div_132 = sum_184 = None | |
| - sub_271 = torch.ops.aten.sub.Tensor(mul_1122, mul_1125); mul_1122 = mul_1125 = None | |
| - mul_1126 = torch.ops.aten.mul.Tensor(sub_271, rsqrt_50); sub_271 = rsqrt_50 = None | |
| - mul_1127 = torch.ops.aten.mul.Tensor(add_458, mul_142); add_458 = mul_142 = None | |
| - sum_185 = torch.ops.aten.sum.dim_IntList(mul_1127, [0]); mul_1127 = None | |
| - convert_element_type_1326 = torch.ops.prims.convert_element_type.default(mul_1126, torch.bfloat16); mul_1126 = None | |
| - index_put_48 = torch.ops.aten.index_put.default(full_default_296, [sub_32], convert_element_type_1325, True); full_default_296 = None | |
| + mm_324 = torch.ops.aten.mm.default(add_494, permute_830); add_494 = permute_830 = None | |
| + add_495 = torch.ops.aten.add.Tensor(convert_element_type_1395, mm_324); convert_element_type_1395 = mm_324 = None | |
| + convert_element_type_1425 = torch.ops.prims.convert_element_type.default(mm_323, torch.float32); mm_323 = None | |
| + mul_1186 = torch.ops.aten.mul.Tensor(add_493, convert_element_type_483); convert_element_type_483 = None | |
| + mul_1187 = torch.ops.aten.mul.Tensor(add_493, sigmoid_12); add_493 = None | |
| + convert_element_type_1426 = torch.ops.prims.convert_element_type.default(mul_1186, torch.float32); mul_1186 = None | |
| + convert_element_type_1427 = torch.ops.prims.convert_element_type.default(sigmoid_12, torch.float32); sigmoid_12 = None | |
| + sub_229 = torch.ops.aten.sub.Tensor(1, convert_element_type_1427) | |
| + mul_1188 = torch.ops.aten.mul.Tensor(convert_element_type_1427, sub_229); convert_element_type_1427 = sub_229 = None | |
| + mul_1189 = torch.ops.aten.mul.Tensor(convert_element_type_1426, mul_1188); convert_element_type_1426 = mul_1188 = None | |
| + convert_element_type_1428 = torch.ops.prims.convert_element_type.default(mul_1189, torch.bfloat16); mul_1189 = None | |
| + add_496 = torch.ops.aten.add.Tensor(mul_1187, convert_element_type_1428); mul_1187 = convert_element_type_1428 = None | |
| + mul_1190 = torch.ops.aten.mul.Tensor(add_495, convert_element_type_481); convert_element_type_481 = None | |
| + mul_1191 = torch.ops.aten.mul.Tensor(add_495, sigmoid_11); add_495 = None | |
| + convert_element_type_1429 = torch.ops.prims.convert_element_type.default(mul_1190, torch.float32); mul_1190 = None | |
| + convert_element_type_1430 = torch.ops.prims.convert_element_type.default(sigmoid_11, torch.float32); sigmoid_11 = None | |
| + sub_230 = torch.ops.aten.sub.Tensor(1, convert_element_type_1430) | |
| + mul_1192 = torch.ops.aten.mul.Tensor(convert_element_type_1430, sub_230); convert_element_type_1430 = sub_230 = None | |
| + mul_1193 = torch.ops.aten.mul.Tensor(convert_element_type_1429, mul_1192); convert_element_type_1429 = mul_1192 = None | |
| + convert_element_type_1431 = torch.ops.prims.convert_element_type.default(mul_1193, torch.bfloat16); mul_1193 = None | |
| + add_497 = torch.ops.aten.add.Tensor(mul_1191, convert_element_type_1431); mul_1191 = convert_element_type_1431 = None | |
| + convert_element_type_1432 = torch.ops.prims.convert_element_type.default(add_496, torch.float32); add_496 = None | |
| + mul_1194 = torch.ops.aten.mul.Tensor(convert_element_type_1432, mul_144); mul_144 = None | |
| + mul_1195 = torch.ops.aten.mul.Tensor(convert_element_type_1432, primals_322); convert_element_type_1432 = primals_322 = None | |
| + sum_182 = torch.ops.aten.sum.dim_IntList(mul_1194, [0], True); mul_1194 = None | |
| + view_564 = torch.ops.aten.view.default(sum_182, [4608]); sum_182 = None | |
| + mul_1196 = torch.ops.aten.mul.Tensor(mul_1195, convert_element_type_482) | |
| + mul_1197 = torch.ops.aten.mul.Tensor(mul_1195, rsqrt_51); mul_1195 = None | |
| + sum_183 = torch.ops.aten.sum.dim_IntList(mul_1196, [1], True); mul_1196 = None | |
| + mul_1198 = torch.ops.aten.mul.Scalar(sum_183, -0.5); sum_183 = None | |
| + pow_162 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_51, 3); rsqrt_51 = None | |
| + mul_1199 = torch.ops.aten.mul.Tensor(mul_1198, pow_162); mul_1198 = pow_162 = None | |
| + expand_102 = torch.ops.aten.expand.default(mul_1199, [4096, 4608]); mul_1199 = None | |
| + div_131 = torch.ops.aten.div.Scalar(expand_102, 4608); expand_102 = None | |
| + pow_163 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_482, 1.0); convert_element_type_482 = None | |
| + mul_1200 = torch.ops.aten.mul.Scalar(pow_163, 2.0); pow_163 = None | |
| + mul_1201 = torch.ops.aten.mul.Tensor(div_131, mul_1200); div_131 = mul_1200 = None | |
| + add_498 = torch.ops.aten.add.Tensor(mul_1197, mul_1201); mul_1197 = mul_1201 = None | |
| + convert_element_type_1433 = torch.ops.prims.convert_element_type.default(add_498, torch.bfloat16); add_498 = None | |
| + convert_element_type_1434 = torch.ops.prims.convert_element_type.default(add_497, torch.float32); add_497 = None | |
| + mul_1202 = torch.ops.aten.mul.Tensor(convert_element_type_1434, mul_142); mul_142 = None | |
| + mul_1203 = torch.ops.aten.mul.Tensor(convert_element_type_1434, primals_321); convert_element_type_1434 = primals_321 = None | |
| + sum_184 = torch.ops.aten.sum.dim_IntList(mul_1202, [0], True); mul_1202 = None | |
| + view_565 = torch.ops.aten.view.default(sum_184, [4608]); sum_184 = None | |
| + mul_1204 = torch.ops.aten.mul.Tensor(mul_1203, convert_element_type_480) | |
| + mul_1205 = torch.ops.aten.mul.Tensor(mul_1203, rsqrt_50); mul_1203 = None | |
| + sum_185 = torch.ops.aten.sum.dim_IntList(mul_1204, [1], True); mul_1204 = None | |
| + mul_1206 = torch.ops.aten.mul.Scalar(sum_185, -0.5); sum_185 = None | |
| + pow_164 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_50, 3); rsqrt_50 = None | |
| + mul_1207 = torch.ops.aten.mul.Tensor(mul_1206, pow_164); mul_1206 = pow_164 = None | |
| + expand_103 = torch.ops.aten.expand.default(mul_1207, [4096, 4608]); mul_1207 = None | |
| + div_132 = torch.ops.aten.div.Scalar(expand_103, 4608); expand_103 = None | |
| + pow_165 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_480, 1.0); convert_element_type_480 = None | |
| + mul_1208 = torch.ops.aten.mul.Scalar(pow_165, 2.0); pow_165 = None | |
| + mul_1209 = torch.ops.aten.mul.Tensor(div_132, mul_1208); div_132 = mul_1208 = None | |
| + add_499 = torch.ops.aten.add.Tensor(mul_1205, mul_1209); mul_1205 = mul_1209 = None | |
| + convert_element_type_1435 = torch.ops.prims.convert_element_type.default(add_499, torch.bfloat16); add_499 = None | |
| + index_put_48 = torch.ops.aten.index_put.default(full_default_296, [sub_32], convert_element_type_1433, True); full_default_296 = None | |
| slice_scatter_57 = torch.ops.aten.slice_scatter.default(full_default_297, index_put_48, 1, 4608, 9223372036854775807); index_put_48 = None | |
| - abs_64 = torch.ops.aten.abs.default(convert_element_type_1325) | |
| + abs_64 = torch.ops.aten.abs.default(convert_element_type_1433) | |
| amax_29 = torch.ops.aten.amax.default(abs_64, [-1], True); abs_64 = None | |
| - convert_element_type_1327 = torch.ops.prims.convert_element_type.default(amax_29, torch.float64); amax_29 = None | |
| - clamp_min_106 = torch.ops.aten.clamp_min.default(convert_element_type_1327, 1e-12); convert_element_type_1327 = None | |
| + convert_element_type_1436 = torch.ops.prims.convert_element_type.default(amax_29, torch.float64); amax_29 = None | |
| + clamp_min_106 = torch.ops.aten.clamp_min.default(convert_element_type_1436, 1e-12); convert_element_type_1436 = None | |
| reciprocal_72 = torch.ops.aten.reciprocal.default(clamp_min_106); clamp_min_106 = None | |
| - mul_1128 = torch.ops.aten.mul.Tensor(reciprocal_72, 448.0); reciprocal_72 = None | |
| - convert_element_type_1328 = torch.ops.prims.convert_element_type.default(mul_1128, torch.float32); mul_1128 = None | |
| - log2_36 = torch.ops.aten.log2.default(convert_element_type_1328); convert_element_type_1328 = None | |
| + mul_1210 = torch.ops.aten.mul.Tensor(reciprocal_72, 448.0); reciprocal_72 = None | |
| + convert_element_type_1437 = torch.ops.prims.convert_element_type.default(mul_1210, torch.float32); mul_1210 = None | |
| + log2_36 = torch.ops.aten.log2.default(convert_element_type_1437); convert_element_type_1437 = None | |
| floor_36 = torch.ops.aten.floor.default(log2_36); log2_36 = None | |
| exp2_36 = torch.ops.aten.exp2.default(floor_36); floor_36 = None | |
| - convert_element_type_1329 = torch.ops.prims.convert_element_type.default(convert_element_type_1325, torch.float32) | |
| - mul_1129 = torch.ops.aten.mul.Tensor(convert_element_type_1329, exp2_36); convert_element_type_1329 = None | |
| - clamp_min_107 = torch.ops.aten.clamp_min.default(mul_1129, -448.0); mul_1129 = None | |
| + convert_element_type_1438 = torch.ops.prims.convert_element_type.default(convert_element_type_1433, torch.float32) | |
| + mul_1211 = torch.ops.aten.mul.Tensor(convert_element_type_1438, exp2_36); convert_element_type_1438 = None | |
| + clamp_min_107 = torch.ops.aten.clamp_min.default(mul_1211, -448.0); mul_1211 = None | |
| clamp_max_66 = torch.ops.aten.clamp_max.default(clamp_min_107, 448.0); clamp_min_107 = None | |
| - convert_element_type_1330 = torch.ops.prims.convert_element_type.default(clamp_max_66, torch.float8_e4m3fn); clamp_max_66 = None | |
| + convert_element_type_1439 = torch.ops.prims.convert_element_type.default(clamp_max_66, torch.float8_e4m3fn); clamp_max_66 = None | |
| permute_174 = torch.ops.aten.permute.default(primals_320, [1, 0]); primals_320 = None | |
| abs_4 = torch.ops.aten.abs.default(permute_174) | |
| max_9 = torch.ops.aten.max.default(abs_4); abs_4 = None | |
| - convert_element_type_1331 = torch.ops.prims.convert_element_type.default(max_9, torch.float64); max_9 = None | |
| - clamp_min_108 = torch.ops.aten.clamp_min.default(convert_element_type_1331, 1e-12); convert_element_type_1331 = None | |
| + convert_element_type_1440 = torch.ops.prims.convert_element_type.default(max_9, torch.float64); max_9 = None | |
| + clamp_min_108 = torch.ops.aten.clamp_min.default(convert_element_type_1440, 1e-12); convert_element_type_1440 = None | |
| reciprocal_73 = torch.ops.aten.reciprocal.default(clamp_min_108); clamp_min_108 = None | |
| - mul_1130 = torch.ops.aten.mul.Tensor(reciprocal_73, 448.0); reciprocal_73 = None | |
| - convert_element_type_1332 = torch.ops.prims.convert_element_type.default(mul_1130, torch.float32); mul_1130 = None | |
| - log2_37 = torch.ops.aten.log2.default(convert_element_type_1332); convert_element_type_1332 = None | |
| + mul_1212 = torch.ops.aten.mul.Tensor(reciprocal_73, 448.0); reciprocal_73 = None | |
| + convert_element_type_1441 = torch.ops.prims.convert_element_type.default(mul_1212, torch.float32); mul_1212 = None | |
| + log2_37 = torch.ops.aten.log2.default(convert_element_type_1441); convert_element_type_1441 = None | |
| floor_37 = torch.ops.aten.floor.default(log2_37); log2_37 = None | |
| exp2_37 = torch.ops.aten.exp2.default(floor_37); floor_37 = None | |
| - mul_1131 = torch.ops.aten.mul.Tensor(permute_174, exp2_37); permute_174 = None | |
| - clamp_min_109 = torch.ops.aten.clamp_min.default(mul_1131, -448.0); mul_1131 = None | |
| + mul_1213 = torch.ops.aten.mul.Tensor(permute_174, exp2_37); permute_174 = None | |
| + clamp_min_109 = torch.ops.aten.clamp_min.default(mul_1213, -448.0); mul_1213 = None | |
| clamp_max_67 = torch.ops.aten.clamp_max.default(clamp_min_109, 448.0); clamp_min_109 = None | |
| - convert_element_type_1333 = torch.ops.prims.convert_element_type.default(clamp_max_67, torch.float8_e4m3fn); clamp_max_67 = None | |
| - clone_106 = torch.ops.aten.clone.default(convert_element_type_1333, memory_format = torch.contiguous_format); convert_element_type_1333 = None | |
| + convert_element_type_1442 = torch.ops.prims.convert_element_type.default(clamp_max_67, torch.float8_e4m3fn); clamp_max_67 = None | |
| + clone_106 = torch.ops.aten.clone.default(convert_element_type_1442, memory_format = torch.contiguous_format); convert_element_type_1442 = None | |
| permute_834 = torch.ops.aten.permute.default(clone_106, [1, 0]); clone_106 = None | |
| repeat_12 = torch.ops.aten.repeat.default(exp2_37, [32048]); exp2_37 = None | |
| - view_526 = torch.ops.aten.view.default(repeat_12, [1, -1]); repeat_12 = None | |
| + view_567 = torch.ops.aten.view.default(repeat_12, [1, -1]); repeat_12 = None | |
| reciprocal_74 = torch.ops.aten.reciprocal.default(exp2_36); exp2_36 = None | |
| - reciprocal_75 = torch.ops.aten.reciprocal.default(view_526); view_526 = None | |
| - mul_1132 = torch.ops.aten.mul.Tensor(reciprocal_74, reciprocal_75); reciprocal_74 = reciprocal_75 = None | |
| - _scaled_mm_18 = torch.ops.aten._scaled_mm.default(convert_element_type_1330, permute_834, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1330 = permute_834 = None | |
| - mul_1133 = torch.ops.aten.mul.Tensor(_scaled_mm_18, mul_1132); _scaled_mm_18 = mul_1132 = None | |
| - permute_835 = torch.ops.aten.permute.default(convert_element_type_1325, [1, 0]); convert_element_type_1325 = None | |
| - convert_element_type_default_36 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_715_bmm_3, torch.bfloat16); fp8_quant_pos_715_bmm_3 = None | |
| - div_tensor_13 = torch.ops.aten.div.Tensor(convert_element_type_default_36, fp8_scale_pos_715_bmm_3); convert_element_type_default_36 = fp8_scale_pos_715_bmm_3 = None | |
| - convert_element_type_default_37 = torch.ops.prims.convert_element_type.default(div_tensor_13, torch.bfloat16); div_tensor_13 = None | |
| - view_210 = torch.ops.aten.view.default(convert_element_type_default_37, [4096, -1]); convert_element_type_default_37 = None | |
| + reciprocal_75 = torch.ops.aten.reciprocal.default(view_567); view_567 = None | |
| + mul_1214 = torch.ops.aten.mul.Tensor(reciprocal_74, reciprocal_75); reciprocal_74 = reciprocal_75 = None | |
| + _scaled_mm_18 = torch.ops.aten._scaled_mm.default(convert_element_type_1439, permute_834, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1439 = permute_834 = None | |
| + mul_1215 = torch.ops.aten.mul.Tensor(_scaled_mm_18, mul_1214); _scaled_mm_18 = mul_1214 = None | |
| + permute_835 = torch.ops.aten.permute.default(convert_element_type_1433, [1, 0]); convert_element_type_1433 = None | |
| + convert_element_type_default_33 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_702_bmm_3, torch.bfloat16); fp8_quant_pos_702_bmm_3 = None | |
| + div_tensor_13 = torch.ops.aten.div.Tensor(convert_element_type_default_33, fp8_scale_pos_702_bmm_3); convert_element_type_default_33 = fp8_scale_pos_702_bmm_3 = None | |
| + convert_element_type_default_34 = torch.ops.prims.convert_element_type.default(div_tensor_13, torch.bfloat16); div_tensor_13 = None | |
| + view_210 = torch.ops.aten.view.default(convert_element_type_default_34, [4096, -1]); convert_element_type_default_34 = None | |
| cat_15 = torch.ops.aten.cat.default([view_210, mul_117], 1); view_210 = mul_117 = None | |
| pow_29 = torch.ops.aten.pow.Tensor_Scalar(cat_15, 2) | |
| mean_28 = torch.ops.aten.mean.dim(pow_29, [1], True); pow_29 = None | |
| @@ -7863,58 +5153,58 @@ | |
| rsqrt_49 = torch.ops.aten.rsqrt.default(add_116); add_116 = None | |
| mul_132 = torch.ops.aten.mul.Tensor(cat_15, rsqrt_49); cat_15 = None | |
| mul_133 = torch.ops.aten.mul.Tensor(mul_132, primals_318) | |
| - convert_element_type_471 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None | |
| - mm_325 = torch.ops.aten.mm.default(permute_835, convert_element_type_471); permute_835 = convert_element_type_471 = None | |
| + convert_element_type_472 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None | |
| + mm_325 = torch.ops.aten.mm.default(permute_835, convert_element_type_472); permute_835 = convert_element_type_472 = None | |
| permute_836 = torch.ops.aten.permute.default(mm_325, [1, 0]); mm_325 = None | |
| - convert_element_type_1337 = torch.ops.prims.convert_element_type.default(permute_836, torch.float32); permute_836 = None | |
| - permute_837 = torch.ops.aten.permute.default(convert_element_type_1337, [1, 0]); convert_element_type_1337 = None | |
| - convert_element_type_default_3 = torch.ops.prims.convert_element_type.default(mul_1133, torch.float32); mul_1133 = None | |
| - slice_scatter_58 = torch.ops.aten.slice_scatter.default(full_default_297, convert_element_type_1326, 1, 0, 4608); full_default_297 = convert_element_type_1326 = None | |
| - add_459 = torch.ops.aten.add.Tensor(slice_scatter_57, slice_scatter_58); slice_scatter_57 = slice_scatter_58 = None | |
| - abs_66 = torch.ops.aten.abs.default(add_459) | |
| + convert_element_type_1446 = torch.ops.prims.convert_element_type.default(permute_836, torch.float32); permute_836 = None | |
| + permute_837 = torch.ops.aten.permute.default(convert_element_type_1446, [1, 0]); convert_element_type_1446 = None | |
| + convert_element_type_default_4 = torch.ops.prims.convert_element_type.default(mul_1215, torch.float32); mul_1215 = None | |
| + slice_scatter_58 = torch.ops.aten.slice_scatter.default(full_default_297, convert_element_type_1435, 1, 0, 4608); full_default_297 = convert_element_type_1435 = None | |
| + add_500 = torch.ops.aten.add.Tensor(slice_scatter_57, slice_scatter_58); slice_scatter_57 = slice_scatter_58 = None | |
| + abs_66 = torch.ops.aten.abs.default(add_500) | |
| amax_30 = torch.ops.aten.amax.default(abs_66, [-1], True); abs_66 = None | |
| - convert_element_type_1339 = torch.ops.prims.convert_element_type.default(amax_30, torch.float64); amax_30 = None | |
| - clamp_min_110 = torch.ops.aten.clamp_min.default(convert_element_type_1339, 1e-12); convert_element_type_1339 = None | |
| + convert_element_type_1448 = torch.ops.prims.convert_element_type.default(amax_30, torch.float64); amax_30 = None | |
| + clamp_min_110 = torch.ops.aten.clamp_min.default(convert_element_type_1448, 1e-12); convert_element_type_1448 = None | |
| reciprocal_76 = torch.ops.aten.reciprocal.default(clamp_min_110); clamp_min_110 = None | |
| - mul_1134 = torch.ops.aten.mul.Tensor(reciprocal_76, 448.0); reciprocal_76 = None | |
| - convert_element_type_1340 = torch.ops.prims.convert_element_type.default(mul_1134, torch.float32); mul_1134 = None | |
| - log2_38 = torch.ops.aten.log2.default(convert_element_type_1340); convert_element_type_1340 = None | |
| + mul_1216 = torch.ops.aten.mul.Tensor(reciprocal_76, 448.0); reciprocal_76 = None | |
| + convert_element_type_1449 = torch.ops.prims.convert_element_type.default(mul_1216, torch.float32); mul_1216 = None | |
| + log2_38 = torch.ops.aten.log2.default(convert_element_type_1449); convert_element_type_1449 = None | |
| floor_38 = torch.ops.aten.floor.default(log2_38); log2_38 = None | |
| exp2_38 = torch.ops.aten.exp2.default(floor_38); floor_38 = None | |
| - convert_element_type_1341 = torch.ops.prims.convert_element_type.default(add_459, torch.float32) | |
| - mul_1135 = torch.ops.aten.mul.Tensor(convert_element_type_1341, exp2_38); convert_element_type_1341 = None | |
| - clamp_min_111 = torch.ops.aten.clamp_min.default(mul_1135, -448.0); mul_1135 = None | |
| + convert_element_type_1450 = torch.ops.prims.convert_element_type.default(add_500, torch.float32) | |
| + mul_1217 = torch.ops.aten.mul.Tensor(convert_element_type_1450, exp2_38); convert_element_type_1450 = None | |
| + clamp_min_111 = torch.ops.aten.clamp_min.default(mul_1217, -448.0); mul_1217 = None | |
| clamp_max_68 = torch.ops.aten.clamp_max.default(clamp_min_111, 448.0); clamp_min_111 = None | |
| - convert_element_type_1342 = torch.ops.prims.convert_element_type.default(clamp_max_68, torch.float8_e4m3fn); clamp_max_68 = None | |
| + convert_element_type_1451 = torch.ops.prims.convert_element_type.default(clamp_max_68, torch.float8_e4m3fn); clamp_max_68 = None | |
| permute_173 = torch.ops.aten.permute.default(primals_319, [1, 0]); primals_319 = None | |
| abs_2 = torch.ops.aten.abs.default(permute_173) | |
| max_10 = torch.ops.aten.max.default(abs_2); abs_2 = None | |
| - convert_element_type_1343 = torch.ops.prims.convert_element_type.default(max_10, torch.float64); max_10 = None | |
| - clamp_min_112 = torch.ops.aten.clamp_min.default(convert_element_type_1343, 1e-12); convert_element_type_1343 = None | |
| + convert_element_type_1452 = torch.ops.prims.convert_element_type.default(max_10, torch.float64); max_10 = None | |
| + clamp_min_112 = torch.ops.aten.clamp_min.default(convert_element_type_1452, 1e-12); convert_element_type_1452 = None | |
| reciprocal_77 = torch.ops.aten.reciprocal.default(clamp_min_112); clamp_min_112 = None | |
| - mul_1136 = torch.ops.aten.mul.Tensor(reciprocal_77, 448.0); reciprocal_77 = None | |
| - convert_element_type_1344 = torch.ops.prims.convert_element_type.default(mul_1136, torch.float32); mul_1136 = None | |
| - log2_39 = torch.ops.aten.log2.default(convert_element_type_1344); convert_element_type_1344 = None | |
| + mul_1218 = torch.ops.aten.mul.Tensor(reciprocal_77, 448.0); reciprocal_77 = None | |
| + convert_element_type_1453 = torch.ops.prims.convert_element_type.default(mul_1218, torch.float32); mul_1218 = None | |
| + log2_39 = torch.ops.aten.log2.default(convert_element_type_1453); convert_element_type_1453 = None | |
| floor_39 = torch.ops.aten.floor.default(log2_39); log2_39 = None | |
| exp2_39 = torch.ops.aten.exp2.default(floor_39); floor_39 = None | |
| - mul_1137 = torch.ops.aten.mul.Tensor(permute_173, exp2_39); permute_173 = None | |
| - clamp_min_113 = torch.ops.aten.clamp_min.default(mul_1137, -448.0); mul_1137 = None | |
| + mul_1219 = torch.ops.aten.mul.Tensor(permute_173, exp2_39); permute_173 = None | |
| + clamp_min_113 = torch.ops.aten.clamp_min.default(mul_1219, -448.0); mul_1219 = None | |
| clamp_max_69 = torch.ops.aten.clamp_max.default(clamp_min_113, 448.0); clamp_min_113 = None | |
| - convert_element_type_1345 = torch.ops.prims.convert_element_type.default(clamp_max_69, torch.float8_e4m3fn); clamp_max_69 = None | |
| - clone_107 = torch.ops.aten.clone.default(convert_element_type_1345, memory_format = torch.contiguous_format); convert_element_type_1345 = None | |
| + convert_element_type_1454 = torch.ops.prims.convert_element_type.default(clamp_max_69, torch.float8_e4m3fn); clamp_max_69 = None | |
| + clone_107 = torch.ops.aten.clone.default(convert_element_type_1454, memory_format = torch.contiguous_format); convert_element_type_1454 = None | |
| permute_840 = torch.ops.aten.permute.default(clone_107, [1, 0]); clone_107 = None | |
| repeat_13 = torch.ops.aten.repeat.default(exp2_39, [24816]); exp2_39 = None | |
| - view_531 = torch.ops.aten.view.default(repeat_13, [1, -1]); repeat_13 = None | |
| + view_572 = torch.ops.aten.view.default(repeat_13, [1, -1]); repeat_13 = None | |
| reciprocal_78 = torch.ops.aten.reciprocal.default(exp2_38); exp2_38 = None | |
| - reciprocal_79 = torch.ops.aten.reciprocal.default(view_531); view_531 = None | |
| - mul_1138 = torch.ops.aten.mul.Tensor(reciprocal_78, reciprocal_79); reciprocal_78 = reciprocal_79 = None | |
| - _scaled_mm_19 = torch.ops.aten._scaled_mm.default(convert_element_type_1342, permute_840, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1342 = permute_840 = full_default_267 = None | |
| - mul_1139 = torch.ops.aten.mul.Tensor(_scaled_mm_19, mul_1138); _scaled_mm_19 = mul_1138 = None | |
| - permute_841 = torch.ops.aten.permute.default(add_459, [1, 0]); add_459 = None | |
| - convert_element_type_default_38 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_716_bmm_4, torch.bfloat16); fp8_quant_pos_716_bmm_4 = None | |
| - div_tensor_14 = torch.ops.aten.div.Tensor(convert_element_type_default_38, fp8_scale_pos_716_bmm_4); convert_element_type_default_38 = fp8_scale_pos_716_bmm_4 = None | |
| - convert_element_type_default_39 = torch.ops.prims.convert_element_type.default(div_tensor_14, torch.bfloat16); div_tensor_14 = None | |
| - view_209 = torch.ops.aten.view.default(convert_element_type_default_39, [4096, -1]); convert_element_type_default_39 = None | |
| + reciprocal_79 = torch.ops.aten.reciprocal.default(view_572); view_572 = None | |
| + mul_1220 = torch.ops.aten.mul.Tensor(reciprocal_78, reciprocal_79); reciprocal_78 = reciprocal_79 = None | |
| + _scaled_mm_19 = torch.ops.aten._scaled_mm.default(convert_element_type_1451, permute_840, full_default_267, full_default_267, None, None, torch.bfloat16); convert_element_type_1451 = permute_840 = full_default_267 = None | |
| + mul_1221 = torch.ops.aten.mul.Tensor(_scaled_mm_19, mul_1220); _scaled_mm_19 = mul_1220 = None | |
| + permute_841 = torch.ops.aten.permute.default(add_500, [1, 0]); add_500 = None | |
| + convert_element_type_default_35 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_703_bmm_4, torch.bfloat16); fp8_quant_pos_703_bmm_4 = None | |
| + div_tensor_14 = torch.ops.aten.div.Tensor(convert_element_type_default_35, fp8_scale_pos_703_bmm_4); convert_element_type_default_35 = fp8_scale_pos_703_bmm_4 = None | |
| + convert_element_type_default_36 = torch.ops.prims.convert_element_type.default(div_tensor_14, torch.bfloat16); div_tensor_14 = None | |
| + view_209 = torch.ops.aten.view.default(convert_element_type_default_36, [4096, -1]); convert_element_type_default_36 = None | |
| cat_14 = torch.ops.aten.cat.default([view_209, mul_115], 1); view_209 = mul_115 = None | |
| pow_28 = torch.ops.aten.pow.Tensor_Scalar(cat_14, 2) | |
| mean_27 = torch.ops.aten.mean.dim(pow_28, [1], True); pow_28 = None | |
| @@ -7922,109 +5212,109 @@ | |
| rsqrt_48 = torch.ops.aten.rsqrt.default(add_115); add_115 = None | |
| mul_130 = torch.ops.aten.mul.Tensor(cat_14, rsqrt_48); cat_14 = None | |
| mul_131 = torch.ops.aten.mul.Tensor(mul_130, primals_317) | |
| - convert_element_type_463 = torch.ops.prims.convert_element_type.default(mul_131, torch.bfloat16); mul_131 = None | |
| - mm_326 = torch.ops.aten.mm.default(permute_841, convert_element_type_463); permute_841 = convert_element_type_463 = None | |
| + convert_element_type_464 = torch.ops.prims.convert_element_type.default(mul_131, torch.bfloat16); mul_131 = None | |
| + mm_326 = torch.ops.aten.mm.default(permute_841, convert_element_type_464); permute_841 = convert_element_type_464 = None | |
| permute_842 = torch.ops.aten.permute.default(mm_326, [1, 0]); mm_326 = None | |
| - convert_element_type_1349 = torch.ops.prims.convert_element_type.default(permute_842, torch.float32); permute_842 = None | |
| - permute_843 = torch.ops.aten.permute.default(convert_element_type_1349, [1, 0]); convert_element_type_1349 = None | |
| - convert_element_type_default_2 = torch.ops.prims.convert_element_type.default(mul_1139, torch.float32); mul_1139 = None | |
| - mul_1140 = torch.ops.aten.mul.Tensor(convert_element_type_default_3, primals_318); primals_318 = None | |
| - mul_1142 = torch.ops.aten.mul.Tensor(mul_132, mul_1140) | |
| - sum_186 = torch.ops.aten.sum.dim_IntList(mul_1142, [1], True); mul_1142 = None | |
| + convert_element_type_1458 = torch.ops.prims.convert_element_type.default(permute_842, torch.float32); permute_842 = None | |
| + permute_843 = torch.ops.aten.permute.default(convert_element_type_1458, [1, 0]); convert_element_type_1458 = None | |
| + convert_element_type_default_3 = torch.ops.prims.convert_element_type.default(mul_1221, torch.float32); mul_1221 = None | |
| + mul_1222 = torch.ops.aten.mul.Tensor(convert_element_type_default_4, primals_318); primals_318 = None | |
| + mul_1224 = torch.ops.aten.mul.Tensor(mul_132, mul_1222) | |
| + sum_186 = torch.ops.aten.sum.dim_IntList(mul_1224, [1], True); mul_1224 = None | |
| div_133 = torch.ops.aten.div.Tensor(mul_132, 32048) | |
| - mul_1143 = torch.ops.aten.mul.Tensor(div_133, sum_186); div_133 = sum_186 = None | |
| - sub_272 = torch.ops.aten.sub.Tensor(mul_1140, mul_1143); mul_1140 = mul_1143 = None | |
| - mul_1144 = torch.ops.aten.mul.Tensor(sub_272, rsqrt_49); sub_272 = rsqrt_49 = None | |
| - mul_1145 = torch.ops.aten.mul.Tensor(convert_element_type_default_3, mul_132); convert_element_type_default_3 = mul_132 = None | |
| - sum_187 = torch.ops.aten.sum.dim_IntList(mul_1145, [0]); mul_1145 = None | |
| - mul_1146 = torch.ops.aten.mul.Tensor(convert_element_type_default_2, primals_317); primals_317 = None | |
| - mul_1148 = torch.ops.aten.mul.Tensor(mul_130, mul_1146) | |
| - sum_188 = torch.ops.aten.sum.dim_IntList(mul_1148, [1], True); mul_1148 = None | |
| + mul_1225 = torch.ops.aten.mul.Tensor(div_133, sum_186); div_133 = sum_186 = None | |
| + sub_231 = torch.ops.aten.sub.Tensor(mul_1222, mul_1225); mul_1222 = mul_1225 = None | |
| + mul_1226 = torch.ops.aten.mul.Tensor(sub_231, rsqrt_49); sub_231 = rsqrt_49 = None | |
| + mul_1227 = torch.ops.aten.mul.Tensor(convert_element_type_default_4, mul_132); convert_element_type_default_4 = mul_132 = None | |
| + sum_187 = torch.ops.aten.sum.dim_IntList(mul_1227, [0]); mul_1227 = None | |
| + mul_1228 = torch.ops.aten.mul.Tensor(convert_element_type_default_3, primals_317); primals_317 = None | |
| + mul_1230 = torch.ops.aten.mul.Tensor(mul_130, mul_1228) | |
| + sum_188 = torch.ops.aten.sum.dim_IntList(mul_1230, [1], True); mul_1230 = None | |
| div_134 = torch.ops.aten.div.Tensor(mul_130, 24816) | |
| - mul_1149 = torch.ops.aten.mul.Tensor(div_134, sum_188); div_134 = sum_188 = None | |
| - sub_273 = torch.ops.aten.sub.Tensor(mul_1146, mul_1149); mul_1146 = mul_1149 = None | |
| - mul_1150 = torch.ops.aten.mul.Tensor(sub_273, rsqrt_48); sub_273 = rsqrt_48 = None | |
| - mul_1151 = torch.ops.aten.mul.Tensor(convert_element_type_default_2, mul_130); convert_element_type_default_2 = mul_130 = None | |
| - sum_189 = torch.ops.aten.sum.dim_IntList(mul_1151, [0]); mul_1151 = None | |
| - slice_107 = torch.ops.aten.slice.Tensor(mul_1144, 1, 0, 27952) | |
| - slice_108 = torch.ops.aten.slice.Tensor(mul_1144, 1, 27952, 32048); mul_1144 = None | |
| - convert_element_type_1351 = torch.ops.prims.convert_element_type.default(slice_107, torch.bfloat16); slice_107 = None | |
| - add_460 = torch.ops.aten.add.Tensor(add_416, slice_108); add_416 = slice_108 = None | |
| - slice_109 = torch.ops.aten.slice.Tensor(mul_1150, 1, 0, 20720) | |
| - slice_110 = torch.ops.aten.slice.Tensor(mul_1150, 1, 20720, 24816); mul_1150 = None | |
| - convert_element_type_1352 = torch.ops.prims.convert_element_type.default(slice_109, torch.bfloat16); slice_109 = None | |
| - add_461 = torch.ops.aten.add.Tensor(add_417, slice_110); add_417 = slice_110 = None | |
| - view_535 = torch.ops.aten.view.default(convert_element_type_1351, [4096, 16, 1747]); convert_element_type_1351 = None | |
| - view_536 = torch.ops.aten.view.default(convert_element_type_1352, [4096, 16, 1295]); convert_element_type_1352 = None | |
| - bmm_29 = torch.ops.aten.bmm.default(permute_844, view_536); permute_844 = None | |
| - convert_element_type_default_56 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_852_permute_845, torch.bfloat16); fp8_quant_pos_852_permute_845 = None | |
| - div_tensor_23 = torch.ops.aten.div.Tensor(convert_element_type_default_56, fp8_scale_pos_852_permute_845); convert_element_type_default_56 = fp8_scale_pos_852_permute_845 = None | |
| - convert_element_type_default_57 = torch.ops.prims.convert_element_type.default(div_tensor_23, torch.bfloat16); div_tensor_23 = None | |
| - bmm_30 = torch.ops.aten.bmm.default(view_536, convert_element_type_default_57); view_536 = convert_element_type_default_57 = None | |
| - convert_element_type_1357 = torch.ops.prims.convert_element_type.default(bmm_29, torch.float32); bmm_29 = None | |
| - permute_846 = torch.ops.aten.permute.default(convert_element_type_1357, [0, 2, 1]); convert_element_type_1357 = None | |
| - bmm_31 = torch.ops.aten.bmm.default(permute_847, view_535); permute_847 = None | |
| - convert_element_type_default_58 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_854_permute_848, torch.bfloat16); fp8_quant_pos_854_permute_848 = None | |
| - div_tensor_24 = torch.ops.aten.div.Tensor(convert_element_type_default_58, fp8_scale_pos_854_permute_848); convert_element_type_default_58 = fp8_scale_pos_854_permute_848 = None | |
| - convert_element_type_default_59 = torch.ops.prims.convert_element_type.default(div_tensor_24, torch.bfloat16); div_tensor_24 = None | |
| - bmm_32 = torch.ops.aten.bmm.default(view_535, convert_element_type_default_59); view_535 = convert_element_type_default_59 = None | |
| - convert_element_type_1362 = torch.ops.prims.convert_element_type.default(bmm_31, torch.float32); bmm_31 = None | |
| - permute_849 = torch.ops.aten.permute.default(convert_element_type_1362, [0, 2, 1]); convert_element_type_1362 = None | |
| + mul_1231 = torch.ops.aten.mul.Tensor(div_134, sum_188); div_134 = sum_188 = None | |
| + sub_232 = torch.ops.aten.sub.Tensor(mul_1228, mul_1231); mul_1228 = mul_1231 = None | |
| + mul_1232 = torch.ops.aten.mul.Tensor(sub_232, rsqrt_48); sub_232 = rsqrt_48 = None | |
| + mul_1233 = torch.ops.aten.mul.Tensor(convert_element_type_default_3, mul_130); convert_element_type_default_3 = mul_130 = None | |
| + sum_189 = torch.ops.aten.sum.dim_IntList(mul_1233, [0]); mul_1233 = None | |
| + slice_107 = torch.ops.aten.slice.Tensor(mul_1226, 1, 0, 27952) | |
| + slice_108 = torch.ops.aten.slice.Tensor(mul_1226, 1, 27952, 32048); mul_1226 = None | |
| + convert_element_type_1460 = torch.ops.prims.convert_element_type.default(slice_107, torch.bfloat16); slice_107 = None | |
| + add_501 = torch.ops.aten.add.Tensor(add_441, slice_108); add_441 = slice_108 = None | |
| + slice_109 = torch.ops.aten.slice.Tensor(mul_1232, 1, 0, 20720) | |
| + slice_110 = torch.ops.aten.slice.Tensor(mul_1232, 1, 20720, 24816); mul_1232 = None | |
| + convert_element_type_1461 = torch.ops.prims.convert_element_type.default(slice_109, torch.bfloat16); slice_109 = None | |
| + add_502 = torch.ops.aten.add.Tensor(add_442, slice_110); add_442 = slice_110 = None | |
| + view_576 = torch.ops.aten.view.default(convert_element_type_1460, [4096, 16, 1747]); convert_element_type_1460 = None | |
| + view_577 = torch.ops.aten.view.default(convert_element_type_1461, [4096, 16, 1295]); convert_element_type_1461 = None | |
| + bmm_29 = torch.ops.aten.bmm.default(permute_844, view_577); permute_844 = None | |
| + convert_element_type_default_51 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_834_permute_845, torch.bfloat16); fp8_quant_pos_834_permute_845 = None | |
| + div_tensor_22 = torch.ops.aten.div.Tensor(convert_element_type_default_51, fp8_scale_pos_834_permute_845); convert_element_type_default_51 = fp8_scale_pos_834_permute_845 = None | |
| + convert_element_type_default_52 = torch.ops.prims.convert_element_type.default(div_tensor_22, torch.bfloat16); div_tensor_22 = None | |
| + bmm_30 = torch.ops.aten.bmm.default(view_577, convert_element_type_default_52); view_577 = convert_element_type_default_52 = None | |
| + convert_element_type_1466 = torch.ops.prims.convert_element_type.default(bmm_29, torch.float32); bmm_29 = None | |
| + permute_846 = torch.ops.aten.permute.default(convert_element_type_1466, [0, 2, 1]); convert_element_type_1466 = None | |
| + bmm_31 = torch.ops.aten.bmm.default(permute_847, view_576); permute_847 = None | |
| + convert_element_type_default_53 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_836_permute_848, torch.bfloat16); fp8_quant_pos_836_permute_848 = None | |
| + div_tensor_23 = torch.ops.aten.div.Tensor(convert_element_type_default_53, fp8_scale_pos_836_permute_848); convert_element_type_default_53 = fp8_scale_pos_836_permute_848 = None | |
| + convert_element_type_default_54 = torch.ops.prims.convert_element_type.default(div_tensor_23, torch.bfloat16); div_tensor_23 = None | |
| + bmm_32 = torch.ops.aten.bmm.default(view_576, convert_element_type_default_54); view_576 = convert_element_type_default_54 = None | |
| + convert_element_type_1471 = torch.ops.prims.convert_element_type.default(bmm_31, torch.float32); bmm_31 = None | |
| + permute_849 = torch.ops.aten.permute.default(convert_element_type_1471, [0, 2, 1]); convert_element_type_1471 = None | |
| slice_12 = torch.ops.aten.slice.Tensor(mm_53, 1, 0, 512) | |
| - convert_element_type_431 = torch.ops.prims.convert_element_type.default(slice_12, torch.float32); slice_12 = None | |
| - pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) | |
| + convert_element_type_432 = torch.ops.prims.convert_element_type.default(slice_12, torch.float32); slice_12 = None | |
| + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_432, 2) | |
| mean_23 = torch.ops.aten.mean.dim(pow_24, [1], True); pow_24 = None | |
| add_109 = torch.ops.aten.add.Scalar(mean_23, 1.1920928955078125e-07); mean_23 = None | |
| rsqrt_44 = torch.ops.aten.rsqrt.default(add_109); add_109 = None | |
| - mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_44); convert_element_type_431 = None | |
| + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_432, rsqrt_44) | |
| mul_119 = torch.ops.aten.mul.Tensor(mul_118, primals_309) | |
| - sigmoid_7 = torch.ops.aten.sigmoid.default(mul_119) | |
| - mul_122 = torch.ops.aten.mul.Tensor(mul_119, sigmoid_7) | |
| - convert_element_type_433 = torch.ops.prims.convert_element_type.default(mul_122, torch.bfloat16); mul_122 = None | |
| - convert_element_type_434 = torch.ops.prims.convert_element_type.default(primals_311, torch.bfloat16); primals_311 = None | |
| - permute_167 = torch.ops.aten.permute.default(convert_element_type_434, [1, 0]); convert_element_type_434 = None | |
| - mm_55 = torch.ops.aten.mm.default(convert_element_type_433, permute_167) | |
| + convert_element_type_433 = torch.ops.prims.convert_element_type.default(mul_119, torch.bfloat16); mul_119 = None | |
| + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_433) | |
| + mul_122 = torch.ops.aten.mul.Tensor(convert_element_type_433, sigmoid_7) | |
| + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_311, torch.bfloat16); primals_311 = None | |
| + permute_167 = torch.ops.aten.permute.default(convert_element_type_436, [1, 0]); convert_element_type_436 = None | |
| + mm_55 = torch.ops.aten.mm.default(mul_122, permute_167) | |
| slice_14 = torch.ops.aten.slice.Tensor(mm_55, 1, 0, 256); mm_55 = None | |
| - convert_element_type_441 = torch.ops.prims.convert_element_type.default(slice_14, torch.float32); slice_14 = None | |
| - pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_441, 2) | |
| + convert_element_type_442 = torch.ops.prims.convert_element_type.default(slice_14, torch.float32); slice_14 = None | |
| + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_442, 2) | |
| mean_25 = torch.ops.aten.mean.dim(pow_26, [1], True); pow_26 = None | |
| add_112 = torch.ops.aten.add.Scalar(mean_25, 1.1920928955078125e-07); mean_25 = None | |
| rsqrt_46 = torch.ops.aten.rsqrt.default(add_112); add_112 = None | |
| - mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_441, rsqrt_46); convert_element_type_441 = None | |
| + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_442, rsqrt_46) | |
| mul_125 = torch.ops.aten.mul.Tensor(mul_124, primals_313) | |
| - sigmoid_9 = torch.ops.aten.sigmoid.default(mul_125) | |
| - mul_128 = torch.ops.aten.mul.Tensor(mul_125, sigmoid_9) | |
| - convert_element_type_443 = torch.ops.prims.convert_element_type.default(mul_128, torch.bfloat16); mul_128 = None | |
| - convert_element_type_444 = torch.ops.prims.convert_element_type.default(primals_315, torch.bfloat16); primals_315 = None | |
| - permute_169 = torch.ops.aten.permute.default(convert_element_type_444, [1, 0]); convert_element_type_444 = None | |
| - mm_57 = torch.ops.aten.mm.default(convert_element_type_443, permute_169) | |
| + convert_element_type_443 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None | |
| + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_443) | |
| + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_443, sigmoid_9) | |
| + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_315, torch.bfloat16); primals_315 = None | |
| + permute_169 = torch.ops.aten.permute.default(convert_element_type_446, [1, 0]); convert_element_type_446 = None | |
| + mm_57 = torch.ops.aten.mm.default(mul_128, permute_169) | |
| slice_16 = torch.ops.aten.slice.Tensor(mm_57, 1, 0, 27952) | |
| view_195 = torch.ops.aten.view.default(slice_16, [4096, -1, 1747]); slice_16 = None | |
| slice_18 = torch.ops.aten.slice.Tensor(view_195, 2, 0, 1295); view_195 = None | |
| expand_2 = torch.ops.aten.expand.default(slice_18, [4096, 16, 1295]); slice_18 = None | |
| permute_850 = torch.ops.aten.permute.default(expand_2, [0, 2, 1]); expand_2 = None | |
| bmm_33 = torch.ops.aten.bmm.default(permute_850, bmm_30); permute_850 = None | |
| - convert_element_type_default_60 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_855_permute_851, torch.bfloat16); fp8_quant_pos_855_permute_851 = None | |
| - div_tensor_25 = torch.ops.aten.div.Tensor(convert_element_type_default_60, fp8_scale_pos_855_permute_851); convert_element_type_default_60 = fp8_scale_pos_855_permute_851 = None | |
| - convert_element_type_default_61 = torch.ops.prims.convert_element_type.default(div_tensor_25, torch.bfloat16); div_tensor_25 = None | |
| - bmm_34 = torch.ops.aten.bmm.default(bmm_30, convert_element_type_default_61); bmm_30 = convert_element_type_default_61 = None | |
| - convert_element_type_1367 = torch.ops.prims.convert_element_type.default(bmm_33, torch.float32); bmm_33 = None | |
| - add_462 = torch.ops.aten.add.Tensor(permute_846, convert_element_type_1367); permute_846 = convert_element_type_1367 = None | |
| + convert_element_type_default_55 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_837_permute_851, torch.bfloat16); fp8_quant_pos_837_permute_851 = None | |
| + div_tensor_24 = torch.ops.aten.div.Tensor(convert_element_type_default_55, fp8_scale_pos_837_permute_851); convert_element_type_default_55 = fp8_scale_pos_837_permute_851 = None | |
| + convert_element_type_default_56 = torch.ops.prims.convert_element_type.default(div_tensor_24, torch.bfloat16); div_tensor_24 = None | |
| + bmm_34 = torch.ops.aten.bmm.default(bmm_30, convert_element_type_default_56); bmm_30 = convert_element_type_default_56 = None | |
| + convert_element_type_1476 = torch.ops.prims.convert_element_type.default(bmm_33, torch.float32); bmm_33 = None | |
| + add_503 = torch.ops.aten.add.Tensor(permute_846, convert_element_type_1476); permute_846 = convert_element_type_1476 = None | |
| full_default_364 = torch.ops.aten.full.default([4096, 16, 1747], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_59 = torch.ops.aten.slice_scatter.default(full_default_364, bmm_34, 2, 0, 1295); full_default_364 = bmm_34 = None | |
| - convert_element_type_442 = torch.ops.prims.convert_element_type.default(add_111, torch.float32); add_111 = None | |
| - pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_442, 2) | |
| + convert_element_type_444 = torch.ops.prims.convert_element_type.default(add_111, torch.float32); add_111 = None | |
| + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_444, 2) | |
| mean_26 = torch.ops.aten.mean.dim(pow_27, [1], True); pow_27 = None | |
| add_113 = torch.ops.aten.add.Scalar(mean_26, 1.1920928955078125e-07); mean_26 = None | |
| rsqrt_47 = torch.ops.aten.rsqrt.default(add_113); add_113 = None | |
| - mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_442, rsqrt_47); convert_element_type_442 = None | |
| + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_444, rsqrt_47) | |
| mul_127 = torch.ops.aten.mul.Tensor(mul_126, primals_314) | |
| - sigmoid_10 = torch.ops.aten.sigmoid.default(mul_127) | |
| - mul_129 = torch.ops.aten.mul.Tensor(mul_127, sigmoid_10) | |
| - convert_element_type_447 = torch.ops.prims.convert_element_type.default(mul_129, torch.bfloat16); mul_129 = None | |
| - convert_element_type_448 = torch.ops.prims.convert_element_type.default(primals_316, torch.bfloat16); primals_316 = None | |
| - permute_170 = torch.ops.aten.permute.default(convert_element_type_448, [1, 0]); convert_element_type_448 = None | |
| - mm_58 = torch.ops.aten.mm.default(convert_element_type_447, permute_170) | |
| + convert_element_type_445 = torch.ops.prims.convert_element_type.default(mul_127, torch.bfloat16); mul_127 = None | |
| + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_445) | |
| + mul_129 = torch.ops.aten.mul.Tensor(convert_element_type_445, sigmoid_10) | |
| + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_316, torch.bfloat16); primals_316 = None | |
| + permute_170 = torch.ops.aten.permute.default(convert_element_type_449, [1, 0]); convert_element_type_449 = None | |
| + mm_58 = torch.ops.aten.mm.default(mul_129, permute_170) | |
| slice_17 = torch.ops.aten.slice.Tensor(mm_57, 1, 27952, 9223372036854775807); mm_57 = None | |
| index_12 = torch.ops.aten.index.Tensor(slice_17, [sub_32]); slice_17 = None | |
| add_114 = torch.ops.aten.add.Tensor(mm_58, index_12); mm_58 = index_12 = None | |
| @@ -8032,323 +5322,363 @@ | |
| expand = torch.ops.aten.expand.default(view_196, [4096, 16, 1747]); view_196 = None | |
| permute_852 = torch.ops.aten.permute.default(expand, [0, 2, 1]); expand = None | |
| bmm_35 = torch.ops.aten.bmm.default(permute_852, bmm_32); permute_852 = None | |
| - convert_element_type_default_62 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_856_permute_853, torch.bfloat16); fp8_quant_pos_856_permute_853 = None | |
| - div_tensor_26 = torch.ops.aten.div.Tensor(convert_element_type_default_62, fp8_scale_pos_856_permute_853); convert_element_type_default_62 = fp8_scale_pos_856_permute_853 = None | |
| - convert_element_type_default_63 = torch.ops.prims.convert_element_type.default(div_tensor_26, torch.bfloat16); div_tensor_26 = None | |
| - bmm_36 = torch.ops.aten.bmm.default(bmm_32, convert_element_type_default_63); bmm_32 = convert_element_type_default_63 = None | |
| - convert_element_type_1372 = torch.ops.prims.convert_element_type.default(bmm_35, torch.float32); bmm_35 = None | |
| - add_463 = torch.ops.aten.add.Tensor(permute_849, convert_element_type_1372); permute_849 = convert_element_type_1372 = None | |
| - slice_111 = torch.ops.aten.slice.Tensor(add_463, 1, 0, 1295) | |
| - slice_112 = torch.ops.aten.slice.Tensor(add_463, 1, 1295, 1747); add_463 = None | |
| + convert_element_type_default_57 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_838_permute_853, torch.bfloat16); fp8_quant_pos_838_permute_853 = None | |
| + div_tensor_25 = torch.ops.aten.div.Tensor(convert_element_type_default_57, fp8_scale_pos_838_permute_853); convert_element_type_default_57 = fp8_scale_pos_838_permute_853 = None | |
| + convert_element_type_default_58 = torch.ops.prims.convert_element_type.default(div_tensor_25, torch.bfloat16); div_tensor_25 = None | |
| + bmm_36 = torch.ops.aten.bmm.default(bmm_32, convert_element_type_default_58); bmm_32 = convert_element_type_default_58 = None | |
| + convert_element_type_1481 = torch.ops.prims.convert_element_type.default(bmm_35, torch.float32); bmm_35 = None | |
| + add_504 = torch.ops.aten.add.Tensor(permute_849, convert_element_type_1481); permute_849 = convert_element_type_1481 = None | |
| + slice_111 = torch.ops.aten.slice.Tensor(add_504, 1, 0, 1295) | |
| + slice_112 = torch.ops.aten.slice.Tensor(add_504, 1, 1295, 1747); add_504 = None | |
| full_default_365 = torch.ops.aten.full.default([4096, 1295, 112], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| index_put_49 = torch.ops.aten.index_put.default(full_default_365, [sub_32], slice_111, True); full_default_365 = slice_111 = None | |
| - add_464 = torch.ops.aten.add.Tensor(add_462, index_put_49); add_462 = index_put_49 = None | |
| - view_549 = torch.ops.aten.view.default(bmm_36, [4096, 27952]); bmm_36 = None | |
| - view_550 = torch.ops.aten.view.default(slice_scatter_59, [4096, 27952]); slice_scatter_59 = None | |
| + add_505 = torch.ops.aten.add.Tensor(add_503, index_put_49); add_503 = index_put_49 = None | |
| + view_590 = torch.ops.aten.view.default(bmm_36, [4096, 27952]); bmm_36 = None | |
| + view_591 = torch.ops.aten.view.default(slice_scatter_59, [4096, 27952]); slice_scatter_59 = None | |
| full_default_366 = torch.ops.aten.full.default([4096, 27952], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - index_put_50 = torch.ops.aten.index_put.default(full_default_366, [sub_32], view_549, True); full_default_366 = None | |
| + index_put_50 = torch.ops.aten.index_put.default(full_default_366, [sub_32], view_590, True); full_default_366 = None | |
| full_default_367 = torch.ops.aten.full.default([4096, 55904], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_60 = torch.ops.aten.slice_scatter.default(full_default_367, index_put_50, 1, 27952, 9223372036854775807); index_put_50 = None | |
| - permute_854 = torch.ops.aten.permute.default(view_549, [1, 0]) | |
| - mm_327 = torch.ops.aten.mm.default(permute_854, convert_element_type_447); permute_854 = convert_element_type_447 = None | |
| + permute_854 = torch.ops.aten.permute.default(view_590, [1, 0]) | |
| + mm_327 = torch.ops.aten.mm.default(permute_854, mul_129); permute_854 = mul_129 = None | |
| permute_856 = torch.ops.aten.permute.default(permute_170, [1, 0]); permute_170 = None | |
| - mm_328 = torch.ops.aten.mm.default(view_549, permute_856); view_549 = permute_856 = None | |
| - convert_element_type_1377 = torch.ops.prims.convert_element_type.default(mm_327, torch.float32); mm_327 = None | |
| - convert_element_type_1378 = torch.ops.prims.convert_element_type.default(mm_328, torch.float32); mm_328 = None | |
| - slice_scatter_61 = torch.ops.aten.slice_scatter.default(full_default_367, view_550, 1, 0, 27952); full_default_367 = view_550 = None | |
| - add_465 = torch.ops.aten.add.Tensor(slice_scatter_60, slice_scatter_61); slice_scatter_60 = slice_scatter_61 = None | |
| - permute_858 = torch.ops.aten.permute.default(add_465, [1, 0]) | |
| - mm_329 = torch.ops.aten.mm.default(permute_858, convert_element_type_443); permute_858 = convert_element_type_443 = None | |
| + mm_328 = torch.ops.aten.mm.default(view_590, permute_856); view_590 = permute_856 = None | |
| + convert_element_type_1486 = torch.ops.prims.convert_element_type.default(mm_327, torch.float32); mm_327 = None | |
| + slice_scatter_61 = torch.ops.aten.slice_scatter.default(full_default_367, view_591, 1, 0, 27952); full_default_367 = view_591 = None | |
| + add_506 = torch.ops.aten.add.Tensor(slice_scatter_60, slice_scatter_61); slice_scatter_60 = slice_scatter_61 = None | |
| + permute_858 = torch.ops.aten.permute.default(add_506, [1, 0]) | |
| + mm_329 = torch.ops.aten.mm.default(permute_858, mul_128); permute_858 = mul_128 = None | |
| permute_860 = torch.ops.aten.permute.default(permute_169, [1, 0]); permute_169 = None | |
| - mm_330 = torch.ops.aten.mm.default(add_465, permute_860); add_465 = permute_860 = None | |
| - convert_element_type_1383 = torch.ops.prims.convert_element_type.default(mm_329, torch.float32); mm_329 = None | |
| - convert_element_type_1384 = torch.ops.prims.convert_element_type.default(mm_330, torch.float32); mm_330 = None | |
| - mul_1152 = torch.ops.aten.mul.Tensor(convert_element_type_1378, mul_127); mul_127 = None | |
| - mul_1153 = torch.ops.aten.mul.Tensor(convert_element_type_1378, sigmoid_10); convert_element_type_1378 = None | |
| - sub_274 = torch.ops.aten.sub.Tensor(1, sigmoid_10) | |
| - mul_1154 = torch.ops.aten.mul.Tensor(sigmoid_10, sub_274); sigmoid_10 = sub_274 = None | |
| - mul_1155 = torch.ops.aten.mul.Tensor(mul_1152, mul_1154); mul_1152 = mul_1154 = None | |
| - add_466 = torch.ops.aten.add.Tensor(mul_1153, mul_1155); mul_1153 = mul_1155 = None | |
| - mul_1156 = torch.ops.aten.mul.Tensor(convert_element_type_1384, mul_125); mul_125 = None | |
| - mul_1157 = torch.ops.aten.mul.Tensor(convert_element_type_1384, sigmoid_9); convert_element_type_1384 = None | |
| - sub_275 = torch.ops.aten.sub.Tensor(1, sigmoid_9) | |
| - mul_1158 = torch.ops.aten.mul.Tensor(sigmoid_9, sub_275); sigmoid_9 = sub_275 = None | |
| - mul_1159 = torch.ops.aten.mul.Tensor(mul_1156, mul_1158); mul_1156 = mul_1158 = None | |
| - add_467 = torch.ops.aten.add.Tensor(mul_1157, mul_1159); mul_1157 = mul_1159 = None | |
| - mul_1160 = torch.ops.aten.mul.Tensor(add_466, primals_314); primals_314 = None | |
| - mul_1162 = torch.ops.aten.mul.Tensor(mul_126, mul_1160) | |
| - sum_190 = torch.ops.aten.sum.dim_IntList(mul_1162, [1], True); mul_1162 = None | |
| - div_135 = torch.ops.aten.div.Tensor(mul_126, 256) | |
| - mul_1163 = torch.ops.aten.mul.Tensor(div_135, sum_190); div_135 = sum_190 = None | |
| - sub_276 = torch.ops.aten.sub.Tensor(mul_1160, mul_1163); mul_1160 = mul_1163 = None | |
| - mul_1164 = torch.ops.aten.mul.Tensor(sub_276, rsqrt_47); sub_276 = rsqrt_47 = None | |
| - mul_1165 = torch.ops.aten.mul.Tensor(add_466, mul_126); add_466 = mul_126 = None | |
| - sum_191 = torch.ops.aten.sum.dim_IntList(mul_1165, [0]); mul_1165 = None | |
| - convert_element_type_1385 = torch.ops.prims.convert_element_type.default(mul_1164, torch.bfloat16); mul_1164 = None | |
| - mul_1166 = torch.ops.aten.mul.Tensor(add_467, primals_313); primals_313 = None | |
| - mul_1168 = torch.ops.aten.mul.Tensor(mul_124, mul_1166) | |
| - sum_192 = torch.ops.aten.sum.dim_IntList(mul_1168, [1], True); mul_1168 = None | |
| - div_136 = torch.ops.aten.div.Tensor(mul_124, 256) | |
| - mul_1169 = torch.ops.aten.mul.Tensor(div_136, sum_192); div_136 = sum_192 = None | |
| - sub_277 = torch.ops.aten.sub.Tensor(mul_1166, mul_1169); mul_1166 = mul_1169 = None | |
| - mul_1170 = torch.ops.aten.mul.Tensor(sub_277, rsqrt_46); sub_277 = rsqrt_46 = None | |
| - mul_1171 = torch.ops.aten.mul.Tensor(add_467, mul_124); add_467 = mul_124 = None | |
| - sum_193 = torch.ops.aten.sum.dim_IntList(mul_1171, [0]); mul_1171 = None | |
| - convert_element_type_1386 = torch.ops.prims.convert_element_type.default(mul_1170, torch.bfloat16); mul_1170 = None | |
| + mm_330 = torch.ops.aten.mm.default(add_506, permute_860); add_506 = permute_860 = None | |
| + convert_element_type_1491 = torch.ops.prims.convert_element_type.default(mm_329, torch.float32); mm_329 = None | |
| + mul_1234 = torch.ops.aten.mul.Tensor(mm_328, convert_element_type_445); convert_element_type_445 = None | |
| + mul_1235 = torch.ops.aten.mul.Tensor(mm_328, sigmoid_10); mm_328 = None | |
| + convert_element_type_1492 = torch.ops.prims.convert_element_type.default(mul_1234, torch.float32); mul_1234 = None | |
| + convert_element_type_1493 = torch.ops.prims.convert_element_type.default(sigmoid_10, torch.float32); sigmoid_10 = None | |
| + sub_233 = torch.ops.aten.sub.Tensor(1, convert_element_type_1493) | |
| + mul_1236 = torch.ops.aten.mul.Tensor(convert_element_type_1493, sub_233); convert_element_type_1493 = sub_233 = None | |
| + mul_1237 = torch.ops.aten.mul.Tensor(convert_element_type_1492, mul_1236); convert_element_type_1492 = mul_1236 = None | |
| + convert_element_type_1494 = torch.ops.prims.convert_element_type.default(mul_1237, torch.bfloat16); mul_1237 = None | |
| + add_507 = torch.ops.aten.add.Tensor(mul_1235, convert_element_type_1494); mul_1235 = convert_element_type_1494 = None | |
| + mul_1238 = torch.ops.aten.mul.Tensor(mm_330, convert_element_type_443); convert_element_type_443 = None | |
| + mul_1239 = torch.ops.aten.mul.Tensor(mm_330, sigmoid_9); mm_330 = None | |
| + convert_element_type_1495 = torch.ops.prims.convert_element_type.default(mul_1238, torch.float32); mul_1238 = None | |
| + convert_element_type_1496 = torch.ops.prims.convert_element_type.default(sigmoid_9, torch.float32); sigmoid_9 = None | |
| + sub_234 = torch.ops.aten.sub.Tensor(1, convert_element_type_1496) | |
| + mul_1240 = torch.ops.aten.mul.Tensor(convert_element_type_1496, sub_234); convert_element_type_1496 = sub_234 = None | |
| + mul_1241 = torch.ops.aten.mul.Tensor(convert_element_type_1495, mul_1240); convert_element_type_1495 = mul_1240 = None | |
| + convert_element_type_1497 = torch.ops.prims.convert_element_type.default(mul_1241, torch.bfloat16); mul_1241 = None | |
| + add_508 = torch.ops.aten.add.Tensor(mul_1239, convert_element_type_1497); mul_1239 = convert_element_type_1497 = None | |
| + convert_element_type_1498 = torch.ops.prims.convert_element_type.default(add_507, torch.float32); add_507 = None | |
| + mul_1242 = torch.ops.aten.mul.Tensor(convert_element_type_1498, mul_126); mul_126 = None | |
| + mul_1243 = torch.ops.aten.mul.Tensor(convert_element_type_1498, primals_314); convert_element_type_1498 = primals_314 = None | |
| + sum_190 = torch.ops.aten.sum.dim_IntList(mul_1242, [0], True); mul_1242 = None | |
| + view_592 = torch.ops.aten.view.default(sum_190, [256]); sum_190 = None | |
| + mul_1244 = torch.ops.aten.mul.Tensor(mul_1243, convert_element_type_444) | |
| + mul_1245 = torch.ops.aten.mul.Tensor(mul_1243, rsqrt_47); mul_1243 = None | |
| + sum_191 = torch.ops.aten.sum.dim_IntList(mul_1244, [1], True); mul_1244 = None | |
| + mul_1246 = torch.ops.aten.mul.Scalar(sum_191, -0.5); sum_191 = None | |
| + pow_166 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_47, 3); rsqrt_47 = None | |
| + mul_1247 = torch.ops.aten.mul.Tensor(mul_1246, pow_166); mul_1246 = pow_166 = None | |
| + expand_104 = torch.ops.aten.expand.default(mul_1247, [4096, 256]); mul_1247 = None | |
| + div_135 = torch.ops.aten.div.Scalar(expand_104, 256); expand_104 = None | |
| + pow_167 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_444, 1.0); convert_element_type_444 = None | |
| + mul_1248 = torch.ops.aten.mul.Scalar(pow_167, 2.0); pow_167 = None | |
| + mul_1249 = torch.ops.aten.mul.Tensor(div_135, mul_1248); div_135 = mul_1248 = None | |
| + add_509 = torch.ops.aten.add.Tensor(mul_1245, mul_1249); mul_1245 = mul_1249 = None | |
| + convert_element_type_1499 = torch.ops.prims.convert_element_type.default(add_509, torch.bfloat16); add_509 = None | |
| + convert_element_type_1500 = torch.ops.prims.convert_element_type.default(add_508, torch.float32); add_508 = None | |
| + mul_1250 = torch.ops.aten.mul.Tensor(convert_element_type_1500, mul_124); mul_124 = None | |
| + mul_1251 = torch.ops.aten.mul.Tensor(convert_element_type_1500, primals_313); convert_element_type_1500 = primals_313 = None | |
| + sum_192 = torch.ops.aten.sum.dim_IntList(mul_1250, [0], True); mul_1250 = None | |
| + view_593 = torch.ops.aten.view.default(sum_192, [256]); sum_192 = None | |
| + mul_1252 = torch.ops.aten.mul.Tensor(mul_1251, convert_element_type_442) | |
| + mul_1253 = torch.ops.aten.mul.Tensor(mul_1251, rsqrt_46); mul_1251 = None | |
| + sum_193 = torch.ops.aten.sum.dim_IntList(mul_1252, [1], True); mul_1252 = None | |
| + mul_1254 = torch.ops.aten.mul.Scalar(sum_193, -0.5); sum_193 = None | |
| + pow_168 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_46, 3); rsqrt_46 = None | |
| + mul_1255 = torch.ops.aten.mul.Tensor(mul_1254, pow_168); mul_1254 = pow_168 = None | |
| + expand_105 = torch.ops.aten.expand.default(mul_1255, [4096, 256]); mul_1255 = None | |
| + div_136 = torch.ops.aten.div.Scalar(expand_105, 256); expand_105 = None | |
| + pow_169 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_442, 1.0); convert_element_type_442 = None | |
| + mul_1256 = torch.ops.aten.mul.Scalar(pow_169, 2.0); pow_169 = None | |
| + mul_1257 = torch.ops.aten.mul.Tensor(div_136, mul_1256); div_136 = mul_1256 = None | |
| + add_510 = torch.ops.aten.add.Tensor(mul_1253, mul_1257); mul_1253 = mul_1257 = None | |
| + convert_element_type_1501 = torch.ops.prims.convert_element_type.default(add_510, torch.bfloat16); add_510 = None | |
| full_default_369 = torch.ops.aten.full.default([4096, 256], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - index_put_51 = torch.ops.aten.index_put.default(full_default_369, [sub_32], convert_element_type_1385, True); full_default_369 = None | |
| + index_put_51 = torch.ops.aten.index_put.default(full_default_369, [sub_32], convert_element_type_1499, True); full_default_369 = None | |
| slice_scatter_62 = torch.ops.aten.slice_scatter.default(full_default_277, index_put_51, 1, 256, 9223372036854775807); index_put_51 = None | |
| - permute_862 = torch.ops.aten.permute.default(convert_element_type_1385, [1, 0]) | |
| + permute_862 = torch.ops.aten.permute.default(convert_element_type_1499, [1, 0]) | |
| slice_13 = torch.ops.aten.slice.Tensor(mm_53, 1, 512, 9223372036854775807); mm_53 = None | |
| index_10 = torch.ops.aten.index.Tensor(slice_13, [sub_32]); slice_13 = None | |
| add_108 = torch.ops.aten.add.Tensor(mm_54, index_10); mm_54 = index_10 = None | |
| - convert_element_type_432 = torch.ops.prims.convert_element_type.default(add_108, torch.float32); add_108 = None | |
| - pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_432, 2) | |
| + convert_element_type_434 = torch.ops.prims.convert_element_type.default(add_108, torch.float32); add_108 = None | |
| + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_434, 2) | |
| mean_24 = torch.ops.aten.mean.dim(pow_25, [1], True); pow_25 = None | |
| add_110 = torch.ops.aten.add.Scalar(mean_24, 1.1920928955078125e-07); mean_24 = None | |
| rsqrt_45 = torch.ops.aten.rsqrt.default(add_110); add_110 = None | |
| - mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_432, rsqrt_45); convert_element_type_432 = None | |
| + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_434, rsqrt_45) | |
| mul_121 = torch.ops.aten.mul.Tensor(mul_120, primals_310) | |
| - sigmoid_8 = torch.ops.aten.sigmoid.default(mul_121) | |
| - mul_123 = torch.ops.aten.mul.Tensor(mul_121, sigmoid_8) | |
| - convert_element_type_437 = torch.ops.prims.convert_element_type.default(mul_123, torch.bfloat16); mul_123 = None | |
| - mm_331 = torch.ops.aten.mm.default(permute_862, convert_element_type_437); permute_862 = convert_element_type_437 = None | |
| - convert_element_type_438 = torch.ops.prims.convert_element_type.default(primals_312, torch.bfloat16); primals_312 = None | |
| - permute_168 = torch.ops.aten.permute.default(convert_element_type_438, [1, 0]); convert_element_type_438 = None | |
| + convert_element_type_435 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None | |
| + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_435) | |
| + mul_123 = torch.ops.aten.mul.Tensor(convert_element_type_435, sigmoid_8) | |
| + mm_331 = torch.ops.aten.mm.default(permute_862, mul_123); permute_862 = mul_123 = None | |
| + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_312, torch.bfloat16); primals_312 = None | |
| + permute_168 = torch.ops.aten.permute.default(convert_element_type_439, [1, 0]); convert_element_type_439 = None | |
| permute_864 = torch.ops.aten.permute.default(permute_168, [1, 0]); permute_168 = None | |
| - mm_332 = torch.ops.aten.mm.default(convert_element_type_1385, permute_864); convert_element_type_1385 = permute_864 = None | |
| - convert_element_type_1391 = torch.ops.prims.convert_element_type.default(mm_331, torch.float32); mm_331 = None | |
| - convert_element_type_1392 = torch.ops.prims.convert_element_type.default(mm_332, torch.float32); mm_332 = None | |
| - slice_scatter_63 = torch.ops.aten.slice_scatter.default(full_default_277, convert_element_type_1386, 1, 0, 256); convert_element_type_1386 = None | |
| - add_468 = torch.ops.aten.add.Tensor(slice_scatter_62, slice_scatter_63); slice_scatter_62 = slice_scatter_63 = None | |
| - permute_866 = torch.ops.aten.permute.default(add_468, [1, 0]) | |
| - mm_333 = torch.ops.aten.mm.default(permute_866, convert_element_type_433); permute_866 = convert_element_type_433 = None | |
| + mm_332 = torch.ops.aten.mm.default(convert_element_type_1499, permute_864); convert_element_type_1499 = permute_864 = None | |
| + convert_element_type_1506 = torch.ops.prims.convert_element_type.default(mm_331, torch.float32); mm_331 = None | |
| + slice_scatter_63 = torch.ops.aten.slice_scatter.default(full_default_277, convert_element_type_1501, 1, 0, 256); convert_element_type_1501 = None | |
| + add_511 = torch.ops.aten.add.Tensor(slice_scatter_62, slice_scatter_63); slice_scatter_62 = slice_scatter_63 = None | |
| + permute_866 = torch.ops.aten.permute.default(add_511, [1, 0]) | |
| + mm_333 = torch.ops.aten.mm.default(permute_866, mul_122); permute_866 = mul_122 = None | |
| permute_868 = torch.ops.aten.permute.default(permute_167, [1, 0]); permute_167 = None | |
| - mm_334 = torch.ops.aten.mm.default(add_468, permute_868); add_468 = permute_868 = None | |
| - convert_element_type_1397 = torch.ops.prims.convert_element_type.default(mm_333, torch.float32); mm_333 = None | |
| - convert_element_type_1398 = torch.ops.prims.convert_element_type.default(mm_334, torch.float32); mm_334 = None | |
| - mul_1172 = torch.ops.aten.mul.Tensor(convert_element_type_1392, mul_121); mul_121 = None | |
| - mul_1173 = torch.ops.aten.mul.Tensor(convert_element_type_1392, sigmoid_8); convert_element_type_1392 = None | |
| - sub_278 = torch.ops.aten.sub.Tensor(1, sigmoid_8) | |
| - mul_1174 = torch.ops.aten.mul.Tensor(sigmoid_8, sub_278); sigmoid_8 = sub_278 = None | |
| - mul_1175 = torch.ops.aten.mul.Tensor(mul_1172, mul_1174); mul_1172 = mul_1174 = None | |
| - add_469 = torch.ops.aten.add.Tensor(mul_1173, mul_1175); mul_1173 = mul_1175 = None | |
| - mul_1176 = torch.ops.aten.mul.Tensor(convert_element_type_1398, mul_119); mul_119 = None | |
| - mul_1177 = torch.ops.aten.mul.Tensor(convert_element_type_1398, sigmoid_7); convert_element_type_1398 = None | |
| - sub_279 = torch.ops.aten.sub.Tensor(1, sigmoid_7) | |
| - mul_1178 = torch.ops.aten.mul.Tensor(sigmoid_7, sub_279); sigmoid_7 = sub_279 = None | |
| - mul_1179 = torch.ops.aten.mul.Tensor(mul_1176, mul_1178); mul_1176 = mul_1178 = None | |
| - add_470 = torch.ops.aten.add.Tensor(mul_1177, mul_1179); mul_1177 = mul_1179 = None | |
| - mul_1180 = torch.ops.aten.mul.Tensor(add_469, primals_310); primals_310 = None | |
| - mul_1182 = torch.ops.aten.mul.Tensor(mul_120, mul_1180) | |
| - sum_194 = torch.ops.aten.sum.dim_IntList(mul_1182, [1], True); mul_1182 = None | |
| - div_137 = torch.ops.aten.div.Tensor(mul_120, 512) | |
| - mul_1183 = torch.ops.aten.mul.Tensor(div_137, sum_194); div_137 = sum_194 = None | |
| - sub_280 = torch.ops.aten.sub.Tensor(mul_1180, mul_1183); mul_1180 = mul_1183 = None | |
| - mul_1184 = torch.ops.aten.mul.Tensor(sub_280, rsqrt_45); sub_280 = rsqrt_45 = None | |
| - mul_1185 = torch.ops.aten.mul.Tensor(add_469, mul_120); add_469 = mul_120 = None | |
| - sum_195 = torch.ops.aten.sum.dim_IntList(mul_1185, [0]); mul_1185 = None | |
| - convert_element_type_1399 = torch.ops.prims.convert_element_type.default(mul_1184, torch.bfloat16); mul_1184 = None | |
| - mul_1186 = torch.ops.aten.mul.Tensor(add_470, primals_309); primals_309 = None | |
| - mul_1188 = torch.ops.aten.mul.Tensor(mul_118, mul_1186) | |
| - sum_196 = torch.ops.aten.sum.dim_IntList(mul_1188, [1], True); mul_1188 = None | |
| - div_138 = torch.ops.aten.div.Tensor(mul_118, 512) | |
| - mul_1189 = torch.ops.aten.mul.Tensor(div_138, sum_196); div_138 = sum_196 = None | |
| - sub_281 = torch.ops.aten.sub.Tensor(mul_1186, mul_1189); mul_1186 = mul_1189 = None | |
| - mul_1190 = torch.ops.aten.mul.Tensor(sub_281, rsqrt_44); sub_281 = rsqrt_44 = None | |
| - mul_1191 = torch.ops.aten.mul.Tensor(add_470, mul_118); add_470 = mul_118 = None | |
| - sum_197 = torch.ops.aten.sum.dim_IntList(mul_1191, [0]); mul_1191 = None | |
| - convert_element_type_1400 = torch.ops.prims.convert_element_type.default(mul_1190, torch.bfloat16); mul_1190 = None | |
| - index_put_52 = torch.ops.aten.index_put.default(full_default_277, [sub_32], convert_element_type_1399, True); full_default_277 = None | |
| + mm_334 = torch.ops.aten.mm.default(add_511, permute_868); add_511 = permute_868 = None | |
| + convert_element_type_1511 = torch.ops.prims.convert_element_type.default(mm_333, torch.float32); mm_333 = None | |
| + mul_1258 = torch.ops.aten.mul.Tensor(mm_332, convert_element_type_435); convert_element_type_435 = None | |
| + mul_1259 = torch.ops.aten.mul.Tensor(mm_332, sigmoid_8); mm_332 = None | |
| + convert_element_type_1512 = torch.ops.prims.convert_element_type.default(mul_1258, torch.float32); mul_1258 = None | |
| + convert_element_type_1513 = torch.ops.prims.convert_element_type.default(sigmoid_8, torch.float32); sigmoid_8 = None | |
| + sub_235 = torch.ops.aten.sub.Tensor(1, convert_element_type_1513) | |
| + mul_1260 = torch.ops.aten.mul.Tensor(convert_element_type_1513, sub_235); convert_element_type_1513 = sub_235 = None | |
| + mul_1261 = torch.ops.aten.mul.Tensor(convert_element_type_1512, mul_1260); convert_element_type_1512 = mul_1260 = None | |
| + convert_element_type_1514 = torch.ops.prims.convert_element_type.default(mul_1261, torch.bfloat16); mul_1261 = None | |
| + add_512 = torch.ops.aten.add.Tensor(mul_1259, convert_element_type_1514); mul_1259 = convert_element_type_1514 = None | |
| + mul_1262 = torch.ops.aten.mul.Tensor(mm_334, convert_element_type_433); convert_element_type_433 = None | |
| + mul_1263 = torch.ops.aten.mul.Tensor(mm_334, sigmoid_7); mm_334 = None | |
| + convert_element_type_1515 = torch.ops.prims.convert_element_type.default(mul_1262, torch.float32); mul_1262 = None | |
| + convert_element_type_1516 = torch.ops.prims.convert_element_type.default(sigmoid_7, torch.float32); sigmoid_7 = None | |
| + sub_236 = torch.ops.aten.sub.Tensor(1, convert_element_type_1516) | |
| + mul_1264 = torch.ops.aten.mul.Tensor(convert_element_type_1516, sub_236); convert_element_type_1516 = sub_236 = None | |
| + mul_1265 = torch.ops.aten.mul.Tensor(convert_element_type_1515, mul_1264); convert_element_type_1515 = mul_1264 = None | |
| + convert_element_type_1517 = torch.ops.prims.convert_element_type.default(mul_1265, torch.bfloat16); mul_1265 = None | |
| + add_513 = torch.ops.aten.add.Tensor(mul_1263, convert_element_type_1517); mul_1263 = convert_element_type_1517 = None | |
| + convert_element_type_1518 = torch.ops.prims.convert_element_type.default(add_512, torch.float32); add_512 = None | |
| + mul_1266 = torch.ops.aten.mul.Tensor(convert_element_type_1518, mul_120); mul_120 = None | |
| + mul_1267 = torch.ops.aten.mul.Tensor(convert_element_type_1518, primals_310); convert_element_type_1518 = primals_310 = None | |
| + sum_194 = torch.ops.aten.sum.dim_IntList(mul_1266, [0], True); mul_1266 = None | |
| + view_594 = torch.ops.aten.view.default(sum_194, [512]); sum_194 = None | |
| + mul_1268 = torch.ops.aten.mul.Tensor(mul_1267, convert_element_type_434) | |
| + mul_1269 = torch.ops.aten.mul.Tensor(mul_1267, rsqrt_45); mul_1267 = None | |
| + sum_195 = torch.ops.aten.sum.dim_IntList(mul_1268, [1], True); mul_1268 = None | |
| + mul_1270 = torch.ops.aten.mul.Scalar(sum_195, -0.5); sum_195 = None | |
| + pow_170 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_45, 3); rsqrt_45 = None | |
| + mul_1271 = torch.ops.aten.mul.Tensor(mul_1270, pow_170); mul_1270 = pow_170 = None | |
| + expand_106 = torch.ops.aten.expand.default(mul_1271, [4096, 512]); mul_1271 = None | |
| + div_137 = torch.ops.aten.div.Scalar(expand_106, 512); expand_106 = None | |
| + pow_171 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_434, 1.0); convert_element_type_434 = None | |
| + mul_1272 = torch.ops.aten.mul.Scalar(pow_171, 2.0); pow_171 = None | |
| + mul_1273 = torch.ops.aten.mul.Tensor(div_137, mul_1272); div_137 = mul_1272 = None | |
| + add_514 = torch.ops.aten.add.Tensor(mul_1269, mul_1273); mul_1269 = mul_1273 = None | |
| + convert_element_type_1519 = torch.ops.prims.convert_element_type.default(add_514, torch.bfloat16); add_514 = None | |
| + convert_element_type_1520 = torch.ops.prims.convert_element_type.default(add_513, torch.float32); add_513 = None | |
| + mul_1274 = torch.ops.aten.mul.Tensor(convert_element_type_1520, mul_118); mul_118 = None | |
| + mul_1275 = torch.ops.aten.mul.Tensor(convert_element_type_1520, primals_309); convert_element_type_1520 = primals_309 = None | |
| + sum_196 = torch.ops.aten.sum.dim_IntList(mul_1274, [0], True); mul_1274 = None | |
| + view_595 = torch.ops.aten.view.default(sum_196, [512]); sum_196 = None | |
| + mul_1276 = torch.ops.aten.mul.Tensor(mul_1275, convert_element_type_432) | |
| + mul_1277 = torch.ops.aten.mul.Tensor(mul_1275, rsqrt_44); mul_1275 = None | |
| + sum_197 = torch.ops.aten.sum.dim_IntList(mul_1276, [1], True); mul_1276 = None | |
| + mul_1278 = torch.ops.aten.mul.Scalar(sum_197, -0.5); sum_197 = None | |
| + pow_172 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_44, 3); rsqrt_44 = None | |
| + mul_1279 = torch.ops.aten.mul.Tensor(mul_1278, pow_172); mul_1278 = pow_172 = None | |
| + expand_107 = torch.ops.aten.expand.default(mul_1279, [4096, 512]); mul_1279 = None | |
| + div_138 = torch.ops.aten.div.Scalar(expand_107, 512); expand_107 = None | |
| + pow_173 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_432, 1.0); convert_element_type_432 = None | |
| + mul_1280 = torch.ops.aten.mul.Scalar(pow_173, 2.0); pow_173 = None | |
| + mul_1281 = torch.ops.aten.mul.Tensor(div_138, mul_1280); div_138 = mul_1280 = None | |
| + add_515 = torch.ops.aten.add.Tensor(mul_1277, mul_1281); mul_1277 = mul_1281 = None | |
| + convert_element_type_1521 = torch.ops.prims.convert_element_type.default(add_515, torch.bfloat16); add_515 = None | |
| + index_put_52 = torch.ops.aten.index_put.default(full_default_277, [sub_32], convert_element_type_1519, True); full_default_277 = None | |
| slice_scatter_64 = torch.ops.aten.slice_scatter.default(full_default_278, index_put_52, 1, 512, 9223372036854775807); index_put_52 = None | |
| - permute_870 = torch.ops.aten.permute.default(convert_element_type_1399, [1, 0]) | |
| + permute_870 = torch.ops.aten.permute.default(convert_element_type_1519, [1, 0]) | |
| mm_335 = torch.ops.aten.mm.default(permute_870, view_194); permute_870 = view_194 = None | |
| - convert_element_type_428 = torch.ops.prims.convert_element_type.default(primals_308, torch.bfloat16); primals_308 = None | |
| - permute_166 = torch.ops.aten.permute.default(convert_element_type_428, [1, 0]); convert_element_type_428 = None | |
| + convert_element_type_429 = torch.ops.prims.convert_element_type.default(primals_308, torch.bfloat16); primals_308 = None | |
| + permute_166 = torch.ops.aten.permute.default(convert_element_type_429, [1, 0]); convert_element_type_429 = None | |
| permute_872 = torch.ops.aten.permute.default(permute_166, [1, 0]); permute_166 = None | |
| - mm_336 = torch.ops.aten.mm.default(convert_element_type_1399, permute_872); convert_element_type_1399 = permute_872 = None | |
| - convert_element_type_1405 = torch.ops.prims.convert_element_type.default(mm_335, torch.float32); mm_335 = None | |
| - slice_scatter_65 = torch.ops.aten.slice_scatter.default(full_default_278, convert_element_type_1400, 1, 0, 512); full_default_278 = convert_element_type_1400 = None | |
| - add_471 = torch.ops.aten.add.Tensor(slice_scatter_64, slice_scatter_65); slice_scatter_64 = slice_scatter_65 = None | |
| - permute_874 = torch.ops.aten.permute.default(add_471, [1, 0]) | |
| + mm_336 = torch.ops.aten.mm.default(convert_element_type_1519, permute_872); convert_element_type_1519 = permute_872 = None | |
| + convert_element_type_1526 = torch.ops.prims.convert_element_type.default(mm_335, torch.float32); mm_335 = None | |
| + slice_scatter_65 = torch.ops.aten.slice_scatter.default(full_default_278, convert_element_type_1521, 1, 0, 512); full_default_278 = convert_element_type_1521 = None | |
| + add_516 = torch.ops.aten.add.Tensor(slice_scatter_64, slice_scatter_65); slice_scatter_64 = slice_scatter_65 = None | |
| + permute_874 = torch.ops.aten.permute.default(add_516, [1, 0]) | |
| mm_337 = torch.ops.aten.mm.default(permute_874, view_193); permute_874 = view_193 = None | |
| - convert_element_type_425 = torch.ops.prims.convert_element_type.default(primals_307, torch.bfloat16); primals_307 = None | |
| - permute_165 = torch.ops.aten.permute.default(convert_element_type_425, [1, 0]); convert_element_type_425 = None | |
| + convert_element_type_426 = torch.ops.prims.convert_element_type.default(primals_307, torch.bfloat16); primals_307 = None | |
| + permute_165 = torch.ops.aten.permute.default(convert_element_type_426, [1, 0]); convert_element_type_426 = None | |
| permute_876 = torch.ops.aten.permute.default(permute_165, [1, 0]); permute_165 = None | |
| - mm_338 = torch.ops.aten.mm.default(add_471, permute_876); add_471 = permute_876 = None | |
| - convert_element_type_1410 = torch.ops.prims.convert_element_type.default(mm_337, torch.float32); mm_337 = None | |
| - view_551 = torch.ops.aten.view.default(mm_336, [4096, 16, 112]); mm_336 = None | |
| - view_552 = torch.ops.aten.view.default(mm_338, [4096, 16, 112]); mm_338 = None | |
| + mm_338 = torch.ops.aten.mm.default(add_516, permute_876); add_516 = permute_876 = None | |
| + convert_element_type_1531 = torch.ops.prims.convert_element_type.default(mm_337, torch.float32); mm_337 = None | |
| + view_596 = torch.ops.aten.view.default(mm_336, [4096, 16, 112]); mm_336 = None | |
| + view_597 = torch.ops.aten.view.default(mm_338, [4096, 16, 112]); mm_338 = None | |
| index_put_53 = torch.ops.aten.index_put.default(full_default_286, [sub_32], slice_104, True); full_default_286 = None | |
| slice_scatter_66 = torch.ops.aten.slice_scatter.default(full_default_287, index_put_53, 1, 128, 9223372036854775807); index_put_53 = None | |
| permute_878 = torch.ops.aten.permute.default(slice_104, [0, 2, 1]); slice_104 = None | |
| - view_553 = torch.ops.aten.view.default(permute_878, [458752, 128]); permute_878 = None | |
| - permute_879 = torch.ops.aten.permute.default(view_553, [1, 0]) | |
| - convert_element_type_default_34 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_709_view_183, torch.bfloat16); fp8_quant_pos_709_view_183 = None | |
| - div_tensor_12 = torch.ops.aten.div.Tensor(convert_element_type_default_34, fp8_scale_pos_709_view_183); convert_element_type_default_34 = fp8_scale_pos_709_view_183 = None | |
| - convert_element_type_default_35 = torch.ops.prims.convert_element_type.default(div_tensor_12, torch.bfloat16); div_tensor_12 = None | |
| - mm_339 = torch.ops.aten.mm.default(permute_879, convert_element_type_default_35); permute_879 = None | |
| - convert_element_type_421 = torch.ops.prims.convert_element_type.default(primals_306, torch.bfloat16); primals_306 = None | |
| - permute_163 = torch.ops.aten.permute.default(convert_element_type_421, [1, 0]); convert_element_type_421 = None | |
| + view_598 = torch.ops.aten.view.default(permute_878, [458752, 128]); permute_878 = None | |
| + permute_879 = torch.ops.aten.permute.default(view_598, [1, 0]) | |
| + convert_element_type_default_31 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_696_view_183, torch.bfloat16); fp8_quant_pos_696_view_183 = None | |
| + div_tensor_12 = torch.ops.aten.div.Tensor(convert_element_type_default_31, fp8_scale_pos_696_view_183); convert_element_type_default_31 = fp8_scale_pos_696_view_183 = None | |
| + convert_element_type_default_32 = torch.ops.prims.convert_element_type.default(div_tensor_12, torch.bfloat16); div_tensor_12 = None | |
| + mm_339 = torch.ops.aten.mm.default(permute_879, convert_element_type_default_32); permute_879 = None | |
| + convert_element_type_422 = torch.ops.prims.convert_element_type.default(primals_306, torch.bfloat16); primals_306 = None | |
| + permute_163 = torch.ops.aten.permute.default(convert_element_type_422, [1, 0]); convert_element_type_422 = None | |
| permute_881 = torch.ops.aten.permute.default(permute_163, [1, 0]); permute_163 = None | |
| - mm_340 = torch.ops.aten.mm.default(view_553, permute_881); view_553 = permute_881 = None | |
| - view_554 = torch.ops.aten.view.default(mm_340, [4096, 112, 452]); mm_340 = None | |
| - permute_883 = torch.ops.aten.permute.default(view_554, [0, 2, 1]); view_554 = None | |
| - convert_element_type_1415 = torch.ops.prims.convert_element_type.default(permute_883, torch.float32); permute_883 = None | |
| - add_472 = torch.ops.aten.add.Tensor(slice_112, convert_element_type_1415); slice_112 = convert_element_type_1415 = None | |
| - convert_element_type_1416 = torch.ops.prims.convert_element_type.default(mm_339, torch.float32); mm_339 = None | |
| + mm_340 = torch.ops.aten.mm.default(view_598, permute_881); view_598 = permute_881 = None | |
| + view_599 = torch.ops.aten.view.default(mm_340, [4096, 112, 452]); mm_340 = None | |
| + permute_883 = torch.ops.aten.permute.default(view_599, [0, 2, 1]); view_599 = None | |
| + convert_element_type_1536 = torch.ops.prims.convert_element_type.default(permute_883, torch.float32); permute_883 = None | |
| + add_517 = torch.ops.aten.add.Tensor(slice_112, convert_element_type_1536); slice_112 = convert_element_type_1536 = None | |
| + convert_element_type_1537 = torch.ops.prims.convert_element_type.default(mm_339, torch.float32); mm_339 = None | |
| slice_scatter_67 = torch.ops.aten.slice_scatter.default(full_default_287, slice_106, 1, 0, 128); full_default_287 = slice_106 = None | |
| - add_473 = torch.ops.aten.add.Tensor(slice_scatter_66, slice_scatter_67); slice_scatter_66 = slice_scatter_67 = None | |
| - permute_884 = torch.ops.aten.permute.default(add_473, [0, 2, 1]); add_473 = None | |
| + add_518 = torch.ops.aten.add.Tensor(slice_scatter_66, slice_scatter_67); slice_scatter_66 = slice_scatter_67 = None | |
| + permute_884 = torch.ops.aten.permute.default(add_518, [0, 2, 1]); add_518 = None | |
| clone_108 = torch.ops.aten.clone.default(permute_884, memory_format = torch.contiguous_format); permute_884 = None | |
| - view_555 = torch.ops.aten.view.default(clone_108, [458752, 256]); clone_108 = None | |
| - permute_885 = torch.ops.aten.permute.default(view_555, [1, 0]) | |
| - convert_element_type_default_32 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_708_view_181, torch.bfloat16); fp8_quant_pos_708_view_181 = None | |
| - div_tensor_11 = torch.ops.aten.div.Tensor(convert_element_type_default_32, fp8_scale_pos_708_view_181); convert_element_type_default_32 = fp8_scale_pos_708_view_181 = None | |
| - convert_element_type_default_33 = torch.ops.prims.convert_element_type.default(div_tensor_11, torch.bfloat16); div_tensor_11 = None | |
| - mm_341 = torch.ops.aten.mm.default(permute_885, convert_element_type_default_33); permute_885 = None | |
| - convert_element_type_417 = torch.ops.prims.convert_element_type.default(primals_305, torch.bfloat16); primals_305 = None | |
| - permute_160 = torch.ops.aten.permute.default(convert_element_type_417, [1, 0]); convert_element_type_417 = None | |
| + view_600 = torch.ops.aten.view.default(clone_108, [458752, 256]); clone_108 = None | |
| + permute_885 = torch.ops.aten.permute.default(view_600, [1, 0]) | |
| + convert_element_type_default_29 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_695_view_181, torch.bfloat16); fp8_quant_pos_695_view_181 = None | |
| + div_tensor_11 = torch.ops.aten.div.Tensor(convert_element_type_default_29, fp8_scale_pos_695_view_181); convert_element_type_default_29 = fp8_scale_pos_695_view_181 = None | |
| + convert_element_type_default_30 = torch.ops.prims.convert_element_type.default(div_tensor_11, torch.bfloat16); div_tensor_11 = None | |
| + mm_341 = torch.ops.aten.mm.default(permute_885, convert_element_type_default_30); permute_885 = None | |
| + convert_element_type_418 = torch.ops.prims.convert_element_type.default(primals_305, torch.bfloat16); primals_305 = None | |
| + permute_160 = torch.ops.aten.permute.default(convert_element_type_418, [1, 0]); convert_element_type_418 = None | |
| permute_887 = torch.ops.aten.permute.default(permute_160, [1, 0]); permute_160 = None | |
| - mm_342 = torch.ops.aten.mm.default(view_555, permute_887); view_555 = permute_887 = None | |
| - view_556 = torch.ops.aten.view.default(mm_342, [4096, 112, 1295]); mm_342 = None | |
| - permute_889 = torch.ops.aten.permute.default(view_556, [0, 2, 1]); view_556 = None | |
| - convert_element_type_1421 = torch.ops.prims.convert_element_type.default(permute_889, torch.float32); permute_889 = None | |
| - add_474 = torch.ops.aten.add.Tensor(add_464, convert_element_type_1421); add_464 = convert_element_type_1421 = None | |
| - convert_element_type_1422 = torch.ops.prims.convert_element_type.default(mm_341, torch.float32); mm_341 = None | |
| + mm_342 = torch.ops.aten.mm.default(view_600, permute_887); view_600 = permute_887 = None | |
| + view_601 = torch.ops.aten.view.default(mm_342, [4096, 112, 1295]); mm_342 = None | |
| + permute_889 = torch.ops.aten.permute.default(view_601, [0, 2, 1]); view_601 = None | |
| + convert_element_type_1542 = torch.ops.prims.convert_element_type.default(permute_889, torch.float32); permute_889 = None | |
| + add_519 = torch.ops.aten.add.Tensor(add_505, convert_element_type_1542); add_505 = convert_element_type_1542 = None | |
| + convert_element_type_1543 = torch.ops.prims.convert_element_type.default(mm_341, torch.float32); mm_341 = None | |
| full_default_378 = torch.ops.aten.full.default([4096, 16, 112], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| - index_put_54 = torch.ops.aten.index_put.default(full_default_378, [sub_32], view_551, True); sub_32 = None | |
| + index_put_54 = torch.ops.aten.index_put.default(full_default_378, [sub_32], view_596, True); sub_32 = None | |
| full_default_379 = torch.ops.aten.full.default([4096, 32, 112], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| slice_scatter_68 = torch.ops.aten.slice_scatter.default(full_default_379, index_put_54, 1, 16, 9223372036854775807); index_put_54 = None | |
| - permute_890 = torch.ops.aten.permute.default(view_551, [0, 2, 1]); view_551 = None | |
| + permute_890 = torch.ops.aten.permute.default(view_596, [0, 2, 1]); view_596 = None | |
| clone_109 = torch.ops.aten.clone.default(permute_890, memory_format = torch.contiguous_format); permute_890 = None | |
| - view_557 = torch.ops.aten.view.default(clone_109, [458752, 16]); clone_109 = None | |
| - permute_891 = torch.ops.aten.permute.default(view_557, [1, 0]) | |
| - mm_343 = torch.ops.aten.mm.default(permute_891, convert_element_type_default_35); permute_891 = None | |
| - convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_304, torch.bfloat16); primals_304 = None | |
| - permute_157 = torch.ops.aten.permute.default(convert_element_type_413, [1, 0]); convert_element_type_413 = None | |
| + view_602 = torch.ops.aten.view.default(clone_109, [458752, 16]); clone_109 = None | |
| + permute_891 = torch.ops.aten.permute.default(view_602, [1, 0]) | |
| + mm_343 = torch.ops.aten.mm.default(permute_891, convert_element_type_default_32); permute_891 = None | |
| + convert_element_type_414 = torch.ops.prims.convert_element_type.default(primals_304, torch.bfloat16); primals_304 = None | |
| + permute_157 = torch.ops.aten.permute.default(convert_element_type_414, [1, 0]); convert_element_type_414 = None | |
| permute_893 = torch.ops.aten.permute.default(permute_157, [1, 0]); permute_157 = None | |
| - mm_344 = torch.ops.aten.mm.default(view_557, permute_893); view_557 = permute_893 = None | |
| - view_558 = torch.ops.aten.view.default(mm_344, [4096, 112, 452]); mm_344 = None | |
| - permute_895 = torch.ops.aten.permute.default(view_558, [0, 2, 1]); view_558 = None | |
| - convert_element_type_1427 = torch.ops.prims.convert_element_type.default(permute_895, torch.float32); permute_895 = None | |
| - add_475 = torch.ops.aten.add.Tensor(add_472, convert_element_type_1427); add_472 = convert_element_type_1427 = None | |
| - convert_element_type_1428 = torch.ops.prims.convert_element_type.default(mm_343, torch.float32); mm_343 = None | |
| - slice_scatter_69 = torch.ops.aten.slice_scatter.default(full_default_379, view_552, 1, 0, 16); full_default_379 = view_552 = None | |
| - add_476 = torch.ops.aten.add.Tensor(slice_scatter_68, slice_scatter_69); slice_scatter_68 = slice_scatter_69 = None | |
| - permute_896 = torch.ops.aten.permute.default(add_476, [0, 2, 1]); add_476 = None | |
| + mm_344 = torch.ops.aten.mm.default(view_602, permute_893); view_602 = permute_893 = None | |
| + view_603 = torch.ops.aten.view.default(mm_344, [4096, 112, 452]); mm_344 = None | |
| + permute_895 = torch.ops.aten.permute.default(view_603, [0, 2, 1]); view_603 = None | |
| + convert_element_type_1548 = torch.ops.prims.convert_element_type.default(permute_895, torch.float32); permute_895 = None | |
| + add_520 = torch.ops.aten.add.Tensor(add_517, convert_element_type_1548); add_517 = convert_element_type_1548 = None | |
| + convert_element_type_1549 = torch.ops.prims.convert_element_type.default(mm_343, torch.float32); mm_343 = None | |
| + slice_scatter_69 = torch.ops.aten.slice_scatter.default(full_default_379, view_597, 1, 0, 16); full_default_379 = view_597 = None | |
| + add_521 = torch.ops.aten.add.Tensor(slice_scatter_68, slice_scatter_69); slice_scatter_68 = slice_scatter_69 = None | |
| + permute_896 = torch.ops.aten.permute.default(add_521, [0, 2, 1]); add_521 = None | |
| clone_110 = torch.ops.aten.clone.default(permute_896, memory_format = torch.contiguous_format); permute_896 = None | |
| - view_559 = torch.ops.aten.view.default(clone_110, [458752, 32]); clone_110 = None | |
| - permute_897 = torch.ops.aten.permute.default(view_559, [1, 0]) | |
| - mm_345 = torch.ops.aten.mm.default(permute_897, convert_element_type_default_33); permute_897 = None | |
| - convert_element_type_409 = torch.ops.prims.convert_element_type.default(primals_303, torch.bfloat16); primals_303 = None | |
| - permute_154 = torch.ops.aten.permute.default(convert_element_type_409, [1, 0]); convert_element_type_409 = None | |
| + view_604 = torch.ops.aten.view.default(clone_110, [458752, 32]); clone_110 = None | |
| + permute_897 = torch.ops.aten.permute.default(view_604, [1, 0]) | |
| + mm_345 = torch.ops.aten.mm.default(permute_897, convert_element_type_default_30); permute_897 = None | |
| + convert_element_type_410 = torch.ops.prims.convert_element_type.default(primals_303, torch.bfloat16); primals_303 = None | |
| + permute_154 = torch.ops.aten.permute.default(convert_element_type_410, [1, 0]); convert_element_type_410 = None | |
| permute_899 = torch.ops.aten.permute.default(permute_154, [1, 0]); permute_154 = None | |
| - mm_346 = torch.ops.aten.mm.default(view_559, permute_899); view_559 = permute_899 = None | |
| - view_560 = torch.ops.aten.view.default(mm_346, [4096, 112, 1295]); mm_346 = None | |
| - permute_901 = torch.ops.aten.permute.default(view_560, [0, 2, 1]); view_560 = None | |
| - convert_element_type_1433 = torch.ops.prims.convert_element_type.default(permute_901, torch.float32); permute_901 = None | |
| - add_477 = torch.ops.aten.add.Tensor(add_474, convert_element_type_1433); add_474 = convert_element_type_1433 = None | |
| - convert_element_type_1434 = torch.ops.prims.convert_element_type.default(mm_345, torch.float32); mm_345 = None | |
| - mm_347 = torch.ops.aten.mm.default(permute_777, convert_element_type_default_35); permute_777 = convert_element_type_default_35 = None | |
| - convert_element_type_405 = torch.ops.prims.convert_element_type.default(primals_302, torch.bfloat16); primals_302 = None | |
| - permute_151 = torch.ops.aten.permute.default(convert_element_type_405, [1, 0]); convert_element_type_405 = None | |
| + mm_346 = torch.ops.aten.mm.default(view_604, permute_899); view_604 = permute_899 = None | |
| + view_605 = torch.ops.aten.view.default(mm_346, [4096, 112, 1295]); mm_346 = None | |
| + permute_901 = torch.ops.aten.permute.default(view_605, [0, 2, 1]); view_605 = None | |
| + convert_element_type_1554 = torch.ops.prims.convert_element_type.default(permute_901, torch.float32); permute_901 = None | |
| + add_522 = torch.ops.aten.add.Tensor(add_519, convert_element_type_1554); add_519 = convert_element_type_1554 = None | |
| + convert_element_type_1555 = torch.ops.prims.convert_element_type.default(mm_345, torch.float32); mm_345 = None | |
| + mm_347 = torch.ops.aten.mm.default(permute_777, convert_element_type_default_32); permute_777 = convert_element_type_default_32 = None | |
| + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_302, torch.bfloat16); primals_302 = None | |
| + permute_151 = torch.ops.aten.permute.default(convert_element_type_406, [1, 0]); convert_element_type_406 = None | |
| permute_905 = torch.ops.aten.permute.default(permute_151, [1, 0]); permute_151 = None | |
| - mm_348 = torch.ops.aten.mm.default(view_509, permute_905); view_509 = permute_905 = None | |
| - view_562 = torch.ops.aten.view.default(mm_348, [4096, 112, 452]); mm_348 = None | |
| - permute_907 = torch.ops.aten.permute.default(view_562, [0, 2, 1]); view_562 = None | |
| - convert_element_type_1439 = torch.ops.prims.convert_element_type.default(permute_907, torch.float32); permute_907 = None | |
| - add_478 = torch.ops.aten.add.Tensor(add_475, convert_element_type_1439); add_475 = convert_element_type_1439 = None | |
| - convert_element_type_1440 = torch.ops.prims.convert_element_type.default(mm_347, torch.float32); mm_347 = None | |
| - mm_349 = torch.ops.aten.mm.default(permute_783, convert_element_type_default_33); permute_783 = convert_element_type_default_33 = None | |
| - convert_element_type_401 = torch.ops.prims.convert_element_type.default(primals_301, torch.bfloat16); primals_301 = None | |
| - permute_148 = torch.ops.aten.permute.default(convert_element_type_401, [1, 0]); convert_element_type_401 = None | |
| + mm_348 = torch.ops.aten.mm.default(view_540, permute_905); view_540 = permute_905 = None | |
| + view_607 = torch.ops.aten.view.default(mm_348, [4096, 112, 452]); mm_348 = None | |
| + permute_907 = torch.ops.aten.permute.default(view_607, [0, 2, 1]); view_607 = None | |
| + convert_element_type_1560 = torch.ops.prims.convert_element_type.default(permute_907, torch.float32); permute_907 = None | |
| + add_523 = torch.ops.aten.add.Tensor(add_520, convert_element_type_1560); add_520 = convert_element_type_1560 = None | |
| + convert_element_type_1561 = torch.ops.prims.convert_element_type.default(mm_347, torch.float32); mm_347 = None | |
| + mm_349 = torch.ops.aten.mm.default(permute_783, convert_element_type_default_30); permute_783 = convert_element_type_default_30 = None | |
| + convert_element_type_402 = torch.ops.prims.convert_element_type.default(primals_301, torch.bfloat16); primals_301 = None | |
| + permute_148 = torch.ops.aten.permute.default(convert_element_type_402, [1, 0]); convert_element_type_402 = None | |
| permute_911 = torch.ops.aten.permute.default(permute_148, [1, 0]); permute_148 = None | |
| - mm_350 = torch.ops.aten.mm.default(view_511, permute_911); view_511 = permute_911 = None | |
| - view_564 = torch.ops.aten.view.default(mm_350, [4096, 112, 1295]); mm_350 = None | |
| - permute_913 = torch.ops.aten.permute.default(view_564, [0, 2, 1]); view_564 = None | |
| - convert_element_type_1445 = torch.ops.prims.convert_element_type.default(permute_913, torch.float32); permute_913 = None | |
| - add_480 = torch.ops.aten.add.Tensor(add_477, convert_element_type_1445); add_477 = convert_element_type_1445 = None | |
| - convert_element_type_1446 = torch.ops.prims.convert_element_type.default(mm_349, torch.float32); mm_349 = None | |
| - mul_1192 = torch.ops.aten.mul.Tensor(add_460, primals_300); primals_300 = None | |
| - convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16); primals_293 = None | |
| - convert_element_type_392 = torch.ops.prims.convert_element_type.default(primals_294, torch.bfloat16); primals_294 = None | |
| - permute_145 = torch.ops.aten.permute.default(convert_element_type_391, [1, 0]); convert_element_type_391 = None | |
| - addmm_28 = torch.ops.aten.addmm.default(convert_element_type_392, cat, permute_145); convert_element_type_392 = None | |
| + mm_350 = torch.ops.aten.mm.default(view_542, permute_911); view_542 = permute_911 = None | |
| + view_609 = torch.ops.aten.view.default(mm_350, [4096, 112, 1295]); mm_350 = None | |
| + permute_913 = torch.ops.aten.permute.default(view_609, [0, 2, 1]); view_609 = None | |
| + convert_element_type_1566 = torch.ops.prims.convert_element_type.default(permute_913, torch.float32); permute_913 = None | |
| + add_525 = torch.ops.aten.add.Tensor(add_522, convert_element_type_1566); add_522 = convert_element_type_1566 = None | |
| + convert_element_type_1567 = torch.ops.prims.convert_element_type.default(mm_349, torch.float32); mm_349 = None | |
| + mul_1282 = torch.ops.aten.mul.Tensor(add_501, primals_300); primals_300 = None | |
| + convert_element_type_392 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16); primals_293 = None | |
| + convert_element_type_393 = torch.ops.prims.convert_element_type.default(primals_294, torch.bfloat16); primals_294 = None | |
| + permute_145 = torch.ops.aten.permute.default(convert_element_type_392, [1, 0]); convert_element_type_392 = None | |
| + addmm_28 = torch.ops.aten.addmm.default(convert_element_type_393, cat, permute_145); convert_element_type_393 = None | |
| cat_12 = torch.ops.aten.cat.default([primals_104, addmm_28], 1); primals_104 = addmm_28 = None | |
| pow_23 = torch.ops.aten.pow.Tensor_Scalar(cat_12, 2) | |
| mean_22 = torch.ops.aten.mean.dim(pow_23, [1], True); pow_23 = None | |
| add_104 = torch.ops.aten.add.Scalar(mean_22, 1.1920928955078125e-07); mean_22 = None | |
| rsqrt_43 = torch.ops.aten.rsqrt.default(add_104); add_104 = None | |
| mul_116 = torch.ops.aten.mul.Tensor(cat_12, rsqrt_43); cat_12 = None | |
| - mul_1194 = torch.ops.aten.mul.Tensor(mul_116, mul_1192) | |
| - sum_198 = torch.ops.aten.sum.dim_IntList(mul_1194, [1], True); mul_1194 = None | |
| + mul_1284 = torch.ops.aten.mul.Tensor(mul_116, mul_1282) | |
| + sum_198 = torch.ops.aten.sum.dim_IntList(mul_1284, [1], True); mul_1284 = None | |
| div_139 = torch.ops.aten.div.Tensor(mul_116, 4096) | |
| - mul_1195 = torch.ops.aten.mul.Tensor(div_139, sum_198); div_139 = sum_198 = None | |
| - sub_282 = torch.ops.aten.sub.Tensor(mul_1192, mul_1195); mul_1192 = mul_1195 = None | |
| - mul_1196 = torch.ops.aten.mul.Tensor(sub_282, rsqrt_43); sub_282 = rsqrt_43 = None | |
| - mul_1197 = torch.ops.aten.mul.Tensor(add_460, mul_116); add_460 = mul_116 = None | |
| - sum_199 = torch.ops.aten.sum.dim_IntList(mul_1197, [0]); mul_1197 = None | |
| - mul_1198 = torch.ops.aten.mul.Tensor(add_461, primals_299); primals_299 = None | |
| - convert_element_type_396 = torch.ops.prims.convert_element_type.default(primals_295, torch.bfloat16); primals_295 = None | |
| - convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_296, torch.bfloat16); primals_296 = None | |
| - permute_146 = torch.ops.aten.permute.default(convert_element_type_396, [1, 0]); convert_element_type_396 = None | |
| - addmm_29 = torch.ops.aten.addmm.default(convert_element_type_397, clone, permute_146); convert_element_type_397 = None | |
| + mul_1285 = torch.ops.aten.mul.Tensor(div_139, sum_198); div_139 = sum_198 = None | |
| + sub_237 = torch.ops.aten.sub.Tensor(mul_1282, mul_1285); mul_1282 = mul_1285 = None | |
| + mul_1286 = torch.ops.aten.mul.Tensor(sub_237, rsqrt_43); sub_237 = rsqrt_43 = None | |
| + mul_1287 = torch.ops.aten.mul.Tensor(add_501, mul_116); add_501 = mul_116 = None | |
| + sum_199 = torch.ops.aten.sum.dim_IntList(mul_1287, [0]); mul_1287 = None | |
| + mul_1288 = torch.ops.aten.mul.Tensor(add_502, primals_299); primals_299 = None | |
| + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_295, torch.bfloat16); primals_295 = None | |
| + convert_element_type_398 = torch.ops.prims.convert_element_type.default(primals_296, torch.bfloat16); primals_296 = None | |
| + permute_146 = torch.ops.aten.permute.default(convert_element_type_397, [1, 0]); convert_element_type_397 = None | |
| + addmm_29 = torch.ops.aten.addmm.default(convert_element_type_398, clone, permute_146); convert_element_type_398 = None | |
| cat_11 = torch.ops.aten.cat.default([primals_105, addmm_29], 1); primals_105 = addmm_29 = None | |
| pow_22 = torch.ops.aten.pow.Tensor_Scalar(cat_11, 2) | |
| mean_21 = torch.ops.aten.mean.dim(pow_22, [1], True); pow_22 = None | |
| add_103 = torch.ops.aten.add.Scalar(mean_21, 1.1920928955078125e-07); mean_21 = None | |
| rsqrt_42 = torch.ops.aten.rsqrt.default(add_103); add_103 = None | |
| mul_114 = torch.ops.aten.mul.Tensor(cat_11, rsqrt_42); cat_11 = None | |
| - mul_1200 = torch.ops.aten.mul.Tensor(mul_114, mul_1198) | |
| - sum_200 = torch.ops.aten.sum.dim_IntList(mul_1200, [1], True); mul_1200 = None | |
| + mul_1290 = torch.ops.aten.mul.Tensor(mul_114, mul_1288) | |
| + sum_200 = torch.ops.aten.sum.dim_IntList(mul_1290, [1], True); mul_1290 = None | |
| div_140 = torch.ops.aten.div.Tensor(mul_114, 4096) | |
| - mul_1201 = torch.ops.aten.mul.Tensor(div_140, sum_200); div_140 = sum_200 = None | |
| - sub_283 = torch.ops.aten.sub.Tensor(mul_1198, mul_1201); mul_1198 = mul_1201 = None | |
| - mul_1202 = torch.ops.aten.mul.Tensor(sub_283, rsqrt_42); sub_283 = rsqrt_42 = None | |
| - mul_1203 = torch.ops.aten.mul.Tensor(add_461, mul_114); add_461 = mul_114 = None | |
| - sum_201 = torch.ops.aten.sum.dim_IntList(mul_1203, [0]); mul_1203 = None | |
| - mul_1204 = torch.ops.aten.mul.Tensor(add_478, primals_298); primals_298 = None | |
| - convert_element_type_default_12 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_154_primals_1, torch.bfloat16); fp8_quant_pos_154_primals_1 = None | |
| - div_tensor_1 = torch.ops.aten.div.Tensor(convert_element_type_default_12, fp8_scale_pos_154_primals_1); convert_element_type_default_12 = fp8_scale_pos_154_primals_1 = None | |
| - convert_element_type_default_13 = torch.ops.prims.convert_element_type.default(div_tensor_1, torch.bfloat16); div_tensor_1 = None | |
| - tanh = torch.ops.aten.tanh.default(convert_element_type_default_13); convert_element_type_default_13 = None | |
| + mul_1291 = torch.ops.aten.mul.Tensor(div_140, sum_200); div_140 = sum_200 = None | |
| + sub_238 = torch.ops.aten.sub.Tensor(mul_1288, mul_1291); mul_1288 = mul_1291 = None | |
| + mul_1292 = torch.ops.aten.mul.Tensor(sub_238, rsqrt_42); sub_238 = rsqrt_42 = None | |
| + mul_1293 = torch.ops.aten.mul.Tensor(add_502, mul_114); add_502 = mul_114 = None | |
| + sum_201 = torch.ops.aten.sum.dim_IntList(mul_1293, [0]); mul_1293 = None | |
| + mul_1294 = torch.ops.aten.mul.Tensor(add_523, primals_298); primals_298 = None | |
| + convert_element_type_default_9 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_154_primals_1, torch.bfloat16); fp8_quant_pos_154_primals_1 = None | |
| + div_tensor_1 = torch.ops.aten.div.Tensor(convert_element_type_default_9, fp8_scale_pos_154_primals_1); convert_element_type_default_9 = fp8_scale_pos_154_primals_1 = None | |
| + convert_element_type_default_10 = torch.ops.prims.convert_element_type.default(div_tensor_1, torch.bfloat16); div_tensor_1 = None | |
| + tanh = torch.ops.aten.tanh.default(convert_element_type_default_10); convert_element_type_default_10 = None | |
| cat_1 = torch.ops.aten.cat.default([tanh, primals_102, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17], 1); primals_102 = primals_3 = primals_4 = primals_5 = primals_6 = primals_7 = primals_8 = primals_9 = primals_10 = primals_11 = primals_12 = primals_13 = primals_14 = primals_15 = primals_16 = primals_17 = None | |
| view = torch.ops.aten.view.default(cat_1, [4096, -1, 112]); cat_1 = None | |
| - convert_element_type_96 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16); primals_163 = None | |
| - convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16); primals_164 = None | |
| - permute_34 = torch.ops.aten.permute.default(convert_element_type_96, [1, 0]); convert_element_type_96 = None | |
| - addmm_5 = torch.ops.aten.addmm.default(convert_element_type_97, view_47, permute_34); convert_element_type_97 = None | |
| + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16); primals_163 = None | |
| + convert_element_type_98 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16); primals_164 = None | |
| + permute_34 = torch.ops.aten.permute.default(convert_element_type_97, [1, 0]); convert_element_type_97 = None | |
| + addmm_5 = torch.ops.aten.addmm.default(convert_element_type_98, view_47, permute_34); convert_element_type_98 = None | |
| view_48 = torch.ops.aten.view.default(addmm_5, [4096, 16, 112]); addmm_5 = None | |
| - convert_element_type_135 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16); primals_179 = None | |
| - convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16); primals_180 = None | |
| - permute_43 = torch.ops.aten.permute.default(convert_element_type_135, [1, 0]); convert_element_type_135 = None | |
| - addmm_10 = torch.ops.aten.addmm.default(convert_element_type_136, view_57, permute_43); convert_element_type_136 = None | |
| + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16); primals_179 = None | |
| + convert_element_type_137 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16); primals_180 = None | |
| + permute_43 = torch.ops.aten.permute.default(convert_element_type_136, [1, 0]); convert_element_type_136 = None | |
| + addmm_10 = torch.ops.aten.addmm.default(convert_element_type_137, view_57, permute_43); convert_element_type_137 = None | |
| view_58 = torch.ops.aten.view.default(addmm_10, [4096, 16, 112]); addmm_10 = None | |
| - convert_element_type_174 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16); primals_195 = None | |
| - convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16); primals_196 = None | |
| - permute_52 = torch.ops.aten.permute.default(convert_element_type_174, [1, 0]); convert_element_type_174 = None | |
| - addmm_15 = torch.ops.aten.addmm.default(convert_element_type_175, view_67, permute_52); convert_element_type_175 = None | |
| + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16); primals_195 = None | |
| + convert_element_type_176 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16); primals_196 = None | |
| + permute_52 = torch.ops.aten.permute.default(convert_element_type_175, [1, 0]); convert_element_type_175 = None | |
| + addmm_15 = torch.ops.aten.addmm.default(convert_element_type_176, view_67, permute_52); convert_element_type_176 = None | |
| view_68 = torch.ops.aten.view.default(addmm_15, [4096, 16, 112]); addmm_15 = None | |
| - convert_element_type_227 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16); primals_219 = None | |
| - convert_element_type_228 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16); primals_220 = None | |
| - permute_75 = torch.ops.aten.permute.default(convert_element_type_227, [1, 0]); convert_element_type_227 = None | |
| - addmm_18 = torch.ops.aten.addmm.default(convert_element_type_228, view_95, permute_75); convert_element_type_228 = None | |
| + convert_element_type_228 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16); primals_219 = None | |
| + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16); primals_220 = None | |
| + permute_75 = torch.ops.aten.permute.default(convert_element_type_228, [1, 0]); convert_element_type_228 = None | |
| + addmm_18 = torch.ops.aten.addmm.default(convert_element_type_229, view_95, permute_75); convert_element_type_229 = None | |
| view_96 = torch.ops.aten.view.default(addmm_18, [4096, 16, 112]); addmm_18 = None | |
| - convert_element_type_280 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16); primals_243 = None | |
| - convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16); primals_244 = None | |
| - permute_98 = torch.ops.aten.permute.default(convert_element_type_280, [1, 0]); convert_element_type_280 = None | |
| - addmm_21 = torch.ops.aten.addmm.default(convert_element_type_281, view_123, permute_98); convert_element_type_281 = None | |
| + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16); primals_243 = None | |
| + convert_element_type_282 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16); primals_244 = None | |
| + permute_98 = torch.ops.aten.permute.default(convert_element_type_281, [1, 0]); convert_element_type_281 = None | |
| + addmm_21 = torch.ops.aten.addmm.default(convert_element_type_282, view_123, permute_98); convert_element_type_282 = None | |
| view_124 = torch.ops.aten.view.default(addmm_21, [4096, 16, 112]); addmm_21 = None | |
| - convert_element_type_333 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16); primals_267 = None | |
| - convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16); primals_268 = None | |
| - permute_121 = torch.ops.aten.permute.default(convert_element_type_333, [1, 0]); convert_element_type_333 = None | |
| - addmm_24 = torch.ops.aten.addmm.default(convert_element_type_334, view_151, permute_121); convert_element_type_334 = None | |
| + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16); primals_267 = None | |
| + convert_element_type_335 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16); primals_268 = None | |
| + permute_121 = torch.ops.aten.permute.default(convert_element_type_334, [1, 0]); convert_element_type_334 = None | |
| + addmm_24 = torch.ops.aten.addmm.default(convert_element_type_335, view_151, permute_121); convert_element_type_335 = None | |
| view_152 = torch.ops.aten.view.default(addmm_24, [4096, 16, 112]); addmm_24 = None | |
| - convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16); primals_291 = None | |
| - convert_element_type_387 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16); primals_292 = None | |
| - permute_144 = torch.ops.aten.permute.default(convert_element_type_386, [1, 0]); convert_element_type_386 = None | |
| - addmm_27 = torch.ops.aten.addmm.default(convert_element_type_387, view_179, permute_144); convert_element_type_387 = None | |
| + convert_element_type_387 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16); primals_291 = None | |
| + convert_element_type_388 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16); primals_292 = None | |
| + permute_144 = torch.ops.aten.permute.default(convert_element_type_387, [1, 0]); convert_element_type_387 = None | |
| + addmm_27 = torch.ops.aten.addmm.default(convert_element_type_388, view_179, permute_144); convert_element_type_388 = None | |
| view_180 = torch.ops.aten.view.default(addmm_27, [4096, 16, 112]); addmm_27 = None | |
| cat_10 = torch.ops.aten.cat.default([view, view_48, view_58, view_68, view_96, view_124, view_152, view_180], 1); view_48 = view_58 = view_68 = view_96 = view_124 = view_152 = view_180 = None | |
| pow_21 = torch.ops.aten.pow.Tensor_Scalar(cat_10, 2) | |
| @@ -8356,19 +5686,19 @@ | |
| add_102 = torch.ops.aten.add.Scalar(mean_20, 1.1920928955078125e-07); mean_20 = None | |
| rsqrt_41 = torch.ops.aten.rsqrt.default(add_102); add_102 = None | |
| mul_112 = torch.ops.aten.mul.Tensor(cat_10, rsqrt_41); cat_10 = None | |
| - mul_1206 = torch.ops.aten.mul.Tensor(mul_112, mul_1204) | |
| - sum_202 = torch.ops.aten.sum.dim_IntList(mul_1206, [2], True); mul_1206 = None | |
| + mul_1296 = torch.ops.aten.mul.Tensor(mul_112, mul_1294) | |
| + sum_202 = torch.ops.aten.sum.dim_IntList(mul_1296, [2], True); mul_1296 = None | |
| div_141 = torch.ops.aten.div.Tensor(mul_112, 112) | |
| - mul_1207 = torch.ops.aten.mul.Tensor(div_141, sum_202); div_141 = sum_202 = None | |
| - sub_284 = torch.ops.aten.sub.Tensor(mul_1204, mul_1207); mul_1204 = mul_1207 = None | |
| - mul_1208 = torch.ops.aten.mul.Tensor(sub_284, rsqrt_41); sub_284 = rsqrt_41 = None | |
| - mul_1209 = torch.ops.aten.mul.Tensor(add_478, mul_112); add_478 = mul_112 = None | |
| - sum_203 = torch.ops.aten.sum.dim_IntList(mul_1209, [0, 1]); mul_1209 = None | |
| - mul_1210 = torch.ops.aten.mul.Tensor(add_480, primals_297); primals_297 = None | |
| - convert_element_type_default_14 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_155_primals_2, torch.bfloat16); fp8_quant_pos_155_primals_2 = None | |
| - div_tensor_2 = torch.ops.aten.div.Tensor(convert_element_type_default_14, fp8_scale_pos_155_primals_2); convert_element_type_default_14 = fp8_scale_pos_155_primals_2 = None | |
| - convert_element_type_default_15 = torch.ops.prims.convert_element_type.default(div_tensor_2, torch.bfloat16); div_tensor_2 = None | |
| - tanh_1 = torch.ops.aten.tanh.default(convert_element_type_default_15); convert_element_type_default_15 = None | |
| + mul_1297 = torch.ops.aten.mul.Tensor(div_141, sum_202); div_141 = sum_202 = None | |
| + sub_239 = torch.ops.aten.sub.Tensor(mul_1294, mul_1297); mul_1294 = mul_1297 = None | |
| + mul_1298 = torch.ops.aten.mul.Tensor(sub_239, rsqrt_41); sub_239 = rsqrt_41 = None | |
| + mul_1299 = torch.ops.aten.mul.Tensor(add_523, mul_112); add_523 = mul_112 = None | |
| + sum_203 = torch.ops.aten.sum.dim_IntList(mul_1299, [0, 1]); mul_1299 = None | |
| + mul_1300 = torch.ops.aten.mul.Tensor(add_525, primals_297); primals_297 = None | |
| + convert_element_type_default_11 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_155_primals_2, torch.bfloat16); fp8_quant_pos_155_primals_2 = None | |
| + div_tensor_2 = torch.ops.aten.div.Tensor(convert_element_type_default_11, fp8_scale_pos_155_primals_2); convert_element_type_default_11 = fp8_scale_pos_155_primals_2 = None | |
| + convert_element_type_default_12 = torch.ops.prims.convert_element_type.default(div_tensor_2, torch.bfloat16); div_tensor_2 = None | |
| + tanh_1 = torch.ops.aten.tanh.default(convert_element_type_default_12); convert_element_type_default_12 = None | |
| cat_2 = torch.ops.aten.cat.default([tanh_1, primals_103, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101], 1); primals_103 = primals_18 = primals_19 = primals_20 = primals_21 = primals_22 = primals_23 = primals_24 = primals_25 = primals_26 = primals_27 = primals_28 = primals_29 = primals_30 = primals_31 = primals_32 = primals_33 = primals_34 = primals_35 = primals_36 = primals_37 = primals_38 = primals_39 = primals_40 = primals_41 = primals_42 = primals_43 = primals_44 = primals_45 = primals_46 = primals_47 = primals_48 = primals_49 = primals_50 = primals_51 = primals_52 = primals_53 = primals_54 = primals_55 = primals_56 = primals_57 = primals_58 = primals_59 = primals_60 = primals_61 = primals_62 = primals_63 = primals_64 = primals_65 = primals_66 = primals_67 = primals_68 = primals_69 = primals_70 = primals_71 = primals_72 = primals_73 = primals_74 = primals_75 = primals_76 = primals_77 = primals_78 = primals_79 = primals_80 = primals_81 = primals_82 = primals_83 = primals_84 = primals_85 = primals_86 = primals_87 = primals_88 = primals_89 = primals_90 = primals_91 = primals_92 = primals_93 = primals_94 = primals_95 = primals_96 = primals_97 = primals_98 = primals_99 = primals_100 = primals_101 = None | |
| view_1 = torch.ops.aten.view.default(cat_2, [4096, -1, 112]); cat_2 = None | |
| pow_20 = torch.ops.aten.pow.Tensor_Scalar(view_1, 2) | |
| @@ -8376,126 +5706,127 @@ | |
| add_101 = torch.ops.aten.add.Scalar(mean_19, 1.1920928955078125e-07); mean_19 = None | |
| rsqrt_40 = torch.ops.aten.rsqrt.default(add_101); add_101 = None | |
| mul_110 = torch.ops.aten.mul.Tensor(view_1, rsqrt_40); view_1 = None | |
| - mul_1212 = torch.ops.aten.mul.Tensor(mul_110, mul_1210) | |
| - sum_204 = torch.ops.aten.sum.dim_IntList(mul_1212, [2], True); mul_1212 = None | |
| + mul_1302 = torch.ops.aten.mul.Tensor(mul_110, mul_1300) | |
| + sum_204 = torch.ops.aten.sum.dim_IntList(mul_1302, [2], True); mul_1302 = None | |
| div_142 = torch.ops.aten.div.Tensor(mul_110, 112) | |
| - mul_1213 = torch.ops.aten.mul.Tensor(div_142, sum_204); div_142 = sum_204 = None | |
| - sub_285 = torch.ops.aten.sub.Tensor(mul_1210, mul_1213); mul_1210 = mul_1213 = None | |
| - mul_1214 = torch.ops.aten.mul.Tensor(sub_285, rsqrt_40); sub_285 = rsqrt_40 = None | |
| - mul_1215 = torch.ops.aten.mul.Tensor(add_480, mul_110); add_480 = mul_110 = None | |
| - sum_205 = torch.ops.aten.sum.dim_IntList(mul_1215, [0, 1]); mul_1215 = None | |
| - slice_113 = torch.ops.aten.slice.Tensor(mul_1196, 1, 0, 2048) | |
| - slice_114 = torch.ops.aten.slice.Tensor(mul_1196, 1, 2048, 4096); mul_1196 = None | |
| - convert_element_type_1447 = torch.ops.prims.convert_element_type.default(slice_114, torch.bfloat16); slice_114 = None | |
| - slice_115 = torch.ops.aten.slice.Tensor(mul_1202, 1, 0, 2048) | |
| - slice_116 = torch.ops.aten.slice.Tensor(mul_1202, 1, 2048, 4096); mul_1202 = None | |
| - convert_element_type_1448 = torch.ops.prims.convert_element_type.default(slice_116, torch.bfloat16); slice_116 = None | |
| + mul_1303 = torch.ops.aten.mul.Tensor(div_142, sum_204); div_142 = sum_204 = None | |
| + sub_240 = torch.ops.aten.sub.Tensor(mul_1300, mul_1303); mul_1300 = mul_1303 = None | |
| + mul_1304 = torch.ops.aten.mul.Tensor(sub_240, rsqrt_40); sub_240 = rsqrt_40 = None | |
| + mul_1305 = torch.ops.aten.mul.Tensor(add_525, mul_110); add_525 = mul_110 = None | |
| + sum_205 = torch.ops.aten.sum.dim_IntList(mul_1305, [0, 1]); mul_1305 = None | |
| + slice_113 = torch.ops.aten.slice.Tensor(mul_1286, 1, 0, 2048) | |
| + slice_114 = torch.ops.aten.slice.Tensor(mul_1286, 1, 2048, 4096); mul_1286 = None | |
| + convert_element_type_1568 = torch.ops.prims.convert_element_type.default(slice_114, torch.bfloat16); slice_114 = None | |
| + slice_115 = torch.ops.aten.slice.Tensor(mul_1292, 1, 0, 2048) | |
| + slice_116 = torch.ops.aten.slice.Tensor(mul_1292, 1, 2048, 4096); mul_1292 = None | |
| + convert_element_type_1569 = torch.ops.prims.convert_element_type.default(slice_116, torch.bfloat16); slice_116 = None | |
| permute_914 = torch.ops.aten.permute.default(permute_146, [1, 0]); permute_146 = None | |
| - mm_351 = torch.ops.aten.mm.default(convert_element_type_1448, permute_914); permute_914 = None | |
| - permute_915 = torch.ops.aten.permute.default(convert_element_type_1448, [1, 0]) | |
| + mm_351 = torch.ops.aten.mm.default(convert_element_type_1569, permute_914); permute_914 = None | |
| + permute_915 = torch.ops.aten.permute.default(convert_element_type_1569, [1, 0]) | |
| mm_352 = torch.ops.aten.mm.default(permute_915, clone); permute_915 = clone = None | |
| - sum_206 = torch.ops.aten.sum.dim_IntList(convert_element_type_1448, [0], True); convert_element_type_1448 = None | |
| - view_565 = torch.ops.aten.view.default(sum_206, [2048]); sum_206 = None | |
| - convert_element_type_1453 = torch.ops.prims.convert_element_type.default(view_565, torch.float32); view_565 = None | |
| - convert_element_type_1454 = torch.ops.prims.convert_element_type.default(mm_352, torch.float32); mm_352 = None | |
| + sum_206 = torch.ops.aten.sum.dim_IntList(convert_element_type_1569, [0], True); convert_element_type_1569 = None | |
| + view_610 = torch.ops.aten.view.default(sum_206, [2048]); sum_206 = None | |
| + convert_element_type_1574 = torch.ops.prims.convert_element_type.default(view_610, torch.float32); view_610 = None | |
| + convert_element_type_1575 = torch.ops.prims.convert_element_type.default(mm_352, torch.float32); mm_352 = None | |
| permute_918 = torch.ops.aten.permute.default(permute_145, [1, 0]); permute_145 = None | |
| - mm_353 = torch.ops.aten.mm.default(convert_element_type_1447, permute_918); permute_918 = None | |
| - permute_919 = torch.ops.aten.permute.default(convert_element_type_1447, [1, 0]) | |
| + mm_353 = torch.ops.aten.mm.default(convert_element_type_1568, permute_918); permute_918 = None | |
| + permute_919 = torch.ops.aten.permute.default(convert_element_type_1568, [1, 0]) | |
| mm_354 = torch.ops.aten.mm.default(permute_919, cat); permute_919 = cat = None | |
| - sum_207 = torch.ops.aten.sum.dim_IntList(convert_element_type_1447, [0], True); convert_element_type_1447 = None | |
| - view_566 = torch.ops.aten.view.default(sum_207, [2048]); sum_207 = None | |
| - convert_element_type_1459 = torch.ops.prims.convert_element_type.default(view_566, torch.float32); view_566 = None | |
| - convert_element_type_1460 = torch.ops.prims.convert_element_type.default(mm_354, torch.float32); mm_354 = None | |
| - slice_117 = torch.ops.aten.slice.Tensor(mul_1208, 1, 0, 340) | |
| - slice_118 = torch.ops.aten.slice.Tensor(mul_1208, 1, 340, 356) | |
| - slice_119 = torch.ops.aten.slice.Tensor(mul_1208, 1, 356, 372) | |
| - slice_120 = torch.ops.aten.slice.Tensor(mul_1208, 1, 372, 388) | |
| - slice_121 = torch.ops.aten.slice.Tensor(mul_1208, 1, 388, 404) | |
| - slice_122 = torch.ops.aten.slice.Tensor(mul_1208, 1, 404, 420) | |
| - slice_123 = torch.ops.aten.slice.Tensor(mul_1208, 1, 420, 436) | |
| - slice_124 = torch.ops.aten.slice.Tensor(mul_1208, 1, 436, 452); mul_1208 = None | |
| - convert_element_type_1461 = torch.ops.prims.convert_element_type.default(slice_118, torch.bfloat16); slice_118 = None | |
| - convert_element_type_1462 = torch.ops.prims.convert_element_type.default(slice_119, torch.bfloat16); slice_119 = None | |
| - convert_element_type_1463 = torch.ops.prims.convert_element_type.default(slice_120, torch.bfloat16); slice_120 = None | |
| - convert_element_type_1464 = torch.ops.prims.convert_element_type.default(slice_121, torch.bfloat16); slice_121 = None | |
| - convert_element_type_1465 = torch.ops.prims.convert_element_type.default(slice_122, torch.bfloat16); slice_122 = None | |
| - convert_element_type_1466 = torch.ops.prims.convert_element_type.default(slice_123, torch.bfloat16); slice_123 = None | |
| - convert_element_type_1467 = torch.ops.prims.convert_element_type.default(slice_124, torch.bfloat16); slice_124 = None | |
| - clone_112 = torch.ops.aten.clone.default(convert_element_type_1467, memory_format = torch.contiguous_format); convert_element_type_1467 = None | |
| - view_567 = torch.ops.aten.view.default(clone_112, [65536, 112]); clone_112 = None | |
| + sum_207 = torch.ops.aten.sum.dim_IntList(convert_element_type_1568, [0], True); convert_element_type_1568 = None | |
| + view_611 = torch.ops.aten.view.default(sum_207, [2048]); sum_207 = None | |
| + convert_element_type_1580 = torch.ops.prims.convert_element_type.default(view_611, torch.float32); view_611 = None | |
| + convert_element_type_1581 = torch.ops.prims.convert_element_type.default(mm_354, torch.float32); mm_354 = None | |
| + slice_117 = torch.ops.aten.slice.Tensor(mul_1298, 1, 0, 340) | |
| + slice_118 = torch.ops.aten.slice.Tensor(mul_1298, 1, 340, 356) | |
| + slice_119 = torch.ops.aten.slice.Tensor(mul_1298, 1, 356, 372) | |
| + slice_120 = torch.ops.aten.slice.Tensor(mul_1298, 1, 372, 388) | |
| + slice_121 = torch.ops.aten.slice.Tensor(mul_1298, 1, 388, 404) | |
| + slice_122 = torch.ops.aten.slice.Tensor(mul_1298, 1, 404, 420) | |
| + slice_123 = torch.ops.aten.slice.Tensor(mul_1298, 1, 420, 436) | |
| + slice_124 = torch.ops.aten.slice.Tensor(mul_1298, 1, 436, 452); mul_1298 = None | |
| + convert_element_type_1582 = torch.ops.prims.convert_element_type.default(slice_118, torch.bfloat16); slice_118 = None | |
| + convert_element_type_1583 = torch.ops.prims.convert_element_type.default(slice_119, torch.bfloat16); slice_119 = None | |
| + convert_element_type_1584 = torch.ops.prims.convert_element_type.default(slice_120, torch.bfloat16); slice_120 = None | |
| + convert_element_type_1585 = torch.ops.prims.convert_element_type.default(slice_121, torch.bfloat16); slice_121 = None | |
| + convert_element_type_1586 = torch.ops.prims.convert_element_type.default(slice_122, torch.bfloat16); slice_122 = None | |
| + convert_element_type_1587 = torch.ops.prims.convert_element_type.default(slice_123, torch.bfloat16); slice_123 = None | |
| + convert_element_type_1588 = torch.ops.prims.convert_element_type.default(slice_124, torch.bfloat16); slice_124 = None | |
| + clone_112 = torch.ops.aten.clone.default(convert_element_type_1588, memory_format = torch.contiguous_format); convert_element_type_1588 = None | |
| + view_612 = torch.ops.aten.view.default(clone_112, [65536, 112]); clone_112 = None | |
| permute_922 = torch.ops.aten.permute.default(permute_144, [1, 0]); permute_144 = None | |
| - mm_355 = torch.ops.aten.mm.default(view_567, permute_922); permute_922 = None | |
| - permute_923 = torch.ops.aten.permute.default(view_567, [1, 0]) | |
| + mm_355 = torch.ops.aten.mm.default(view_612, permute_922); permute_922 = None | |
| + permute_923 = torch.ops.aten.permute.default(view_612, [1, 0]) | |
| mm_356 = torch.ops.aten.mm.default(permute_923, view_179); permute_923 = view_179 = None | |
| - sum_208 = torch.ops.aten.sum.dim_IntList(view_567, [0], True); view_567 = None | |
| - view_568 = torch.ops.aten.view.default(sum_208, [112]); sum_208 = None | |
| - view_569 = torch.ops.aten.view.default(mm_355, [4096, 16, 96]); mm_355 = None | |
| - convert_element_type_1472 = torch.ops.prims.convert_element_type.default(view_568, torch.float32); view_568 = None | |
| - convert_element_type_1473 = torch.ops.prims.convert_element_type.default(mm_356, torch.float32); mm_356 = None | |
| - view_571 = torch.ops.aten.view.default(view_569, [4096, 1536]); view_569 = None | |
| - permute_926 = torch.ops.aten.permute.default(view_571, [1, 0]) | |
| - convert_element_type_381 = torch.ops.prims.convert_element_type.default(mm_45, torch.float32); mm_45 = None | |
| - mul_107 = torch.ops.aten.mul.Tensor(convert_element_type_381, 0.5) | |
| - mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_381, 0.7071067811865476) | |
| + sum_208 = torch.ops.aten.sum.dim_IntList(view_612, [0], True); view_612 = None | |
| + view_613 = torch.ops.aten.view.default(sum_208, [112]); sum_208 = None | |
| + view_614 = torch.ops.aten.view.default(mm_355, [4096, 16, 96]); mm_355 = None | |
| + convert_element_type_1593 = torch.ops.prims.convert_element_type.default(view_613, torch.float32); view_613 = None | |
| + convert_element_type_1594 = torch.ops.prims.convert_element_type.default(mm_356, torch.float32); mm_356 = None | |
| + view_616 = torch.ops.aten.view.default(view_614, [4096, 1536]); view_614 = None | |
| + permute_926 = torch.ops.aten.permute.default(view_616, [1, 0]) | |
| + convert_element_type_379 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16); primals_289 = None | |
| + permute_142 = torch.ops.aten.permute.default(convert_element_type_379, [1, 0]); convert_element_type_379 = None | |
| + mm_45 = torch.ops.aten.mm.default(view_176, permute_142) | |
| + convert_element_type_382 = torch.ops.prims.convert_element_type.default(mm_45, torch.float32); mm_45 = None | |
| + mul_107 = torch.ops.aten.mul.Tensor(convert_element_type_382, 0.5) | |
| + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_382, 0.7071067811865476) | |
| erf_9 = torch.ops.aten.erf.default(mul_108); mul_108 = None | |
| add_100 = torch.ops.aten.add.Tensor(erf_9, 1); erf_9 = None | |
| mul_109 = torch.ops.aten.mul.Tensor(mul_107, add_100); mul_107 = None | |
| - convert_element_type_382 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None | |
| - mm_357 = torch.ops.aten.mm.default(permute_926, convert_element_type_382); permute_926 = convert_element_type_382 = None | |
| - convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16); primals_290 = None | |
| - permute_143 = torch.ops.aten.permute.default(convert_element_type_383, [1, 0]); convert_element_type_383 = None | |
| + convert_element_type_383 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None | |
| + mm_357 = torch.ops.aten.mm.default(permute_926, convert_element_type_383); permute_926 = convert_element_type_383 = None | |
| + convert_element_type_384 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16); primals_290 = None | |
| + permute_143 = torch.ops.aten.permute.default(convert_element_type_384, [1, 0]); convert_element_type_384 = None | |
| permute_928 = torch.ops.aten.permute.default(permute_143, [1, 0]); permute_143 = None | |
| - mm_358 = torch.ops.aten.mm.default(view_571, permute_928); view_571 = permute_928 = None | |
| - convert_element_type_1478 = torch.ops.prims.convert_element_type.default(mm_357, torch.float32); mm_357 = None | |
| - convert_element_type_1479 = torch.ops.prims.convert_element_type.default(mm_358, torch.float32); mm_358 = None | |
| - mul_1217 = torch.ops.aten.mul.Tensor(add_100, 0.5); add_100 = None | |
| - mul_1218 = torch.ops.aten.mul.Tensor(convert_element_type_381, convert_element_type_381) | |
| - mul_1219 = torch.ops.aten.mul.Tensor(mul_1218, -0.5); mul_1218 = None | |
| - exp_33 = torch.ops.aten.exp.default(mul_1219); mul_1219 = None | |
| - mul_1220 = torch.ops.aten.mul.Tensor(exp_33, 0.3989422804014327); exp_33 = None | |
| - mul_1221 = torch.ops.aten.mul.Tensor(convert_element_type_381, mul_1220); convert_element_type_381 = mul_1220 = None | |
| - add_482 = torch.ops.aten.add.Tensor(mul_1217, mul_1221); mul_1217 = mul_1221 = None | |
| - mul_1222 = torch.ops.aten.mul.Tensor(convert_element_type_1479, add_482); convert_element_type_1479 = add_482 = None | |
| - convert_element_type_1481 = torch.ops.prims.convert_element_type.default(mul_1222, torch.bfloat16); mul_1222 = None | |
| - permute_930 = torch.ops.aten.permute.default(convert_element_type_1481, [1, 0]) | |
| + mm_358 = torch.ops.aten.mm.default(view_616, permute_928); view_616 = permute_928 = None | |
| + convert_element_type_1599 = torch.ops.prims.convert_element_type.default(mm_357, torch.float32); mm_357 = None | |
| + convert_element_type_1600 = torch.ops.prims.convert_element_type.default(mm_358, torch.float32); mm_358 = None | |
| + mul_1307 = torch.ops.aten.mul.Tensor(add_100, 0.5); add_100 = None | |
| + mul_1308 = torch.ops.aten.mul.Tensor(convert_element_type_382, convert_element_type_382) | |
| + mul_1309 = torch.ops.aten.mul.Tensor(mul_1308, -0.5); mul_1308 = None | |
| + exp_33 = torch.ops.aten.exp.default(mul_1309); mul_1309 = None | |
| + mul_1310 = torch.ops.aten.mul.Tensor(exp_33, 0.3989422804014327); exp_33 = None | |
| + mul_1311 = torch.ops.aten.mul.Tensor(convert_element_type_382, mul_1310); convert_element_type_382 = mul_1310 = None | |
| + add_527 = torch.ops.aten.add.Tensor(mul_1307, mul_1311); mul_1307 = mul_1311 = None | |
| + mul_1312 = torch.ops.aten.mul.Tensor(convert_element_type_1600, add_527); convert_element_type_1600 = add_527 = None | |
| + convert_element_type_1602 = torch.ops.prims.convert_element_type.default(mul_1312, torch.bfloat16); mul_1312 = None | |
| + permute_930 = torch.ops.aten.permute.default(convert_element_type_1602, [1, 0]) | |
| mm_359 = torch.ops.aten.mm.default(permute_930, view_176); permute_930 = view_176 = None | |
| - convert_element_type_378 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16); primals_289 = None | |
| - permute_142 = torch.ops.aten.permute.default(convert_element_type_378, [1, 0]); convert_element_type_378 = None | |
| permute_932 = torch.ops.aten.permute.default(permute_142, [1, 0]); permute_142 = None | |
| - mm_360 = torch.ops.aten.mm.default(convert_element_type_1481, permute_932); convert_element_type_1481 = permute_932 = None | |
| - convert_element_type_1486 = torch.ops.prims.convert_element_type.default(mm_359, torch.float32); mm_359 = None | |
| - view_572 = torch.ops.aten.view.default(mm_360, [4096, 10, 96]); mm_360 = None | |
| - permute_934 = torch.ops.aten.permute.default(view_572, [0, 2, 1]); view_572 = None | |
| + mm_360 = torch.ops.aten.mm.default(convert_element_type_1602, permute_932); convert_element_type_1602 = permute_932 = None | |
| + convert_element_type_1607 = torch.ops.prims.convert_element_type.default(mm_359, torch.float32); mm_359 = None | |
| + view_617 = torch.ops.aten.view.default(mm_360, [4096, 10, 96]); mm_360 = None | |
| + permute_934 = torch.ops.aten.permute.default(view_617, [0, 2, 1]); view_617 = None | |
| clone_113 = torch.ops.aten.clone.default(permute_934, memory_format = torch.contiguous_format); permute_934 = None | |
| - view_573 = torch.ops.aten.view.default(clone_113, [393216, 10]); clone_113 = None | |
| - permute_935 = torch.ops.aten.permute.default(view_573, [1, 0]) | |
| - convert_element_type_default_30 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_699_mul_94, torch.float32); fp8_quant_pos_699_mul_94 = None | |
| - div_tensor_10 = torch.ops.aten.div.Tensor(convert_element_type_default_30, fp8_scale_pos_699_mul_94); convert_element_type_default_30 = fp8_scale_pos_699_mul_94 = None | |
| - convert_element_type_default_31 = torch.ops.prims.convert_element_type.default(div_tensor_10, torch.float32); div_tensor_10 = None | |
| + view_618 = torch.ops.aten.view.default(clone_113, [393216, 10]); clone_113 = None | |
| + permute_935 = torch.ops.aten.permute.default(view_618, [1, 0]) | |
| + convert_element_type_default_27 = torch.ops.prims.convert_element_type.default(fp8_quant_pos_688_mul_94, torch.float32); fp8_quant_pos_688_mul_94 = None | |
| + div_tensor_10 = torch.ops.aten.div.Tensor(convert_element_type_default_27, fp8_scale_pos_688_mul_94); convert_element_type_default_27 = fp8_scale_pos_688_mul_94 = None | |
| + convert_element_type_default_28 = torch.ops.prims.convert_element_type.default(div_tensor_10, torch.float32); div_tensor_10 = None | |
| view_164 = torch.ops.aten.view.default(mm_42, [-1, 16, 96]); mm_42 = None | |
| - convert_element_type_361 = torch.ops.prims.convert_element_type.default(view_164, torch.float32) | |
| - var_mean_19 = torch.ops.aten.var_mean.correction(convert_element_type_361, [2], correction = 0, keepdim = True) | |
| + convert_element_type_362 = torch.ops.prims.convert_element_type.default(view_164, torch.float32) | |
| + var_mean_19 = torch.ops.aten.var_mean.correction(convert_element_type_362, [2], correction = 0, keepdim = True) | |
| getitem_176 = var_mean_19[0] | |
| getitem_177 = var_mean_19[1]; var_mean_19 = None | |
| add_94 = torch.ops.aten.add.Tensor(getitem_176, 1e-05); getitem_176 = None | |
| rsqrt_37 = torch.ops.aten.rsqrt.default(add_94); add_94 = None | |
| - sub_30 = torch.ops.aten.sub.Tensor(convert_element_type_361, getitem_177); convert_element_type_361 = getitem_177 = None | |
| + sub_30 = torch.ops.aten.sub.Tensor(convert_element_type_362, getitem_177); convert_element_type_362 = getitem_177 = None | |
| mul_101 = torch.ops.aten.mul.Tensor(sub_30, rsqrt_37); sub_30 = None | |
| mul_102 = torch.ops.aten.mul.Tensor(mul_101, primals_280) | |
| add_95 = torch.ops.aten.add.Tensor(mul_102, primals_281); mul_102 = primals_281 = None | |
| - mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_default_31, primals_282) | |
| + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_default_28, primals_282) | |
| add_97 = torch.ops.aten.add.Tensor(mul_104, primals_283); mul_104 = primals_283 = None | |
| - convert_element_type_362 = torch.ops.prims.convert_element_type.default(add_95, torch.bfloat16); add_95 = None | |
| - convert_element_type_363 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16); primals_284 = None | |
| - convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16); primals_285 = None | |
| - view_166 = torch.ops.aten.view.default(convert_element_type_362, [65536, 96]); convert_element_type_362 = None | |
| - permute_133 = torch.ops.aten.permute.default(convert_element_type_363, [1, 0]); convert_element_type_363 = None | |
| - addmm_26 = torch.ops.aten.addmm.default(convert_element_type_364, view_166, permute_133); convert_element_type_364 = None | |
| + convert_element_type_363 = torch.ops.prims.convert_element_type.default(add_95, torch.bfloat16); add_95 = None | |
| + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16); primals_284 = None | |
| + convert_element_type_365 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16); primals_285 = None | |
| + view_166 = torch.ops.aten.view.default(convert_element_type_363, [65536, 96]); convert_element_type_363 = None | |
| + permute_133 = torch.ops.aten.permute.default(convert_element_type_364, [1, 0]); convert_element_type_364 = None | |
| + addmm_26 = torch.ops.aten.addmm.default(convert_element_type_365, view_166, permute_133); convert_element_type_365 = None | |
| view_167 = torch.ops.aten.view.default(addmm_26, [4096, 16, 96]); addmm_26 = None | |
| - convert_element_type_368 = torch.ops.prims.convert_element_type.default(add_97, torch.bfloat16) | |
| - convert_element_type_369 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16); primals_286 = None | |
| - permute_134 = torch.ops.aten.permute.default(convert_element_type_369, [1, 0]); convert_element_type_369 = None | |
| - view_168 = torch.ops.aten.view.default(convert_element_type_368, [819200, 96]); convert_element_type_368 = None | |
| + convert_element_type_369 = torch.ops.prims.convert_element_type.default(add_97, torch.bfloat16) | |
| + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16); primals_286 = None | |
| + permute_134 = torch.ops.aten.permute.default(convert_element_type_370, [1, 0]); convert_element_type_370 = None | |
| + view_168 = torch.ops.aten.view.default(convert_element_type_369, [819200, 96]); convert_element_type_369 = None | |
| mm_43 = torch.ops.aten.mm.default(view_168, permute_134) | |
| view_169 = torch.ops.aten.view.default(mm_43, [4096, 200, 96]); mm_43 = None | |
| view_170 = torch.ops.aten.view.default(view_167, [4096, 16, 1, 96]); view_167 = None | |
| @@ -8504,8 +5835,8 @@ | |
| permute_136 = torch.ops.aten.permute.default(view_171, [0, 2, 1, 3]); view_171 = None | |
| view_172 = torch.ops.aten.view.default(add_97, [4096, 200, 1, 96]); add_97 = None | |
| permute_137 = torch.ops.aten.permute.default(view_172, [0, 2, 1, 3]); view_172 = None | |
| - convert_element_type_372 = torch.ops.prims.convert_element_type.default(permute_137, torch.bfloat16); permute_137 = None | |
| - _scaled_dot_product_flash_attention_13 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_135, permute_136, convert_element_type_372, scale = 0.10206207261596577) | |
| + convert_element_type_373 = torch.ops.prims.convert_element_type.default(permute_137, torch.bfloat16); permute_137 = None | |
| + _scaled_dot_product_flash_attention_13 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_135, permute_136, convert_element_type_373, scale = 0.10206207261596577) | |
| getitem_180 = _scaled_dot_product_flash_attention_13[0] | |
| getitem_181 = _scaled_dot_product_flash_attention_13[1] | |
| getitem_186 = _scaled_dot_product_flash_attention_13[6] | |
| @@ -8513,123 +5844,106 @@ | |
| permute_138 = torch.ops.aten.permute.default(getitem_180, [0, 2, 1, 3]) | |
| view_173 = torch.ops.aten.view.default(permute_138, [4096, 16, 96]); permute_138 = None | |
| add_98 = torch.ops.aten.add.Tensor(view_173, view_164); view_173 = view_164 = None | |
| - convert_element_type_373 = torch.ops.prims.convert_element_type.default(add_98, torch.float32); add_98 = None | |
| - pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_373, 2) | |
| + convert_element_type_374 = torch.ops.prims.convert_element_type.default(add_98, torch.float32); add_98 = None | |
| + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_374, 2) | |
| mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None | |
| add_99 = torch.ops.aten.add.Scalar(mean_18, 1.1920928955078125e-07); mean_18 = None | |
| rsqrt_39 = torch.ops.aten.rsqrt.default(add_99); add_99 = None | |
| - mul_105 = torch.ops.aten.mul.Tensor(convert_element_type_373, rsqrt_39); convert_element_type_373 = None | |
| + mul_105 = torch.ops.aten.mul.Tensor(convert_element_type_374, rsqrt_39) | |
| mul_106 = torch.ops.aten.mul.Tensor(mul_105, primals_287) | |
| convert_element_type_375 = torch.ops.prims.convert_element_type.default(mul_106, torch.bfloat16); mul_106 = None | |
| permute_139 = torch.ops.aten.permute.default(convert_element_type_375, [0, 2, 1]); convert_element_type_375 = None | |
| clone_45 = torch.ops.aten.clone.default(permute_139, memory_format = torch.contiguous_format); permute_139 = None | |
| view_174 = torch.ops.aten.view.default(clone_45, [393216, 16]); clone_45 = None | |
| mm_361 = torch.ops.aten.mm.default(permute_935, view_174); permute_935 = view_174 = None | |
| - convert_element_type_374 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16); primals_288 = None | |
| - permute_140 = torch.ops.aten.permute.default(convert_element_type_374, [1, 0]); convert_element_type_374 = None | |
| + convert_element_type_376 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16); primals_288 = None | |
| + permute_140 = torch.ops.aten.permute.default(convert_element_type_376, [1, 0]); convert_element_type_376 = None | |
| permute_937 = torch.ops.aten.permute.default(permute_140, [1, 0]); permute_140 = None | |
| - mm_362 = torch.ops.aten.mm.default(view_573, permute_937); view_573 = permute_937 = None | |
| - view_574 = torch.ops.aten.view.default(mm_362, [4096, 96, 16]); mm_362 = None | |
| - permute_939 = torch.ops.aten.permute.default(view_574, [0, 2, 1]); view_574 = None | |
| - convert_element_type_1491 = torch.ops.prims.convert_element_type.default(permute_939, torch.float32); permute_939 = None | |
| - convert_element_type_1492 = torch.ops.prims.convert_element_type.default(mm_361, torch.float32); mm_361 = None | |
| - mul_1223 = torch.ops.aten.mul.Tensor(convert_element_type_1491, primals_287); primals_287 = None | |
| - mul_1225 = torch.ops.aten.mul.Tensor(mul_105, mul_1223) | |
| - sum_209 = torch.ops.aten.sum.dim_IntList(mul_1225, [2], True); mul_1225 = None | |
| - div_143 = torch.ops.aten.div.Tensor(mul_105, 96) | |
| - mul_1226 = torch.ops.aten.mul.Tensor(div_143, sum_209); div_143 = sum_209 = None | |
| - sub_286 = torch.ops.aten.sub.Tensor(mul_1223, mul_1226); mul_1223 = mul_1226 = None | |
| - mul_1227 = torch.ops.aten.mul.Tensor(sub_286, rsqrt_39); sub_286 = rsqrt_39 = None | |
| - mul_1228 = torch.ops.aten.mul.Tensor(convert_element_type_1491, mul_105); convert_element_type_1491 = mul_105 = None | |
| - sum_210 = torch.ops.aten.sum.dim_IntList(mul_1228, [0, 1]); mul_1228 = None | |
| - convert_element_type_1493 = torch.ops.prims.convert_element_type.default(mul_1227, torch.bfloat16); mul_1227 = None | |
| - view_575 = torch.ops.aten.view.default(convert_element_type_1493, [4096, 16, 1, 96]) | |
| - permute_940 = torch.ops.aten.permute.default(view_575, [0, 2, 1, 3]); view_575 = None | |
| - _scaled_dot_product_flash_attention_backward = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_940, permute_135, permute_136, convert_element_type_372, getitem_180, getitem_181, None, None, 16, 200, 0.0, False, getitem_186, getitem_187, scale = 0.10206207261596577); permute_940 = permute_135 = permute_136 = convert_element_type_372 = getitem_180 = getitem_181 = getitem_186 = getitem_187 = None | |
| + mm_362 = torch.ops.aten.mm.default(view_618, permute_937); view_618 = permute_937 = None | |
| + view_619 = torch.ops.aten.view.default(mm_362, [4096, 96, 16]); mm_362 = None | |
| + permute_939 = torch.ops.aten.permute.default(view_619, [0, 2, 1]); view_619 = None | |
| + convert_element_type_1612 = torch.ops.prims.convert_element_type.default(mm_361, torch.float32); mm_361 = None | |
| + convert_element_type_1613 = torch.ops.prims.convert_element_type.default(permute_939, torch.float32); permute_939 = None | |
| + mul_1313 = torch.ops.aten.mul.Tensor(convert_element_type_1613, mul_105); mul_105 = None | |
| + mul_1314 = torch.ops.aten.mul.Tensor(convert_element_type_1613, primals_287); convert_element_type_1613 = primals_287 = None | |
| + sum_209 = torch.ops.aten.sum.dim_IntList(mul_1313, [0, 1], True); mul_1313 = None | |
| + view_620 = torch.ops.aten.view.default(sum_209, [96]); sum_209 = None | |
| + mul_1315 = torch.ops.aten.mul.Tensor(mul_1314, convert_element_type_374) | |
| + mul_1316 = torch.ops.aten.mul.Tensor(mul_1314, rsqrt_39); mul_1314 = None | |
| + sum_210 = torch.ops.aten.sum.dim_IntList(mul_1315, [2], True); mul_1315 = None | |
| + mul_1317 = torch.ops.aten.mul.Scalar(sum_210, -0.5); sum_210 = None | |
| + pow_174 = torch.ops.aten.pow.Tensor_Scalar(rsqrt_39, 3); rsqrt_39 = None | |
| + mul_1318 = torch.ops.aten.mul.Tensor(mul_1317, pow_174); mul_1317 = pow_174 = None | |
| + expand_108 = torch.ops.aten.expand.default(mul_1318, [4096, 16, 96]); mul_1318 = None | |
| + div_143 = torch.ops.aten.div.Scalar(expand_108, 96); expand_108 = None | |
| + pow_175 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_374, 1.0); convert_element_type_374 = None | |
| + mul_1319 = torch.ops.aten.mul.Scalar(pow_175, 2.0); pow_175 = None | |
| + mul_1320 = torch.ops.aten.mul.Tensor(div_143, mul_1319); div_143 = mul_1319 = None | |
| + add_528 = torch.ops.aten.add.Tensor(mul_1316, mul_1320); mul_1316 = mul_1320 = None | |
| + convert_element_type_1614 = torch.ops.prims.convert_element_type.default(add_528, torch.bfloat16); add_528 = None | |
| + view_621 = torch.ops.aten.view.default(convert_element_type_1614, [4096, 16, 1, 96]) | |
| + permute_940 = torch.ops.aten.permute.default(view_621, [0, 2, 1, 3]); view_621 = None | |
| + _scaled_dot_product_flash_attention_backward = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_940, permute_135, permute_136, convert_element_type_373, getitem_180, getitem_181, None, None, 16, 200, 0.0, False, getitem_186, getitem_187, scale = 0.10206207261596577); permute_940 = permute_135 = permute_136 = convert_element_type_373 = getitem_180 = getitem_181 = getitem_186 = getitem_187 = None | |
| getitem_189 = _scaled_dot_product_flash_attention_backward[0] | |
| getitem_190 = _scaled_dot_product_flash_attention_backward[1] | |
| getitem_191 = _scaled_dot_product_flash_attention_backward[2]; _scaled_dot_product_flash_attention_backward = None | |
| - convert_element_type_1494 = torch.ops.prims.convert_element_type.default(getitem_191, torch.float32); getitem_191 = None | |
| - permute_941 = torch.ops.aten.permute.default(convert_element_type_1494, [0, 2, 1, 3]); convert_element_type_1494 = None | |
| - view_576 = torch.ops.aten.view.default(permute_941, [4096, 200, 96]); permute_941 = None | |
| + convert_element_type_1615 = torch.ops.prims.convert_element_type.default(getitem_191, torch.float32); getitem_191 = None | |
| + permute_941 = torch.ops.aten.permute.default(convert_element_type_1615, [0, 2, 1, 3]); convert_element_type_1615 = None | |
| + view_622 = torch.ops.aten.view.default(permute_941, [4096, 200, 96]); permute_941 = None | |
| permute_942 = torch.ops.aten.permute.default(getitem_190, [0, 2, 1, 3]); getitem_190 = None | |
| - view_577 = torch.ops.aten.view.default(permute_942, [4096, 200, 96]); permute_942 = None | |
| + view_623 = torch.ops.aten.view.default(permute_942, [4096, 200, 96]); permute_942 = None | |
| permute_943 = torch.ops.aten.permute.default(getitem_189, [0, 2, 1, 3]); getitem_189 = None | |
| - view_578 = torch.ops.aten.view.default(permute_943, [4096, 16, 96]); permute_943 = None | |
| - view_579 = torch.ops.aten.view.default(view_577, [819200, 96]); view_577 = None | |
| - permute_944 = torch.ops.aten.permute.default(view_579, [1, 0]) | |
| + view_624 = torch.ops.aten.view.default(permute_943, [4096, 16, 96]); permute_943 = None | |
| + view_625 = torch.ops.aten.view.default(view_623, [819200, 96]); view_623 = None | |
| + permute_944 = torch.ops.aten.permute.default(view_625, [1, 0]) | |
| mm_363 = torch.ops.aten.mm.default(permute_944, view_168); permute_944 = view_168 = None | |
| permute_946 = torch.ops.aten.permute.default(permute_134, [1, 0]); permute_134 = None | |
| - mm_364 = torch.ops.aten.mm.default(view_579, permute_946); view_579 = permute_946 = None | |
| - view_580 = torch.ops.aten.view.default(mm_364, [4096, 200, 96]); mm_364 = None | |
| - convert_element_type_1499 = torch.ops.prims.convert_element_type.default(mm_363, torch.float32); mm_363 = None | |
| - co |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment