|
6 | 6 | width = 227
|
7 | 7 | num_class = 1000
|
8 | 8 | batch_size = get_config_arg('batch_size', int, 128)
|
| 9 | +gp = get_config_arg('layer_num', int, 1) |
| 10 | +is_infer = get_config_arg("is_infer", bool, False) |
| 11 | +num_samples = get_config_arg('num_samples', int, 2560) |
9 | 12 |
|
10 |
| -args = {'height': height, 'width': width, 'color': True, 'num_class': num_class} |
| 13 | +args = { |
| 14 | + 'height': height, |
| 15 | + 'width': width, |
| 16 | + 'color': True, |
| 17 | + 'num_class': num_class, |
| 18 | + 'is_infer': is_infer, |
| 19 | + 'num_samples': num_samples |
| 20 | +} |
11 | 21 | define_py_data_sources2(
|
12 | 22 | "train.list", None, module="provider", obj="process", args=args)
|
13 | 23 |
|
|
31 | 41 |
|
32 | 42 | # conv2
|
33 | 43 | net = img_conv_layer(
|
34 |
| - input=net, filter_size=5, num_filters=256, stride=1, padding=2, groups=1) |
| 44 | + input=net, filter_size=5, num_filters=256, stride=1, padding=2, groups=gp) |
35 | 45 | net = img_cmrnorm_layer(input=net, size=5, scale=0.0001, power=0.75)
|
36 | 46 | net = img_pool_layer(input=net, pool_size=3, stride=2)
|
37 | 47 |
|
|
40 | 50 | input=net, filter_size=3, num_filters=384, stride=1, padding=1)
|
41 | 51 | # conv4
|
42 | 52 | net = img_conv_layer(
|
43 |
| - input=net, filter_size=3, num_filters=384, stride=1, padding=1, groups=1) |
| 53 | + input=net, filter_size=3, num_filters=384, stride=1, padding=1, groups=gp) |
44 | 54 |
|
45 | 55 | # conv5
|
46 | 56 | net = img_conv_layer(
|
47 |
| - input=net, filter_size=3, num_filters=256, stride=1, padding=1, groups=1) |
| 57 | + input=net, filter_size=3, num_filters=256, stride=1, padding=1, groups=gp) |
48 | 58 | net = img_pool_layer(input=net, pool_size=3, stride=2)
|
49 | 59 |
|
50 | 60 | net = fc_layer(
|
|
59 | 69 | layer_attr=ExtraAttr(drop_rate=0.5))
|
60 | 70 | net = fc_layer(input=net, size=1000, act=SoftmaxActivation())
|
61 | 71 |
|
62 |
| -lab = data_layer('label', num_class) |
63 |
| -loss = cross_entropy(input=net, label=lab) |
64 |
| -outputs(loss) |
| 72 | +if is_infer: |
| 73 | + outputs(net) |
| 74 | +else: |
| 75 | + lab = data_layer('label', num_class) |
| 76 | + loss = cross_entropy(input=net, label=lab) |
| 77 | + outputs(loss) |
0 commit comments