Skip to content

Instantly share code, notes, and snippets.

@crcrpar
Created February 25, 2026 05:50
Show Gist options
  • Select an option

  • Save crcrpar/70935c9eba7e60bb40639201a57a6a49 to your computer and use it in GitHub Desktop.

Select an option

Save crcrpar/70935c9eba7e60bb40639201a57a6a49 to your computer and use it in GitHub Desktop.
Investigation: ILP tuning for PyTorch foreach vectorized loads — hypothesis, ncu profiling, analysis, and conclusion

ILP Tuning for _foreach Vectorized Loads: Investigation & Results

Hypothesis

PyTorch's multi_tensor_apply kernel uses kILP = 4 for vectorized memory access. For 2-byte types (fp16, bf16), this means each vectorized load is 4 * 2 = 8 bytes (64-bit LDG.64). Modern GPUs support 128-bit loads (LDG.128). By increasing ILP to 8 for 16-bit types, each thread would load 16 bytes per instruction, potentially doubling memory throughput.

The change introduces effective_ilp<T>():

template <typename T>
constexpr int64_t effective_ilp() {
  constexpr int64_t target = static_cast<int64_t>(16 / sizeof(T));
  return (target < kILP) ? kILP : ((target > 8) ? 8 : target);
}
dtype sizeof ILP (before) ILP (after) Load width (before) Load width (after)
fp32 4 4 4 128-bit 128-bit (unchanged)
fp16 2 4 8 64-bit 128-bit
bf16 2 4 8 64-bit 128-bit

All 13 _foreach functors in ForeachFunctors.cuh were updated to use effective_ilp<T>() for their register arrays, loop bounds, and helper function calls. Fused optimizer functors (depth 4-5, already register-heavy) were left at the default ILP=4.

Measurement Setup

  • GPU: NVIDIA RTX 6000 Ada Generation (SM89, 48 GB GDDR6X, ~960 GB/s)
  • L2 Cache: 96 MB
  • Profiler: NVIDIA Nsight Compute 2025.4.0
  • Workload: _foreach_add_(a, b) / _foreach_lerp_(a, b, 0.3) with 50 tensors, numel=262144
    • Total data per op: 50 * 262144 * 2B * 3 = 75 MB (bf16 add: read a, read b, write a)
    • Fits in L2 cache (96 MB)
  • Kernel: multi_tensor_apply_kernel, block_size=512, grid=200

ncu command

ncu --target-processes all \
    --set full \
    --kernel-name "multi_tensor_apply_kernel" \
    --launch-skip 5 --launch-count 3 \
    -o <output_name> \
    python ncu_foreach_profile.py --op add --dtype bf16 --numel 262144

Nsight Compute Results

Kernel-Level Metrics (averaged over 3 launches)

Config Time (us) Regs/thread Occupancy % L2 Read Sectors L2 Write Sectors L2 Bytes (MB) Eff BW (GB/s)
bf16 add BEFORE 64.76 38 46.1 1,638,400 819,200 78.6 1,214
bf16 add AFTER 65.90 44 45.9 1,638,400 819,200 78.6 1,193
fp16 add BEFORE 64.25 39 46.2 1,638,400 819,200 78.6 1,224
fp16 add AFTER 65.91 44 45.6 1,638,400 819,200 78.6 1,193
bf16 lerp BEFORE 64.10 37 46.2 1,638,400 819,200 78.6 1,227
bf16 lerp AFTER 65.58 40 45.7 1,638,400 819,200 78.6 1,199
fp32 add BEFORE 136.90 40 44.9 3,276,800 1,638,400 157.3 1,149
fp32 add AFTER 135.71 40 45.0 3,276,800 1,638,400 157.3 1,159

Key Observations

  1. L2 sector counts are identical. The total data moved is unchanged. ILP=8 doesn't move more data; it just changes the width of individual load/store instructions.

  2. Register pressure increased. bf16 add went from 38 to 44 registers per thread (+16%). This is the direct cost of doubling r_args[depth][ilp] from [2][4] to [2][8].

  3. Occupancy dropped. ~46.1% to ~45.9%. More registers per thread means fewer concurrent warps per SM, reducing the GPU's ability to hide memory latency.

  4. The kernel is ~1.7% slower for bf16/fp16 on SM89. The wider vectorized loads don't compensate for the occupancy loss.

  5. fp32 is unchanged (control). effective_ilp<float>() returns 4, same as baseline.

Analysis: Why It Doesn't Help

The Wrong Tradeoff

GPUs hide memory latency through warp-level parallelism (switching between warps while one is waiting for data), not through instruction-level parallelism within a single thread.

The ILP=8 change trades occupancy for per-thread throughput:

                    ILP=4 (before)              ILP=8 (after)
Load instructions   2x more per thread          1x (halved)
Registers/thread    38                          44
Active warps/SM     higher                      lower
Latency hiding      better (more warps)         worse (fewer warps)

At block_size=512 with ~46% occupancy, the kernel already has enough warps in flight to keep the memory pipeline busy with 64-bit loads. Halving the load count per thread doesn't help when the bottleneck is warp-level scheduling, not instruction issue rate.

Would Hopper (H100) or Blackwell (B200/GB200) Be Different?

Unlikely, for the same fundamental reason:

  • H100 (SM90): 3.35 TB/s HBM3, 50 MB L2, 65536 regs/SM. Higher bandwidth, but the SM count scales proportionally. Per-SM bandwidth requirement: ~25 GB/s. With ~4 active warps issuing 64-bit loads, each SM can sustain ~32 bytes/cycle — already sufficient. The register pressure vs occupancy tradeoff is architecture-independent.

  • B200/GB200 (SM100): ~8 TB/s HBM3e. Same argument: higher aggregate bandwidth but more SMs to distribute it. The per-SM arithmetic doesn't change enough to make wider loads the bottleneck.

The Irony

The kernels that could benefit from wider loads are those with already-low occupancy where fewer warps means the memory pipeline can't stay busy. These are exactly the fused optimizer kernels (depth=4-5, ~30-40 regs/thread) — but those are the ones we deliberately excluded from ILP=8 because adding more registers would make their occupancy even worse.

Conclusion

Increasing ILP from 4 to 8 for 16-bit foreach ops is a slight negative on SM89 and likely neutral-to-negative on SM90+. The correct levers for improving foreach memory throughput are:

  • Block size / chunk size tuning — affects work distribution and occupancy without increasing register pressure
  • cp.async (SM90+) — asynchronous loads that don't consume registers while in flight
  • Reducing register pressure in functors (e.g., shared memory for metadata) to allow higher occupancy
  • Persistent kernel patterns for multi-tensor workloads

The ILP tuning commit should be reverted.

Appendix: Profiling Script

"""Profile foreach kernels with nsight compute.

Usage:
    ncu --target-processes all \
        --set full \
        --kernel-name "multi_tensor_apply_kernel" \
        --launch-skip 5 --launch-count 3 \
        -o foreach_profile \
        python ncu_foreach_profile.py --op add --dtype bf16 --numel 262144
"""

import argparse
import torch


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--op", choices=["add", "mul", "lerp"], default="add")
    parser.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
    parser.add_argument("--numel", type=int, default=262144)
    parser.add_argument("--num-tensors", type=int, default=50)
    parser.add_argument("--iters", type=int, default=10)
    args = parser.parse_args()

    dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
    dtype = dtype_map[args.dtype]

    a = [torch.randn(args.numel, device="cuda", dtype=dtype) for _ in range(args.num_tensors)]
    b = [torch.randn(args.numel, device="cuda", dtype=dtype) for _ in range(args.num_tensors)]

    if args.op == "add":
        fn = lambda: torch._foreach_add_(a, b)
    elif args.op == "mul":
        fn = lambda: torch._foreach_mul_(a, 0.5)
    elif args.op == "lerp":
        fn = lambda: torch._foreach_lerp_(a, b, 0.3)

    for _ in range(5):
        fn()
    torch.cuda.synchronize()

    for _ in range(args.iters):
        fn()
    torch.cuda.synchronize()


if __name__ == "__main__":
    main()
diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh
index 37175776097..f6e18b9cb47 100644
--- a/aten/src/ATen/native/cuda/ForeachFunctors.cuh
+++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh
@@ -18,8 +18,7 @@ inline void increment_version(TensorList tensors) {
}
}
-// Initializes args and checks if all args are aligned
-template <int depth, typename T>
+template <int depth, int64_t ilp = kILP, typename T>
__device__ bool init_args(
T** args,
TensorListMetadata<depth>& tl,
@@ -31,15 +30,14 @@ __device__ bool init_args(
args[i] = (T*)tl.addresses[i][tensor_loc];
args[i] += chunk_idx * chunk_size;
- if (!is_aligned(args[i])) {
+ if (!is_aligned<ilp>(args[i])) {
all_aligned = false;
}
}
return all_aligned;
}
-// Initializes args and checks if all args are aligned
-template <int depth, typename T, typename T2>
+template <int depth, int64_t ilp = kILP, typename T, typename T2>
__device__ bool init_args(
T** args,
TensorListScalarListMetadata<T2, depth>& tl,
@@ -51,14 +49,14 @@ __device__ bool init_args(
args[i] = (T*)tl.addresses[i][tensor_loc];
args[i] += chunk_idx * chunk_size;
- if (!is_aligned(args[i])) {
+ if (!is_aligned<ilp>(args[i])) {
all_aligned = false;
}
}
return all_aligned;
}
-template <int depth, typename T>
+template <int depth, int64_t ilp = kILP, typename T>
__device__ bool init_args(
T** args,
FusedOptimizerTensorListMetadata<depth>& tl,
@@ -70,22 +68,22 @@ __device__ bool init_args(
args[i] = (T*)tl.addresses[i][tensor_loc];
args[i] += chunk_idx * chunk_size;
- if (!is_aligned(args[i])) {
+ if (!is_aligned<ilp>(args[i])) {
all_aligned = false;
}
}
return all_aligned;
}
-template <int depth, typename T>
+template <int depth, int64_t ilp = kILP, typename T>
__device__ void load_args(
- T r_args[][kILP],
+ T r_args[][ilp],
T** args,
const int64_t i_start,
const int64_t chunk_size,
const int64_t n) {
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
const auto i = i_start + threadIdx.x + ii * blockDim.x;
for (int r_index = 0; r_index < depth; r_index++) {
r_args[r_index][ii] = 0;
@@ -96,7 +94,7 @@ __device__ void load_args(
}
}
-template <typename T>
+template <int64_t ilp = kILP, typename T>
__device__ void store_args(
T* dst,
T* src,
@@ -104,93 +102,93 @@ __device__ void store_args(
const int64_t chunk_size,
const int64_t n) {
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
const int64_t i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size)
dst[i] = src[ii];
}
}
-template <int res_arg_index, typename Op, typename T, typename opmath_t>
+template <
+ int res_arg_index,
+ int64_t ilp = kILP,
+ typename Op,
+ typename T,
+ typename opmath_t>
__device__ __forceinline__ void binary_op_scalar(
- T r_args[][kILP],
+ T r_args[][ilp],
T** args,
opmath_t scalar,
const int64_t n,
const int64_t chunk_size,
const bool all_aligned,
Op op) {
- // to make things simple, we put aligned case in a different code path
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
- i_start * kILP < n && i_start * kILP < chunk_size;
+ i_start * ilp < n && i_start * ilp < chunk_size;
i_start += blockDim.x) {
- // load
- load_store(r_args[0], args[0], 0, i_start);
+ load_store<ilp>(r_args[0], args[0], 0, i_start);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] = static_cast<T>(
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(scalar)));
}
- // store
- load_store(args[res_arg_index], r_args[0], i_start, 0);
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
- i_start += blockDim.x * kILP) {
- // Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args
- // has depth 1
- load_args<1>(r_args, args, i_start, chunk_size, n);
+ i_start += blockDim.x * ilp) {
+ load_args<1, ilp>(r_args, args, i_start, chunk_size, n);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] = static_cast<T>(
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(scalar)));
}
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
-template <int res_arg_index, typename Op, typename T, typename opmath_t>
+template <
+ int res_arg_index,
+ int64_t ilp = kILP,
+ typename Op,
+ typename T,
+ typename opmath_t>
__device__ __forceinline__ void pointwise_op_scalar(
- T r_args[][kILP],
+ T r_args[][ilp],
T** args,
opmath_t scalar,
const int64_t n,
const int64_t chunk_size,
const bool all_aligned,
Op op) {
- // to make things simple, we put aligned case in a different code path
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
- i_start * kILP < n && i_start * kILP < chunk_size;
+ i_start * ilp < n && i_start * ilp < chunk_size;
i_start += blockDim.x) {
- // load
- load_store(r_args[0], args[0], 0, i_start);
- load_store(r_args[1], args[1], 0, i_start);
- load_store(r_args[2], args[2], 0, i_start);
+ load_store<ilp>(r_args[0], args[0], 0, i_start);
+ load_store<ilp>(r_args[1], args[1], 0, i_start);
+ load_store<ilp>(r_args[2], args[2], 0, i_start);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] = pointwise_op_impl<opmath_t>(
r_args[0][ii], r_args[1][ii], r_args[2][ii], scalar, op);
}
- // store
- load_store(args[res_arg_index], r_args[0], i_start, 0);
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
- i_start += blockDim.x * kILP) {
- // Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args
- // has depth 3
- load_args<3>(r_args, args, i_start, chunk_size, n);
+ i_start += blockDim.x * ilp) {
+ load_args<3, ilp>(r_args, args, i_start, chunk_size, n);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] = pointwise_op_impl<opmath_t>(
r_args[0][ii], r_args[1][ii], r_args[2][ii], scalar, op);
}
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
@@ -201,6 +199,7 @@ __device__ __forceinline__ void pointwise_op_scalar(
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct BinaryOpScalarFunctor {
using opmath_t = at::opmath_type<T>;
+ static constexpr int64_t ilp = effective_ilp<T>();
template <typename Op>
__device__ __forceinline__ void operator()(
int64_t chunk_size,
@@ -213,11 +212,11 @@ struct BinaryOpScalarFunctor {
T* args[depth];
const bool all_aligned =
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
- T r_args[r_args_depth][kILP];
+ T r_args[r_args_depth][ilp];
- binary_op_scalar<res_arg_index>(
+ binary_op_scalar<res_arg_index, ilp>(
r_args, args, scalar, n, chunk_size, all_aligned, op);
}
};
@@ -225,6 +224,7 @@ struct BinaryOpScalarFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct BinaryOpScalarListFunctor {
using opmath_t = at::opmath_type<T>;
+ static constexpr int64_t ilp = effective_ilp<T>();
template <typename Op>
__device__ __forceinline__ void operator()(
int64_t chunk_size,
@@ -236,12 +236,12 @@ struct BinaryOpScalarListFunctor {
T* args[depth];
const bool all_aligned =
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc);
opmath_t scalar = tl.scalar_vals[tensor_loc];
n -= chunk_idx * chunk_size;
- T r_args[r_args_depth][kILP];
+ T r_args[r_args_depth][ilp];
- binary_op_scalar<res_arg_index>(
+ binary_op_scalar<res_arg_index, ilp>(
r_args, args, scalar, n, chunk_size, all_aligned, op);
}
};
@@ -249,6 +249,7 @@ struct BinaryOpScalarListFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct BinaryOpListAlphaFunctor {
using opmath_t = at::opmath_type<T>;
+ static constexpr int64_t ilp = effective_ilp<T>();
template <typename Op>
__device__ __forceinline__ void operator()(
int64_t chunk_size,
@@ -261,38 +262,35 @@ struct BinaryOpListAlphaFunctor {
T* args[depth];
const bool all_aligned =
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
- T r_args[r_args_depth][kILP];
+ T r_args[r_args_depth][ilp];
- // to make things simple, we put aligned case in a different code path
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
- i_start * kILP < n && i_start * kILP < chunk_size;
+ i_start * ilp < n && i_start * ilp < chunk_size;
i_start += blockDim.x) {
- // load
- load_store(r_args[0], args[0], 0, i_start);
- load_store(r_args[1], args[1], 0, i_start);
+ load_store<ilp>(r_args[0], args[0], 0, i_start);
+ load_store<ilp>(r_args[1], args[1], 0, i_start);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] = static_cast<T>(
op(static_cast<opmath_t>(r_args[0][ii]),
alpha * static_cast<opmath_t>(r_args[1][ii])));
}
- // store
- load_store(args[res_arg_index], r_args[0], i_start, 0);
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
- i_start += blockDim.x * kILP) {
- load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
+ i_start += blockDim.x * ilp) {
+ load_args<r_args_depth, ilp>(r_args, args, i_start, chunk_size, n);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] = static_cast<T>(
op(static_cast<opmath_t>(r_args[0][ii]),
alpha * static_cast<opmath_t>(r_args[1][ii])));
}
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
@@ -301,6 +299,7 @@ struct BinaryOpListAlphaFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct BinaryOpScalarTensorFunctor {
using opmath_t = at::opmath_type<T>;
+ static constexpr int64_t ilp = effective_ilp<T>();
template <typename Op>
__device__ __forceinline__ void operator()(
int64_t chunk_size,
@@ -314,39 +313,34 @@ struct BinaryOpScalarTensorFunctor {
T* args[depth];
const bool all_aligned =
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
- T r_args[r_args_depth][kILP];
+ T r_args[r_args_depth][ilp];
- // to make things simple, we put aligned case in a different code path
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
- i_start * kILP < n && i_start * kILP < chunk_size;
+ i_start * ilp < n && i_start * ilp < chunk_size;
i_start += blockDim.x) {
- // load
- load_store(r_args[0], args[0], 0, i_start);
+ load_store<ilp>(r_args[0], args[0], 0, i_start);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] = static_cast<T>(op(
static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
}
- // store
- load_store(args[res_arg_index], r_args[0], i_start, 0);
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
- i_start += blockDim.x * kILP) {
- // Regardless if depth is 1 (for inplace) or 2 (for out of place),
- // r_args has depth 1
- load_args<1>(r_args, args, i_start, chunk_size, n);
+ i_start += blockDim.x * ilp) {
+ load_args<1, ilp>(r_args, args, i_start, chunk_size, n);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] = static_cast<T>(op(
static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
}
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
@@ -358,6 +352,7 @@ struct BinaryOpScalarTensorFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct ZeroFunctor {
+ static constexpr int64_t ilp = effective_ilp<T>();
__device__ __forceinline__ void operator()(
int64_t chunk_size,
TensorListMetadata<1>& tl) {
@@ -367,30 +362,28 @@ struct ZeroFunctor {
T* args[depth];
const auto all_aligned =
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
- T r_args[r_args_depth][kILP];
+ T r_args[r_args_depth][ilp];
- // to make things simple, we put aligned case in a different code path
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
- i_start * kILP < n && i_start * kILP < chunk_size;
+ i_start * ilp < n && i_start * ilp < chunk_size;
i_start += blockDim.x) {
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] = 0;
}
- // store
- load_store(args[0], r_args[0], i_start, 0);
+ load_store<ilp>(args[0], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
- i_start += blockDim.x * kILP) {
+ i_start += blockDim.x * ilp) {
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] = 0;
}
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
@@ -399,6 +392,7 @@ struct ZeroFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct UnaryOpFunctor {
using opmath_t = at::opmath_type<T>;
+ static constexpr int64_t ilp = effective_ilp<T>();
template <typename Op>
__device__ __forceinline__ void operator()(
int64_t chunk_size,
@@ -410,35 +404,32 @@ struct UnaryOpFunctor {
T* args[depth];
bool all_aligned =
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
- T r_args[r_args_depth][kILP];
+ T r_args[r_args_depth][ilp];
- // to make things simple, we put aligned case in a different code path
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
- i_start * kILP < n && i_start * kILP < chunk_size;
+ i_start * ilp < n && i_start * ilp < chunk_size;
i_start += blockDim.x) {
- // load
- load_store(r_args[0], args[0], 0, i_start);
+ load_store<ilp>(r_args[0], args[0], 0, i_start);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] =
static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
}
- // store
- load_store(args[res_arg_index], r_args[0], i_start, 0);
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
- i_start += blockDim.x * kILP) {
- load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
+ i_start += blockDim.x * ilp) {
+ load_args<r_args_depth, ilp>(r_args, args, i_start, chunk_size, n);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] =
static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
}
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
@@ -451,6 +442,7 @@ struct UnaryOpFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct PointwiseOpScalarFunctor {
using opmath_t = at::opmath_type<T>;
+ static constexpr int64_t ilp = effective_ilp<T>();
template <typename Op>
__device__ __forceinline__ void operator()(
int64_t chunk_size,
@@ -463,11 +455,11 @@ struct PointwiseOpScalarFunctor {
T* args[depth];
const bool all_aligned =
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
- T r_args[r_args_depth][kILP];
+ T r_args[r_args_depth][ilp];
- pointwise_op_scalar<res_arg_index>(
+ pointwise_op_scalar<res_arg_index, ilp>(
r_args, args, scalar, n, chunk_size, all_aligned, op);
}
};
@@ -475,6 +467,7 @@ struct PointwiseOpScalarFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct PointwiseOpScalarListFunctor {
using opmath_t = at::opmath_type<T>;
+ static constexpr int64_t ilp = effective_ilp<T>();
template <typename Op>
__device__ __forceinline__ void operator()(
int64_t chunk_size,
@@ -486,12 +479,12 @@ struct PointwiseOpScalarListFunctor {
T* args[depth];
const bool all_aligned =
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc);
opmath_t scalar = tl.scalar_vals[tensor_loc];
n -= chunk_idx * chunk_size;
- T r_args[r_args_depth][kILP];
+ T r_args[r_args_depth][ilp];
- pointwise_op_scalar<res_arg_index>(
+ pointwise_op_scalar<res_arg_index, ilp>(
r_args, args, scalar, n, chunk_size, all_aligned, op);
}
};
@@ -505,6 +498,7 @@ struct PointwiseOpScalarListFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct PointwiseOpScalar0dTensorFunctor {
using opmath_t = at::opmath_type<T>;
+ static constexpr int64_t ilp = effective_ilp<T>();
template <typename Op>
__device__ __forceinline__ void operator()(
int64_t chunk_size,
@@ -517,57 +511,46 @@ struct PointwiseOpScalar0dTensorFunctor {
T* args[depth];
const bool all_aligned =
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc);
- // Load the 0D tensor1 value from device memory (just one element)
opmath_t tensor1_val = static_cast<opmath_t>(
*reinterpret_cast<const T*>(tl.addresses[1][tensor_loc]));
n -= chunk_idx * chunk_size;
- T r_args[r_args_depth][kILP];
+ T r_args[r_args_depth][ilp];
- // to make things simple, we put aligned case in a different code path
- // For depth=4: args[0] = input, args[1] = tensor1 (0D), args[2] = tensor2,
- // args[3] = output For depth=3: args[0] = input, args[1] = tensor1 (0D),
- // args[2] = tensor2, output = args[0]
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
- i_start * kILP < n && i_start * kILP < chunk_size;
+ i_start * ilp < n && i_start * ilp < chunk_size;
i_start += blockDim.x) {
- // load input and tensor2 only (tensor1 is already loaded as scalar)
- load_store(r_args[0], args[0], 0, i_start);
- load_store(r_args[1], args[2], 0, i_start); // tensor2 is at args[2]
+ load_store<ilp>(r_args[0], args[0], 0, i_start);
+ load_store<ilp>(r_args[1], args[2], 0, i_start);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
- // input + alpha * op(tensor1_val, tensor2)
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] = pointwise_op_impl<opmath_t>(
r_args[0][ii], tensor1_val, r_args[1][ii], alpha, op);
}
- // store
- load_store(args[res_arg_index], r_args[0], i_start, 0);
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
- i_start += blockDim.x * kILP) {
- // Load input (r_args[0]) and tensor2 (r_args[1])
- // We need to load from args[0] and args[2] (skipping args[1] which is
- // 0D tensor)
+ i_start += blockDim.x * ilp) {
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
const auto i = i_start + threadIdx.x + ii * blockDim.x;
r_args[0][ii] = 0;
r_args[1][ii] = 0;
if (i < n && i < chunk_size) {
r_args[0][ii] = args[0][i];
- r_args[1][ii] = args[2][i]; // tensor2 is at args[2]
+ r_args[1][ii] = args[2][i];
}
}
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] = pointwise_op_impl<opmath_t>(
r_args[0][ii], tensor1_val, r_args[1][ii], alpha, op);
}
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
@@ -576,6 +559,7 @@ struct PointwiseOpScalar0dTensorFunctor {
template <typename T, int depth>
struct PointwiseOpListFunctor {
using opmath_t = at::opmath_type<T>;
+ static constexpr int64_t ilp = effective_ilp<T>();
template <typename Op>
__device__ __forceinline__ void operator()(
int64_t chunk_size,
@@ -587,38 +571,35 @@ struct PointwiseOpListFunctor {
T* args[depth];
const bool all_aligned =
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
- T r_args[depth - 1][kILP];
+ T r_args[depth - 1][ilp];
- // to make things simple, we put aligned case in a different code path
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
- i_start * kILP < n && i_start * kILP < chunk_size;
+ i_start * ilp < n && i_start * ilp < chunk_size;
i_start += blockDim.x) {
- // load
- load_store(r_args[0], args[0], 0, i_start);
- load_store(r_args[1], args[1], 0, i_start);
+ load_store<ilp>(r_args[0], args[0], 0, i_start);
+ load_store<ilp>(r_args[1], args[1], 0, i_start);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] = static_cast<T>(
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(r_args[1][ii])));
}
- // store
- load_store(args[2], r_args[0], i_start, 0);
+ load_store<ilp>(args[2], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
- i_start += blockDim.x * kILP) {
- load_args<depth - 1>(r_args, args, i_start, chunk_size, n);
+ i_start += blockDim.x * ilp) {
+ load_args<depth - 1, ilp>(r_args, args, i_start, chunk_size, n);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] = static_cast<T>(
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(r_args[1][ii])));
}
- store_args(args[2], r_args[0], i_start, chunk_size, n);
+ store_args<ilp>(args[2], r_args[0], i_start, chunk_size, n);
}
}
}
@@ -627,6 +608,7 @@ struct PointwiseOpListFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct TernaryOpListFunctor {
using opmath_t = at::opmath_type<T>;
+ static constexpr int64_t ilp = effective_ilp<T>();
template <typename Op>
__device__ __forceinline__ void operator()(
int64_t chunk_size,
@@ -641,38 +623,38 @@ struct TernaryOpListFunctor {
T* args[depth];
const bool all_aligned =
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
- T r_args[r_args_depth][kILP];
+ T r_args[r_args_depth][ilp];
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
- i_start * kILP < n && i_start * kILP < chunk_size;
+ i_start * ilp < n && i_start * ilp < chunk_size;
i_start += blockDim.x) {
- load_store(r_args[0], args[0], 0, i_start);
- load_store(r_args[1], args[1], 0, i_start);
- load_store(r_args[2], args[2], 0, i_start);
+ load_store<ilp>(r_args[0], args[0], 0, i_start);
+ load_store<ilp>(r_args[1], args[1], 0, i_start);
+ load_store<ilp>(r_args[2], args[2], 0, i_start);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] =
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(r_args[1][ii]),
static_cast<opmath_t>(r_args[2][ii]));
}
- load_store(args[res_arg_index], r_args[0], i_start, 0);
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
- i_start += blockDim.x * kILP) {
- load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
+ i_start += blockDim.x * ilp) {
+ load_args<r_args_depth, ilp>(r_args, args, i_start, chunk_size, n);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] =
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(r_args[1][ii]),
static_cast<opmath_t>(r_args[2][ii]));
}
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
@@ -681,6 +663,7 @@ struct TernaryOpListFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct TernaryOpScalarFunctor {
using opmath_t = at::opmath_type<T>;
+ static constexpr int64_t ilp = effective_ilp<T>();
template <typename Op>
__device__ __forceinline__ void operator()(
int64_t chunk_size,
@@ -696,40 +679,37 @@ struct TernaryOpScalarFunctor {
T* args[depth];
const bool all_aligned =
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
- T r_args[r_args_depth][kILP];
+ T r_args[r_args_depth][ilp];
- // to make things simple, we put aligned case in a different code path
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
- i_start * kILP < n && i_start * kILP < chunk_size;
+ i_start * ilp < n && i_start * ilp < chunk_size;
i_start += blockDim.x) {
- // load
- load_store(r_args[0], args[0], 0, i_start);
- load_store(r_args[1], args[1], 0, i_start);
+ load_store<ilp>(r_args[0], args[0], 0, i_start);
+ load_store<ilp>(r_args[1], args[1], 0, i_start);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] =
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(r_args[1][ii]),
alpha);
}
- // store
- load_store(args[res_arg_index], r_args[0], i_start, 0);
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
- i_start += blockDim.x * kILP) {
- load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
+ i_start += blockDim.x * ilp) {
+ load_args<r_args_depth, ilp>(r_args, args, i_start, chunk_size, n);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] =
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(r_args[1][ii]),
alpha);
}
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
@@ -738,6 +718,7 @@ struct TernaryOpScalarFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct TernaryOpScalarListFunctor {
using opmath_t = at::opmath_type<T>;
+ static constexpr int64_t ilp = effective_ilp<T>();
template <typename Op>
__device__ __forceinline__ void operator()(
int64_t chunk_size,
@@ -752,41 +733,38 @@ struct TernaryOpScalarListFunctor {
T* args[depth];
const bool all_aligned =
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
- T r_args[r_args_depth][kILP];
+ T r_args[r_args_depth][ilp];
const opmath_t scalar = tl.scalar_vals[tensor_loc];
- // to make things simple, we put aligned case in a different code path
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
- i_start * kILP < n && i_start * kILP < chunk_size;
+ i_start * ilp < n && i_start * ilp < chunk_size;
i_start += blockDim.x) {
- // load
- load_store(r_args[0], args[0], 0, i_start);
- load_store(r_args[1], args[1], 0, i_start);
+ load_store<ilp>(r_args[0], args[0], 0, i_start);
+ load_store<ilp>(r_args[1], args[1], 0, i_start);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] =
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(r_args[1][ii]),
scalar);
}
- // store
- load_store(args[res_arg_index], r_args[0], i_start, 0);
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
- i_start += blockDim.x * kILP) {
- load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
+ i_start += blockDim.x * ilp) {
+ load_args<r_args_depth, ilp>(r_args, args, i_start, chunk_size, n);
#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
+ for (int ii = 0; ii < ilp; ii++) {
r_args[0][ii] =
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(r_args[1][ii]),
scalar);
}
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cuh b/aten/src/ATen/native/cuda/MultiTensorApply.cuh
index 2fe431f778b..f4d9b9043f7 100644
--- a/aten/src/ATen/native/cuda/MultiTensorApply.cuh
+++ b/aten/src/ATen/native/cuda/MultiTensorApply.cuh
@@ -14,6 +14,16 @@ static constexpr int64_t kILP = 4;
static constexpr int64_t kChunkSize = 65536;
static constexpr int64_t kBlockSize = 512;
+// Target 128-bit (16-byte) vectorized loads. For types narrower than 4 bytes
+// (e.g. fp16/bf16 at 2 bytes), this yields ILP=8 instead of the default 4,
+// doubling memory throughput per thread. Capped at 8 to limit register
+// pressure at higher depths.
+template <typename T>
+constexpr int64_t effective_ilp() {
+ constexpr int64_t target = static_cast<int64_t>(16 / sizeof(T));
+ return (target < kILP) ? kILP : ((target > 8) ? 8 : target);
+}
+
// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy`
// TensorListMetadata has to be < 4KB - the limit for kernel launch argument
static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
@@ -23,18 +33,18 @@ static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
72,
60};
-template <typename T>
+template <int64_t ilp = kILP, typename T>
__device__ __forceinline__ bool is_aligned(T* p) {
- return ((uint64_t)p) % (kILP * sizeof(T)) == 0;
+ return ((uint64_t)p) % (ilp * sizeof(T)) == 0;
}
-template <typename T>
+template <int64_t ilp = kILP, typename T>
__device__ __forceinline__ void load_store(
T* dst,
T* src,
int64_t dst_offset,
int64_t src_offset) {
- using LT = at::native::memory::aligned_vector<T, kILP>;
+ using LT = at::native::memory::aligned_vector<T, ilp>;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment