Skip to content

Commit a053903

Browse files
EnigmatismsGITD245
authored andcommitted
[CINN] Fix arange float16 and bfloat16 support (PaddlePaddle#72669)
1 parent 928af82 commit a053903

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc

+6
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,12 @@ class ArangeOpPattern
554554
case phi::DataType::INT32:
555555
input = phi::Scalar(input.to<int>());
556556
break;
557+
case phi::DataType::FLOAT16:
558+
input = phi::Scalar(input.to<float>());
559+
break;
560+
case phi::DataType::BFLOAT16:
561+
input = phi::Scalar(input.to<float>());
562+
break;
557563
default:
558564
input = phi::Scalar(input.to<int64_t>());
559565
}

paddle/cinn/hlir/op/elementwise.cc

+13-1
Original file line numberDiff line numberDiff line change
@@ -1310,8 +1310,20 @@ std::shared_ptr<framework::OpStrategy> StrategyForArangeSymbolic(
13101310
EXPR_FROM_ATTR(int)
13111311
} else if (dtype.is_int(64)) {
13121312
EXPR_FROM_ATTR(int64_t)
1313+
} else if (dtype.is_bfloat16()) {
1314+
EXPR_FROM_ATTR(float)
1315+
start->set_type(cinn::common::BFloat16());
1316+
step->set_type(cinn::common::BFloat16());
1317+
} else if (dtype.is_float16()) {
1318+
EXPR_FROM_ATTR(float)
1319+
start->set_type(cinn::common::Float16());
1320+
step->set_type(cinn::common::Float16());
13131321
} else {
1314-
CINN_NOT_IMPLEMENTED
1322+
PADDLE_ENFORCE_NOT_NULL(
1323+
nullptr,
1324+
::common::errors::InvalidArgument(
1325+
"The dtype of arange op should be float32, float64, int32, int64, "
1326+
"bfloat16 or float16."));
13151327
}
13161328

13171329
#undef EXPR_FROM_ATTR

paddle/phi/infermeta/nullary.cc

+4
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ void ArangeInferMeta(const Scalar& start,
5757
GET_SIZE_GIVEN_TYPE(double)
5858
case DataType::INT32:
5959
GET_SIZE_GIVEN_TYPE(int)
60+
case DataType::FLOAT16:
61+
GET_SIZE_GIVEN_TYPE(float)
62+
case DataType::BFLOAT16:
63+
GET_SIZE_GIVEN_TYPE(float)
6064
default:
6165
GET_SIZE_GIVEN_TYPE(int64_t)
6266
}

0 commit comments

Comments
 (0)