-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Simplify the testOnePeriod method. #731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,22 +17,22 @@ limitations under the License. */ | |
#include <fenv.h> | ||
#include <stdio.h> | ||
|
||
#include <iostream> | ||
#include <iomanip> | ||
#include <sstream> | ||
#include <iostream> | ||
#include <limits> | ||
#include <sstream> | ||
|
||
#include <google/protobuf/text_format.h> | ||
|
||
#include "paddle/utils/GlobalConstants.h" | ||
#include "paddle/utils/PythonUtil.h" | ||
#include "paddle/utils/Stat.h" | ||
#include "paddle/utils/Util.h" | ||
#include "paddle/utils/GlobalConstants.h" | ||
|
||
#include "TesterConfig.h" | ||
#include "paddle/gserver/gradientmachines/GradientMachineMode.h" | ||
#include "paddle/gserver/gradientmachines/NeuralNetwork.h" | ||
#include "paddle/gserver/layers/ValidationLayer.h" | ||
#include "paddle/gserver/gradientmachines/GradientMachineMode.h" | ||
#include "TesterConfig.h" | ||
|
||
namespace paddle { | ||
|
||
|
@@ -66,6 +66,9 @@ Tester::Tester(const std::shared_ptr<TrainerConfigHelper>& config, | |
} | ||
|
||
void Tester::startTestPeriod() { | ||
if (testDataProvider_) { | ||
testDataProvider_->reset(); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在swig api里面调用时,testDataProvider_可能是不存在的。 swig里面可以直接调用startTestPeriod(); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okey 按理应该在swig里加判断来fix这个问题。这里我先pass了吧 |
||
testEvaluator_->start(); | ||
testContext_.cost = 0; | ||
testContext_.numSamples = 0; | ||
|
@@ -87,27 +90,18 @@ void Tester::testOneDataBatch(const DataBatch& dataBatch, | |
void Tester::testOnePeriod() { | ||
DataBatch dataBatch; | ||
int64_t batchSize = config_->getOptConfig().batch_size(); | ||
|
||
int batches = std::numeric_limits<int>::max(); | ||
|
||
std::vector<Argument> outArgs; | ||
|
||
startTestPeriod(); | ||
for (int i = 0; i < batches; ++i) { | ||
int num = testDataProvider_->getNextBatch(batchSize, &dataBatch); | ||
if (num == 0) { | ||
testDataProvider_->reset(); | ||
if (intconfig_->prevBatchState) { | ||
gradientMachine_->resetState(); | ||
} | ||
break; | ||
} | ||
while (testDataProvider_->getNextBatch(batchSize, &dataBatch) != 0) { | ||
testOneDataBatch(dataBatch, &outArgs); | ||
} | ||
finishTestPeriod(); | ||
} | ||
|
||
void Tester::finishTestPeriod() { | ||
if (intconfig_->prevBatchState) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 光从patch上看,似乎语意不一致 了,这个testDataProvider_->reset(); 没有了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 放到startTestPeriod里面了。。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okey, 没有问题了 |
||
gradientMachine_->resetState(); | ||
} | ||
testEvaluator_->finish(); | ||
CHECK_GT(testContext_.numSamples, 0) | ||
<< "There is no samples in your test batch. Possibly " | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,36 +17,38 @@ limitations under the License. */ | |
#include <fenv.h> | ||
#include <stdio.h> | ||
|
||
#include <iostream> | ||
#include <iomanip> | ||
#include <sstream> | ||
#include <iostream> | ||
#include <limits> | ||
#include <sstream> | ||
|
||
#include <google/protobuf/text_format.h> | ||
|
||
#include "paddle/utils/Excepts.h" | ||
#include "paddle/utils/GlobalConstants.h" | ||
#include "paddle/utils/PythonUtil.h" | ||
#include "paddle/utils/Stat.h" | ||
#include "paddle/utils/Util.h" | ||
#include "paddle/utils/Excepts.h" | ||
#include "paddle/utils/GlobalConstants.h" | ||
|
||
#include "paddle/gserver/gradientmachines/NeuralNetwork.h" | ||
#include "paddle/gserver/gradientmachines/GradientMachineMode.h" | ||
#include "paddle/gserver/layers/ValidationLayer.h" | ||
#include "RemoteParameterUpdater.h" | ||
#include "TesterConfig.h" | ||
#include "ThreadParameterUpdater.h" | ||
#include "RemoteParameterUpdater.h" | ||
#include "TrainerConfigHelper.h" | ||
#include "paddle/gserver/gradientmachines/GradientMachineMode.h" | ||
#include "paddle/gserver/gradientmachines/NeuralNetwork.h" | ||
#include "paddle/gserver/layers/ValidationLayer.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 看头文件改了很多呀, 这个为什么? 为了好看? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
P_DEFINE_string(config, "", "Trainer config file"); | ||
|
||
P_DEFINE_int32(test_period, 0, | ||
P_DEFINE_int32(test_period, | ||
0, | ||
"if equal 0, do test on all test data at the end of " | ||
"each pass. While if equal non-zero, do test on all test " | ||
"data every test_period batches"); | ||
P_DEFINE_bool(test_all_data_in_one_period, false, | ||
"This option was deprecated, since we will always do " | ||
"test on all test set "); | ||
P_DEFINE_bool(test_all_data_in_one_period, | ||
false, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great |
||
"This option was deprecated, since we will always do " | ||
"test on all test set "); | ||
|
||
P_DEFINE_bool(local, true, "Train in local mode or not"); | ||
|
||
|
@@ -392,10 +394,6 @@ void Trainer::startTrain() { | |
dataProvider_->reset(); | ||
} | ||
|
||
if (this->testDataProvider_) { | ||
this->testDataProvider_->reset(); | ||
} | ||
|
||
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_); | ||
} | ||
|
||
|
@@ -630,16 +628,14 @@ void Trainer::test() { tester_->test(); } | |
std::unique_ptr<TesterConfig> Trainer::createTesterConfig() { | ||
TesterConfig* conf = new TesterConfig; | ||
if (FLAGS_test_period) { | ||
LOG(WARNING) | ||
<< "The meaning of --test_period is changed: " | ||
<< "if equal 0, do test on all test data at the end of " | ||
<< "each pass. While if equal non-zero, do test on all test " | ||
<< "data every test_period batches "; | ||
LOG(WARNING) << "The meaning of --test_period is changed: " | ||
<< "if equal 0, do test on all test data at the end of " | ||
<< "each pass. While if equal non-zero, do test on all test " | ||
<< "data every test_period batches "; | ||
} | ||
if (FLAGS_test_all_data_in_one_period) { | ||
LOG(WARNING) | ||
<< "--test_all_data_in_one_period was deprecated, since " | ||
<< "we will always do test on all test set "; | ||
LOG(WARNING) << "--test_all_data_in_one_period was deprecated, since " | ||
<< "we will always do test on all test set "; | ||
} | ||
conf->testPeriod = FLAGS_test_period; | ||
conf->prevBatchState = FLAGS_prev_batch_state; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sphinx_rtd_theme 在 #724 中已经添加并merge了,这里不用再加了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
并没有conflict :)