Skip to content

fix the bug of dap for preprocess #245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 45 additions & 47 deletions apps/protein_folding/helixfold/gpu_infer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,59 @@ root_path="$(pwd)"
demo=$1

export PYTHONPATH=$root_path:$PYTHONPATH
export FLAGS_use_cuda_managed_memory=true

#DATA_DIR="$root_path/data"
DATA_DIR="/root/paddlejob/workspace/env_run/alphafold_data"
fasta_file="$root_path/demo_data/casp14_demo/fasta/${demo}.fasta"
OUTPUT_DIR="$root_path/demo_data/casp14_demo/output"
log_dir="$root_path/demo_data/casp14_demo/demo_log"
MODELS="model_1,model_5"
USE_DAP=false
SUBBATCH_SIZE=48

if [ $USE_DAP == true ]; then
distributed=true

distributed_args="--run_mode=collective --log_dir=${log_dir}"
python -m paddle.distributed.launch ${distributed_args} \
--gpus="0,1,2,3,4,5,6,7" \
run_helixfold.py \
--distributed \
--dap_degree 8 \
--fasta_paths=${fasta_file} \
--data_dir=${DATA_DIR} \
--bfd_database_path=${DATA_DIR}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--small_bfd_database_path=${DATA_DIR}/small_bfd/bfd-first_non_consensus_sequences.fasta \
--uniclust30_database_path=${DATA_DIR}/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--uniref90_database_path=${DATA_DIR}/uniref90/uniref90.fasta \
--mgnify_database_path=${DATA_DIR}/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path=${DATA_DIR}/pdb70/pdb70 \
--template_mmcif_dir=${DATA_DIR}/pdb_mmcif/mmcif_files \
--obsolete_pdbs_path=${DATA_DIR}/pdb_mmcif/obsolete.dat \
--max_template_date=2020-05-14 \
--model_names=${MODELS} \
--output_dir=${OUTPUT_DIR} \
--disable_amber_relax \
--seed 2022 \
--preset='reduced_dbs' \
--random_seed=0 \
${@:2}
# 'fp32' or 'bf16'
PRECISION='bf16'

if [ $distributed == true ]
then
# enable dap for EXTREMELY LONG SEQUENCE PROTEIN
python_cmd="python -m paddle.distributed.launch --log_dir=${log_dir} --gpus=0,1,2,3,4,5,6,7 "
distributed_flag="--distributed"
DAP_DEGREE=8
# Reduce the size of subbatch_size when the gpu memory is not enough
# SUBBATCH_SIZE=1
else
# enable unified memory for EXTREMELY LONG SEQUENCE PROTEIN
# export FLAGS_use_cuda_managed_memory=true
python_cmd="CUDA_VISIBLE_DEVICES=0 python "
distributed_flag=""
DAP_DEGREE=1
# Reduce the size of subbatch_size when the gpu memory is not enough
# SUBBATCH_SIZE=1
fi

CUDA_VISIBLE_DEVICES=0 python run_helixfold.py \
--fasta_paths=${fasta_file} \
--data_dir=${DATA_DIR} \
--bfd_database_path=${DATA_DIR}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--small_bfd_database_path=${DATA_DIR}/small_bfd/bfd-first_non_consensus_sequences.fasta \
--uniclust30_database_path=${DATA_DIR}/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--uniref90_database_path=${DATA_DIR}/uniref90/uniref90.fasta \
--mgnify_database_path=${DATA_DIR}/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path=${DATA_DIR}/pdb70/pdb70 \
--template_mmcif_dir=${DATA_DIR}/pdb_mmcif/mmcif_files \
--obsolete_pdbs_path=${DATA_DIR}/pdb_mmcif/obsolete.dat \
--max_template_date=2020-05-14 \
--model_names=${MODELS} \
--output_dir=${OUTPUT_DIR} \
--disable_amber_relax \
--seed 2022 \
--preset='reduced_dbs' \
--random_seed=0 \
${@:2}
fi
$python_cmd \
run_helixfold.py \
${distributed_flag} \
--dap_degree=${DAP_DEGREE} \
--fasta_paths=${fasta_file} \
--data_dir=${DATA_DIR} \
--bfd_database_path=${DATA_DIR}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--small_bfd_database_path=${DATA_DIR}/small_bfd/bfd-first_non_consensus_sequences.fasta \
--uniclust30_database_path=${DATA_DIR}/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--uniref90_database_path=${DATA_DIR}/uniref90/uniref90.fasta \
--mgnify_database_path=${DATA_DIR}/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path=${DATA_DIR}/pdb70/pdb70 \
--template_mmcif_dir=${DATA_DIR}/pdb_mmcif/mmcif_files \
--obsolete_pdbs_path=${DATA_DIR}/pdb_mmcif/obsolete.dat \
--max_template_date=2020-05-14 \
--model_names=${MODELS} \
--output_dir=${OUTPUT_DIR} \
--disable_amber_relax \
--seed 2022 \
--preset='reduced_dbs' \
--random_seed=0 \
--precision=${PRECISION} \
--subbatch_size=${SUBBATCH_SIZE} \
${@:2}
90 changes: 52 additions & 38 deletions apps/protein_folding/helixfold/run_helixfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from alphafold_paddle.data.utils import align_feat, unpad_prediction

from utils.init_env import init_seed, init_distributed_env
from ppfleetx.distributed.protein_folding import dp, dap, bp
from utils.utils import get_bf16_op_list

logging.basicConfig()
logger = logging.getLogger(__file__)
Expand All @@ -48,15 +50,20 @@


def predict_structure(
dp_rank: int,
dp_nranks: int,
fasta_path: str,
fasta_name: str,
output_dir_base: str,
data_pipeline: pipeline.DataPipeline,
model_runners: Dict[str, model.RunModel],
amber_relaxer: relax.AmberRelaxation,
random_seed: int):

dap_rank = 0
bp_rank = 0
if args.distributed and args.dap_degree > 1:
dap_rank = dap.get_rank_in_group() if dap.get_world_size() > 1 else 0
bp_rank = bp.get_rank_in_group() if bp.get_world_size() > 1 else 0

timings = dict()
output_dir = pathlib.Path(output_dir_base).joinpath(fasta_name)
output_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -75,18 +82,22 @@ def predict_structure(
logger.info('Use cached features.pkl')
with open(features_pkl, 'rb') as f:
feature_dict = pickle.load(f)
elif dp_rank == 0:
t0 = time.time()
feature_dict = data_pipeline.process(
input_fasta_path=fasta_path,
msa_output_dir=msa_output_dir)
timings['features'] = time.time() - t0
else:
if dap_rank == 0 and bp_rank == 0:
t0 = time.time()
feature_dict = data_pipeline.process(
input_fasta_path=fasta_path,
msa_output_dir=msa_output_dir)
timings['features'] = time.time() - t0

with open(features_pkl, 'wb') as f:
pickle.dump(feature_dict, f, protocol=4)
with open(features_pkl, 'wb') as f:
pickle.dump(feature_dict, f, protocol=4)

if args.distributed:
dist.barrier()
if args.distributed:
dist.barrier()

with open(features_pkl, 'rb') as f:
feature_dict = pickle.load(f)

relaxed_pdbs, plddts = dict(), dict()
for model_name, model_runner in model_runners.items():
Expand All @@ -96,37 +107,42 @@ def predict_structure(
has_cache = input_features_pkl.exists()

t0 = time.time()
if dp_rank == 0:
processed_feature_dict = model_runner.preprocess(
feature_dict, random_seed, input_features_pkl)
if not has_cache and dp_rank == 0:
timings[f'process_features_{model_name}'] = time.time() - t0

if args.distributed and args.dap_degree > 1:
processed_feature_dict = align_feat(
processed_feature_dict, args.dap_degree)
else:
processed_feature_dict = dict()

if args.distributed:
for _, tensor in processed_feature_dict.items():
dist.broadcast(tensor, 0)

dist.barrier()
processed_feature_dict = model_runner.preprocess(
feature_dict, random_seed, input_features_pkl)
if not has_cache and dap_rank == 0 and bp_rank == 0:
timings[f'process_features_{model_name}'] = time.time() - t0

if args.distributed and args.dap_degree > 1:
processed_feature_dict = align_feat(
processed_feature_dict, args.dap_degree)

def _forward_with_precision(processed_feature_dict):
if args.precision == "bf16":
black_list, white_list = get_bf16_op_list()
with paddle.amp.auto_cast(level='O1', custom_white_list=white_list, custom_black_list=black_list, dtype='bfloat16'):
return model_runner.predict(
processed_feature_dict,
ensemble_representations=True,
return_representations=True)
elif args.precision == "fp32":
return model_runner.predict(
processed_feature_dict,
ensemble_representations=True,
return_representations=True)
else:
raise ValueError("Please choose precision from bf16 and fp32! ")

t0 = time.time()
prediction = model_runner.predict(
processed_feature_dict,
ensemble_representations=True,
return_representations=True)
prediction = _forward_with_precision(processed_feature_dict)

if args.distributed and dp_rank == 0 and args.dap_degree > 1:
if args.distributed and args.dap_degree > 1:
prediction = unpad_prediction(feature_dict, prediction)

print('########## prediction shape ##########')
model.print_shape(prediction)

if dp_rank == 0:
if dap_rank == 0 and bp_rank == 0:
timings[f'predict_{model_name}'] = time.time() - t0

aatype = feature_dict['aatype'].argmax(axis=-1)
Expand All @@ -136,7 +152,7 @@ def predict_structure(
output_dir, 0, timings)
plddts[model_name] = np.mean(prediction['plddt'])

if dp_rank == 0:
if dap_rank == 0 and bp_rank == 0:
# Rank by pLDDT and write out relaxed PDBs in rank order.
ranked_order = []
for idx, (model_name, _) in enumerate(
Expand Down Expand Up @@ -247,8 +263,6 @@ def main(args):

for fasta_path, fasta_name in zip(args.fasta_paths.split(','), fasta_names):
predict_structure(
dp_rank=dp_rank,
dp_nranks=dp_nranks,
fasta_path=fasta_path,
fasta_name=fasta_name,
output_dir_base=args.output_dir,
Expand Down Expand Up @@ -336,7 +350,7 @@ def main(args):
parser.add_argument('--random_seed', type=int,
help='The random seed for the data pipeline. '
'By default, this is randomly generated.')

parser.add_argument("--precision", type=str, choices=['fp32', 'bf16'], default='fp32')
parser.add_argument('--distributed',
action='store_true', default=False,
help='Whether to use distributed DAP inference')
Expand Down
10 changes: 1 addition & 9 deletions apps/protein_folding/helixfold/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tensorboardX import SummaryWriter

from utils.utils import get_model_parameter_size, add_to_data_writer, upload_to_hadoop, csv_print
from utils.utils import get_bf16_op_list
from utils.metric import ResultsCollect
from utils.model import RunModel
from utils.exponential_moving_average import ExponentialMovingAverage, EMA
Expand Down Expand Up @@ -98,15 +99,6 @@ def get_optimizer(opt_config, model):
)
return optimizer, lr_scheduler

def get_bf16_op_list():
"""tbd."""

black_list = {"reduce_sum"}
white_list = {"concat", "elementwise_add", "elementwise_div", "elementwise_mul", "elementwise_sub", "fill_any_like", "fill_constant", "gather", "gaussian_random",
"softmax", "layer_norm", "log_softmax", "matmul_v2", "p_norm", "py_layer", "relu", "scale", "sigmoid", "slice", "softplus", "split", "sqrt", "square", "stack",
"sum", "transpose2", "fused_gate_attention", "dropout_nd"}
return black_list, white_list


def add_dyna_features(train_config, model_config, batch, step):
"""add `num_iter_recycling` and `use_clamped_fape`"""
Expand Down
8 changes: 8 additions & 0 deletions apps/protein_folding/helixfold/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
import numpy as np
import paddle

def get_bf16_op_list():
"""tbd."""

black_list = {"reduce_sum"}
white_list = {"concat", "elementwise_add", "elementwise_div", "elementwise_mul", "elementwise_sub", "fill_any_like", "fill_constant", "gather", "gaussian_random",
"softmax", "layer_norm", "log_softmax", "matmul_v2", "p_norm", "py_layer", "relu", "scale", "sigmoid", "slice", "softplus", "split", "sqrt", "square", "stack",
"sum", "transpose2", "fused_gate_attention", "dropout_nd"}
return black_list, white_list

def get_model_parameter_size(model):
"""tbd"""
Expand Down