Skip to content

Commit 820a387

Browse files
authored
[Semi-Auto] Adapt transpose rule to phi (#57259)
* adapt transpose rule to phi * small modification * modify api in unit test
1 parent b555415 commit 820a387

File tree

7 files changed

+220
-240
lines changed

7 files changed

+220
-240
lines changed

paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h"
1919
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
2020
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h"
21-
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h"
2221

2322
// TODO(ljz) Automatic this process in cmake file.
2423
namespace paddle {
@@ -36,9 +35,6 @@ REGISTER_SPMD_RULE(log_softmax, SoftmaxSPMDRule);
3635
REGISTER_SPMD_RULE(cross_entropy_with_softmax, CrossEntropyWithSoftmaxSPMDRule);
3736
REGISTER_SPMD_RULE(softmax_with_cross_entropy, CrossEntropyWithSoftmaxSPMDRule);
3837

39-
// transpose rule
40-
REGISTER_SPMD_RULE(transpose, TransposeSPMDRule);
41-
4238
} // namespace auto_parallel
4339
} // namespace distributed
4440
} // namespace paddle

paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.cc

Lines changed: 0 additions & 173 deletions
This file was deleted.

paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h

Lines changed: 0 additions & 46 deletions
This file was deleted.

paddle/phi/infermeta/spmd_rules/rules.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License. */
2525
#include "paddle/phi/infermeta/spmd_rules/replicated.h"
2626
#include "paddle/phi/infermeta/spmd_rules/reshape.h"
2727
#include "paddle/phi/infermeta/spmd_rules/split.h"
28+
#include "paddle/phi/infermeta/spmd_rules/transpose.h"
2829

2930
/**
3031
* Design Notes:
@@ -495,5 +496,11 @@ PD_REGISTER_SPMD_RULE(
495496
PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmd),
496497
PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmdReverse));
497498

499+
// transpose rule
500+
PD_REGISTER_SPMD_RULE(
501+
transpose,
502+
PD_INFER_SPMD(phi::distributed::TransposeInferSpmd),
503+
PD_INFER_SPMD(phi::distributed::TransposeInferSpmdReverse));
504+
498505
} // namespace distributed
499506
} // namespace phi

0 commit comments

Comments
 (0)