Skip to content

Commit a78f415

Browse files
committed
opt(insert-and-evict): thrust prefix_sum introduce cudaMalloc/cudaFree which make host wait. Replace it by cub API.
1 parent 770be38 commit a78f415

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

include/merlin/array_kernels.cuh

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#pragma once
1818

1919
#include <cooperative_groups.h>
20+
#include "cub/cub.cuh"
2021
#include "cuda_runtime.h"
2122
#include "thrust/device_vector.h"
2223
#include "thrust/execution_policy.h"
@@ -104,18 +105,13 @@ template <typename K, typename V, typename S, typename Tidx, int TILE_SIZE = 8>
104105
void gpu_boolean_mask(size_t grid_size, size_t block_size, const bool* masks,
105106
size_t n, size_t* n_evicted, Tidx* offsets,
106107
K* __restrict keys, V* __restrict values,
107-
S* __restrict scores, size_t dim, cudaStream_t stream) {
108+
S* __restrict scores, Tidx* offset_ws,
109+
size_t offset_ws_bytes, size_t dim, cudaStream_t stream) {
108110
size_t n_offsets = (n + TILE_SIZE - 1) / TILE_SIZE;
109111
gpu_cell_count<Tidx, TILE_SIZE>
110112
<<<grid_size, block_size, 0, stream>>>(masks, offsets, n, n_evicted);
111-
#if THRUST_VERSION >= 101600
112-
auto policy = thrust::cuda::par_nosync.on(stream);
113-
#else
114-
auto policy = thrust::cuda::par.on(stream);
115-
#endif
116-
thrust::device_ptr<Tidx> d_src(offsets);
117-
thrust::device_ptr<Tidx> d_dest(offsets);
118-
thrust::exclusive_scan(policy, d_src, d_src + n_offsets, d_dest);
113+
CUDA_CHECK(cub::DeviceScan::ExclusiveSum(offset_ws, offset_ws_bytes, offsets,
114+
offsets, n_offsets, stream));
119115
gpu_select_kvm_kernel<K, V, S, Tidx, TILE_SIZE>
120116
<<<grid_size, block_size, 0, stream>>>(masks, n, offsets, keys, values,
121117
scores, dim);

include/merlin_hashtable.cuh

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <mutex>
2727
#include <shared_mutex>
2828
#include <type_traits>
29+
#include "cub/cub.cuh"
2930
#include "merlin/allocator.cuh"
3031
#include "merlin/array_kernels.cuh"
3132
#include "merlin/core_kernels.cuh"
@@ -598,9 +599,20 @@ class HashTable {
598599

599600
keys_not_empty<K>
600601
<<<grid_size, block_size, 0, stream>>>(evicted_keys, d_masks, n);
602+
603+
void* d_temp_storage = nullptr;
604+
size_t temp_storage_bytes = 0;
605+
CUDA_CHECK(cub::DeviceScan::ExclusiveSum(d_temp_storage,
606+
temp_storage_bytes, d_offsets,
607+
d_offsets, n_offsets, stream));
608+
auto helper_ws{
609+
dev_mem_pool_->get_workspace<1>(temp_storage_bytes, stream)};
610+
int64_t* d_temp_storage_i64 = helper_ws.get<int64_t*>(0);
611+
601612
gpu_boolean_mask<K, V, S, int64_t, TILE_SIZE>(
602613
grid_size, block_size, d_masks, n, d_evicted_counter, d_offsets,
603-
evicted_keys, evicted_values, evicted_scores, dim(), stream);
614+
evicted_keys, evicted_values, evicted_scores, d_temp_storage_i64,
615+
temp_storage_bytes, dim(), stream);
604616
}
605617
return;
606618
}

0 commit comments

Comments
 (0)