Skip to content

[AutoParallel] support tp_conv #73039

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@

# Black Ops list that's NO NEED to apply code generation
black_ops_list = [
"conv2d",
"conv2d_grad",
"conv2d_grad_grad",
"add_n",
"add_n_grad",
"sync_batch_norm_",
Expand All @@ -68,6 +65,8 @@
"push_gpups_sparse",
]

only_backward_ops_list = ["conv2d"]


# white ops list whose kernel can be deleted after performance analysis
# original kernel and its derivative kernel can be deleted when composite_grad
Expand Down Expand Up @@ -3248,10 +3247,14 @@ def GenerateCode(self, grad_flag=False):
for forward_api_contents in true_forward_api_list:
if forward_api_contents[op_string] in black_ops_list:
continue
if op_string == 'backward_op' and (
forward_api_contents[op_string].endswith(
('double_grad', 'triple_grad', 'grad_grad')
if (
op_string == 'backward_op'
and (
forward_api_contents[op_string].endswith(
('double_grad', 'triple_grad', 'grad_grad')
)
)
and "conv2d" not in forward_api_contents[op_string]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经设置了几处黑名单了,为什么还要加conv2d的特判?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经设置了几处黑名单了,为什么还要加conv2d的特判?

因为上面给 conv2d 的黑名单移除了, conv2d 的 py api 应该是有其他地方处理了,我之前不加这个特判的时候 windows 和 mac 版本的 paddle 会编译报错 提示 conv2d 重复定义(linux 下编译没问题)

):
continue

Expand All @@ -3264,21 +3267,22 @@ def GenerateCode(self, grad_flag=False):
forward_api_contents
)

# Generate Dygraph Forward Function
function_generator = DygraphForwardFunctionGenerator(
forward_api_contents,
backward_api_contents,
forward_apis_dict,
namespace,
)
function_generator.run(grad_flag)
if forward_api_contents[op_string] not in only_backward_ops_list:
# Generate Dygraph Forward Function
function_generator = DygraphForwardFunctionGenerator(
forward_api_contents,
backward_api_contents,
forward_apis_dict,
namespace,
)
function_generator.run(grad_flag)

self.forward_definition_str += (
function_generator.forward_definition_str + "\n"
)
self.forward_declaration_str += (
function_generator.forward_declaration_str + "\n"
)
self.forward_definition_str += (
function_generator.forward_definition_str + "\n"
)
self.forward_declaration_str += (
function_generator.forward_declaration_str + "\n"
)

if not grad_flag:
# Generate Dygraph GradNode Function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"scale_grad",
"push_gpups_sparse",
"multiply_grad",
"conv2d_grad",
"pull_sparse_v2_grad",
}

Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/auto_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from .process_mesh import ProcessMesh # noqa: F401
from .random import parallel_manual_seed # noqa: F401
from .ring_conv import RingConv2d # noqa: F401
from .static.engine import Engine # noqa: F401
from .strategy import Strategy # noqa: F401

Expand Down
Loading
Loading