diff --git a/paddle/fluid/operators/bilateral_slice_op.cu b/paddle/fluid/operators/bilateral_slice_op.cu index 3c64ed1acc847d8f60ad39bf3437b22ad8f2bb4a..e7bf6d212dcf1730c32ed794576d569dff96f215 100644 --- a/paddle/fluid/operators/bilateral_slice_op.cu +++ b/paddle/fluid/operators/bilateral_slice_op.cu @@ -12,8 +12,8 @@ #include #include #include "paddle/fluid/operators/bilateral_slice_op.h" -#include "paddle/fluid/platform/cuda_primitives.h" -#include "paddle/fluid/platform/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" namespace paddle { namespace operators { @@ -472,8 +472,8 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel { grid_sizes.gw = gw; grid_sizes.input_chans = input_chans; - platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D( - ctx.cuda_device_context(), grid_count, 512); + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), grid_count); BilateralSliceCudaGridGradKernel< T><< { grid_grad_data, output_grad_data, guide_data, input_data, grid_sizes, has_offset, grid_count, output_chans); - config = platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), - guide_count, 512); + config = + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), guide_count); BilateralSliceCudaGuideGradKernel< T><< { guide_grad_data, output_grad_data, grid_data, guide_data, input_data, grid_sizes, has_offset, guide_count, output_chans); - config = platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), - input_count, 512); + config = + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), input_count); BilateralSliceCudaInputGradKernel< T><<(x_data) == 4) ? 4 : 1; - int block_size = pten::funcs::GetThreadsConfig(dev_ctx, x_numel, vec_size); - int grid_size = - ((x_numel + vec_size - 1) / vec_size + block_size - 1) / block_size; - + auto gpu_config = GetGpuLaunchConfig1D(dev_ctx, x_numel, vec_size); auto offset = - ((x_numel - 1) / (grid_size * block_size * vec_size) + 1) * vec_size; + ((x_numel - 1) / (gpu_config.GetThreadNum() * vec_size) + 1) * vec_size; GetSeedDataAndIncrement(dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment); @@ -206,23 +203,25 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, #ifdef __HIPCC__ if (vec_size == 4 && size % 4 == 0) { hipLaunchKernelGGL( - HIP_KERNEL_NAME(VectorizedRandomGenerator), grid_size, - block_size, 0, stream, size, seed_data, dropout_prob, x_data, - mask_data, y_data, upscale_in_train, increment); + HIP_KERNEL_NAME(VectorizedRandomGenerator), + gpu_config.GetGridSize(), gpu_config.GetBlockSize(), 0, stream, size, + seed_data, dropout_prob, x_data, mask_data, y_data, upscale_in_train, + increment); } else { hipLaunchKernelGGL(HIP_KERNEL_NAME(RandomGenerator), - grid_size, block_size, 0, stream, size, seed_data, - dropout_prob, x_data, mask_data, y_data, - upscale_in_train, increment); + gpu_config.GetGridSize(), gpu_config.GetBlockSize(), 0, + stream, size, seed_data, dropout_prob, x_data, + mask_data, y_data, upscale_in_train, increment); } #else if (vec_size == 4 && size % 4 == 0) { - VectorizedRandomGenerator<<>>( + VectorizedRandomGenerator<<< + gpu_config.block_per_grid, gpu_config.thread_per_block, 0, stream>>>( size, seed_data, dropout_prob, x_data, mask_data, y_data, upscale_in_train, increment); } else { - RandomGenerator<<>>( + RandomGenerator<<>>( size, seed_data, dropout_prob, x_data, mask_data, y_data, upscale_in_train, increment); } @@ -265,7 +264,7 @@ void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, auto factor = static_cast(1.0f / (1.0f - dropout_prob)); auto stream = dev_ctx.stream(); platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(dev_ctx, size); + platform::GetGpuLaunchConfig1D(dev_ctx, size, vec_size); DropoutGradCUDAKernel< T, uint8_t, 4><<>>( diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index b5c19a3edb81869becc516b3c223402e4fe775ea..779779b44da8d1df275b057bbb9d37828c6904ed 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -128,10 +128,10 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, } else if (dx_data != dout_data && dy_data != dout_data) { auto size = x->numel(); int vec_size = max(static_cast(sizeof(float4) / sizeof(T)), 1); - dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); + dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); dim3 grid_size = - dim3(((size + vec_size - 1) / vec_size + ELEMENTWISE_BLOCK_SIZE - 1) / - ELEMENTWISE_BLOCK_SIZE, + dim3(((size + vec_size - 1) / vec_size + PREDEFINED_BLOCK_SIZE - 1) / + PREDEFINED_BLOCK_SIZE, 1); SimpleElemwiseAddGradCUDAKernel< T><<mutable_data(ctx.GetPlace()); if (dy->dims() == dout->dims()) { if (dy_data != dout_data) { - dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); + dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); auto size = dy->numel(); - dim3 grid_size = dim3( - (size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1); + dim3 grid_size = + dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1); SimpleElemwiseSubGradCUDAKernel<<< grid_size, block_size, 0, ctx.template device_context().stream()>>>( @@ -100,10 +100,10 @@ elementwise_sub_grad(const framework::ExecutionContext& ctx, const framework::Tensor* out, const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dy) { - dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); + dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); auto size = x->numel(); dim3 grid_size = - dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1); + dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1); SimpleElemwiseSubGradCUDAKernel< T><<().stream()>>>( diff --git a/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu b/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu index dc068e02be4ecf457b9156f392aa945a07a6a74b..ebda9bbaa8b81b8b147ccc21ddbb047d1e6df1f7 100644 --- a/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu +++ b/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu @@ -22,7 +22,8 @@ namespace cub = hipcub; #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu index 46dd91fed6cbc17487e7e49c14003953fc8772c2..4260d0516e3cccefaf0cff5ea4b4441af96ef146 100644 --- a/paddle/fluid/operators/index_sample_op.cu +++ b/paddle/fluid/operators/index_sample_op.cu @@ -15,8 +15,8 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/index_sample_op.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/cuda_device_function.h" -#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/beam_search.cu b/paddle/fluid/operators/math/beam_search.cu index ed3ead47d171efb4128a294c7d7a24324c7187b7..cec688262604a10cdce04d9cca324f324196c652 100644 --- a/paddle/fluid/operators/math/beam_search.cu +++ b/paddle/fluid/operators/math/beam_search.cu @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/beam_search.h" -#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/pooling.cu b/paddle/fluid/operators/math/pooling.cu index 84a970a9a26067eaf143188370c1cfe14c7fc72e..9d96345eb1f6dca6fc5eb6cf5847baaf1a9019da 100644 --- a/paddle/fluid/operators/math/pooling.cu +++ b/paddle/fluid/operators/math/pooling.cu @@ -16,16 +16,9 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/pooling.h" -#include "paddle/fluid/platform/cuda_device_function.h" -#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/fast_divmod.h" -#include "paddle/fluid/platform/gpu_launch_config.h" - -#ifdef __HIPCC__ -#define POOLING_BLOCK_SIZE 256 -#else -#define POOLING_BLOCK_SIZE 512 -#endif namespace paddle { namespace operators { @@ -97,22 +90,6 @@ __device__ void OffsetPreparationFor4Dimension( } } -int GetThreadsPerBlock(const platform::CUDADeviceContext& ctx, - int threads_per_block, int64_t numel) { - int sm_count = ctx.GetSMCount(); - if (numel / (sm_count << 1) < threads_per_block) { - // Round up threads number into an exponential multiple of 2, while number - // of acitve blocks is about twice of SM, to acquire better performance. - threads_per_block = platform::RoundToPowerOfTwo(numel / (sm_count << 1)); - } else if (numel / (sm_count << 2) < threads_per_block) { - // Round up threads number into an exponential multiple of 2, while number - // of acitve blocks is about 4 times of SM, to acquire better performance. - threads_per_block = platform::RoundToPowerOfTwo(numel / (sm_count << 2)); - } - // Number of threads per block shall be larger than 64. - return std::max(64, threads_per_block); -} - template __global__ void KernelPool2D( const int nthreads, const T* input_data, const int channels, @@ -491,14 +468,13 @@ class Pool2dGradFunctor { T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * input_channels * input_height * input_width; - int blocks = GetThreadsPerBlock(context, POOLING_BLOCK_SIZE, nthreads); - int grids = (nthreads + blocks - 1) / blocks; - auto pool_divmods = FastDivModForPoolingWithMoreStaff( input_channels, input_width, input_height, ksize_width, ksize_height, stride_width, stride_height); - KernelPool2DGrad<<>>( + auto config = GetGpuLaunchConfig1D(context, nthreads); + KernelPool2DGrad<<< + config.block_per_grid, config.thread_per_block, 0, context.stream()>>>( nthreads, input_data, output_data, output_grad_data, output_width, output_height, input_width, input_height, ksize_width, ksize_height, stride_width, stride_height, padding_width, padding_height, @@ -541,14 +517,13 @@ class Pool2dGradFunctor { T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * input_channels * input_height * input_width; - int blocks = GetThreadsPerBlock(context, POOLING_BLOCK_SIZE, nthreads); - int grids = (nthreads + blocks - 1) / blocks; - auto pool_divmods = FastDivModForPoolingWithMoreStaff( input_channels, input_width, input_height, ksize_width, ksize_height, stride_width, stride_height); - KernelPool2DGrad<<>>( + auto config = GetGpuLaunchConfig1D(context, nthreads); + KernelPool2DGrad<<< + config.block_per_grid, config.thread_per_block, 0, context.stream()>>>( nthreads, input_data, output_data, output_grad_data, output_width, output_height, input_width, input_height, ksize_width, ksize_height, stride_width, stride_height, padding_width, padding_height, diff --git a/paddle/fluid/platform/device/gpu/cuda/cuda_device_function.h b/paddle/fluid/platform/device/gpu/cuda/cuda_device_function.h index 352143302388a9f8169a40a14ccea9bae647cfc6..cd78a89088cc612c3fb43e489cfb7ef2e07cfcf3 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cuda_device_function.h +++ b/paddle/fluid/platform/device/gpu/cuda/cuda_device_function.h @@ -22,41 +22,9 @@ limitations under the License. */ namespace paddle { namespace platform { -#ifdef PADDLE_WITH_HIP -#define CREATE_SHFL_MASK(mask, predicate) mask = __ballot((predicate)) -#else #define FULL_WARP_MASK 0xFFFFFFFF #define CREATE_SHFL_MASK(mask, predicate) \ mask = __ballot_sync(FULL_WARP_MASK, (predicate)) -#endif - -inline static int RoundToPowerOfTwo(int dim) { -#ifdef PADDLE_WITH_CUDA - if (dim > 512) { - return 1024; - } else if (dim > 256) { - return 512; - } else if (dim > 128) { - return 256; - } else if (dim > 64) { - return 128; - } else if (dim > 32) { - return 64; - } else { - return 32; - } -#else // HIP results in error or nan if > 256 - if (dim > 128) { - return 256; - } else if (dim > 64) { - return 128; - } else if (dim > 32) { - return 64; - } else { - return 32; - } -#endif -} #define CUDA_LAUNCH_KERNEL_BASE(dim, ...) \ case (dim): { \ @@ -76,76 +44,20 @@ template __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) { -#if defined(PADDLE_WITH_HIP) - return __shfl_down(val, delta, width); -#else return __shfl_down_sync(mask, val, static_cast(delta), width); -#endif } template __forceinline__ __device__ T CudaShuffleXorSync(unsigned mask, T val, int width = warpSize) { -#if defined(PADDLE_WITH_HIP) - return __shfl_xor(val, width); -#else return __shfl_xor_sync(mask, val, width); -#endif -} - -#if defined(PADDLE_WITH_HIP) -template <> -__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, - float16 val, int delta, - int width) { - return float16(__shfl_down(static_cast(val), - static_cast(delta), width)); } -template <> -__forceinline__ __device__ paddle::platform::complex CudaShuffleDownSync( - unsigned mask, paddle::platform::complex val, int delta, int width) { - float real = __shfl_down(val.real, delta, width); - float imag = __shfl_down(val.imag, delta, width); - return paddle::platform::complex(real, imag); -} - -template <> -__forceinline__ __device__ paddle::platform::complex -CudaShuffleDownSync(unsigned mask, paddle::platform::complex val, - int delta, int width) { - double real = __shfl_down(val.real, delta, width); - double imag = __shfl_down(val.imag, delta, width); - return paddle::platform::complex(real, imag); -} - -template <> -__forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask, - float16 val, int width) { - return float16(__shfl_xor(static_cast(val), width)); -} - -template <> -__forceinline__ __device__ paddle::platform::complex CudaShuffleXorSync( - unsigned mask, paddle::platform::complex val, int width) { - float real = __shfl_xor(val.real, width); - float imag = __shfl_xor(val.imag, width); - return paddle::platform::complex(real, imag); -} - -template <> -__forceinline__ __device__ paddle::platform::complex CudaShuffleXorSync( - unsigned mask, paddle::platform::complex val, int width) { - double real = __shfl_xor(val.real, width); - double imag = __shfl_xor(val.imag, width); - return paddle::platform::complex(real, imag); -} -#else template <> __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, float16 val, int delta, int width) { - return float16(__shfl_down_sync(mask, static_cast(val), + return float16(__shfl_down_sync(mask, val.to_half(), static_cast(delta), width)); } @@ -175,7 +87,7 @@ CudaShuffleDownSync(unsigned mask, paddle::platform::complex val, template <> __forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask, float16 val, int width) { - return float16(__shfl_xor_sync(mask, static_cast(val), width)); + return float16(__shfl_xor_sync(mask, val.to_half(), width)); } template <> @@ -197,16 +109,11 @@ __forceinline__ __device__ paddle::platform::complex CudaShuffleXorSync( __shfl_xor_sync(mask, static_cast(val.imag), width)); return paddle::platform::complex(real, imag); } -#endif template __forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line, int width = 32) { -#if defined(PADDLE_WITH_HIP) - return __shfl(val, src_line, width); -#else return __shfl_sync(mask, val, src_line, width); -#endif } template @@ -216,17 +123,13 @@ HOSTDEVICE T Infinity() { template __device__ T reduceSum(T val, int tid, int len) { -// NOTE(zcd): The warp size should be taken from the -// parameters of the GPU but not specified as 32 simply. -// To make the reduceSum more efficiently, -// I use Warp-Level Parallelism and assume the Warp size -// is 32 which may be different for different GPU, -// but most card's warp size is 32. -#ifdef PADDLE_WITH_HIP - const int warpSize = 64; -#else + // NOTE(zcd): The warp size should be taken from the + // parameters of the GPU but not specified as 32 simply. + // To make the reduceSum more efficiently, + // I use Warp-Level Parallelism and assume the Warp size + // is 32 which may be different for different GPU, + // but most card's warp size is 32. const int warpSize = 32; -#endif __shared__ T shm[warpSize]; unsigned mask = 0u; CREATE_SHFL_MASK(mask, tid < len); diff --git a/paddle/fluid/platform/device/gpu/gpu_launch_config.h b/paddle/fluid/platform/device/gpu/gpu_launch_config.h new file mode 100644 index 0000000000000000000000000000000000000000..883767348f06a99c32664ca2575880737b7418b5 --- /dev/null +++ b/paddle/fluid/platform/device/gpu/gpu_launch_config.h @@ -0,0 +1,173 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Used for compute gpu launch parameter config + +#pragma once + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + +#ifdef PADDLE_WITH_CUDA +#include +#else +#include +#endif + +#include +#include +#include +#include +#include "paddle/fluid/platform/device_context.h" + +#ifdef __HIPCC__ +// HIP results in error or nan if > 256 +#define PREDEFINED_BLOCK_SIZE 256 +#else +/* CUDA performs better as thread_per_block + num is between [64, 512] */ +#define PREDEFINED_BLOCK_SIZE 512 +#endif + +namespace paddle { +namespace platform { + +inline int DivUp(int a, int b) { return (a + b - 1) / b; } + +/* https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 + for round integer value into next highest power of 2. */ +static inline int RoundToPowerOfTwo(int n) { + n--; + n |= (n >> 1); + n |= (n >> 2); + n |= (n >> 4); + n |= (n >> 8); + n |= (n >> 16); +#ifdef __HIPCC__ + return std::min(256, std::max(32, (n + 1))); +#else + return std::min(1024, std::max(32, (n + 1))); +#endif +} + +#ifdef WITH_NV_JETSON +// The number of threads cannot be assigned 1024 in some cases when the device +// is nano or tx2 . +inline void ChangeThreadNum(const platform::CUDADeviceContext& context, + int* num_thread, int alternative_num_thread = 512) { + if (context.GetComputeCapability() == 53 || + context.GetComputeCapability() == 62) { + *num_thread = alternative_num_thread; + } +} +#endif + +struct GpuLaunchConfig { + public: + GpuLaunchConfig() {} + + size_t GetThreadNum() const { return GetBlockSize() * GetGridSize(); } + + size_t GetGridSize() const { + return block_per_grid.x * block_per_grid.y * block_per_grid.z; + } + + size_t GetBlockSize() const { + return thread_per_block.x * thread_per_block.y * thread_per_block.z; + } + + int compute_capability = 0; + dim3 thread_per_block = dim3(1, 1, 1); + dim3 block_per_grid = dim3(1, 1, 1); +}; + +/* According to NVIDIA, if number of threads per block is 64/128/256/512, + * cuda performs better. And number of blocks should be greater (at least + * 2x~4x) than number of SMs. Hence, SM count is took into account within + * this function to determine the right number of threads per block. */ +inline GpuLaunchConfig GetGpuLaunchConfig1D( + const platform::CUDADeviceContext& context, int64_t numel, + int vec_size = 1) { + PADDLE_ENFORCE_GT(numel, 0, platform::errors::InvalidArgument( + "element quantity should be greater than 0," + " but received value is: %d.", + numel)); + // Get compute_capability + const int capability = context.GetComputeCapability(); + /* If thread number per block is 64/128/256/512, cuda performs better.*/ + int limit_threads = + std::min(PREDEFINED_BLOCK_SIZE, context.GetMaxThreadsPerBlock()); +#ifdef WITH_NV_JETSON + if (capability == 53 || capability == 62) { + limit_threads = 512; + } +#endif + int threads = limit_threads; + int sm_count = context.GetSMCount(); + int active_threads_num = numel / vec_size; + if (active_threads_num / (sm_count << 1) < limit_threads) { + // Round up threads number into an exponential multiple of 2, while number + // of acitve blocks is about twice of SM, to acquire better performance. + threads = RoundToPowerOfTwo(active_threads_num / (sm_count << 1)); + } else if (active_threads_num / (sm_count << 2) < limit_threads) { + // Round up threads number into an exponential multiple of 2, while number + // of acitve blocks is about 4 times of SM, to acquire better performance. + threads = RoundToPowerOfTwo(active_threads_num / (sm_count << 2)); + } + // Number of threads per block shall be larger than 64. + threads = std::max(64, threads); + int blocks = DivUp(DivUp(numel, vec_size), threads); + + GpuLaunchConfig config; + config.thread_per_block.x = threads; + config.block_per_grid.x = blocks; + config.compute_capability = capability; + return config; +} + +inline GpuLaunchConfig GetGpuLaunchConfig2D( + const platform::CUDADeviceContext& context, int x_dim, int y_dim) { + PADDLE_ENFORCE_GT(x_dim, 0, platform::errors::InvalidArgument( + "x dim number should greater than 0," + " but received value is: %d", + x_dim)); + PADDLE_ENFORCE_GT(y_dim, 0, platform::errors::InvalidArgument( + "y dim number should greater than 0," + " but received value is: %d", + y_dim)); + + const int kThreadsPerBlock = 256; + int block_cols = (std::min)(x_dim, kThreadsPerBlock); + int block_rows = (std::max)(kThreadsPerBlock / block_cols, 1); + + int max_physical_threads = context.GetMaxPhysicalThreadCount(); + const int max_blocks = (std::max)(max_physical_threads / kThreadsPerBlock, 1); + + GpuLaunchConfig config; + // Noticed, block size is not align to 32, if needed do it yourself. + config.thread_per_block = dim3(block_cols, block_rows, 1); + + int grid_x = (std::min)(DivUp(x_dim, block_cols), max_blocks); + int grid_y = + (std::min)(max_blocks / grid_x, (std::max)(y_dim / block_rows, 1)); + + config.block_per_grid = dim3(grid_x, grid_y, 1); + return config; +} + +// TODO(wangchaochaohu): 3D will add later + +} // namespace platform +} // namespace paddle + +#endif diff --git a/paddle/pten/kernels/gpu/elementwise.h b/paddle/pten/kernels/gpu/elementwise.h index e4cc894e48354b8d7e91aeee74384a7df0891ff3..049e430154a8ba1f209f885e6dd06383c2b65499 100644 --- a/paddle/pten/kernels/gpu/elementwise.h +++ b/paddle/pten/kernels/gpu/elementwise.h @@ -16,9 +16,9 @@ limitations under the License. */ #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/platform/aligned_vector.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/function_traits.h" #include "paddle/pten/core/dense_tensor.h" -#include "paddle/pten/kernels/funcs/cuda_kernel_config.h" namespace pten { @@ -239,18 +239,15 @@ void ElementwiseCudaKernel(const KPDevice &ctx, VecSize><<>>( ins_data, outs_data, numel, main_offset, func); #else - int block_size = funcs::GetThreadsConfig(ctx, numel, VecSize); - int grid_size = - ((numel + VecSize - 1) / VecSize + block_size - 1) / block_size; - int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; + auto gpu_config = GetGpuLaunchConfig1D(ctx, numel, VecSize); + int main_offset = (numel / (VecSize * gpu_config.GetBlockSize())) * VecSize * + gpu_config.GetBlockSize(); auto stream = ctx.stream(); - VectorizedElementwiseKernel<<>>( - ins_data, outs_data, numel, main_offset, func); + VectorizedElementwiseKernel<<< + gpu_config.block_per_grid, + gpu_config.thread_per_block, + 0, + stream>>>(ins_data, outs_data, numel, main_offset, func); #endif }