Skip to content

[PIR]Choose op by value type in PIR apis #59605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 157 additions & 41 deletions paddle/fluid/pir/dialect/op_generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import argparse
import os
import re
import subprocess

import yaml
from op_gen import (
Expand Down Expand Up @@ -73,16 +74,30 @@

API_IMPL_TEMPLATE = """
{ret_type} {api_name}({args}){{
{inner_code}
}}

"""

API_INNER_CODE_TEMPLATE = """
{check_data_type}
{handle_optional_inputs}
{in_combine}
{compute_op}
{handle_optional_outputs}
{out_split}
{return_result}
}}
{return_result}"""


OP_DISPATCH_TEMPLATE = """
if ({cond}) {{
{inner_code}
}}"""

OP_DISPATCH_ERROR_TEMPLATE = """
PADDLE_THROW(phi::errors::Unimplemented(
"The kernel of ({op_name}) for input Value is unimplemented, please check the type of input Value."));"""

"""

CHECK_DATA_TYPE_TEMPLATE = """
{function}({inputs}, "{op_name}");"""
Expand Down Expand Up @@ -231,9 +246,11 @@ def _is_optional_input(self, op_info, input_name):
return True
return False

def _is_optional_output(self, op_info, op_name, output_name):
if op_name.endswith(('_grad', '_grad_')):
return False
def _is_optional_output(self, op_info, output_name):
op_names = op_info.op_phi_name
for name in op_names:
if name.endswith(('_grad', '_grad_')):
return False
inplace_map = op_info.inplace_map
input_optional_list = op_info.input_optional_list
input_name_list = op_info.input_name_list
Expand Down Expand Up @@ -307,7 +324,7 @@ def _gen_api_args(
)
return (inputs + ', ' + attrs).strip(', ')

def _gen_ret_type(self, op_info, op_name):
def _gen_ret_type(self, op_info):
name_list = op_info.output_name_list
type_list = op_info.output_type_list
intermediate_list = op_info.output_intermediate_list
Expand All @@ -321,15 +338,15 @@ def _gen_ret_type(self, op_info, op_name):
):
if intermediate == 'true':
continue
if self._is_optional_output(op_info, op_name, name):
if self._is_optional_output(op_info, name):
ret.append(OPTIONAL_OUTPUT_TYPE_MAP[type])
else:
ret.append(OUTPUT_TYPE_MAP[type])
return 'std::tuple<{}>'.format(', '.join(ret))
elif output_num == 1:
index = intermediate_list.index('false')
name = name_list[index]
if self._is_optional_output(op_info, op_name, name):
if self._is_optional_output(op_info, name):
return OPTIONAL_OUTPUT_TYPE_MAP[type_list[index]]
else:
return OUTPUT_TYPE_MAP[type_list[index]]
Expand All @@ -340,7 +357,7 @@ def _gen_one_declare(
self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr
):
return API_DECLARE_TEMPLATE.format(
ret_type=self._gen_ret_type(op_info, op_name),
ret_type=self._gen_ret_type(op_info),
api_name=op_name,
args=self._gen_api_args(
op_info, True, is_mutable_attr, is_vector_mutable_attr
Expand Down Expand Up @@ -403,7 +420,7 @@ def _gen_handle_optional_outputs(self, op_info, op_name):
):
if intermediate == 'true':
continue
if self._is_optional_output(op_info, op_name, name):
if self._is_optional_output(op_info, name):
if VECTOR_TYPE in type:
ret += OPTIONAL_VECTOR_OPRESULT_OUTPUT_TEMPLATE.format(
name=name,
Expand Down Expand Up @@ -497,7 +514,7 @@ def _gen_compute_op(
op_inst_name,
)

def _gen_out_split_and_ret_list(self, op_info, op_name, op_inst_name):
def _gen_out_split_and_ret_list(self, op_info, op_inst_name):
name_list = op_info.output_name_list
type_list = op_info.output_type_list
intermediate_list = op_info.output_intermediate_list
Expand All @@ -516,7 +533,7 @@ def _gen_out_split_and_ret_list(self, op_info, op_name, op_inst_name):
):
if intermediate == 'true':
continue
if self._is_optional_output(op_info, op_name, name):
if self._is_optional_output(op_info, name):
ret_list.append(f'optional_{name}')
elif VECTOR_TYPE in type:
split_op_name = f'{name}_split_op'
Expand Down Expand Up @@ -648,35 +665,129 @@ def _gen_check_data_type(self, op_info, op_name):
def _gen_one_impl(
self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr
):
ret_type = self._gen_ret_type(op_info, op_name)
in_combine, in_combine_op_list = self._gen_in_combine(
op_info, is_mutable_attr, is_vector_mutable_attr
)
compute_op, op_inst_name = self._gen_compute_op(
op_info, op_name, in_combine_op_list, is_mutable_attr
)
if ret_type == 'void':
compute_op += f' (void){op_inst_name};'
ret = ''
dispatch_kernel = None
if op_info.kernel_map and 'dispatch' in op_info.kernel_map:
dispatch_kernel = op_info.kernel_map['dispatch']

if dispatch_kernel and len(dispatch_kernel.keys()) > 1:
api_inner_code = ''
for kernel_name in dispatch_kernel.keys():
dispatch_input_type = dispatch_kernel[kernel_name][0]
input_name = op_info.input_name_list
input_optional = op_info.input_optional_list
cond_list = []
for i, type in enumerate(dispatch_input_type):
name = input_name[i]
optional = input_optional[i]
if type == 'dense':
if optional == 'true':
cond_list.append(
f'(!{name} || {name}->type().isa<paddle::dialect::DenseTensorType>())'
)
else:
cond_list.append(
f'{name}.type().isa<paddle::dialect::DenseTensorType>()'
)
elif type == 'selected_rows':
if optional == 'true':
cond_list.append(
f'(!{name} || {name}->type().isa<paddle::dialect::SelectedRowsType>())'
)
else:
cond_list.append(
f'{name}.type().isa<paddle::dialect::SelectedRowsType>()'
)

ret_type = self._gen_ret_type(op_info)
in_combine, in_combine_op_list = self._gen_in_combine(
op_info, is_mutable_attr, is_vector_mutable_attr
)

out_split, ret_list = self._gen_out_split_and_ret_list(
op_info, op_name, op_inst_name
)
ret = API_IMPL_TEMPLATE.format(
check_data_type=self._gen_check_data_type(op_info, op_name),
ret_type=ret_type,
api_name=op_name,
args=self._gen_api_args(
op_info, False, is_mutable_attr, is_vector_mutable_attr
),
handle_optional_inputs=self._gen_handle_optional_inputs(op_info),
in_combine=in_combine,
compute_op=compute_op,
handle_optional_outputs=self._gen_handle_optional_outputs(
op_info, op_name
),
out_split=out_split,
return_result=self._gen_return_result(ret_list),
)
if op_name.endswith('_') and not kernel_name.endswith('_'):
kernel_name = kernel_name + '_'
compute_op, op_inst_name = self._gen_compute_op(
op_info, kernel_name, in_combine_op_list, is_mutable_attr
)
if ret_type == 'void':
compute_op += f' (void){op_inst_name};'

out_split, ret_list = self._gen_out_split_and_ret_list(
op_info, op_inst_name
)

if_inner_code = API_INNER_CODE_TEMPLATE.format(
check_data_type=self._gen_check_data_type(
op_info, kernel_name
),
handle_optional_inputs=self._gen_handle_optional_inputs(
op_info
),
in_combine=in_combine,
compute_op=compute_op,
handle_optional_outputs=self._gen_handle_optional_outputs(
op_info, kernel_name
),
out_split=out_split,
return_result=self._gen_return_result(ret_list),
)

if_inner_code = if_inner_code.split('\n')
if_inner_code = '\n'.join(
[' ' + code for code in if_inner_code]
)

api_inner_code += OP_DISPATCH_TEMPLATE.format(
cond=' && '.join(cond_list), inner_code=if_inner_code
)

api_inner_code += OP_DISPATCH_ERROR_TEMPLATE.format(op_name=op_name)
ret = API_IMPL_TEMPLATE.format(
ret_type=ret_type,
api_name=op_name,
args=self._gen_api_args(
op_info, False, is_mutable_attr, is_vector_mutable_attr
),
inner_code=api_inner_code,
)

else:
ret_type = self._gen_ret_type(op_info)
in_combine, in_combine_op_list = self._gen_in_combine(
op_info, is_mutable_attr, is_vector_mutable_attr
)
compute_op, op_inst_name = self._gen_compute_op(
op_info, op_name, in_combine_op_list, is_mutable_attr
)
if ret_type == 'void':
compute_op += f' (void){op_inst_name};'

out_split, ret_list = self._gen_out_split_and_ret_list(
op_info, op_inst_name
)

api_inner_code = API_INNER_CODE_TEMPLATE.format(
check_data_type=self._gen_check_data_type(op_info, op_name),
handle_optional_inputs=self._gen_handle_optional_inputs(
op_info
),
in_combine=in_combine,
compute_op=compute_op,
handle_optional_outputs=self._gen_handle_optional_outputs(
op_info, op_name
),
out_split=out_split,
return_result=self._gen_return_result(ret_list),
)

ret = API_IMPL_TEMPLATE.format(
ret_type=ret_type,
api_name=op_name,
args=self._gen_api_args(
op_info, False, is_mutable_attr, is_vector_mutable_attr
),
inner_code=api_inner_code,
)

ret = re.sub(r' +\n', "", ret)
return ret
Expand Down Expand Up @@ -722,6 +833,11 @@ def gen_h_and_cpp_file(

self._gen_h_file(op_info_items, namespaces, h_file_path)
self._gen_cpp_file(op_info_items, namespaces, cpp_file_path)
try:
subprocess.run(['clang-format', '-i', h_file_path])
subprocess.run(['clang-format', '-i', cpp_file_path])
except Exception as e:
pass


def ParseArguments():
Expand Down