@@ -258,16 +258,13 @@ def update_a3m(
258
258
a3m_path : str ,
259
259
uniref_to_ncbi_taxid : Dict [str , str ],
260
260
save_root : str ,
261
- ) -> str :
261
+ ) -> None :
262
262
"""add NCBI TaxID to header if "UniRef" in header
263
263
264
264
Args:
265
265
a3m_path (str): the original a3m path returned by mmseqs(colabfold search)
266
266
uniref_to_ncbi_taxid (Dict): the dict mapping uniref hit_name to NCBI TaxID
267
267
save_root (str): the updated a3m
268
-
269
- Returns:
270
- str: The path of the processed a3m file
271
268
"""
272
269
heads , seqs , uniref_index = read_a3m (a3m_path )
273
270
fname = a3m_path .split ("/" )[- 1 ]
@@ -284,10 +281,9 @@ def update_a3m(
284
281
else :
285
282
head = head .replace (uniref_id , f"{ uniref_id } _{ ncbi_taxid } /" )
286
283
ofile .write (f"{ head } { seq } " )
287
- return out_a3m_path
288
284
289
285
290
- def update_a3m_batch (batch_paths : List [str ], uniref_to_ncbi_taxid : Dict [str , str ], save_root : str ) -> List [ str ] :
286
+ def update_a3m_batch (batch_paths : List [str ], uniref_to_ncbi_taxid : Dict [str , str ], save_root : str ) -> int :
291
287
"""Process a batch of a3m files.
292
288
293
289
Args:
@@ -296,17 +292,15 @@ def update_a3m_batch(batch_paths: List[str], uniref_to_ncbi_taxid: Dict[str, str
296
292
save_root (str): Directory to save processed files
297
293
298
294
Returns:
299
- List[str]: List of processed file paths
295
+ int: Number of files processed
300
296
"""
301
- results = []
302
297
for a3m_path in batch_paths :
303
- result = update_a3m (
298
+ update_a3m (
304
299
a3m_path = a3m_path ,
305
300
uniref_to_ncbi_taxid = uniref_to_ncbi_taxid ,
306
301
save_root = save_root
307
302
)
308
- results .append (result )
309
- return results
303
+ return len (batch_paths )
310
304
311
305
312
306
def process_files (
@@ -343,25 +337,17 @@ def process_files(
343
337
batch_size = max (1 , math .ceil (total_files / (num_workers * target_batches_per_worker )))
344
338
345
339
# Create batches
346
- batches = []
347
- for i in range (0 , len (a3m_paths ), batch_size ):
348
- batch = a3m_paths [i :i + batch_size ]
349
- batches .append (batch )
340
+ batches = [a3m_paths [i :i + batch_size ] for i in range (0 , len (a3m_paths ), batch_size )]
350
341
351
342
# Process in single-threaded mode if we have very few files or only one worker
352
343
if total_files < 10 or num_workers == 1 :
353
- with tqdm (total = total_files , desc = "Processing a3m files" ) as pbar :
354
- for a3m_path in a3m_paths :
355
- update_a3m (
356
- a3m_path = a3m_path ,
357
- uniref_to_ncbi_taxid = uniref_to_ncbi_taxid ,
358
- save_root = output_msa_dir ,
359
- )
360
- pbar .update (1 )
361
- return
362
-
363
- start_time = time .time ()
364
-
344
+ for a3m_path in tqdm (a3m_paths , desc = "Processing a3m files" ):
345
+ update_a3m (
346
+ a3m_path = a3m_path ,
347
+ uniref_to_ncbi_taxid = uniref_to_ncbi_taxid ,
348
+ save_root = output_msa_dir ,
349
+ )
350
+ return
365
351
# Use ProcessPoolExecutor for parallel processing
366
352
with concurrent .futures .ProcessPoolExecutor (max_workers = num_workers ) as executor :
367
353
# Submit batch tasks instead of individual files
@@ -376,24 +362,17 @@ def process_files(
376
362
futures .append (future )
377
363
378
364
# Track progress across all batches
379
- completed_files = 0
380
365
with tqdm (total = total_files , desc = "Processing a3m files" ) as pbar :
381
366
for future in concurrent .futures .as_completed (futures ):
382
367
try :
383
368
# Each result is a list of file paths processed in the batch
384
- result = future .result ()
385
- batch_size = len (result )
386
- completed_files += batch_size
369
+ batch_size = future .result ()
387
370
pbar .update (batch_size )
388
371
except Exception as e :
389
372
print (f"Error processing batch: { e } " )
390
373
# Estimate how many files might have been in this failed batch
391
374
avg_batch_size = total_files / len (batches )
392
375
pbar .update (int (avg_batch_size ))
393
-
394
- end_time = time .time ()
395
- elapsed = end_time - start_time
396
- print (f"Processing complete ({ elapsed :.1f} seconds)" )
397
376
398
377
399
378
if __name__ == "__main__" :
@@ -461,3 +440,5 @@ def process_files(
461
440
release_shared_dict (dict_id )
462
441
except Exception as e :
463
442
print (f"Warning: Failed to release shared dict { dict_id } : { e } " )
443
+
444
+ print ("Processing complete" )
0 commit comments