Skip to content

Commit fcfbd23

Browse files
Gregory-Meyerhwu36
andauthored
Fix host compilation of cute::cast_smem_ptr_to_uint. (PaddlePaddle#940)
* Remove references to device-only intrinsics when compiling for host. Currently, we attempt to use the `__device__`-only functions `__cvta_generic_to_shared` and `__nvvm_get_smem_pointer` when compiling `cute::cast_smem_ptr_to_uint` for the host on Clang. This results in a compilation error, as expected. This commit changes the definition of the `*_ACTIVATED` macros so that they are only true when `__CUDA_ARCH__` is defined; that is, when compiling for the device. Additionally, the declaration of `__nvvm_get_smem_pointer` is currently only visible during the device compilation pass when compiling with NVCC; this commit makes the declaration visible during host compilation with the `__device__` annotation. * Annotate cute::cast_smem_ptr_to_uint as device-only. The implementation of `cute::cast_smem_ptr_to_uint` is currently an unchecked failure on host code, and the only host implementation I can think of -- casting a probably-64-bit pointer to 32 bits somehow -- doesn't make sense to implement. This commit marks this function as device-only so that it can't be accidentally used on host code. * small change --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
1 parent b250fac commit fcfbd23

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

include/cute/arch/util.hpp

+19-15
Original file line numberDiff line numberDiff line change
@@ -36,39 +36,43 @@
3636

3737
#if defined(__clang__) && defined(__CUDA__)
3838
// __cvta_generic_to_shared was added in Clang 14: https://reviews.llvm.org/D111665
39-
#define CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED (__clang_major__ >= 14)
39+
#if __clang_major__ >= 14
40+
#define CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED 1
41+
#endif
4042

41-
#ifndef _WIN32
4243
// __nvvm_get_smem_pointer added in Clang 14: https://reviews.llvm.org/D111665
43-
#define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER (__clang_major__ >= 14)
44-
#else
45-
// ... but will not work on Windows until Clang 15: https://reviews.llvm.org/D122897
46-
#define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER (__clang_major__ >= 15)
44+
// ... but will not work on Windows until Clang 15: https://reviews.llvm.org/D122897
45+
#if (!defined(_WIN32) && __clang_major__ >= 14) || __clang_major__ >= 15
46+
#define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER 1
4747
#endif
4848
#endif
4949

5050
#if defined(__NVCC__) || defined(__CUDACC_RTC__)
5151
// __cvta_generic_to_shared added in CUDA 11+
52-
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11)
52+
#if __CUDACC_VER_MAJOR__ >= 11
5353
#define CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED 1
5454
#endif
5555

5656
// __nvvm_get_smem_pointer added in CUDA 10.2
57-
#if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2
57+
#if __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2
5858
#define CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER 1
5959
#endif
6060
#endif
6161

62-
#define CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED (CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED || CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED)
62+
#if CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED || CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED
63+
#define CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED 1
64+
#endif
6365

64-
#ifndef CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED
65-
#define CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED
66+
#if !defined(CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED) && CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED && defined(__CUDA_ARCH__)
67+
#define CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED 1
6668
#endif
6769

68-
#define CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED (CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER || CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER)
70+
#if CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER || CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER
71+
#define CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED 1
72+
#endif
6973

70-
#ifndef CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED
71-
#define CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED
74+
#if !defined(CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED) && CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED && defined(__CUDA_ARCH__)
75+
#define CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED 1
7276
#endif
7377

7478
// Clang 14+ provides a declaration of __nvvm_get_smem_pointer, so we only need
@@ -85,7 +89,7 @@ namespace cute
8589
{
8690

8791
/// CUTE helper to cast SMEM pointer to unsigned
88-
CUTE_HOST_DEVICE
92+
CUTE_DEVICE
8993
uint32_t
9094
cast_smem_ptr_to_uint(void const* const ptr)
9195
{

0 commit comments

Comments
 (0)