Skip to content

Commit 02cf173

Browse files
committed
adapt softmax spmd rule to phi
1 parent b39b2b7 commit 02cf173

File tree

8 files changed

+202
-220
lines changed

8 files changed

+202
-220
lines changed

paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h

-5
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h"
2323
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
2424
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h"
25-
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h"
2625
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h"
2726
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h"
2827

@@ -141,10 +140,6 @@ REGISTER_SPMD_RULE(replicated, ReplicatedSPMDRule);
141140
REGISTER_SPMD_RULE(embedding, EmbeddingSPMDRule);
142141
REGISTER_SPMD_RULE(lookup_table_v2, EmbeddingSPMDRule);
143142

144-
// softmax rule
145-
REGISTER_SPMD_RULE(softmax, SoftmaxSPMDRule);
146-
REGISTER_SPMD_RULE(log_softmax, SoftmaxSPMDRule);
147-
148143
// cross_entropy_with_softmax
149144
REGISTER_SPMD_RULE(cross_entropy_with_softmax, CrossEntropyWithSoftmaxSPMDRule);
150145
REGISTER_SPMD_RULE(softmax_with_cross_entropy, CrossEntropyWithSoftmaxSPMDRule);

paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.cc

-179
This file was deleted.

paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc

+2-4
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ AttrType InferSpmdContext::AttrAt(size_t idx) const {
5353
}
5454
}
5555

56+
template int InferSpmdContext::AttrAt(size_t idx) const;
57+
5658
template <>
5759
bool InferSpmdContext::AttrAt(size_t idx) const {
5860
try {
@@ -88,10 +90,6 @@ std::vector<int> InferSpmdContext::AttrAt(size_t idx) const {
8890
}
8991
}
9092

91-
// template const std::vector<int64_t>& InferSpmdContext::AttrAt(size_t idx)
92-
// const; template const std::vector<int>& InferSpmdContext::AttrAt(size_t idx)
93-
// const;
94-
9593
const Attribute& InferSpmdContext::AttrAt(size_t idx) const {
9694
return attrs_.at(idx);
9795
}

paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h

+1
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ struct InferSpmdFnImpl<Return (*)(Args...), infer_spmd_fn> {
153153

154154
// TODO(chenweihang): support other attr type later as needed
155155
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(bool);
156+
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(int);
156157
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int>);
157158
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
158159
std::vector<int64_t>);

paddle/phi/infermeta/spmd_rules/rules.h

+10
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919
#include "paddle/phi/infermeta/spmd_rules/default_data_parallel.h"
2020
#include "paddle/phi/infermeta/spmd_rules/matmul.h"
2121
#include "paddle/phi/infermeta/spmd_rules/replicated.h"
22+
#include "paddle/phi/infermeta/spmd_rules/softmax.h"
2223

2324
/**
2425
* Design Notes:
@@ -57,5 +58,14 @@ PD_REGISTER_SPMD_RULE(
5758
PD_INFER_SPMD(phi::distributed::ReplicatedSpmdInferForward),
5859
PD_INFER_SPMD(phi::distributed::ReplicatedSpmdInferBackward));
5960

61+
// softmax rule
62+
PD_REGISTER_SPMD_RULE(softmax,
63+
PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmd),
64+
PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmdReverse));
65+
66+
PD_REGISTER_SPMD_RULE(log_softmax,
67+
PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmd),
68+
PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmdReverse));
69+
6070
} // namespace distributed
6171
} // namespace phi

0 commit comments

Comments
 (0)