Skip to content

Commit 2d51cfd

Browse files
Test 02 branch (#907)
* yifei add dnn register part for uapi * yifei add register dnn part * yifei add rec runner.py * [add function] yifei add rec runner.py * yifei delete print funcs * yifei change gpu list length from 8 to 2 * yifei add gpu list * yifei change device name * yifei change dataset part * yifei change data check funcs * yifei add uapi_rec data check part --------- Co-authored-by: wangzhen38 <41941775+wangzhen38@users.noreply.github.com>
1 parent cd3de1d commit 2d51cfd

File tree

3 files changed

+100
-28
lines changed

3 files changed

+100
-28
lines changed

uapi_rec/base/utils/stagelog.py

+90-12
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,16 @@
1414

1515
import inspect
1616
import threading
17+
import signal
18+
import sys
19+
import functools
1720

1821
from . import logging
1922
from ..flags import DEBUG
23+
from .misc import abspath
24+
25+
_CURR_STAGE_IDS = []
26+
_LOCK = threading.RLock()
2027

2128

2229
class _Singleton(type):
@@ -36,6 +43,22 @@ def __repr__(self):
3643
return "-EMPTY-"
3744

3845

46+
def stagelog_on_exit(signum, frame):
47+
global _CURR_STAGE_IDS, _LOCK
48+
with _LOCK:
49+
for stage_id in _CURR_STAGE_IDS[:]:
50+
fail(stage_id,
51+
f"Running of the stage was interrupted with signal {signum}.")
52+
# NOTE: We do not send signals to the subprocesses. It is the user's choice whether to
53+
# do this.
54+
logging.warn(
55+
f"Received signal {signum}. The program is about to terminate.")
56+
sys.exit(signum)
57+
58+
59+
signal.signal(signal.SIGINT, stagelog_on_exit)
60+
signal.signal(signal.SIGTERM, stagelog_on_exit)
61+
3962
# TODO: Replacing the code that forwards arguments using `locals()` with more explicit code
4063

4164

@@ -47,7 +70,7 @@ def running_datacheck(data_path, data_type, yaml_path=_EMPTY_ARG()):
4770
yaml_path (str, optinoal): Absolute path of the YAML file of the dataset.
4871
"""
4972

50-
return _stagelog_call('running_datacheck', **locals())
73+
return _stagelog_running_call('running_datacheck', **locals())
5174

5275

5376
def running_train(learning_rate, epoch_iters, batch_size, data_path, yaml_path,
@@ -63,7 +86,7 @@ def running_train(learning_rate, epoch_iters, batch_size, data_path, yaml_path,
6386
save_dir (str): Directory that contains model snapshots and logs.
6487
"""
6588

66-
return _stagelog_call('running_train', **locals())
89+
return _stagelog_running_call('running_train', **locals())
6790

6891

6992
def running_verify(checkpoint_dir,
@@ -77,7 +100,7 @@ def running_verify(checkpoint_dir,
77100
metrics (list[str]): A list of names of all metrics used in model evaluation.
78101
save_dir (str, optional): Directory that contains model snapshots and logs.
79102
"""
80-
return _stagelog_call('running_verify', **locals())
103+
return _stagelog_running_call('running_verify', **locals())
81104

82105

83106
def running_compress(checkpoint_dir,
@@ -100,19 +123,19 @@ def running_compress(checkpoint_dir,
100123
metrics (list[str], optional): A list of names of all metrics used in model evaluation.
101124
"""
102125

103-
return _stagelog_call('running_compress', **locals())
126+
return _stagelog_running_call('running_compress', **locals())
104127

105128

106129
def running_deploy(operating_system, language, architecture, accelerator):
107-
return _stagelog_call('running_deploy', **locals())
130+
return _stagelog_running_call('running_deploy', **locals())
108131

109132

110133
def success(stage_id, result=_EMPTY_ARG()):
111-
return _stagelog_call('success', **locals())
134+
return _stagelog_status_call('success', **locals())
112135

113136

114137
def fail(stage_id, error_message=_EMPTY_ARG()):
115-
return _stagelog_call('fail', **locals())
138+
return _stagelog_status_call('fail', **locals())
116139

117140

118141
def success_datacheck(stage_id, train_dataset, validation_dataset,
@@ -124,7 +147,7 @@ def success_datacheck(stage_id, train_dataset, validation_dataset,
124147
validation_dataset (int): Number of samples in validation dataset.
125148
test_dataset (int): Number of samples in test dataset.
126149
"""
127-
return _stagelog_call('success_datacheck', **locals())
150+
return _stagelog_status_call('success_datacheck', **locals())
128151

129152

130153
def _stagelog_call(func_name, *args, **kwargs):
@@ -151,9 +174,39 @@ def _stagelog_call(func_name, *args, **kwargs):
151174
if DEBUG:
152175
logging.warn("stagelog not initialized.")
153176
else:
177+
# Ignore
154178
pass
155179

156180

181+
def _stagelog_running_call(func_name, **kwargs):
182+
# FIXME: It is possible that the program gets interrupted when `_stagelog_call`
183+
# is called and the stage id is not yet recorded.
184+
ret = _stagelog_call(func_name, **kwargs)
185+
if ret is not None:
186+
global _CURR_STAGE_IDS, _LOCK
187+
with _LOCK:
188+
# if `ret` is not None, it is a stage id
189+
_CURR_STAGE_IDS.append(ret)
190+
return ret
191+
192+
193+
def _stagelog_status_call(func_name, **kwargs):
194+
global _CURR_STAGE_IDS, _LOCK
195+
stage_id = kwargs['stage_id']
196+
with _LOCK:
197+
try:
198+
# O(N)
199+
idx = _CURR_STAGE_IDS.index(stage_id)
200+
except ValueError:
201+
idx = None
202+
else:
203+
# O(N)
204+
_CURR_STAGE_IDS.pop(idx)
205+
if idx is not None:
206+
# Each stage makes no more than one status call
207+
return _stagelog_call(func_name, **kwargs)
208+
209+
157210
class _StageLogContextManagerMeta(type):
158211
def __new__(mcls, name, bases, attrs):
159212
cls = super().__new__(mcls, name, bases, attrs)
@@ -185,7 +238,7 @@ def __enter__(self):
185238
return self
186239

187240
def __exit__(self, exc_type, exc_val, exc_tb):
188-
self._log_status(self.stage_id, exc_type, exc_val, exc_tb)
241+
return self._log_status(self.stage_id, exc_type, exc_val, exc_tb)
189242

190243
def _log_running(self):
191244
return type(self)._STAGELOG_API(*self.running_args,
@@ -205,7 +258,7 @@ def _log_status(self, stage_id, exc_type, exc_val, exc_tb):
205258
if exc_type is None and exc_val is None and exc_tb is None:
206259
success(stage_id)
207260
else:
208-
fail(stage_id, str(exc_val))
261+
fail(stage_id, f"{exc_type.__name__}: {str(exc_val)}")
209262

210263

211264
class StageLogCompress(_StageLogContextManager):
@@ -218,7 +271,7 @@ def _log_status(self, stage_id, exc_type, exc_val, exc_tb):
218271
if exc_type is None and exc_val is None and exc_tb is None:
219272
success(stage_id)
220273
else:
221-
fail(stage_id, str(exc_val))
274+
fail(stage_id, f"{exc_type.__name__}: {str(exc_val)}")
222275

223276

224277
class StageLogEvaluate(_StageLogContextManager):
@@ -244,4 +297,29 @@ def _log_status(self, stage_id, exc_type, exc_val, exc_tb):
244297
# the error message.
245298
fail(stage_id, "Model evaluation failed.")
246299
else:
247-
return False
300+
fail(stage_id, f"{exc_type.__name__}: {str(exc_val)}")
301+
302+
303+
def stagelog_check_dataset(dataset_checker):
304+
@functools.wraps(dataset_checker)
305+
def _wrapper(dataset_dir, dataset_type):
306+
stage_id = running_datacheck(abspath(dataset_dir), dataset_type)
307+
try:
308+
res = dataset_checker(dataset_dir, dataset_type)
309+
except BaseException as e:
310+
# We catch all exceptions including `KeyboardInterrupt`
311+
fail(stage_id, f"{type(e).__name__}: {str(e)}")
312+
raise
313+
else:
314+
if res['res_flag']:
315+
test_samples = res.get('test.samples', None) or 0
316+
success_datacheck(
317+
stage_id,
318+
train_dataset=res['train.samples'],
319+
validation_dataset=0,
320+
test_dataset=0)
321+
else:
322+
fail(stage_id, res['err_msg'])
323+
return res
324+
325+
return _wrapper

uapi_rec/rank/check_dataset.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@
2323
persist_dataset_meta, build_res_dict, CheckFailedError,
2424
UnsupportedDatasetTypeError, DatasetFileNotFoundError)
2525

26-
from base.utils import stagelog
26+
from base.utils.stagelog import stagelog_check_dataset
27+
from base.utils.misc import abspath
2728

2829
MAX_V = 18446744073709551615
2930

3031

31-
@persist_dataset_meta
32-
def check_dataset(model_name, dataset_dir, dataset_type):
33-
stage_id = stagelog.running_datacheck(
34-
data_path=dataset_dir, data_type=dataset_type)
32+
@stagelog_check_dataset
33+
def check_dataset(dataset_dir, dataset_type):
34+
dataset_dir = abspath(dataset_dir)
3535
try:
3636
if dataset_type == 'Dataset':
3737
# Custom dataset
38-
dataset_dir = osp.abspath(dataset_dir)
38+
#dataset_dir = osp.abspath(dataset_dir)
3939
if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
4040
raise DatasetFileNotFoundError(file_path=dataset_dir)
4141

@@ -80,14 +80,6 @@ def check_dataset(model_name, dataset_dir, dataset_type):
8080
else:
8181
raise UnsupportedDatasetTypeError(dataset_type=dataset_type)
8282
except CheckFailedError as e:
83-
stagelog.fail(stage_id, str(e))
8483
return build_res_dict(False, err_type=type(e), err_msg=str(e))
8584
else:
86-
stagelog.success_datacheck(
87-
stage_id,
88-
train_dataset=meta['train.samples'],
89-
validation_dataset=0,
90-
test_dataset=0)
91-
#validation_dataset=meta['val.samples'],
92-
#test_dataset=meta['test.samples'] or 0)
9385
return meta

uapi_rec/test_uapi/test_stagelog.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ def close_stagelog():
4646
init_stagelog()
4747

4848
# Check dataset + success
49-
check_dataset(model_name, dataset_dir, dataset_type='Dataset')
49+
#check_dataset(model_name, dataset_dir, dataset_type='Dataset')
50+
check_dataset(dataset_dir, dataset_type='Dataset')
5051

5152
# Check dataset + failure
52-
check_dataset(model_name, dataset_dir, dataset_type='baidu')
53+
#check_dataset(model_name, dataset_dir, dataset_type='baidu')
54+
check_dataset(dataset_dir, dataset_type='baidu')
5355

5456
config = Config(model_name)
5557
config.update_dataset(dataset_dir)

0 commit comments

Comments
 (0)