@@ -58,6 +58,43 @@ class TensorOperantsBase {
58
58
59
59
"""
60
60
61
+ tensor_api_source_include = """// Generated by paddle/phi/api/yaml/generator/tensor_gen.py
62
+
63
+ #include "paddle/phi/api/include/tensor.h"
64
+
65
+ #include "paddle/phi/api/include/operants_manager.h"
66
+
67
+ """
68
+
69
+ tensor_api_source_start = """
70
+ namespace paddle {
71
+
72
+ namespace experimental {
73
+
74
+ Tensor Tensor::operator+(const Tensor &other) const {
75
+ return add(other);
76
+ }
77
+
78
+ Tensor Tensor::operator-(const Tensor &other) const {
79
+ return subtract(other);
80
+ }
81
+
82
+ Tensor Tensor::operator*(const Tensor &other) const {
83
+ return multiply(other);
84
+ }
85
+
86
+ Tensor Tensor::operator/(const Tensor &other) const {
87
+ return divide(other);
88
+ }
89
+ """
90
+
91
+
92
+ tensor_api_source_end = """
93
+ } // namespace experimental
94
+ } // namespace paddle
95
+
96
+ """
97
+
61
98
62
99
operants_header_include = """// Generated by paddle/phi/api/yaml/generator/tensor_gen.py
63
100
@@ -231,6 +268,40 @@ def gene_operants_base(self):
231
268
else :
232
269
return f"""
233
270
{ indent } virtual { self .get_return_type (inplace_flag = True )} { api_func_name } ({ self .get_declare_args (inplace_flag = True )} ) = 0;
271
+ """
272
+
273
+ def get_define_args_without_first_tensor (self , inplace_flag = False ):
274
+ # NOTE(HongyuJia): consider vector<Tensor> becomes first input argument.
275
+ define_args = self .get_input_tensor_args (inplace_flag )
276
+ assert (
277
+ len (define_args ) > 1
278
+ ), "Can't use tensor api without Tensor inputs"
279
+ for name in self .attrs ['names' ]:
280
+ define_args .append (self .attrs ['attr_info' ][name ][0 ] + ' ' + name )
281
+ # remove first Tensor argument
282
+ return ", " .join (define_args [1 :])
283
+
284
+ def gene_tensor_api_implementation (self ):
285
+ func_name = self .get_api_func_name ()
286
+ assert (
287
+ len (self .inputs ['names' ]) > 1
288
+ ), "Can't use tensor api without Tensor inputs"
289
+ # remove first Tensor argument
290
+ func_args = self .inputs ['names' ][1 :] + self .attrs ['names' ]
291
+ func_args_code = ", " .join (func_args )
292
+ # func decalaration
293
+ if func_name [- 1 ] != '_' :
294
+ return f"""
295
+ { self .get_return_type ()} Tensor::{ func_name } ({ self .get_define_args_without_first_tensor ()} ) const {{
296
+ { indent } return paddle::OperantsManager::Instance().{ func_name } (static_cast<const Tensor &>(*this), { func_args_code } );
297
+ }}
298
+ """
299
+ else :
300
+ return f"""
301
+ { self .get_return_type (inplace_flag = True )} Tensor::{ func_name } ({ self .get_define_args_without_first_tensor (inplace_flag = True )} ) const {{
302
+ { indent } return paddle::OperantsManager::Instance().{ func_name } (static_cast<const Tensor &>(*this), { func_args_code } );
303
+ }}
304
+
234
305
"""
235
306
236
307
def gene_operants_declaration (self ):
@@ -318,6 +389,7 @@ def gene_operants_manager_implementation(self):
318
389
def generate_tensor_operants_api (
319
390
api_yaml_path ,
320
391
operants_base_path ,
392
+ tensor_api_source_path ,
321
393
operants_header_path ,
322
394
operants_source_path ,
323
395
operants_manager_header_path ,
@@ -332,13 +404,16 @@ def generate_tensor_operants_api(
332
404
apis .extend (api_list )
333
405
334
406
operants_base_file = open (operants_base_path , 'w' )
407
+ tensor_api_source_file = open (tensor_api_source_path , 'w' )
335
408
operants_header_file = open (operants_header_path , 'w' )
336
409
operants_source_file = open (operants_source_path , 'w' )
337
410
operants_manager_header_file = open (operants_manager_header_path , 'w' )
338
411
operants_manager_source_file = open (operants_manager_source_path , 'w' )
339
412
340
413
operants_base_file .write (operants_base_include )
341
414
operants_base_file .write (operants_base_start )
415
+ tensor_api_source_file .write (tensor_api_source_include )
416
+ tensor_api_source_file .write (tensor_api_source_start )
342
417
operants_header_file .write (operants_header_include )
343
418
operants_header_file .write (operants_header_start )
344
419
operants_source_file .write (operants_source_include )
@@ -355,6 +430,9 @@ def generate_tensor_operants_api(
355
430
operants_api = OperantsAPI (api , api_prims )
356
431
if operants_api .is_prim_api :
357
432
operants_base_file .write (operants_api .gene_operants_base ())
433
+ tensor_api_source_file .write (
434
+ operants_api .gene_tensor_api_implementation ()
435
+ )
358
436
operants_header_file .write (operants_api .gene_operants_declaration ())
359
437
operants_source_file .write (
360
438
operants_api .gene_operants_implementation ()
@@ -367,12 +445,14 @@ def generate_tensor_operants_api(
367
445
)
368
446
369
447
operants_base_file .write (operants_base_end )
448
+ tensor_api_source_file .write (tensor_api_source_end )
370
449
operants_header_file .write (operants_header_end )
371
450
operants_source_file .write (operants_source_end )
372
451
operants_manager_header_file .write (operants_manager_header_end )
373
452
operants_manager_source_file .write (operants_manager_source_end )
374
453
375
454
operants_base_file .close ()
455
+ tensor_api_source_file .close ()
376
456
operants_header_file .close ()
377
457
operants_source_file .close ()
378
458
operants_manager_header_file .close ()
@@ -396,6 +476,12 @@ def main():
396
476
default = 'paddle/phi/api/include/operants_base.h' ,
397
477
)
398
478
479
+ parser .add_argument (
480
+ '--tensor_api_source_path' ,
481
+ help = 'output of generated tensor_api source code file' ,
482
+ default = 'paddle/phi/api/lib/tensor_api.cc' ,
483
+ )
484
+
399
485
parser .add_argument (
400
486
'--phi_tensor_operants_header_path' ,
401
487
help = 'output of generated phi_tensor_operants header code file' ,
@@ -424,6 +510,7 @@ def main():
424
510
425
511
api_yaml_path = options .api_yaml_path
426
512
operants_base_path = options .operants_base_path
513
+ tensor_api_source_path = options .tensor_api_source_path
427
514
operants_header_path = options .phi_tensor_operants_header_path
428
515
operants_source_path = options .phi_tensor_operants_source_path
429
516
operants_manager_header_path = options .operants_manager_header_path
@@ -432,6 +519,7 @@ def main():
432
519
generate_tensor_operants_api (
433
520
api_yaml_path ,
434
521
operants_base_path ,
522
+ tensor_api_source_path ,
435
523
operants_header_path ,
436
524
operants_source_path ,
437
525
operants_manager_header_path ,
0 commit comments