@@ -2434,16 +2434,19 @@ void IndexSelectInferMeta(const MetaTensor& x,
2434
2434
" the dimension of Input(Index) is [%d]." ,
2435
2435
index_dim,
2436
2436
index_dim.size ()));
2437
-
2438
- PADDLE_ENFORCE_EQ (index_dim[0 ] != 0 ,
2439
- true ,
2440
- common::errors::InvalidArgument (
2441
- " The length of Input(Index) can't be 0." ));
2442
-
2443
- auto output_dim = common::vectorize (input_dim);
2444
2437
if (dim < 0 ) {
2445
2438
dim += input_dim.size ();
2446
2439
}
2440
+
2441
+ if (input_dim[dim] != 0 ) {
2442
+ PADDLE_ENFORCE_EQ (index_dim[0 ] != 0 ,
2443
+ true ,
2444
+ common::errors::InvalidArgument (
2445
+ " The length of Input(Index) can't be 0." ));
2446
+ }
2447
+
2448
+ auto output_dim = common::vectorize (input_dim);
2449
+
2447
2450
output_dim[dim] = index_dim[0 ];
2448
2451
output->set_dims (common::make_ddim (output_dim));
2449
2452
output->set_dtype (x.dtype ());
@@ -3668,18 +3671,23 @@ void RepeatInterleaveWithTensorIndexInferMeta(const MetaTensor& x,
3668
3671
repeats_dim,
3669
3672
repeats_dim.size ()));
3670
3673
3671
- PADDLE_ENFORCE_EQ (repeats_dim[0 ] != 0 ,
3672
- true ,
3673
- common::errors::InvalidArgument (
3674
- " The length of Input(RepeatsTensor) can't be 0." ));
3675
- PADDLE_ENFORCE_NE (out,
3676
- nullptr ,
3677
- common::errors::InvalidArgument (
3678
- " repeat_interleave's output tensor can't be nullptr" ));
3679
- if (dim < 0 ) {
3680
- dim += input_dim.size ();
3674
+ if (input_dim.size () == 1 && input_dim[0 ] == 0 ) {
3675
+ output_dim[0 ] = 0 ;
3676
+ } else {
3677
+ PADDLE_ENFORCE_EQ (repeats_dim[0 ] != 0 ,
3678
+ true ,
3679
+ common::errors::InvalidArgument (
3680
+ " The length of Input(RepeatsTensor) can't be 0." ));
3681
+ PADDLE_ENFORCE_NE (
3682
+ out,
3683
+ nullptr ,
3684
+ common::errors::InvalidArgument (
3685
+ " repeat_interleave's output tensor can't be nullptr" ));
3686
+ if (dim < 0 ) {
3687
+ dim += input_dim.size ();
3688
+ }
3689
+ output_dim[dim] = -1 ;
3681
3690
}
3682
- output_dim[dim] = -1 ;
3683
3691
3684
3692
out->set_dims (common::make_ddim (output_dim));
3685
3693
out->share_lod (x);
0 commit comments