diff --git a/model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.h b/model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.h index 5e7b3a1d88ba..e5e3cc563ea5 100644 --- a/model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.h +++ b/model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.h @@ -20,9 +20,12 @@ #pragma once // NOLINT +#ifdef PADDLE_WITH_HIP +#include +#else #include // NOLINT #include // NOLINT - +#endif #include "paddle/extension.h" #define DEFAULT_THROW(NAME, TYPE) \ @@ -71,14 +74,22 @@ DEFAULT_THROW(NAME, TYPEIN); \ } +#ifdef PADDLE_WITH_HIP +#define WARP_SIZE 64 +#else #define WARP_SIZE 32 +#endif template __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = WARP_SIZE, unsigned int mask = 0xffffffff) { - return __shfl_xor_sync(mask, value, laneMask, width); + #ifdef PADDLE_WITH_HIP + return __shfl_xor(value, laneMask, width); + #else + return __shfl_xor_sync(mask,value, laneMask, width); + #endif } template @@ -86,7 +97,11 @@ __device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = WARP_SIZE, unsigned int mask = 0xffffffff) { + #ifdef PADDLE_WITH_HIP + return __shfl(value, srcLane, width); + #else return __shfl_sync(mask, value, srcLane, width); + #endif } template @@ -181,8 +196,17 @@ __device__ void cuWelfordMuSigma2(const T* __restrict__ vals, } } // intra-warp reductions - for (int l = 0; l <= 4; ++l) { + #ifdef PADDLE_WITH_HIP + for (int l = 0; l <= 5; ++l) + #else + for (int l = 0; l <= 4; ++l) + #endif + { + #ifdef PADDLE_WITH_HIP + int srcLaneB = (threadIdx.x + (1 << l)) & 63; + #else int srcLaneB = (threadIdx.x + (1 << l)) & 31; + #endif U sigma2B = WARP_SHFL(sigma2, srcLaneB); if (!rms_only) { U muB = WARP_SHFL(mu, srcLaneB); @@ -306,8 +330,17 @@ __device__ void cuWelfordMuSigma2(const phi::dtype::float16* __restrict__ vals, } } // intra-warp reductions - for (int l = 0; l <= 4; ++l) { + #ifdef PADDLE_WITH_HIP + for (int l = 0; l <= 5; ++l) + #else + for (int l = 0; l <= 4; ++l) + #endif + { + #ifdef PADDLE_WITH_HIP + int srcLaneB = (threadIdx.x + (1 << l)) & 63; + #else int srcLaneB = (threadIdx.x + (1 << l)) & 31; + #endif float sigma2B = WARP_SHFL(sigma2, srcLaneB); if (!rms_only) { float muB = WARP_SHFL(mu, srcLaneB); @@ -369,15 +402,15 @@ __device__ void cuWelfordMuSigma2(const phi::dtype::float16* __restrict__ vals, } } -template +template __device__ U rsqrt(U v) { return U(1) / sqrt(v); } -template <> +template <> __device__ float rsqrt(float v) { return rsqrtf(v); } -template <> +template <> __device__ double rsqrt(double v) { return rsqrt(v); } @@ -914,6 +947,22 @@ __global__ void cuComputeGradInput(const V* __restrict__ dout, } } +#ifdef PADDLE_WITH_HIP +static hipDeviceProp_t GetDevicePropImpl() { + int device = -1; + PD_CHECK(hipGetDevice(&device) == hipSuccess); + hipDeviceProp_t prop; + PD_CHECK(hipGetDeviceProperties(&prop, device) == hipSuccess); + return prop; +} + +static hipDeviceProp_t* GetDeviceProp() { + static auto prop = GetDevicePropImpl(); + return ∝ +} + +#else + static cudaDeviceProp GetDevicePropImpl() { int device = -1; PD_CHECK(cudaGetDevice(&device) == cudaSuccess); @@ -926,8 +975,10 @@ static cudaDeviceProp* GetDeviceProp() { static auto prop = GetDevicePropImpl(); return ∝ } +#endif template +#ifdef PADDLE_WITH_HIP void HostApplyLayerNorm(V* output, U* mean, U* invvar, @@ -937,8 +988,25 @@ void HostApplyLayerNorm(V* output, double epsilon, const V* gamma, const V* beta, - cudaStream_t stream) { + hipStream_t stream) +#else +void HostApplyLayerNorm(V* output, + U* mean, + U* invvar, + const T* input, + int n1, + int n2, + double epsilon, + const V* gamma, + const V* beta, + cudaStream_t stream) +#endif +{ + #ifdef PADDLE_WITH_HIP + const dim3 threads(64, 4, 1); + #else const dim3 threads(32, 4, 1); + #endif const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); int nshared = @@ -948,6 +1016,16 @@ void HostApplyLayerNorm(V* output, } template +#ifdef PADDLE_WITH_HIP +void HostApplyRMSNorm(V* output, + U* invvar, + const T* input, + int n1, + int n2, + double epsilon, + const V* gamma, + hipStream_t stream) +#else void HostApplyRMSNorm(V* output, U* invvar, const T* input, @@ -955,9 +1033,15 @@ void HostApplyRMSNorm(V* output, int n2, double epsilon, const V* gamma, - cudaStream_t stream) { + cudaStream_t stream) +#endif +{ // auto stream = at::cuda::getCurrentCUDAStream().stream(); + #ifdef PADDLE_WITH_HIP + const dim3 threads(64, 4, 1); + #else const dim3 threads(32, 4, 1); + #endif // const uint64_t maxGridY = // at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; @@ -1015,6 +1099,7 @@ static void cuda_rms_norm(const paddle::Tensor& x, } template +#ifdef PADDLE_WITH_HIP void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, @@ -1027,7 +1112,23 @@ void HostLayerNormGradient(const V* dout, T* grad_input, V* grad_gamma, V* grad_beta, - cudaStream_t stream) { + hipStream_t stream) +#else +void HostLayerNormGradient(const V* dout, + const U* mean, + const U* invvar, + const paddle::Tensor& input, + int n1, + int n2, + const V* gamma, + const V* beta, + double epsilon, + T* grad_input, + V* grad_gamma, + V* grad_beta, + cudaStream_t stream) +#endif +{ if (gamma != NULL && beta != NULL) { // compute grad_gamma(j) and grad_beta(j) const int part_size = 16; @@ -1085,6 +1186,18 @@ void HostLayerNormGradient(const V* dout, } template +#ifdef PADDLE_WITH_HIP +void HostRMSNormGradient(const V* dout, + const U* invvar, + const paddle::Tensor& input, + int n1, + int n2, + const V* gamma, + double epsilon, + T* grad_input, + V* grad_gamma, + hipStream_t stream) +#else void HostRMSNormGradient(const V* dout, const U* invvar, const paddle::Tensor& input, @@ -1094,7 +1207,9 @@ void HostRMSNormGradient(const V* dout, double epsilon, T* grad_input, V* grad_gamma, - cudaStream_t stream) { + cudaStream_t stream) +#endif +{ if (gamma != NULL) { const int part_size = 16; const dim3 threads2(32, 4, 1); diff --git a/model_zoo/gpt-3/external_ops/setup.py b/model_zoo/gpt-3/external_ops/setup.py index f067cbf19667..9e51bbe00dfb 100644 --- a/model_zoo/gpt-3/external_ops/setup.py +++ b/model_zoo/gpt-3/external_ops/setup.py @@ -37,69 +37,97 @@ def change_pwd(): def setup_fast_ln(): from paddle.utils.cpp_extension import CUDAExtension, setup + from paddle.device import is_compiled_with_rocm - gencode_flags = get_gencode_flags() - change_pwd() - setup( - name="fast_ln", - ext_modules=CUDAExtension( - sources=[ - "fast_ln/ln_api.cpp", - "fast_ln/ln_bwd_semi_cuda_kernel.cu", - "fast_ln/ln_fwd_cuda_kernel.cu", - ], - extra_compile_args={ - "cxx": ["-O3"], - "nvcc": [ - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "-I./apex/contrib/csrc/layer_norm/", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] - + gencode_flags, - }, - ), - ) + if(is_compiled_with_rocm()): + print("The 'fasl_ln' feature is temporarily not supported on the ROCm platform !!!") + else: + gencode_flags = get_gencode_flags() + change_pwd() + setup( + name="fast_ln", + ext_modules=CUDAExtension( + sources=[ + "fast_ln/ln_api.cpp", + "fast_ln/ln_bwd_semi_cuda_kernel.cu", + "fast_ln/ln_fwd_cuda_kernel.cu", + ], + extra_compile_args={ + "cxx": ["-O3"], + "nvcc": [ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "-I./apex/contrib/csrc/layer_norm/", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + ] + + gencode_flags, + }, + ), + ) def setup_fused_ln(): from paddle.utils.cpp_extension import CUDAExtension, setup + from paddle.device import is_compiled_with_rocm gencode_flags = get_gencode_flags() change_pwd() - setup( - name="fused_ln", - ext_modules=CUDAExtension( - sources=[ - "fused_ln/layer_norm_cuda.cu", - ], - extra_compile_args={ - "cxx": ["-O3"], - "nvcc": [ - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "-I./apex/contrib/csrc/layer_norm/", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "-maxrregcount=50", - ] - + gencode_flags, - }, - ), - ) + if(is_compiled_with_rocm()): + setup( + name="fused_ln", + ext_modules=CUDAExtension( + sources=[ + "fused_ln/layer_norm_cuda.cu", + ], + extra_compile_args={ + "cxx": ["-O3"], + "hipcc": [ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "-DPADDLE_WITH_HIP", + ] + }, + ), + ) + else: + setup( + name="fused_ln", + ext_modules=CUDAExtension( + sources=[ + "fused_ln/layer_norm_cuda.cu", + ], + extra_compile_args={ + "cxx": ["-O3"], + "nvcc": [ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "-I./apex/contrib/csrc/layer_norm/", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "-maxrregcount=50", + ] + + gencode_flags, + }, + ), + ) run(setup_fast_ln)