@@ -225,6 +225,14 @@ nvinfer1::DataType QkvToContextPluginDynamic::getOutputDataType(
225
225
return input_types[0 ];
226
226
}
227
227
228
+ template <typename T>
229
+ __global__ void apply_scale (T *data, T scale, int n) {
230
+ #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
231
+ int tid = blockIdx .x * blockDim .x + threadIdx .x ;
232
+ data[tid] = data[tid] * scale;
233
+ #endif
234
+ }
235
+
228
236
int QkvToContextPluginDynamic::enqueue (
229
237
const nvinfer1::PluginTensorDesc *input_desc,
230
238
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
@@ -291,10 +299,17 @@ int QkvToContextPluginDynamic::enqueue(
291
299
platform::DeviceContextPool::Instance ().Get (
292
300
platform::CUDAPlace (device_id)));
293
301
302
+ int n_q = seq_len * head_number_ * head_size_;
303
+ constexpr int threads = 128 ;
304
+ int blocks = (n_q + threads - 1 ) / threads;
305
+
306
+ apply_scale<<<blocks, threads, 0 , stream>>> (tptr, static_cast <half>(scale_),
307
+ n_q);
308
+
294
309
const platform::CUDADeviceContext &dev_ctx = *device_ctx;
295
310
operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
296
311
multihead_compute_func (dev_ctx, batch, seq_len, head_number_, head_size_,
297
- qkptr, input1_data, tptr, half (scale_ ), half (0.0 ));
312
+ qkptr, input1_data, tptr, half (1 . ), half (0.0 ));
298
313
299
314
int grid = batch * head_number_ * seq_len;
300
315
int block = head_size_;
0 commit comments