Skip to content

Commit

Permalink
Improve Normalization.cuh (pytorch#83871)
Browse files Browse the repository at this point in the history
remove unused Ops
replaced copy-and-paste by calling BlockReduce (+SumReduceOp +2D block indexing) and removing duplicate warpSum
Pull Request resolved: pytorch#83871
Approved by: https://github.com/ngimel
  • Loading branch information
chengscott authored and pytorchmergebot committed Aug 24, 2022
1 parent 7b1a056 commit a741927
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 69 deletions.
83 changes: 24 additions & 59 deletions aten/src/ATen/native/cuda/Normalization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <ATen/ceil_div.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/DeviceUtils.cuh>
#include <ATen/native/cuda/block_reduce.cuh>
#include <ATen/native/cuda/DeviceSqrt.cuh>
#include <ATen/native/cuda/LaunchUtils.h>
#include <c10/macros/Macros.h>
Expand Down Expand Up @@ -60,26 +61,10 @@ struct Float2 {
v2 += a.v2;
return *this;
}
};

template <typename scalar_t, typename accscalar_t, typename PTA>
struct SumOp {
__device__ SumOp(const PTA& t) : tensor(t) {}
__device__ __forceinline__ accscalar_t operator()(int batch, int plane, int n) {
return static_cast<accscalar_t>(tensor[batch][plane][n]);
}
const PTA& tensor;
};

template <typename scalar_t, typename accscalar_t, typename PTA>
struct VarOp {
__device__ VarOp(accscalar_t m, const PTA& t) : mean(m), tensor(t) {}
__device__ __forceinline__ accscalar_t operator()(int batch, int plane, int n) {
accscalar_t val = tensor[batch][plane][n];
return (val - mean) * (val - mean);
__device__ friend Float2 operator+(Float2 a, const Float2& b) {
a += b;
return a;
}
const accscalar_t mean;
const PTA& tensor;
};

template <typename scalar_t, typename accscalar_t, typename PTA>
Expand All @@ -96,21 +81,25 @@ struct GradOp {
const PTA& grad_output;
};

// Sum across all threads within a warp
template <typename T>
static __device__ __forceinline__ T warpSum(T val) {
for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
val += WARP_SHFL_XOR(val, 1 << i, C10_WARP_SIZE);
}
return val;
}
template <typename acc_t>
struct SumReduceOp {
__device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }

__device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
return WARP_SHFL_DOWN(data, offset);
}
};

template <typename scalar_t, typename accscalar_t>
static __device__ __forceinline__ Float2<scalar_t, accscalar_t> warpSum(Float2<scalar_t, accscalar_t> value) {
value.v1 = warpSum(value.v1);
value.v2 = warpSum(value.v2);
return value;
}
struct SumReduceOp<Float2<scalar_t, accscalar_t>> {
using acc_t = Float2<scalar_t, accscalar_t>;

__device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }

__device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
return {WARP_SHFL_DOWN(data.v1, offset), WARP_SHFL_DOWN(data.v2, offset)};
}
};

// Sum across (batch, x/y/z) applying Op() pointwise
// this works by first having each thread sum it's part
Expand All @@ -130,37 +119,13 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
sum += op(batch, plane, x);
}
}

// first warpSum to get one value per thread to
// one value per warp
sum = warpSum(sum);

// this writes each warps item into shared memory
// there are at most C10_WARP_SIZE items left because
// there are at most C10_WARP_SIZE**2 threads at the beginning
__shared__ scalar_t shared[C10_WARP_SIZE];
__syncthreads();
int tid = threadIdx.x + threadIdx.y * blockDim.x;
if (tid % C10_WARP_SIZE == 0) {
shared[tid / C10_WARP_SIZE] = sum;
}
if (tid >= blockDim.x * blockDim.y / C10_WARP_SIZE && tid < C10_WARP_SIZE) {
// zero out the other entries in shared
shared[tid] = (scalar_t)0;
}
__syncthreads();
// now have a second warpSum to reduce the intermediate values
// from shared memory to a single number. The very first
// thread writes it to shared memory.

if (tid / C10_WARP_SIZE == 0) {
sum = warpSum(shared[tid]);
if (tid == 0) {
SumReduceOp<scalar_t> reduce_op;
sum = cuda_utils::BlockReduce<scalar_t, SumReduceOp<scalar_t>, cuda_utils::Block2D>(sum, reduce_op, 0, shared);
if (threadIdx.x == 0 && threadIdx.y == 0) {
shared[0] = sum;
}
}
__syncthreads();

// Everyone picks it up, should be broadcast into the whole grad_input
return shared[0];
}
Expand Down
38 changes: 28 additions & 10 deletions aten/src/ATen/native/cuda/block_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,42 @@ __inline__ __device__ T WarpReduceSum(T val) {
return val;
}

struct Block1D {
static __forceinline__ __device__ int Tid() { return threadIdx.x; }

static __forceinline__ __device__ int Warps() {
return blockDim.x / C10_WARP_SIZE;
}
};

struct Block2D {
static __forceinline__ __device__ int Tid() {
return threadIdx.x + threadIdx.y * blockDim.x;
}

static __forceinline__ __device__ int Warps() {
return blockDim.x * blockDim.y / C10_WARP_SIZE;
}
};

// Sums `val` across all threads in a block.
//
// Assumptions:
// - Thread blocks are an 1D set of threads (indexed with `threadIdx.x` only)
// - The size of each block should be a multiple of `C10_WARP_SIZE`
// - `shared` should be a pointer to shared memory with size of, at least,
// `sizeof(T) * number_of_warps`
template <typename T>
template <typename T, typename B = Block1D>
__inline__ __device__ T BlockReduceSum(T val, T* shared) {
const int lid = threadIdx.x % C10_WARP_SIZE;
const int wid = threadIdx.x / C10_WARP_SIZE;
const int tid = B::Tid();
const int lid = tid % C10_WARP_SIZE;
const int wid = tid / C10_WARP_SIZE;
val = WarpReduceSum(val);
__syncthreads();
if (lid == 0) {
shared[wid] = val;
}
__syncthreads();
val = (threadIdx.x < blockDim.x / C10_WARP_SIZE) ? shared[lid] : T(0);
val = (tid < B::Warps()) ? shared[lid] : T(0);
if (wid == 0) {
val = WarpReduceSum(val);
}
Expand All @@ -62,19 +80,19 @@ __inline__ __device__ T WarpReduce(T val, const ReduceOp& op) {
return val;
}

template <typename T, class ReduceOp>
template <typename T, class ReduceOp, typename B = Block1D>
__inline__ __device__ T
BlockReduce(T val, const ReduceOp& op, const T& identity_element, T* shared) {
const int lid = threadIdx.x % C10_WARP_SIZE;
const int wid = threadIdx.x / C10_WARP_SIZE;
const int tid = B::Tid();
const int lid = tid % C10_WARP_SIZE;
const int wid = tid / C10_WARP_SIZE;
val = WarpReduce(val, op);
__syncthreads();
if (lid == 0) {
shared[wid] = val;
}
__syncthreads();
val = (threadIdx.x < blockDim.x / C10_WARP_SIZE) ? shared[lid]
: identity_element;
val = (tid < B::Warps()) ? shared[lid] : identity_element;
if (wid == 0) {
val = WarpReduce(val, op);
}
Expand Down

0 comments on commit a741927

Please sign in to comment.