Skip to content

Simplify the testOnePeriod method. #731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 6, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ before_install:
fi
- if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then sudo paddle/scripts/travis/before_install.linux.sh; fi
- if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then paddle/scripts/travis/before_install.osx.sh; fi
- pip install wheel protobuf sphinx breathe recommonmark virtualenv numpy
- pip install wheel protobuf sphinx breathe recommonmark virtualenv numpy sphinx_rtd_theme
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sphinx_rtd_theme 在 #724 中已经添加并merge了,这里不用再加了

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

并没有conflict :)

script:
- paddle/scripts/travis/main.sh
notifications:
Expand Down
6 changes: 1 addition & 5 deletions paddle/api/test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@ popd > /dev/null

cd $SCRIPTPATH

if [ ! -f ../../dist/*.whl ] ; then # Swig not compiled.
exit 0
fi

rm .test_env -rf
rm -rf .test_env
virtualenv .test_env
source .test_env/bin/activate

Expand Down
30 changes: 12 additions & 18 deletions paddle/trainer/Tester.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@ limitations under the License. */
#include <fenv.h>
#include <stdio.h>

#include <iostream>
#include <iomanip>
#include <sstream>
#include <iostream>
#include <limits>
#include <sstream>

#include <google/protobuf/text_format.h>

#include "paddle/utils/GlobalConstants.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Stat.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/GlobalConstants.h"

#include "TesterConfig.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/gserver/layers/ValidationLayer.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "TesterConfig.h"

namespace paddle {

Expand Down Expand Up @@ -66,6 +66,9 @@ Tester::Tester(const std::shared_ptr<TrainerConfigHelper>& config,
}

void Tester::startTestPeriod() {
if (testDataProvider_) {
testDataProvider_->reset();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

void Trainer::finishTrainPass()
void Trainer::trainOneDataBatch(DataBatch& dataBatch)
看了下代码,上两处都已经check 过了 if (testDataProvider_) 了, 这里可以不要判断if了, 默认走到这个函数的话testDataProvider_一定是存在的, 。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在swig api里面调用时,testDataProvider_可能是不存在的。

swig里面可以直接调用startTestPeriod();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okey

按理应该在swig里加判断来fix这个问题。这里我先pass了吧

testEvaluator_->start();
testContext_.cost = 0;
testContext_.numSamples = 0;
Expand All @@ -87,27 +90,18 @@ void Tester::testOneDataBatch(const DataBatch& dataBatch,
void Tester::testOnePeriod() {
DataBatch dataBatch;
int64_t batchSize = config_->getOptConfig().batch_size();

int batches = std::numeric_limits<int>::max();

std::vector<Argument> outArgs;

startTestPeriod();
for (int i = 0; i < batches; ++i) {
int num = testDataProvider_->getNextBatch(batchSize, &dataBatch);
if (num == 0) {
testDataProvider_->reset();
if (intconfig_->prevBatchState) {
gradientMachine_->resetState();
}
break;
}
while (testDataProvider_->getNextBatch(batchSize, &dataBatch) != 0) {
testOneDataBatch(dataBatch, &outArgs);
}
finishTestPeriod();
}

void Tester::finishTestPeriod() {
if (intconfig_->prevBatchState) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

光从patch上看,似乎语意不一致 了,这个testDataProvider_->reset(); 没有了。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放到startTestPeriod里面了。。
其实应该是每次开始读取数据之前reset。之前逻辑比较复杂是因为要处理的事情太多,简化完了以后DataProvider的reset就可以在和train一样的位置上了。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okey, 没有问题了

gradientMachine_->resetState();
}
testEvaluator_->finish();
CHECK_GT(testContext_.numSamples, 0)
<< "There is no samples in your test batch. Possibly "
Expand Down
44 changes: 20 additions & 24 deletions paddle/trainer/Trainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,38 @@ limitations under the License. */
#include <fenv.h>
#include <stdio.h>

#include <iostream>
#include <iomanip>
#include <sstream>
#include <iostream>
#include <limits>
#include <sstream>

#include <google/protobuf/text_format.h>

#include "paddle/utils/Excepts.h"
#include "paddle/utils/GlobalConstants.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Stat.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/Excepts.h"
#include "paddle/utils/GlobalConstants.h"

#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "paddle/gserver/layers/ValidationLayer.h"
#include "RemoteParameterUpdater.h"
#include "TesterConfig.h"
#include "ThreadParameterUpdater.h"
#include "RemoteParameterUpdater.h"
#include "TrainerConfigHelper.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/gserver/layers/ValidationLayer.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看头文件改了很多呀, 这个为什么? 为了好看?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


P_DEFINE_string(config, "", "Trainer config file");

P_DEFINE_int32(test_period, 0,
P_DEFINE_int32(test_period,
0,
"if equal 0, do test on all test data at the end of "
"each pass. While if equal non-zero, do test on all test "
"data every test_period batches");
P_DEFINE_bool(test_all_data_in_one_period, false,
"This option was deprecated, since we will always do "
"test on all test set ");
P_DEFINE_bool(test_all_data_in_one_period,
false,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great

"This option was deprecated, since we will always do "
"test on all test set ");

P_DEFINE_bool(local, true, "Train in local mode or not");

Expand Down Expand Up @@ -392,10 +394,6 @@ void Trainer::startTrain() {
dataProvider_->reset();
}

if (this->testDataProvider_) {
this->testDataProvider_->reset();
}

trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
}

Expand Down Expand Up @@ -630,16 +628,14 @@ void Trainer::test() { tester_->test(); }
std::unique_ptr<TesterConfig> Trainer::createTesterConfig() {
TesterConfig* conf = new TesterConfig;
if (FLAGS_test_period) {
LOG(WARNING)
<< "The meaning of --test_period is changed: "
<< "if equal 0, do test on all test data at the end of "
<< "each pass. While if equal non-zero, do test on all test "
<< "data every test_period batches ";
LOG(WARNING) << "The meaning of --test_period is changed: "
<< "if equal 0, do test on all test data at the end of "
<< "each pass. While if equal non-zero, do test on all test "
<< "data every test_period batches ";
}
if (FLAGS_test_all_data_in_one_period) {
LOG(WARNING)
<< "--test_all_data_in_one_period was deprecated, since "
<< "we will always do test on all test set ";
LOG(WARNING) << "--test_all_data_in_one_period was deprecated, since "
<< "we will always do test on all test set ";
}
conf->testPeriod = FLAGS_test_period;
conf->prevBatchState = FLAGS_prev_batch_state;
Expand Down