Skip to content

Commit a3b4ab4

Browse files
committed
Simplify thrust invocation.
1 parent a0de3ac commit a3b4ab4

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

include/merlin_hashtable.cuh

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,6 @@ class HashTable {
153153
using DeviceMemoryPool = MemoryPool<DeviceAllocator<char>>;
154154
using HostMemoryPool = MemoryPool<HostAllocator<char>>;
155155

156-
#if THRUST_VERSION >= 101600
157-
static constexpr auto thrust_par = thrust::cuda::par_nosync;
158-
#else
159-
static constexpr auto thrust_par = thrust::cuda::par;
160-
#endif
161-
162156
public:
163157
/**
164158
* @brief Default constructor for the hash table class.
@@ -323,7 +317,7 @@ class HashTable {
323317
reinterpret_cast<uintptr_t*>(d_dst));
324318
thrust::device_ptr<int> d_src_offset_ptr(d_src_offset);
325319

326-
thrust::sort_by_key(thrust_par.on(stream), d_dst_ptr, d_dst_ptr + n,
320+
thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), d_dst_ptr, d_dst_ptr + n,
327321
d_src_offset_ptr, thrust::less<uintptr_t>());
328322
}
329323

@@ -557,7 +551,7 @@ class HashTable {
557551
thrust::device_ptr<uintptr_t> dst_ptr(reinterpret_cast<uintptr_t*>(dst));
558552
thrust::device_ptr<int> src_offset_ptr(src_offset);
559553

560-
thrust::sort_by_key(thrust_par.on(stream), dst_ptr, dst_ptr + n,
554+
thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), dst_ptr, dst_ptr + n,
561555
src_offset_ptr, thrust::less<uintptr_t>());
562556
}
563557

@@ -651,7 +645,7 @@ class HashTable {
651645
reinterpret_cast<uintptr_t*>(d_table_value_addrs));
652646
thrust::device_ptr<int> param_key_index_ptr(param_key_index);
653647

654-
thrust::sort_by_key(thrust_par.on(stream), table_value_ptr,
648+
thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), table_value_ptr,
655649
table_value_ptr + n, param_key_index_ptr,
656650
thrust::less<uintptr_t>());
657651
}
@@ -821,7 +815,7 @@ class HashTable {
821815
reinterpret_cast<uintptr_t*>(d_dst));
822816
thrust::device_ptr<int> d_src_offset_ptr(d_src_offset);
823817

824-
thrust::sort_by_key(thrust_par.on(stream), d_dst_ptr, d_dst_ptr + n,
818+
thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), d_dst_ptr, d_dst_ptr + n,
825819
d_src_offset_ptr, thrust::less<uintptr_t>());
826820
}
827821

@@ -922,7 +916,7 @@ class HashTable {
922916
reinterpret_cast<uintptr_t*>(src));
923917
thrust::device_ptr<int> dst_offset_ptr(dst_offset);
924918

925-
thrust::sort_by_key(thrust_par.on(stream), src_ptr, src_ptr + n,
919+
thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), src_ptr, src_ptr + n,
926920
dst_offset_ptr, thrust::less<uintptr_t>());
927921
}
928922

@@ -1275,7 +1269,7 @@ class HashTable {
12751269

12761270
for (size_type start_i = 0; start_i < N; start_i += step) {
12771271
size_type end_i = std::min(start_i + step, N);
1278-
h_size += thrust::reduce(thrust_par.on(stream), size_ptr + start_i,
1272+
h_size += thrust::reduce(thrust::cuda::par_nosync.on(stream), size_ptr + start_i,
12791273
size_ptr + end_i, 0, thrust::plus<int>());
12801274
}
12811275

@@ -1589,7 +1583,7 @@ class HashTable {
15891583

15901584
thrust::device_ptr<int> size_ptr(table_->buckets_size);
15911585

1592-
int size = thrust::reduce(thrust_par.on(stream), size_ptr, size_ptr + N, 0,
1586+
int size = thrust::reduce(thrust::cuda::par_nosync.on(stream), size_ptr, size_ptr + N, 0,
15931587
thrust::plus<int>());
15941588

15951589
CudaCheckError();

0 commit comments

Comments
 (0)