From 9118132fe9145f11a4086c39604adf496711d6a9 Mon Sep 17 00:00:00 2001 From: Enigmatisms Date: Mon, 12 May 2025 04:05:49 +0000 Subject: [PATCH] [CINN] Fixed arange float16 and bfloat16 support --- .../dialect/operator/transforms/pd_to_cinn_pass.cc | 6 ++++++ paddle/cinn/hlir/op/elementwise.cc | 14 +++++++++++++- paddle/phi/infermeta/nullary.cc | 4 ++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index b9423106fe0d0..ffbd4e4f7b8e3 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -554,6 +554,12 @@ class ArangeOpPattern case phi::DataType::INT32: input = phi::Scalar(input.to()); break; + case phi::DataType::FLOAT16: + input = phi::Scalar(input.to()); + break; + case phi::DataType::BFLOAT16: + input = phi::Scalar(input.to()); + break; default: input = phi::Scalar(input.to()); } diff --git a/paddle/cinn/hlir/op/elementwise.cc b/paddle/cinn/hlir/op/elementwise.cc index 838f999bd1780..1d6c1b6fc2d73 100644 --- a/paddle/cinn/hlir/op/elementwise.cc +++ b/paddle/cinn/hlir/op/elementwise.cc @@ -1310,8 +1310,20 @@ std::shared_ptr StrategyForArangeSymbolic( EXPR_FROM_ATTR(int) } else if (dtype.is_int(64)) { EXPR_FROM_ATTR(int64_t) + } else if (dtype.is_bfloat16()) { + EXPR_FROM_ATTR(float) + start->set_type(cinn::common::BFloat16()); + step->set_type(cinn::common::BFloat16()); + } else if (dtype.is_float16()) { + EXPR_FROM_ATTR(float) + start->set_type(cinn::common::Float16()); + step->set_type(cinn::common::Float16()); } else { - CINN_NOT_IMPLEMENTED + PADDLE_ENFORCE_NOT_NULL( + nullptr, + ::common::errors::InvalidArgument( + "The dtype of arange op should be float32, float64, int32, int64, " + "bfloat16 or float16.")); } #undef EXPR_FROM_ATTR diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index 7df8862d9989f..020e87144f70d 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -57,6 +57,10 @@ void ArangeInferMeta(const Scalar& start, GET_SIZE_GIVEN_TYPE(double) case DataType::INT32: GET_SIZE_GIVEN_TYPE(int) + case DataType::FLOAT16: + GET_SIZE_GIVEN_TYPE(float) + case DataType::BFLOAT16: + GET_SIZE_GIVEN_TYPE(float) default: GET_SIZE_GIVEN_TYPE(int64_t) }