Skip to content

Commit 2e6e188

Browse files
authored
[Tensor Operants & Prim-Relevant] Tensor API support default value (#50928)
1 parent 539293e commit 2e6e188

File tree

6 files changed

+27
-22
lines changed

6 files changed

+27
-22
lines changed

paddle/fluid/prim/utils/static/composite_grad_desc_maker.h

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "paddle/fluid/framework/type_defs.h"
2929
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
3030
#include "paddle/fluid/prim/utils/static/static_global_utils.h"
31+
#include "paddle/phi/api/include/tensor.h"
3132
#include "paddle/phi/core/enforce.h"
3233
#include "paddle/phi/core/flags.h"
3334

paddle/phi/api/include/tensor.h

+16-17
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ using gpuStream_t = hipStream_t;
3131

3232
#include "paddle/phi/api/include/dll_decl.h"
3333
#include "paddle/phi/common/data_type.h"
34+
#include "paddle/phi/common/int_array.h"
3435
#include "paddle/phi/common/layout.h"
3536
#include "paddle/phi/common/place.h"
37+
#include "paddle/phi/common/scalar.h"
3638

3739
namespace phi {
3840
class DenseTensor;
@@ -47,16 +49,6 @@ namespace paddle {
4749

4850
namespace experimental {
4951

50-
class Tensor;
51-
52-
template <typename T>
53-
class ScalarBase;
54-
using Scalar = paddle::experimental::ScalarBase<Tensor>;
55-
56-
template <typename T>
57-
class IntArrayBase;
58-
using IntArray = paddle::experimental::IntArrayBase<Tensor>;
59-
6052
class AbstractAutogradMeta {
6153
public:
6254
// No AbstractAutogradMeta should be created
@@ -684,22 +676,29 @@ class PADDLE_API Tensor final {
684676
Tensor floor() const;
685677
Tensor gather_nd(const Tensor& index) const;
686678
Tensor log() const;
687-
Tensor roll(const IntArray& shifts, const std::vector<int64_t>& axis) const;
679+
Tensor roll(const IntArray& shifts = {},
680+
const std::vector<int64_t>& axis = {}) const;
688681
Tensor scatter(const Tensor& index,
689682
const Tensor& updates,
690-
bool overwrite) const;
683+
bool overwrite = true) const;
691684
Tensor scatter_nd_add(const Tensor& index, const Tensor& updates) const;
692685
Tensor abs() const;
693686
Tensor assign() const;
694687
Tensor elementwise_pow(const Tensor& y) const;
695688
Tensor expand(const IntArray& shape) const;
696-
Tensor matmul(const Tensor& y, bool transpose_x, bool transpose_y) const;
697-
Tensor max(const IntArray& axis, bool keepdim) const;
689+
Tensor matmul(const Tensor& y,
690+
bool transpose_x = false,
691+
bool transpose_y = false) const;
692+
Tensor max(const IntArray& axis = {}, bool keepdim = false) const;
698693
Tensor maximum(const Tensor& y) const;
699694
Tensor minimum(const Tensor& y) const;
700-
Tensor scale(const Scalar& scale, float bias, bool bias_after_scale) const;
701-
Tensor sum(const IntArray& axis, DataType dtype, bool keepdim) const;
702-
Tensor tile(const IntArray& repeat_times) const;
695+
Tensor scale(const Scalar& scale = 1.0,
696+
float bias = 0.0,
697+
bool bias_after_scale = true) const;
698+
Tensor sum(const IntArray& axis = {},
699+
DataType dtype = DataType::UNDEFINED,
700+
bool keepdim = false) const;
701+
Tensor tile(const IntArray& repeat_times = {}) const;
703702
};
704703

705704
PADDLE_API Tensor operator+(const Scalar& x, const Tensor& y);

paddle/phi/api/lib/int_array.cc

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/phi/common/int_array.h"
1616

17+
#include "paddle/phi/api/include/tensor.h"
1718
#include "paddle/phi/api/lib/tensor_copy.h"
1819
#include "paddle/phi/common/place.h"
1920

paddle/phi/api/lib/scalar.cc

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/phi/common/scalar.h"
1616

17+
#include "paddle/phi/api/include/tensor.h"
1718
#include "paddle/phi/api/lib/tensor_copy.h"
1819
#include "paddle/phi/common/place.h"
1920
#include "paddle/phi/core/enforce.h"

paddle/phi/common/int_array.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <vector>
18+
1719
#include "paddle/phi/api/ext/exception.h"
18-
#include "paddle/phi/api/include/tensor.h"
20+
#include "paddle/phi/common/data_type.h"
1921

2022
namespace paddle {
2123
namespace experimental {
@@ -103,8 +105,8 @@ class IntArrayBase {
103105
bool is_from_tensor_{false};
104106
};
105107

106-
using IntArray =
107-
paddle::experimental::IntArrayBase<paddle::experimental::Tensor>;
108+
class Tensor;
109+
using IntArray = paddle::experimental::IntArrayBase<Tensor>;
108110

109111
} // namespace experimental
110112
} // namespace paddle

paddle/phi/common/scalar.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ limitations under the License. */
1818
#include <limits>
1919

2020
#include "paddle/phi/api/ext/exception.h"
21-
#include "paddle/phi/api/include/tensor.h"
21+
#include "paddle/phi/common/data_type.h"
2222

2323
namespace paddle {
2424
namespace experimental {
@@ -228,7 +228,8 @@ void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst) {
228228
dst->data_.c128 = src.data_.c128;
229229
}
230230

231-
using Scalar = paddle::experimental::ScalarBase<paddle::experimental::Tensor>;
231+
class Tensor;
232+
using Scalar = paddle::experimental::ScalarBase<Tensor>;
232233

233234
} // namespace experimental
234235
} // namespace paddle

0 commit comments

Comments
 (0)