-
Notifications
You must be signed in to change notification settings - Fork 449
Add the keyword "template" to member template specialization appears after .
or ->
in a post-fix expression which is a requirement in C++ standard
#1246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -295,7 +295,7 @@ __device__ __forceinline__ void produce_kv(smem_t<KTraits::SWIZZLE_MODE_KV> smem | |
for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { | ||
#pragma unroll | ||
for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { | ||
smem.load_128b_async<fill_mode>(*smem_offset, *gptr, kv_idx < kv_len); | ||
smem.template load_128b_async<fill_mode>(*smem_offset, *gptr, kv_idx < kv_len); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
*smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); | ||
*gptr += 8 * upcast_size<DTypeKV>(); | ||
} | ||
|
@@ -434,7 +434,7 @@ __device__ __forceinline__ void load_q_global_smem( | |
const uint32_t lane_idx = tid.x, warp_idx_x = get_warp_idx_q<KTraits>(tid.y); | ||
|
||
if (get_warp_idx_kv<KTraits>(tid.z) == 0) { | ||
uint32_t q_smem_offset_w = q_smem->get_permuted_offset<UPCAST_STRIDE_Q>( | ||
uint32_t q_smem_offset_w = q_smem->template get_permuted_offset<UPCAST_STRIDE_Q>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8); | ||
|
||
#pragma unroll | ||
|
@@ -449,7 +449,7 @@ __device__ __forceinline__ void load_q_global_smem( | |
#pragma unroll | ||
for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; ++mma_do) { | ||
// load q fragment from gmem to smem | ||
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(q_smem_offset_w, q_ptr, | ||
q_smem->template load_128b_async<SharedMemFillMode::kNoFill>(q_smem_offset_w, q_ptr, | ||
q_idx < qo_upper_bound); | ||
q_smem_offset_w = q_smem->template advance_offset_by_column<8>(q_smem_offset_w, mma_do); | ||
q_ptr += 8 * upcast_size<DTypeQ>(); | ||
|
@@ -1251,12 +1251,12 @@ __device__ __forceinline__ void write_o_reg_gmem( | |
vec_cast<DTypeO, float>::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_q][mma_d]); | ||
|
||
#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED | ||
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_STRIDE_O>( | ||
uint32_t o_smem_offset_w = o_smem->template get_permuted_offset<UPCAST_STRIDE_O>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
(warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16, | ||
mma_d * 2 + lane_idx / 16); | ||
o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); | ||
#else | ||
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_STRIDE_O>( | ||
uint32_t o_smem_offset_w = o_smem->template get_permuted_offset<UPCAST_STRIDE_O>( | ||
(warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / 4, mma_d * 2); | ||
((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; | ||
((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = | ||
|
@@ -1268,7 +1268,7 @@ __device__ __forceinline__ void write_o_reg_gmem( | |
} | ||
} | ||
|
||
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_STRIDE_O>( | ||
uint32_t o_smem_offset_w = o_smem->template get_permuted_offset<UPCAST_STRIDE_O>( | ||
warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8); | ||
|
||
#pragma unroll | ||
|
@@ -1412,7 +1412,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( | |
? o + chunk_idx * o_stride_n + (kv_head_idx * group_size) * o_stride_h | ||
: o + (kv_head_idx * group_size) * o_stride_h; | ||
|
||
uint32_t q_smem_offset_r = qo_smem.get_permuted_offset<UPCAST_STRIDE_Q>( | ||
uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset<UPCAST_STRIDE_Q>( | ||
get_warp_idx_q<KTraits>(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); | ||
load_q_global_smem<KTraits>(qo_packed_idx_base, qo_len, q_ptr_base, q_stride_n, q_stride_h, | ||
group_size, &qo_smem, tid); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -292,7 +292,7 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u | |
} | ||
max_val = max( | ||
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
.Reduce<VEC_SIZE>(in_data_, cub::Max())); | ||
.template Reduce<VEC_SIZE>(in_data_, cub::Max())); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
__syncthreads(); | ||
} | ||
if (tx == 0) { | ||
|
@@ -599,7 +599,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb( | |
} | ||
float aggregate_local = | ||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce) | ||
.Sum<VEC_SIZE>(prob_greater_than_threshold); | ||
.template Sum<VEC_SIZE>(prob_greater_than_threshold); | ||
if (tx == 0) { | ||
temp_storage->block_aggregate.value = aggregate_local; | ||
} | ||
|
@@ -612,7 +612,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb( | |
prob_greater_than_threshold, inclusive_cdf, temp_storage); | ||
} else { | ||
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan) | ||
.InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf); | ||
.template InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf); | ||
|
||
__syncthreads(); | ||
} | ||
|
@@ -628,7 +628,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb( | |
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp()); | ||
#else | ||
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff) | ||
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); | ||
.template FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); | ||
#endif | ||
__syncthreads(); | ||
|
||
|
@@ -764,7 +764,7 @@ __global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType* | |
|
||
max_data += | ||
BlockReduce<DataAndIndex<DType, IdType>, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage) | ||
.Sum<VEC_SIZE>(cur_data); | ||
.template Sum<VEC_SIZE>(cur_data); | ||
} | ||
if (tx == 0) { | ||
output[bx] = max_data.index; | ||
|
@@ -1004,15 +1004,15 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* | |
} | ||
|
||
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce) | ||
.Sum<VEC_SIZE>(probs_gt_pivot_0); | ||
.template Sum<VEC_SIZE>(probs_gt_pivot_0); | ||
if (tx == 0) { | ||
temp_storage.block_aggregate.value = aggregate_gt_pivot_0; | ||
} | ||
__syncthreads(); | ||
aggregate_gt_pivot_0 = temp_storage.block_aggregate.value; | ||
|
||
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce) | ||
.Sum<VEC_SIZE>(probs_gt_pivot_1); | ||
.template Sum<VEC_SIZE>(probs_gt_pivot_1); | ||
if (tx == 0) { | ||
temp_storage.block_aggregate.value = aggregate_gt_pivot_1; | ||
} | ||
|
@@ -1612,12 +1612,12 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* | |
|
||
aggregate_gt_pivot_0 += | ||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
.Sum<VEC_SIZE>(probs_gt_pivot_0); | ||
.template Sum<VEC_SIZE>(probs_gt_pivot_0); | ||
__syncthreads(); | ||
|
||
aggregate_gt_pivot_1 += | ||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
.Sum<VEC_SIZE>(probs_gt_pivot_1); | ||
.template Sum<VEC_SIZE>(probs_gt_pivot_1); | ||
__syncthreads(); | ||
} | ||
min_gt_low = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
|
@@ -1853,12 +1853,12 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* | |
|
||
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>( | ||
temp_storage.block_prim.reduce_value_count) | ||
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair); | ||
.template Sum<VEC_SIZE>(probs_gt_pivot_0_pair); | ||
__syncthreads(); | ||
|
||
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>( | ||
temp_storage.block_prim.reduce_value_count) | ||
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair); | ||
.template Sum<VEC_SIZE>(probs_gt_pivot_1_pair); | ||
__syncthreads(); | ||
} | ||
min_gt_low = | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
template
keyword is added here to specify thatload
is a template method. Consider adding thetemplate
keyword to other template method calls within this file for consistency and to avoid potential compilation issues with strict compilers.