15
15
import argparse
16
16
import os
17
17
import re
18
+ import subprocess
18
19
19
20
import yaml
20
21
from op_gen import (
73
74
74
75
API_IMPL_TEMPLATE = """
75
76
{ret_type} {api_name}({args}){{
77
+ {inner_code}
78
+ }}
79
+
80
+ """
81
+
82
+ API_INNER_CODE_TEMPLATE = """
76
83
{check_data_type}
77
84
{handle_optional_inputs}
78
85
{in_combine}
79
86
{compute_op}
80
87
{handle_optional_outputs}
81
88
{out_split}
82
- {return_result}
83
- }}
89
+ {return_result}"""
90
+
91
+
92
+ OP_DISPATCH_TEMPLATE = """
93
+ if ({cond}) {{
94
+ {inner_code}
95
+ }}"""
96
+
97
+ OP_DISPATCH_ERROR_TEMPLATE = """
98
+ PADDLE_THROW(phi::errors::Unimplemented(
99
+ "The kernel of ({op_name}) for input Value is unimplemented, please check the type of input Value."));"""
84
100
85
- """
86
101
87
102
CHECK_DATA_TYPE_TEMPLATE = """
88
103
{function}({inputs}, "{op_name}");"""
@@ -231,9 +246,11 @@ def _is_optional_input(self, op_info, input_name):
231
246
return True
232
247
return False
233
248
234
- def _is_optional_output (self , op_info , op_name , output_name ):
235
- if op_name .endswith (('_grad' , '_grad_' )):
236
- return False
249
+ def _is_optional_output (self , op_info , output_name ):
250
+ op_names = op_info .op_phi_name
251
+ for name in op_names :
252
+ if name .endswith (('_grad' , '_grad_' )):
253
+ return False
237
254
inplace_map = op_info .inplace_map
238
255
input_optional_list = op_info .input_optional_list
239
256
input_name_list = op_info .input_name_list
@@ -307,7 +324,7 @@ def _gen_api_args(
307
324
)
308
325
return (inputs + ', ' + attrs ).strip (', ' )
309
326
310
- def _gen_ret_type (self , op_info , op_name ):
327
+ def _gen_ret_type (self , op_info ):
311
328
name_list = op_info .output_name_list
312
329
type_list = op_info .output_type_list
313
330
intermediate_list = op_info .output_intermediate_list
@@ -321,15 +338,15 @@ def _gen_ret_type(self, op_info, op_name):
321
338
):
322
339
if intermediate == 'true' :
323
340
continue
324
- if self ._is_optional_output (op_info , op_name , name ):
341
+ if self ._is_optional_output (op_info , name ):
325
342
ret .append (OPTIONAL_OUTPUT_TYPE_MAP [type ])
326
343
else :
327
344
ret .append (OUTPUT_TYPE_MAP [type ])
328
345
return 'std::tuple<{}>' .format (', ' .join (ret ))
329
346
elif output_num == 1 :
330
347
index = intermediate_list .index ('false' )
331
348
name = name_list [index ]
332
- if self ._is_optional_output (op_info , op_name , name ):
349
+ if self ._is_optional_output (op_info , name ):
333
350
return OPTIONAL_OUTPUT_TYPE_MAP [type_list [index ]]
334
351
else :
335
352
return OUTPUT_TYPE_MAP [type_list [index ]]
@@ -340,7 +357,7 @@ def _gen_one_declare(
340
357
self , op_info , op_name , is_mutable_attr , is_vector_mutable_attr
341
358
):
342
359
return API_DECLARE_TEMPLATE .format (
343
- ret_type = self ._gen_ret_type (op_info , op_name ),
360
+ ret_type = self ._gen_ret_type (op_info ),
344
361
api_name = op_name ,
345
362
args = self ._gen_api_args (
346
363
op_info , True , is_mutable_attr , is_vector_mutable_attr
@@ -403,7 +420,7 @@ def _gen_handle_optional_outputs(self, op_info, op_name):
403
420
):
404
421
if intermediate == 'true' :
405
422
continue
406
- if self ._is_optional_output (op_info , op_name , name ):
423
+ if self ._is_optional_output (op_info , name ):
407
424
if VECTOR_TYPE in type :
408
425
ret += OPTIONAL_VECTOR_OPRESULT_OUTPUT_TEMPLATE .format (
409
426
name = name ,
@@ -497,7 +514,7 @@ def _gen_compute_op(
497
514
op_inst_name ,
498
515
)
499
516
500
- def _gen_out_split_and_ret_list (self , op_info , op_name , op_inst_name ):
517
+ def _gen_out_split_and_ret_list (self , op_info , op_inst_name ):
501
518
name_list = op_info .output_name_list
502
519
type_list = op_info .output_type_list
503
520
intermediate_list = op_info .output_intermediate_list
@@ -516,7 +533,7 @@ def _gen_out_split_and_ret_list(self, op_info, op_name, op_inst_name):
516
533
):
517
534
if intermediate == 'true' :
518
535
continue
519
- if self ._is_optional_output (op_info , op_name , name ):
536
+ if self ._is_optional_output (op_info , name ):
520
537
ret_list .append (f'optional_{ name } ' )
521
538
elif VECTOR_TYPE in type :
522
539
split_op_name = f'{ name } _split_op'
@@ -648,35 +665,129 @@ def _gen_check_data_type(self, op_info, op_name):
648
665
def _gen_one_impl (
649
666
self , op_info , op_name , is_mutable_attr , is_vector_mutable_attr
650
667
):
651
- ret_type = self ._gen_ret_type (op_info , op_name )
652
- in_combine , in_combine_op_list = self ._gen_in_combine (
653
- op_info , is_mutable_attr , is_vector_mutable_attr
654
- )
655
- compute_op , op_inst_name = self ._gen_compute_op (
656
- op_info , op_name , in_combine_op_list , is_mutable_attr
657
- )
658
- if ret_type == 'void' :
659
- compute_op += f' (void){ op_inst_name } ;'
668
+ ret = ''
669
+ dispatch_kernel = None
670
+ if op_info .kernel_map and 'dispatch' in op_info .kernel_map :
671
+ dispatch_kernel = op_info .kernel_map ['dispatch' ]
672
+
673
+ if dispatch_kernel and len (dispatch_kernel .keys ()) > 1 :
674
+ api_inner_code = ''
675
+ for kernel_name in dispatch_kernel .keys ():
676
+ dispatch_input_type = dispatch_kernel [kernel_name ][0 ]
677
+ input_name = op_info .input_name_list
678
+ input_optional = op_info .input_optional_list
679
+ cond_list = []
680
+ for i , type in enumerate (dispatch_input_type ):
681
+ name = input_name [i ]
682
+ optional = input_optional [i ]
683
+ if type == 'dense' :
684
+ if optional == 'true' :
685
+ cond_list .append (
686
+ f'(!{ name } || { name } ->type().isa<paddle::dialect::DenseTensorType>())'
687
+ )
688
+ else :
689
+ cond_list .append (
690
+ f'{ name } .type().isa<paddle::dialect::DenseTensorType>()'
691
+ )
692
+ elif type == 'selected_rows' :
693
+ if optional == 'true' :
694
+ cond_list .append (
695
+ f'(!{ name } || { name } ->type().isa<paddle::dialect::SelectedRowsType>())'
696
+ )
697
+ else :
698
+ cond_list .append (
699
+ f'{ name } .type().isa<paddle::dialect::SelectedRowsType>()'
700
+ )
701
+
702
+ ret_type = self ._gen_ret_type (op_info )
703
+ in_combine , in_combine_op_list = self ._gen_in_combine (
704
+ op_info , is_mutable_attr , is_vector_mutable_attr
705
+ )
660
706
661
- out_split , ret_list = self ._gen_out_split_and_ret_list (
662
- op_info , op_name , op_inst_name
663
- )
664
- ret = API_IMPL_TEMPLATE .format (
665
- check_data_type = self ._gen_check_data_type (op_info , op_name ),
666
- ret_type = ret_type ,
667
- api_name = op_name ,
668
- args = self ._gen_api_args (
669
- op_info , False , is_mutable_attr , is_vector_mutable_attr
670
- ),
671
- handle_optional_inputs = self ._gen_handle_optional_inputs (op_info ),
672
- in_combine = in_combine ,
673
- compute_op = compute_op ,
674
- handle_optional_outputs = self ._gen_handle_optional_outputs (
675
- op_info , op_name
676
- ),
677
- out_split = out_split ,
678
- return_result = self ._gen_return_result (ret_list ),
679
- )
707
+ if op_name .endswith ('_' ) and not kernel_name .endswith ('_' ):
708
+ kernel_name = kernel_name + '_'
709
+ compute_op , op_inst_name = self ._gen_compute_op (
710
+ op_info , kernel_name , in_combine_op_list , is_mutable_attr
711
+ )
712
+ if ret_type == 'void' :
713
+ compute_op += f' (void){ op_inst_name } ;'
714
+
715
+ out_split , ret_list = self ._gen_out_split_and_ret_list (
716
+ op_info , op_inst_name
717
+ )
718
+
719
+ if_inner_code = API_INNER_CODE_TEMPLATE .format (
720
+ check_data_type = self ._gen_check_data_type (
721
+ op_info , kernel_name
722
+ ),
723
+ handle_optional_inputs = self ._gen_handle_optional_inputs (
724
+ op_info
725
+ ),
726
+ in_combine = in_combine ,
727
+ compute_op = compute_op ,
728
+ handle_optional_outputs = self ._gen_handle_optional_outputs (
729
+ op_info , kernel_name
730
+ ),
731
+ out_split = out_split ,
732
+ return_result = self ._gen_return_result (ret_list ),
733
+ )
734
+
735
+ if_inner_code = if_inner_code .split ('\n ' )
736
+ if_inner_code = '\n ' .join (
737
+ [' ' + code for code in if_inner_code ]
738
+ )
739
+
740
+ api_inner_code += OP_DISPATCH_TEMPLATE .format (
741
+ cond = ' && ' .join (cond_list ), inner_code = if_inner_code
742
+ )
743
+
744
+ api_inner_code += OP_DISPATCH_ERROR_TEMPLATE .format (op_name = op_name )
745
+ ret = API_IMPL_TEMPLATE .format (
746
+ ret_type = ret_type ,
747
+ api_name = op_name ,
748
+ args = self ._gen_api_args (
749
+ op_info , False , is_mutable_attr , is_vector_mutable_attr
750
+ ),
751
+ inner_code = api_inner_code ,
752
+ )
753
+
754
+ else :
755
+ ret_type = self ._gen_ret_type (op_info )
756
+ in_combine , in_combine_op_list = self ._gen_in_combine (
757
+ op_info , is_mutable_attr , is_vector_mutable_attr
758
+ )
759
+ compute_op , op_inst_name = self ._gen_compute_op (
760
+ op_info , op_name , in_combine_op_list , is_mutable_attr
761
+ )
762
+ if ret_type == 'void' :
763
+ compute_op += f' (void){ op_inst_name } ;'
764
+
765
+ out_split , ret_list = self ._gen_out_split_and_ret_list (
766
+ op_info , op_inst_name
767
+ )
768
+
769
+ api_inner_code = API_INNER_CODE_TEMPLATE .format (
770
+ check_data_type = self ._gen_check_data_type (op_info , op_name ),
771
+ handle_optional_inputs = self ._gen_handle_optional_inputs (
772
+ op_info
773
+ ),
774
+ in_combine = in_combine ,
775
+ compute_op = compute_op ,
776
+ handle_optional_outputs = self ._gen_handle_optional_outputs (
777
+ op_info , op_name
778
+ ),
779
+ out_split = out_split ,
780
+ return_result = self ._gen_return_result (ret_list ),
781
+ )
782
+
783
+ ret = API_IMPL_TEMPLATE .format (
784
+ ret_type = ret_type ,
785
+ api_name = op_name ,
786
+ args = self ._gen_api_args (
787
+ op_info , False , is_mutable_attr , is_vector_mutable_attr
788
+ ),
789
+ inner_code = api_inner_code ,
790
+ )
680
791
681
792
ret = re .sub (r' +\n' , "" , ret )
682
793
return ret
@@ -722,6 +833,11 @@ def gen_h_and_cpp_file(
722
833
723
834
self ._gen_h_file (op_info_items , namespaces , h_file_path )
724
835
self ._gen_cpp_file (op_info_items , namespaces , cpp_file_path )
836
+ try :
837
+ subprocess .run (['clang-format' , '-i' , h_file_path ])
838
+ subprocess .run (['clang-format' , '-i' , cpp_file_path ])
839
+ except Exception as e :
840
+ pass
725
841
726
842
727
843
def ParseArguments ():
0 commit comments