14
14
15
15
import inspect
16
16
import threading
17
+ import signal
18
+ import sys
19
+ import functools
17
20
18
21
from . import logging
19
22
from ..flags import DEBUG
23
+ from .misc import abspath
24
+
25
+ _CURR_STAGE_IDS = []
26
+ _LOCK = threading .RLock ()
20
27
21
28
22
29
class _Singleton (type ):
@@ -36,6 +43,22 @@ def __repr__(self):
36
43
return "-EMPTY-"
37
44
38
45
46
+ def stagelog_on_exit (signum , frame ):
47
+ global _CURR_STAGE_IDS , _LOCK
48
+ with _LOCK :
49
+ for stage_id in _CURR_STAGE_IDS [:]:
50
+ fail (stage_id ,
51
+ f"Running of the stage was interrupted with signal { signum } ." )
52
+ # NOTE: We do not send signals to the subprocesses. It is the user's choice whether to
53
+ # do this.
54
+ logging .warn (
55
+ f"Received signal { signum } . The program is about to terminate." )
56
+ sys .exit (signum )
57
+
58
+
59
+ signal .signal (signal .SIGINT , stagelog_on_exit )
60
+ signal .signal (signal .SIGTERM , stagelog_on_exit )
61
+
39
62
# TODO: Replacing the code that forwards arguments using `locals()` with more explicit code
40
63
41
64
@@ -47,7 +70,7 @@ def running_datacheck(data_path, data_type, yaml_path=_EMPTY_ARG()):
47
70
yaml_path (str, optinoal): Absolute path of the YAML file of the dataset.
48
71
"""
49
72
50
- return _stagelog_call ('running_datacheck' , ** locals ())
73
+ return _stagelog_running_call ('running_datacheck' , ** locals ())
51
74
52
75
53
76
def running_train (learning_rate , epoch_iters , batch_size , data_path , yaml_path ,
@@ -63,7 +86,7 @@ def running_train(learning_rate, epoch_iters, batch_size, data_path, yaml_path,
63
86
save_dir (str): Directory that contains model snapshots and logs.
64
87
"""
65
88
66
- return _stagelog_call ('running_train' , ** locals ())
89
+ return _stagelog_running_call ('running_train' , ** locals ())
67
90
68
91
69
92
def running_verify (checkpoint_dir ,
@@ -77,7 +100,7 @@ def running_verify(checkpoint_dir,
77
100
metrics (list[str]): A list of names of all metrics used in model evaluation.
78
101
save_dir (str, optional): Directory that contains model snapshots and logs.
79
102
"""
80
- return _stagelog_call ('running_verify' , ** locals ())
103
+ return _stagelog_running_call ('running_verify' , ** locals ())
81
104
82
105
83
106
def running_compress (checkpoint_dir ,
@@ -100,19 +123,19 @@ def running_compress(checkpoint_dir,
100
123
metrics (list[str], optional): A list of names of all metrics used in model evaluation.
101
124
"""
102
125
103
- return _stagelog_call ('running_compress' , ** locals ())
126
+ return _stagelog_running_call ('running_compress' , ** locals ())
104
127
105
128
106
129
def running_deploy (operating_system , language , architecture , accelerator ):
107
- return _stagelog_call ('running_deploy' , ** locals ())
130
+ return _stagelog_running_call ('running_deploy' , ** locals ())
108
131
109
132
110
133
def success (stage_id , result = _EMPTY_ARG ()):
111
- return _stagelog_call ('success' , ** locals ())
134
+ return _stagelog_status_call ('success' , ** locals ())
112
135
113
136
114
137
def fail (stage_id , error_message = _EMPTY_ARG ()):
115
- return _stagelog_call ('fail' , ** locals ())
138
+ return _stagelog_status_call ('fail' , ** locals ())
116
139
117
140
118
141
def success_datacheck (stage_id , train_dataset , validation_dataset ,
@@ -124,7 +147,7 @@ def success_datacheck(stage_id, train_dataset, validation_dataset,
124
147
validation_dataset (int): Number of samples in validation dataset.
125
148
test_dataset (int): Number of samples in test dataset.
126
149
"""
127
- return _stagelog_call ('success_datacheck' , ** locals ())
150
+ return _stagelog_status_call ('success_datacheck' , ** locals ())
128
151
129
152
130
153
def _stagelog_call (func_name , * args , ** kwargs ):
@@ -151,9 +174,39 @@ def _stagelog_call(func_name, *args, **kwargs):
151
174
if DEBUG :
152
175
logging .warn ("stagelog not initialized." )
153
176
else :
177
+ # Ignore
154
178
pass
155
179
156
180
181
+ def _stagelog_running_call (func_name , ** kwargs ):
182
+ # FIXME: It is possible that the program gets interrupted when `_stagelog_call`
183
+ # is called and the stage id is not yet recorded.
184
+ ret = _stagelog_call (func_name , ** kwargs )
185
+ if ret is not None :
186
+ global _CURR_STAGE_IDS , _LOCK
187
+ with _LOCK :
188
+ # if `ret` is not None, it is a stage id
189
+ _CURR_STAGE_IDS .append (ret )
190
+ return ret
191
+
192
+
193
+ def _stagelog_status_call (func_name , ** kwargs ):
194
+ global _CURR_STAGE_IDS , _LOCK
195
+ stage_id = kwargs ['stage_id' ]
196
+ with _LOCK :
197
+ try :
198
+ # O(N)
199
+ idx = _CURR_STAGE_IDS .index (stage_id )
200
+ except ValueError :
201
+ idx = None
202
+ else :
203
+ # O(N)
204
+ _CURR_STAGE_IDS .pop (idx )
205
+ if idx is not None :
206
+ # Each stage makes no more than one status call
207
+ return _stagelog_call (func_name , ** kwargs )
208
+
209
+
157
210
class _StageLogContextManagerMeta (type ):
158
211
def __new__ (mcls , name , bases , attrs ):
159
212
cls = super ().__new__ (mcls , name , bases , attrs )
@@ -185,7 +238,7 @@ def __enter__(self):
185
238
return self
186
239
187
240
def __exit__ (self , exc_type , exc_val , exc_tb ):
188
- self ._log_status (self .stage_id , exc_type , exc_val , exc_tb )
241
+ return self ._log_status (self .stage_id , exc_type , exc_val , exc_tb )
189
242
190
243
def _log_running (self ):
191
244
return type (self )._STAGELOG_API (* self .running_args ,
@@ -205,7 +258,7 @@ def _log_status(self, stage_id, exc_type, exc_val, exc_tb):
205
258
if exc_type is None and exc_val is None and exc_tb is None :
206
259
success (stage_id )
207
260
else :
208
- fail (stage_id , str (exc_val ))
261
+ fail (stage_id , f" { exc_type . __name__ } : { str (exc_val )} " )
209
262
210
263
211
264
class StageLogCompress (_StageLogContextManager ):
@@ -218,7 +271,7 @@ def _log_status(self, stage_id, exc_type, exc_val, exc_tb):
218
271
if exc_type is None and exc_val is None and exc_tb is None :
219
272
success (stage_id )
220
273
else :
221
- fail (stage_id , str (exc_val ))
274
+ fail (stage_id , f" { exc_type . __name__ } : { str (exc_val )} " )
222
275
223
276
224
277
class StageLogEvaluate (_StageLogContextManager ):
@@ -244,4 +297,29 @@ def _log_status(self, stage_id, exc_type, exc_val, exc_tb):
244
297
# the error message.
245
298
fail (stage_id , "Model evaluation failed." )
246
299
else :
247
- return False
300
+ fail (stage_id , f"{ exc_type .__name__ } : { str (exc_val )} " )
301
+
302
+
303
+ def stagelog_check_dataset (dataset_checker ):
304
+ @functools .wraps (dataset_checker )
305
+ def _wrapper (dataset_dir , dataset_type ):
306
+ stage_id = running_datacheck (abspath (dataset_dir ), dataset_type )
307
+ try :
308
+ res = dataset_checker (dataset_dir , dataset_type )
309
+ except BaseException as e :
310
+ # We catch all exceptions including `KeyboardInterrupt`
311
+ fail (stage_id , f"{ type (e ).__name__ } : { str (e )} " )
312
+ raise
313
+ else :
314
+ if res ['res_flag' ]:
315
+ test_samples = res .get ('test.samples' , None ) or 0
316
+ success_datacheck (
317
+ stage_id ,
318
+ train_dataset = res ['train.samples' ],
319
+ validation_dataset = 0 ,
320
+ test_dataset = 0 )
321
+ else :
322
+ fail (stage_id , res ['err_msg' ])
323
+ return res
324
+
325
+ return _wrapper
0 commit comments