Skip to content

Commit fdd7f9c

Browse files
authored
[mlir][tosa] Allow scalar int8 tensors to be unranked (#150731)
This PR fixes #150519
1 parent 2adbf9e commit fdd7f9c

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
151151
def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
152152

153153
def Tosa_ScalarTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_AnyNumber], [1]>]>;
154-
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
154+
def Tosa_ScalarInt8Tensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int8]>, TosaScalarTensorOf<[Tosa_Int8], [1]>]>;
155155
def Tosa_ScalarIntOrFloatTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>]>;
156156

157157
// We include unranked tensors as a supported type for all possible tosa

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,7 +1125,7 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1:
11251125
// CHECK-LABEL: test_mul_non_scalar_shift_2d
11261126
func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
11271127
%shift = "tosa.const"() <{values = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8>
1128-
// expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
1128+
// expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
11291129
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32>
11301130
return %0 : tensor<13x21x3xf32>
11311131
}
@@ -1134,7 +1134,7 @@ func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tenso
11341134
// CHECK-LABEL: test_mul_non_scalar_shift_1d
11351135
func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
11361136
%shift = "tosa.const"() <{values = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8>
1137-
// expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
1137+
// expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
11381138
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32>
11391139
return %0 : tensor<13x21x3xf32>
11401140
}

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,17 @@ func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1:
357357

358358
// -----
359359

360+
// CHECK-LABEL: @test_unranked_scalar_i8_tensor
361+
func.func @test_unranked_scalar_i8_tensor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>, %arg2: tensor<1xi8>) -> tensor<4xi32> {
362+
// CHECK: %[[SHIFT:.*]] = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<1xi8>
363+
%shift = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<*xi8>
364+
// CHECK: tosa.mul %arg0, %arg1, %[[SHIFT]] : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
365+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<*xi8>) -> tensor<4xi32>
366+
return %0 : tensor<4xi32>
367+
}
368+
369+
// -----
370+
360371
// CHECK-LABEL: @test_table_static
361372
func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () {
362373
// CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16>

0 commit comments

Comments
 (0)