Skip to content

Commit 4652bee

Browse files
authored
Split generated_op.cc into 4 src files [generated_op(1-4).cc] (#50985)
* split generated_op.cc into 4 src files * fix bug * fix compile on windows
1 parent bb5dd20 commit 4652bee

File tree

4 files changed

+75
-17
lines changed

4 files changed

+75
-17
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ paddle/fluid/pybind/eager_op_function.cc
8181
tools/nvcc_lazy
8282

8383
# these files (directories) are generated before build system generation
84-
paddle/fluid/operators/generated_op.cc
84+
paddle/fluid/operators/generated_op*.cc
8585
paddle/fluid/operators/generated_sparse_op.cc
8686
paddle/fluid/operators/generated_static_op.cc
8787
paddle/phi/ops/compat/generated_*.cc

paddle/fluid/operators/CMakeLists.txt

+3-2
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,10 @@ endif()
9797

9898
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_utils backward_infermeta sparse_backward_infermeta static_prim_api)
9999

100-
register_operators(EXCLUDES py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op
100+
register_operators(EXCLUDES py_func_op warpctc_op dgc_op generated_op1 generated_op2 generated_op3 generated_op4 load_combine_op lstm_op run_program_op eye_op quantize_linear_op
101101
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
102102

103+
op_library(generated_op UNITY SRCS generated_op1.cc generated_op2.cc generated_op3.cc generated_op4.cc DEPS ${OP_HEADER_DEPS})
103104
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc run_program_op_npu.cc DEPS executor_cache ${OP_HEADER_DEPS})
104105
target_link_libraries(run_program_op cuda_graph_with_memory_pool)
105106
op_library(quantize_linear_op DEPS phi)
@@ -200,7 +201,7 @@ elseif(WITH_ROCM)
200201
else()
201202
cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3)
202203
endif()
203-
cc_test(share_buffer_op_cpp_test SRCS share_buffer_op_test.cc DEPS lod_tensor device_context generated_op)
204+
cc_test(share_buffer_op_cpp_test SRCS share_buffer_op_test.cc DEPS lod_tensor device_context generated_static_op)
204205

205206
cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS})
206207
if (WITH_PYTHON)

paddle/fluid/operators/generator/CMakeLists.txt

+26-6
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,14 @@ endif()
3030
# parse ops
3131
set(parsed_op_dir
3232
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops)
33-
set(generated_op_path
34-
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc)
33+
set(generated_op_path_1
34+
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op1.cc)
35+
set(generated_op_path_2
36+
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op2.cc)
37+
set(generated_op_path_3
38+
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op3.cc)
39+
set(generated_op_path_4
40+
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op4.cc)
3541
set(generated_static_op_path
3642
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_static_op.cc)
3743
set(generated_sparse_ops_path
@@ -118,7 +124,7 @@ endif()
118124

119125
# code generation for op, op makers, and argument mapping functions
120126
message(
121-
"create or remove auto-geneated operators: ${generated_op_path}.tmp
127+
"create or remove auto-geneated operators: generated_op(1-4).cc.tmp
122128
create or remove auto-geneated argument mappings: ${generated_argument_mapping_path}.tmp"
123129
)
124130
execute_process(
@@ -129,8 +135,9 @@ execute_process(
129135
./parsed_ops/backward_ops.parsed.yaml --op_version_yaml_path
130136
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_version.yaml
131137
--op_compat_yaml_path ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml
132-
--output_op_path "${generated_op_path}.tmp" --output_arg_map_path
133-
"${generated_argument_mapping_path}.tmp"
138+
--output_op_path "${generated_op_path_1}.tmp" "${generated_op_path_2}.tmp"
139+
"${generated_op_path_3}.tmp" "${generated_op_path_4}.tmp"
140+
--output_arg_map_path "${generated_argument_mapping_path}.tmp"
134141
RESULT_VARIABLE _result)
135142
if(${_result})
136143
message(FATAL_ERROR "operator codegen failed, exiting.")
@@ -165,7 +172,10 @@ if(${_result})
165172
endif()
166173

167174
set(generated_static_files
168-
"${generated_op_path}"
175+
"${generated_op_path_1}"
176+
"${generated_op_path_2}"
177+
"${generated_op_path_3}"
178+
"${generated_op_path_4}"
169179
"${generated_static_op_path}"
170180
"${generated_sparse_ops_path}"
171181
"${generated_argument_mapping_path}"
@@ -192,6 +202,16 @@ foreach(generated_static_file ${generated_static_files})
192202
endif()
193203
endforeach()
194204

205+
# Note(zyfncg): The generated file generated_op.cc has been deleted,
206+
# so we need to clear the generated_op.cc and generated_op.cc.tmp cached in develop environment.
207+
set(old_generated_op_path
208+
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc)
209+
if(EXISTS "${old_generated_op_path}" OR EXISTS "${old_generated_op_path}.tmp")
210+
execute_process(
211+
COMMAND ${CMAKE_COMMAND} -E remove -f "${old_generated_op_path}"
212+
"${old_generated_op_path}.tmp")
213+
endif()
214+
195215
# op extra info file
196216
set(ops_extra_info_gen_file
197217
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator/ops_extra_info_gen.py)

paddle/fluid/operators/generator/generate_op.py

+45-8
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import argparse
16+
import math
1617
import os
1718
from pathlib import Path
1819

@@ -478,6 +479,29 @@ def parse_drop_empty_grad(op_fluid_list: list, bw_op_dict: dict):
478479
] = False
479480

480481

482+
def split_ops_list(ops, backward_op_dict, split_num):
483+
new_ops_list = []
484+
new_bw_ops_list = []
485+
list_size = math.ceil(len(ops) / split_num)
486+
tmp_ops_list = []
487+
tmp_bw_ops_list = []
488+
for idx, op in enumerate(ops):
489+
tmp_ops_list.append(op)
490+
current_op = op
491+
while (
492+
'backward' in current_op
493+
and current_op['backward'] in backward_op_dict
494+
):
495+
tmp_bw_ops_list.append(backward_op_dict[current_op['backward']])
496+
current_op = backward_op_dict[current_op['backward']]
497+
if (idx + 1) % list_size == 0 or idx == len(ops) - 1:
498+
new_ops_list.append(tmp_ops_list)
499+
new_bw_ops_list.append(tmp_bw_ops_list)
500+
tmp_ops_list = []
501+
tmp_bw_ops_list = []
502+
return new_ops_list, new_bw_ops_list
503+
504+
481505
def main(
482506
ops_yaml_path,
483507
backward_yaml_path,
@@ -548,13 +572,23 @@ def main(
548572
os.remove(output_arg_map_path)
549573
return
550574
op_template = env.get_template('op.c.j2')
551-
with open(output_op_path, "wt") as f:
552-
msg = op_template.render(
553-
ops=ops,
554-
backward_ops=backward_ops,
555-
op_dict=op_dict,
556-
)
557-
f.write(msg)
575+
576+
backward_fluid_op_dict = {}
577+
for bw_op in backward_ops:
578+
backward_fluid_op_dict[bw_op['op_name']] = bw_op
579+
output_op_files_num = len(output_op_path)
580+
new_ops_list, new_bw_ops_list = split_ops_list(
581+
ops, backward_fluid_op_dict, output_op_files_num
582+
)
583+
for idx, output_op_file in enumerate(output_op_path):
584+
with open(output_op_file, "wt") as f:
585+
msg = op_template.render(
586+
ops=new_ops_list[idx],
587+
backward_ops=new_bw_ops_list[idx],
588+
op_dict=op_dict,
589+
)
590+
f.write(msg)
591+
558592
ks_template = env.get_template('ks.c.j2')
559593
with open(output_arg_map_path, 'wt') as f:
560594
msg = ks_template.render(ops=ops, backward_ops=backward_ops)
@@ -578,7 +612,10 @@ def main(
578612
'--op_version_yaml_path', type=str, help="ops version yaml file."
579613
)
580614
parser.add_argument(
581-
"--output_op_path", type=str, help="path to save generated operators."
615+
"--output_op_path",
616+
type=str,
617+
nargs='+',
618+
help="path to save generated operators.",
582619
)
583620
parser.add_argument(
584621
"--output_arg_map_path",

0 commit comments

Comments
 (0)