Skip to content

Commit ca1a707

Browse files
【dygraph】support python fused/sparse Grad api (#71029)
* grad_api2 * support sparse * modify dygraph_api
1 parent 1ecf5ee commit ca1a707

File tree

7 files changed

+45
-23
lines changed

7 files changed

+45
-23
lines changed

paddle/fluid/eager/auto_code_generator/generator/CMakeLists.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@ set(api_yaml_path
44
set(backward_yaml_path
55
"${PADDLE_SOURCE_DIR}/paddle/phi/ops/yaml/backward.yaml,${PADDLE_SOURCE_DIR}/paddle/phi/ops/yaml/inconsistent/dygraph_backward.yaml,${PADDLE_SOURCE_DIR}/paddle/phi/ops/yaml/sparse_backward.yaml,${PADDLE_SOURCE_DIR}/paddle/phi/ops/yaml/fused_backward.yaml"
66
)
7-
set(no_sparse_backward_yaml_path
8-
"${PADDLE_SOURCE_DIR}/paddle/phi/ops/yaml/backward.yaml,${PADDLE_SOURCE_DIR}/paddle/phi/ops/yaml/inconsistent/dygraph_backward.yaml"
9-
)
7+
108
set(tmp_forwards_cc_path
119
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/tmp_dygraph_functions.cc"
1210
)
@@ -70,7 +68,7 @@ add_custom_target(
7068
COMMAND
7169
"${PYTHON_EXECUTABLE}"
7270
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py"
73-
"--api_yaml_path=${api_yaml_path},${fwd_api_yaml_path},${no_sparse_backward_yaml_path}"
71+
"--api_yaml_path=${api_yaml_path},${fwd_api_yaml_path},${backward_yaml_path}"
7472
"--source_path=${tmp_python_c_source_path}"
7573
"--header_path=${tmp_python_c_header_path}"
7674
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_python_c_source_path}

paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,10 @@ def ReadFwdFile(filepath):
118118
contents = yaml.load(f, Loader=yaml.FullLoader)
119119
f.close()
120120
# not all fused ops support dygraph
121-
if filepath.endswith("fused_ops.yaml") is True:
121+
if (
122+
filepath.endswith("fused_ops.yaml") is True
123+
or filepath.endswith("fused_backward.yaml") is True
124+
):
122125
new_apis = [
123126
api
124127
for api in contents

paddle/fluid/eager/auto_code_generator/generator/eager_gen.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3357,8 +3357,12 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
33573357
generator_grad = DygraphForwardAndNodesGenerator(
33583358
backward_yaml_paths[i], backward_yaml_paths[i], all_bw, all_bw
33593359
)
3360-
else:
3360+
elif backward_yaml_path.endswith('/dygraph_backward.yaml'):
33613361
continue
3362+
else:
3363+
generator_grad = DygraphForwardAndNodesGenerator(
3364+
backward_yaml_path, backward_yaml_path
3365+
)
33623366

33633367
generator_grad.run(True)
33643368

paddle/fluid/pir/dialect/op_generator/ops_api_gen.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,6 @@ def __init__(self) -> None:
251251

252252
def _need_skip(self, op_info, op_name):
253253
if op_name.endswith("_grad"):
254-
if op_info.is_sparse_op or op_info.is_fused_op:
255-
return True
256254
if op_name.endswith(("double_grad", "_grad_grad", "triple_grad")):
257255
return True
258256
if op_name[:-5] in NO_NEED_GEN_STATIC_ONLY_APIS:

paddle/phi/api/generator/intermediate_api_gen.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import argparse
1616

1717
import yaml
18-
from api_gen import ForwardAPI
18+
from api_gen import ForwardAPI, backward_api_black_list
1919
from dist_api_gen import DistForwardAPI
2020
from sparse_api_gen import SparseAPI
2121

@@ -131,12 +131,17 @@ def generate_intermediate_api(
131131

132132
dygraph_header_file.write(sparse_namespace_pair[0])
133133
dygraph_source_file.write(sparse_namespace_pair[0])
134-
135-
with open(sparse_api_yaml_path, 'r') as f:
136-
sparse_apis = yaml.load(f, Loader=yaml.FullLoader)
134+
sparse_apis = []
135+
for each_sparse_api_yaml in sparse_api_yaml_path:
136+
with open(each_sparse_api_yaml, 'r') as f:
137+
sparse_api_list = yaml.load(f, Loader=yaml.FullLoader)
138+
if sparse_api_list:
139+
sparse_apis.extend(sparse_api_list)
137140

138141
for api in sparse_apis:
139142
sparse_api = SparseAPI(api)
143+
if sparse_api.api in backward_api_black_list:
144+
continue
140145
if sparse_api.is_dygraph_api:
141146
dygraph_header_file.write(sparse_api.gene_api_declaration())
142147
dygraph_source_file.write(sparse_api.gene_api_code())
@@ -164,6 +169,7 @@ def main():
164169

165170
parser.add_argument(
166171
'--sparse_api_yaml_path',
172+
nargs='+',
167173
help='path to sparse api yaml file',
168174
default='paddle/phi/ops/yaml/sparse_ops.yaml',
169175
)

paddle/phi/api/generator/sparse_api_gen.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import yaml
1818
from api_base import PREFIX_TENSOR_NAME
19-
from api_gen import ForwardAPI
19+
from api_gen import ForwardAPI, backward_api_black_list
2020

2121

2222
class SparseAPI(ForwardAPI):
@@ -438,11 +438,13 @@ def source_include(header_file_path):
438438
#include "paddle/phi/infermeta/binary.h"
439439
#include "paddle/phi/infermeta/ternary.h"
440440
#include "paddle/phi/infermeta/multiary.h"
441+
#include "paddle/phi/infermeta/backward.h"
441442
#include "paddle/utils/none.h"
442443
443444
#include "paddle/phi/infermeta/sparse/unary.h"
444445
#include "paddle/phi/infermeta/sparse/binary.h"
445446
#include "paddle/phi/infermeta/sparse/multiary.h"
447+
#include "paddle/phi/infermeta/sparse/backward.h"
446448
447449
COMMON_DECLARE_int32(low_precision_op_list);
448450
COMMON_DECLARE_bool(benchmark);
@@ -467,8 +469,13 @@ def api_namespace():
467469

468470

469471
def generate_api(api_yaml_path, header_file_path, source_file_path):
470-
with open(api_yaml_path, 'r') as f:
471-
apis = yaml.load(f, Loader=yaml.FullLoader)
472+
apis = []
473+
474+
for each_api_yaml in api_yaml_path:
475+
with open(each_api_yaml, 'r') as f:
476+
api_list = yaml.load(f, Loader=yaml.FullLoader)
477+
if api_list:
478+
apis.extend(api_list)
472479
header_file = open(header_file_path, 'w')
473480
source_file = open(source_file_path, 'w')
474481

@@ -483,7 +490,10 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
483490
source_file.write(namespace[0])
484491

485492
for api in apis:
493+
486494
sparse_api = SparseAPI(api)
495+
if sparse_api.api in backward_api_black_list:
496+
continue
487497
if sparse_api.is_dygraph_api:
488498
sparse_api.is_dygraph_api = False
489499
header_file.write(sparse_api.gene_api_declaration())
@@ -503,6 +513,7 @@ def main():
503513
parser.add_argument(
504514
'--api_yaml_path',
505515
help='path to sparse api yaml file',
516+
nargs='+',
506517
default='paddle/phi/ops/yaml/sparse_ops.yaml',
507518
)
508519

paddle/phi/api/lib/CMakeLists.txt

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,9 @@ endif()
202202
execute_process(
203203
COMMAND
204204
${PYTHON_EXECUTABLE} ${sparse_api_gen_file} --api_yaml_path
205-
${sparse_api_yaml_file} --api_header_path ${sparse_api_header_file_tmp}
206-
--api_source_path ${sparse_api_source_file_tmp})
205+
${sparse_api_yaml_file} ${sparse_bw_api_yaml_file} --api_header_path
206+
${sparse_api_header_file_tmp} --api_source_path
207+
${sparse_api_source_file_tmp})
207208

208209
# generate backward sparse api
209210
execute_process(
@@ -226,17 +227,18 @@ if(WITH_DISTRIBUTE)
226227
COMMAND
227228
${PYTHON_EXECUTABLE} ${im_api_gen_file} --api_yaml_path ${api_yaml_file}
228229
${bw_api_yaml_file} ${legacy_api_yaml_file} ${legacy_bw_api_yaml_file}
229-
--sparse_api_yaml_path ${sparse_api_yaml_file} --dygraph_api_header_path
230-
${dygraph_api_header_file_tmp} --dygraph_api_source_path
231-
${dygraph_api_source_file_tmp} --gen_dist_branch)
230+
--sparse_api_yaml_path ${sparse_api_yaml_file} ${sparse_bw_api_yaml_file}
231+
--dygraph_api_header_path ${dygraph_api_header_file_tmp}
232+
--dygraph_api_source_path ${dygraph_api_source_file_tmp}
233+
--gen_dist_branch)
232234
else()
233235
execute_process(
234236
COMMAND
235237
${PYTHON_EXECUTABLE} ${im_api_gen_file} --api_yaml_path ${api_yaml_file}
236238
${bw_api_yaml_file} ${legacy_api_yaml_file} ${legacy_bw_api_yaml_file}
237-
--sparse_api_yaml_path ${sparse_api_yaml_file} --dygraph_api_header_path
238-
${dygraph_api_header_file_tmp} --dygraph_api_source_path
239-
${dygraph_api_source_file_tmp})
239+
--sparse_api_yaml_path ${sparse_api_yaml_file} ${sparse_bw_api_yaml_file}
240+
--dygraph_api_header_path ${dygraph_api_header_file_tmp}
241+
--dygraph_api_source_path ${dygraph_api_source_file_tmp})
240242
endif()
241243

242244
# generate tensor and tensor operants file

0 commit comments

Comments
 (0)