Skip to content

Commit e8f716d

Browse files
committed
support ring_conv forward and backward
1 parent 425c14d commit e8f716d

File tree

8 files changed

+1139
-21
lines changed

8 files changed

+1139
-21
lines changed

paddle/fluid/eager/auto_code_generator/generator/eager_gen.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@
5555

5656
# Black Ops list that's NO NEED to apply code generation
5757
black_ops_list = [
58-
"conv2d",
59-
"conv2d_grad",
60-
"conv2d_grad_grad",
6158
"add_n",
6259
"add_n_grad",
6360
"sync_batch_norm_",
@@ -68,6 +65,8 @@
6865
"push_gpups_sparse",
6966
]
7067

68+
only_backward_ops_list = ["conv2d"]
69+
7170

7271
# white ops list whose kernel can be deleted after performance analysis
7372
# original kernel and its derivative kernel can be deleted when composite_grad
@@ -3248,10 +3247,14 @@ def GenerateCode(self, grad_flag=False):
32483247
for forward_api_contents in true_forward_api_list:
32493248
if forward_api_contents[op_string] in black_ops_list:
32503249
continue
3251-
if op_string == 'backward_op' and (
3252-
forward_api_contents[op_string].endswith(
3253-
('double_grad', 'triple_grad', 'grad_grad')
3250+
if (
3251+
op_string == 'backward_op'
3252+
and (
3253+
forward_api_contents[op_string].endswith(
3254+
('double_grad', 'triple_grad', 'grad_grad')
3255+
)
32543256
)
3257+
and "conv2d" not in forward_api_contents[op_string]
32553258
):
32563259
continue
32573260

@@ -3264,21 +3267,22 @@ def GenerateCode(self, grad_flag=False):
32643267
forward_api_contents
32653268
)
32663269

3267-
# Generate Dygraph Forward Function
3268-
function_generator = DygraphForwardFunctionGenerator(
3269-
forward_api_contents,
3270-
backward_api_contents,
3271-
forward_apis_dict,
3272-
namespace,
3273-
)
3274-
function_generator.run(grad_flag)
3270+
if forward_api_contents[op_string] not in only_backward_ops_list:
3271+
# Generate Dygraph Forward Function
3272+
function_generator = DygraphForwardFunctionGenerator(
3273+
forward_api_contents,
3274+
backward_api_contents,
3275+
forward_apis_dict,
3276+
namespace,
3277+
)
3278+
function_generator.run(grad_flag)
32753279

3276-
self.forward_definition_str += (
3277-
function_generator.forward_definition_str + "\n"
3278-
)
3279-
self.forward_declaration_str += (
3280-
function_generator.forward_declaration_str + "\n"
3281-
)
3280+
self.forward_definition_str += (
3281+
function_generator.forward_definition_str + "\n"
3282+
)
3283+
self.forward_declaration_str += (
3284+
function_generator.forward_declaration_str + "\n"
3285+
)
32823286

32833287
if not grad_flag:
32843288
# Generate Dygraph GradNode Function

paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
"scale_grad",
3131
"push_gpups_sparse",
3232
"multiply_grad",
33-
"conv2d_grad",
3433
"pull_sparse_v2_grad",
3534
}
3635

python/paddle/distributed/auto_parallel/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from .process_mesh import ProcessMesh # noqa: F401
2626
from .random import parallel_manual_seed # noqa: F401
27+
from .ring_conv import RingConv2d # noqa: F401
2728
from .static.engine import Engine # noqa: F401
2829
from .strategy import Strategy # noqa: F401
2930

0 commit comments

Comments
 (0)