Skip to content

Commit b2bfa44

Browse files
authored
fix bug of weight_only_linear tune (#71017)
* fix bug of weight_only_linear tune
1 parent 302f9c2 commit b2bfa44

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,8 @@ void CutlassFpAIntBGemmRunner<T, WeightType>::run_gemm<EpilogueTag,
617617
} else {
618618
float best_time = std::numeric_limits<float>::max();
619619
CutlassGemmConfig best_config;
620-
int profile_m = gemmConfigManager.nextPowerOfTwo(m);
620+
int profile_m = std::min(gemmConfigManager.nextPowerOfTwo(m),
621+
gemmConfigManager.getMaxProfileM());
621622
bool found_one = false;
622623

623624
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {

paddle/phi/kernels/fusion/cutlass/cutlass_kernels/gemm_config_manager.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ class GemmConfigManager {
245245
return ++v;
246246
}
247247

248-
int getMaxProfileM() const { return 256; }
248+
int getMaxProfileM() const { return 1024; }
249249

250250
bool loadFromJson(const std::string& filename) {
251251
std::ifstream inFile(filename);

paddle/phi/kernels/fusion/cutlass/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,9 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
960960
} else {
961961
float best_time = std::numeric_limits<float>::max();
962962
CutlassGemmConfig best_config;
963-
int profile_total_rows = gemmConfigManager.nextPowerOfTwo(total_rows);
963+
int profile_total_rows =
964+
std::min(gemmConfigManager.nextPowerOfTwo(total_rows),
965+
gemmConfigManager.getMaxProfileM());
964966
bool found_one = false;
965967

966968
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {

0 commit comments

Comments
 (0)