Skip to content

Benchmark/mnist #5680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
129 changes: 129 additions & 0 deletions benchmark/tensorflow/image/refactor_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import tensorflow.python.platform
import tensorflow as tf
import paddle.v2 as paddle
import paddle.v2.fluid.framework as framework
import paddle.v2.fluid.initializer as initializer
import paddle.v2.fluid.core as core
from paddle.v2.fluid.executor import Executor
import numpy as np
import time

BATCH_SIZE = 128
PASS_NUM = 5
SEED = 1
DTYPE = tf.float32


def normal_scale(size, channels):
scale = (2.0 / (size**2 * channels))**0.5
return scale


# NOTE(dzhwinter) : tensorflow use Phliox random algorithm
# as normal generator, fetch out paddle random for comparization
def paddle_random_normal(shape, loc=.0, scale=1., seed=1, dtype="float32"):
program = framework.Program()
block = program.global_block()
w = block.create_var(
dtype="float32",
shape=shape,
lod_level=0,
name="param",
initializer=initializer.NormalInitializer(
loc=.0, scale=scale, seed=seed))
place = core.CPUPlace()
exe = Executor(place)
out = exe.run(program, fetch_list=[w])
return np.array(out[0])


train_reader = paddle.batch(paddle.dataset.mnist.train(), batch_size=BATCH_SIZE)
images = tf.placeholder(DTYPE, shape=(None, 28, 28, 1))
labels = tf.placeholder(tf.int64, shape=(None, ))

# conv layer
arg = tf.convert_to_tensor(
np.transpose(
paddle_random_normal(
[20, 1, 5, 5], scale=normal_scale(5, 1), seed=SEED, dtype=DTYPE),
axes=[2, 3, 1, 0]))
conv1_weights = tf.Variable(arg)
conv1_bias = tf.Variable(tf.zeros([20]), dtype=DTYPE)
conv1 = tf.nn.conv2d(
images, conv1_weights, strides=[1, 1, 1, 1], padding="VALID")
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_bias))
pool1 = tf.nn.max_pool(
relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")

arg = tf.convert_to_tensor(
np.transpose(
paddle_random_normal(
[50, 20, 5, 5], scale=normal_scale(5, 20), seed=SEED, dtype=DTYPE),
axes=[2, 3, 1, 0]))
conv2_weights = tf.Variable(arg)
conv2_bias = tf.Variable(tf.zeros([50]), dtype=DTYPE)
conv2 = tf.nn.conv2d(
pool1, conv2_weights, strides=[1, 1, 1, 1], padding="VALID")
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_bias))
pool2 = tf.nn.max_pool(
relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")

pool_shape = pool2.get_shape().as_list()
hidden_dim = reduce(lambda a, b: a * b, pool_shape[1:], 1)
reshape = tf.reshape(pool2, shape=(tf.shape(pool2)[0], hidden_dim))

# fc layer
# NOTE(dzhwinter) : paddle has a NCHW data format, tensorflow has a NHWC data format
# need to convert the fc weight
paddle_weight = paddle_random_normal(
[hidden_dim, 10],
scale=normal_scale(hidden_dim, 10),
seed=SEED,
dtype=DTYPE)
new_shape = pool_shape[-1:] + pool_shape[1:-1] + [10]
paddle_weight = np.reshape(paddle_weight, new_shape)
paddle_weight = np.transpose(paddle_weight, [1, 2, 0, 3])

arg = tf.convert_to_tensor(np.reshape(paddle_weight, [hidden_dim, 10]))
fc_weights = tf.Variable(arg, dtype=DTYPE)
fc_bias = tf.Variable(tf.zeros([10]), dtype=DTYPE)
logits = tf.matmul(reshape, fc_weights) + fc_bias

# cross entropy

prediction = tf.nn.softmax(logits)

one_hot_labels = tf.one_hot(labels, depth=10)
cost = -tf.reduce_sum(tf.log(prediction) * one_hot_labels, [1])
avg_cost = tf.reduce_mean(cost)

correct = tf.equal(tf.argmax(prediction, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
g_accuracy = tf.metrics.accuracy(labels, tf.argmax(prediction, axis=1))

optimizer = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.9, beta2=0.999)
train_op = optimizer.minimize(avg_cost)

with tf.Session() as sess:
init_g = tf.global_variables_initializer()
init_l = tf.local_variables_initializer()
sess.run(init_g)
sess.run(init_l)
pass_start = time.clock()
for pass_id in range(PASS_NUM):
for batch_id, data in enumerate(train_reader()):
images_data = np.array(
map(lambda x: np.transpose(x[0].reshape([1, 28, 28]), axes=[1,2,0]), data)).astype("float32")
labels_data = np.array(map(lambda x: x[1], data)).astype("int64")
start = time.clock()
_, loss, acc, g_acc = sess.run(
[train_op, avg_cost, accuracy, g_accuracy],
feed_dict={images: images_data,
labels: labels_data})
end = time.clock()
# print g_acc

print "pass=%d, batch=%d, loss=%f, error=%f, elapse=%f" % (
pass_id, batch_id, loss, 1 - acc, (end - start) / 1000)
print "pass=%d, accuracy=%f, elapse=%f" % (pass_id, g_acc[0], (
time.clock() - pass_start) / 1000)
8 changes: 4 additions & 4 deletions python/paddle/v2/fluid/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class UniformInitializer(Initializer):
"""Implements the random uniform distribution initializer
"""

def __init__(self, low=-1.0, high=1.0, seed=0):
def __init__(self, low=-1.0, high=1.0, seed=1):
"""Constructor for UniformInitializer

Args:
Expand Down Expand Up @@ -153,7 +153,7 @@ class NormalInitializer(Initializer):
"""Implements the random Normal(Gaussian) distribution initializer
"""

def __init__(self, loc=0.0, scale=1.0, seed=0):
def __init__(self, loc=0.0, scale=1.0, seed=1):
"""Constructor for NormalInitializer

Args:
Expand Down Expand Up @@ -217,7 +217,7 @@ class XavierInitializer(Initializer):
(http://proceedings.mlr.press/v9/glorot10a.html)
"""

def __init__(self, uniform=True, fan_in=None, fan_out=None, seed=0):
def __init__(self, uniform=True, fan_in=None, fan_out=None, seed=1):
"""Constructor for XavierInitializer

Args:
Expand Down Expand Up @@ -305,7 +305,7 @@ class MSRAInitializer(Initializer):
(https://arxiv.org/abs/1502.01852)
"""

def __init__(self, uniform=True, fan_in=None, seed=0):
def __init__(self, uniform=True, fan_in=None, seed=1):
"""Constructor for MSRAInitializer

Args:
Expand Down
7 changes: 4 additions & 3 deletions python/paddle/v2/fluid/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def _get_default_bias_initializer():
param_shape = [
reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1)
] + [size]

w = helper.create_parameter(
attr=param_attr,
initializer=param_initializer,
Expand All @@ -104,9 +105,9 @@ def _get_default_bias_initializer():
helper.append_op(
type="sum", inputs={"X": mul_results}, outputs={"Out": pre_bias})
# add bias
pre_activation = helper.append_bias_op(pre_bias, bias_initializer)
pre_act = helper.append_bias_op(pre_bias, bias_initializer)
# add activation
return helper.append_activation(pre_activation)
return helper.append_activation(pre_act)


def embedding(input,
Expand Down Expand Up @@ -726,7 +727,7 @@ def _get_default_bias_initializer():

def _get_default_param_initializer(filter_size, num_channels):
std = (2.0 / (filter_size[0]**2 * num_channels))**0.5
return NormalInitializer(0.0, std, 0)
return NormalInitializer(0.0, std, 1)

helper = LayerHelper('conv2d', **locals())
dtype = helper.input_dtype()
Expand Down
54 changes: 35 additions & 19 deletions python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,16 @@
import paddle.v2.fluid.nets as nets
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.optimizer import AdamOptimizer
from paddle.v2.fluid.initializer import NormalInitializer
import numpy as np
import time

BATCH_SIZE = 128
PASS_NUM = 5
SEED = 1
DTYPE = "float32"

images = layers.data(name='pixel', shape=[1, 28, 28], dtype='float32')
images = layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
label = layers.data(name='label', shape=[1], dtype='int64')
conv_pool_1 = nets.simple_img_conv_pool(
input=images,
Expand All @@ -25,20 +33,26 @@
pool_stride=2,
act="relu")

predict = layers.fc(input=conv_pool_2, size=10, act="softmax")
# TODO(dzhwinter) : refine the initializer and random seed settting
SIZE = 10
input_shape = conv_pool_2.shape
param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE]
scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5

predict = layers.fc(input=conv_pool_2,
size=SIZE,
act="softmax",
param_initializer=NormalInitializer(
loc=0.0, scale=scale, seed=SEED))

cost = layers.cross_entropy(input=predict, label=label)
avg_cost = layers.mean(x=cost)
optimizer = AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999)
optimizer = AdamOptimizer(learning_rate=0.001, beta1=0.9, beta2=0.999)
opts = optimizer.minimize(avg_cost)

accuracy, acc_out = evaluator.accuracy(input=predict, label=label)

BATCH_SIZE = 50
PASS_NUM = 3
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
train_reader = paddle.batch(paddle.dataset.mnist.train(), batch_size=BATCH_SIZE)

place = core.CPUPlace()
exe = Executor(place)
Expand All @@ -47,32 +61,34 @@

for pass_id in range(PASS_NUM):
accuracy.reset(exe)
for data in train_reader():
pass_start = time.clock()
for batch_id, data in enumerate(train_reader()):
img_data = np.array(map(lambda x: x[0].reshape([1, 28, 28]),
data)).astype("float32")
data)).astype(DTYPE)
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = y_data.reshape([BATCH_SIZE, 1])
y_data = y_data.reshape([len(y_data), 1])

tensor_img = core.LoDTensor()
tensor_y = core.LoDTensor()
tensor_img.set(img_data, place)
tensor_y.set(y_data, place)

start = time.clock()
outs = exe.run(framework.default_main_program(),
feed={"pixel": tensor_img,
"label": tensor_y},
fetch_list=[avg_cost, acc_out])
end = time.clock()
loss = np.array(outs[0])
acc = np.array(outs[1])
pass_acc = accuracy.eval(exe)
print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" +
str(pass_acc))
# print loss, acc
if loss < 10.0 and pass_acc > 0.9:
print "pass=%d, batch=%d, loss=%f, error=%f, elapse=%f" % (
pass_id, batch_id, loss, 1 - acc, (end - start) / 1000)

if loss < 10.0 and acc > 0.9:
# if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good.
exit(0)

pass_acc = accuracy.eval(exe)
print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc))

print "pass=%d, accuracy=%f, elapse=%f" % (pass_id, pass_acc, (
time.clock() - pass_start) / 1000)
exit(1)
20 changes: 10 additions & 10 deletions python/paddle/v2/fluid/tests/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_uniform_initializer_default_value(self):
self.assertEqual(init_op.type, 'uniform_random')
self.assertAlmostEqual(init_op.attr('min'), -1.0, delta=DELTA)
self.assertAlmostEqual(init_op.attr('max'), 1.0, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
self.assertEqual(init_op.attr('seed'), 1)

def test_uniform_initializer(self):
"""Test uniform initializer with supplied attributes
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_normal_initializer_default_value(self):
self.assertEqual(init_op.type, 'gaussian_random')
self.assertAlmostEqual(init_op.attr('mean'), 0.0, delta=DELTA)
self.assertAlmostEqual(init_op.attr('std'), 1.0, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
self.assertEqual(init_op.attr('seed'), 1)

def test_normal_initializer(self):
"""Test normal initializer with supplied attributes
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_uniform_xavier_initializer(self):
limit = np.sqrt(6.0 / (param.shape[0] + param.shape[1]))
self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA)
self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
self.assertEqual(init_op.attr('seed'), 1)

def test_uniform_xavier_initializer_conv(self):
"""Test Xavier initializer with uniform distribution on
Expand All @@ -158,7 +158,7 @@ def test_uniform_xavier_initializer_conv(self):
(param.shape[0] + param.shape[1]) * receptive_field_size))
self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA)
self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
self.assertEqual(init_op.attr('seed'), 1)

def test_normal_xavier_initializer(self):
"""Test Xavier initializer with normal distribution on
Expand All @@ -178,7 +178,7 @@ def test_normal_xavier_initializer(self):
std = np.sqrt(2.0 / (param.shape[0] + param.shape[1]))
self.assertAlmostEqual(init_op.attr('mean'), 0.0, delta=DELTA)
self.assertAlmostEqual(init_op.attr('std'), std, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
self.assertEqual(init_op.attr('seed'), 1)

def test_normal_xavier_initializer_conv(self):
"""Test Xavier initializer with normal distribution on
Expand All @@ -200,7 +200,7 @@ def test_normal_xavier_initializer_conv(self):
(param.shape[0] + param.shape[1]) * receptive_field_size))
self.assertAlmostEqual(init_op.attr('mean'), 0.0, delta=DELTA)
self.assertAlmostEqual(init_op.attr('std'), std, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
self.assertEqual(init_op.attr('seed'), 1)

def test_xavier_initializer_supplied_arguments(self):
"""Test the Xavier initializer with supplied arguments
Expand Down Expand Up @@ -242,7 +242,7 @@ def test_uniform_msra_initializer(self):
limit = np.sqrt(6.0 / param.shape[0])
self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA)
self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
self.assertEqual(init_op.attr('seed'), 1)

def test_uniform_msra_initializer_conv(self):
"""Test MSRA initializer with uniform distribution on
Expand All @@ -263,7 +263,7 @@ def test_uniform_msra_initializer_conv(self):
limit = np.sqrt(6.0 / (param.shape[1] * receptive_field_size))
self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA)
self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
self.assertEqual(init_op.attr('seed'), 1)

def test_normal_msra_initializer(self):
"""Test MSRA initializer with normal distribution on
Expand All @@ -283,7 +283,7 @@ def test_normal_msra_initializer(self):
std = np.sqrt(2.0 / param.shape[0])
self.assertAlmostEqual(init_op.attr('mean'), 0.0, delta=DELTA)
self.assertAlmostEqual(init_op.attr('std'), std, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
self.assertEqual(init_op.attr('seed'), 1)

def test_normal_msra_initializer_conv(self):
"""Test MSRA initializer with normal distribution on
Expand All @@ -304,7 +304,7 @@ def test_normal_msra_initializer_conv(self):
std = np.sqrt(2.0 / (param.shape[1] * receptive_field_size))
self.assertAlmostEqual(init_op.attr('mean'), 0.0, delta=DELTA)
self.assertAlmostEqual(init_op.attr('std'), std, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
self.assertEqual(init_op.attr('seed'), 1)

def test_msra_initializer_supplied_arguments(self):
"""Test the MSRA initializer with supplied arguments
Expand Down