|
diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh |
|
index 37175776097..f6e18b9cb47 100644 |
|
--- a/aten/src/ATen/native/cuda/ForeachFunctors.cuh |
|
+++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh |
|
@@ -18,8 +18,7 @@ inline void increment_version(TensorList tensors) { |
|
} |
|
} |
|
|
|
-// Initializes args and checks if all args are aligned |
|
-template <int depth, typename T> |
|
+template <int depth, int64_t ilp = kILP, typename T> |
|
__device__ bool init_args( |
|
T** args, |
|
TensorListMetadata<depth>& tl, |
|
@@ -31,15 +30,14 @@ __device__ bool init_args( |
|
args[i] = (T*)tl.addresses[i][tensor_loc]; |
|
args[i] += chunk_idx * chunk_size; |
|
|
|
- if (!is_aligned(args[i])) { |
|
+ if (!is_aligned<ilp>(args[i])) { |
|
all_aligned = false; |
|
} |
|
} |
|
return all_aligned; |
|
} |
|
|
|
-// Initializes args and checks if all args are aligned |
|
-template <int depth, typename T, typename T2> |
|
+template <int depth, int64_t ilp = kILP, typename T, typename T2> |
|
__device__ bool init_args( |
|
T** args, |
|
TensorListScalarListMetadata<T2, depth>& tl, |
|
@@ -51,14 +49,14 @@ __device__ bool init_args( |
|
args[i] = (T*)tl.addresses[i][tensor_loc]; |
|
args[i] += chunk_idx * chunk_size; |
|
|
|
- if (!is_aligned(args[i])) { |
|
+ if (!is_aligned<ilp>(args[i])) { |
|
all_aligned = false; |
|
} |
|
} |
|
return all_aligned; |
|
} |
|
|
|
-template <int depth, typename T> |
|
+template <int depth, int64_t ilp = kILP, typename T> |
|
__device__ bool init_args( |
|
T** args, |
|
FusedOptimizerTensorListMetadata<depth>& tl, |
|
@@ -70,22 +68,22 @@ __device__ bool init_args( |
|
args[i] = (T*)tl.addresses[i][tensor_loc]; |
|
args[i] += chunk_idx * chunk_size; |
|
|
|
- if (!is_aligned(args[i])) { |
|
+ if (!is_aligned<ilp>(args[i])) { |
|
all_aligned = false; |
|
} |
|
} |
|
return all_aligned; |
|
} |
|
|
|
-template <int depth, typename T> |
|
+template <int depth, int64_t ilp = kILP, typename T> |
|
__device__ void load_args( |
|
- T r_args[][kILP], |
|
+ T r_args[][ilp], |
|
T** args, |
|
const int64_t i_start, |
|
const int64_t chunk_size, |
|
const int64_t n) { |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
const auto i = i_start + threadIdx.x + ii * blockDim.x; |
|
for (int r_index = 0; r_index < depth; r_index++) { |
|
r_args[r_index][ii] = 0; |
|
@@ -96,7 +94,7 @@ __device__ void load_args( |
|
} |
|
} |
|
|
|
-template <typename T> |
|
+template <int64_t ilp = kILP, typename T> |
|
__device__ void store_args( |
|
T* dst, |
|
T* src, |
|
@@ -104,93 +102,93 @@ __device__ void store_args( |
|
const int64_t chunk_size, |
|
const int64_t n) { |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
const int64_t i = i_start + threadIdx.x + ii * blockDim.x; |
|
if (i < n && i < chunk_size) |
|
dst[i] = src[ii]; |
|
} |
|
} |
|
|
|
-template <int res_arg_index, typename Op, typename T, typename opmath_t> |
|
+template < |
|
+ int res_arg_index, |
|
+ int64_t ilp = kILP, |
|
+ typename Op, |
|
+ typename T, |
|
+ typename opmath_t> |
|
__device__ __forceinline__ void binary_op_scalar( |
|
- T r_args[][kILP], |
|
+ T r_args[][ilp], |
|
T** args, |
|
opmath_t scalar, |
|
const int64_t n, |
|
const int64_t chunk_size, |
|
const bool all_aligned, |
|
Op op) { |
|
- // to make things simple, we put aligned case in a different code path |
|
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) { |
|
for (int64_t i_start = threadIdx.x; |
|
- i_start * kILP < n && i_start * kILP < chunk_size; |
|
+ i_start * ilp < n && i_start * ilp < chunk_size; |
|
i_start += blockDim.x) { |
|
- // load |
|
- load_store(r_args[0], args[0], 0, i_start); |
|
+ load_store<ilp>(r_args[0], args[0], 0, i_start); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = static_cast<T>( |
|
op(static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(scalar))); |
|
} |
|
- // store |
|
- load_store(args[res_arg_index], r_args[0], i_start, 0); |
|
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0); |
|
} |
|
} else { |
|
for (int64_t i_start = 0; i_start < n && i_start < chunk_size; |
|
- i_start += blockDim.x * kILP) { |
|
- // Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args |
|
- // has depth 1 |
|
- load_args<1>(r_args, args, i_start, chunk_size, n); |
|
+ i_start += blockDim.x * ilp) { |
|
+ load_args<1, ilp>(r_args, args, i_start, chunk_size, n); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = static_cast<T>( |
|
op(static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(scalar))); |
|
} |
|
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
|
|
-template <int res_arg_index, typename Op, typename T, typename opmath_t> |
|
+template < |
|
+ int res_arg_index, |
|
+ int64_t ilp = kILP, |
|
+ typename Op, |
|
+ typename T, |
|
+ typename opmath_t> |
|
__device__ __forceinline__ void pointwise_op_scalar( |
|
- T r_args[][kILP], |
|
+ T r_args[][ilp], |
|
T** args, |
|
opmath_t scalar, |
|
const int64_t n, |
|
const int64_t chunk_size, |
|
const bool all_aligned, |
|
Op op) { |
|
- // to make things simple, we put aligned case in a different code path |
|
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) { |
|
for (int64_t i_start = threadIdx.x; |
|
- i_start * kILP < n && i_start * kILP < chunk_size; |
|
+ i_start * ilp < n && i_start * ilp < chunk_size; |
|
i_start += blockDim.x) { |
|
- // load |
|
- load_store(r_args[0], args[0], 0, i_start); |
|
- load_store(r_args[1], args[1], 0, i_start); |
|
- load_store(r_args[2], args[2], 0, i_start); |
|
+ load_store<ilp>(r_args[0], args[0], 0, i_start); |
|
+ load_store<ilp>(r_args[1], args[1], 0, i_start); |
|
+ load_store<ilp>(r_args[2], args[2], 0, i_start); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = pointwise_op_impl<opmath_t>( |
|
r_args[0][ii], r_args[1][ii], r_args[2][ii], scalar, op); |
|
} |
|
- // store |
|
- load_store(args[res_arg_index], r_args[0], i_start, 0); |
|
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0); |
|
} |
|
} else { |
|
for (int64_t i_start = 0; i_start < n && i_start < chunk_size; |
|
- i_start += blockDim.x * kILP) { |
|
- // Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args |
|
- // has depth 3 |
|
- load_args<3>(r_args, args, i_start, chunk_size, n); |
|
+ i_start += blockDim.x * ilp) { |
|
+ load_args<3, ilp>(r_args, args, i_start, chunk_size, n); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = pointwise_op_impl<opmath_t>( |
|
r_args[0][ii], r_args[1][ii], r_args[2][ii], scalar, op); |
|
} |
|
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
@@ -201,6 +199,7 @@ __device__ __forceinline__ void pointwise_op_scalar( |
|
template <typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct BinaryOpScalarFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
+ static constexpr int64_t ilp = effective_ilp<T>(); |
|
template <typename Op> |
|
__device__ __forceinline__ void operator()( |
|
int64_t chunk_size, |
|
@@ -213,11 +212,11 @@ struct BinaryOpScalarFunctor { |
|
|
|
T* args[depth]; |
|
const bool all_aligned = |
|
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
- T r_args[r_args_depth][kILP]; |
|
+ T r_args[r_args_depth][ilp]; |
|
|
|
- binary_op_scalar<res_arg_index>( |
|
+ binary_op_scalar<res_arg_index, ilp>( |
|
r_args, args, scalar, n, chunk_size, all_aligned, op); |
|
} |
|
}; |
|
@@ -225,6 +224,7 @@ struct BinaryOpScalarFunctor { |
|
template <typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct BinaryOpScalarListFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
+ static constexpr int64_t ilp = effective_ilp<T>(); |
|
template <typename Op> |
|
__device__ __forceinline__ void operator()( |
|
int64_t chunk_size, |
|
@@ -236,12 +236,12 @@ struct BinaryOpScalarListFunctor { |
|
|
|
T* args[depth]; |
|
const bool all_aligned = |
|
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
opmath_t scalar = tl.scalar_vals[tensor_loc]; |
|
n -= chunk_idx * chunk_size; |
|
- T r_args[r_args_depth][kILP]; |
|
+ T r_args[r_args_depth][ilp]; |
|
|
|
- binary_op_scalar<res_arg_index>( |
|
+ binary_op_scalar<res_arg_index, ilp>( |
|
r_args, args, scalar, n, chunk_size, all_aligned, op); |
|
} |
|
}; |
|
@@ -249,6 +249,7 @@ struct BinaryOpScalarListFunctor { |
|
template <typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct BinaryOpListAlphaFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
+ static constexpr int64_t ilp = effective_ilp<T>(); |
|
template <typename Op> |
|
__device__ __forceinline__ void operator()( |
|
int64_t chunk_size, |
|
@@ -261,38 +262,35 @@ struct BinaryOpListAlphaFunctor { |
|
|
|
T* args[depth]; |
|
const bool all_aligned = |
|
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
- T r_args[r_args_depth][kILP]; |
|
+ T r_args[r_args_depth][ilp]; |
|
|
|
- // to make things simple, we put aligned case in a different code path |
|
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) { |
|
for (int64_t i_start = threadIdx.x; |
|
- i_start * kILP < n && i_start * kILP < chunk_size; |
|
+ i_start * ilp < n && i_start * ilp < chunk_size; |
|
i_start += blockDim.x) { |
|
- // load |
|
- load_store(r_args[0], args[0], 0, i_start); |
|
- load_store(r_args[1], args[1], 0, i_start); |
|
+ load_store<ilp>(r_args[0], args[0], 0, i_start); |
|
+ load_store<ilp>(r_args[1], args[1], 0, i_start); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = static_cast<T>( |
|
op(static_cast<opmath_t>(r_args[0][ii]), |
|
alpha * static_cast<opmath_t>(r_args[1][ii]))); |
|
} |
|
- // store |
|
- load_store(args[res_arg_index], r_args[0], i_start, 0); |
|
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0); |
|
} |
|
} else { |
|
for (int64_t i_start = 0; i_start < n && i_start < chunk_size; |
|
- i_start += blockDim.x * kILP) { |
|
- load_args<r_args_depth>(r_args, args, i_start, chunk_size, n); |
|
+ i_start += blockDim.x * ilp) { |
|
+ load_args<r_args_depth, ilp>(r_args, args, i_start, chunk_size, n); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = static_cast<T>( |
|
op(static_cast<opmath_t>(r_args[0][ii]), |
|
alpha * static_cast<opmath_t>(r_args[1][ii]))); |
|
} |
|
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
@@ -301,6 +299,7 @@ struct BinaryOpListAlphaFunctor { |
|
template <typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct BinaryOpScalarTensorFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
+ static constexpr int64_t ilp = effective_ilp<T>(); |
|
template <typename Op> |
|
__device__ __forceinline__ void operator()( |
|
int64_t chunk_size, |
|
@@ -314,39 +313,34 @@ struct BinaryOpScalarTensorFunctor { |
|
|
|
T* args[depth]; |
|
const bool all_aligned = |
|
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
- T r_args[r_args_depth][kILP]; |
|
+ T r_args[r_args_depth][ilp]; |
|
|
|
- // to make things simple, we put aligned case in a different code path |
|
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) { |
|
for (int64_t i_start = threadIdx.x; |
|
- i_start * kILP < n && i_start * kILP < chunk_size; |
|
+ i_start * ilp < n && i_start * ilp < chunk_size; |
|
i_start += blockDim.x) { |
|
- // load |
|
- load_store(r_args[0], args[0], 0, i_start); |
|
+ load_store<ilp>(r_args[0], args[0], 0, i_start); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = static_cast<T>(op( |
|
static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar))); |
|
} |
|
- // store |
|
- load_store(args[res_arg_index], r_args[0], i_start, 0); |
|
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0); |
|
} |
|
} else { |
|
for (int64_t i_start = 0; i_start < n && i_start < chunk_size; |
|
- i_start += blockDim.x * kILP) { |
|
- // Regardless if depth is 1 (for inplace) or 2 (for out of place), |
|
- // r_args has depth 1 |
|
- load_args<1>(r_args, args, i_start, chunk_size, n); |
|
+ i_start += blockDim.x * ilp) { |
|
+ load_args<1, ilp>(r_args, args, i_start, chunk_size, n); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = static_cast<T>(op( |
|
static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar))); |
|
} |
|
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
@@ -358,6 +352,7 @@ struct BinaryOpScalarTensorFunctor { |
|
|
|
template <typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct ZeroFunctor { |
|
+ static constexpr int64_t ilp = effective_ilp<T>(); |
|
__device__ __forceinline__ void operator()( |
|
int64_t chunk_size, |
|
TensorListMetadata<1>& tl) { |
|
@@ -367,30 +362,28 @@ struct ZeroFunctor { |
|
|
|
T* args[depth]; |
|
const auto all_aligned = |
|
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
- T r_args[r_args_depth][kILP]; |
|
+ T r_args[r_args_depth][ilp]; |
|
|
|
- // to make things simple, we put aligned case in a different code path |
|
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) { |
|
for (int64_t i_start = threadIdx.x; |
|
- i_start * kILP < n && i_start * kILP < chunk_size; |
|
+ i_start * ilp < n && i_start * ilp < chunk_size; |
|
i_start += blockDim.x) { |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = 0; |
|
} |
|
- // store |
|
- load_store(args[0], r_args[0], i_start, 0); |
|
+ load_store<ilp>(args[0], r_args[0], i_start, 0); |
|
} |
|
} else { |
|
for (int64_t i_start = 0; i_start < n && i_start < chunk_size; |
|
- i_start += blockDim.x * kILP) { |
|
+ i_start += blockDim.x * ilp) { |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = 0; |
|
} |
|
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
@@ -399,6 +392,7 @@ struct ZeroFunctor { |
|
template <typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct UnaryOpFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
+ static constexpr int64_t ilp = effective_ilp<T>(); |
|
template <typename Op> |
|
__device__ __forceinline__ void operator()( |
|
int64_t chunk_size, |
|
@@ -410,35 +404,32 @@ struct UnaryOpFunctor { |
|
|
|
T* args[depth]; |
|
bool all_aligned = |
|
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
- T r_args[r_args_depth][kILP]; |
|
+ T r_args[r_args_depth][ilp]; |
|
|
|
- // to make things simple, we put aligned case in a different code path |
|
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) { |
|
for (int64_t i_start = threadIdx.x; |
|
- i_start * kILP < n && i_start * kILP < chunk_size; |
|
+ i_start * ilp < n && i_start * ilp < chunk_size; |
|
i_start += blockDim.x) { |
|
- // load |
|
- load_store(r_args[0], args[0], 0, i_start); |
|
+ load_store<ilp>(r_args[0], args[0], 0, i_start); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = |
|
static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]))); |
|
} |
|
- // store |
|
- load_store(args[res_arg_index], r_args[0], i_start, 0); |
|
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0); |
|
} |
|
} else { |
|
for (int64_t i_start = 0; i_start < n && i_start < chunk_size; |
|
- i_start += blockDim.x * kILP) { |
|
- load_args<r_args_depth>(r_args, args, i_start, chunk_size, n); |
|
+ i_start += blockDim.x * ilp) { |
|
+ load_args<r_args_depth, ilp>(r_args, args, i_start, chunk_size, n); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = |
|
static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]))); |
|
} |
|
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
@@ -451,6 +442,7 @@ struct UnaryOpFunctor { |
|
template <typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct PointwiseOpScalarFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
+ static constexpr int64_t ilp = effective_ilp<T>(); |
|
template <typename Op> |
|
__device__ __forceinline__ void operator()( |
|
int64_t chunk_size, |
|
@@ -463,11 +455,11 @@ struct PointwiseOpScalarFunctor { |
|
|
|
T* args[depth]; |
|
const bool all_aligned = |
|
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
- T r_args[r_args_depth][kILP]; |
|
+ T r_args[r_args_depth][ilp]; |
|
|
|
- pointwise_op_scalar<res_arg_index>( |
|
+ pointwise_op_scalar<res_arg_index, ilp>( |
|
r_args, args, scalar, n, chunk_size, all_aligned, op); |
|
} |
|
}; |
|
@@ -475,6 +467,7 @@ struct PointwiseOpScalarFunctor { |
|
template <typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct PointwiseOpScalarListFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
+ static constexpr int64_t ilp = effective_ilp<T>(); |
|
template <typename Op> |
|
__device__ __forceinline__ void operator()( |
|
int64_t chunk_size, |
|
@@ -486,12 +479,12 @@ struct PointwiseOpScalarListFunctor { |
|
|
|
T* args[depth]; |
|
const bool all_aligned = |
|
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
opmath_t scalar = tl.scalar_vals[tensor_loc]; |
|
n -= chunk_idx * chunk_size; |
|
- T r_args[r_args_depth][kILP]; |
|
+ T r_args[r_args_depth][ilp]; |
|
|
|
- pointwise_op_scalar<res_arg_index>( |
|
+ pointwise_op_scalar<res_arg_index, ilp>( |
|
r_args, args, scalar, n, chunk_size, all_aligned, op); |
|
} |
|
}; |
|
@@ -505,6 +498,7 @@ struct PointwiseOpScalarListFunctor { |
|
template <typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct PointwiseOpScalar0dTensorFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
+ static constexpr int64_t ilp = effective_ilp<T>(); |
|
template <typename Op> |
|
__device__ __forceinline__ void operator()( |
|
int64_t chunk_size, |
|
@@ -517,57 +511,46 @@ struct PointwiseOpScalar0dTensorFunctor { |
|
|
|
T* args[depth]; |
|
const bool all_aligned = |
|
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
|
|
- // Load the 0D tensor1 value from device memory (just one element) |
|
opmath_t tensor1_val = static_cast<opmath_t>( |
|
*reinterpret_cast<const T*>(tl.addresses[1][tensor_loc])); |
|
|
|
n -= chunk_idx * chunk_size; |
|
- T r_args[r_args_depth][kILP]; |
|
+ T r_args[r_args_depth][ilp]; |
|
|
|
- // to make things simple, we put aligned case in a different code path |
|
- // For depth=4: args[0] = input, args[1] = tensor1 (0D), args[2] = tensor2, |
|
- // args[3] = output For depth=3: args[0] = input, args[1] = tensor1 (0D), |
|
- // args[2] = tensor2, output = args[0] |
|
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) { |
|
for (int64_t i_start = threadIdx.x; |
|
- i_start * kILP < n && i_start * kILP < chunk_size; |
|
+ i_start * ilp < n && i_start * ilp < chunk_size; |
|
i_start += blockDim.x) { |
|
- // load input and tensor2 only (tensor1 is already loaded as scalar) |
|
- load_store(r_args[0], args[0], 0, i_start); |
|
- load_store(r_args[1], args[2], 0, i_start); // tensor2 is at args[2] |
|
+ load_store<ilp>(r_args[0], args[0], 0, i_start); |
|
+ load_store<ilp>(r_args[1], args[2], 0, i_start); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
- // input + alpha * op(tensor1_val, tensor2) |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = pointwise_op_impl<opmath_t>( |
|
r_args[0][ii], tensor1_val, r_args[1][ii], alpha, op); |
|
} |
|
- // store |
|
- load_store(args[res_arg_index], r_args[0], i_start, 0); |
|
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0); |
|
} |
|
} else { |
|
for (int64_t i_start = 0; i_start < n && i_start < chunk_size; |
|
- i_start += blockDim.x * kILP) { |
|
- // Load input (r_args[0]) and tensor2 (r_args[1]) |
|
- // We need to load from args[0] and args[2] (skipping args[1] which is |
|
- // 0D tensor) |
|
+ i_start += blockDim.x * ilp) { |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
const auto i = i_start + threadIdx.x + ii * blockDim.x; |
|
r_args[0][ii] = 0; |
|
r_args[1][ii] = 0; |
|
if (i < n && i < chunk_size) { |
|
r_args[0][ii] = args[0][i]; |
|
- r_args[1][ii] = args[2][i]; // tensor2 is at args[2] |
|
+ r_args[1][ii] = args[2][i]; |
|
} |
|
} |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = pointwise_op_impl<opmath_t>( |
|
r_args[0][ii], tensor1_val, r_args[1][ii], alpha, op); |
|
} |
|
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
@@ -576,6 +559,7 @@ struct PointwiseOpScalar0dTensorFunctor { |
|
template <typename T, int depth> |
|
struct PointwiseOpListFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
+ static constexpr int64_t ilp = effective_ilp<T>(); |
|
template <typename Op> |
|
__device__ __forceinline__ void operator()( |
|
int64_t chunk_size, |
|
@@ -587,38 +571,35 @@ struct PointwiseOpListFunctor { |
|
|
|
T* args[depth]; |
|
const bool all_aligned = |
|
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
- T r_args[depth - 1][kILP]; |
|
+ T r_args[depth - 1][ilp]; |
|
|
|
- // to make things simple, we put aligned case in a different code path |
|
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) { |
|
for (int64_t i_start = threadIdx.x; |
|
- i_start * kILP < n && i_start * kILP < chunk_size; |
|
+ i_start * ilp < n && i_start * ilp < chunk_size; |
|
i_start += blockDim.x) { |
|
- // load |
|
- load_store(r_args[0], args[0], 0, i_start); |
|
- load_store(r_args[1], args[1], 0, i_start); |
|
+ load_store<ilp>(r_args[0], args[0], 0, i_start); |
|
+ load_store<ilp>(r_args[1], args[1], 0, i_start); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = static_cast<T>( |
|
op(static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(r_args[1][ii]))); |
|
} |
|
- // store |
|
- load_store(args[2], r_args[0], i_start, 0); |
|
+ load_store<ilp>(args[2], r_args[0], i_start, 0); |
|
} |
|
} else { |
|
for (int64_t i_start = 0; i_start < n && i_start < chunk_size; |
|
- i_start += blockDim.x * kILP) { |
|
- load_args<depth - 1>(r_args, args, i_start, chunk_size, n); |
|
+ i_start += blockDim.x * ilp) { |
|
+ load_args<depth - 1, ilp>(r_args, args, i_start, chunk_size, n); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = static_cast<T>( |
|
op(static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(r_args[1][ii]))); |
|
} |
|
- store_args(args[2], r_args[0], i_start, chunk_size, n); |
|
+ store_args<ilp>(args[2], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
@@ -627,6 +608,7 @@ struct PointwiseOpListFunctor { |
|
template <typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct TernaryOpListFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
+ static constexpr int64_t ilp = effective_ilp<T>(); |
|
template <typename Op> |
|
__device__ __forceinline__ void operator()( |
|
int64_t chunk_size, |
|
@@ -641,38 +623,38 @@ struct TernaryOpListFunctor { |
|
|
|
T* args[depth]; |
|
const bool all_aligned = |
|
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
- T r_args[r_args_depth][kILP]; |
|
+ T r_args[r_args_depth][ilp]; |
|
|
|
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) { |
|
for (int64_t i_start = threadIdx.x; |
|
- i_start * kILP < n && i_start * kILP < chunk_size; |
|
+ i_start * ilp < n && i_start * ilp < chunk_size; |
|
i_start += blockDim.x) { |
|
- load_store(r_args[0], args[0], 0, i_start); |
|
- load_store(r_args[1], args[1], 0, i_start); |
|
- load_store(r_args[2], args[2], 0, i_start); |
|
+ load_store<ilp>(r_args[0], args[0], 0, i_start); |
|
+ load_store<ilp>(r_args[1], args[1], 0, i_start); |
|
+ load_store<ilp>(r_args[2], args[2], 0, i_start); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = |
|
op(static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(r_args[1][ii]), |
|
static_cast<opmath_t>(r_args[2][ii])); |
|
} |
|
- load_store(args[res_arg_index], r_args[0], i_start, 0); |
|
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0); |
|
} |
|
} else { |
|
for (int64_t i_start = 0; i_start < n && i_start < chunk_size; |
|
- i_start += blockDim.x * kILP) { |
|
- load_args<r_args_depth>(r_args, args, i_start, chunk_size, n); |
|
+ i_start += blockDim.x * ilp) { |
|
+ load_args<r_args_depth, ilp>(r_args, args, i_start, chunk_size, n); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = |
|
op(static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(r_args[1][ii]), |
|
static_cast<opmath_t>(r_args[2][ii])); |
|
} |
|
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
@@ -681,6 +663,7 @@ struct TernaryOpListFunctor { |
|
template <typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct TernaryOpScalarFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
+ static constexpr int64_t ilp = effective_ilp<T>(); |
|
template <typename Op> |
|
__device__ __forceinline__ void operator()( |
|
int64_t chunk_size, |
|
@@ -696,40 +679,37 @@ struct TernaryOpScalarFunctor { |
|
|
|
T* args[depth]; |
|
const bool all_aligned = |
|
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
- T r_args[r_args_depth][kILP]; |
|
+ T r_args[r_args_depth][ilp]; |
|
|
|
- // to make things simple, we put aligned case in a different code path |
|
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) { |
|
for (int64_t i_start = threadIdx.x; |
|
- i_start * kILP < n && i_start * kILP < chunk_size; |
|
+ i_start * ilp < n && i_start * ilp < chunk_size; |
|
i_start += blockDim.x) { |
|
- // load |
|
- load_store(r_args[0], args[0], 0, i_start); |
|
- load_store(r_args[1], args[1], 0, i_start); |
|
+ load_store<ilp>(r_args[0], args[0], 0, i_start); |
|
+ load_store<ilp>(r_args[1], args[1], 0, i_start); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = |
|
op(static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(r_args[1][ii]), |
|
alpha); |
|
} |
|
- // store |
|
- load_store(args[res_arg_index], r_args[0], i_start, 0); |
|
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0); |
|
} |
|
} else { |
|
for (int64_t i_start = 0; i_start < n && i_start < chunk_size; |
|
- i_start += blockDim.x * kILP) { |
|
- load_args<r_args_depth>(r_args, args, i_start, chunk_size, n); |
|
+ i_start += blockDim.x * ilp) { |
|
+ load_args<r_args_depth, ilp>(r_args, args, i_start, chunk_size, n); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = |
|
op(static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(r_args[1][ii]), |
|
alpha); |
|
} |
|
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
@@ -738,6 +718,7 @@ struct TernaryOpScalarFunctor { |
|
template <typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct TernaryOpScalarListFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
+ static constexpr int64_t ilp = effective_ilp<T>(); |
|
template <typename Op> |
|
__device__ __forceinline__ void operator()( |
|
int64_t chunk_size, |
|
@@ -752,41 +733,38 @@ struct TernaryOpScalarListFunctor { |
|
|
|
T* args[depth]; |
|
const bool all_aligned = |
|
- init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
+ init_args<depth, ilp>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
- T r_args[r_args_depth][kILP]; |
|
+ T r_args[r_args_depth][ilp]; |
|
const opmath_t scalar = tl.scalar_vals[tensor_loc]; |
|
|
|
- // to make things simple, we put aligned case in a different code path |
|
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
+ if (n % ilp == 0 && chunk_size % ilp == 0 && all_aligned) { |
|
for (int64_t i_start = threadIdx.x; |
|
- i_start * kILP < n && i_start * kILP < chunk_size; |
|
+ i_start * ilp < n && i_start * ilp < chunk_size; |
|
i_start += blockDim.x) { |
|
- // load |
|
- load_store(r_args[0], args[0], 0, i_start); |
|
- load_store(r_args[1], args[1], 0, i_start); |
|
+ load_store<ilp>(r_args[0], args[0], 0, i_start); |
|
+ load_store<ilp>(r_args[1], args[1], 0, i_start); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = |
|
op(static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(r_args[1][ii]), |
|
scalar); |
|
} |
|
- // store |
|
- load_store(args[res_arg_index], r_args[0], i_start, 0); |
|
+ load_store<ilp>(args[res_arg_index], r_args[0], i_start, 0); |
|
} |
|
} else { |
|
for (int64_t i_start = 0; i_start < n && i_start < chunk_size; |
|
- i_start += blockDim.x * kILP) { |
|
- load_args<r_args_depth>(r_args, args, i_start, chunk_size, n); |
|
+ i_start += blockDim.x * ilp) { |
|
+ load_args<r_args_depth, ilp>(r_args, args, i_start, chunk_size, n); |
|
#pragma unroll |
|
- for (int ii = 0; ii < kILP; ii++) { |
|
+ for (int ii = 0; ii < ilp; ii++) { |
|
r_args[0][ii] = |
|
op(static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(r_args[1][ii]), |
|
scalar); |
|
} |
|
- store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
+ store_args<ilp>(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cuh b/aten/src/ATen/native/cuda/MultiTensorApply.cuh |
|
index 2fe431f778b..f4d9b9043f7 100644 |
|
--- a/aten/src/ATen/native/cuda/MultiTensorApply.cuh |
|
+++ b/aten/src/ATen/native/cuda/MultiTensorApply.cuh |
|
@@ -14,6 +14,16 @@ static constexpr int64_t kILP = 4; |
|
static constexpr int64_t kChunkSize = 65536; |
|
static constexpr int64_t kBlockSize = 512; |
|
|
|
+// Target 128-bit (16-byte) vectorized loads. For types narrower than 4 bytes |
|
+// (e.g. fp16/bf16 at 2 bytes), this yields ILP=8 instead of the default 4, |
|
+// doubling memory throughput per thread. Capped at 8 to limit register |
|
+// pressure at higher depths. |
|
+template <typename T> |
|
+constexpr int64_t effective_ilp() { |
|
+ constexpr int64_t target = static_cast<int64_t>(16 / sizeof(T)); |
|
+ return (target < kILP) ? kILP : ((target > 8) ? 8 : target); |
|
+} |
|
+ |
|
// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy` |
|
// TensorListMetadata has to be < 4KB - the limit for kernel launch argument |
|
static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; |
|
@@ -23,18 +33,18 @@ static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = { |
|
72, |
|
60}; |
|
|
|
-template <typename T> |
|
+template <int64_t ilp = kILP, typename T> |
|
__device__ __forceinline__ bool is_aligned(T* p) { |
|
- return ((uint64_t)p) % (kILP * sizeof(T)) == 0; |
|
+ return ((uint64_t)p) % (ilp * sizeof(T)) == 0; |
|
} |
|
|
|
-template <typename T> |
|
+template <int64_t ilp = kILP, typename T> |
|
__device__ __forceinline__ void load_store( |
|
T* dst, |
|
T* src, |
|
int64_t dst_offset, |
|
int64_t src_offset) { |
|
- using LT = at::native::memory::aligned_vector<T, kILP>; |
|
+ using LT = at::native::memory::aligned_vector<T, ilp>; |
|
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; |
|
} |
|
|