-
Notifications
You must be signed in to change notification settings - Fork 5.7k
[AutoParallel] support tp_conv #73039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
) | ||
|
||
|
||
def _ring_send_recv_construct( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数名可以改下
grad_weight = dist.auto_parallel.api.dtensor_from_local( | ||
grad_weight, weight_mesh, weight_placements | ||
) | ||
grad_weight = dist.reshard( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reshard 没用,可以删掉
grad_bias = dist.auto_parallel.api.dtensor_from_local( | ||
grad_bias, bias_mesh, bias_placements | ||
) | ||
grad_weight = dist.reshard( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
in_dygraph_mode() | ||
and x.is_dist() | ||
and self._data_format in ["NCHW", "NHWC"] | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if分支这里可以加下注释,走RingConv2d的条件
return reconstructed_tensor | ||
|
||
|
||
def _ring_send_recv_aggregate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
统一在中层 api 的 pr 里面修改
) | ||
and "conv2d" not in forward_api_contents[op_string] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经设置了几处黑名单了,为什么还要加conv2d的特判?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经设置了几处黑名单了,为什么还要加conv2d的特判?
因为上面给 conv2d 的黑名单移除了, conv2d 的 py api 应该是有其他地方处理了,我之前不加这个特判的时候 windows 和 mac 版本的 paddle 会编译报错 提示 conv2d 重复定义(linux 下编译没问题)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
/re-run all-failed |
PR Category
Auto Parallel
PR Types
New features
Description
为自动并行支持 tp conv
当前的 TP Conv 实现基于 Ring 通信机制,在 W 维上对输入进行切分,并通过 halo 区域的数据交换实现边界感受野覆盖。前向阶段采用 ring-send-recv 构造 halo,后向阶段在本地梯度聚合后同样通过 ring 通信完成交叉更新。当前版本仅支持 dilation=1,支持典型的 stride=1 卷积和部分等间隔下采样卷积
Pcard-76459