@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
- #include < cuda_runtime.h>
16
15
#include < algorithm>
17
16
#include " paddle/fluid/framework/tensor.h"
18
17
#include " paddle/fluid/framework/tensor_util.h"
@@ -145,6 +144,8 @@ __global__ void EmbEltwiseLayernormKernel(int hidden, const int64_t *ids,
145
144
LayerNorm<T, TPB>(thread_data, hidden, out_offset, bias, scale, output, eps);
146
145
}
147
146
147
+ // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
148
+ #ifndef __HIPCC__ // @{ Half kernel: EmbEltwiseLayernormKernel
148
149
template <>
149
150
__global__ void EmbEltwiseLayernormKernel<half, 256 >(
150
151
int hidden, const int64_t *ids, const float *scale, const float *bias,
@@ -188,12 +189,13 @@ __global__ void EmbEltwiseLayernormKernel<half, 256>(
188
189
eps);
189
190
#endif
190
191
}
192
+ #endif // @} End Half kernel: EmbEltwiseLayernormKernel
191
193
192
194
template <typename T>
193
195
void EmbEltwiseLayerNormFunctor<T>::operator ()(
194
196
int batch, int seq_len, int hidden, const int64_t *ids, const float *scale,
195
197
const float *bias, const int64_t *embs, T *output, float eps, int input_num,
196
- cudaStream_t stream) {
198
+ gpuStream_t stream) {
197
199
const unsigned tpb = 256 ;
198
200
const dim3 grid (seq_len, batch, 1 );
199
201
const dim3 block (tpb, 1 , 1 );
@@ -205,7 +207,8 @@ void EmbEltwiseLayerNormFunctor<T>::operator()(
205
207
template class EmbEltwiseLayerNormFunctor <float >;
206
208
207
209
// device function 'operator()' is not supportted until cuda 10.0
208
- #if CUDA_VERSION >= 10000
210
+ // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
211
+ #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
209
212
template class EmbEltwiseLayerNormFunctor <half>;
210
213
#endif
211
214
@@ -230,6 +233,8 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_, const T *bias_qk_,
230
233
qk_buf_[threadIdx .x + qk_offset] = (T)(qk_tmp / sum_val);
231
234
}
232
235
236
+ // HIP defined __HIP_NO_HALF_CONVERSIONS__
237
+ #ifndef __HIPCC__ // @{ Half kernel: SoftmaxKernelWithEltadd
233
238
template <>
234
239
__global__ void SoftmaxKernelWithEltadd<half>(
235
240
half *qk_buf_, const half *bias_qk_, const int batch_size,
@@ -251,6 +256,7 @@ __global__ void SoftmaxKernelWithEltadd<half>(
251
256
qk_buf_[threadIdx .x + qk_offset] = (half)(qk_tmp / sum_val);
252
257
#endif
253
258
}
259
+ #endif // @} End Half kernel: SoftmaxKernelWithEltadd
254
260
255
261
template <typename T>
256
262
__global__ void SoftmaxKernelWithEltadd2 (T *qk_buf_, const T *bias_qk_,
@@ -282,7 +288,9 @@ __global__ void SoftmaxKernelWithEltadd2<half2>(
282
288
half2 *qk_buf_, const half2 *bias_qk_, const int batch_size,
283
289
const int head_num, const int seq_len, const unsigned mask) {
284
290
// operator "+" of half only suppotted after cuda version 10.0
285
- #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000
291
+ // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
292
+ #if defined(PADDLE_WITH_CUDA) || \
293
+ (CUDA_ARCH_FP16_SUPPORTED (__CUDA_ARCH__) && CUDA_VERSION >= 10000 )
286
294
int qk_offset = blockIdx .x * seq_len;
287
295
int idx = threadIdx .x ;
288
296
assert (blockDim .x % 32 == 0 );
@@ -398,7 +406,8 @@ void MultiHeadGPUComputeFunctor<T>::operator()(
398
406
template class MultiHeadGPUComputeFunctor <float >;
399
407
400
408
// device function 'operator()' is not supportted until cuda 10.0
401
- #if CUDA_VERSION >= 10000
409
+ // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
410
+ #if defined(PADDLE_WITH_CUDA) || CUDA_VERSION >= 10000
402
411
template class MultiHeadGPUComputeFunctor <half>;
403
412
#endif
404
413
@@ -422,6 +431,8 @@ __global__ void SkipLayerNormSmallKernel(int num, int hidden, const T *input1,
422
431
eps);
423
432
}
424
433
434
+ // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
435
+ #ifndef __HIPCC__ // @{ Half kernel: SkipLayerNormSmallKernel
425
436
template <>
426
437
__global__ void SkipLayerNormSmallKernel<half, 32 >(
427
438
int num, int hidden, const half *input1, const half *input2, half *output,
@@ -484,6 +495,7 @@ __global__ void SkipLayerNormSmallKernel<half, 384>(
484
495
eps);
485
496
#endif
486
497
}
498
+ #endif // @} End Half kernel: SkipLayerNormSmallKernel
487
499
488
500
template <typename T, unsigned TPB>
489
501
__global__ void SkipLayerNormKernel (int num, int hidden, const T *input1,
@@ -505,6 +517,8 @@ __global__ void SkipLayerNormKernel(int num, int hidden, const T *input1,
505
517
LayerNorm<T, TPB>(thread_data, hidden, offset, bias, scale, output, eps);
506
518
}
507
519
520
+ // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
521
+ #ifndef __HIPCC__ // @{ Half kernel: SkipLayerNormKernel
508
522
template <>
509
523
__global__ void SkipLayerNormKernel<half, 256 >(int num, int hidden,
510
524
const half *input1,
@@ -527,6 +541,7 @@ __global__ void SkipLayerNormKernel<half, 256>(int num, int hidden,
527
541
LayerNorm<half, 256 >(thread_data, hidden, offset, bias, scale, output, eps);
528
542
#endif
529
543
}
544
+ #endif // @} End Half kernel: SkipLayerNormKernel
530
545
531
546
template <typename T, typename T2, unsigned TPB>
532
547
__global__ void SkipLayerNormKernel2 (int num, int hidden, const T2 *input1,
@@ -549,6 +564,8 @@ __global__ void SkipLayerNormKernel2(int num, int hidden, const T2 *input1,
549
564
LayerNorm2<T, T2, TPB>(thread_data, hidden, offset, bias, scale, output, eps);
550
565
}
551
566
567
+ // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
568
+ #ifndef __HIPCC__ // @{ Half kernel: SkipLayerNormKernel2
552
569
template <>
553
570
__global__ void SkipLayerNormKernel2<half, half2, 256 >(
554
571
int num, int hidden, const half2 *input1, const half2 *input2,
@@ -572,13 +589,13 @@ __global__ void SkipLayerNormKernel2<half, half2, 256>(
572
589
eps);
573
590
#endif
574
591
}
592
+ #endif // @} End Half kernel: SkipLayerNormKernel2
575
593
576
594
template <typename T>
577
595
void SkipLayerNormFunctor<T>::operator ()(const int num, const int hidden,
578
596
const T *input1, const T *input2,
579
597
const float *scale, const float *bias,
580
- T *output, T eps,
581
- cudaStream_t stream) {
598
+ T *output, T eps, gpuStream_t stream) {
582
599
int block = num / hidden;
583
600
if (hidden <= 32 ) {
584
601
const int threads = 32 ;
@@ -603,6 +620,8 @@ void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
603
620
reinterpret_cast <float2 *>(output),
604
621
reinterpret_cast <const float2 *>(scale),
605
622
reinterpret_cast <const float2 *>(bias), eps);
623
+ // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
624
+ #ifndef __HIPCC__
606
625
} else if (std::is_same<T, __half>::value) {
607
626
SkipLayerNormKernel2<__half, __half2,
608
627
threads><<<block, threads, 0 , stream>>> (
@@ -611,6 +630,7 @@ void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
611
630
reinterpret_cast <__half2 *>(output),
612
631
reinterpret_cast <const float2 *>(scale),
613
632
reinterpret_cast <const float2 *>(bias), eps);
633
+ #endif
614
634
} else {
615
635
assert (false );
616
636
// should not be here
@@ -625,7 +645,8 @@ void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
625
645
template class SkipLayerNormFunctor <float >;
626
646
627
647
// device function 'operator()' is not supportted until cuda 10.0
628
- #if CUDA_VERSION >= 10000
648
+ // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
649
+ #if defined(PADDLE_WITH_CUDA) || CUDA_VERSION >= 10000
629
650
template class SkipLayerNormFunctor <half>;
630
651
#endif
631
652
0 commit comments