Skip to content

Commit b6aca33

Browse files
authored
Merge pull request #764 from emailweixu/multiple_parse
Correctly handling multiple calls to parse_config()
2 parents caadbe6 + d87b2c1 commit b6aca33

File tree

4 files changed

+61
-5
lines changed

4 files changed

+61
-5
lines changed

python/paddle/trainer/config_parser.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def init_config_environment(
141141
g_add_submodel_suffix=False,
142142

143143
# Whether current layer needs to pass the image height and width.
144-
# Default value is true, but if it encounters recurrent_layer_group,
145-
# it will be false. The reason is that image is converted to be sequence,
146-
# image height will be sequence length, and image width will be feature
144+
# Default value is true, but if it encounters recurrent_layer_group,
145+
# it will be false. The reason is that image is converted to be sequence,
146+
# image height will be sequence length, and image width will be feature
147147
# length of each timestep.
148148
g_pass_height_width=True, ):
149149

@@ -1067,7 +1067,7 @@ def cnn_output_size(img_size, filter_size, padding, stride, caffe_mode):
10671067
return 1 + int(math.ceil(output))
10681068

10691069

1070-
#calcualte image_size based on output_size for de-convolution (ConvTransLayer).
1070+
#calcualte image_size based on output_size for de-convolution (ConvTransLayer).
10711071
#It is the reverse function of cnn_output_size
10721072
def cnn_image_size(output_size, filter_size, padding, stride, caffe_mode):
10731073
img_size = (output_size - 1) * stride + filter_size - 2 * padding
@@ -3364,13 +3364,23 @@ def my_fatal(s):
33643364
logger.critical(s)
33653365
raise Exception()
33663366

3367+
_parse_config_hooks = set()
3368+
def register_parse_config_hook(f):
3369+
"""
3370+
Register a hook function for parse_config. parse_config will invoke the hook
3371+
at the beginning of parse. This make it possible to reset global state for
3372+
for constructing the model.
3373+
"""
3374+
_parse_config_hooks.add(f)
33673375

33683376
def parse_config(config_file, config_arg_str):
33693377
'''
33703378
@param config_arg_str: a string of the form var1=val1,var2=val2. It will be
33713379
passed to config script as a dictionary CONFIG_ARGS
33723380
'''
33733381
init_config_environment()
3382+
for hook in _parse_config_hooks:
3383+
hook()
33743384

33753385
config_args = {}
33763386

python/paddle/trainer_config_helpers/default_decorators.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,17 @@ def __check_name__(self, nm):
7878
"""
7979
pass
8080

81+
def reset(self):
82+
self.__counter__ = 0
83+
84+
85+
_name_factories = []
86+
87+
def reset_hook():
88+
for factory in _name_factories:
89+
factory.reset()
90+
91+
register_parse_config_hook(reset_hook)
8192

8293
def wrap_name_default(name_prefix=None):
8394
"""
@@ -95,7 +106,9 @@ def func(name=None):
95106
:return: a decorator to set default name
96107
:rtype: callable
97108
"""
98-
return wrap_param_default(["name"], DefaultNameFactory(name_prefix))
109+
factory = DefaultNameFactory(name_prefix)
110+
_name_factories.append(factory)
111+
return wrap_param_default(["name"], factory)
99112

100113

101114
def wrap_param_attr_default(param_names=None, default_factory=None):

python/paddle/trainer_config_helpers/tests/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ add_test(NAME layers_test
44
python ${PROJ_ROOT}/python/paddle/trainer_config_helpers/tests/layers_test.py
55
WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle)
66

7+
add_test(NAME test_reset_hook
8+
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
9+
python ${PROJ_ROOT}/python/paddle/trainer_config_helpers/tests/test_reset_hook.py
10+
WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle)
11+
712
if (PROTOBUF_3)
813
add_paddle_exe(protobuf_equal
914
ProtobufEqualMain.cpp)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright PaddlePaddle contributors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import unittest
15+
from paddle.trainer.config_parser import parse_config
16+
17+
class TestParse(unittest.TestCase):
18+
19+
def test_parse(self):
20+
a = parse_config(
21+
'trainer_config_helpers/tests/layers_test_config.py', '')
22+
b = parse_config(
23+
'trainer_config_helpers/tests/layers_test_config.py', '')
24+
self.assertEqual(a, b)
25+
26+
27+
if __name__ == '__main__':
28+
unittest.main()

0 commit comments

Comments
 (0)