@@ -219,25 +219,29 @@ def update_inference_configs(configs: Any, N_token: int):
219
219
def infer_predict (runner : InferenceRunner , configs : Any ) -> None :
220
220
# Data
221
221
logger .info (f"Loading data from\n { configs .input_json_path } " )
222
- dataloader = get_inference_dataloader (configs = configs )
222
+ try :
223
+ dataloader = get_inference_dataloader (configs = configs )
224
+ except Exception as e :
225
+ error_message = f"{ e } :\n { traceback .format_exc ()} "
226
+ logger .info (error_message )
227
+ with open (opjoin (runner .error_dir , "error.txt" ), "a" ) as f :
228
+ f .write (error_message )
229
+ return
223
230
224
231
num_data = len (dataloader .dataset )
225
232
for seed in configs .seeds :
226
233
seed_everything (seed = seed , deterministic = configs .deterministic )
227
234
for batch in dataloader :
228
235
try :
229
236
data , atom_array , data_error_message = batch [0 ]
237
+ sample_name = data ["sample_name" ]
230
238
231
239
if len (data_error_message ) > 0 :
232
240
logger .info (data_error_message )
233
- with open (
234
- opjoin (runner .error_dir , f"{ data ['sample_name' ]} .txt" ),
235
- "w" ,
236
- ) as f :
241
+ with open (opjoin (runner .error_dir , f"{ sample_name } .txt" ), "a" ) as f :
237
242
f .write (data_error_message )
238
243
continue
239
244
240
- sample_name = data ["sample_name" ]
241
245
logger .info (
242
246
(
243
247
f"[Rank { DIST_WRAPPER .rank } ({ data ['sample_index' ] + 1 } /{ num_data } )] { sample_name } : "
@@ -266,15 +270,10 @@ def infer_predict(runner: InferenceRunner, configs: Any) -> None:
266
270
error_message = f"[Rank { DIST_WRAPPER .rank } ]{ data ['sample_name' ]} { e } :\n { traceback .format_exc ()} "
267
271
logger .info (error_message )
268
272
# Save error info
269
- if opexists (
270
- error_path := opjoin (runner .error_dir , f"{ sample_name } .txt" )
271
- ):
272
- os .remove (error_path )
273
- with open (error_path , "w" ) as f :
273
+ with open (opjoin (runner .error_dir , f"{ sample_name } .txt" ), "a" ) as f :
274
274
f .write (error_message )
275
275
if hasattr (torch .cuda , "empty_cache" ):
276
276
torch .cuda .empty_cache ()
277
- raise RuntimeError (f"run infer failed: { str (e )} " )
278
277
279
278
280
279
def main (configs : Any ) -> None :
0 commit comments