Skip to content

Commit 0c4a400

Browse files
authored
[CINN] Use approx tanh for fp16/bf16 (#72871)
1 parent 73173cc commit 0c4a400

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,15 @@ __device__ inline float FN_FP32(rcp)(float x) {
100100
asm("rcp.approx.ftz.f32 %0, %1;" : "=f"(res) : "f"(x));
101101
return res;
102102
}
103+
__device__ inline float FN_FP32(tanh_approx)(float x) {
104+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
105+
float res;
106+
asm("tanh.approx.f32 %0, %1;" : "=f"(res) : "f"(x));
107+
return res;
108+
#else
109+
return tanh(x);
110+
#endif
111+
}
103112

104113
// *************************************************************** //
105114
// float64 unary and binary operator
@@ -426,7 +435,7 @@ __device__ inline bfloat16 FN_BF16(erf)(bfloat16 x) { return bfloat16(FN_FP32(er
426435
__device__ inline bfloat16 FN_BF16(tan)(bfloat16 x) { return bfloat16(FN_FP32(tan)(static_cast<float>(x))); }
427436
__device__ inline bfloat16 FN_BF16(sinh)(bfloat16 x) { return bfloat16(FN_FP32(sinh)(static_cast<float>(x))); }
428437
__device__ inline bfloat16 FN_BF16(cosh)(bfloat16 x) { return bfloat16(FN_FP32(cosh)(static_cast<float>(x))); }
429-
__device__ inline bfloat16 FN_BF16(tanh)(bfloat16 x) { return bfloat16(FN_FP32(tanh)(static_cast<float>(x))); }
438+
__device__ inline bfloat16 FN_BF16(tanh)(bfloat16 x) { return bfloat16(FN_FP32(tanh_approx)(static_cast<float>(x))); }
430439
__device__ inline bfloat16 FN_BF16(asin)(bfloat16 x) { return bfloat16(FN_FP32(asin)(static_cast<float>(x))); }
431440
__device__ inline bfloat16 FN_BF16(acos)(bfloat16 x) { return bfloat16(FN_FP32(acos)(static_cast<float>(x))); }
432441
__device__ inline bfloat16 FN_BF16(atan)(bfloat16 x) { return bfloat16(FN_FP32(atan)(static_cast<float>(x))); }
@@ -480,7 +489,7 @@ __device__ inline float16 FN_FP16(erf)(float16 x) { return float16(FN_FP32(erf)(
480489
__device__ inline float16 FN_FP16(tan)(float16 x) { return float16(FN_FP32(tan)(static_cast<float>(x))); }
481490
__device__ inline float16 FN_FP16(sinh)(float16 x) { return float16(FN_FP32(sinh)(static_cast<float>(x))); }
482491
__device__ inline float16 FN_FP16(cosh)(float16 x) { return float16(FN_FP32(cosh)(static_cast<float>(x))); }
483-
__device__ inline float16 FN_FP16(tanh)(float16 x) { return float16(FN_FP32(tanh)(static_cast<float>(x))); }
492+
__device__ inline float16 FN_FP16(tanh)(float16 x) { return float16(FN_FP32(tanh_approx)(static_cast<float>(x))); }
484493
__device__ inline float16 FN_FP16(asin)(float16 x) { return float16(FN_FP32(asin)(static_cast<float>(x))); }
485494
__device__ inline float16 FN_FP16(acos)(float16 x) { return float16(FN_FP32(acos)(static_cast<float>(x))); }
486495
__device__ inline float16 FN_FP16(atan)(float16 x) { return float16(FN_FP32(atan)(static_cast<float>(x))); }

0 commit comments

Comments
 (0)