diff --git a/apps/protein_folding/helixfold/gpu_infer.sh b/apps/protein_folding/helixfold/gpu_infer.sh index dd84b266..cf48b749 100644 --- a/apps/protein_folding/helixfold/gpu_infer.sh +++ b/apps/protein_folding/helixfold/gpu_infer.sh @@ -5,7 +5,6 @@ root_path="$(pwd)" demo=$1 export PYTHONPATH=$root_path:$PYTHONPATH -export FLAGS_use_cuda_managed_memory=true #DATA_DIR="$root_path/data" DATA_DIR="/root/paddlejob/workspace/env_run/alphafold_data" @@ -13,53 +12,52 @@ fasta_file="$root_path/demo_data/casp14_demo/fasta/${demo}.fasta" OUTPUT_DIR="$root_path/demo_data/casp14_demo/output" log_dir="$root_path/demo_data/casp14_demo/demo_log" MODELS="model_1,model_5" -USE_DAP=false +SUBBATCH_SIZE=48 -if [ $USE_DAP == true ]; then +distributed=true - distributed_args="--run_mode=collective --log_dir=${log_dir}" - python -m paddle.distributed.launch ${distributed_args} \ - --gpus="0,1,2,3,4,5,6,7" \ - run_helixfold.py \ - --distributed \ - --dap_degree 8 \ - --fasta_paths=${fasta_file} \ - --data_dir=${DATA_DIR} \ - --bfd_database_path=${DATA_DIR}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ - --small_bfd_database_path=${DATA_DIR}/small_bfd/bfd-first_non_consensus_sequences.fasta \ - --uniclust30_database_path=${DATA_DIR}/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ - --uniref90_database_path=${DATA_DIR}/uniref90/uniref90.fasta \ - --mgnify_database_path=${DATA_DIR}/mgnify/mgy_clusters_2018_12.fa \ - --pdb70_database_path=${DATA_DIR}/pdb70/pdb70 \ - --template_mmcif_dir=${DATA_DIR}/pdb_mmcif/mmcif_files \ - --obsolete_pdbs_path=${DATA_DIR}/pdb_mmcif/obsolete.dat \ - --max_template_date=2020-05-14 \ - --model_names=${MODELS} \ - --output_dir=${OUTPUT_DIR} \ - --disable_amber_relax \ - --seed 2022 \ - --preset='reduced_dbs' \ - --random_seed=0 \ - ${@:2} +# 'fp32' or 'bf16' +PRECISION='bf16' + +if [ $distributed == true ] +then + # enable dap for EXTREMELY LONG SEQUENCE PROTEIN + python_cmd="python -m paddle.distributed.launch --log_dir=${log_dir} --gpus=0,1,2,3,4,5,6,7 " + distributed_flag="--distributed" + DAP_DEGREE=8 + # Reduce the size of subbatch_size when the gpu memory is not enough + # SUBBATCH_SIZE=1 else + # enable unified memory for EXTREMELY LONG SEQUENCE PROTEIN + # export FLAGS_use_cuda_managed_memory=true + python_cmd="CUDA_VISIBLE_DEVICES=0 python " + distributed_flag="" + DAP_DEGREE=1 + # Reduce the size of subbatch_size when the gpu memory is not enough + # SUBBATCH_SIZE=1 +fi - CUDA_VISIBLE_DEVICES=0 python run_helixfold.py \ - --fasta_paths=${fasta_file} \ - --data_dir=${DATA_DIR} \ - --bfd_database_path=${DATA_DIR}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ - --small_bfd_database_path=${DATA_DIR}/small_bfd/bfd-first_non_consensus_sequences.fasta \ - --uniclust30_database_path=${DATA_DIR}/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ - --uniref90_database_path=${DATA_DIR}/uniref90/uniref90.fasta \ - --mgnify_database_path=${DATA_DIR}/mgnify/mgy_clusters_2018_12.fa \ - --pdb70_database_path=${DATA_DIR}/pdb70/pdb70 \ - --template_mmcif_dir=${DATA_DIR}/pdb_mmcif/mmcif_files \ - --obsolete_pdbs_path=${DATA_DIR}/pdb_mmcif/obsolete.dat \ - --max_template_date=2020-05-14 \ - --model_names=${MODELS} \ - --output_dir=${OUTPUT_DIR} \ - --disable_amber_relax \ - --seed 2022 \ - --preset='reduced_dbs' \ - --random_seed=0 \ - ${@:2} -fi \ No newline at end of file +$python_cmd \ + run_helixfold.py \ + ${distributed_flag} \ + --dap_degree=${DAP_DEGREE} \ + --fasta_paths=${fasta_file} \ + --data_dir=${DATA_DIR} \ + --bfd_database_path=${DATA_DIR}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ + --small_bfd_database_path=${DATA_DIR}/small_bfd/bfd-first_non_consensus_sequences.fasta \ + --uniclust30_database_path=${DATA_DIR}/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ + --uniref90_database_path=${DATA_DIR}/uniref90/uniref90.fasta \ + --mgnify_database_path=${DATA_DIR}/mgnify/mgy_clusters_2018_12.fa \ + --pdb70_database_path=${DATA_DIR}/pdb70/pdb70 \ + --template_mmcif_dir=${DATA_DIR}/pdb_mmcif/mmcif_files \ + --obsolete_pdbs_path=${DATA_DIR}/pdb_mmcif/obsolete.dat \ + --max_template_date=2020-05-14 \ + --model_names=${MODELS} \ + --output_dir=${OUTPUT_DIR} \ + --disable_amber_relax \ + --seed 2022 \ + --preset='reduced_dbs' \ + --random_seed=0 \ + --precision=${PRECISION} \ + --subbatch_size=${SUBBATCH_SIZE} \ + ${@:2} diff --git a/apps/protein_folding/helixfold/run_helixfold.py b/apps/protein_folding/helixfold/run_helixfold.py index 129a5a34..2a1dbaa9 100644 --- a/apps/protein_folding/helixfold/run_helixfold.py +++ b/apps/protein_folding/helixfold/run_helixfold.py @@ -35,6 +35,8 @@ from alphafold_paddle.data.utils import align_feat, unpad_prediction from utils.init_env import init_seed, init_distributed_env +from ppfleetx.distributed.protein_folding import dp, dap, bp +from utils.utils import get_bf16_op_list logging.basicConfig() logger = logging.getLogger(__file__) @@ -48,8 +50,6 @@ def predict_structure( - dp_rank: int, - dp_nranks: int, fasta_path: str, fasta_name: str, output_dir_base: str, @@ -57,6 +57,13 @@ def predict_structure( model_runners: Dict[str, model.RunModel], amber_relaxer: relax.AmberRelaxation, random_seed: int): + + dap_rank = 0 + bp_rank = 0 + if args.distributed and args.dap_degree > 1: + dap_rank = dap.get_rank_in_group() if dap.get_world_size() > 1 else 0 + bp_rank = bp.get_rank_in_group() if bp.get_world_size() > 1 else 0 + timings = dict() output_dir = pathlib.Path(output_dir_base).joinpath(fasta_name) output_dir.mkdir(parents=True, exist_ok=True) @@ -75,18 +82,22 @@ def predict_structure( logger.info('Use cached features.pkl') with open(features_pkl, 'rb') as f: feature_dict = pickle.load(f) - elif dp_rank == 0: - t0 = time.time() - feature_dict = data_pipeline.process( - input_fasta_path=fasta_path, - msa_output_dir=msa_output_dir) - timings['features'] = time.time() - t0 + else: + if dap_rank == 0 and bp_rank == 0: + t0 = time.time() + feature_dict = data_pipeline.process( + input_fasta_path=fasta_path, + msa_output_dir=msa_output_dir) + timings['features'] = time.time() - t0 - with open(features_pkl, 'wb') as f: - pickle.dump(feature_dict, f, protocol=4) + with open(features_pkl, 'wb') as f: + pickle.dump(feature_dict, f, protocol=4) - if args.distributed: - dist.barrier() + if args.distributed: + dist.barrier() + + with open(features_pkl, 'rb') as f: + feature_dict = pickle.load(f) relaxed_pdbs, plddts = dict(), dict() for model_name, model_runner in model_runners.items(): @@ -96,37 +107,42 @@ def predict_structure( has_cache = input_features_pkl.exists() t0 = time.time() - if dp_rank == 0: - processed_feature_dict = model_runner.preprocess( - feature_dict, random_seed, input_features_pkl) - if not has_cache and dp_rank == 0: - timings[f'process_features_{model_name}'] = time.time() - t0 - - if args.distributed and args.dap_degree > 1: - processed_feature_dict = align_feat( - processed_feature_dict, args.dap_degree) - else: - processed_feature_dict = dict() - if args.distributed: - for _, tensor in processed_feature_dict.items(): - dist.broadcast(tensor, 0) - - dist.barrier() + processed_feature_dict = model_runner.preprocess( + feature_dict, random_seed, input_features_pkl) + if not has_cache and dap_rank == 0 and bp_rank == 0: + timings[f'process_features_{model_name}'] = time.time() - t0 + + if args.distributed and args.dap_degree > 1: + processed_feature_dict = align_feat( + processed_feature_dict, args.dap_degree) + + def _forward_with_precision(processed_feature_dict): + if args.precision == "bf16": + black_list, white_list = get_bf16_op_list() + with paddle.amp.auto_cast(level='O1', custom_white_list=white_list, custom_black_list=black_list, dtype='bfloat16'): + return model_runner.predict( + processed_feature_dict, + ensemble_representations=True, + return_representations=True) + elif args.precision == "fp32": + return model_runner.predict( + processed_feature_dict, + ensemble_representations=True, + return_representations=True) + else: + raise ValueError("Please choose precision from bf16 and fp32! ") t0 = time.time() - prediction = model_runner.predict( - processed_feature_dict, - ensemble_representations=True, - return_representations=True) + prediction = _forward_with_precision(processed_feature_dict) - if args.distributed and dp_rank == 0 and args.dap_degree > 1: + if args.distributed and args.dap_degree > 1: prediction = unpad_prediction(feature_dict, prediction) print('########## prediction shape ##########') model.print_shape(prediction) - if dp_rank == 0: + if dap_rank == 0 and bp_rank == 0: timings[f'predict_{model_name}'] = time.time() - t0 aatype = feature_dict['aatype'].argmax(axis=-1) @@ -136,7 +152,7 @@ def predict_structure( output_dir, 0, timings) plddts[model_name] = np.mean(prediction['plddt']) - if dp_rank == 0: + if dap_rank == 0 and bp_rank == 0: # Rank by pLDDT and write out relaxed PDBs in rank order. ranked_order = [] for idx, (model_name, _) in enumerate( @@ -247,8 +263,6 @@ def main(args): for fasta_path, fasta_name in zip(args.fasta_paths.split(','), fasta_names): predict_structure( - dp_rank=dp_rank, - dp_nranks=dp_nranks, fasta_path=fasta_path, fasta_name=fasta_name, output_dir_base=args.output_dir, @@ -336,7 +350,7 @@ def main(args): parser.add_argument('--random_seed', type=int, help='The random seed for the data pipeline. ' 'By default, this is randomly generated.') - + parser.add_argument("--precision", type=str, choices=['fp32', 'bf16'], default='fp32') parser.add_argument('--distributed', action='store_true', default=False, help='Whether to use distributed DAP inference') diff --git a/apps/protein_folding/helixfold/train.py b/apps/protein_folding/helixfold/train.py index 09d1afa3..56ff1a16 100644 --- a/apps/protein_folding/helixfold/train.py +++ b/apps/protein_folding/helixfold/train.py @@ -30,6 +30,7 @@ from tensorboardX import SummaryWriter from utils.utils import get_model_parameter_size, add_to_data_writer, upload_to_hadoop, csv_print +from utils.utils import get_bf16_op_list from utils.metric import ResultsCollect from utils.model import RunModel from utils.exponential_moving_average import ExponentialMovingAverage, EMA @@ -98,15 +99,6 @@ def get_optimizer(opt_config, model): ) return optimizer, lr_scheduler -def get_bf16_op_list(): - """tbd.""" - - black_list = {"reduce_sum"} - white_list = {"concat", "elementwise_add", "elementwise_div", "elementwise_mul", "elementwise_sub", "fill_any_like", "fill_constant", "gather", "gaussian_random", - "softmax", "layer_norm", "log_softmax", "matmul_v2", "p_norm", "py_layer", "relu", "scale", "sigmoid", "slice", "softplus", "split", "sqrt", "square", "stack", - "sum", "transpose2", "fused_gate_attention", "dropout_nd"} - return black_list, white_list - def add_dyna_features(train_config, model_config, batch, step): """add `num_iter_recycling` and `use_clamped_fape`""" diff --git a/apps/protein_folding/helixfold/utils/utils.py b/apps/protein_folding/helixfold/utils/utils.py index 6e553844..2ac90847 100644 --- a/apps/protein_folding/helixfold/utils/utils.py +++ b/apps/protein_folding/helixfold/utils/utils.py @@ -19,6 +19,14 @@ import numpy as np import paddle +def get_bf16_op_list(): + """tbd.""" + + black_list = {"reduce_sum"} + white_list = {"concat", "elementwise_add", "elementwise_div", "elementwise_mul", "elementwise_sub", "fill_any_like", "fill_constant", "gather", "gaussian_random", + "softmax", "layer_norm", "log_softmax", "matmul_v2", "p_norm", "py_layer", "relu", "scale", "sigmoid", "slice", "softplus", "split", "sqrt", "square", "stack", + "sum", "transpose2", "fused_gate_attention", "dropout_nd"} + return black_list, white_list def get_model_parameter_size(model): """tbd"""