Skip to content

Add the keyword "template" to member template specialization appears after . or -> in a post-fix expression which is a requirement in C++ standard #1246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/flashinfer/attention/hopper/prefill_sm90.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp
shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx,
num_kv_tiles_outside_items_window, num_kv_tiles_prefix);
} else {
collective_mainloop.load<LEFT_SLIDING_WINDOW>(
collective_mainloop.template load<LEFT_SLIDING_WINDOW>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The template keyword is added here to specify that load is a template method. Consider adding the template keyword to other template method calls within this file for consistency and to avoid potential compilation issues with strict compilers.

mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx);
}
Expand Down
14 changes: 7 additions & 7 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ __device__ __forceinline__ void produce_kv(smem_t<KTraits::SWIZZLE_MODE_KV> smem
for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) {
#pragma unroll
for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, *gptr, kv_idx < kv_len);
smem.template load_128b_async<fill_mode>(*smem_offset, *gptr, kv_idx < kv_len);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The template keyword is correctly added here. Consider adding the template keyword to the call to advance_offset_by_column on line 299 for consistency, as it's also a member template function call on a dependent name.

*smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j);
*gptr += 8 * upcast_size<DTypeKV>();
}
Expand Down Expand Up @@ -434,7 +434,7 @@ __device__ __forceinline__ void load_q_global_smem(
const uint32_t lane_idx = tid.x, warp_idx_x = get_warp_idx_q<KTraits>(tid.y);

if (get_warp_idx_kv<KTraits>(tid.z) == 0) {
uint32_t q_smem_offset_w = q_smem->get_permuted_offset<UPCAST_STRIDE_Q>(
uint32_t q_smem_offset_w = q_smem->template get_permuted_offset<UPCAST_STRIDE_Q>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The template keyword is correctly added here. Consider adding the template keyword to the call to advance_offset_by_column on line 454 for consistency, as it's also a member template function call on a dependent name.

warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8);

#pragma unroll
Expand All @@ -449,7 +449,7 @@ __device__ __forceinline__ void load_q_global_smem(
#pragma unroll
for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; ++mma_do) {
// load q fragment from gmem to smem
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(q_smem_offset_w, q_ptr,
q_smem->template load_128b_async<SharedMemFillMode::kNoFill>(q_smem_offset_w, q_ptr,
q_idx < qo_upper_bound);
q_smem_offset_w = q_smem->template advance_offset_by_column<8>(q_smem_offset_w, mma_do);
q_ptr += 8 * upcast_size<DTypeQ>();
Expand Down Expand Up @@ -1251,12 +1251,12 @@ __device__ __forceinline__ void write_o_reg_gmem(
vec_cast<DTypeO, float>::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_q][mma_d]);

#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_STRIDE_O>(
uint32_t o_smem_offset_w = o_smem->template get_permuted_offset<UPCAST_STRIDE_O>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The template keyword is correctly added here. Consider adding the template keyword to the call to get_permuted_offset on line 1259 and 1271 for consistency, as it's also a member template function call on a dependent name.

(warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16,
mma_d * 2 + lane_idx / 16);
o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16);
#else
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_STRIDE_O>(
uint32_t o_smem_offset_w = o_smem->template get_permuted_offset<UPCAST_STRIDE_O>(
(warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / 4, mma_d * 2);
((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * UPCAST_STRIDE_O))[lane_idx % 4] =
Expand All @@ -1268,7 +1268,7 @@ __device__ __forceinline__ void write_o_reg_gmem(
}
}

uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_STRIDE_O>(
uint32_t o_smem_offset_w = o_smem->template get_permuted_offset<UPCAST_STRIDE_O>(
warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8);

#pragma unroll
Expand Down Expand Up @@ -1412,7 +1412,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice(
? o + chunk_idx * o_stride_n + (kv_head_idx * group_size) * o_stride_h
: o + (kv_head_idx * group_size) * o_stride_h;

uint32_t q_smem_offset_r = qo_smem.get_permuted_offset<UPCAST_STRIDE_Q>(
uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset<UPCAST_STRIDE_Q>(
get_warp_idx_q<KTraits>(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16);
load_q_global_smem<KTraits>(qo_packed_idx_base, qo_len, q_ptr_base, q_stride_n, q_stride_h,
group_size, &qo_smem, tid);
Expand Down
22 changes: 11 additions & 11 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
}
max_val = max(
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
.template Reduce<VEC_SIZE>(in_data_, cub::Max()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The template keyword is added here to specify that Reduce is a template method. Consider adding the template keyword to other template method calls within this file for consistency and to avoid potential compilation issues with strict compilers.

__syncthreads();
}
if (tx == 0) {
Expand Down Expand Up @@ -599,7 +599,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
}
float aggregate_local =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
.Sum<VEC_SIZE>(prob_greater_than_threshold);
.template Sum<VEC_SIZE>(prob_greater_than_threshold);
if (tx == 0) {
temp_storage->block_aggregate.value = aggregate_local;
}
Expand All @@ -612,7 +612,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
prob_greater_than_threshold, inclusive_cdf, temp_storage);
} else {
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
.InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
.template InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);

__syncthreads();
}
Expand All @@ -628,7 +628,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
#else
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
.template FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
#endif
__syncthreads();

Expand Down Expand Up @@ -764,7 +764,7 @@ __global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType*

max_data +=
BlockReduce<DataAndIndex<DType, IdType>, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage)
.Sum<VEC_SIZE>(cur_data);
.template Sum<VEC_SIZE>(cur_data);
}
if (tx == 0) {
output[bx] = max_data.index;
Expand Down Expand Up @@ -1004,15 +1004,15 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType*
}

aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot_0);
.template Sum<VEC_SIZE>(probs_gt_pivot_0);
if (tx == 0) {
temp_storage.block_aggregate.value = aggregate_gt_pivot_0;
}
__syncthreads();
aggregate_gt_pivot_0 = temp_storage.block_aggregate.value;

aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot_1);
.template Sum<VEC_SIZE>(probs_gt_pivot_1);
if (tx == 0) {
temp_storage.block_aggregate.value = aggregate_gt_pivot_1;
}
Expand Down Expand Up @@ -1612,12 +1612,12 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*

aggregate_gt_pivot_0 +=
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot_0);
.template Sum<VEC_SIZE>(probs_gt_pivot_0);
__syncthreads();

aggregate_gt_pivot_1 +=
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot_1);
.template Sum<VEC_SIZE>(probs_gt_pivot_1);
__syncthreads();
}
min_gt_low = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
Expand Down Expand Up @@ -1853,12 +1853,12 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*

aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
.template Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
__syncthreads();

aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
.template Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
__syncthreads();
}
min_gt_low =
Expand Down
Loading