Skip to content

Commit 48d0ac8

Browse files
committed
solve conflicts, merge auto code-gen
1 parent 07f55f0 commit 48d0ac8

File tree

8 files changed

+277
-34
lines changed

8 files changed

+277
-34
lines changed

paddle/fluid/prim/api/api.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
- add
2+
- subtract
3+
- multiply
4+
- divide
15
- unsqueeze
26
- pow
37
- exp
48
- scale
5-
- multiply
69
- matmul
710
- expand
8-
- divide
911
- sum
10-
- add
1112
- abs
1213
- assign
1314
- concat
@@ -23,4 +24,3 @@
2324
- scatter
2425
- scatter_nd_add
2526
- tile
26-
- subtract

paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
3737
using Tensor = paddle::experimental::Tensor;
3838
using TensorOperantsBase = paddle::operants::TensorOperantsBase;
39+
using Scalar = paddle::experimental::Scalar;
3940
4041
class EagerTensorOperants : public TensorOperantsBase {
4142
private:
@@ -44,6 +45,14 @@ class EagerTensorOperants : public TensorOperantsBase {
4445
public:
4546
EagerTensorOperants() = default;
4647
48+
Tensor add(const Tensor& x, const Scalar& y);
49+
50+
Tensor subtract(const Tensor& x, const Scalar& y);
51+
52+
Tensor multiply(const Tensor& x, const Scalar& y);
53+
54+
Tensor divide(const Tensor& x, const Scalar& y);
55+
4756
"""
4857

4958

@@ -69,6 +78,22 @@ class EagerTensorOperants : public TensorOperantsBase {
6978
7079
namespace prim {
7180
81+
Tensor EagerTensorOperants::add(const Tensor& x, const Scalar& y) {
82+
return ::add_ad_func(x, ::full_like_ad_func(x, y));
83+
}
84+
85+
Tensor EagerTensorOperants::subtract(const Tensor& x, const Scalar& y) {
86+
return ::subtract_ad_func(x, ::full_like_ad_func(x, y));
87+
}
88+
89+
Tensor EagerTensorOperants::multiply(const Tensor& x, const Scalar& y) {
90+
return ::multiply_ad_func(x, ::full_like_ad_func(x, y));
91+
}
92+
93+
Tensor EagerTensorOperants::divide(const Tensor& x, const Scalar& y) {
94+
return ::divide_ad_func(x, ::full_like_ad_func(x, y));
95+
}
96+
7297
"""
7398

7499

@@ -96,6 +121,7 @@ class EagerTensorOperants : public TensorOperantsBase {
96121
97122
using Tensor = paddle::experimental::Tensor;
98123
using TensorOperantsBase = paddle::operants::TensorOperantsBase;
124+
using Scalar = paddle::experimental::Scalar;
99125
100126
class StaticTensorOperants : public TensorOperantsBase {
101127
private:
@@ -104,6 +130,14 @@ class StaticTensorOperants : public TensorOperantsBase {
104130
public:
105131
StaticTensorOperants() = default;
106132
133+
Tensor add(const Tensor& x, const Scalar& y);
134+
135+
Tensor subtract(const Tensor& x, const Scalar& y);
136+
137+
Tensor multiply(const Tensor& x, const Scalar& y);
138+
139+
Tensor divide(const Tensor& x, const Scalar& y);
140+
107141
"""
108142

109143

@@ -120,6 +154,7 @@ class StaticTensorOperants : public TensorOperantsBase {
120154
#include "paddle/fluid/prim/utils/static/static_tensor_operants.h"
121155
122156
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
157+
#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
123158
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
124159
125160
"""
@@ -131,6 +166,22 @@ class StaticTensorOperants : public TensorOperantsBase {
131166
namespace prim {
132167
using DescTensor = paddle::prim::DescTensor;
133168
169+
Tensor StaticTensorOperants::add(const Tensor& x, const Scalar& y) {
170+
return paddle::prim::add<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
171+
}
172+
173+
Tensor StaticTensorOperants::subtract(const Tensor& x, const Scalar& y) {
174+
return paddle::prim::subtract<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
175+
}
176+
177+
Tensor StaticTensorOperants::multiply(const Tensor& x, const Scalar& y) {
178+
return paddle::prim::multiply<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
179+
}
180+
181+
Tensor StaticTensorOperants::divide(const Tensor& x, const Scalar& y) {
182+
return paddle::prim::divide<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
183+
}
184+
134185
"""
135186

136187

paddle/fluid/prim/api/composite_backward/composite_backward_api.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ template <typename T>
2828
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
2929
if (!grad_x) return;
3030
auto tmp = pow<T>(out, 2.0);
31-
tmp = scale<T>(tmp, -1.0, 1.0, true);
32-
auto grad_x_tmp = grad_out * tmp;
31+
auto grad_x_tmp = grad_out * (tmp * -1.0 + 1.0);
3332
set_output<T>(grad_x_tmp, grad_x);
3433
}
3534

@@ -170,9 +169,8 @@ void divide_grad(const Tensor& x,
170169
if (dy) {
171170
// dy = -(x/y^2) * dout
172171
auto tmp0 = pow<T>(y, 2.0);
173-
auto tmp1 = x / tmp0;
174-
auto tmp2 = scale<T>(tmp1, -1.0, 0.0, true);
175-
auto dy_res = tmp2 * out_grad;
172+
auto tmp1 = x / tmp0 * -1.0;
173+
auto dy_res = tmp1 * out_grad;
176174
if (x.dims() != y.dims()) {
177175
// Maybe need reduce here
178176
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
@@ -213,8 +211,7 @@ void divide_grad(const Tensor& x,
213211
template <typename T>
214212
void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
215213
if (x_grad) {
216-
auto div_x = full<T>(phi::vectorize(out.dims()), 0.5);
217-
auto x_grad_tmp = out_grad * div_x / out;
214+
auto x_grad_tmp = out_grad / 2.0 / out;
218215
set_output<T>(x_grad_tmp, x_grad);
219216
}
220217
}

paddle/fluid/prim/utils/static/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ cc_library(
66
cc_library(
77
static_tensor_operants
88
SRCS static_tensor_operants.cc
9-
DEPS static_prim_api)
9+
DEPS static_prim_api manual_static_prim_api)

paddle/phi/api/include/tensor.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ namespace paddle {
4747

4848
namespace experimental {
4949

50+
class Tensor;
51+
52+
template <typename T>
53+
class ScalarBase;
54+
using Scalar = paddle::experimental::ScalarBase<Tensor>;
55+
5056
class AbstractAutogradMeta {
5157
public:
5258
// No AbstractAutogradMeta should be created
@@ -538,6 +544,14 @@ class PADDLE_API Tensor final {
538544

539545
Tensor operator/(const Tensor& other) const;
540546

547+
Tensor operator+(const Tensor& x, const Scalar& y) const;
548+
549+
Tensor operator-(const Tensor& x, const Scalar& y) const;
550+
551+
Tensor operator*(const Tensor& x, const Scalar& y) const;
552+
553+
Tensor operator/(const Tensor& x, const Scalar& y) const;
554+
541555
/* Part 8: Autograd methods */
542556

543557
/**

paddle/phi/api/yaml/generator/tensor_gen.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,23 @@ class TensorOperantsBase {
7474
7575
namespace operants {
7676
77+
using Scalar = paddle::experimental::Scalar;
78+
7779
class PhiTensorOperants : public TensorOperantsBase {
7880
private:
7981
DISABLE_COPY_AND_ASSIGN(PhiTensorOperants);
8082
8183
public:
8284
PhiTensorOperants() = default;
85+
86+
Tensor add(const Tensor& x, const Scalar& y);
87+
88+
Tensor subtract(const Tensor& x, const Scalar& y);
89+
90+
Tensor multiply(const Tensor& x, const Scalar& y);
91+
92+
Tensor divide(const Tensor& x, const Scalar& y);
93+
8394
"""
8495

8596

@@ -104,6 +115,24 @@ class PhiTensorOperants : public TensorOperantsBase {
104115
namespace paddle {
105116
106117
namespace operants {
118+
119+
Tensor PhiTensorOperants::add(const Tensor& x, const Scalar& y) {
120+
return paddle::experimental::add(x, paddle::experimental::full_like(x, y));
121+
}
122+
123+
Tensor PhiTensorOperants::subtract(const Tensor& x, const Scalar& y) {
124+
return paddle::experimental::subtract(x,
125+
paddle::experimental::full_like(x, y));
126+
}
127+
128+
Tensor PhiTensorOperants::multiply(const Tensor& x, const Scalar& y) {
129+
return paddle::experimental::multiply(x,
130+
paddle::experimental::full_like(x, y));
131+
}
132+
133+
Tensor PhiTensorOperants::divide(const Tensor& x, const Scalar& y) {
134+
return paddle::experimental::divide(x, paddle::experimental::full_like(x, y));
135+
}
107136
"""
108137

109138

@@ -129,6 +158,7 @@ class PhiTensorOperants : public TensorOperantsBase {
129158
130159
using Tensor = paddle::experimental::Tensor;
131160
using TensorOperantsBase = paddle::operants::TensorOperantsBase;
161+
using Scalar = paddle::experimental::Scalar;
132162
133163
/**
134164
* [ Why need OperantsManager? ]
@@ -175,6 +205,15 @@ class OperantsManager {
175205
176206
public:
177207
static OperantsManager& Instance();
208+
209+
Tensor add(const Tensor& x, const Scalar& y);
210+
211+
Tensor subtract(const Tensor& x, const Scalar& y);
212+
213+
Tensor multiply(const Tensor& x, const Scalar& y);
214+
215+
Tensor divide(const Tensor& x, const Scalar& y);
216+
178217
"""
179218

180219

@@ -302,17 +341,28 @@ def gene_operants_manager_code(self):
302341

303342
def gene_operants_manager_implementation(self):
304343
func_name = self.get_api_func_name()
344+
final_code = ""
345+
if func_name in ["add", "subtract", "multiply", "divide"]:
346+
final_code += f"""
347+
{self.get_return_type()} OperantsManager::{func_name}(const Tensor& x, const Scalar& y) {{{self.gene_operants_manager_code()}}}
348+
"""
305349
# func decalaration
306350
if func_name[-1] != '_':
307-
return f"""
351+
return (
352+
final_code
353+
+ f"""
308354
{self.get_return_type()} OperantsManager::{func_name}({self.get_define_args()}) {{{self.gene_operants_manager_code()}}}
309355
"""
356+
)
310357
else:
311-
return f"""
358+
return (
359+
final_code
360+
+ f"""
312361
{self.get_return_type(inplace_flag=True)} OperantsManager::{func_name}({self.get_define_args(inplace_flag=True)}) {{
313362
{self.gene_operants_manager_code()}
314363
}}
315364
"""
365+
)
316366

317367

318368
def generate_tensor_operants_api(

0 commit comments

Comments
 (0)