@@ -481,6 +481,165 @@ def test(self):
481
481
self .run_test ()
482
482
483
483
484
+ class TrtConvertGreaterEqualTest (TrtLayerAutoScanTest ):
485
+ def is_program_valid (self , program_config : ProgramConfig ) -> bool :
486
+ return True
487
+
488
+ def sample_program_configs (self ):
489
+ def generate_input (shape ):
490
+ return np .random .random (shape ).astype (np .float32 )
491
+
492
+ for shape in [[2 , 16 ], [2 , 16 , 32 ], [1 , 32 , 16 , 32 ]]:
493
+ for op_type in ["greater_equal" ]:
494
+ for axis in [- 1 ]:
495
+ self .dims = len (shape )
496
+ dics = [
497
+ {"axis" : axis },
498
+ {"in_dtype" : 5 , "out_dtype" : 2 },
499
+ {"in_dtype" : 0 , "out_dtype" : 5 },
500
+ ]
501
+ ops_config = [
502
+ {
503
+ "op_type" : "cast" ,
504
+ "op_inputs" : {"X" : ["input_data1" ]},
505
+ "op_outputs" : {"Out" : ["cast_output_data1" ]},
506
+ "op_attrs" : dics [1 ],
507
+ "outputs_dtype" : {"cast_output_data1" : np .int32 },
508
+ },
509
+ {
510
+ "op_type" : "cast" ,
511
+ "op_inputs" : {"X" : ["input_data2" ]},
512
+ "op_outputs" : {"Out" : ["cast_output_data2" ]},
513
+ "op_attrs" : dics [1 ],
514
+ "outputs_dtype" : {"cast_output_data2" : np .int32 },
515
+ },
516
+ {
517
+ "op_type" : op_type ,
518
+ "op_inputs" : {
519
+ "X" : ["cast_output_data1" ],
520
+ "Y" : ["cast_output_data2" ],
521
+ },
522
+ "op_outputs" : {"Out" : ["cast_output_data0" ]},
523
+ "op_attrs" : dics [0 ],
524
+ },
525
+ {
526
+ "op_type" : "cast" ,
527
+ "op_inputs" : {"X" : ["cast_output_data0" ]},
528
+ "op_outputs" : {"Out" : ["output_data" ]},
529
+ "op_attrs" : dics [2 ],
530
+ },
531
+ ]
532
+ ops = self .generate_op_config (ops_config )
533
+
534
+ program_config = ProgramConfig (
535
+ ops = ops ,
536
+ weights = {},
537
+ inputs = {
538
+ "input_data1" : TensorConfig (
539
+ data_gen = partial (generate_input , shape )
540
+ ),
541
+ "input_data2" : TensorConfig (
542
+ data_gen = partial (generate_input , shape )
543
+ ),
544
+ },
545
+ outputs = ["output_data" ],
546
+ )
547
+
548
+ yield program_config
549
+
550
+ def sample_predictor_configs (
551
+ self , program_config
552
+ ) -> (paddle_infer .Config , List [int ], float ):
553
+ def generate_dynamic_shape (attrs ):
554
+ if self .dims == 2 :
555
+ self .dynamic_shape .min_input_shape = {
556
+ "input_data1" : [2 , 16 ],
557
+ "input_data2" : [2 , 16 ],
558
+ }
559
+ self .dynamic_shape .max_input_shape = {
560
+ "input_data1" : [2 , 16 ],
561
+ "input_data2" : [2 , 16 ],
562
+ }
563
+ self .dynamic_shape .opt_input_shape = {
564
+ "input_data1" : [2 , 16 ],
565
+ "input_data2" : [2 , 16 ],
566
+ }
567
+ if self .dims == 3 :
568
+ self .dynamic_shape .min_input_shape = {
569
+ "input_data1" : [2 , 16 , 32 ],
570
+ "input_data2" : [2 , 16 , 32 ],
571
+ }
572
+ self .dynamic_shape .max_input_shape = {
573
+ "input_data1" : [2 , 16 , 32 ],
574
+ "input_data2" : [2 , 16 , 32 ],
575
+ }
576
+ self .dynamic_shape .opt_input_shape = {
577
+ "input_data1" : [2 , 16 , 32 ],
578
+ "input_data2" : [2 , 16 , 32 ],
579
+ }
580
+ if self .dims == 4 :
581
+ self .dynamic_shape .min_input_shape = {
582
+ "input_data1" : [1 , 32 , 16 , 32 ],
583
+ "input_data2" : [1 , 32 , 16 , 32 ],
584
+ }
585
+ self .dynamic_shape .max_input_shape = {
586
+ "input_data1" : [1 , 32 , 16 , 32 ],
587
+ "input_data2" : [1 , 32 , 16 , 32 ],
588
+ }
589
+ self .dynamic_shape .opt_input_shape = {
590
+ "input_data1" : [1 , 32 , 16 , 32 ],
591
+ "input_data2" : [1 , 32 , 16 , 32 ],
592
+ }
593
+
594
+ def clear_dynamic_shape ():
595
+ self .dynamic_shape .max_input_shape = {}
596
+ self .dynamic_shape .min_input_shape = {}
597
+ self .dynamic_shape .opt_input_shape = {}
598
+
599
+ def generate_trt_nodes_num (attrs , dynamic_shape ):
600
+ ver = paddle_infer .get_trt_compile_version ()
601
+ if (
602
+ ver [0 ] * 1000 + ver [1 ] * 100 + ver [2 ] * 10 < 8400
603
+ or not dynamic_shape
604
+ ):
605
+ return 2 , 5
606
+ else :
607
+ return 1 , 3
608
+
609
+ attrs = [
610
+ program_config .ops [i ].attrs for i in range (len (program_config .ops ))
611
+ ]
612
+
613
+ # for static_shape
614
+ clear_dynamic_shape ()
615
+ self .trt_param .precision = paddle_infer .PrecisionType .Float32
616
+ yield self .create_inference_config (), generate_trt_nodes_num (
617
+ attrs , False
618
+ ), 1e-5
619
+ self .trt_param .precision = paddle_infer .PrecisionType .Half
620
+ yield self .create_inference_config (), generate_trt_nodes_num (
621
+ attrs , False
622
+ ), (1e-3 , 1e-3 )
623
+
624
+ # for dynamic_shape
625
+ generate_dynamic_shape (attrs )
626
+ self .trt_param .precision = paddle_infer .PrecisionType .Float32
627
+ yield self .create_inference_config (), generate_trt_nodes_num (
628
+ attrs , True
629
+ ), 1e-5
630
+ self .trt_param .precision = paddle_infer .PrecisionType .Half
631
+ yield self .create_inference_config (), generate_trt_nodes_num (
632
+ attrs , True
633
+ ), (1e-3 , 1e-3 )
634
+
635
+ def add_skip_trt_case (self ):
636
+ pass
637
+
638
+ def test (self ):
639
+ self .add_skip_trt_case ()
640
+ self .run_test ()
641
+
642
+
484
643
class TrtConvertCompareSkipTest (TrtLayerAutoScanTest ):
485
644
def is_program_valid (self , program_config : ProgramConfig ) -> bool :
486
645
return True
0 commit comments