Skip to content

Commit 68a68b5

Browse files
committed
Simplification & cleanup
1 parent 5610d8f commit 68a68b5

File tree

2 files changed

+23
-49
lines changed

2 files changed

+23
-49
lines changed

scripts/msa/step3-uniref_add_taxid.py

+16-35
Original file line numberDiff line numberDiff line change
@@ -258,16 +258,13 @@ def update_a3m(
258258
a3m_path: str,
259259
uniref_to_ncbi_taxid: Dict[str, str],
260260
save_root: str,
261-
) -> str:
261+
) -> None:
262262
"""add NCBI TaxID to header if "UniRef" in header
263263
264264
Args:
265265
a3m_path (str): the original a3m path returned by mmseqs(colabfold search)
266266
uniref_to_ncbi_taxid (Dict): the dict mapping uniref hit_name to NCBI TaxID
267267
save_root (str): the updated a3m
268-
269-
Returns:
270-
str: The path of the processed a3m file
271268
"""
272269
heads, seqs, uniref_index = read_a3m(a3m_path)
273270
fname = a3m_path.split("/")[-1]
@@ -284,10 +281,9 @@ def update_a3m(
284281
else:
285282
head = head.replace(uniref_id, f"{uniref_id}_{ncbi_taxid}/")
286283
ofile.write(f"{head}{seq}")
287-
return out_a3m_path
288284

289285

290-
def update_a3m_batch(batch_paths: List[str], uniref_to_ncbi_taxid: Dict[str, str], save_root: str) -> List[str]:
286+
def update_a3m_batch(batch_paths: List[str], uniref_to_ncbi_taxid: Dict[str, str], save_root: str) -> int:
291287
"""Process a batch of a3m files.
292288
293289
Args:
@@ -296,17 +292,15 @@ def update_a3m_batch(batch_paths: List[str], uniref_to_ncbi_taxid: Dict[str, str
296292
save_root (str): Directory to save processed files
297293
298294
Returns:
299-
List[str]: List of processed file paths
295+
int: Number of files processed
300296
"""
301-
results = []
302297
for a3m_path in batch_paths:
303-
result = update_a3m(
298+
update_a3m(
304299
a3m_path=a3m_path,
305300
uniref_to_ncbi_taxid=uniref_to_ncbi_taxid,
306301
save_root=save_root
307302
)
308-
results.append(result)
309-
return results
303+
return len(batch_paths)
310304

311305

312306
def process_files(
@@ -343,25 +337,17 @@ def process_files(
343337
batch_size = max(1, math.ceil(total_files / (num_workers * target_batches_per_worker)))
344338

345339
# Create batches
346-
batches = []
347-
for i in range(0, len(a3m_paths), batch_size):
348-
batch = a3m_paths[i:i + batch_size]
349-
batches.append(batch)
340+
batches = [a3m_paths[i:i + batch_size] for i in range(0, len(a3m_paths), batch_size)]
350341

351342
# Process in single-threaded mode if we have very few files or only one worker
352343
if total_files < 10 or num_workers == 1:
353-
with tqdm(total=total_files, desc="Processing a3m files") as pbar:
354-
for a3m_path in a3m_paths:
355-
update_a3m(
356-
a3m_path=a3m_path,
357-
uniref_to_ncbi_taxid=uniref_to_ncbi_taxid,
358-
save_root=output_msa_dir,
359-
)
360-
pbar.update(1)
361-
return
362-
363-
start_time = time.time()
364-
344+
for a3m_path in tqdm(a3m_paths, desc="Processing a3m files"):
345+
update_a3m(
346+
a3m_path=a3m_path,
347+
uniref_to_ncbi_taxid=uniref_to_ncbi_taxid,
348+
save_root=output_msa_dir,
349+
)
350+
return
365351
# Use ProcessPoolExecutor for parallel processing
366352
with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
367353
# Submit batch tasks instead of individual files
@@ -376,24 +362,17 @@ def process_files(
376362
futures.append(future)
377363

378364
# Track progress across all batches
379-
completed_files = 0
380365
with tqdm(total=total_files, desc="Processing a3m files") as pbar:
381366
for future in concurrent.futures.as_completed(futures):
382367
try:
383368
# Each result is a list of file paths processed in the batch
384-
result = future.result()
385-
batch_size = len(result)
386-
completed_files += batch_size
369+
batch_size = future.result()
387370
pbar.update(batch_size)
388371
except Exception as e:
389372
print(f"Error processing batch: {e}")
390373
# Estimate how many files might have been in this failed batch
391374
avg_batch_size = total_files / len(batches)
392375
pbar.update(int(avg_batch_size))
393-
394-
end_time = time.time()
395-
elapsed = end_time - start_time
396-
print(f"Processing complete ({elapsed:.1f} seconds)")
397376

398377

399378
if __name__ == "__main__":
@@ -461,3 +440,5 @@ def process_files(
461440
release_shared_dict(dict_id)
462441
except Exception as e:
463442
print(f"Warning: Failed to release shared dict {dict_id}: {e}")
443+
444+
print("Processing complete")

scripts/msa/step4-split_msa_to_uniref_and_others.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@
2525

2626
from utils import (
2727
convert_to_shared_dict, # To create new shared dictionaries
28+
SharedDict, # To handle type annotation
2829
release_shared_dict, # To manually release dictionaries
2930
get_shared_dict_ids # To list available dictionaries
3031
)
3132

3233
# Type alias for dictionary-like objects (regular dict or Manager.dict)
33-
DictLike = Union[Dict[str, Any], Mapping[str, Any]]
34+
DictLike = Union[Dict[str, Any], Mapping[str, Any], SharedDict]
3435

3536

3637
def load_mapping_data(seq_to_pdb_id_path: str, seq_to_pdb_index_path: str, use_shared_memory: bool = False) -> Tuple[Dict[str, Any], DictLike, DictLike]:
@@ -47,7 +48,7 @@ def load_mapping_data(seq_to_pdb_id_path: str, seq_to_pdb_index_path: str, use_s
4748
"""
4849
# Load sequence to PDB ID mapping
4950
with open(seq_to_pdb_id_path, "r") as f:
50-
seq_to_pdbid = json.load(f)
51+
seq_to_pdbid: Dict[str, Any] = json.load(f)
5152

5253
# Create reverse mapping for easy lookup
5354
first_pdbid_to_seq_data = {"_".join(v[0]): k for k, v in seq_to_pdbid.items()}
@@ -349,15 +350,11 @@ def process_files_batched(
349350
)
350351
except Exception as e:
351352
print(f"Error processing batch: {e}")
352-
353-
# Print final statistics
354-
total_time = time.time() - start_time
355-
print(f"Processed {total_processed} files in {total_time:.2f} seconds")
356-
print(f"Average processing speed: {total_processed / total_time:.2f} files/second")
357353

358354

359-
def main():
360-
"""Main function to run the script with command line arguments."""
355+
if __name__ == "__main__":
356+
# Set start method to spawn to ensure compatibility with shared memory
357+
multiprocessing.set_start_method('spawn', force=True)
361358
import argparse
362359

363360
parser = argparse.ArgumentParser()
@@ -412,8 +409,4 @@ def main():
412409
for dict_id in get_shared_dict_ids():
413410
release_shared_dict(dict_id)
414411

415-
416-
if __name__ == "__main__":
417-
# Set start method to spawn to ensure compatibility with shared memory
418-
multiprocessing.set_start_method('spawn', force=True)
419-
main()
412+
print("Processing complete")

0 commit comments

Comments
 (0)