diff --git a/include/flashinfer/attention/hopper/prefill_sm90.cuh b/include/flashinfer/attention/hopper/prefill_sm90.cuh index da7fda25b..c9bee9466 100644 --- a/include/flashinfer/attention/hopper/prefill_sm90.cuh +++ b/include/flashinfer/attention/hopper/prefill_sm90.cuh @@ -194,7 +194,7 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx, num_kv_tiles_outside_items_window, num_kv_tiles_prefix); } else { - collective_mainloop.load( + collective_mainloop.template load( mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v, shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx); } diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 135cee8e3..a5ecc0900 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -295,7 +295,7 @@ __device__ __forceinline__ void produce_kv(smem_t smem for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { - smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + smem.template load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); *gptr += 8 * upcast_size(); } @@ -434,7 +434,7 @@ __device__ __forceinline__ void load_q_global_smem( const uint32_t lane_idx = tid.x, warp_idx_x = get_warp_idx_q(tid.y); if (get_warp_idx_kv(tid.z) == 0) { - uint32_t q_smem_offset_w = q_smem->get_permuted_offset( + uint32_t q_smem_offset_w = q_smem->template get_permuted_offset( warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8); #pragma unroll @@ -449,7 +449,7 @@ __device__ __forceinline__ void load_q_global_smem( #pragma unroll for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; ++mma_do) { // load q fragment from gmem to smem - q_smem->load_128b_async(q_smem_offset_w, q_ptr, + q_smem->template load_128b_async(q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); q_smem_offset_w = q_smem->template advance_offset_by_column<8>(q_smem_offset_w, mma_do); q_ptr += 8 * upcast_size(); @@ -1251,12 +1251,12 @@ __device__ __forceinline__ void write_o_reg_gmem( vec_cast::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_q][mma_d]); #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED - uint32_t o_smem_offset_w = o_smem->get_permuted_offset( + uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16, mma_d * 2 + lane_idx / 16); o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); #else - uint32_t o_smem_offset_w = o_smem->get_permuted_offset( + uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / 4, mma_d * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = @@ -1268,7 +1268,7 @@ __device__ __forceinline__ void write_o_reg_gmem( } } - uint32_t o_smem_offset_w = o_smem->get_permuted_offset( + uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8); #pragma unroll @@ -1412,7 +1412,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( ? o + chunk_idx * o_stride_n + (kv_head_idx * group_size) * o_stride_h : o + (kv_head_idx * group_size) * o_stride_h; - uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( + uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem, tid); diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index f1cb12d06..79b7f9b8a 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -292,7 +292,7 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u } max_val = max( max_val, BlockReduce(temp_storage.block_prim.reduce) - .Reduce(in_data_, cub::Max())); + .template Reduce(in_data_, cub::Max())); __syncthreads(); } if (tx == 0) { @@ -599,7 +599,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb( } float aggregate_local = BlockReduce(temp_storage->block_prim.reduce) - .Sum(prob_greater_than_threshold); + .template Sum(prob_greater_than_threshold); if (tx == 0) { temp_storage->block_aggregate.value = aggregate_local; } @@ -612,7 +612,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb( prob_greater_than_threshold, inclusive_cdf, temp_storage); } else { BlockScan(temp_storage->block_prim.scan) - .InclusiveSum(prob_greater_than_threshold, inclusive_cdf); + .template InclusiveSum(prob_greater_than_threshold, inclusive_cdf); __syncthreads(); } @@ -628,7 +628,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb( .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp()); #else BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); + .template FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); #endif __syncthreads(); @@ -764,7 +764,7 @@ __global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType* max_data += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage) - .Sum(cur_data); + .template Sum(cur_data); } if (tx == 0) { output[bx] = max_data.index; @@ -1004,7 +1004,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* } aggregate_gt_pivot_0 += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_0); + .template Sum(probs_gt_pivot_0); if (tx == 0) { temp_storage.block_aggregate.value = aggregate_gt_pivot_0; } @@ -1012,7 +1012,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* aggregate_gt_pivot_0 = temp_storage.block_aggregate.value; aggregate_gt_pivot_1 += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_1); + .template Sum(probs_gt_pivot_1); if (tx == 0) { temp_storage.block_aggregate.value = aggregate_gt_pivot_1; } @@ -1612,12 +1612,12 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* aggregate_gt_pivot_0 += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_0); + .template Sum(probs_gt_pivot_0); __syncthreads(); aggregate_gt_pivot_1 += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_1); + .template Sum(probs_gt_pivot_1); __syncthreads(); } min_gt_low = BlockReduce(temp_storage.block_prim.reduce) @@ -1853,12 +1853,12 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0_pair); + .template Sum(probs_gt_pivot_0_pair); __syncthreads(); aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1_pair); + .template Sum(probs_gt_pivot_1_pair); __syncthreads(); } min_gt_low =