Skip to content

Commit d884573

Browse files
Support bw invoke fw (#50260)
* support bw invoke fw * fix scale in static_backward.yaml * fix the bug in tensorrt/convert * move 'scale','sign' into ops.yaml * add scale_grad of scale in op_compat.yaml * change generated_static_op in CMakeLists.txt
1 parent 9af23f1 commit d884573

File tree

20 files changed

+130
-327
lines changed

20 files changed

+130
-327
lines changed

paddle/fluid/eager/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,6 @@ cc_library(
7777
op_registry
7878
variable_helper
7979
memcpy
80-
scale_op
80+
generated_op
8181
autograd_meta
8282
hook_utils)

paddle/fluid/eager/tests/performance_tests/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ if(NOT (NOT WITH_PYTHON AND ON_INFER))
77
${generated_deps}
88
eager_scale
99
scale_node
10-
scale_op
10+
generated_op
1111
matmul_v2_op
1212
dygraph_function
1313
eager_prim_api)

paddle/fluid/framework/CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,7 @@ if(WITH_PSCORE)
10511051
heter_pipeline_trainer_test
10521052
SRCS heter_pipeline_trainer_test.cc
10531053
DEPS conditional_block_op
1054-
scale_op
1054+
generated_op
10551055
heter_listen_and_serv_op
10561056
executor
10571057
heter_server
@@ -1068,7 +1068,7 @@ if(WITH_PSCORE)
10681068
heter_pipeline_trainer_test
10691069
SRCS heter_pipeline_trainer_test.cc
10701070
DEPS conditional_block_op
1071-
scale_op
1071+
generated_op
10721072
heter_listen_and_serv_op
10731073
executor
10741074
heter_server

paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,5 +76,5 @@ cc_library(
7676
cc_test(
7777
test_reference_count_pass_last_lived_ops
7878
SRCS test_reference_count_pass_last_lived_ops.cc
79-
DEPS parallel_executor elementwise_mul_op elementwise_add_op scale_op
79+
DEPS parallel_executor elementwise_mul_op elementwise_add_op generated_op
8080
eigen_function)

paddle/fluid/jit/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ if(WITH_TESTING AND NOT WIN32)
6565
reduce_mean_op
6666
feed_op
6767
fetch_op
68-
scale_op
68+
generated_op
6969
transfer_layout_op
7070
jit_layer)
7171
cc_test(

paddle/fluid/operators/generator/CMakeLists.txt

+20-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ set(legacy_bw_op_yaml_file
1111
set(sparse_op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/sparse_ops.yaml)
1212
set(sparse_bw_op_yaml_file
1313
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/sparse_backward.yaml)
14+
set(static_bw_op_yaml_file
15+
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/static_backward.yaml)
1416

1517
if(NOT PYTHONINTERP_FOUND)
1618
find_package(PythonInterp REQUIRED)
@@ -66,6 +68,9 @@ execute_process(
6668
COMMAND
6769
${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${sparse_bw_op_yaml_file}
6870
--output_path ./parsed_ops/sparse_backward.parsed.yaml --backward
71+
COMMAND
72+
${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${static_bw_op_yaml_file}
73+
--output_path ./parsed_ops/static_backward.parsed.yaml --backward
6974
RESULTS_VARIABLE _results)
7075
foreach(_result in ${_results})
7176
if(${_result})
@@ -82,14 +87,24 @@ execute_process(
8287
COMMAND
8388
${PYTHON_EXECUTABLE} cross_validate.py --forward_yaml_paths
8489
./parsed_ops/ops.parsed.yaml ./parsed_ops/legacy_ops.parsed.yaml
85-
./parsed_ops/static_ops.parsed.yaml --backward_yaml_paths
86-
./parsed_ops/backward_ops.parsed.yaml
90+
--backward_yaml_paths ./parsed_ops/backward_ops.parsed.yaml
8791
./parsed_ops/legacy_backward_ops.parsed.yaml
8892
RESULT_VARIABLE _result)
8993
if(${_result})
9094
message(FATAL_ERROR "ops validation failed, exiting.")
9195
endif()
9296

97+
execute_process(
98+
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
99+
COMMAND
100+
${PYTHON_EXECUTABLE} cross_validate.py --forward_yaml_paths
101+
./parsed_ops/static_ops.parsed.yaml --backward_yaml_paths
102+
./parsed_ops/static_backward.parsed.yaml
103+
RESULT_VARIABLE _result)
104+
if(${_result})
105+
message(FATAL_ERROR "static ops validation failed, exiting.")
106+
endif()
107+
93108
execute_process(
94109
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
95110
COMMAND
@@ -124,8 +139,9 @@ endif()
124139
execute_process(
125140
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
126141
COMMAND
127-
${PYTHON_EXECUTABLE} generate_static_op.py --ops_yaml_path
128-
./parsed_ops/static_ops.parsed.yaml --op_version_yaml_path
142+
${PYTHON_EXECUTABLE} generate_op.py --ops_yaml_path
143+
./parsed_ops/static_ops.parsed.yaml --backward_yaml_path
144+
./parsed_ops/static_backward.parsed.yaml --op_version_yaml_path
129145
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_version.yaml
130146
--op_compat_yaml_path ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml
131147
--output_op_path "${generated_static_op_path}.tmp" --output_arg_map_path

paddle/fluid/operators/generator/generate_op.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def process_invoke_op(forward_op_dict, backward_op_dict):
407407
invoke_op = bw_op['invoke']['func']
408408
args_list = bw_op['invoke']['args']
409409
args_index = 0
410+
# backward invoke forward
410411
if invoke_op in forward_op_dict:
411412
reuse_op = forward_op_dict[invoke_op]
412413
bw_op['invoke']['func'] = reuse_op['op_name']
@@ -460,17 +461,16 @@ def parse_drop_empty_grad(op_fluid_list: list, bw_op_dict: dict):
460461
for bw_name in op_op['backward'].split(',')
461462
]
462463
for bw_name in bw_names:
463-
assert (
464-
bw_name in bw_op_dict
465-
), f"backward {bw_name} is not existed"
466-
for out_grad in op_op['drop_empty_grad']:
467-
assert (
468-
out_grad in bw_op_dict[bw_name]['output_dict']
469-
), f'''
470-
{bw_name} with {out_grad} is not existed in output_dict '''
471-
bw_op_dict[bw_name]['output_dict'][out_grad][
472-
'drop_empty_grad'
473-
] = False
464+
# static_ops.yaml and ops.yaml use the common op_compat.yaml
465+
if bw_name in bw_op_dict:
466+
for out_grad in op_op['drop_empty_grad']:
467+
assert (
468+
out_grad in bw_op_dict[bw_name]['output_dict']
469+
), f'''
470+
{bw_name} with {out_grad} is not existed in output_dict '''
471+
bw_op_dict[bw_name]['output_dict'][out_grad][
472+
'drop_empty_grad'
473+
] = False
474474

475475

476476
def main(
@@ -493,7 +493,8 @@ def main(
493493
op_versions = yaml.safe_load(f)
494494
# add op version info into op
495495
for op_version in op_versions:
496-
forward_op_dict[op_version['op']]['version'] = op_version['version']
496+
if op_version['op'] in forward_op_dict:
497+
forward_op_dict[op_version['op']]['version'] = op_version['version']
497498

498499
with open(op_compat_yaml_path, "rt") as f:
499500
op_fluid_map_list = yaml.safe_load(f)

paddle/fluid/operators/generator/generate_static_op.py

+12
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def restruct_io(op):
8181

8282
def main(
8383
ops_yaml_path,
84+
backward_yaml_path,
8485
op_compat_yaml_path,
8586
op_version_yaml_path,
8687
output_op_path,
@@ -91,6 +92,11 @@ def main(
9192
ops = [restruct_io(op) for op in ops]
9293
forward_op_dict = to_named_dict(ops)
9394

95+
with open(backward_yaml_path, "rt") as f:
96+
backward_ops = yaml.safe_load(f)
97+
backward_ops = [restruct_io(op) for op in backward_ops]
98+
backward_op_dict = to_named_dict(backward_ops)
99+
94100
with open(op_version_yaml_path, "rt") as f:
95101
op_versions = yaml.safe_load(f)
96102

@@ -139,6 +145,11 @@ def main(
139145
parser.add_argument(
140146
'--ops_yaml_path', type=str, help="parsed static ops yaml file."
141147
)
148+
parser.add_argument(
149+
'--backward_yaml_path',
150+
type=str,
151+
help="parsed static backward ops yaml file.",
152+
)
142153
parser.add_argument(
143154
'--op_compat_yaml_path', type=str, help="ops args compat yaml file."
144155
)
@@ -157,6 +168,7 @@ def main(
157168
args = parser.parse_args()
158169
main(
159170
args.ops_yaml_path,
171+
args.backward_yaml_path,
160172
args.op_compat_yaml_path,
161173
args.op_version_yaml_path,
162174
args.output_op_path,

paddle/fluid/operators/generator/templates/operator_utils.c.j2

+25
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,31 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
522522
true)}});
523523
{% endfor %}
524524

525+
{% for attr in invoke_op["attrs"] %}
526+
{% set attr_name = attr["fluid_name"] %}
527+
{% set fw_attrs = forward_op["attrs"] %}
528+
{% if attr_name in forward_attr_names %}
529+
{# invoke_op's attrs and fw_attr's attrs must be the same#}
530+
{% set fw_attr = fw_attrs[loop.index0] %}
531+
{% if fw_attr["typename"] == "IntArray" %}
532+
{% if 'tensor_name' in attr or 'manual_flag' not in attr %}
533+
if (this->HasInput("{{fw_attr | to_int_array_tensor_name}}")) {
534+
grad_op->SetInput("{{fw_attr | to_int_array_tensor_name}}", this->Input("{{fw_attr | to_int_array_tensor_name}}"));
535+
}
536+
{% endif %}
537+
{% if 'tensors_name' in fw_attr or 'manual_flag' not in fw_attr %}
538+
if (this->HasInput("{{fw_attr | to_int_array_tensors_name}}")) {
539+
grad_op->SetInput("{{fw_attr | to_int_array_tensors_name}}", this->Input("{{fw_attr | to_int_array_tensors_name}}"));
540+
}
541+
{% endif %}
542+
{% elif fw_attr["typename"] == "Scalar" %}
543+
if (this->HasInput("{{fw_attr | to_scalar_tensor_name}}")) {
544+
grad_op->SetInput("{{fw_attr | to_scalar_tensor_name}}", this->Input("{{fw_attr | to_scalar_tensor_name}}"));
545+
}
546+
{% endif %}
547+
{% endif %}
548+
{% endfor %}
549+
525550
{% for attr in invoke_op["attrs"] %}
526551
grad_op->SetAttr("{{attr["fluid_name"]}}", {{attr["value"]}});
527552
{% endfor %}

paddle/fluid/operators/pscore/CMakeLists.txt

+6-6
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ cc_test_old(
8686
executor
8787
scope
8888
proto_desc
89-
scale_op
89+
generated_op
9090
eigen_function)
9191

9292
set_source_files_properties(
@@ -100,7 +100,7 @@ cc_test_old(
100100
executor
101101
scope
102102
proto_desc
103-
scale_op
103+
generated_op
104104
send_and_recv_op
105105
${RPC_DEPS}
106106
${DISTRIBUTE_DEPS}
@@ -117,7 +117,7 @@ cc_test_old(
117117
executor
118118
scope
119119
proto_desc
120-
scale_op
120+
generated_op
121121
send_and_recv_op
122122
${RPC_DEPS}
123123
${DISTRIBUTE_DEPS}
@@ -134,14 +134,14 @@ cc_test_old(
134134
executor
135135
scope
136136
proto_desc
137-
scale_op
137+
generated_op
138138
heter_listen_and_serv_op
139139
${RPC_DEPS}
140140
${DISTRIBUTE_DEPS}
141141
eigen_function)
142142

143143
#set_source_files_properties(heter_cloud_comm_cpu_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
144-
#cc_test(heter_cloud_comm_cpu_test SRCS heter_cloud_comm_cpu_test.cc DEPS executor scope proto_desc scale_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function)
144+
#cc_test(heter_cloud_comm_cpu_test SRCS heter_cloud_comm_cpu_test.cc DEPS executor scope proto_desc generated_static_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function)
145145

146146
set_source_files_properties(
147147
switch_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
@@ -153,7 +153,7 @@ cc_binary(
153153
executor
154154
scope
155155
proto_desc
156-
scale_op
156+
generated_op
157157
heter_listen_and_serv_op
158158
${RPC_DEPS}
159159
${DISTRIBUTE_DEPS}

paddle/fluid/operators/scale_op.cc

-118
This file was deleted.

0 commit comments

Comments
 (0)