Skip to content

Commit 3c4cd69

Browse files
authored
1 parent 76b9c05 commit 3c4cd69

File tree

2 files changed

+70
-10
lines changed

2 files changed

+70
-10
lines changed

paddle/fluid/pir/dialect/CMakeLists.txt

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ set(op_src_files_tmp
7979

8080
set(op_vjp_src_file_tmp ${op_vjp_source_file_tmp})
8181

82+
set(op_cc_split_num 4)
83+
set(bwd_op_cc_split_num 2)
84+
8285
# Auto code gen
8386
execute_process(
8487
COMMAND ${PYTHON_EXECUTABLE} ${op_parse_file} --op_yaml_path
@@ -95,15 +98,22 @@ execute_process(
9598
--op_compat_yaml_file ${op_compat_yaml_file} --namespaces ${op_namespace}
9699
--dialect_name ${dialect_name} --op_def_h_file ${op_header_file_tmp}
97100
--op_info_file ${op_info_file_tmp} --op_def_cc_file ${op_src_files_tmp}
98-
--op_vjp_cc_file ${op_vjp_src_file_tmp} --with_distributed
99-
${WITH_DISTRIBUTE})
101+
--op_vjp_cc_file ${op_vjp_src_file_tmp} --op_cc_split_num
102+
${op_cc_split_num} --bwd_op_cc_split_num ${bwd_op_cc_split_num}
103+
--with_distributed ${WITH_DISTRIBUTE})
104+
105+
set(split_op_source_files
106+
${PIR_DIALECT_BINARY_DIR}/pd_op1.cc ${PIR_DIALECT_BINARY_DIR}/pd_op2.cc
107+
${PIR_DIALECT_BINARY_DIR}/pd_op3.cc ${PIR_DIALECT_BINARY_DIR}/pd_op4.cc)
108+
set(split_bwd_op_source_files ${PIR_DIALECT_BINARY_DIR}/pd_op_bwd1.cc
109+
${PIR_DIALECT_BINARY_DIR}/pd_op_bwd2.cc)
100110

101111
set(generated_files_pd_op
102112
"${op_header_file}"
103113
"${op_info_file}"
104-
"${op_source_file}"
114+
"${split_op_source_files}"
115+
"${split_bwd_op_source_files}"
105116
"${op_vjp_source_file}"
106-
"${bwd_op_source_file}"
107117
"${fused_op_source_file}"
108118
"${bwd_fused_op_source_file}"
109119
"${pir_op_source_file}"
@@ -247,8 +257,8 @@ set(op_dialect_srcs
247257
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/op_attribute.cc
248258
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/op_type.cc
249259
${op_info_file}
250-
${op_source_file}
251-
${bwd_op_source_file}
260+
${split_op_source_files}
261+
${split_bwd_op_source_files}
252262
${fused_op_source_file}
253263
${bwd_fused_op_source_file}
254264
${pir_op_source_file}

paddle/fluid/pir/dialect/op_generator/op_gen.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import argparse
1616
import logging
17+
import math
1718
import os
1819
import pathlib
1920
import sys
@@ -1130,6 +1131,21 @@ def get_mutable_attribute_grad_semantic(op_info, op_info_items):
11301131
return mutable_attribute_grad_semantics
11311132

11321133

1134+
def split_ops(op_info_items: dict, cc_file, split_nums):
1135+
op_list = list(op_info_items.keys())
1136+
ops_max_size = math.ceil(len(op_list) / split_nums)
1137+
split_op_info_items = []
1138+
for i in range(split_nums):
1139+
split_op_info_items.append({})
1140+
for i, op_name in enumerate(op_list):
1141+
list_idx = math.ceil((i + 1) / ops_max_size) - 1
1142+
split_op_info_items[list_idx][op_name] = op_info_items[op_name]
1143+
split_cc_files = []
1144+
for i in range(split_nums):
1145+
split_cc_files.append(cc_file.replace(".cc", f"{i + 1}.cc"))
1146+
return split_op_info_items, split_cc_files
1147+
1148+
11331149
def GenOneDnnExtraAttrsDefaultValue(onednn_extra_args):
11341150
INTARRAY_STR_TEMPLATE = """ pir::Attribute attr_{attr_name} = {op_attribute_type}::get(pir::IrContext::Instance(), phi::IntArray({attr}));
11351151
"""
@@ -2080,6 +2096,8 @@ def OpGenerator(
20802096
op_info_file,
20812097
op_def_cc_file,
20822098
op_vjp_cc_file,
2099+
op_cc_split_num,
2100+
bwd_op_cc_split_num,
20832101
onednn_yaml_file,
20842102
ops_onednn_extra_yaml_file,
20852103
):
@@ -2126,9 +2144,11 @@ def OpGenerator(
21262144

21272145
op_infos = []
21282146
all_op_info_items = {}
2147+
new_op_def_cc_file = []
21292148
first_file = True
21302149
onednn_only_op_list = []
2131-
for yaml_file in op_yaml_files:
2150+
for idx in range(len(op_yaml_files)):
2151+
yaml_file = op_yaml_files[idx]
21322152
op_yaml_items = []
21332153
with open(yaml_file, "r") as f:
21342154
ops = yaml.safe_load(f)
@@ -2194,13 +2214,37 @@ def OpGenerator(
21942214
key_suffix = '_sp' if item.is_sparse_op else ''
21952215
op_info_items[op['name'] + key_suffix] = item
21962216
all_op_info_items[op['name'] + key_suffix] = item
2197-
op_infos.append(op_info_items)
2217+
2218+
if dialect_name != "onednn_op":
2219+
cc_file = op_def_cc_file[idx]
2220+
if (
2221+
yaml_file.split('/')[-1] == "ops.parsed.yaml"
2222+
and op_cc_split_num is not None
2223+
):
2224+
split_op_info_items, split_cc_files = split_ops(
2225+
op_info_items, cc_file, op_cc_split_num
2226+
)
2227+
op_infos.extend(split_op_info_items)
2228+
new_op_def_cc_file.extend(split_cc_files)
2229+
elif (
2230+
yaml_file.split('/')[-1] == "backward.parsed.yaml"
2231+
and bwd_op_cc_split_num is not None
2232+
):
2233+
split_op_info_items, split_cc_files = split_ops(
2234+
op_info_items, cc_file, bwd_op_cc_split_num
2235+
)
2236+
op_infos.extend(split_op_info_items)
2237+
new_op_def_cc_file.extend(split_cc_files)
2238+
else:
2239+
op_infos.append(op_info_items)
2240+
new_op_def_cc_file.append(cc_file)
21982241

21992242
if first_file:
22002243
first_file = False
22012244

22022245
if dialect_name == "onednn_op":
22032246
op_infos = [all_op_info_items]
2247+
new_op_def_cc_file = op_def_cc_file
22042248
# (3) auto code gen
22052249
op_list_strs = []
22062250
declare_type_id_strs = []
@@ -2329,7 +2373,7 @@ def OpGenerator(
23292373
f.write(op_info_str)
23302374

23312375
# (6) write to files for xx_op.cc.tmp
2332-
for id in range(len(op_def_cc_file)):
2376+
for id in range(len(new_op_def_cc_file)):
23332377
source_file_str = source_file_strs[id]
23342378
for name in reversed(namespaces):
23352379
source_file_str = NAMESPACE_GARD_TEMPLATE.format(
@@ -2349,7 +2393,7 @@ def OpGenerator(
23492393
input=source_file_str,
23502394
define_type_id=define_type_id_strs[id],
23512395
)
2352-
with open(op_def_cc_file[id], 'w') as f:
2396+
with open(new_op_def_cc_file[id], 'w') as f:
23532397
f.write(source_file_str)
23542398

23552399
# (6) write to files for xx_vjp_op.cc.tmp
@@ -2381,6 +2425,8 @@ def ParseArguments():
23812425
parser.add_argument('--op_info_file', type=str)
23822426
parser.add_argument('--op_def_cc_file', type=str)
23832427
parser.add_argument('--op_vjp_cc_file', type=str)
2428+
parser.add_argument('--op_cc_split_num', type=int)
2429+
parser.add_argument('--bwd_op_cc_split_num', type=int)
23842430
parser.add_argument('--onednn_yaml_file', type=str)
23852431
parser.add_argument('--ops_onednn_extra_yaml_file', type=str)
23862432
parser.add_argument('--with_distributed', type=strtobool)
@@ -2403,6 +2449,8 @@ def ParseArguments():
24032449
op_info_file = args.op_info_file
24042450
op_def_cc_files = args.op_def_cc_file.split(",")
24052451
op_vjp_cc_file = args.op_vjp_cc_file
2452+
op_cc_split_num = args.op_cc_split_num
2453+
bwd_op_cc_split_num = args.bwd_op_cc_split_num
24062454
onednn_yaml_file = args.onednn_yaml_file
24072455
ops_onednn_extra_yaml_file = args.ops_onednn_extra_yaml_file
24082456

@@ -2417,6 +2465,8 @@ def ParseArguments():
24172465
op_info_file,
24182466
op_def_cc_files,
24192467
op_vjp_cc_file,
2468+
op_cc_split_num,
2469+
bwd_op_cc_split_num,
24202470
onednn_yaml_file,
24212471
ops_onednn_extra_yaml_file,
24222472
)

0 commit comments

Comments
 (0)