File tree Expand file tree Collapse file tree 1 file changed +18
-2
lines changed
paddle/fluid/inference/tensorrt/plugin Expand file tree Collapse file tree 1 file changed +18
-2
lines changed Original file line number Diff line number Diff line change @@ -354,13 +354,21 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(
354
354
int32_t batchSize = inputDesc[0 ].dims .d [0 ] - 1 ;
355
355
// read out the maximum sequence length from the dummy input
356
356
int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims .d [1 ];
357
- int32_t S = 384 ;
357
+ int32_t S = 512 ;
358
358
if (maxSeqlen <= 128 ) {
359
359
S = 128 ;
360
360
} else if (maxSeqlen <= 192 ) {
361
361
S = 192 ;
362
362
} else if (maxSeqlen <= 256 ) {
363
363
S = 256 ;
364
+ } else if (maxSeqlen <= 384 ) {
365
+ S = 384 ;
366
+ } else if (maxSeqlen <= 512 ) {
367
+ S = 512 ;
368
+ } else {
369
+ std::cerr << " fused_embedding_eltwise_layernorm'max sequence lengths is "
370
+ " 512 for VarSeqlen"
371
+ << std::endl;
364
372
}
365
373
const float * beta = mBetaDev .get ();
366
374
const float * gamma = mGammaDev .get ();
@@ -507,13 +515,21 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(
507
515
int32_t batchSize = inputDesc[0 ].dims .d [0 ] - 1 ;
508
516
// read out the maximum sequence length from the dummy input
509
517
int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims .d [1 ];
510
- int32_t S = 384 ;
518
+ int32_t S = 512 ;
511
519
if (maxSeqlen <= 128 ) {
512
520
S = 128 ;
513
521
} else if (maxSeqlen <= 192 ) {
514
522
S = 192 ;
515
523
} else if (maxSeqlen <= 256 ) {
516
524
S = 256 ;
525
+ } else if (maxSeqlen <= 384 ) {
526
+ S = 384 ;
527
+ } else if (maxSeqlen <= 512 ) {
528
+ S = 512 ;
529
+ } else {
530
+ std::cerr << " fused_embedding_eltwise_layernorm'max sequence lengths is "
531
+ " 512 for VarSeqlen"
532
+ << std::endl;
517
533
}
518
534
const float * beta = mBetaDev .get ();
519
535
const float * gamma = mGammaDev .get ();
You can’t perform that action at this time.
0 commit comments