Skip to content

[Inference] Support group-wize quantize for weight_quantize op in GPU #71549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions paddle/phi/kernels/gpu/weight_quantize_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ void WeightQuantizeKernel(const Context& dev_ctx,
scale->data<float>(),
weight_shape,
arch,
algo);
algo,
group_size);
trans(dev_ctx, quanted_x, out, axis);
} else if (algo == "weight_only_int8") {
dev_ctx.template Alloc<T>(scale);
Expand All @@ -72,7 +73,8 @@ void WeightQuantizeKernel(const Context& dev_ctx,
scale->data<T>(),
weight_shape,
arch,
algo);
algo,
group_size);
#ifdef PADDLE_WITH_HIP
std::vector<int> axis = {1, 0};
funcs::Transpose<Context, int8_t, 2> trans;
Expand All @@ -93,7 +95,8 @@ void WeightQuantizeKernel(const Context& dev_ctx,
scale->data<T>(),
weight_shape,
arch,
algo);
algo,
group_size);
#ifdef PADDLE_WITH_HIP
DenseTensor x_int_tmp(out->type());
x_int_tmp.Resize({static_cast<int64_t>(m), static_cast<int64_t>(n / 2)});
Expand Down
264 changes: 251 additions & 13 deletions paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,14 +425,222 @@ __global__ void per_channel_quant_gpu_int4_col_pack(const T* weight_data,
}
}

template <typename T, int VectorSize = 8, typename ScaleT>
__global__ void per_group_quant_gpu_int4_col_pack(const T* weight_data,
int8_t* quanted_weight_data,
ScaleT* scale_data,
int total_k,
int total_vec_n,
int group_size) {
int n = blockIdx.x * blockDim.x + threadIdx.x;
if (n < total_vec_n) {
const int4* vec_weight_data_ptr =
reinterpret_cast<const int4*>(weight_data);

phi::AlignedVector<float, VectorSize> abs_max;

// Compute per group row
for (int k = 0; k < total_k; k += group_size) {
// Init per group abs_max
#pragma unroll
for (int i = 0; i < VectorSize; ++i) {
abs_max[i] = static_cast<float>(0.0f);
}
for (int g = 0; g < group_size && k + g < total_k; g++) {
int linear_index = (k + g) * total_vec_n + n;
phi::AlignedVector<T, VectorSize> weight;
*reinterpret_cast<int4*>(&weight) = vec_weight_data_ptr[linear_index];
#pragma unroll
for (int i = 0; i < VectorSize; ++i) {
abs_max[i] = fmaxf(abs_max[i], fabsf(weight[i]));
}
}
// Compute Scale
phi::AlignedVector<ScaleT, VectorSize> scale;
#pragma unroll
for (int i = 0; i < VectorSize; ++i) {
scale[i] = static_cast<ScaleT>(abs_max[i] / static_cast<float>(7.0f));
}
*reinterpret_cast<float4*>(
scale_data + (k / group_size) * (total_vec_n * VectorSize) +
n * VectorSize) = *reinterpret_cast<float4*>(&scale);

// group-wise weight quant
for (int g = 0; g < group_size / 2; g++) {
phi::AlignedVector<int8_t, VectorSize> quanted_weight;
// write 2 elements to an int8
for (int packed_idx = 0;
packed_idx < 2 && k + g * 2 + packed_idx < total_k;
packed_idx++) {
int linear_index = (k + g * 2 + packed_idx) * total_vec_n + n;
phi::AlignedVector<T, VectorSize> weight;
*reinterpret_cast<int4*>(&weight) = *reinterpret_cast<const int4*>(
vec_weight_data_ptr + linear_index);
#pragma unroll
for (int i = 0; i < VectorSize; ++i) {
float weight_elt =
(static_cast<float>(weight[i]) / static_cast<float>(scale[i]));
int8_t clipped_weight = static_cast<int8_t>(
lroundf(fmaxf(-7.0f, fminf(7.0f, weight_elt))));
// Reset the last 4 bit or first 4 bit
quanted_weight[i] &= ~(0x0F << (4 * packed_idx));
quanted_weight[i] |= ((clipped_weight & 0x0F) << (4 * packed_idx));
}
}
int linear_index =
(k / 2 + g) * total_vec_n * VectorSize + n * VectorSize;

*reinterpret_cast<int2*>(quanted_weight_data + linear_index) =
*reinterpret_cast<int2*>(&quanted_weight);
}
}
}
}

template <typename T, int VectorSize = 8, typename ScaleT>
__global__ void per_group_quant_gpu_int4_row_pack(const T* weight_data,
int8_t* quanted_weight_data,
ScaleT* scale_data,
int total_k,
int total_vec_n,
int group_size) {
int n = blockIdx.x * blockDim.x + threadIdx.x;
if (n < total_vec_n) {
const int4* vec_weight_data_ptr =
reinterpret_cast<const int4*>(weight_data);

phi::AlignedVector<float, VectorSize> abs_max;

// Compute per group row
for (int k = 0; k < total_k; k += group_size) {
// Init per group abs_max
#pragma unroll
for (int i = 0; i < VectorSize; ++i) {
abs_max[i] = static_cast<float>(0.0f);
}
for (int g = 0; g < group_size && k + g < total_k; g++) {
int linear_index = (k + g) * total_vec_n + n;
phi::AlignedVector<T, VectorSize> weight;
*reinterpret_cast<int4*>(&weight) = vec_weight_data_ptr[linear_index];
#pragma unroll
for (int i = 0; i < VectorSize; ++i) {
abs_max[i] = fmaxf(abs_max[i], fabsf(weight[i]));
}
}
// Compute Scale
phi::AlignedVector<ScaleT, VectorSize> scale;
#pragma unroll
for (int i = 0; i < VectorSize; ++i) {
scale[i] = static_cast<ScaleT>(abs_max[i] / static_cast<float>(7.0f));
}
*reinterpret_cast<float4*>(
scale_data + (k / group_size) * (total_vec_n * VectorSize) +
n * VectorSize) = *reinterpret_cast<float4*>(&scale);

// group-wise weight quant
for (int g = 0; g < group_size && k + g < total_k; g++) {
int linear_index = (k + g) * total_vec_n + n;
phi::AlignedVector<T, VectorSize> weight;
phi::AlignedVector<int8_t, VectorSize / 2> quanted_weight;
*reinterpret_cast<int4*>(&weight) =
*reinterpret_cast<const int4*>(vec_weight_data_ptr + linear_index);
#pragma unroll
for (int i = 0; i < VectorSize / 2; i++) {
int8_t packed_int4s = 0;
for (int pack = 0; pack < 2; pack++) {
int vector_index = i * 2 + pack;
const float weight_elt = static_cast<float>(weight[vector_index]) /
static_cast<float>(scale[vector_index]);
int8_t clipped_weight = static_cast<int8_t>(
lroundf(fmaxf(-7.0f, fminf(7.0f, weight_elt))));
packed_int4s |= ((clipped_weight & 0x0F) << (4 * pack));
}
quanted_weight[i] = packed_int4s;
}
int quant_weight_idx =
(k + g) * total_vec_n * VectorSize / 2 + n * VectorSize / 2;
*reinterpret_cast<int*>(quanted_weight_data + quant_weight_idx) =
*reinterpret_cast<int*>(&quanted_weight);
}
}
}
}

template <typename T, int VectorSize = 8, typename ScaleT>
__global__ void group_wise_quant_gpu(const T* weight_data,
int8_t* quanted_weight_data,
ScaleT* scale_data,
int total_k,
int total_vec_n,
int group_size) {
int n = blockIdx.x * blockDim.x + threadIdx.x;
// This can be optimize with group-wize parallel
if (n < total_vec_n) {
const int4* vec_weight_data_ptr =
reinterpret_cast<const int4*>(weight_data);
int2* vec_quanted_weight_data =
reinterpret_cast<int2*>(quanted_weight_data);

phi::AlignedVector<float, VectorSize> abs_max;

// Compute per group row
for (int k = 0; k < total_k; k += group_size) {
// Init per group abs_max
#pragma unroll
for (int i = 0; i < VectorSize; ++i) {
abs_max[i] = static_cast<float>(0.0f);
}
for (int g = 0; g < group_size && k + g < total_k; g++) {
int linear_index = (k + g) * total_vec_n + n;
phi::AlignedVector<T, VectorSize> weight;
*reinterpret_cast<int4*>(&weight) = vec_weight_data_ptr[linear_index];
#pragma unroll
for (int i = 0; i < VectorSize; ++i) {
abs_max[i] = fmaxf(abs_max[i], fabsf(weight[i]));
}
}
// Compute Scale
phi::AlignedVector<ScaleT, VectorSize> scale;
#pragma unroll
for (int i = 0; i < VectorSize; ++i) {
scale[i] = static_cast<ScaleT>(abs_max[i] / static_cast<float>(127.0f));
}
*reinterpret_cast<float4*>(
scale_data + (k / group_size) * (total_vec_n * VectorSize) +
n * VectorSize) = *reinterpret_cast<float4*>(&scale);

// group-wise weight quant
for (int g = 0; g < group_size && k + g < total_k; g++) {
phi::AlignedVector<int8_t, VectorSize> quanted_weight;
int linear_index = (k + g) * total_vec_n + n;
phi::AlignedVector<T, VectorSize> weight;
*reinterpret_cast<int4*>(&weight) =
*reinterpret_cast<const int4*>(vec_weight_data_ptr + linear_index);
#pragma unroll
for (int i = 0; i < VectorSize; ++i) {
float scaled_weight =
(static_cast<float>(weight[i]) / static_cast<float>(abs_max[i])) *
static_cast<float>(127.0f);
int8_t clipped_weight = static_cast<int8_t>(
lroundf(fmaxf(-127.0f, fminf(127.0f, scaled_weight))));
quanted_weight[i] = clipped_weight;
}
*reinterpret_cast<int2*>(vec_quanted_weight_data + linear_index) =
*reinterpret_cast<int2*>(&quanted_weight);
}
}
}
}

template <typename T, typename GPUContext, typename ScaleT>
void weight_quant_gpu(const GPUContext& dev_ctx,
const T* weight_data,
int8_t* quanted_weight_data,
ScaleT* scale_data,
const std::vector<int>& shape,
const int32_t arch,
const std::string& algo) {
const std::string& algo,
const int32_t group_size) {
int total_k = shape[0];
int total_n = shape[1];
int numel = total_k * total_n;
Expand All @@ -457,24 +665,54 @@ void weight_quant_gpu(const GPUContext& dev_ctx,
#else
if ((arch == 90) || (arch == 89) || (arch == 86) || (arch == 80) ||
(arch == 75)) {
per_channel_quant_gpu_int4_col_pack<T, kVectorSize>
<<<kGridSize, kBlockSize>>>(weight_data,
quanted_weight_data,
scale_data,
total_k,
vec_total_n);
if (group_size == -1) { // per channel
per_channel_quant_gpu_int4_col_pack<T, kVectorSize>
<<<kGridSize, kBlockSize>>>(weight_data,
quanted_weight_data,
scale_data,
total_k,
vec_total_n);
} else {
per_group_quant_gpu_int4_col_pack<T, kVectorSize>
<<<kGridSize, kBlockSize>>>(weight_data,
quanted_weight_data,
scale_data,
total_k,
vec_total_n,
group_size);
}
} else if ((arch == 70)) {
per_channel_quant_gpu_int4_row_pack<T, kVectorSize>
if (group_size == -1) {
per_channel_quant_gpu_int4_row_pack<T, kVectorSize>
<<<kGridSize, kBlockSize>>>(weight_data,
quanted_weight_data,
scale_data,
total_k,
vec_total_n);
} else {
per_group_quant_gpu_int4_row_pack<T, kVectorSize>
<<<kGridSize, kBlockSize>>>(weight_data,
quanted_weight_data,
scale_data,
total_k,
vec_total_n,
group_size);
}
}
#endif
} else {
if (group_size == -1) { // per channel
per_channel_quant_gpu<T, kVectorSize><<<kGridSize, kBlockSize>>>(
weight_data, quanted_weight_data, scale_data, total_k, vec_total_n);
} else {
group_wise_quant_gpu<T, kVectorSize>
<<<kGridSize, kBlockSize>>>(weight_data,
quanted_weight_data,
scale_data,
total_k,
vec_total_n);
vec_total_n,
group_size);
}
#endif
} else {
per_channel_quant_gpu<T, kVectorSize><<<kGridSize, kBlockSize>>>(
weight_data, quanted_weight_data, scale_data, total_k, vec_total_n);
}
}

Expand Down
Loading
Loading