Skip to content

【Inference】add weight_quanttize gpu kernel #72203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions paddle/phi/kernels/funcs/weight_dequant_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ __global__ void int4_weight_only_dequant(const uint8_t* weight,

int warp_id = threadIdx.x / 32, lane_id = threadIdx.x % 32;
int tile_id = blockIdx.x * blockDim.x / 32 + warp_id;
// Every two rows of the original weights are interleaved into a row with
// stride of 64, so if each thread processes 16 elements(for int8, we can use
// ldg.128 to load weights), then every group of four adjacent threads will
// alternately process two different row weights for example every 128
// Every 4 rows of the original weights are interleaved into a row with
// stride of 32, so if each thread processes 16 elements(for int8, we can use
// ldg.128 to load weights), then every group of two adjacent threads will
// alternately process four different row weights for example every 128
// consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave
// layout, the first 64 are from [64*i, 64*(i+1)-1] of row 2N before
// interleaving, and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1
Expand Down Expand Up @@ -383,6 +383,7 @@ void WeightDequantize(const Context& dev_ctx,
k,
group_size);
} else if (algo == "weight_only_int4" && group_size == -1) {
k *= 2;
grid.x /= 2;
int4_weight_only_dequant<DataType><<<grid, block, 0, stream>>>(
reinterpret_cast<const uint8_t*>(x.data<int8_t>()),
Expand All @@ -391,6 +392,7 @@ void WeightDequantize(const Context& dev_ctx,
n,
k);
} else if (algo == "weight_only_int4" && group_size > 0) {
k *= 2;
grid.x /= 2;
int4_weight_only_dequant<DataType><<<grid, block, 0, stream>>>(
reinterpret_cast<const uint8_t*>(x.data<int8_t>()),
Expand Down
28 changes: 22 additions & 6 deletions paddle/phi/kernels/gpu/weight_quantize_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,40 @@ void WeightQuantizeKernel(const Context& dev_ctx,
x.data<T>(),
quanted_x.data<int8_t>(),
scale->data<float>(),
weight_shape);
weight_shape,
arch,
algo);
trans(dev_ctx, quanted_x, out, axis);
} else if (algo == "weight_only_int8") {
dev_ctx.template Alloc<T>(scale);
weight_quant_gpu<T, Context>(dev_ctx,
x.data<T>(),
quanted_x.data<int8_t>(),
scale->data<T>(),
weight_shape);
weight_shape,
arch,
algo);
weight_permute_gpu<Context>(dev_ctx,
quanted_x.data<int8_t>(),
out->data<int8_t>(),
weight_shape,
arch);
arch,
algo);
} else if (algo == "weight_only_int4") {
PADDLE_FATAL(
"Weight quant gpu kernel currently don't support weight_only_int4 "
"algo, please use cpu version.");
dev_ctx.template Alloc<T>(scale);
weight_quant_gpu<T, Context>(dev_ctx,
x.data<T>(),
quanted_x.data<int8_t>(),
scale->data<T>(),
weight_shape,
arch,
algo);
weight_permute_gpu<Context>(dev_ctx,
quanted_x.data<int8_t>(),
out->data<int8_t>(),
weight_shape,
arch,
algo);
} else {
PADDLE_FATAL(
"The algo must be in ['weight_only_int8', 'weight_only_int4', "
Expand Down
Loading