From 809650c8207651d2d5adf6c6012fc4e8f6e769a3 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Thu, 9 Feb 2023 13:42:05 +0000 Subject: [PATCH] add tensor numel check for float --- paddle/fluid/pybind/op_function_common.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 9f97556e20040..629419c2a4073 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -92,7 +92,8 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) { // sometimes users provide PyLong or numpy.int64 but attr is float if (PyFloat_Check(*obj) || PyLong_Check(*obj) || PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype) || // NOLINT - PyObject_IsInstance(*obj, (PyObject*)p_tensor_type)) { // NOLINT + (PyObject_IsInstance(*obj, (PyObject*)p_tensor_type) && // NOLINT + (((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT return true; } if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name) // NOLINT