Skip to content

Commit ab1b630

Browse files
authored
【Prim】Reshape, transpose, cast vjp (#50778)
* support transpose and reshape * support reshpe, transpose, cast vjp * merge develop * recover unused file * remove prim base * support problem * remove additional status settting * remove additional status settting * fix ut * fix ut * fix ut * fix no grad branch * add more test * disable fp16 in cpu * fix test
1 parent f8ce3a2 commit ab1b630

21 files changed

+955
-385
lines changed

paddle/fluid/operators/cast_op.cc

+22
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ limitations under the License. */
2222
#ifdef PADDLE_WITH_MLU
2323
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
2424
#endif
25+
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
26+
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
27+
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
2528

2629
namespace paddle {
2730
namespace operators {
@@ -63,6 +66,24 @@ class CastOpGradMaker : public framework::SingleGradOpMaker<T> {
6366
}
6467
};
6568

69+
class CastCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
70+
public:
71+
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
72+
73+
void Apply() override {
74+
paddle::experimental::Tensor out_grad = paddle::experimental::Tensor(
75+
std::make_shared<prim::DescTensor>(this->SingleOutputGrad("Out")));
76+
paddle::experimental::Tensor x_grad = paddle::experimental::Tensor(
77+
std::make_shared<prim::DescTensor>(this->SingleInputGrad("X")));
78+
auto dx_ptr = this->GetOutputPtr(&x_grad);
79+
std::string dx_name = this->GetOutputName(x_grad);
80+
auto dtype = static_cast<paddle::experimental::DataType>(
81+
this->Attr<int>("in_dtype"));
82+
prim::cast_grad<prim::DescTensor>(out_grad, dtype, dx_ptr);
83+
this->RecoverOutputName(x_grad, dx_name);
84+
}
85+
};
86+
6687
class CastOp : public framework::OperatorWithKernel {
6788
public:
6889
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -134,6 +155,7 @@ REGISTER_OPERATOR(cast,
134155
ops::CastOp,
135156
ops::CastOpGradMaker<paddle::framework::OpDesc>,
136157
ops::CastOpGradMaker<paddle::imperative::OpBase>,
158+
ops::CastCompositeGradOpMaker,
137159
ops::CastOpProtoMaker);
138160

139161
// [ why register transfer_dtype_op alias with cast_op? ]

paddle/fluid/operators/reshape_op.cc

+23-1
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,16 @@ limitations under the License. */
1919
#include "paddle/fluid/framework/phi_utils.h"
2020

2121
// only can include the headers in paddle/phi/api dirs
22+
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
23+
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
24+
#include "paddle/phi/api/lib/utils/tensor_utils.h"
2225
#include "paddle/phi/backends/cpu/cpu_context.h"
2326
#include "paddle/phi/common/int_array.h"
2427
#include "paddle/phi/core/infermeta_utils.h"
2528
#include "paddle/phi/infermeta/backward.h"
2629
#include "paddle/phi/infermeta/unary.h"
2730
#include "paddle/phi/kernels/reshape_grad_kernel.h"
2831
#include "paddle/phi/kernels/reshape_kernel.h"
29-
3032
namespace paddle {
3133
namespace framework {
3234
class InferShapeContext;
@@ -571,6 +573,25 @@ class Reshape2GradMaker : public framework::SingleGradOpMaker<T> {
571573
}
572574
};
573575

576+
class Reshape2CompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
577+
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
578+
579+
public:
580+
void Apply() override {
581+
// We prefer to use x.shape instead of using xshape, this is different from
582+
// PHI definition.
583+
paddle::experimental::Tensor x = this->GetSingleForwardInput("X");
584+
paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
585+
paddle::experimental::Tensor dx = this->GetSingleInputGrad("X");
586+
587+
auto *dx_ptr = this->GetOutputPtr(&dx);
588+
std::string dx_name = this->GetOutputName(dx);
589+
VLOG(6) << "Runing reshape2_grad composite func";
590+
prim::reshape_grad<prim::DescTensor>(x, out_grad, dx_ptr);
591+
this->RecoverOutputName(dx, dx_name);
592+
}
593+
};
594+
574595
template <typename T>
575596
class Reshape2DoubleGradMaker : public framework::SingleGradOpMaker<T> {
576597
public:
@@ -715,6 +736,7 @@ REGISTER_OPERATOR(reshape2,
715736
ops::Reshape2OpMaker,
716737
ops::Reshape2GradMaker<paddle::framework::OpDesc>,
717738
ops::Reshape2GradMaker<paddle::imperative::OpBase>,
739+
ops::Reshape2CompositeGradOpMaker,
718740
ops::ReshapeOpInplaceInferer);
719741
REGISTER_OPERATOR(reshape2_grad,
720742
ops::Reshape2GradOp,

paddle/fluid/operators/transpose_op.cc

+23-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ limitations under the License. */
2424
#include "paddle/fluid/platform/mkldnn_helper.h"
2525
#endif
2626
#include "paddle/fluid/framework/op_registry.h"
27+
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
28+
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
2729

2830
namespace paddle {
2931
namespace operators {
@@ -300,6 +302,25 @@ class Transpose2GradMaker : public framework::SingleGradOpMaker<T> {
300302
}
301303
};
302304

305+
class Transpose2CompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
306+
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
307+
308+
public:
309+
void Apply() override {
310+
paddle::experimental::Tensor xshape =
311+
this->GetSingleForwardOutput("XShape");
312+
paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
313+
paddle::experimental::Tensor dx = this->GetSingleInputGrad("X");
314+
auto *dx_ptr = this->GetOutputPtr(&dx);
315+
std::string dx_name = this->GetOutputName(dx);
316+
std::vector<int> axis =
317+
static_cast<std::vector<int>>(this->Attr<std::vector<int>>("axis"));
318+
VLOG(6) << "Runing transpose2_grad composite func";
319+
prim::transpose_grad<prim::DescTensor>(out_grad, axis, dx_ptr);
320+
this->RecoverOutputName(dx, dx_name);
321+
}
322+
};
323+
303324
template <typename T>
304325
class Transpose2DoubleGradMaker : public framework::SingleGradOpMaker<T> {
305326
public:
@@ -365,7 +386,8 @@ REGISTER_OPERATOR(transpose2,
365386
ops::Transpose2Op,
366387
ops::Transpose2OpMaker,
367388
ops::Transpose2GradMaker<paddle::framework::OpDesc>,
368-
ops::Transpose2GradMaker<paddle::imperative::OpBase>);
389+
ops::Transpose2GradMaker<paddle::imperative::OpBase>,
390+
ops::Transpose2CompositeGradOpMaker);
369391
REGISTER_OPERATOR(transpose2_grad,
370392
ops::Transpose2OpGrad,
371393
ops::TransposeGradInferVarType,

0 commit comments

Comments
 (0)