Skip to content

Commit d6c8ca8

Browse files
authored
support auto generate for activation_op swish (#53983)
1 parent 4f56e7c commit d6c8ca8

File tree

5 files changed

+25
-28
lines changed

5 files changed

+25
-28
lines changed

paddle/fluid/operators/activation_op.cc

-16
Original file line numberDiff line numberDiff line change
@@ -176,21 +176,6 @@ SoftRelu Activation Operator.
176176
}
177177
};
178178

179-
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
180-
public:
181-
void Make() override {
182-
AddInput("X", "Input of Swish operator");
183-
AddOutput("Out", "Output of Swish operator");
184-
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
185-
AddComment(R"DOC(
186-
Swish Activation Operator.
187-
188-
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
189-
190-
)DOC");
191-
}
192-
};
193-
194179
class MishOpMaker : public framework::OpProtoAndCheckerMaker {
195180
public:
196181
void Make() override {
@@ -406,7 +391,6 @@ FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
406391
REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu)
407392

408393
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
409-
REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor);
410394

411395
/* ========================== register checkpoint ===========================*/
412396
REGISTER_OP_VERSION(leaky_relu)

paddle/phi/api/yaml/op_compat.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -2383,6 +2383,10 @@
23832383

23842384
- op : swish
23852385
backward : swish_grad
2386+
inputs :
2387+
x : X
2388+
outputs :
2389+
out : Out
23862390
extra :
23872391
attrs : [bool use_mkldnn = false]
23882392

paddle/phi/api/yaml/static_backward.yaml

+11
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,17 @@
121121
data_type : out_grad
122122
no_need_buffer : x
123123

124+
- backward_op : swish_grad
125+
forward : swish (Tensor x, float beta = 1.0f) -> Tensor(out)
126+
args : (Tensor x, Tensor out_grad)
127+
output : Tensor(x_grad)
128+
infer_meta :
129+
func : GeneralUnaryGradInferMeta
130+
param : [x]
131+
kernel :
132+
func : swish_grad
133+
inplace : (out_grad -> x_grad)
134+
124135
- backward_op : tril_triu_grad
125136
forward : tril_triu (Tensor x, int diagonal = 0, bool lower = false) -> Tensor(out)
126137
args : (Tensor out_grad, int diagonal, bool lower)

paddle/phi/api/yaml/static_ops.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,16 @@
398398
param : [x, axes, starts, ends, strides]
399399
backward : strided_slice_grad
400400

401+
- op : swish
402+
args : (Tensor x, float beta = 1.0f)
403+
output : Tensor(out)
404+
infer_meta :
405+
func : UnchangedInferMeta
406+
param : [x]
407+
kernel :
408+
func : swish_raw
409+
backward : swish_grad
410+
401411
- op : tril_indices
402412
args : (int rows = 0, int cols = 0, int offset = 0, DataType dtype = DataType::INT64)
403413
output : Tensor(out)

paddle/phi/ops/compat/activation_sig.cc

-12
Original file line numberDiff line numberDiff line change
@@ -42,26 +42,14 @@ namespace phi {
4242
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardTanh, "hardtanh", "t_min" comma "t_max");
4343
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Mish, "mish", "threshold");
4444

45-
KernelSignature SwishGradOpArgumentMapping(
46-
const ArgumentMappingContext& ctx UNUSED) {
47-
return KernelSignature("swish_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
48-
}
49-
5045
KernelSignature HardSwishOpArgumentMapping(
5146
const ArgumentMappingContext& ctx UNUSED) {
5247
return KernelSignature("hardswish", {"X"}, {}, {"Out"});
5348
}
5449

55-
KernelSignature SwishOpArgumentMapping(
56-
const ArgumentMappingContext& ctx UNUSED) {
57-
return KernelSignature("swish_raw", {"X"}, {"beta"}, {"Out"});
58-
}
59-
6050
} // namespace phi
6151

6252
PD_REGISTER_BASE_KERNEL_NAME(hard_swish, hardswish);
6353
PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping);
6454

6555
PD_REGISTER_ARG_MAPPING_FN(hard_swish, phi::HardSwishOpArgumentMapping);
66-
PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping);
67-
PD_REGISTER_ARG_MAPPING_FN(swish, phi::SwishOpArgumentMapping);

0 commit comments

Comments
 (0)