Skip to content

Commit 7adc457

Browse files
authored
[Bug Fix]reduce grad of kv_a_proj_with_mqa and q_a_proj to maintain correctness (#11085)
1 parent 6e67781 commit 7adc457

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

llm/run_finetune.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def main():
267267
model_config.using_fake_gate = model_args.using_fake_gate
268268
model_config.moe_subbatch_token_num = model_args.moe_subbatch_token_num
269269
model_config.aux_loss_alpha = model_args.aux_loss_alpha
270+
model_config.gradient_accumulation_steps = training_args.gradient_accumulation_steps
270271
logger.info(f"Final model config: {model_config}")
271272

272273
logger.info("Creating model")

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,26 @@ def linear_dtype_gaurd():
10021002

10031003
# fmt: on
10041004

1005+
if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel:
1006+
def grad_allreduce_hook(param, accumulation_steps):
1007+
hcg = fleet.get_hybrid_communicate_group()
1008+
pg = hcg.get_model_parallel_group().process_group
1009+
step = [0]
1010+
1011+
@paddle.autograd.no_grad()
1012+
def __impl__():
1013+
step[0] += 1
1014+
if (step[0] % accumulation_steps) == 0:
1015+
if hasattr(param, "main_grad"):
1016+
pg.allreduce(param.main_grad).wait()
1017+
else:
1018+
pg.allreduce(param.grad).wait()
1019+
1020+
return __impl__
1021+
# kv_a_proj_with_mqa and q_a_proj grad need to be reduce between mp
1022+
self.kv_a_proj_with_mqa.weight._register_backward_hook(grad_allreduce_hook(self.kv_a_proj_with_mqa.weight, accumulation_steps=config.gradient_accumulation_steps))
1023+
self.q_a_proj.weight._register_backward_hook(grad_allreduce_hook(self.q_a_proj.weight, accumulation_steps=config.gradient_accumulation_steps))
1024+
10051025
self._init_rope()
10061026

10071027
self.softmax_scale = self.q_head_dim ** (-0.5)

0 commit comments

Comments
 (0)