Skip to content

Commit a1f3dd6

Browse files
authored
[PIR]Choose op by value type in PIR apis (#59605)
1 parent ed1a8ac commit a1f3dd6

File tree

1 file changed

+157
-41
lines changed

1 file changed

+157
-41
lines changed

paddle/fluid/pir/dialect/op_generator/api_gen.py

Lines changed: 157 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import argparse
1616
import os
1717
import re
18+
import subprocess
1819

1920
import yaml
2021
from op_gen import (
@@ -73,16 +74,30 @@
7374

7475
API_IMPL_TEMPLATE = """
7576
{ret_type} {api_name}({args}){{
77+
{inner_code}
78+
}}
79+
80+
"""
81+
82+
API_INNER_CODE_TEMPLATE = """
7683
{check_data_type}
7784
{handle_optional_inputs}
7885
{in_combine}
7986
{compute_op}
8087
{handle_optional_outputs}
8188
{out_split}
82-
{return_result}
83-
}}
89+
{return_result}"""
90+
91+
92+
OP_DISPATCH_TEMPLATE = """
93+
if ({cond}) {{
94+
{inner_code}
95+
}}"""
96+
97+
OP_DISPATCH_ERROR_TEMPLATE = """
98+
PADDLE_THROW(phi::errors::Unimplemented(
99+
"The kernel of ({op_name}) for input Value is unimplemented, please check the type of input Value."));"""
84100

85-
"""
86101

87102
CHECK_DATA_TYPE_TEMPLATE = """
88103
{function}({inputs}, "{op_name}");"""
@@ -231,9 +246,11 @@ def _is_optional_input(self, op_info, input_name):
231246
return True
232247
return False
233248

234-
def _is_optional_output(self, op_info, op_name, output_name):
235-
if op_name.endswith(('_grad', '_grad_')):
236-
return False
249+
def _is_optional_output(self, op_info, output_name):
250+
op_names = op_info.op_phi_name
251+
for name in op_names:
252+
if name.endswith(('_grad', '_grad_')):
253+
return False
237254
inplace_map = op_info.inplace_map
238255
input_optional_list = op_info.input_optional_list
239256
input_name_list = op_info.input_name_list
@@ -307,7 +324,7 @@ def _gen_api_args(
307324
)
308325
return (inputs + ', ' + attrs).strip(', ')
309326

310-
def _gen_ret_type(self, op_info, op_name):
327+
def _gen_ret_type(self, op_info):
311328
name_list = op_info.output_name_list
312329
type_list = op_info.output_type_list
313330
intermediate_list = op_info.output_intermediate_list
@@ -321,15 +338,15 @@ def _gen_ret_type(self, op_info, op_name):
321338
):
322339
if intermediate == 'true':
323340
continue
324-
if self._is_optional_output(op_info, op_name, name):
341+
if self._is_optional_output(op_info, name):
325342
ret.append(OPTIONAL_OUTPUT_TYPE_MAP[type])
326343
else:
327344
ret.append(OUTPUT_TYPE_MAP[type])
328345
return 'std::tuple<{}>'.format(', '.join(ret))
329346
elif output_num == 1:
330347
index = intermediate_list.index('false')
331348
name = name_list[index]
332-
if self._is_optional_output(op_info, op_name, name):
349+
if self._is_optional_output(op_info, name):
333350
return OPTIONAL_OUTPUT_TYPE_MAP[type_list[index]]
334351
else:
335352
return OUTPUT_TYPE_MAP[type_list[index]]
@@ -340,7 +357,7 @@ def _gen_one_declare(
340357
self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr
341358
):
342359
return API_DECLARE_TEMPLATE.format(
343-
ret_type=self._gen_ret_type(op_info, op_name),
360+
ret_type=self._gen_ret_type(op_info),
344361
api_name=op_name,
345362
args=self._gen_api_args(
346363
op_info, True, is_mutable_attr, is_vector_mutable_attr
@@ -403,7 +420,7 @@ def _gen_handle_optional_outputs(self, op_info, op_name):
403420
):
404421
if intermediate == 'true':
405422
continue
406-
if self._is_optional_output(op_info, op_name, name):
423+
if self._is_optional_output(op_info, name):
407424
if VECTOR_TYPE in type:
408425
ret += OPTIONAL_VECTOR_OPRESULT_OUTPUT_TEMPLATE.format(
409426
name=name,
@@ -497,7 +514,7 @@ def _gen_compute_op(
497514
op_inst_name,
498515
)
499516

500-
def _gen_out_split_and_ret_list(self, op_info, op_name, op_inst_name):
517+
def _gen_out_split_and_ret_list(self, op_info, op_inst_name):
501518
name_list = op_info.output_name_list
502519
type_list = op_info.output_type_list
503520
intermediate_list = op_info.output_intermediate_list
@@ -516,7 +533,7 @@ def _gen_out_split_and_ret_list(self, op_info, op_name, op_inst_name):
516533
):
517534
if intermediate == 'true':
518535
continue
519-
if self._is_optional_output(op_info, op_name, name):
536+
if self._is_optional_output(op_info, name):
520537
ret_list.append(f'optional_{name}')
521538
elif VECTOR_TYPE in type:
522539
split_op_name = f'{name}_split_op'
@@ -648,35 +665,129 @@ def _gen_check_data_type(self, op_info, op_name):
648665
def _gen_one_impl(
649666
self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr
650667
):
651-
ret_type = self._gen_ret_type(op_info, op_name)
652-
in_combine, in_combine_op_list = self._gen_in_combine(
653-
op_info, is_mutable_attr, is_vector_mutable_attr
654-
)
655-
compute_op, op_inst_name = self._gen_compute_op(
656-
op_info, op_name, in_combine_op_list, is_mutable_attr
657-
)
658-
if ret_type == 'void':
659-
compute_op += f' (void){op_inst_name};'
668+
ret = ''
669+
dispatch_kernel = None
670+
if op_info.kernel_map and 'dispatch' in op_info.kernel_map:
671+
dispatch_kernel = op_info.kernel_map['dispatch']
672+
673+
if dispatch_kernel and len(dispatch_kernel.keys()) > 1:
674+
api_inner_code = ''
675+
for kernel_name in dispatch_kernel.keys():
676+
dispatch_input_type = dispatch_kernel[kernel_name][0]
677+
input_name = op_info.input_name_list
678+
input_optional = op_info.input_optional_list
679+
cond_list = []
680+
for i, type in enumerate(dispatch_input_type):
681+
name = input_name[i]
682+
optional = input_optional[i]
683+
if type == 'dense':
684+
if optional == 'true':
685+
cond_list.append(
686+
f'(!{name} || {name}->type().isa<paddle::dialect::DenseTensorType>())'
687+
)
688+
else:
689+
cond_list.append(
690+
f'{name}.type().isa<paddle::dialect::DenseTensorType>()'
691+
)
692+
elif type == 'selected_rows':
693+
if optional == 'true':
694+
cond_list.append(
695+
f'(!{name} || {name}->type().isa<paddle::dialect::SelectedRowsType>())'
696+
)
697+
else:
698+
cond_list.append(
699+
f'{name}.type().isa<paddle::dialect::SelectedRowsType>()'
700+
)
701+
702+
ret_type = self._gen_ret_type(op_info)
703+
in_combine, in_combine_op_list = self._gen_in_combine(
704+
op_info, is_mutable_attr, is_vector_mutable_attr
705+
)
660706

661-
out_split, ret_list = self._gen_out_split_and_ret_list(
662-
op_info, op_name, op_inst_name
663-
)
664-
ret = API_IMPL_TEMPLATE.format(
665-
check_data_type=self._gen_check_data_type(op_info, op_name),
666-
ret_type=ret_type,
667-
api_name=op_name,
668-
args=self._gen_api_args(
669-
op_info, False, is_mutable_attr, is_vector_mutable_attr
670-
),
671-
handle_optional_inputs=self._gen_handle_optional_inputs(op_info),
672-
in_combine=in_combine,
673-
compute_op=compute_op,
674-
handle_optional_outputs=self._gen_handle_optional_outputs(
675-
op_info, op_name
676-
),
677-
out_split=out_split,
678-
return_result=self._gen_return_result(ret_list),
679-
)
707+
if op_name.endswith('_') and not kernel_name.endswith('_'):
708+
kernel_name = kernel_name + '_'
709+
compute_op, op_inst_name = self._gen_compute_op(
710+
op_info, kernel_name, in_combine_op_list, is_mutable_attr
711+
)
712+
if ret_type == 'void':
713+
compute_op += f' (void){op_inst_name};'
714+
715+
out_split, ret_list = self._gen_out_split_and_ret_list(
716+
op_info, op_inst_name
717+
)
718+
719+
if_inner_code = API_INNER_CODE_TEMPLATE.format(
720+
check_data_type=self._gen_check_data_type(
721+
op_info, kernel_name
722+
),
723+
handle_optional_inputs=self._gen_handle_optional_inputs(
724+
op_info
725+
),
726+
in_combine=in_combine,
727+
compute_op=compute_op,
728+
handle_optional_outputs=self._gen_handle_optional_outputs(
729+
op_info, kernel_name
730+
),
731+
out_split=out_split,
732+
return_result=self._gen_return_result(ret_list),
733+
)
734+
735+
if_inner_code = if_inner_code.split('\n')
736+
if_inner_code = '\n'.join(
737+
[' ' + code for code in if_inner_code]
738+
)
739+
740+
api_inner_code += OP_DISPATCH_TEMPLATE.format(
741+
cond=' && '.join(cond_list), inner_code=if_inner_code
742+
)
743+
744+
api_inner_code += OP_DISPATCH_ERROR_TEMPLATE.format(op_name=op_name)
745+
ret = API_IMPL_TEMPLATE.format(
746+
ret_type=ret_type,
747+
api_name=op_name,
748+
args=self._gen_api_args(
749+
op_info, False, is_mutable_attr, is_vector_mutable_attr
750+
),
751+
inner_code=api_inner_code,
752+
)
753+
754+
else:
755+
ret_type = self._gen_ret_type(op_info)
756+
in_combine, in_combine_op_list = self._gen_in_combine(
757+
op_info, is_mutable_attr, is_vector_mutable_attr
758+
)
759+
compute_op, op_inst_name = self._gen_compute_op(
760+
op_info, op_name, in_combine_op_list, is_mutable_attr
761+
)
762+
if ret_type == 'void':
763+
compute_op += f' (void){op_inst_name};'
764+
765+
out_split, ret_list = self._gen_out_split_and_ret_list(
766+
op_info, op_inst_name
767+
)
768+
769+
api_inner_code = API_INNER_CODE_TEMPLATE.format(
770+
check_data_type=self._gen_check_data_type(op_info, op_name),
771+
handle_optional_inputs=self._gen_handle_optional_inputs(
772+
op_info
773+
),
774+
in_combine=in_combine,
775+
compute_op=compute_op,
776+
handle_optional_outputs=self._gen_handle_optional_outputs(
777+
op_info, op_name
778+
),
779+
out_split=out_split,
780+
return_result=self._gen_return_result(ret_list),
781+
)
782+
783+
ret = API_IMPL_TEMPLATE.format(
784+
ret_type=ret_type,
785+
api_name=op_name,
786+
args=self._gen_api_args(
787+
op_info, False, is_mutable_attr, is_vector_mutable_attr
788+
),
789+
inner_code=api_inner_code,
790+
)
680791

681792
ret = re.sub(r' +\n', "", ret)
682793
return ret
@@ -722,6 +833,11 @@ def gen_h_and_cpp_file(
722833

723834
self._gen_h_file(op_info_items, namespaces, h_file_path)
724835
self._gen_cpp_file(op_info_items, namespaces, cpp_file_path)
836+
try:
837+
subprocess.run(['clang-format', '-i', h_file_path])
838+
subprocess.run(['clang-format', '-i', cpp_file_path])
839+
except Exception as e:
840+
pass
725841

726842

727843
def ParseArguments():

0 commit comments

Comments
 (0)