Skip to content

Commit 6520d84

Browse files
committed
Update exception handling in inference
1 parent 3c6b990 commit 6520d84

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

runner/inference.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -219,25 +219,29 @@ def update_inference_configs(configs: Any, N_token: int):
219219
def infer_predict(runner: InferenceRunner, configs: Any) -> None:
220220
# Data
221221
logger.info(f"Loading data from\n{configs.input_json_path}")
222-
dataloader = get_inference_dataloader(configs=configs)
222+
try:
223+
dataloader = get_inference_dataloader(configs=configs)
224+
except Exception as e:
225+
error_message = f"{e}:\n{traceback.format_exc()}"
226+
logger.info(error_message)
227+
with open(opjoin(runner.error_dir, "error.txt"), "a") as f:
228+
f.write(error_message)
229+
return
223230

224231
num_data = len(dataloader.dataset)
225232
for seed in configs.seeds:
226233
seed_everything(seed=seed, deterministic=configs.deterministic)
227234
for batch in dataloader:
228235
try:
229236
data, atom_array, data_error_message = batch[0]
237+
sample_name = data["sample_name"]
230238

231239
if len(data_error_message) > 0:
232240
logger.info(data_error_message)
233-
with open(
234-
opjoin(runner.error_dir, f"{data['sample_name']}.txt"),
235-
"w",
236-
) as f:
241+
with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f:
237242
f.write(data_error_message)
238243
continue
239244

240-
sample_name = data["sample_name"]
241245
logger.info(
242246
(
243247
f"[Rank {DIST_WRAPPER.rank} ({data['sample_index'] + 1}/{num_data})] {sample_name}: "
@@ -266,15 +270,10 @@ def infer_predict(runner: InferenceRunner, configs: Any) -> None:
266270
error_message = f"[Rank {DIST_WRAPPER.rank}]{data['sample_name']} {e}:\n{traceback.format_exc()}"
267271
logger.info(error_message)
268272
# Save error info
269-
if opexists(
270-
error_path := opjoin(runner.error_dir, f"{sample_name}.txt")
271-
):
272-
os.remove(error_path)
273-
with open(error_path, "w") as f:
273+
with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f:
274274
f.write(error_message)
275275
if hasattr(torch.cuda, "empty_cache"):
276276
torch.cuda.empty_cache()
277-
raise RuntimeError(f"run infer failed: {str(e)}")
278277

279278

280279
def main(configs: Any) -> None:

0 commit comments

Comments
 (0)