Skip to content

Commit 65bcaeb

Browse files
authored
[ROCM] update fluid operators for rocm (part5), test=develop (#31258)
* [ROCM] update fluid operators for rocm (part5), test=develop * address review comments, test=develop * fix typo, test=develop
1 parent 2111d91 commit 65bcaeb

19 files changed

+214
-39
lines changed

cmake/hip.cmake

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ set(THRUST_DEVICE_SYSTEM THRUST_DEVICE_SYSTEM_HIP)
4545
# define HIP_CXX_FLAGS
4646
list(APPEND HIP_CXX_FLAGS -fPIC)
4747
list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_HCC__=1)
48+
# Note(qili93): HIP has compile conflicts of float16.h as platform::float16 overload std::is_floating_point and std::is_integer
4849
list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_CONVERSIONS__=1)
4950
list(APPEND HIP_CXX_FLAGS -Wno-macro-redefined)
5051
list(APPEND HIP_CXX_FLAGS -Wno-inconsistent-missing-override)

paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ x.second );
737737
}
738738

739739
int assign_async(const concurrent_unordered_map& other,
740-
gpuStream_t stream = 0) {
740+
cudaStream_t stream = 0) {
741741
m_collisions = other.m_collisions;
742742
if (other.m_hashtbl_size <= m_hashtbl_capacity) {
743743
m_hashtbl_size = other.m_hashtbl_size;
@@ -754,7 +754,7 @@ x.second );
754754
return 0;
755755
}
756756

757-
void clear_async(gpuStream_t stream = 0) {
757+
void clear_async(cudaStream_t stream = 0) {
758758
constexpr int block_size = 128;
759759
init_hashtbl<<<((m_hashtbl_size - 1) / block_size) + 1, block_size, 0,
760760
stream>>>(m_hashtbl_values, m_hashtbl_size, unused_key,
@@ -771,7 +771,7 @@ x.second );
771771
}
772772
}
773773

774-
int prefetch(const int dev_id, gpuStream_t stream = 0) {
774+
int prefetch(const int dev_id, cudaStream_t stream = 0) {
775775
cudaPointerAttributes hashtbl_values_ptr_attributes;
776776
cudaError_t status = cudaPointerGetAttributes(
777777
&hashtbl_values_ptr_attributes, m_hashtbl_values);

paddle/fluid/operators/array_to_lod_tensor_op.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> {
5151
if (std::is_same<Place, platform::CPUPlace>::value) {
5252
Apply(static_cast<platform::CPUDeviceContext *>(pool.Get(place)));
5353
} else {
54-
#ifdef PADDLE_WITH_CUDA
54+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
5555
Apply(static_cast<platform::CUDADeviceContext *>(pool.Get(place)));
5656
#else
5757
PADDLE_THROW(

paddle/fluid/operators/assign_op.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
164164
ops::AssignKernel, plat::float16,
165165
ops::AssignKernel);
166166

167-
#ifdef PADDLE_WITH_CUDA
167+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
168168
REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
169169
ops::AssignKernel, int, ops::AssignKernel,
170170
int64_t, ops::AssignKernel, bool,

paddle/fluid/operators/math/bert_encoder_functor.cu

+29-8
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include <cuda_runtime.h>
1615
#include <algorithm>
1716
#include "paddle/fluid/framework/tensor.h"
1817
#include "paddle/fluid/framework/tensor_util.h"
@@ -145,6 +144,8 @@ __global__ void EmbEltwiseLayernormKernel(int hidden, const int64_t *ids,
145144
LayerNorm<T, TPB>(thread_data, hidden, out_offset, bias, scale, output, eps);
146145
}
147146

147+
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
148+
#ifndef __HIPCC__ // @{ Half kernel: EmbEltwiseLayernormKernel
148149
template <>
149150
__global__ void EmbEltwiseLayernormKernel<half, 256>(
150151
int hidden, const int64_t *ids, const float *scale, const float *bias,
@@ -188,12 +189,13 @@ __global__ void EmbEltwiseLayernormKernel<half, 256>(
188189
eps);
189190
#endif
190191
}
192+
#endif // @} End Half kernel: EmbEltwiseLayernormKernel
191193

192194
template <typename T>
193195
void EmbEltwiseLayerNormFunctor<T>::operator()(
194196
int batch, int seq_len, int hidden, const int64_t *ids, const float *scale,
195197
const float *bias, const int64_t *embs, T *output, float eps, int input_num,
196-
cudaStream_t stream) {
198+
gpuStream_t stream) {
197199
const unsigned tpb = 256;
198200
const dim3 grid(seq_len, batch, 1);
199201
const dim3 block(tpb, 1, 1);
@@ -205,7 +207,8 @@ void EmbEltwiseLayerNormFunctor<T>::operator()(
205207
template class EmbEltwiseLayerNormFunctor<float>;
206208

207209
// device function 'operator()' is not supportted until cuda 10.0
208-
#if CUDA_VERSION >= 10000
210+
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
211+
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
209212
template class EmbEltwiseLayerNormFunctor<half>;
210213
#endif
211214

@@ -230,6 +233,8 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_, const T *bias_qk_,
230233
qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / sum_val);
231234
}
232235

236+
// HIP defined __HIP_NO_HALF_CONVERSIONS__
237+
#ifndef __HIPCC__ // @{ Half kernel: SoftmaxKernelWithEltadd
233238
template <>
234239
__global__ void SoftmaxKernelWithEltadd<half>(
235240
half *qk_buf_, const half *bias_qk_, const int batch_size,
@@ -251,6 +256,7 @@ __global__ void SoftmaxKernelWithEltadd<half>(
251256
qk_buf_[threadIdx.x + qk_offset] = (half)(qk_tmp / sum_val);
252257
#endif
253258
}
259+
#endif // @} End Half kernel: SoftmaxKernelWithEltadd
254260

255261
template <typename T>
256262
__global__ void SoftmaxKernelWithEltadd2(T *qk_buf_, const T *bias_qk_,
@@ -282,7 +288,9 @@ __global__ void SoftmaxKernelWithEltadd2<half2>(
282288
half2 *qk_buf_, const half2 *bias_qk_, const int batch_size,
283289
const int head_num, const int seq_len, const unsigned mask) {
284290
// operator "+" of half only suppotted after cuda version 10.0
285-
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000
291+
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
292+
#if defined(PADDLE_WITH_CUDA) || \
293+
(CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000)
286294
int qk_offset = blockIdx.x * seq_len;
287295
int idx = threadIdx.x;
288296
assert(blockDim.x % 32 == 0);
@@ -398,7 +406,8 @@ void MultiHeadGPUComputeFunctor<T>::operator()(
398406
template class MultiHeadGPUComputeFunctor<float>;
399407

400408
// device function 'operator()' is not supportted until cuda 10.0
401-
#if CUDA_VERSION >= 10000
409+
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
410+
#if defined(PADDLE_WITH_CUDA) || CUDA_VERSION >= 10000
402411
template class MultiHeadGPUComputeFunctor<half>;
403412
#endif
404413

@@ -422,6 +431,8 @@ __global__ void SkipLayerNormSmallKernel(int num, int hidden, const T *input1,
422431
eps);
423432
}
424433

434+
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
435+
#ifndef __HIPCC__ // @{ Half kernel: SkipLayerNormSmallKernel
425436
template <>
426437
__global__ void SkipLayerNormSmallKernel<half, 32>(
427438
int num, int hidden, const half *input1, const half *input2, half *output,
@@ -484,6 +495,7 @@ __global__ void SkipLayerNormSmallKernel<half, 384>(
484495
eps);
485496
#endif
486497
}
498+
#endif // @} End Half kernel: SkipLayerNormSmallKernel
487499

488500
template <typename T, unsigned TPB>
489501
__global__ void SkipLayerNormKernel(int num, int hidden, const T *input1,
@@ -505,6 +517,8 @@ __global__ void SkipLayerNormKernel(int num, int hidden, const T *input1,
505517
LayerNorm<T, TPB>(thread_data, hidden, offset, bias, scale, output, eps);
506518
}
507519

520+
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
521+
#ifndef __HIPCC__ // @{ Half kernel: SkipLayerNormKernel
508522
template <>
509523
__global__ void SkipLayerNormKernel<half, 256>(int num, int hidden,
510524
const half *input1,
@@ -527,6 +541,7 @@ __global__ void SkipLayerNormKernel<half, 256>(int num, int hidden,
527541
LayerNorm<half, 256>(thread_data, hidden, offset, bias, scale, output, eps);
528542
#endif
529543
}
544+
#endif // @} End Half kernel: SkipLayerNormKernel
530545

531546
template <typename T, typename T2, unsigned TPB>
532547
__global__ void SkipLayerNormKernel2(int num, int hidden, const T2 *input1,
@@ -549,6 +564,8 @@ __global__ void SkipLayerNormKernel2(int num, int hidden, const T2 *input1,
549564
LayerNorm2<T, T2, TPB>(thread_data, hidden, offset, bias, scale, output, eps);
550565
}
551566

567+
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
568+
#ifndef __HIPCC__ // @{ Half kernel: SkipLayerNormKernel2
552569
template <>
553570
__global__ void SkipLayerNormKernel2<half, half2, 256>(
554571
int num, int hidden, const half2 *input1, const half2 *input2,
@@ -572,13 +589,13 @@ __global__ void SkipLayerNormKernel2<half, half2, 256>(
572589
eps);
573590
#endif
574591
}
592+
#endif // @} End Half kernel: SkipLayerNormKernel2
575593

576594
template <typename T>
577595
void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
578596
const T *input1, const T *input2,
579597
const float *scale, const float *bias,
580-
T *output, T eps,
581-
cudaStream_t stream) {
598+
T *output, T eps, gpuStream_t stream) {
582599
int block = num / hidden;
583600
if (hidden <= 32) {
584601
const int threads = 32;
@@ -603,6 +620,8 @@ void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
603620
reinterpret_cast<float2 *>(output),
604621
reinterpret_cast<const float2 *>(scale),
605622
reinterpret_cast<const float2 *>(bias), eps);
623+
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
624+
#ifndef __HIPCC__
606625
} else if (std::is_same<T, __half>::value) {
607626
SkipLayerNormKernel2<__half, __half2,
608627
threads><<<block, threads, 0, stream>>>(
@@ -611,6 +630,7 @@ void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
611630
reinterpret_cast<__half2 *>(output),
612631
reinterpret_cast<const float2 *>(scale),
613632
reinterpret_cast<const float2 *>(bias), eps);
633+
#endif
614634
} else {
615635
assert(false);
616636
// should not be here
@@ -625,7 +645,8 @@ void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
625645
template class SkipLayerNormFunctor<float>;
626646

627647
// device function 'operator()' is not supportted until cuda 10.0
628-
#if CUDA_VERSION >= 10000
648+
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
649+
#if defined(PADDLE_WITH_CUDA) || CUDA_VERSION >= 10000
629650
template class SkipLayerNormFunctor<half>;
630651
#endif
631652

paddle/fluid/operators/math/bert_encoder_functor.h

+12-3
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,18 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
17+
#ifdef PADDLE_WITH_CUDA
1618
#include <cuda.h>
1719
#include <cuda_runtime.h>
1820
#include <cub/cub.cuh> // NOLINT
21+
#endif
22+
#ifdef PADDLE_WITH_HIP
23+
#include <hip/hip_runtime.h>
24+
#include <hipcub/hipcub.hpp>
25+
namespace cub = hipcub;
26+
#endif
27+
1928
#include "paddle/fluid/platform/device_context.h"
2029
#include "paddle/fluid/platform/float16.h"
2130

@@ -36,7 +45,7 @@ struct CUDATypeTraits<float> {
3645
typedef float TYPE;
3746
};
3847

39-
#ifdef PADDLE_WITH_CUDA
48+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
4049
// This functor involves a fusion calculation in Ernie or Bert.
4150
// The fusion mode is as follows:
4251
//
@@ -55,7 +64,7 @@ class EmbEltwiseLayerNormFunctor {
5564
public:
5665
void operator()(int batch, int seq_len, int hidden, const int64_t *ids,
5766
const float *scale, const float *bias, const int64_t *embs,
58-
T *output, float eps, int input_num, cudaStream_t stream);
67+
T *output, float eps, int input_num, gpuStream_t stream);
5968
};
6069

6170
// This functor involves a fusion calculation in Ernie or Bert.
@@ -97,7 +106,7 @@ class SkipLayerNormFunctor {
97106
public:
98107
void operator()(const int num, const int hidden, const T *input1,
99108
const T *input2, const float *scale, const float *bias,
100-
T *output, T eps, cudaStream_t stream);
109+
T *output, T eps, gpuStream_t stream);
101110
};
102111
#endif
103112

paddle/fluid/operators/math/depthwise_conv.cu

+14-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@ limitations under the License. */
1414

1515
#include <algorithm>
1616
#include <vector>
17-
#include "cub/cub.cuh"
17+
#ifdef __NVCC__
18+
#include <cub/cub.cuh>
19+
#endif
20+
#ifdef __HIPCC__
21+
#include <hipcub/hipcub.hpp>
22+
namespace cub = hipcub;
23+
#endif
1824
#include "paddle/fluid/operators/math/depthwise_conv.h"
1925
#include "paddle/fluid/platform/cuda_device_function.h"
2026
#include "paddle/fluid/platform/cuda_primitives.h"
@@ -27,7 +33,14 @@ template <typename T>
2733
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
2834
typedef cub::WarpReduce<T> WarpReduce;
2935
typename WarpReduce::TempStorage temp_storage;
36+
37+
#ifdef __HIPCC__
38+
int block_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize);
39+
value = WarpReduce(temp_storage).Sum(value, block_size);
40+
#else
3041
value = WarpReduce(temp_storage).Sum(value);
42+
#endif
43+
3144
if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value);
3245
}
3346

0 commit comments

Comments
 (0)