@@ -38,16 +38,40 @@ enum SparseDataType {
38
38
39
39
enum SparseDataFormat { SPARSE_CSR_FORMAT = 0 , SPARSE_CSC_FORMAT = 1 };
40
40
41
- /* *
42
- * BufferArg used as the argument type for Function.
43
- */
44
41
class BufferArg ;
45
42
class SequenceArg ;
46
43
class SparseMatrixArg ;
47
44
typedef std::shared_ptr<BufferArg> BufferArgPtr;
48
45
49
- // an array of arbitrary dimensions
46
+ /* *
47
+ * \brief BufferArg used as the argument type of Function.
48
+ *
49
+ * The arguments of the Paddle Function have four Buffer types.
50
+ * 1. BufferArg for a dense Buffer of any dimension.
51
+ * 2. SequenceIdArg for a Buffer of sequence start positions.
52
+ * 3. SequenceArg for a Buffer of sequence data.
53
+ * 4. SparseMatrixArg for a Buffer of sparse matrix.
54
+ *
55
+ * There is an ArgType property for the BufferArg used as Function Output.
56
+ * Whether the result of the Function calculation is assigned to the
57
+ * output Buffer or added to the output Buffer is determined by the
58
+ * argType_ property of the output BufferArg.
59
+ */
50
60
class BufferArg {
61
+ public:
62
+ // ArgType is only used by output BufferArg.
63
+ // For input argument, argType_ is ignored.
64
+ // For output argument, need to set the argType_ of the BufferArg.
65
+ enum ArgType {
66
+ UNSPECIFIED = 0 ,
67
+ ASSIGN_TO = 1 ,
68
+ ADD_TO = 2 ,
69
+ };
70
+
71
+ void setArgType (ArgType argType) { argType_ = argType; }
72
+
73
+ ArgType getArgType () const { return argType_; }
74
+
51
75
public:
52
76
BufferArg (void * buf, ValueType valueType, const TensorShape& shape)
53
77
: buf_(buf), valueType_(valueType), shape_(shape) {}
@@ -56,29 +80,33 @@ class BufferArg {
56
80
: buf_(buf), valueType_(valueType) {}
57
81
58
82
BufferArg (const Matrix& matrix)
59
- : buf_(reinterpret_cast <void *>(matrix.getData())),
83
+ : buf_(
84
+ const_cast <void *>(reinterpret_cast <const void *>(matrix.getData()))),
60
85
valueType_ (DataType<real>::value),
61
86
shape_(2 ) {
62
87
shape_.setDim (0 , matrix.getHeight ());
63
88
shape_.setDim (1 , matrix.getWidth ());
64
89
}
65
90
66
91
BufferArg (const Matrix& matrix, const TensorShape& shape)
67
- : buf_(reinterpret_cast <void *>(matrix.getData())),
92
+ : buf_(
93
+ const_cast <void *>(reinterpret_cast <const void *>(matrix.getData()))),
68
94
valueType_(DataType<real>::value),
69
95
shape_(shape) {
70
96
CHECK_EQ (matrix.getElementCnt (), shape.getElements ());
71
97
}
72
98
73
99
BufferArg (const Vector& vector)
74
- : buf_(reinterpret_cast <void *>(vector.getData())),
100
+ : buf_(
101
+ const_cast <void *>(reinterpret_cast <const void *>(vector.getData()))),
75
102
valueType_(DataType<real>::value),
76
103
shape_(1 ) {
77
104
shape_.setDim (0 , vector.getSize ());
78
105
}
79
106
80
107
BufferArg (const IVector& vector)
81
- : buf_(reinterpret_cast <void *>(vector.getData())),
108
+ : buf_(
109
+ const_cast <void *>(reinterpret_cast <const void *>(vector.getData()))),
82
110
valueType_(VALUE_TYPE_INT32),
83
111
shape_(1 ) {
84
112
shape_.setDim (0 , vector.getSize ());
@@ -124,6 +152,7 @@ class BufferArg {
124
152
ValueType valueType_;
125
153
TensorShape shape_;
126
154
BufferType bufferType_;
155
+ ArgType argType_ = UNSPECIFIED;
127
156
// leading dimensions. The size is dims_.size()
128
157
// Dims lds_;
129
158
};
0 commit comments