Skip to content

Commit 09b18a4

Browse files
[Paddle-TRT] Implement MHA fp16 order same as training (#32629) (#32785)
* implement MHA order same as training * fix fp16 compile issue on old architecture Co-authored-by: zlsh80826 <rewang@nvidia.com>
1 parent 2ec6b6f commit 09b18a4

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,14 @@ nvinfer1::DataType QkvToContextPluginDynamic::getOutputDataType(
225225
return input_types[0];
226226
}
227227

228+
template <typename T>
229+
__global__ void apply_scale(T *data, T scale, int n) {
230+
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
231+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
232+
data[tid] = data[tid] * scale;
233+
#endif
234+
}
235+
228236
int QkvToContextPluginDynamic::enqueue(
229237
const nvinfer1::PluginTensorDesc *input_desc,
230238
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
@@ -291,10 +299,17 @@ int QkvToContextPluginDynamic::enqueue(
291299
platform::DeviceContextPool::Instance().Get(
292300
platform::CUDAPlace(device_id)));
293301

302+
int n_q = seq_len * head_number_ * head_size_;
303+
constexpr int threads = 128;
304+
int blocks = (n_q + threads - 1) / threads;
305+
306+
apply_scale<<<blocks, threads, 0, stream>>>(tptr, static_cast<half>(scale_),
307+
n_q);
308+
294309
const platform::CUDADeviceContext &dev_ctx = *device_ctx;
295310
operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
296311
multihead_compute_func(dev_ctx, batch, seq_len, head_number_, head_size_,
297-
qkptr, input1_data, tptr, half(scale_), half(0.0));
312+
qkptr, input1_data, tptr, half(1.), half(0.0));
298313

299314
int grid = batch * head_number_ * seq_len;
300315
int block = head_size_;

0 commit comments

Comments
 (0)