Skip to content

Commit bc7380d

Browse files
authored
[pir] add __neg__ and clean pir.cc (#60166)
1 parent 81d800b commit bc7380d

File tree

3 files changed

+25
-64
lines changed

3 files changed

+25
-64
lines changed

paddle/fluid/pybind/pir.cc

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -581,55 +581,6 @@ const phi::DDim &GetValueDims(Value value) {
581581
}
582582
}
583583

584-
#define OVERRIDE_OPERATOR(operator, api, other_type) \
585-
value.def(#operator, [](Value self, other_type other) { \
586-
return paddle::dialect::api(self, other); \
587-
});
588-
589-
#define OVERRIDE_OPERATOR_WITH_SCALE(operator, \
590-
other_type, \
591-
scale_value, \
592-
bias_value, \
593-
bias_after_scale) \
594-
value.def(#operator, [](Value self, other_type other) { \
595-
return paddle::dialect::scale( \
596-
self, scale_value, bias_value, bias_after_scale); \
597-
});
598-
599-
#define OVERRIDE_OPERATOR_FOR_EACH(operator, \
600-
api, \
601-
scale_value, \
602-
bias_value, \
603-
bias_after_scale) \
604-
OVERRIDE_OPERATOR(operator, api, Value) \
605-
OVERRIDE_OPERATOR_WITH_SCALE(operator, \
606-
int, \
607-
scale_value, \
608-
bias_value, \
609-
bias_after_scale) \
610-
OVERRIDE_OPERATOR_WITH_SCALE(operator, \
611-
float, \
612-
scale_value, \
613-
bias_value, \
614-
bias_after_scale) \
615-
OVERRIDE_OPERATOR_WITH_SCALE(operator, \
616-
double, \
617-
scale_value, \
618-
bias_value, \
619-
bias_after_scale)
620-
621-
#define OVERRIDE_COMPARE_OP_WITH_FULL(operator, api, other_type) \
622-
value.def(#operator, [](Value self, other_type other) { \
623-
auto rhs = \
624-
paddle::dialect::full(/*shape=*/{}, other, GetValueDtype(self)); \
625-
return paddle::dialect::api(self, rhs); \
626-
});
627-
628-
#define OVERRIDE_COMPARE_OP_FOR_EACH(operator, api) \
629-
OVERRIDE_OPERATOR(operator, api, Value) \
630-
OVERRIDE_COMPARE_OP_WITH_FULL(operator, api, int) \
631-
OVERRIDE_COMPARE_OP_WITH_FULL(operator, api, float) \
632-
OVERRIDE_COMPARE_OP_WITH_FULL(operator, api, double)
633584
void BindValue(py::module *m) {
634585
py::class_<Value> value(*m, "Value", R"DOC(
635586
Value class represents the SSA value in the IR system. It is a directed edge
@@ -787,23 +738,9 @@ void BindValue(py::module *m) {
787738
print_stream << ")";
788739
return print_stream.str();
789740
})
790-
.def("__neg__",
791-
[](Value self) {
792-
return paddle::dialect::scale(self, -1.0, 0.0, true);
793-
})
794741
.def("is_same", &Value::operator==)
795742
.def("hash", [](Value self) { return std::hash<pir::Value>{}(self); })
796743
.def("__repr__", &Value2String);
797-
// For basaic operators
798-
OVERRIDE_OPERATOR_FOR_EACH(__add__, add, 1.0, other, true);
799-
OVERRIDE_OPERATOR_FOR_EACH(__sub__, subtract, 1.0, -1.0 * other, true);
800-
OVERRIDE_OPERATOR_FOR_EACH(__mul__, multiply, other, 0.0, false);
801-
OVERRIDE_OPERATOR_FOR_EACH(__truediv__, divide, 1.0 / other, 0.0, false);
802-
// For compare opeartors
803-
OVERRIDE_COMPARE_OP_FOR_EACH(__lt__, less_than);
804-
OVERRIDE_COMPARE_OP_FOR_EACH(__le__, less_equal);
805-
OVERRIDE_COMPARE_OP_FOR_EACH(__gt__, greater_than);
806-
OVERRIDE_COMPARE_OP_FOR_EACH(__ge__, greater_equal);
807744
}
808745

809746
void BindOpOperand(py::module *m) {

python/paddle/pir/math_op_patch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,9 @@ def _scalar_mul_(var, value):
281281
def _scalar_div_(var, value):
282282
return paddle.scale(var, 1.0 / value, 0.0)
283283

284+
def _scalar_neg_(var):
285+
return paddle.scale(var, -1.0, 0.0)
286+
284287
def _binary_creator_(
285288
method_name,
286289
python_api,
@@ -513,6 +516,7 @@ def value_hash(self):
513516
('append', append),
514517
('set_shape', set_shape),
515518
('__hash__', value_hash),
519+
# For basic operators
516520
(
517521
'__add__',
518522
_binary_creator_('__add__', paddle.tensor.add, False, _scalar_add_),
@@ -591,7 +595,8 @@ def value_hash(self):
591595
'__matmul__',
592596
_binary_creator_('__matmul__', paddle.tensor.matmul, False, None),
593597
),
594-
# for logical compare
598+
('__neg__', _scalar_neg_),
599+
# For compare opeartors
595600
(
596601
'__eq__',
597602
_binary_creator_('__eq__', paddle.tensor.equal, False, None),

test/legacy_test/test_math_op_patch_pir.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,25 @@ def test_append(self):
491491
with self.assertRaises(TypeError):
492492
x.append(array)
493493

494+
def test_neg(self):
495+
x_np = np.random.uniform(-1, 1, [10, 1024]).astype(np.float32)
496+
res = -x_np
497+
with paddle.pir_utils.IrGuard():
498+
main_program, exe, program_guard = new_program()
499+
with program_guard:
500+
x = paddle.static.data(
501+
name='x', shape=[10, 1024], dtype="float32"
502+
)
503+
a = -x
504+
b = x.__neg__()
505+
(a_np, b_np) = exe.run(
506+
main_program,
507+
feed={"x": x_np},
508+
fetch_list=[a, b],
509+
)
510+
np.testing.assert_array_equal(res, a_np)
511+
np.testing.assert_array_equal(res, b_np)
512+
494513
def test_math_exists(self):
495514
with paddle.pir_utils.IrGuard():
496515
a = paddle.static.data(name='a', shape=[1], dtype='float32')

0 commit comments

Comments
 (0)