@@ -100,6 +100,15 @@ __device__ inline float FN_FP32(rcp)(float x) {
100
100
asm (" rcp.approx.ftz.f32 %0, %1;" : " =f" (res) : " f" (x));
101
101
return res;
102
102
}
103
+ __device__ inline float FN_FP32 (tanh_approx)(float x) {
104
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
105
+ float res;
106
+ asm (" tanh.approx.f32 %0, %1;" : " =f" (res) : " f" (x));
107
+ return res;
108
+ #else
109
+ return tanh (x);
110
+ #endif
111
+ }
103
112
104
113
// *************************************************************** //
105
114
// float64 unary and binary operator
@@ -426,7 +435,7 @@ __device__ inline bfloat16 FN_BF16(erf)(bfloat16 x) { return bfloat16(FN_FP32(er
426
435
__device__ inline bfloat16 FN_BF16 (tan)(bfloat16 x) { return bfloat16 (FN_FP32 (tan)(static_cast <float >(x))); }
427
436
__device__ inline bfloat16 FN_BF16 (sinh)(bfloat16 x) { return bfloat16 (FN_FP32 (sinh)(static_cast <float >(x))); }
428
437
__device__ inline bfloat16 FN_BF16 (cosh)(bfloat16 x) { return bfloat16 (FN_FP32 (cosh)(static_cast <float >(x))); }
429
- __device__ inline bfloat16 FN_BF16 (tanh)(bfloat16 x) { return bfloat16 (FN_FP32 (tanh )(static_cast <float >(x))); }
438
+ __device__ inline bfloat16 FN_BF16 (tanh)(bfloat16 x) { return bfloat16 (FN_FP32 (tanh_approx )(static_cast <float >(x))); }
430
439
__device__ inline bfloat16 FN_BF16 (asin)(bfloat16 x) { return bfloat16 (FN_FP32 (asin)(static_cast <float >(x))); }
431
440
__device__ inline bfloat16 FN_BF16 (acos)(bfloat16 x) { return bfloat16 (FN_FP32 (acos)(static_cast <float >(x))); }
432
441
__device__ inline bfloat16 FN_BF16 (atan)(bfloat16 x) { return bfloat16 (FN_FP32 (atan)(static_cast <float >(x))); }
@@ -480,7 +489,7 @@ __device__ inline float16 FN_FP16(erf)(float16 x) { return float16(FN_FP32(erf)(
480
489
__device__ inline float16 FN_FP16 (tan)(float16 x) { return float16 (FN_FP32 (tan)(static_cast <float >(x))); }
481
490
__device__ inline float16 FN_FP16 (sinh)(float16 x) { return float16 (FN_FP32 (sinh)(static_cast <float >(x))); }
482
491
__device__ inline float16 FN_FP16 (cosh)(float16 x) { return float16 (FN_FP32 (cosh)(static_cast <float >(x))); }
483
- __device__ inline float16 FN_FP16 (tanh)(float16 x) { return float16 (FN_FP32 (tanh )(static_cast <float >(x))); }
492
+ __device__ inline float16 FN_FP16 (tanh)(float16 x) { return float16 (FN_FP32 (tanh_approx )(static_cast <float >(x))); }
484
493
__device__ inline float16 FN_FP16 (asin)(float16 x) { return float16 (FN_FP32 (asin)(static_cast <float >(x))); }
485
494
__device__ inline float16 FN_FP16 (acos)(float16 x) { return float16 (FN_FP32 (acos)(static_cast <float >(x))); }
486
495
__device__ inline float16 FN_FP16 (atan)(float16 x) { return float16 (FN_FP32 (atan)(static_cast <float >(x))); }
0 commit comments