Skip to content

Commit e7492fd

Browse files
authored
[PaddleInference]Support fused_embedding_eltwise_layernorm max sequence lengths is 512 for VarSeqlen (#65079)
* add fused_embedding_eltwise_layernorm max sequence lengths is 512 for VarSeqlen
1 parent d9b6b82 commit e7492fd

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,13 +354,21 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(
354354
int32_t batchSize = inputDesc[0].dims.d[0] - 1;
355355
// read out the maximum sequence length from the dummy input
356356
int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[1];
357-
int32_t S = 384;
357+
int32_t S = 512;
358358
if (maxSeqlen <= 128) {
359359
S = 128;
360360
} else if (maxSeqlen <= 192) {
361361
S = 192;
362362
} else if (maxSeqlen <= 256) {
363363
S = 256;
364+
} else if (maxSeqlen <= 384) {
365+
S = 384;
366+
} else if (maxSeqlen <= 512) {
367+
S = 512;
368+
} else {
369+
std::cerr << "fused_embedding_eltwise_layernorm'max sequence lengths is "
370+
"512 for VarSeqlen"
371+
<< std::endl;
364372
}
365373
const float* beta = mBetaDev.get();
366374
const float* gamma = mGammaDev.get();
@@ -507,13 +515,21 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(
507515
int32_t batchSize = inputDesc[0].dims.d[0] - 1;
508516
// read out the maximum sequence length from the dummy input
509517
int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[1];
510-
int32_t S = 384;
518+
int32_t S = 512;
511519
if (maxSeqlen <= 128) {
512520
S = 128;
513521
} else if (maxSeqlen <= 192) {
514522
S = 192;
515523
} else if (maxSeqlen <= 256) {
516524
S = 256;
525+
} else if (maxSeqlen <= 384) {
526+
S = 384;
527+
} else if (maxSeqlen <= 512) {
528+
S = 512;
529+
} else {
530+
std::cerr << "fused_embedding_eltwise_layernorm'max sequence lengths is "
531+
"512 for VarSeqlen"
532+
<< std::endl;
517533
}
518534
const float* beta = mBetaDev.get();
519535
const float* gamma = mGammaDev.get();

0 commit comments

Comments
 (0)