From e05c077ec6a6d4527977b3af2bda87c6af653b46 Mon Sep 17 00:00:00 2001 From: gaoziyuan Date: Tue, 8 Apr 2025 22:52:40 +0800 Subject: [PATCH 1/4] add sm90 config --- .../transforms/gpu/fused_weight_only_linear_pass.cc | 2 +- paddle/phi/infermeta/unary.cc | 3 ++- paddle/phi/kernels/gpu/weight_only_linear_kernel.cu | 6 ++++-- paddle/phi/kernels/gpu/weight_quantize_kernel.cu | 5 +++-- .../kernels/impl/weight_quantize_kernel_gpu_impl.h | 3 ++- python/paddle/nn/quant/quantized_linear.py | 11 ++++++++--- 6 files changed, 20 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.cc b/paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.cc index 5babd4072a7b05..1904cbfbbb5722 100644 --- a/paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.cc @@ -311,7 +311,7 @@ class FusedWeightOnlyLinearPass : public pir::PatternRewritePass { bool CanApplyOn(pir::Operation *op) const override { if (sm_version_ != 70 && sm_version_ != 75 && sm_version_ != 80 && - sm_version_ != 86) { + sm_version_ != 86 && sm_version_ != 89 && sm_version_ != 90) { return false; } return op->num_regions() > 0; diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 90c7e2726362b2..a8418ab6c45ed6 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -5783,7 +5783,8 @@ void WeightQuantizeInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* scale) { PADDLE_ENFORCE_EQ( - ((arch == 80) || (arch == 86) || (arch == 70) || (arch == 75)), + ((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) || + (arch == 89) || (arch == 90)), true, phi::errors::InvalidArgument( "Currently, arch only support 70, 75, 80, 86.")); diff --git a/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu index 901a291d3924db..13ce482fbdc01b 100644 --- a/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu @@ -35,9 +35,11 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, DenseTensor* out) { #if defined(PADDLE_WITH_CUTLASS) PADDLE_ENFORCE_EQ( - ((arch == 80) || (arch == 70) || (arch == 75) || (arch == 86)), + ((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) || + (arch == 89) || (arch == 90)), true, - phi::errors::InvalidArgument("Currently, arch only support 70, 80.")); + phi::errors::InvalidArgument( + "Currently, arch only support 70, 75, 80, 86, 89, 90.")); #else PADDLE_THROW(phi::errors::Unimplemented( "Please compile with cutlass to make cutlass available")); diff --git a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu index 51b4786155a923..4d1efc407c1cd0 100644 --- a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu @@ -45,10 +45,11 @@ void WeightQuantizeKernel(const Context& dev_ctx, std::vector weight_shape{static_cast(x.dims()[0]), static_cast(x.dims()[1])}; PADDLE_ENFORCE_EQ( - ((arch == 80) || (arch == 86) || (arch == 75) || (arch == 70)), + ((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) || + (arch == 89) || (arch == 90)), true, phi::errors::InvalidArgument( - "Currently, arch only support 70, 75, 80, 86.")); + "Currently, arch only support 70, 75, 80, 86, 89, 90.")); if (algo == "llm.int8") { dev_ctx.template Alloc(scale); diff --git a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h index 05d0e47b314555..80365d99349eb2 100644 --- a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h +++ b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h @@ -97,7 +97,8 @@ void weight_permute_gpu(const GPUContext& dev_ctx, auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, 1); int grid_size = gpu_config.GetGridSize(); int block_size = gpu_config.GetBlockSize(); - if ((arch == 80) || (arch == 86) || (arch == 75)) { + if ((arch == 90) || (arch == 89) || (arch == 86) || (arch == 80) || + (arch == 75)) { weight_permute_kernel_wint8<<>>( input_data, output_data, numel, total_k, total_n); } else if (arch == 70) { diff --git a/python/paddle/nn/quant/quantized_linear.py b/python/paddle/nn/quant/quantized_linear.py index 41ad1839e1f8a4..c093b7041faa9c 100644 --- a/python/paddle/nn/quant/quantized_linear.py +++ b/python/paddle/nn/quant/quantized_linear.py @@ -69,9 +69,14 @@ def weight_quantize(x, algo="weight_only_int8", arch=None, group_size=-1): arch = _get_arch_info() assert ( - arch == 70 or arch == 80 or arch == 86 or arch == 75 - ), f"Currently weight_quantize only support SM70/75/80/86. but got {arch} " - + arch == 70 + or arch == 75 + or arch == 80 + or arch == 86 + or arch == 89 + or arch == 90 + ), f"Currently weight_quantize only support SM70/75/80/86/89/90. but got {arch} " + assert ( group_size == -1 or group_size == 64 or group_size == 128 ), f"Currently group_size only support -1/64/128. but got {group_size} " From 9fa2c854682c116fea763b0440fce25941e2a9bc Mon Sep 17 00:00:00 2001 From: gaoziyuan Date: Wed, 9 Apr 2025 11:37:02 +0800 Subject: [PATCH 2/4] fix code style --- paddle/phi/infermeta/unary.cc | 11 +++++------ .../phi/kernels/gpu/weight_only_linear_kernel.cu | 4 ++-- paddle/phi/kernels/gpu/weight_quantize_kernel.cu | 2 +- .../impl/weight_quantize_kernel_gpu_impl.h | 2 +- python/paddle/nn/quant/quantized_linear.py | 16 ++++++++-------- 5 files changed, 17 insertions(+), 18 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index a8418ab6c45ed6..ecdbf517deccac 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -5782,12 +5782,11 @@ void WeightQuantizeInferMeta(const MetaTensor& x, const int32_t group_size, MetaTensor* out, MetaTensor* scale) { - PADDLE_ENFORCE_EQ( - ((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) || - (arch == 89) || (arch == 90)), - true, - phi::errors::InvalidArgument( - "Currently, arch only support 70, 75, 80, 86.")); + PADDLE_ENFORCE_EQ(((arch == 70) || (arch == 75) || (arch == 80) || + (arch == 86) || (arch == 89) || (arch == 90)), + true, + phi::errors::InvalidArgument( + "Currently, arch only support 70, 75, 80, 86.")); auto x_dims = x.dims(); PADDLE_ENFORCE_EQ( diff --git a/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu index 13ce482fbdc01b..668d2f29764700 100644 --- a/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu @@ -36,10 +36,10 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, #if defined(PADDLE_WITH_CUTLASS) PADDLE_ENFORCE_EQ( ((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) || - (arch == 89) || (arch == 90)), + (arch == 89) || (arch == 90)), true, phi::errors::InvalidArgument( - "Currently, arch only support 70, 75, 80, 86, 89, 90.")); + "Currently, arch only support 70, 75, 80, 86, 89, 90.")); #else PADDLE_THROW(phi::errors::Unimplemented( "Please compile with cutlass to make cutlass available")); diff --git a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu index 4d1efc407c1cd0..18244d38989623 100644 --- a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu @@ -46,7 +46,7 @@ void WeightQuantizeKernel(const Context& dev_ctx, static_cast(x.dims()[1])}; PADDLE_ENFORCE_EQ( ((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) || - (arch == 89) || (arch == 90)), + (arch == 89) || (arch == 90)), true, phi::errors::InvalidArgument( "Currently, arch only support 70, 75, 80, 86, 89, 90.")); diff --git a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h index 80365d99349eb2..963608f7833210 100644 --- a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h +++ b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h @@ -98,7 +98,7 @@ void weight_permute_gpu(const GPUContext& dev_ctx, int grid_size = gpu_config.GetGridSize(); int block_size = gpu_config.GetBlockSize(); if ((arch == 90) || (arch == 89) || (arch == 86) || (arch == 80) || - (arch == 75)) { + (arch == 75)) { weight_permute_kernel_wint8<<>>( input_data, output_data, numel, total_k, total_n); } else if (arch == 70) { diff --git a/python/paddle/nn/quant/quantized_linear.py b/python/paddle/nn/quant/quantized_linear.py index c093b7041faa9c..cc7dee548d176e 100644 --- a/python/paddle/nn/quant/quantized_linear.py +++ b/python/paddle/nn/quant/quantized_linear.py @@ -69,14 +69,14 @@ def weight_quantize(x, algo="weight_only_int8", arch=None, group_size=-1): arch = _get_arch_info() assert ( - arch == 70 - or arch == 75 - or arch == 80 - or arch == 86 - or arch == 89 - or arch == 90 - ), f"Currently weight_quantize only support SM70/75/80/86/89/90. but got {arch} " - + arch == 70 + or arch == 75 + or arch == 80 + or arch == 86 + or arch == 89 + or arch == 90 + ), f"Currently weight_quantize only support SM70/75/80/86/89/90. but got {arch} " + assert ( group_size == -1 or group_size == 64 or group_size == 128 ), f"Currently group_size only support -1/64/128. but got {group_size} " From f30db20a6f43c7575b6ed5113cf077c8b441068e Mon Sep 17 00:00:00 2001 From: gaoziyuan Date: Wed, 9 Apr 2025 11:55:45 +0800 Subject: [PATCH 3/4] fix --- python/paddle/nn/quant/quantized_linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/quant/quantized_linear.py b/python/paddle/nn/quant/quantized_linear.py index cc7dee548d176e..00279b6825fc98 100644 --- a/python/paddle/nn/quant/quantized_linear.py +++ b/python/paddle/nn/quant/quantized_linear.py @@ -198,8 +198,8 @@ def weight_only_linear( arch = _get_arch_info() assert ( - arch == 70 or arch == 80 or arch == 86 or arch == 75 - ), f"Currently weight_quantize only support SM70/75/80/86. but got {arch} " + arch == 70 or arch == 80 or arch == 86 or arch == 75 or arch == 90 + ), f"Currently weight_quantize only support SM70/75/80/86/90. but got {arch} " assert ( group_size == -1 or group_size == 64 or group_size == 128 ), f"Currently weight_quantize only support group size of -1, 64 or 128. but got {group_size} " From a19d7dc2839d9620ba4e7d05a76da7b447233cbb Mon Sep 17 00:00:00 2001 From: gaoziyuan Date: Fri, 11 Apr 2025 17:02:02 +0800 Subject: [PATCH 4/4] add weight_quantize gpu api --- .../kernels/funcs/weight_dequant_functor.h | 10 +- .../phi/kernels/gpu/weight_quantize_kernel.cu | 28 +- .../impl/weight_quantize_kernel_gpu_impl.h | 296 +++++++++++++++++- 3 files changed, 316 insertions(+), 18 deletions(-) diff --git a/paddle/phi/kernels/funcs/weight_dequant_functor.h b/paddle/phi/kernels/funcs/weight_dequant_functor.h index 4eed94de7bf4dc..feb3a7992f65a7 100644 --- a/paddle/phi/kernels/funcs/weight_dequant_functor.h +++ b/paddle/phi/kernels/funcs/weight_dequant_functor.h @@ -187,10 +187,10 @@ __global__ void int4_weight_only_dequant(const uint8_t* weight, int warp_id = threadIdx.x / 32, lane_id = threadIdx.x % 32; int tile_id = blockIdx.x * blockDim.x / 32 + warp_id; - // Every two rows of the original weights are interleaved into a row with - // stride of 64, so if each thread processes 16 elements(for int8, we can use - // ldg.128 to load weights), then every group of four adjacent threads will - // alternately process two different row weights for example every 128 + // Every 4 rows of the original weights are interleaved into a row with + // stride of 32, so if each thread processes 16 elements(for int8, we can use + // ldg.128 to load weights), then every group of two adjacent threads will + // alternately process four different row weights for example every 128 // consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave // layout, the first 64 are from [64*i, 64*(i+1)-1] of row 2N before // interleaving, and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1 @@ -383,6 +383,7 @@ void WeightDequantize(const Context& dev_ctx, k, group_size); } else if (algo == "weight_only_int4" && group_size == -1) { + k *= 2; grid.x /= 2; int4_weight_only_dequant<<>>( reinterpret_cast(x.data()), @@ -391,6 +392,7 @@ void WeightDequantize(const Context& dev_ctx, n, k); } else if (algo == "weight_only_int4" && group_size > 0) { + k *= 2; grid.x /= 2; int4_weight_only_dequant<<>>( reinterpret_cast(x.data()), diff --git a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu index 18244d38989623..0c8918f396a560 100644 --- a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu @@ -59,7 +59,9 @@ void WeightQuantizeKernel(const Context& dev_ctx, x.data(), quanted_x.data(), scale->data(), - weight_shape); + weight_shape, + arch, + algo); trans(dev_ctx, quanted_x, out, axis); } else if (algo == "weight_only_int8") { dev_ctx.template Alloc(scale); @@ -67,16 +69,30 @@ void WeightQuantizeKernel(const Context& dev_ctx, x.data(), quanted_x.data(), scale->data(), - weight_shape); + weight_shape, + arch, + algo); weight_permute_gpu(dev_ctx, quanted_x.data(), out->data(), weight_shape, - arch); + arch, + algo); } else if (algo == "weight_only_int4") { - PADDLE_FATAL( - "Weight quant gpu kernel currently don't support weight_only_int4 " - "algo, please use cpu version."); + dev_ctx.template Alloc(scale); + weight_quant_gpu(dev_ctx, + x.data(), + quanted_x.data(), + scale->data(), + weight_shape, + arch, + algo); + weight_permute_gpu(dev_ctx, + quanted_x.data(), + out->data(), + weight_shape, + arch, + algo); } else { PADDLE_FATAL( "The algo must be in ['weight_only_int8', 'weight_only_int4', " diff --git a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h index 963608f7833210..096b6b70bac3eb 100644 --- a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h +++ b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h @@ -48,6 +48,97 @@ __global__ void weight_permute_kernel_wint8(const int8_t* input_data_dev, } } +// from +// 0 1 2 3 4 5 6 7... +// to +// 0 8 16 24 1 9 17 25... +__global__ void weight_permute_kernel_wint4(const int8_t* input_data_dev, + int8_t* output_data_dev, + int numel, + int total_k, + int total_n) { + for (int linear_idx = blockIdx.x * blockDim.x + threadIdx.x; + linear_idx < numel; + linear_idx += blockDim.x * gridDim.x) { + int k_id = linear_idx / total_n; + int n_id = linear_idx % total_n; + constexpr int k_permute_const = 8; + int k_mod_8 = k_id % 8; + int temp_k_expr_1 = k_mod_8 - k_mod_8 / 4 * 4; + int temp_k_expr_2 = k_mod_8 / 4; + // we need int4 index like + // 0 8 16 24 1 9 17 25 2 10 18 26 3 11 19 27 + // 4 12 20 28 5 13 21 29 6 14 22 30 7 15 23 31 + // we can change it to + // 0 1 16 17 8 9 24 25 2 3 18 19 10 11 26 27 + // 4 5 20 21 12 13 28 29 6 7 22 23 14 15 30 31 + // 2 int4 pack to a int8 + // 0 8 4 12 1 9 5 13 2 10 6 14 3 11 7 15 + // find index of above list + // 0 4 8 12 2 6 10 14 1 5 9 13 3 7 11 15 + // we know int8 index is + // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 + // change it to + // 0 2 4 6 1 3 5 7 8 10 12 14 9 11 13 15 + // % 8 * 2 + // 0 4 8 12 2 6 10 14 0 4 8 12 2 6 10 14 + // add 1 for 0 4 8 12 2 6 10 14 [0 4 8 12 2 6 10 14] + // we get 0 4 8 12 2 6 10 14 1 5 9 13 3 7 11 15 + // it change ori to 0 8 4 12... + // finally we do some bitwise operation to change int4index + int permute_kk = (temp_k_expr_1 + temp_k_expr_2 + + (temp_k_expr_2 + 1) % 2 * k_mod_8 * 2 / 2 + + temp_k_expr_1 * temp_k_expr_2) % + 8 * 2 + + (k_id % 16) / 8 + k_id / 16 * 16; + int permute_index = permute_kk % 32 + permute_kk / 32 * 128 + + 32 * (n_id % 4) + total_k * 2 * (n_id / 4); + int8_t shift_quant_weight = input_data_dev[linear_idx]; + output_data_dev[permute_index] = + *reinterpret_cast(&shift_quant_weight); + } +} + +// bitwise operation +__global__ void weight_interval_kernel_wint4(int8_t* output_data_dev, + int numel) { + constexpr int value_per_interval_thread = 4; + constexpr int pack_size = 2; + for (int linear_idx = + (blockIdx.x * blockDim.x + threadIdx.x) * value_per_interval_thread; + linear_idx < numel; + linear_idx += blockDim.x * gridDim.x * value_per_interval_thread) { + for (int pack = 0; pack < pack_size; ++pack) { + int8_t interval_weight_0 = output_data_dev[linear_idx + pack]; + int8_t interval_weight_1 = output_data_dev[linear_idx + pack + 2]; + + uint8_t interval_weight_0_l = + static_cast(interval_weight_0) & 0x0F; + uint8_t interval_weight_0_r = + static_cast(interval_weight_0) >> 4; + uint8_t interval_weight_1_l = + static_cast(interval_weight_1) & 0x0F; + uint8_t interval_weight_1_r = + static_cast(interval_weight_1) >> 4; + + interval_weight_0_l = (interval_weight_0_l + 8) & 0x0F; + interval_weight_0_r = (interval_weight_0_r + 8) & 0x0F; + interval_weight_1_l = (interval_weight_1_l + 8) & 0x0F; + interval_weight_1_r = (interval_weight_1_r + 8) & 0x0F; + + uint8_t new_interval_weight_0 = + interval_weight_0_l | (interval_weight_1_l << 4); + uint8_t new_interval_weight_1 = + interval_weight_0_r | (interval_weight_1_r << 4); + + output_data_dev[linear_idx + pack] = + static_cast(new_interval_weight_0); + output_data_dev[linear_idx + pack + 2] = + static_cast(new_interval_weight_1); + } + } +} + /* For SM70 volta arch, weightonly int8 dequantize invoked in load global memory. So it only need interleave in K-dimension @@ -85,12 +176,45 @@ __global__ void weight_interleave_add_bias_kernel_wint8( } } +/* +For SM70 volta arch, weightonly int4 dequantize invoked in load global memory. +So it only need interleave in K-dimension +K_index: 0 1 2 3 4 5 6 7 -> 0 2 4 6 1 3 5 7 +*/ +__global__ void weight_interleave_add_bias_kernel_wint4(int8_t* input_data_dev, + int8_t* output_data_dev, + int numel, + int total_k, + int total_n) { + const int num_registers = numel / 4; + uint32_t* packed_input = reinterpret_cast(input_data_dev); + uint32_t* packed_output = reinterpret_cast(output_data_dev); + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num_registers; + i += blockDim.x * gridDim.x) { + uint32_t current_pack = packed_input[i]; + uint32_t transformed_pack = 0; +#pragma unroll + for (int idx = 0; idx < 8; ++idx) { + const int offset = idx / 4; + const int src = (idx % 4) * 2 + offset; + + const int src_shift = src * 4; + const int dst_shift = idx * 4; + + const uint32_t src_bits = ((current_pack >> src_shift) + 8) & 0xF; + transformed_pack |= (src_bits << dst_shift); + } + packed_output[i] = transformed_pack; + } +} + template void weight_permute_gpu(const GPUContext& dev_ctx, int8_t* input_data, int8_t* output_data, const std::vector& shape, - const int32_t arch) { + const int32_t arch, + const std::string& algo) { auto total_k = shape[0]; auto total_n = shape[1]; auto numel = total_k * total_n; @@ -99,11 +223,24 @@ void weight_permute_gpu(const GPUContext& dev_ctx, int block_size = gpu_config.GetBlockSize(); if ((arch == 90) || (arch == 89) || (arch == 86) || (arch == 80) || (arch == 75)) { - weight_permute_kernel_wint8<<>>( - input_data, output_data, numel, total_k, total_n); + if (algo == "weight_only_int4") { + numel /= 2; + weight_permute_kernel_wint4<<>>( + input_data, output_data, numel, total_k, total_n); + weight_interval_kernel_wint4<<>>(output_data, + numel); + } else { + weight_permute_kernel_wint8<<>>( + input_data, output_data, numel, total_k, total_n); + } } else if (arch == 70) { - weight_interleave_add_bias_kernel_wint8<<>>( - input_data, output_data, numel, total_k, total_n); + if (algo == "weight_only_int4") { + weight_interleave_add_bias_kernel_wint4<<>>( + input_data, output_data, numel, total_k, total_n); + } else { + weight_interleave_add_bias_kernel_wint8<<>>( + input_data, output_data, numel, total_k, total_n); + } } } @@ -162,12 +299,136 @@ __global__ void per_channel_quant_gpu(const T* weight_data, } } } + +template +__global__ void per_channel_quant_gpu_int4_row_pack(const T* weight_data, + int8_t* quanted_weight_data, + ScaleT* scale_data, + int total_k, + int total_vec_n) { + int n = blockIdx.x * blockDim.x + threadIdx.x; + if (n < total_vec_n) { + const int4* vec_weight_data_ptr = + reinterpret_cast(weight_data); + int* vec_quanted_weight_data = reinterpret_cast(quanted_weight_data); + phi::AlignedVector abs_max; +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + abs_max[i] = static_cast(0.0f); + } +#pragma unroll + for (int k = 0; k < total_k; ++k) { + int linear_index = k * total_vec_n + n; + phi::AlignedVector weight; + *reinterpret_cast(&weight) = vec_weight_data_ptr[linear_index]; +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + abs_max[i] = fmaxf((abs_max[i]), fabsf((weight[i]))); + } + } + phi::AlignedVector scale; +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + scale[i] = static_cast(abs_max[i] / static_cast(7.0f)); + } + *reinterpret_cast(scale_data + VectorSize * n) = + *reinterpret_cast(&scale); + for (int k = 0; k < total_k; ++k) { + int linear_index = k * total_vec_n + n; + phi::AlignedVector weight; + phi::AlignedVector quanted_weight; + *reinterpret_cast(&weight) = + *reinterpret_cast(vec_weight_data_ptr + linear_index); +#pragma unroll + for (int i = 0; i < VectorSize / 2; ++i) { + int8_t packed_int4s = 0; + for (int pack = 0; pack < 2; ++pack) { + int vector_index = i * 2 + pack; + const float r_scale = 1 / static_cast(scale[vector_index]); + const float weight_elt = + static_cast(weight[vector_index]) * r_scale; + float scaled_weight = roundf(weight_elt); + int int_weight = static_cast(scaled_weight); + int8_t clipped_weight = max(-7, min(7, int_weight)); + packed_int4s |= ((clipped_weight & 0x0F) << (4 * pack)); + } + quanted_weight[i] = packed_int4s; + } + *reinterpret_cast(vec_quanted_weight_data + linear_index) = + *reinterpret_cast(&quanted_weight); + } + } +} + +template +__global__ void per_channel_quant_gpu_int4_col_pack(const T* weight_data, + int8_t* quanted_weight_data, + ScaleT* scale_data, + int total_k, + int total_vec_n) { + int n = blockIdx.x * blockDim.x + threadIdx.x; + if (n < total_vec_n) { + const int4* vec_weight_data_ptr = + reinterpret_cast(weight_data); + int2* vec_quanted_weight_data = + reinterpret_cast(quanted_weight_data); + phi::AlignedVector abs_max; +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + abs_max[i] = static_cast(0.0f); + } +#pragma unroll + for (int k = 0; k < total_k; ++k) { + int linear_index = k * total_vec_n + n; + phi::AlignedVector weight; + *reinterpret_cast(&weight) = vec_weight_data_ptr[linear_index]; +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + abs_max[i] = fmaxf((abs_max[i]), static_cast(fabsf(weight[i]))); + } + } + phi::AlignedVector scale; +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + scale[i] = static_cast(abs_max[i] / static_cast(7.0f)); + } + *reinterpret_cast(scale_data + VectorSize * n) = + *reinterpret_cast(&scale); + + for (int k = 0; k < total_k / 2; ++k) { + phi::AlignedVector quanted_weight; + for (int packed_idx = 0; packed_idx < 2; ++packed_idx) { + int linear_index = (k * 2 + packed_idx) * total_vec_n + n; + phi::AlignedVector weight; + *reinterpret_cast(&weight) = + *reinterpret_cast(vec_weight_data_ptr + linear_index); +#pragma unroll + for (int i = 0; i < VectorSize; ++i) { + const float weight_elt = + (static_cast(weight[i]) / static_cast(abs_max[i])) * + static_cast(7.0); + const float scaled_weight = lroundf(weight_elt); + int int_weight = static_cast(scaled_weight); + const int8_t clipped_weight = fmaxf(-7, fminf(7, int_weight)); + quanted_weight[i] &= ~(0x0F << (4 * packed_idx)); + quanted_weight[i] |= ((clipped_weight & 0x0F) << (4 * packed_idx)); + } + } + int linear_index_new = k * total_vec_n + n; + *reinterpret_cast(vec_quanted_weight_data + linear_index_new) = + *reinterpret_cast(&quanted_weight); + } + } +} + template void weight_quant_gpu(const GPUContext& dev_ctx, const T* weight_data, int8_t* quanted_weight_data, ScaleT* scale_data, - const std::vector& shape) { + const std::vector& shape, + const int32_t arch, + const std::string& algo) { int total_k = shape[0]; int total_n = shape[1]; int numel = total_k * total_n; @@ -184,8 +445,27 @@ void weight_quant_gpu(const GPUContext& dev_ctx, int vec_total_n = total_n / kVectorSize; int kGridSize = max((vec_total_n + kBlockSize - 1) / kBlockSize, static_cast(1)); - per_channel_quant_gpu<<>>( - weight_data, quanted_weight_data, scale_data, total_k, vec_total_n); + if (algo == "weight_only_int4") { + if ((arch == 90) || (arch == 89) || (arch == 86) || (arch == 80) || + (arch == 75)) { + per_channel_quant_gpu_int4_col_pack + <<>>(weight_data, + quanted_weight_data, + scale_data, + total_k, + vec_total_n); + } else if ((arch == 70)) { + per_channel_quant_gpu_int4_row_pack + <<>>(weight_data, + quanted_weight_data, + scale_data, + total_k, + vec_total_n); + } + } else { + per_channel_quant_gpu<<>>( + weight_data, quanted_weight_data, scale_data, total_k, vec_total_n); + } } } // namespace phi