diff --git a/csrc/cuda/utils.cuh b/csrc/cuda/utils.cuh index 396b4fa1..53d328cd 100644 --- a/csrc/cuda/utils.cuh +++ b/csrc/cuda/utils.cuh @@ -6,6 +6,7 @@ AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") +#ifndef USE_ROCM __device__ __inline__ at::Half __shfl_up_sync(const unsigned mask, const at::Half var, const unsigned int delta) { @@ -17,6 +18,7 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask, const unsigned int delta) { return __shfl_down_sync(mask, var.operator __half(), delta); } +#endif __device__ __inline__ at::Half __shfl_up(const at::Half var, const unsigned int delta) {