|
17 | 17 | #pragma once |
18 | 18 |
|
19 | 19 | #include <cooperative_groups.h> |
| 20 | +#include "cub/cub.cuh" |
20 | 21 | #include "cuda_runtime.h" |
21 | 22 | #include "thrust/device_vector.h" |
22 | 23 | #include "thrust/execution_policy.h" |
@@ -104,18 +105,11 @@ template <typename K, typename V, typename S, typename Tidx, int TILE_SIZE = 8> |
104 | 105 | void gpu_boolean_mask(size_t grid_size, size_t block_size, const bool* masks, |
105 | 106 | size_t n, size_t* n_evicted, Tidx* offsets, |
106 | 107 | K* __restrict keys, V* __restrict values, |
107 | | - S* __restrict scores, size_t dim, cudaStream_t stream) { |
| 108 | + S* __restrict scores, Tidx* offset_ws, size_t offset_ws_bytes, size_t dim, cudaStream_t stream) { |
108 | 109 | size_t n_offsets = (n + TILE_SIZE - 1) / TILE_SIZE; |
109 | 110 | gpu_cell_count<Tidx, TILE_SIZE> |
110 | 111 | <<<grid_size, block_size, 0, stream>>>(masks, offsets, n, n_evicted); |
111 | | -#if THRUST_VERSION >= 101600 |
112 | | - auto policy = thrust::cuda::par_nosync.on(stream); |
113 | | -#else |
114 | | - auto policy = thrust::cuda::par.on(stream); |
115 | | -#endif |
116 | | - thrust::device_ptr<Tidx> d_src(offsets); |
117 | | - thrust::device_ptr<Tidx> d_dest(offsets); |
118 | | - thrust::exclusive_scan(policy, d_src, d_src + n_offsets, d_dest); |
| 112 | + cub::DeviceScan::ExclusiveSum(offset_ws, offset_ws_bytes, offsets, offsets, n_offsets, stream); |
119 | 113 | gpu_select_kvm_kernel<K, V, S, Tidx, TILE_SIZE> |
120 | 114 | <<<grid_size, block_size, 0, stream>>>(masks, n, offsets, keys, values, |
121 | 115 | scores, dim); |
|
0 commit comments