@@ -211,6 +211,11 @@ class StaticTensorOperants : public TensorOperantsBase {
211
211
#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
212
212
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
213
213
214
+ #include "paddle/fluid/primitive/backend/backend.h"
215
+ #include "paddle/fluid/primitive/type/lazy_tensor.h"
216
+
217
+ PHI_DECLARE_bool(enable_new_ir_api);
218
+
214
219
"""
215
220
216
221
@@ -219,47 +224,88 @@ class StaticTensorOperants : public TensorOperantsBase {
219
224
220
225
namespace prim {
221
226
using DescTensor = paddle::prim::DescTensor;
227
+ using LazyTensor = paddle::primitive::LazyTensor;
222
228
223
229
Tensor StaticTensorOperants::add(const Tensor& x, const Scalar& y) {
224
- return paddle::prim::add<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
230
+ if (FLAGS_enable_new_ir_api) {
231
+ return paddle::primitive::backend::add<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
232
+ } else {
233
+ return paddle::prim::add<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
234
+ }
225
235
}
226
236
227
237
Tensor StaticTensorOperants::subtract(const Tensor& x, const Scalar& y) {
228
- return paddle::prim::subtract<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
238
+ if (FLAGS_enable_new_ir_api) {
239
+ return paddle::primitive::backend::subtract<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
240
+ } else {
241
+ return paddle::prim::subtract<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
242
+ }
229
243
}
230
244
231
245
Tensor StaticTensorOperants::multiply(const Tensor& x, const Scalar& y) {
232
- return paddle::prim::scale<DescTensor>(x, y, 0.0f, true);
246
+ if (FLAGS_enable_new_ir_api) {
247
+ return paddle::primitive::backend::scale<LazyTensor>(x, y, 0.0f, true);
248
+ } else {
249
+ return paddle::prim::scale<DescTensor>(x, y, 0.0f, true);
250
+ }
233
251
}
234
252
235
253
Tensor StaticTensorOperants::divide(const Tensor& x, const Scalar& y) {
236
- return paddle::prim::divide<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
254
+ if (FLAGS_enable_new_ir_api) {
255
+ return paddle::primitive::backend::divide<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
256
+ } else {
257
+ return paddle::prim::divide<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
258
+ }
237
259
}
238
260
239
261
Tensor StaticTensorOperants::add(const Scalar& x, const Tensor& y) {
240
- return paddle::prim::add<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
262
+ if (FLAGS_enable_new_ir_api) {
263
+ return paddle::primitive::backend::add<LazyTensor>(paddle::primitive::backend::full<LazyTensor>(y.shape(), x, y.dtype(), y.place()), y);
264
+ } else {
265
+ return paddle::prim::add<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
266
+ }
241
267
}
242
268
269
+
243
270
Tensor StaticTensorOperants::subtract(const Scalar& x, const Tensor& y) {
244
- return paddle::prim::subtract<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
271
+ if (FLAGS_enable_new_ir_api) {
272
+ return paddle::primitive::backend::subtract<LazyTensor>(paddle::primitive::backend::full<LazyTensor>(y.shape(), x, y.dtype(), y.place()), y);
273
+ } else {
274
+ return paddle::prim::subtract<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
275
+ }
245
276
}
246
277
247
278
Tensor StaticTensorOperants::multiply(const Scalar& x, const Tensor& y) {
248
- return paddle::prim::scale<DescTensor>(y, x, 0.0f, true);
279
+ if (FLAGS_enable_new_ir_api) {
280
+ return paddle::primitive::backend::scale<LazyTensor>(y, x, 0.0f, true);
281
+ } else {
282
+ return paddle::prim::scale<DescTensor>(y, x, 0.0f, true);
283
+ }
249
284
}
250
285
251
286
Tensor StaticTensorOperants::divide(const Scalar& x, const Tensor& y) {
252
- return paddle::prim::divide<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
287
+ if (FLAGS_enable_new_ir_api) {
288
+ return paddle::primitive::backend::divide<LazyTensor>(paddle::primitive::backend::full<LazyTensor>(y.shape(), x, y.dtype(), y.place()), y);
289
+ } else {
290
+ return paddle::prim::divide<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
291
+ }
253
292
}
254
293
255
294
Tensor StaticTensorOperants::pow(const Tensor& x, const Tensor& y) {
256
- return paddle::prim::elementwise_pow<DescTensor>(x, y);
295
+ if (FLAGS_enable_new_ir_api) {
296
+ return paddle::primitive::backend::elementwise_pow<LazyTensor>(x, y);
297
+ } else {
298
+ return paddle::prim::elementwise_pow<DescTensor>(x, y);
299
+ }
257
300
}
258
301
259
302
Tensor StaticTensorOperants::pow(const Tensor& x, const Scalar& y) {
260
- return paddle::prim::elementwise_pow<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
303
+ if (FLAGS_enable_new_ir_api) {
304
+ return paddle::primitive::backend::elementwise_pow<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
305
+ } else {
306
+ return paddle::prim::elementwise_pow<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
307
+ }
261
308
}
262
-
263
309
"""
264
310
265
311
@@ -339,13 +385,21 @@ def gene_eager_tensor_operants_implementation(self):
339
385
340
386
def gene_static_tensor_func_call (self ):
341
387
api_func_name = self .get_api_func_name ()
342
-
388
+ backend_static_func_name = (
389
+ 'paddle::primitive::backend::' + api_func_name + '<LazyTensor>'
390
+ )
343
391
prim_static_func_name = (
344
392
'paddle::prim::' + api_func_name + '<DescTensor>'
345
393
)
346
- prim_static_func_parameters = self .get_func_args ()
394
+ static_func_parameters = self .get_func_args ()
395
+
396
+ static_tensor_func_call = f"""if (FLAGS_enable_new_ir_api) {{
397
+ return { backend_static_func_name } ({ static_func_parameters } );
398
+ }} else {{
399
+ return { prim_static_func_name } ({ static_func_parameters } );
400
+ }}"""
347
401
348
- return f"""return { prim_static_func_name } ( { prim_static_func_parameters } );"""
402
+ return static_tensor_func_call
349
403
350
404
def gene_static_tensor_operants_implementation (self ):
351
405
api_code = ""
0 commit comments