Skip to content

Commit df9be2d

Browse files
committed
fix CrossMapNormalFunc and ContextProjectionFunc(remove inouts argument)
1 parent 57e2521 commit df9be2d

File tree

7 files changed

+98
-85
lines changed

7 files changed

+98
-85
lines changed

paddle/function/BufferArg.h

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -57,58 +57,67 @@ typedef std::shared_ptr<BufferArg> BufferArgPtr;
5757
* output Buffer or added to the output Buffer is determined by the
5858
* argType_ property of the output BufferArg.
5959
*/
60+
61+
// ArgType is only used by output BufferArg.
62+
// For input argument, argType_ is ignored.
63+
// For output argument, need to set the argType_ of the BufferArg.
64+
enum ArgType {
65+
UNSPECIFIED = 0,
66+
ASSIGN_TO = 1,
67+
ADD_TO = 2,
68+
};
6069
class BufferArg {
6170
public:
62-
// ArgType is only used by output BufferArg.
63-
// For input argument, argType_ is ignored.
64-
// For output argument, need to set the argType_ of the BufferArg.
65-
enum ArgType {
66-
UNSPECIFIED = 0,
67-
ASSIGN_TO = 1,
68-
ADD_TO = 2,
69-
};
70-
7171
void setArgType(ArgType argType) { argType_ = argType; }
7272

7373
ArgType getArgType() const { return argType_; }
7474

7575
public:
76-
BufferArg(void* buf, ValueType valueType, const TensorShape& shape)
77-
: buf_(buf), valueType_(valueType), shape_(shape) {}
76+
BufferArg(void* buf,
77+
ValueType valueType,
78+
const TensorShape& shape,
79+
ArgType argType = UNSPECIFIED)
80+
: buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {}
7881

7982
BufferArg(void* buf, ValueType valueType)
8083
: buf_(buf), valueType_(valueType) {}
8184

82-
BufferArg(const Matrix& matrix)
85+
BufferArg(const Matrix& matrix, ArgType argType = UNSPECIFIED)
8386
: buf_(
8487
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
8588
valueType_(DataType<real>::value),
86-
shape_(2) {
89+
shape_(2),
90+
argType_(argType) {
8791
shape_.setDim(0, matrix.getHeight());
8892
shape_.setDim(1, matrix.getWidth());
8993
}
9094

91-
BufferArg(const Matrix& matrix, const TensorShape& shape)
95+
BufferArg(const Matrix& matrix,
96+
const TensorShape& shape,
97+
ArgType argType = UNSPECIFIED)
9298
: buf_(
9399
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
94100
valueType_(DataType<real>::value),
95-
shape_(shape) {
101+
shape_(shape),
102+
argType_(argType) {
96103
CHECK_EQ(matrix.getElementCnt(), shape.getElements());
97104
}
98105

99-
BufferArg(const Vector& vector)
106+
BufferArg(const Vector& vector, ArgType argType = UNSPECIFIED)
100107
: buf_(
101108
const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
102109
valueType_(DataType<real>::value),
103-
shape_(1) {
110+
shape_(1),
111+
argType_(argType) {
104112
shape_.setDim(0, vector.getSize());
105113
}
106114

107-
BufferArg(const IVector& vector)
115+
BufferArg(const IVector& vector, ArgType argType = UNSPECIFIED)
108116
: buf_(
109117
const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
110118
valueType_(VALUE_TYPE_INT32),
111-
shape_(1) {
119+
shape_(1),
120+
argType_(argType) {
112121
shape_.setDim(0, vector.getSize());
113122
}
114123

@@ -163,8 +172,10 @@ class BufferArg {
163172
// if a < b then value_.buf_[a] < value_.buf_[b]
164173
class SequenceIdArg : public BufferArg {
165174
public:
166-
SequenceIdArg(void* buf, const TensorShape& shape)
167-
: BufferArg(buf, VALUE_TYPE_INT32, shape) {
175+
SequenceIdArg(void* buf,
176+
const TensorShape& shape,
177+
ArgType argType = UNSPECIFIED)
178+
: BufferArg(buf, VALUE_TYPE_INT32, shape, argType) {
168179
CHECK_EQ(shape_.ndims(), 1);
169180
numSeqs_ = shape_[0] - 1;
170181
}
@@ -187,11 +198,15 @@ class SequenceArg : public BufferArg {
187198
SequenceArg(void* buf,
188199
ValueType valueType,
189200
const TensorShape& shape,
190-
const SequenceIdArg& startPositions)
191-
: BufferArg(buf, valueType, shape), startPositions_(startPositions) {}
201+
const SequenceIdArg& startPositions,
202+
ArgType argType = UNSPECIFIED)
203+
: BufferArg(buf, valueType, shape, argType),
204+
startPositions_(startPositions) {}
192205

193-
SequenceArg(const Matrix& matrix, const IVector& vector)
194-
: BufferArg(matrix), startPositions_(vector) {}
206+
SequenceArg(const Matrix& matrix,
207+
const IVector& vector,
208+
ArgType argType = UNSPECIFIED)
209+
: BufferArg(matrix, argType), startPositions_(vector) {}
195210

196211
~SequenceArg() {}
197212

@@ -214,8 +229,9 @@ class SparseMatrixArg : public BufferArg {
214229
const BufferArg& col,
215230
size_t nnz,
216231
SparseDataFormat format,
217-
SparseDataType type)
218-
: BufferArg(buf, valueType, shape),
232+
SparseDataType type,
233+
ArgType argType = UNSPECIFIED)
234+
: BufferArg(buf, valueType, shape, argType),
219235
row_(row),
220236
col_(col),
221237
nnz_(nnz),
@@ -232,13 +248,13 @@ class SparseMatrixArg : public BufferArg {
232248
}
233249
}
234250

235-
SparseMatrixArg(const CpuSparseMatrix& sparse)
236-
: BufferArg(sparse),
251+
SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED)
252+
: BufferArg(sparse, argType),
237253
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
238254
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
239255

240-
SparseMatrixArg(const GpuSparseMatrix& sparse)
241-
: BufferArg(sparse),
256+
SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED)
257+
: BufferArg(sparse, argType),
242258
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
243259
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
244260

paddle/function/ContextProjectionOp.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,9 @@ class ContextProjectionForwardFunc : public FunctionBase {
8484
begin_pad_ = config.get<size_t>("begin_pad");
8585
}
8686

87-
void calc(const BufferArgs& inputs,
88-
const BufferArgs& outputs,
89-
const BufferArgs& inouts) override {
87+
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
9088
CHECK_EQ(3, inputs.size());
9189
CHECK_EQ(1, outputs.size());
92-
CHECK_EQ(0, inouts.size());
9390

9491
CHECK(outputs[0].data() && inputs[0].data() && inputs[2].data());
9592
CHECK_EQ(outputs[0].shape().ndims(), 2);
@@ -103,6 +100,7 @@ class ContextProjectionForwardFunc : public FunctionBase {
103100
/// input and output has the same batch_size
104101
CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]);
105102

103+
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
106104
auto out_mat = outputs[0].matrix<Device>();
107105
auto in_mat = inputs[0].matrix<Device>();
108106
auto w_mat = !inputs[1].data()
@@ -194,12 +192,9 @@ class ContextProjectionBackwardFunc : public FunctionBase {
194192
total_pad_ = config.get<size_t>("total_pad");
195193
}
196194

197-
void calc(const BufferArgs& inputs,
198-
const BufferArgs& outputs,
199-
const BufferArgs& inouts) override {
195+
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
200196
CHECK_EQ(3, inputs.size());
201197
CHECK_EQ(1, outputs.size());
202-
CHECK_EQ(0, inouts.size());
203198

204199
CHECK(outputs[0].data() && inputs[2].data());
205200
CHECK_EQ(outputs[0].shape().ndims(), 2);
@@ -214,6 +209,8 @@ class ContextProjectionBackwardFunc : public FunctionBase {
214209
/// dim of output = dim of input * context_length
215210
CHECK_EQ(outputs[0].shape()[1], inputs[0].shape()[1] * context_length_);
216211

212+
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
213+
217214
auto out_grad_mat = outputs[0].matrix<Device>();
218215
auto in_grad_mat =
219216
!inputs[0].data() ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)

paddle/function/CrossMapNormalOp.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ void CrossMapNormalGrad<DEVICE_TYPE_CPU>(real* inputsGrad,
112112
}
113113

114114
/**
115+
* \brief {o_0, o_1} = calc(i_0)
116+
*
115117
* \param inputs[0] input value.
116118
* \param outputs[0] output value.
117119
* \param outputs[1] denoms.
@@ -125,17 +127,16 @@ class CrossMapNormalFunc : public FunctionBase {
125127
pow_ = config.get<real>("pow");
126128
}
127129

128-
void calc(const BufferArgs& inputs,
129-
const BufferArgs& outputs,
130-
const BufferArgs& inouts) override {
130+
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
131131
CHECK_EQ(1, inputs.size());
132132
CHECK_EQ(2, outputs.size());
133-
CHECK_EQ(0, inouts.size());
134133

135134
CHECK_EQ(inputs[0].shape().ndims(), 4);
136135
CHECK(inputs[0].shape() == outputs[0].shape());
137136
CHECK(inputs[0].shape() == outputs[1].shape());
138137

138+
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
139+
CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO);
139140
size_t samples = inputs[0].shape()[0];
140141
size_t channels = inputs[0].shape()[1];
141142
size_t height = inputs[0].shape()[2];
@@ -160,6 +161,8 @@ class CrossMapNormalFunc : public FunctionBase {
160161
};
161162

162163
/**
164+
* \brief {o_0} = calc(i_0, i_1, i_2, i_3)
165+
*
163166
* \param inputs[0] input value.
164167
* \param inputs[1] output value.
165168
* \param inputs[2] output grad.
@@ -175,19 +178,19 @@ class CrossMapNormalGradFunc : public FunctionBase {
175178
pow_ = config.get<real>("pow");
176179
}
177180

178-
void calc(const BufferArgs& inputs,
179-
const BufferArgs& outputs,
180-
const BufferArgs& inouts) override {
181+
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
181182
CHECK_EQ(4, inputs.size());
182183
CHECK_EQ(1, outputs.size());
183-
CHECK_EQ(0, inouts.size());
184184

185185
CHECK_EQ(inputs[0].shape().ndims(), 4);
186186
CHECK(inputs[0].shape() == inputs[1].shape());
187187
CHECK(inputs[0].shape() == inputs[2].shape());
188188
CHECK(inputs[0].shape() == inputs[3].shape());
189189
CHECK(inputs[0].shape() == outputs[0].shape());
190190

191+
// TODO(hedaoyuan): need support ASSIGN_TO mode.
192+
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
193+
191194
size_t samples = inputs[0].shape()[0];
192195
size_t channels = inputs[0].shape()[1];
193196
size_t height = inputs[0].shape()[2];

paddle/function/Function.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,18 @@ FuncConfig& FuncConfig::set<bool>(const std::string& key, bool v) {
7272
return *this;
7373
}
7474

75-
void BufferArgs::addArg(const Matrix& arg, const TensorShape& shape) {
76-
args_.push_back(std::make_shared<BufferArg>(arg, shape));
75+
void BufferArgs::addArg(const Matrix& arg,
76+
const TensorShape& shape,
77+
ArgType argType) {
78+
args_.push_back(std::make_shared<BufferArg>(arg, shape, argType));
7779
}
7880

79-
void BufferArgs::addArg(const CpuSparseMatrix& arg) {
80-
args_.push_back(std::make_shared<SparseMatrixArg>(arg));
81+
void BufferArgs::addArg(const CpuSparseMatrix& arg, ArgType argType) {
82+
args_.push_back(std::make_shared<SparseMatrixArg>(arg, argType));
8183
}
8284

83-
void BufferArgs::addArg(const GpuSparseMatrix& arg) {
84-
args_.push_back(std::make_shared<SparseMatrixArg>(arg));
85+
void BufferArgs::addArg(const GpuSparseMatrix& arg, ArgType argType) {
86+
args_.push_back(std::make_shared<SparseMatrixArg>(arg, argType));
8587
}
8688

8789
ClassRegistrar<FunctionBase> FunctionBase::funcRegistrar_;

paddle/function/Function.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class FuncConfig {
4949
/**
5050
* Argument type for Function::calc().
5151
* A BufferArgs contains a set of BufferArg,
52-
* because Function can have multiple inputs, outputs and inouts.
52+
* because Function can have multiple inputs and outputs.
5353
*/
5454
class BufferArgs {
5555
public:
@@ -58,20 +58,24 @@ class BufferArgs {
5858

5959
// add argument into BufferArgs
6060
// Tensor can be Matrix, Vector, IVector.
61+
// For inputs, do not need argType.
62+
// For outputs, the argType needs to be specified as ASSIGN_TO or ADD_TO.
6163
template <typename Tensor>
62-
void addArg(const Tensor& arg) {
63-
args_.push_back(std::make_shared<BufferArg>(arg));
64+
void addArg(const Tensor& arg, ArgType argType = UNSPECIFIED) {
65+
args_.push_back(std::make_shared<BufferArg>(arg, argType));
6466
}
6567

6668
// Add arg into BufferArgs and reshape the arg.
6769
//
6870
// For example, arg represents an image buffer,
6971
// but Matrix can only represent a two-dimensional Tensor.
7072
// So need an extra argument to describe the shape of the image buffer.
71-
void addArg(const Matrix& arg, const TensorShape& shape);
73+
void addArg(const Matrix& arg,
74+
const TensorShape& shape,
75+
ArgType argType = UNSPECIFIED);
7276

73-
void addArg(const CpuSparseMatrix& arg);
74-
void addArg(const GpuSparseMatrix& arg);
77+
void addArg(const CpuSparseMatrix& arg, ArgType argType = UNSPECIFIED);
78+
void addArg(const GpuSparseMatrix& arg, ArgType argType = UNSPECIFIED);
7579

7680
// get argument
7781
const BufferArg& operator[](size_t num) const {

paddle/gserver/layers/ContextProjection.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,13 @@ void ContextProjection::forward() {
122122

123123
BufferArgs inputs;
124124
BufferArgs outputs;
125-
BufferArgs inouts;
126125
inputs.addArg(*in_->value);
127126
inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr,
128127
w_ptr ? w_ptr->getHeight() : 0,
129128
input_dim));
130129
inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_));
131-
outputs.addArg(*out_->value);
132-
forward_[0]->calc(inputs, outputs, inouts);
130+
outputs.addArg(*out_->value, ADD_TO);
131+
forward_[0]->calc(inputs, outputs);
133132

134133
if (state_ && config_.context_start() < 0) {
135134
CHECK_EQ(1, in_->getNumSequences());
@@ -166,15 +165,14 @@ void ContextProjection::backward(const UpdateCallback& callback) {
166165

167166
BufferArgs inputs;
168167
BufferArgs outputs;
169-
BufferArgs inouts;
170168
inputs.addArg(CpuMatrix(
171169
in_->grad ? in_->grad->getData() : nullptr, batch_size, input_dim));
172170
inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr,
173171
w_ptr ? w_ptr->getHeight() : 0,
174172
input_dim));
175173
inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_));
176-
outputs.addArg(*out_->grad);
177-
backward_[0]->calc(inputs, outputs, inouts);
174+
outputs.addArg(*out_->grad, ADD_TO);
175+
backward_[0]->calc(inputs, outputs);
178176

179177
if (config_.trainable_padding()) {
180178
weight_->getParameterPtr()->incUpdate(callback);

0 commit comments

Comments
 (0)