Skip to content

support testing when training and handle dropout and batch_norm operator in testing mode #5734

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 24, 2017
2 changes: 1 addition & 1 deletion paddle/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,

for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(10) << op->DebugString();
VLOG(3) << op->DebugString();
op->Run(*local_scope, *device);
}
if (create_local_scope) {
Expand Down
23 changes: 23 additions & 0 deletions paddle/framework/prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ namespace framework {

const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
const std::string kDropOutOpType = "dropout";
const std::string kBatchNormOpType = "batch_norm";

bool HasDependentVar(const OpDesc& op_desc,
const std::set<std::string>& dependent_vars) {
Expand Down Expand Up @@ -106,5 +108,26 @@ void Prune(const ProgramDesc& input, ProgramDesc* output) {
prune_impl(input, output, 0);
}

void inference_optimize_impl(const ProgramDesc& input, ProgramDesc* output,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic of inference_optimize_impl is quite simple. Maybe we can implement it in Python.

int block_id) {
*output = input;
auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
for (auto& op_desc : *op_field) {
if (op_desc.type() == kDropOutOpType ||
op_desc.type() == kBatchNormOpType) {
for (auto& attr : *op_desc.mutable_attrs()) {
if (attr.name() == "is_test") {
attr.set_b(true);
break;
}
}
}
}
}

void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output) {
inference_optimize_impl(input, output, 0);
}

} // namespace framework
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/framework/prune.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,7 @@ namespace framework {

void Prune(const ProgramDesc& input, ProgramDesc* output);

void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output);

} // namespace framework
} // namespace paddle
8 changes: 4 additions & 4 deletions paddle/operators/dropout_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DropoutOp : public framework::OperatorWithKernel {

auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
if (ctx->Attrs().Get<bool>("is_training") == true) {
if (ctx->Attrs().Get<bool>("is_test") == false) {
ctx->SetOutputDim("Mask", x_dims);
}
ctx->ShareLoD("X", /*->*/ "Out");
Expand All @@ -49,7 +49,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {

AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f);
AddAttr<bool>("is_training", "True if in training phase.").SetDefault(true);
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);

AddComment(R"DOC(
Expand All @@ -71,8 +71,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), true,
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_test"), false,
"GradOp is only callable when is_test is false");

PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mask"), "Mask must not be null.");
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/dropout_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
auto Y = EigenMatrix<T>::Reshape(*y, 1);

auto place = context.GetEigenDevice<Place>();
if (context.Attr<bool>("is_training")) {
if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
int size = framework::product(mask->dims());
Expand Down
6 changes: 3 additions & 3 deletions paddle/operators/dropout_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto* y_data = y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob");

if (context.Attr<bool>("is_training")) {
if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
int seed = context.Attr<int>("seed");
Expand Down Expand Up @@ -65,8 +65,8 @@ template <typename Place, typename T>
class DropoutGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(context.Attr<bool>("is_training"),
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE(!context.Attr<bool>("is_test"),
"GradOp is only callable when is_test is false");

auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
Expand Down
5 changes: 5 additions & 0 deletions paddle/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,11 @@ All parameter, weight, gradient are variables in Paddle.
Prune(*prog_with_targets.Proto(), &pruned_desc);
return new ProgramDescBind(pruned_desc);
});
m.def("inference_optimize", [](ProgramDescBind &origin) {
ProgramDesc pruned_desc;
InferenceOptimize(*(origin.Proto()), &pruned_desc);
return new ProgramDescBind(pruned_desc);
});
m.def_submodule(
"var_names",
"The module will return special predefined variable name in Paddle")
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/v2/fluid/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def __init__(self, name, **kwargs):
else:
self._main_program = g_main_program

def states(self):
return self._states

def _update_ops(self, *args, **kwargs):
"""
append update ops to the global states
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/v2/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,13 @@ def prune(self, targets):
res.sync_with_cpp()
return res

def inference_optimize(self):
res = Program()
res.desc = core.inference_optimize(self.desc)
res.blocks = [Block(res, i) for i in xrange(res.desc.num_blocks())]
res.sync_with_cpp()
return res

@staticmethod
def parse_from_string(binary_str):
p = Program()
Expand Down
19 changes: 16 additions & 3 deletions python/paddle/v2/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

__all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', "save_inference_model", "load_inference_model"
'load_persistables', "save_inference_model", "load_inference_model",
"get_inference_program"
]


Expand Down Expand Up @@ -151,6 +152,17 @@ def load_persistables(executor, dirname, main_program=None):
predicate=is_persistable)


def get_inference_program(target_vars, main_program=None):
if main_program is None:
main_program = g_main_program
if not isinstance(target_vars, list):
target_vars = [target_vars]

pruned_program = main_program.prune(targets=target_vars)
inference_program = pruned_program.inference_optimize()
return inference_program


def save_inference_model(dirname,
feeded_var_names,
target_vars,
Expand All @@ -177,13 +189,14 @@ def save_inference_model(dirname,
if not os.path.isdir(dirname):
os.makedirs(dirname)

pruned_program = main_program.prune(target_vars)
pruned_program = main_program.prune(targets=target_vars)
inference_program = pruned_program.inference_optimize()
fetch_var_names = [v.name for v in target_vars]

model_file_name = dirname + "/__model__"
with open(model_file_name, "w") as f:
pickle.dump({
"program_desc_str": pruned_program.desc.serialize_to_string(),
"program_desc_str": inference_program.desc.serialize_to_string(),
"feed_var_names": feeded_var_names,
"fetch_var_names": fetch_var_names
}, f, -1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
import paddle.v2.fluid.evaluator as evaluator
from paddle.v2.fluid.io import get_inference_program
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.initializer import XavierInitializer
from paddle.v2.fluid.optimizer import AdamOptimizer
Expand Down Expand Up @@ -116,9 +117,11 @@ def conv_block(input, num_filter, groups, dropouts):

train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=128 * 10),
paddle.dataset.cifar.train10(), buf_size=BATCH_SIZE * 10),
batch_size=BATCH_SIZE)

test_reader = paddle.batch(paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE)

place = core.CPUPlace()
exe = Executor(place)

Expand Down Expand Up @@ -149,10 +152,41 @@ def conv_block(input, num_filter, groups, dropouts):
loss = np.array(outs[0])
acc = np.array(outs[1])
pass_acc = accuracy.eval(exe)

batch_id = batch_id + 1

test_accuracy, test_acc_out = evaluator.accuracy(
input=predict, label=label)

test_target = [avg_cost, test_acc_out] + test_accuracy.states().values()
inference_program = get_inference_program(test_target)

test_accuracy.reset(exe)

for data in test_reader():
x_data = np.array(map(lambda x: x[0].reshape(data_shape),
data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = np.expand_dims(y_data, axis=1)

tensor_x = core.LoDTensor()
tensor_x.set(x_data, place)

tensor_y = core.LoDTensor()
tensor_y.set(y_data, place)

outs = exe.run(inference_program,
feed={'pixel': tensor_x,
'label': tensor_y},
fetch_list=[avg_cost, test_acc_out])
out = np.array(outs[0])
acc = np.array(outs[1])

test_pass_acc = test_accuracy.eval(exe)

print("pass_id:" + str(pass_id) + " batch_id:" + str(batch_id) +
" loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str(
pass_acc))
batch_id = batch_id + 1
pass_acc) + " test_pass_acc:" + str(test_pass_acc))

if batch_id > 1:
# this model is slow, so if we can train two mini batch, we think it works properly.
Expand Down
37 changes: 34 additions & 3 deletions python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import paddle.v2.fluid.framework as framework
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.evaluator as evaluator
from paddle.v2.fluid.io import get_inference_program
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.initializer import UniformInitializer
from paddle.v2.fluid.optimizer import MomentumOptimizer
Expand Down Expand Up @@ -42,6 +43,8 @@
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=BATCH_SIZE)

test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)

place = core.CPUPlace()
exe = Executor(place)

Expand Down Expand Up @@ -69,8 +72,36 @@
acc = np.array(outs[1])
pass_acc = accuracy.eval(exe)

if pass_acc > 0.7:
test_accuracy, test_acc_out = evaluator.accuracy(
input=predict, label=label)

test_target = [avg_cost, test_acc_out] + test_accuracy.states().values()
inference_program = get_inference_program(test_target)

test_accuracy.reset(exe)
for data in test_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = np.expand_dims(y_data, axis=1)

tensor_x = core.LoDTensor()
tensor_x.set(x_data, place)

tensor_y = core.LoDTensor()
tensor_y.set(y_data, place)

outs = exe.run(inference_program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost, test_acc_out])
out = np.array(outs[0])
acc = np.array(outs[1])

test_pass_acc = test_accuracy.eval(exe)
print("pass_id=" + str(pass_id) + " train_cost=" + str(
out) + " train_acc=" + str(acc) + " train_pass_acc=" + str(pass_acc)
+ " test_acc=" + str(test_pass_acc))

if test_pass_acc > 0.7:
exit(0)
# print("pass_id=" + str(pass_id) + " auc=" +
# str(acc) + " pass_acc=" + str(pass_acc))
exit(1)
10 changes: 5 additions & 5 deletions python/paddle/v2/fluid/tests/test_dropout_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class TestDropoutOp(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'is_training': True}
self.attrs = {'dropout_prob': 0.0, 'is_test': False}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64)).astype('float32')
Expand All @@ -24,7 +24,7 @@ class TestDropoutOp2(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 1.0, 'is_training': True}
self.attrs = {'dropout_prob': 1.0, 'is_test': False}
self.outputs = {
'Out': np.zeros((32, 64)).astype('float32'),
'Mask': np.zeros((32, 64)).astype('float32')
Expand All @@ -35,7 +35,7 @@ class TestDropoutOp3(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'is_training': True}
self.attrs = {'dropout_prob': 0.0, 'is_test': False}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64, 2)).astype('float32')
Expand All @@ -46,7 +46,7 @@ class TestDropoutOp4(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.35, 'is_training': False}
self.attrs = {'dropout_prob': 0.35, 'is_test': True}
self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}

def test_check_output(self):
Expand All @@ -57,7 +57,7 @@ class TestDropoutOp5(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
self.attrs = {'dropout_prob': 0.75, 'is_training': False}
self.attrs = {'dropout_prob': 0.75, 'is_test': True}
self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}

def test_check_output(self):
Expand Down