From 225a8fa14b8fa04c814da02ff9f240f1819373f3 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 23 Jan 2017 11:26:20 +0800 Subject: [PATCH 1/4] Add numInputs_ and numOutputs_ --- paddle/function/CrossMapNormalOp.cpp | 18 ++++++++++++++---- paddle/function/Function.h | 13 +++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp index 5c0bdd933b1e4..3fab2127a1511 100644 --- a/paddle/function/CrossMapNormalOp.cpp +++ b/paddle/function/CrossMapNormalOp.cpp @@ -162,14 +162,19 @@ template class CrossMapNormalFunc : public FunctionBase { public: void init(const FuncConfig& config) override { + // function arguments size_ = config.get("size"); scale_ = config.get("scale"); pow_ = config.get("pow"); + + // number of inputs and outputs + numInputs_ = 1; + numOutputs_ = 2; } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)1, inputs.size()); - CHECK_EQ((size_t)2, outputs.size()); + CHECK_EQ((size_t)numInputs_, inputs.size()); + CHECK_EQ((size_t)numOutputs_, outputs.size()); CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); CHECK(inputs[0].shape() == outputs[0].shape()); @@ -236,14 +241,19 @@ template class CrossMapNormalGradFunc : public FunctionBase { public: void init(const FuncConfig& config) override { + // function arguments size_ = config.get("size"); scale_ = config.get("scale"); pow_ = config.get("pow"); + + // number of inputs and outputs + numInputs_ = 4; + numOutputs_ = 1; } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)4, inputs.size()); - CHECK_EQ((size_t)1, outputs.size()); + CHECK_EQ((size_t)numInputs_, inputs.size()); + CHECK_EQ((size_t)numOutputs_, outputs.size()); CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); CHECK(inputs[0].shape() == inputs[1].shape()); diff --git a/paddle/function/Function.h b/paddle/function/Function.h index 9215c137eb8e8..4a6c79b6ebdf8 100644 --- a/paddle/function/Function.h +++ b/paddle/function/Function.h @@ -153,7 +153,20 @@ class FunctionBase { virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} + int getNumInputs() const { return numInputs_; } + + int getNumOutputs() const { return numOutputs_; } + static ClassRegistrar funcRegistrar_; + +protected: + // numInputs_ and numOutputs_ represents the maximum + // input and output supported by Function. + // Some functions are optimized for input and output, + // so when comparing the number of arguments, for these functions + // inputs.size() <= numInputs_ or outputs.size() <= numOutputs_ + size_t numInputs_; + size_t numOutputs_; }; #define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName From 9896f15e7cabd5d68ec03157439a44bbb709c221 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 23 Jan 2017 12:44:03 +0800 Subject: [PATCH 2/4] Add FunctionBase::ops() --- paddle/function/CrossMapNormalOp.cpp | 30 ++++++++++++++++++++-------- paddle/function/Function.h | 7 +++++++ 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp index 3fab2127a1511..8749a48327604 100644 --- a/paddle/function/CrossMapNormalOp.cpp +++ b/paddle/function/CrossMapNormalOp.cpp @@ -182,23 +182,37 @@ class CrossMapNormalFunc : public FunctionBase { CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO); - size_t samples = inputs[0].shape()[0]; - size_t channels = inputs[0].shape()[1]; - size_t height = inputs[0].shape()[2]; - size_t width = inputs[0].shape()[3]; + size_t batchSize = inputs[0].shape()[0]; + size_t maps = inputs[0].shape()[1]; + size_t rows = inputs[0].shape()[2]; + size_t columns = inputs[0].shape()[3]; CrossMapNormal(outputs[0].data(), outputs[1].data(), inputs[0].data(), - samples, - channels, - height, - width, + batchSize, + maps, + rows, + columns, size_, scale_, pow_); } + // Only need the shape of the input, can calculate the + // floating-point operation. + size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ((size_t)numInputs_, inputs.size()); + size_t batchSize = inputs[0].shape()[0]; + size_t maps = inputs[0].shape()[1]; + size_t rows = inputs[0].shape()[2]; + size_t columns = inputs[0].shape()[3]; + + // number of floating-point operations + // an approximate value + size_t ops = batchSize * maps * ((rows * columns) * size_); + } + private: size_t size_; real scale_; diff --git a/paddle/function/Function.h b/paddle/function/Function.h index 4a6c79b6ebdf8..65688eebee975 100644 --- a/paddle/function/Function.h +++ b/paddle/function/Function.h @@ -153,6 +153,13 @@ class FunctionBase { virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} + // Calculate the number of floating-point operations of this Function. + // The inputs and outputs arguments do not need to contain the actual data, + // only the shape. + virtual size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) { + return 0; + } + int getNumInputs() const { return numInputs_; } int getNumOutputs() const { return numOutputs_; } From c4437fa2312b7550fef89ddac00d057361804385 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 23 Jan 2017 13:07:16 +0800 Subject: [PATCH 3/4] Add FunctionBase::check() --- paddle/function/CrossMapNormalOp.cpp | 21 ++++++++++++++------- paddle/function/Function.h | 6 ++++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp index 8749a48327604..99af02ac74414 100644 --- a/paddle/function/CrossMapNormalOp.cpp +++ b/paddle/function/CrossMapNormalOp.cpp @@ -173,13 +173,9 @@ class CrossMapNormalFunc : public FunctionBase { } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)numInputs_, inputs.size()); - CHECK_EQ((size_t)numOutputs_, outputs.size()); - - CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); - CHECK(inputs[0].shape() == outputs[0].shape()); - CHECK(inputs[0].shape() == outputs[1].shape()); - + check(inputs, outputs); + // ArgType check still on here, + // not sure whether it is better to put inside the check. CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO); size_t batchSize = inputs[0].shape()[0]; @@ -199,6 +195,15 @@ class CrossMapNormalFunc : public FunctionBase { pow_); } + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ((size_t)numInputs_, inputs.size()); + CHECK_EQ((size_t)numOutputs_, outputs.size()); + + CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); + CHECK(inputs[0].shape() == outputs[0].shape()); + CHECK(inputs[0].shape() == outputs[1].shape()); + } + // Only need the shape of the input, can calculate the // floating-point operation. size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override { @@ -211,6 +216,8 @@ class CrossMapNormalFunc : public FunctionBase { // number of floating-point operations // an approximate value size_t ops = batchSize * maps * ((rows * columns) * size_); + + return ops; } private: diff --git a/paddle/function/Function.h b/paddle/function/Function.h index 65688eebee975..4802c2e846cfa 100644 --- a/paddle/function/Function.h +++ b/paddle/function/Function.h @@ -153,6 +153,12 @@ class FunctionBase { virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} + // This member function is used to check whether the BufferType and shape of + // the inputs and outputs arguments of the Function are correct. + // General calc function which will call this check to do arguments check. + // Also before the call calc, the caller can also check their own arguments. + virtual void check(const BufferArgs& inputs, const BufferArgs& outputs) {} + // Calculate the number of floating-point operations of this Function. // The inputs and outputs arguments do not need to contain the actual data, // only the shape. From a9228e2a406ecb3588ea0c2d112971260d87e1a3 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 23 Jan 2017 13:49:19 +0800 Subject: [PATCH 4/4] Fix CrossMapNormalGradFunc --- paddle/function/CrossMapNormalOp.cpp | 59 ++++++++++++++++++---------- paddle/function/Function.h | 5 ++- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp index 99af02ac74414..ef878bfbba961 100644 --- a/paddle/function/CrossMapNormalOp.cpp +++ b/paddle/function/CrossMapNormalOp.cpp @@ -196,8 +196,8 @@ class CrossMapNormalFunc : public FunctionBase { } void check(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)numInputs_, inputs.size()); - CHECK_EQ((size_t)numOutputs_, outputs.size()); + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); CHECK(inputs[0].shape() == outputs[0].shape()); @@ -215,7 +215,7 @@ class CrossMapNormalFunc : public FunctionBase { // number of floating-point operations // an approximate value - size_t ops = batchSize * maps * ((rows * columns) * size_); + size_t ops = batchSize * maps * rows * columns * (size_ * 2 + 3); return ops; } @@ -273,15 +273,7 @@ class CrossMapNormalGradFunc : public FunctionBase { } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)numInputs_, inputs.size()); - CHECK_EQ((size_t)numOutputs_, outputs.size()); - - CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); - CHECK(inputs[0].shape() == inputs[1].shape()); - CHECK(inputs[0].shape() == inputs[2].shape()); - CHECK(inputs[0].shape() == inputs[3].shape()); - CHECK(inputs[0].shape() == outputs[0].shape()); - + check(inputs, outputs); if (outputs[0].getArgType() != ADD_TO) { // Currently, some algorithm implementations are ASSIGN_TO mode, // if need to support the ADD_TO calculation, need to clear the output. @@ -290,25 +282,52 @@ class CrossMapNormalGradFunc : public FunctionBase { tmp.zero(); } - size_t samples = inputs[0].shape()[0]; - size_t channels = inputs[0].shape()[1]; - size_t height = inputs[0].shape()[2]; - size_t width = inputs[0].shape()[3]; + size_t batchSize = inputs[0].shape()[0]; + size_t maps = inputs[0].shape()[1]; + size_t rows = inputs[0].shape()[2]; + size_t columns = inputs[0].shape()[3]; CrossMapNormalGrad(outputs[0].data(), inputs[0].data(), inputs[1].data(), inputs[2].data(), inputs[3].data(), - samples, - channels, - height, - width, + batchSize, + maps, + rows, + columns, size_, scale_, pow_); } + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + + CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); + CHECK(inputs[0].shape() == inputs[1].shape()); + CHECK(inputs[0].shape() == inputs[2].shape()); + CHECK(inputs[0].shape() == inputs[3].shape()); + CHECK(inputs[0].shape() == outputs[0].shape()); + } + + // Only need the shape of one input, can calculate the + // floating-point operation. + size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_LT((size_t)1, inputs.size()); + size_t batchSize = inputs[0].shape()[0]; + size_t maps = inputs[0].shape()[1]; + size_t rows = inputs[0].shape()[2]; + size_t columns = inputs[0].shape()[3]; + + // number of floating-point operations + // an approximate value + size_t ops = batchSize * maps * rows * columns * (size_ * 4 + 2); + + return ops; + } + private: size_t size_; real scale_; diff --git a/paddle/function/Function.h b/paddle/function/Function.h index 4802c2e846cfa..3bbeb6e525f85 100644 --- a/paddle/function/Function.h +++ b/paddle/function/Function.h @@ -156,12 +156,15 @@ class FunctionBase { // This member function is used to check whether the BufferType and shape of // the inputs and outputs arguments of the Function are correct. // General calc function which will call this check to do arguments check. - // Also before the call calc, the caller can also check their own arguments. + // And before the calc called, the caller can also check their own arguments. virtual void check(const BufferArgs& inputs, const BufferArgs& outputs) {} // Calculate the number of floating-point operations of this Function. // The inputs and outputs arguments do not need to contain the actual data, // only the shape. + // And some Functions have the same input and output shapes, + // so you may not need to enter the complete number of arguments. + // But entering the full arguments is always correct for this interface. virtual size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) { return 0; }