Skip to content

Commit ee4968b

Browse files
authored
add dsv3 64gpu sft json && solve OOM problem by offload (#11112)
1 parent 7adc457 commit ee4968b

File tree

9 files changed

+264
-0
lines changed

9 files changed

+264
-0
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
{
2+
"model_name_or_path": "/root/paddlejob/tmpspace/huggingface_model/huggingface/deepseek-ai/DeepSeek-V3-bf16/",
3+
"dataset_name_or_path": "./data_small",
4+
"output_dir": "./checkpoints/sft_ckpts",
5+
"per_device_train_batch_size": 1,
6+
"gradient_accumulation_steps": 8,
7+
"per_device_eval_batch_size": 1,
8+
"eval_accumulation_steps": 1,
9+
"max_steps": 100,
10+
"max_grad_norm": 1.0,
11+
"amp_master_grad": true,
12+
"num_train_epochs": 1,
13+
"learning_rate": 2.2e-05,
14+
"aux_loss_alpha": 0.0001,
15+
"warmup_steps": 30,
16+
"logging_steps": 1,
17+
"evaluation_strategy": "no",
18+
"save_strategy": "no",
19+
"src_length": 2048,
20+
"max_length": 4097,
21+
"bf16": true,
22+
"fp16_opt_level": "O2",
23+
"do_train": true,
24+
"do_eval": false,
25+
"disable_tqdm": true,
26+
"use_expert_parallel": true,
27+
"expert_parallel_degree": 8,
28+
"continue_training": false,
29+
"pipeline_parallel_config": "enable_delay_scale_loss disable_partial_send_recv disable_batch_p2p_comm",
30+
"tensor_parallel_config": "sync_param sync_grad",
31+
"sharding_parallel_config": "split_param",
32+
"load_best_model_at_end": true,
33+
"eval_with_do_generation": false,
34+
"metric_for_best_model": "loss",
35+
"recompute": true,
36+
"recompute_use_reentrant": true,
37+
"recompute_granularity": "full",
38+
"save_total_limit": 1,
39+
"tensor_parallel_degree": 4,
40+
"pipeline_parallel_degree": 8,
41+
"sharding_parallel_degree": 2,
42+
"sharding": "stage1",
43+
"zero_padding": true,
44+
"unified_checkpoint": true,
45+
"use_flash_attention": true,
46+
"flash_mask": true,
47+
"using_fake_gate": true,
48+
"using_flex_token": true,
49+
"use_fused_rms_norm": true,
50+
"moe_subbatch_token_num": 0,
51+
"recompute_offload": true,
52+
"pre_alloc_memory": 70,
53+
"tensorwise_offload_optimizer": true,
54+
"sequence_parallel": true,
55+
"tensor_parallel_output": true
56+
}
57+

llm/run_finetune.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def main():
268268
model_config.moe_subbatch_token_num = model_args.moe_subbatch_token_num
269269
model_config.aux_loss_alpha = model_args.aux_loss_alpha
270270
model_config.gradient_accumulation_steps = training_args.gradient_accumulation_steps
271+
model_config.recompute_offload = training_args.recompute_offload
271272
logger.info(f"Final model config: {model_config}")
272273

273274
logger.info("Creating model")

llm/script/kill_process.sh

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#!/bin/bash
2+
set -x
3+
4+
SCRIPT_DIR=`dirname "$0"`
5+
LAUNCH_SCRIPT="$SCRIPT_DIR/selective_launch.py"
6+
7+
if [[ -f "$LAUNCH_SCRIPT" ]]; then
8+
LAUNCH_CMD=`python "$LAUNCH_SCRIPT" 36677`
9+
if [[ -z "$LAUNCH_CMD" ]]; then
10+
exit 0
11+
fi
12+
fi
13+
14+
skip_kill_time=${1:-"False"}
15+
16+
function kill_impl() {
17+
skip_kill_time=$1
18+
if [[ $skip_kill_time == "True" ]];then
19+
for((i=1;i<=60;i++));
20+
do
21+
pids=`ps -ef | grep 'time_2023_8888.py' | grep -v grep | awk '{print $2}'`
22+
if [[ "$pids" == "" ]] ; then
23+
echo "no process found for speed-testing. stop waiting and kill other scripts."
24+
break
25+
fi
26+
echo "wait 10 seconds for finishing the speed-testing scripts."
27+
sleep 10s
28+
done
29+
fi
30+
31+
# kill aadiff test finally.
32+
ps -ef | grep -E "check_aadiff.sh|run_aadiff_matmul.sh|test_matmul.py" | awk '{print $2}' | xargs kill -9
33+
34+
pids=`ps -ef | grep train.py | grep -v grep | awk '{print $2}'`
35+
if [[ "$pids" != "" ]] ; then
36+
echo $pids
37+
echo $pids | xargs kill -9
38+
fi
39+
40+
# kill agent server
41+
(ps -ef | grep agent | grep port | awk '{print $2}' | xargs -I {} kill -9 {}) || true
42+
43+
if [[ $TRAININGJOB_REPLICA_NAME == "trainer" ]]; then
44+
echo "Killing processes on gpu"
45+
lsof /dev/nvidia* | awk '{print $2}' | xargs -I {} kill -9 {}
46+
elif [[ $TRAININGJOB_REPLICA_NAME == "trainerxpu" ]]; then
47+
echo "Killing processes on xpu"
48+
lsof /dev/xpu* | awk '{print $2}' | xargs -I {} kill -9 {}
49+
else
50+
echo "[FATAL] unsupported training job type: ${TRAININGJOB_REPLICA_NAME}"
51+
exit 1
52+
fi
53+
}
54+
55+
kill_impl $skip_kill_time || true

llm/script/selective_launch.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""
2+
Selective launch script.
3+
4+
Usage: python script/selective_launch.py <port> <ranks> <ranks> <ranks> ...
5+
"""
6+
import os
7+
import sys
8+
9+
10+
def parse_ranks(ranks_strs):
11+
"""
12+
parse_ranks
13+
"""
14+
# NOTE: You can return ranks directly here to change script/train_gpu.sh
15+
# and script/kill_process.sh together
16+
17+
# Example 1: Use contiguous nodes [8, 16)
18+
# return range(8, 16)
19+
20+
# Example 2: Use non-contiguous nodes [4, 8) + {10} + [30, 32), i.e., [4, 5, 6, 7, 10, 30, 31]
21+
# return list(range(4, 8)) + [10] + list(range(30, 32))
22+
23+
# Example 3:
24+
# Just Python code, return any nodes you want!
25+
return list(range(64, 72))
26+
if not ranks_strs:
27+
return None
28+
29+
ranks = []
30+
for r in ranks_strs:
31+
r = eval(r)
32+
if isinstance(r, int):
33+
ranks.append(r)
34+
else:
35+
ranks.extend(r)
36+
return ranks
37+
38+
39+
def main(port, ranks):
40+
"""
41+
main
42+
"""
43+
ips = [ip.strip() for ip in os.getenv("TRAINER_INSTANCES").split(",") if ip.strip()]
44+
if ranks is None:
45+
ranks = list(range(len(ips)))
46+
ranks = sorted(list(set(ranks)))
47+
my_rank = int(os.getenv("POD_INDEX", "0"))
48+
if my_rank not in ranks:
49+
return
50+
51+
rank = ranks.index(my_rank)
52+
nranks = len(ranks)
53+
54+
master = ips[ranks[0]]
55+
print(f"--master {master}:{port} --rank {rank} --nnodes {nranks}")
56+
57+
58+
if __name__ == "__main__":
59+
main(int(sys.argv[1]), parse_ranks(sys.argv[2:]))

llm/script/train_gpu.sh

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#!/bin/bash
2+
3+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
unset PADDLE_ELASTIC_JOB_ID
18+
unset PADDLE_TRAINER_ENDPOINTS
19+
unset DISTRIBUTED_TRAINER_ENDPOINTS
20+
unset FLAGS_START_PORT
21+
unset PADDLE_ELASTIC_TIMEOUT
22+
23+
nnodes=$PADDLE_TRAINERS_NUM
24+
rank=$PADDLE_TRAINER_ID
25+
26+
for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do
27+
unset ${name}
28+
done
29+
30+
#export FLAGS_shard_bypass_dygraph_optimizer=1
31+
export NCCL_IB_GID_INDEX=3
32+
export NVSHMEM_IB_GID_INDEX=3
33+
export NVSHMEM_IB_TRAFFIC_CLASS=162
34+
35+
#export NVSHMEM_IB_ENABLE_IBGDA=true
36+
##export NVSHMEM_DISABLE_P2P=1
37+
export NVSHMEM_BOOTSTRAP=UID
38+
39+
unset NVSHMEM_HCA_LIST
40+
unset NVSHMEM_ENABLE_NIC_PE_MAPPING
41+
42+
LAUNCH_CMD=`python script/selective_launch.py 36677`
43+
if [[ -z "$LAUNCH_CMD" ]]; then
44+
exit 0
45+
fi
46+
47+
export PYTHONPATH=../:$PYTHONPATH
48+
export CUDA_PATH=/usr/local/cuda-12.9
49+
50+
export DSV3_USE_FP8_GEMM=true
51+
export DSV3_USE_ATTEN_RECOMPUTE=true
52+
export FA_VERSION=3
53+
export FLAGS_share_tensor_for_grad_tensor_holder=1
54+
export FLAGS_use_default_stream=false
55+
export DSV3_USE_FP8_DISPATCH=true
56+
export USE_DS_GEMM=false
57+
58+
export NVJITLIB=/root/paddlejob/workspace/env_run/zhengzhonghui/venv/lib/python3.10/site-packages/nvidia/nvjitlink/lib/
59+
export CUSPARSELIB=/root/paddlejob/workspace/env_run/zhengzhonghui/venv/lib/python3.10/site-packages/nvidia/cusparse/lib
60+
export LD_LIBRARY_PATH=$NVJITLIB:$CUSPARSELIB:$LD_LIBRARY_PATH
61+
62+
source /root/paddlejob/workspace/env_run/zhengzhonghui/venv/bin/activate
63+
# source /root/paddlejob/workspace/env_run/chenzhichao/venv/bin/activate
64+
bash script/kill_process.sh
65+
66+
rm core.* -rf
67+
68+
python3.10 -m paddle.distributed.launch \
69+
--log_dir output/paddle_distributed_logs \
70+
$LAUNCH_CMD \
71+
--run_mode=collective \
72+
${script:-run_finetune.py} \
73+
$@

paddlenlp/trainer/utils/offload_optimizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ def new_add_accumulator(self, *args, **kwargs):
4848

4949
setattr(Optimizer, "_add_accumulator", new_add_accumulator)
5050

51+
origin_create_master_weight = getattr(Optimizer, "_create_master_weight")
52+
def new_create_master_weight(self, *args, **kwargs):
53+
x = origin_create_master_weight(self, *args, **kwargs)
54+
offload(x)
55+
return x
56+
57+
setattr(Optimizer, "_create_master_weight", new_create_master_weight)
58+
5159
# Step 2: mock _C_ops.adamw_ and _C_ops.adamw
5260
for name in ["adam_", "adamw_"]:
5361
origin_op = getattr(_C_ops, name)

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,9 @@ def forward(self, args):
235235
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
236236
)
237237
elif self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:
238+
offload_kwargs = {}
239+
if self.config.recompute_offload:
240+
offload_kwargs["offload_indices"] = [0]
238241
if attention_mask is not None or attn_mask_startend_row_indices is not None:
239242
hidden_states = recompute(
240243
super().forward,
@@ -243,6 +246,7 @@ def forward(self, args):
243246
attention_mask=attention_mask,
244247
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
245248
use_reentrant=self.config.recompute_use_reentrant,
249+
**offload_kwargs,
246250
)
247251
else:
248252
# for pretrain
@@ -300,6 +304,9 @@ def forward(self, args):
300304
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
301305
)
302306
elif self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:
307+
offload_kwargs = {}
308+
if self.config.recompute_offload:
309+
offload_kwargs["offload_indices"] = [0]
303310
if attention_mask is not None or attn_mask_startend_row_indices is not None:
304311
hidden_states = recompute(
305312
super().forward,
@@ -309,6 +316,7 @@ def forward(self, args):
309316
attention_mask=attention_mask,
310317
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
311318
use_reentrant=self.config.recompute_use_reentrant,
319+
**offload_kwargs,
312320
)
313321
else:
314322
# for pretrain

paddlenlp/transformers/deepseek_v3/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def forward(
8888
self,
8989
input_ids: paddle.Tensor = None,
9090
attention_mask: Optional[paddle.Tensor] = None,
91+
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
9192
position_ids: Optional[paddle.Tensor] = None,
9293
past_key_values: Optional[List[paddle.Tensor]] = None,
9394
inputs_embeds: Optional[paddle.Tensor] = None,
@@ -139,6 +140,7 @@ def forward(
139140
output_attentions=output_attentions,
140141
output_hidden_states=output_hidden_states,
141142
return_dict=return_dict,
143+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
142144
)
143145

144146
hidden_states = outputs[0]

paddlenlp/trl/sft_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class SFTConfig(TrainingArguments):
6363
model_init_kwargs: Optional[dict[str, Any]] = None
6464
dataset_kwargs: Optional[dict[str, Any]] = None
6565
eval_packing: Optional[bool] = None
66+
recompute_offload: Optional[bool] = None
6667
use_ssa: bool = field(
6768
default=False,
6869
metadata={

0 commit comments

Comments
 (0)