Skip to content

Commit 57e2521

Browse files
committed
BufferArg add ArgType and Function remove inouts
1 parent d35ef9d commit 57e2521

File tree

4 files changed

+59
-1886
lines changed

4 files changed

+59
-1886
lines changed

paddle/function/BufferArg.h

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,40 @@ enum SparseDataType {
3838

3939
enum SparseDataFormat { SPARSE_CSR_FORMAT = 0, SPARSE_CSC_FORMAT = 1 };
4040

41-
/**
42-
* BufferArg used as the argument type for Function.
43-
*/
4441
class BufferArg;
4542
class SequenceArg;
4643
class SparseMatrixArg;
4744
typedef std::shared_ptr<BufferArg> BufferArgPtr;
4845

49-
// an array of arbitrary dimensions
46+
/**
47+
* \brief BufferArg used as the argument type of Function.
48+
*
49+
* The arguments of the Paddle Function have four Buffer types.
50+
* 1. BufferArg for a dense Buffer of any dimension.
51+
* 2. SequenceIdArg for a Buffer of sequence start positions.
52+
* 3. SequenceArg for a Buffer of sequence data.
53+
* 4. SparseMatrixArg for a Buffer of sparse matrix.
54+
*
55+
* There is an ArgType property for the BufferArg used as Function Output.
56+
* Whether the result of the Function calculation is assigned to the
57+
* output Buffer or added to the output Buffer is determined by the
58+
* argType_ property of the output BufferArg.
59+
*/
5060
class BufferArg {
61+
public:
62+
// ArgType is only used by output BufferArg.
63+
// For input argument, argType_ is ignored.
64+
// For output argument, need to set the argType_ of the BufferArg.
65+
enum ArgType {
66+
UNSPECIFIED = 0,
67+
ASSIGN_TO = 1,
68+
ADD_TO = 2,
69+
};
70+
71+
void setArgType(ArgType argType) { argType_ = argType; }
72+
73+
ArgType getArgType() const { return argType_; }
74+
5175
public:
5276
BufferArg(void* buf, ValueType valueType, const TensorShape& shape)
5377
: buf_(buf), valueType_(valueType), shape_(shape) {}
@@ -56,29 +80,33 @@ class BufferArg {
5680
: buf_(buf), valueType_(valueType) {}
5781

5882
BufferArg(const Matrix& matrix)
59-
: buf_(reinterpret_cast<void*>(matrix.getData())),
83+
: buf_(
84+
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
6085
valueType_(DataType<real>::value),
6186
shape_(2) {
6287
shape_.setDim(0, matrix.getHeight());
6388
shape_.setDim(1, matrix.getWidth());
6489
}
6590

6691
BufferArg(const Matrix& matrix, const TensorShape& shape)
67-
: buf_(reinterpret_cast<void*>(matrix.getData())),
92+
: buf_(
93+
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
6894
valueType_(DataType<real>::value),
6995
shape_(shape) {
7096
CHECK_EQ(matrix.getElementCnt(), shape.getElements());
7197
}
7298

7399
BufferArg(const Vector& vector)
74-
: buf_(reinterpret_cast<void*>(vector.getData())),
100+
: buf_(
101+
const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
75102
valueType_(DataType<real>::value),
76103
shape_(1) {
77104
shape_.setDim(0, vector.getSize());
78105
}
79106

80107
BufferArg(const IVector& vector)
81-
: buf_(reinterpret_cast<void*>(vector.getData())),
108+
: buf_(
109+
const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
82110
valueType_(VALUE_TYPE_INT32),
83111
shape_(1) {
84112
shape_.setDim(0, vector.getSize());
@@ -124,6 +152,7 @@ class BufferArg {
124152
ValueType valueType_;
125153
TensorShape shape_;
126154
BufferType bufferType_;
155+
ArgType argType_ = UNSPECIFIED;
127156
// leading dimensions. The size is dims_.size()
128157
// Dims lds_;
129158
};

paddle/function/Function.h

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,18 @@ class BufferArgs {
5656
BufferArgs() {}
5757
size_t size() const { return args_.size(); }
5858

59-
// add argument into BufferArgss
59+
// add argument into BufferArgs
60+
// Tensor can be Matrix, Vector, IVector.
6061
template <typename Tensor>
6162
void addArg(const Tensor& arg) {
6263
args_.push_back(std::make_shared<BufferArg>(arg));
6364
}
6465

66+
// Add arg into BufferArgs and reshape the arg.
67+
//
68+
// For example, arg represents an image buffer,
69+
// but Matrix can only represent a two-dimensional Tensor.
70+
// So need an extra argument to describe the shape of the image buffer.
6571
void addArg(const Matrix& arg, const TensorShape& shape);
6672

6773
void addArg(const CpuSparseMatrix& arg);
@@ -78,20 +84,28 @@ class BufferArgs {
7884
};
7985

8086
/**
81-
* Base class for Function.
87+
* \brief Base class for Function.
8288
* The basic Function implementation requires override init and calc interfaces.
83-
* Need to pay attention to the inouts argument. For the input argument
84-
* that will be modified, it needs to be passed through inouts.
89+
*
90+
* Function inputs are readonly, Function outputs have two modes: ASSIGN_TO
91+
* and ADD_TO.
92+
* If output.getArgType() == ASSIGN_TO, this is assign mode, and the calculation
93+
* result of Function assigned to the output BufferArg.
94+
* If output.getArgType() == ADD_TO, this is add mode, and the calculation
95+
* result of Function need added to the output BufferArg.
96+
*
97+
* For example:
98+
* ASSIGN_TO: output = Function(inputs)
99+
* ADD_TO: output += Function(inputs)
100+
* If Function has more than one output, each output can have different modes.
85101
*/
86102
class FunctionBase {
87103
public:
88104
virtual ~FunctionBase() {}
89105

90106
virtual void init(const FuncConfig& config) {}
91107

92-
virtual void calc(const BufferArgs& inputs,
93-
const BufferArgs& outputs,
94-
const BufferArgs& inouts) {}
108+
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
95109

96110
static ClassRegistrar<FunctionBase> funcRegistrar_;
97111
};

paddle/function/FunctionTest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void FunctionApi<DEVICE_TYPE_GPU>(GpuMatrix& output, const GpuMatrix& input) {
3535

3636
template <DeviceType DType>
3737
void Function(const BufferArgs& arguments) {
38-
auto input = arguments[0].matrix<DType>();
38+
const auto input = arguments[0].matrix<DType>();
3939
auto output = arguments[1].matrix<DType>();
4040
FunctionApi<DType>(output, input);
4141
}

0 commit comments

Comments
 (0)