Skip to content

Commit 3eafa1f

Browse files
lxd-cumtCharles-hit0x45f
authored
Auto codegen for supporting calling new_ir api in static operants (#56955)
* support new ir primitive operator in static operants * support more vjp code gen * support more vjp code gen * support more vjp code gen * use code gen * fix operants codegen * support more vjp code gen * Fix ci build error * set FLAGS_tensor_operants_mode to static in generated_vjp for testing * fix bugs * change the order of ops_name of divide_grad * replace FLAGS_enable_new_ir_in_executor by FLAGS_enable_new_ir_api in codegen and test_vjp_prim --------- Co-authored-by: Charles-hit <wanghao107@baidu.com> Co-authored-by: 0x45f <wangzhen45@baidu.com>
1 parent c62902e commit 3eafa1f

File tree

6 files changed

+80
-22
lines changed

6 files changed

+80
-22
lines changed

paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ class StaticTensorOperants : public TensorOperantsBase {
211211
#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
212212
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
213213
214+
#include "paddle/fluid/primitive/backend/backend.h"
215+
#include "paddle/fluid/primitive/type/lazy_tensor.h"
216+
217+
PHI_DECLARE_bool(enable_new_ir_api);
218+
214219
"""
215220

216221

@@ -219,47 +224,88 @@ class StaticTensorOperants : public TensorOperantsBase {
219224
220225
namespace prim {
221226
using DescTensor = paddle::prim::DescTensor;
227+
using LazyTensor = paddle::primitive::LazyTensor;
222228
223229
Tensor StaticTensorOperants::add(const Tensor& x, const Scalar& y) {
224-
return paddle::prim::add<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
230+
if (FLAGS_enable_new_ir_api) {
231+
return paddle::primitive::backend::add<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
232+
} else {
233+
return paddle::prim::add<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
234+
}
225235
}
226236
227237
Tensor StaticTensorOperants::subtract(const Tensor& x, const Scalar& y) {
228-
return paddle::prim::subtract<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
238+
if (FLAGS_enable_new_ir_api) {
239+
return paddle::primitive::backend::subtract<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
240+
} else {
241+
return paddle::prim::subtract<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
242+
}
229243
}
230244
231245
Tensor StaticTensorOperants::multiply(const Tensor& x, const Scalar& y) {
232-
return paddle::prim::scale<DescTensor>(x, y, 0.0f, true);
246+
if (FLAGS_enable_new_ir_api) {
247+
return paddle::primitive::backend::scale<LazyTensor>(x, y, 0.0f, true);
248+
} else {
249+
return paddle::prim::scale<DescTensor>(x, y, 0.0f, true);
250+
}
233251
}
234252
235253
Tensor StaticTensorOperants::divide(const Tensor& x, const Scalar& y) {
236-
return paddle::prim::divide<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
254+
if (FLAGS_enable_new_ir_api) {
255+
return paddle::primitive::backend::divide<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
256+
} else {
257+
return paddle::prim::divide<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
258+
}
237259
}
238260
239261
Tensor StaticTensorOperants::add(const Scalar& x, const Tensor& y) {
240-
return paddle::prim::add<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
262+
if (FLAGS_enable_new_ir_api) {
263+
return paddle::primitive::backend::add<LazyTensor>(paddle::primitive::backend::full<LazyTensor>(y.shape(), x, y.dtype(), y.place()), y);
264+
} else {
265+
return paddle::prim::add<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
266+
}
241267
}
242268
269+
243270
Tensor StaticTensorOperants::subtract(const Scalar& x, const Tensor& y) {
244-
return paddle::prim::subtract<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
271+
if (FLAGS_enable_new_ir_api) {
272+
return paddle::primitive::backend::subtract<LazyTensor>(paddle::primitive::backend::full<LazyTensor>(y.shape(), x, y.dtype(), y.place()), y);
273+
} else {
274+
return paddle::prim::subtract<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
275+
}
245276
}
246277
247278
Tensor StaticTensorOperants::multiply(const Scalar& x, const Tensor& y) {
248-
return paddle::prim::scale<DescTensor>(y, x, 0.0f, true);
279+
if (FLAGS_enable_new_ir_api) {
280+
return paddle::primitive::backend::scale<LazyTensor>(y, x, 0.0f, true);
281+
} else {
282+
return paddle::prim::scale<DescTensor>(y, x, 0.0f, true);
283+
}
249284
}
250285
251286
Tensor StaticTensorOperants::divide(const Scalar& x, const Tensor& y) {
252-
return paddle::prim::divide<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
287+
if (FLAGS_enable_new_ir_api) {
288+
return paddle::primitive::backend::divide<LazyTensor>(paddle::primitive::backend::full<LazyTensor>(y.shape(), x, y.dtype(), y.place()), y);
289+
} else {
290+
return paddle::prim::divide<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
291+
}
253292
}
254293
255294
Tensor StaticTensorOperants::pow(const Tensor& x, const Tensor& y) {
256-
return paddle::prim::elementwise_pow<DescTensor>(x, y);
295+
if (FLAGS_enable_new_ir_api) {
296+
return paddle::primitive::backend::elementwise_pow<LazyTensor>(x, y);
297+
} else {
298+
return paddle::prim::elementwise_pow<DescTensor>(x, y);
299+
}
257300
}
258301
259302
Tensor StaticTensorOperants::pow(const Tensor& x, const Scalar& y) {
260-
return paddle::prim::elementwise_pow<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
303+
if (FLAGS_enable_new_ir_api) {
304+
return paddle::primitive::backend::elementwise_pow<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
305+
} else {
306+
return paddle::prim::elementwise_pow<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
307+
}
261308
}
262-
263309
"""
264310

265311

@@ -339,13 +385,21 @@ def gene_eager_tensor_operants_implementation(self):
339385

340386
def gene_static_tensor_func_call(self):
341387
api_func_name = self.get_api_func_name()
342-
388+
backend_static_func_name = (
389+
'paddle::primitive::backend::' + api_func_name + '<LazyTensor>'
390+
)
343391
prim_static_func_name = (
344392
'paddle::prim::' + api_func_name + '<DescTensor>'
345393
)
346-
prim_static_func_parameters = self.get_func_args()
394+
static_func_parameters = self.get_func_args()
395+
396+
static_tensor_func_call = f"""if (FLAGS_enable_new_ir_api) {{
397+
return {backend_static_func_name}({static_func_parameters});
398+
}} else {{
399+
return {prim_static_func_name}({static_func_parameters});
400+
}}"""
347401

348-
return f"""return {prim_static_func_name}({prim_static_func_parameters});"""
402+
return static_tensor_func_call
349403

350404
def gene_static_tensor_operants_implementation(self):
351405
api_code = ""

paddle/fluid/prim/utils/static/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ cc_library(
66
cc_library(
77
static_tensor_operants
88
SRCS static_tensor_operants.cc
9-
DEPS static_prim_api)
9+
DEPS static_prim_api primitive_backend_static_experimental)

paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
#include "paddle/fluid/primitive/type/lazy_tensor.h"
1111
#include "paddle/fluid/primitive/utils/utils.h"
1212
#include "paddle/ir/core/operation.h"
13+
#include "paddle/phi/core/flags.h"
1314

15+
PHI_DECLARE_string(tensor_operants_mode);
1416

1517
namespace paddle {
1618
namespace primitive {
@@ -95,6 +97,7 @@ for (size_t i=0; i< stop_gradients[0].size(); i++ ) {
9597
{% endmacro %}
9698

9799
{% macro body_prim(api) %}
100+
FLAGS_tensor_operants_mode = "static";
98101
{% for i in range(api.outputs|length) %}
99102
{% if api.outputs[i].typename=='Tensor' %}
100103
paddle::Tensor* {{api.outputs[i].name}} = !stop_gradients[{{i}}][0] ? &vjp_res[{{i}}][0] : nullptr;

paddle/fluid/primitive/rule/vjp/details.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,7 @@ void divide_grad(const Tensor& x,
3939
Tensor* dy) {
4040
if (dy) {
4141
// dy = -(x/y^2) * dout
42-
auto denominator =
43-
elementwise_pow<T>(y, full<T>(y.shape(), 2.0, y.dtype(), y.place()));
44-
auto dy_res = scale<T>(
45-
multiply<T>(divide<T>(x, denominator), out_grad), -1.0, 0.0, true);
42+
auto dy_res = -(x / y.pow(2.0)) * out_grad;
4643
if (x.dims() != y.dims()) {
4744
// Maybe need reduce here
4845
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
@@ -61,7 +58,7 @@ void divide_grad(const Tensor& x,
6158
if (dx) {
6259
// dx = (1/y) * dout
6360
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
64-
auto dx_res = multiply<T>(divide<T>(one_tensor, y), out_grad);
61+
auto dx_res = one_tensor / y * out_grad;
6562
if (y.dims() != x.dims()) {
6663
// Maybe need reduce here
6764
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());

paddle/phi/core/extended_tensor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ DataType ExtendedTensor::dtype() const {
3838

3939
DataLayout ExtendedTensor::layout() const {
4040
PADDLE_THROW(phi::errors::Unavailable(
41-
"ExtendedTensor does not support `dtype` method."));
41+
"ExtendedTensor does not support `layout` method."));
4242
}
4343

4444
bool ExtendedTensor::valid() const {

test/prim/new_ir_prim/test_vjp_prim.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class TestVjpPrim(unittest.TestCase):
6363
def test_divide_grad_prim_case1(self):
6464
newir_program = get_ir_divide_program()
6565
paddle.framework.core._set_prim_backward_enabled(True)
66+
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
6667
dout = newir_program.block().ops[-2].result(0)
6768
out_grads = [[dout]]
6869
stop_gradients = [[False], [False]]
@@ -83,9 +84,9 @@ def test_divide_grad_prim_case1(self):
8384
"pd.full",
8485
"pd.elementwise_pow",
8586
"pd.divide",
86-
"pd.multiply",
8787
"pd.full",
8888
"pd.scale",
89+
"pd.multiply",
8990
"pd.full_int_array",
9091
"pd.sum",
9192
"pd.full_int_array",
@@ -101,6 +102,7 @@ def test_divide_grad_prim_case1(self):
101102
for idx, op in enumerate(newir_program.block().ops):
102103
self.assertEqual(op.name(), all_op_names[idx])
103104
paddle.framework.core._set_prim_backward_enabled(False)
105+
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
104106

105107
def test_divide_grad_no_prim(self):
106108
newir_program = get_ir_divide_program()
@@ -123,6 +125,7 @@ def test_divide_grad_no_prim(self):
123125
def test_sum_grad_prim(self):
124126
newir_program = get_ir_sum_program()
125127
paddle.framework.core._set_prim_backward_enabled(True)
128+
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
126129
dout = newir_program.block().ops[-3].result(0)
127130
out_grads = [[dout]]
128131
stop_gradients = [[False], [True]]
@@ -147,6 +150,7 @@ def test_sum_grad_prim(self):
147150
for idx, op in enumerate(newir_program.block().ops):
148151
self.assertEqual(op.name(), all_op_names[idx])
149152
paddle.framework.core._set_prim_backward_enabled(False)
153+
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
150154

151155
def test_sum_grad_no_prim(self):
152156
newir_program = get_ir_sum_program()

0 commit comments

Comments
 (0)