Skip to content

Commit 142d792

Browse files
authored
Supports droping the masked tokens in MoE dispatching and fix some bugs (#11114)
* drop padded tokens in MoE dispatching * fix sequence parallel bugs in deepseek-v3 * drop padded token subbatch * fix offload setting in subbatch
1 parent ee4968b commit 142d792

File tree

4 files changed

+79
-47
lines changed

4 files changed

+79
-47
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 69 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import paddle.nn.functional as F
3434
from paddle import Tensor, nn
3535
from paddle.distributed import fleet
36-
from paddle.distributed.communication.reduce import ReduceOp
3736
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
3837
from paddle.distributed.fleet.recompute.recompute import recompute
3938
from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -881,8 +880,8 @@ def __init__(self, config: DeepseekV2Config):
881880
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
882881
self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size, is_moe=False)
883882

884-
def forward(self, hidden_states):
885-
final_hidden_states, l_aux, l_zloss = super().forward(hidden_states)
883+
def forward(self, hidden_states, masked_tokens=None):
884+
final_hidden_states, l_aux, l_zloss = super().forward(hidden_states, masked_tokens=masked_tokens)
886885
if self.training and self.alpha > 0.0:
887886
l_aux = l_aux * self.alpha
888887
final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, l_aux)
@@ -1003,6 +1002,7 @@ def linear_dtype_gaurd():
10031002
# fmt: on
10041003

10051004
if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel:
1005+
10061006
def grad_allreduce_hook(param, accumulation_steps):
10071007
hcg = fleet.get_hybrid_communicate_group()
10081008
pg = hcg.get_model_parallel_group().process_group
@@ -1018,10 +1018,22 @@ def __impl__():
10181018
pg.allreduce(param.grad).wait()
10191019

10201020
return __impl__
1021+
10211022
# kv_a_proj_with_mqa and q_a_proj grad need to be reduce between mp
1022-
self.kv_a_proj_with_mqa.weight._register_backward_hook(grad_allreduce_hook(self.kv_a_proj_with_mqa.weight, accumulation_steps=config.gradient_accumulation_steps))
1023-
self.q_a_proj.weight._register_backward_hook(grad_allreduce_hook(self.q_a_proj.weight, accumulation_steps=config.gradient_accumulation_steps))
1024-
1023+
# self.kv_a_proj_with_mqa.weight._register_backward_hook(
1024+
# grad_allreduce_hook(
1025+
# self.kv_a_proj_with_mqa.weight, accumulation_steps=config.gradient_accumulation_steps
1026+
# )
1027+
# )
1028+
# self.q_a_proj.weight._register_backward_hook(
1029+
# grad_allreduce_hook(self.q_a_proj.weight, accumulation_steps=config.gradient_accumulation_steps)
1030+
# )
1031+
mark_as_sequence_parallel_parameter(self.kv_a_proj_with_mqa.weight)
1032+
mark_as_sequence_parallel_parameter(self.q_a_proj.weight)
1033+
if config.attention_bias:
1034+
mark_as_sequence_parallel_parameter(self.kv_a_proj_with_mqa.bias)
1035+
mark_as_sequence_parallel_parameter(self.q_a_proj.bias)
1036+
10251037
self._init_rope()
10261038

10271039
self.softmax_scale = self.q_head_dim ** (-0.5)
@@ -1284,18 +1296,41 @@ def subbatch_recompute_forward(
12841296
seq_len = hidden_states.shape[seq_axis]
12851297
assert seq_len % sub_seq_len == 0
12861298
num_chunks = seq_len // sub_seq_len
1287-
split_list = [sub_seq_len] * num_chunks
1288-
input_list = paddle.split(hidden_states, split_list, axis=seq_axis)
1299+
input_list = paddle.split(hidden_states, num_chunks, axis=seq_axis)
12891300
output_list = []
12901301

1291-
for chunk in input_list:
1292-
out = recompute(
1293-
self.mlp.forward,
1294-
chunk,
1295-
**offload_kwargs,
1296-
)
1297-
output_list.append(out)
1302+
if isinstance(self.mlp, DeepseekV2MoEFlexToken):
1303+
if attn_mask_startend_row_indices is not None:
1304+
if self.config.sequence_parallel and self.config.tensor_parallel_degree > 1:
1305+
flat_mask = paddle.transpose(attn_mask_startend_row_indices, [2, 0, 1])
1306+
flat_mask = ScatterOp.apply(flat_mask)
1307+
flat_mask = paddle.flatten(flat_mask)
1308+
mask_list = paddle.split(flat_mask, num_chunks)
1309+
else:
1310+
mask_list = [None] * num_chunks
1311+
1312+
for chunk, mask_chunk in zip(input_list, mask_list):
1313+
masked_tokens = None
1314+
if mask_chunk is not None:
1315+
masked_tokens = mask_chunk == 0
1316+
offload_kwargs["offload_indices"] = [0]
1317+
out = recompute(
1318+
self.mlp.forward,
1319+
chunk,
1320+
masked_tokens=masked_tokens,
1321+
**offload_kwargs,
1322+
)
1323+
output_list.append(out)
1324+
else:
1325+
for chunk in input_list:
1326+
out = recompute(
1327+
self.mlp.forward,
1328+
chunk,
1329+
**offload_kwargs,
1330+
)
1331+
output_list.append(out)
12981332
hidden_states = paddle.concat(output_list, axis=seq_axis)
1333+
offload_kwargs["offload_indices"] = [0]
12991334
outputs = recompute(
13001335
self.post_process,
13011336
hidden_states,
@@ -1431,7 +1466,17 @@ def forward(
14311466
self_attn_weights = attn_outputs[2] if output_attentions else None
14321467
present_key_value = attn_outputs[3] if use_cache else None
14331468

1434-
hidden_states = self.mlp(hidden_states)
1469+
if attn_mask_startend_row_indices is not None and isinstance(self.mlp, DeepseekV2MoEFlexToken):
1470+
masked_tokens = None
1471+
if self.config.sequence_parallel and self.config.tensor_parallel_degree > 1:
1472+
flat_mask = paddle.transpose(attn_mask_startend_row_indices, [2, 0, 1])
1473+
flat_mask = ScatterOp.apply(flat_mask)
1474+
flat_mask = paddle.flatten(flat_mask)
1475+
masked_tokens = flat_mask == 0
1476+
hidden_states = self.mlp(hidden_states, masked_tokens=masked_tokens)
1477+
else:
1478+
hidden_states = self.mlp(hidden_states)
1479+
14351480
outputs = self.post_process(
14361481
hidden_states, residual, output_attentions, use_cache, self_attn_weights, present_key_value
14371482
)
@@ -1547,6 +1592,10 @@ def __init__(
15471592
self.hnorm = DeepseekV2RMSNorm(config)
15481593
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size)
15491594

1595+
if config.sequence_parallel and config.tensor_parallel_degree > 1:
1596+
mark_as_sequence_parallel_parameter(self.eh_proj.weight)
1597+
mark_as_sequence_parallel_parameter(self.eh_proj.bias)
1598+
15501599
def subbatch_recompute_forward(
15511600
self,
15521601
hidden_states: paddle.Tensor,
@@ -2226,10 +2275,6 @@ def __init__(self, config: DeepseekV2Config):
22262275
else:
22272276
self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index)
22282277

2229-
if self.config.sequence_parallel:
2230-
self.seq_para_scale = 1.0 / self.config.tensor_parallel_degree
2231-
self.mp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group()
2232-
22332278
def forward(self, prediction_scores, masked_lm_labels, router_loss=None, mtp_logits=None):
22342279

22352280
if self.enable_parallel_cross_entropy:
@@ -2247,17 +2292,11 @@ def compute_loss(preds, labels):
22472292
)
22482293
count = paddle.sum(binary_sequence)
22492294

2250-
if self.config.sequence_parallel:
2251-
dist.all_reduce(count, op=ReduceOp.SUM, group=self.mp_group)
2252-
22532295
if count == 0:
22542296
loss = paddle.sum(masked_lm_loss * binary_sequence)
22552297
else:
22562298
loss = paddle.sum(masked_lm_loss * binary_sequence) / count
22572299

2258-
if self.config.sequence_parallel:
2259-
dist.all_reduce(loss, op=ReduceOp.SUM, group=self.mp_group)
2260-
22612300
return loss
22622301

22632302
def add_loss(main_loss, loss):
@@ -2269,28 +2308,15 @@ def add_loss(main_loss, loss):
22692308
masked_lm_labels = masked_lm_labels[:, : -self.config.num_nextn_predict_layers]
22702309
seq_length = masked_lm_labels.shape[1]
22712310

2272-
if self.config.sequence_parallel:
2273-
masked_lm_labels = masked_lm_labels.transpose([1, 0]) # [B, S] --> [S, B]
2274-
masked_lm_labels = ScatterOp.apply(masked_lm_labels)
2275-
22762311
loss = compute_loss(prediction_scores, masked_lm_labels)
22772312

22782313
mtp_loss_res = []
22792314
for depth in range(self.config.num_nextn_predict_layers):
22802315
prediction_scores_cur_depth = mtp_logits[depth]
22812316
masked_lm_labels_cur_depth = masked_lm_labels_ori[:, (depth + 1) : (depth + 1 + seq_length)]
2282-
2283-
if self.config.sequence_parallel:
2284-
masked_lm_labels_cur_depth = masked_lm_labels_cur_depth.transpose([1, 0]) # [B, S] --> [S, B]
2285-
masked_lm_labels_cur_depth = ScatterOp.apply(masked_lm_labels_cur_depth)
2286-
22872317
res_cur_depth = compute_loss(prediction_scores_cur_depth, masked_lm_labels_cur_depth)
2288-
2289-
if self.config.sequence_parallel:
2290-
res_cur_depth = res_cur_depth * self.seq_para_scale
2291-
dist.all_reduce(res_cur_depth, op=ReduceOp.SUM, group=self.mp_group)
2292-
22932318
mtp_loss_res.append(res_cur_depth)
2319+
22942320
loss = add_loss(loss, self.config.num_nextn_predict_lambda * sum([x for x in mtp_loss_res]) / len(mtp_loss_res)) # fmt: skip
22952321

22962322
else:
@@ -2336,9 +2362,9 @@ def __init__(self, config: DeepseekV2Config):
23362362

23372363
def forward(self, hidden_states, tensor_parallel_output=None):
23382364

2339-
# if self.config.sequence_parallel:
2340-
# hidden_states = GatherOp.apply(hidden_states)
2341-
# hidden_states = paddle.transpose(hidden_states, [1, 0, 2])
2365+
if self.config.sequence_parallel:
2366+
hidden_states = GatherOp.apply(hidden_states)
2367+
hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # [S, B, H] --> [B, S, H]
23422368
# hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size])
23432369

23442370
if tensor_parallel_output is None:

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,12 @@ def forward(self, args):
294294
hidden_states = hidden_states_main_model
295295
for depth in range(self.config.num_nextn_predict_layers):
296296
inputs_embeds_cur_depth = inputs_embeds_cur_depth_list[depth]
297-
297+
298298
moelayer_use_subbatch_recompute = self.config.moe_subbatch_token_num > 0
299299
if moelayer_use_subbatch_recompute:
300300
hidden_states = super().subbatch_recompute_forward(
301301
hidden_states,
302+
inputs_embeds_cur_depth,
302303
position_ids=position_ids,
303304
attention_mask=attention_mask,
304305
attn_mask_startend_row_indices=attn_mask_startend_row_indices,

paddlenlp/transformers/moe_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,12 @@ def expert_forward(self, dispatched_input, tokens_per_expert):
378378

379379
return paddle.concat(outputs, axis=0)
380380

381-
def forward(self, hidden_states: paddle.Tensor):
381+
def forward(self, hidden_states: paddle.Tensor, masked_tokens=None):
382382
_, _, d_model = hidden_states.shape
383383
# reshaped_input = hidden_states.reshape([-1, d_model])
384384
probs, routing_map, l_aux, l_zloss = self.router(hidden_states)
385385
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
386-
hidden_states, probs, routing_map
386+
hidden_states, probs, routing_map, masked_tokens
387387
)
388388
expert_output = self.expert_forward(dispatched_input, tokens_per_expert)
389389
output, _ = self.token_dispatcher.token_unpermutation(expert_output, None)

paddlenlp/transformers/token_dispatcher.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,17 @@ def __init__(self, num_local_experts: int, moe_router_topk: int, num_moe_experts
261261
)
262262

263263
def token_permutation(
264-
self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor
264+
self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor, masked_tokens=None
265265
) -> Tuple[paddle.Tensor, paddle.Tensor]:
266266
self.hidden_shape = hidden_states.shape
267267
hidden_states = hidden_states.view([-1, self.hidden_shape[-1]])
268268

269269
self._comm_manager.setup_metadata(routing_map, probs)
270+
if masked_tokens is not None:
271+
self._comm_manager.token_indices.stop_gradient = True
272+
masked_tokens = masked_tokens.unsqueeze(axis=-1)
273+
self._comm_manager.token_indices = paddle.masked_fill(self._comm_manager.token_indices, masked_tokens, -1)
274+
270275
hidden_states = self._comm_manager.dispatch(hidden_states)
271276
global_input_tokens = self._comm_manager.get_permuted_hidden_states_by_experts(hidden_states)
272277
tokens_per_expert = self._comm_manager.get_number_of_tokens_per_expert()

0 commit comments

Comments
 (0)