-
Notifications
You must be signed in to change notification settings - Fork 5.7k
[spmd_rules] Add SPMD rules for ArgSort operator #72388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[spmd_rules] Add SPMD rules for ArgSort operator #72388
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
561556f
to
e14325b
Compare
x_tensor_dist_attr = dist_attr[0] | ||
y_tensor_dist_attr = dist_attr[1] | ||
|
||
self.assertEqual(len(dist_attr), 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line should be before L77, first determine the length, then retrieve the element
x_tensor_dist_attr = dist_attr[0] | ||
y_tensor_dist_attr = dist_attr[1] | ||
|
||
self.assertEqual(len(dist_attr), 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same issue as above
x_tensor_dist_attr = dist_attr[0] | ||
y_tensor_dist_attr = dist_attr[1] | ||
|
||
self.assertEqual(len(dist_attr), 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same issue as above
SpmdInfo ArgSortInferSpmdBase(const DistMetaTensor& x, | ||
int axis, | ||
bool descending, | ||
bool stable); | ||
|
||
SpmdInfo ArgSortGradInferSpmdBase(const DistMetaTensor& indices, | ||
const DistMetaTensor& x, | ||
const DistMetaTensor& out_grad, | ||
int axis, | ||
bool descending, | ||
bool stable); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason for ArgSortInferSpmdBase
and ArgSortGradInferSpmdBase
? for other operators reuse these codes?
f62d6e7
to
51a9572
Compare
@jeff41404 Code has been modified according issues you highlighted. |
to generate the correct code, |
51a9572
to
62735f0
Compare
@@ -0,0 +1,37 @@ | |||
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2023 -> 2015
@@ -0,0 +1,184 @@ | |||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
62735f0
to
b4026b7
Compare
4be10d4
to
0a50085
Compare
/re-run approval |
PR Category
Auto Parallel
PR Types
New features
Description
pcard-90510
The version upon committing incorporates SPMD rules for the ArgSort operator, covering both the forward and backward passes.
The SPMD rules for both the forward and backward passes simply replicate the sorting axis across one worker mesh dimension. For example, dimension mapping [0, 1, -1], given sorting axis of 1, will be transformed into dimension mapping [0, -1, -1].
The rules that involve splitting the sorting axis are deeply coupled with the reduce communication mechanism. Specifically, merging sorted sequences during reduction is not currently supported and remains a topic for future work.