33
33
import paddle .nn .functional as F
34
34
from paddle import Tensor , nn
35
35
from paddle .distributed import fleet
36
- from paddle .distributed .communication .reduce import ReduceOp
37
36
from paddle .distributed .fleet .meta_parallel import get_rng_state_tracker
38
37
from paddle .distributed .fleet .recompute .recompute import recompute
39
38
from paddle .nn import BCEWithLogitsLoss , CrossEntropyLoss , MSELoss
@@ -881,8 +880,8 @@ def __init__(self, config: DeepseekV2Config):
881
880
intermediate_size = config .moe_intermediate_size * config .n_shared_experts
882
881
self .shared_experts = DeepseekV2MLP (config = config , intermediate_size = intermediate_size , is_moe = False )
883
882
884
- def forward (self , hidden_states ):
885
- final_hidden_states , l_aux , l_zloss = super ().forward (hidden_states )
883
+ def forward (self , hidden_states , masked_tokens = None ):
884
+ final_hidden_states , l_aux , l_zloss = super ().forward (hidden_states , masked_tokens = masked_tokens )
886
885
if self .training and self .alpha > 0.0 :
887
886
l_aux = l_aux * self .alpha
888
887
final_hidden_states = AddAuxiliaryLoss .apply (final_hidden_states , l_aux )
@@ -1003,6 +1002,7 @@ def linear_dtype_gaurd():
1003
1002
# fmt: on
1004
1003
1005
1004
if self .config .tensor_parallel_degree > 1 and self .config .sequence_parallel :
1005
+
1006
1006
def grad_allreduce_hook (param , accumulation_steps ):
1007
1007
hcg = fleet .get_hybrid_communicate_group ()
1008
1008
pg = hcg .get_model_parallel_group ().process_group
@@ -1018,10 +1018,22 @@ def __impl__():
1018
1018
pg .allreduce (param .grad ).wait ()
1019
1019
1020
1020
return __impl__
1021
+
1021
1022
# kv_a_proj_with_mqa and q_a_proj grad need to be reduce between mp
1022
- self .kv_a_proj_with_mqa .weight ._register_backward_hook (grad_allreduce_hook (self .kv_a_proj_with_mqa .weight , accumulation_steps = config .gradient_accumulation_steps ))
1023
- self .q_a_proj .weight ._register_backward_hook (grad_allreduce_hook (self .q_a_proj .weight , accumulation_steps = config .gradient_accumulation_steps ))
1024
-
1023
+ # self.kv_a_proj_with_mqa.weight._register_backward_hook(
1024
+ # grad_allreduce_hook(
1025
+ # self.kv_a_proj_with_mqa.weight, accumulation_steps=config.gradient_accumulation_steps
1026
+ # )
1027
+ # )
1028
+ # self.q_a_proj.weight._register_backward_hook(
1029
+ # grad_allreduce_hook(self.q_a_proj.weight, accumulation_steps=config.gradient_accumulation_steps)
1030
+ # )
1031
+ mark_as_sequence_parallel_parameter (self .kv_a_proj_with_mqa .weight )
1032
+ mark_as_sequence_parallel_parameter (self .q_a_proj .weight )
1033
+ if config .attention_bias :
1034
+ mark_as_sequence_parallel_parameter (self .kv_a_proj_with_mqa .bias )
1035
+ mark_as_sequence_parallel_parameter (self .q_a_proj .bias )
1036
+
1025
1037
self ._init_rope ()
1026
1038
1027
1039
self .softmax_scale = self .q_head_dim ** (- 0.5 )
@@ -1284,18 +1296,41 @@ def subbatch_recompute_forward(
1284
1296
seq_len = hidden_states .shape [seq_axis ]
1285
1297
assert seq_len % sub_seq_len == 0
1286
1298
num_chunks = seq_len // sub_seq_len
1287
- split_list = [sub_seq_len ] * num_chunks
1288
- input_list = paddle .split (hidden_states , split_list , axis = seq_axis )
1299
+ input_list = paddle .split (hidden_states , num_chunks , axis = seq_axis )
1289
1300
output_list = []
1290
1301
1291
- for chunk in input_list :
1292
- out = recompute (
1293
- self .mlp .forward ,
1294
- chunk ,
1295
- ** offload_kwargs ,
1296
- )
1297
- output_list .append (out )
1302
+ if isinstance (self .mlp , DeepseekV2MoEFlexToken ):
1303
+ if attn_mask_startend_row_indices is not None :
1304
+ if self .config .sequence_parallel and self .config .tensor_parallel_degree > 1 :
1305
+ flat_mask = paddle .transpose (attn_mask_startend_row_indices , [2 , 0 , 1 ])
1306
+ flat_mask = ScatterOp .apply (flat_mask )
1307
+ flat_mask = paddle .flatten (flat_mask )
1308
+ mask_list = paddle .split (flat_mask , num_chunks )
1309
+ else :
1310
+ mask_list = [None ] * num_chunks
1311
+
1312
+ for chunk , mask_chunk in zip (input_list , mask_list ):
1313
+ masked_tokens = None
1314
+ if mask_chunk is not None :
1315
+ masked_tokens = mask_chunk == 0
1316
+ offload_kwargs ["offload_indices" ] = [0 ]
1317
+ out = recompute (
1318
+ self .mlp .forward ,
1319
+ chunk ,
1320
+ masked_tokens = masked_tokens ,
1321
+ ** offload_kwargs ,
1322
+ )
1323
+ output_list .append (out )
1324
+ else :
1325
+ for chunk in input_list :
1326
+ out = recompute (
1327
+ self .mlp .forward ,
1328
+ chunk ,
1329
+ ** offload_kwargs ,
1330
+ )
1331
+ output_list .append (out )
1298
1332
hidden_states = paddle .concat (output_list , axis = seq_axis )
1333
+ offload_kwargs ["offload_indices" ] = [0 ]
1299
1334
outputs = recompute (
1300
1335
self .post_process ,
1301
1336
hidden_states ,
@@ -1431,7 +1466,17 @@ def forward(
1431
1466
self_attn_weights = attn_outputs [2 ] if output_attentions else None
1432
1467
present_key_value = attn_outputs [3 ] if use_cache else None
1433
1468
1434
- hidden_states = self .mlp (hidden_states )
1469
+ if attn_mask_startend_row_indices is not None and isinstance (self .mlp , DeepseekV2MoEFlexToken ):
1470
+ masked_tokens = None
1471
+ if self .config .sequence_parallel and self .config .tensor_parallel_degree > 1 :
1472
+ flat_mask = paddle .transpose (attn_mask_startend_row_indices , [2 , 0 , 1 ])
1473
+ flat_mask = ScatterOp .apply (flat_mask )
1474
+ flat_mask = paddle .flatten (flat_mask )
1475
+ masked_tokens = flat_mask == 0
1476
+ hidden_states = self .mlp (hidden_states , masked_tokens = masked_tokens )
1477
+ else :
1478
+ hidden_states = self .mlp (hidden_states )
1479
+
1435
1480
outputs = self .post_process (
1436
1481
hidden_states , residual , output_attentions , use_cache , self_attn_weights , present_key_value
1437
1482
)
@@ -1547,6 +1592,10 @@ def __init__(
1547
1592
self .hnorm = DeepseekV2RMSNorm (config )
1548
1593
self .eh_proj = nn .Linear (2 * config .hidden_size , config .hidden_size )
1549
1594
1595
+ if config .sequence_parallel and config .tensor_parallel_degree > 1 :
1596
+ mark_as_sequence_parallel_parameter (self .eh_proj .weight )
1597
+ mark_as_sequence_parallel_parameter (self .eh_proj .bias )
1598
+
1550
1599
def subbatch_recompute_forward (
1551
1600
self ,
1552
1601
hidden_states : paddle .Tensor ,
@@ -2226,10 +2275,6 @@ def __init__(self, config: DeepseekV2Config):
2226
2275
else :
2227
2276
self .loss_func = paddle .nn .CrossEntropyLoss (reduction = "none" , ignore_index = self .ignore_index )
2228
2277
2229
- if self .config .sequence_parallel :
2230
- self .seq_para_scale = 1.0 / self .config .tensor_parallel_degree
2231
- self .mp_group = fleet .get_hybrid_communicate_group ().get_model_parallel_group ()
2232
-
2233
2278
def forward (self , prediction_scores , masked_lm_labels , router_loss = None , mtp_logits = None ):
2234
2279
2235
2280
if self .enable_parallel_cross_entropy :
@@ -2247,17 +2292,11 @@ def compute_loss(preds, labels):
2247
2292
)
2248
2293
count = paddle .sum (binary_sequence )
2249
2294
2250
- if self .config .sequence_parallel :
2251
- dist .all_reduce (count , op = ReduceOp .SUM , group = self .mp_group )
2252
-
2253
2295
if count == 0 :
2254
2296
loss = paddle .sum (masked_lm_loss * binary_sequence )
2255
2297
else :
2256
2298
loss = paddle .sum (masked_lm_loss * binary_sequence ) / count
2257
2299
2258
- if self .config .sequence_parallel :
2259
- dist .all_reduce (loss , op = ReduceOp .SUM , group = self .mp_group )
2260
-
2261
2300
return loss
2262
2301
2263
2302
def add_loss (main_loss , loss ):
@@ -2269,28 +2308,15 @@ def add_loss(main_loss, loss):
2269
2308
masked_lm_labels = masked_lm_labels [:, : - self .config .num_nextn_predict_layers ]
2270
2309
seq_length = masked_lm_labels .shape [1 ]
2271
2310
2272
- if self .config .sequence_parallel :
2273
- masked_lm_labels = masked_lm_labels .transpose ([1 , 0 ]) # [B, S] --> [S, B]
2274
- masked_lm_labels = ScatterOp .apply (masked_lm_labels )
2275
-
2276
2311
loss = compute_loss (prediction_scores , masked_lm_labels )
2277
2312
2278
2313
mtp_loss_res = []
2279
2314
for depth in range (self .config .num_nextn_predict_layers ):
2280
2315
prediction_scores_cur_depth = mtp_logits [depth ]
2281
2316
masked_lm_labels_cur_depth = masked_lm_labels_ori [:, (depth + 1 ) : (depth + 1 + seq_length )]
2282
-
2283
- if self .config .sequence_parallel :
2284
- masked_lm_labels_cur_depth = masked_lm_labels_cur_depth .transpose ([1 , 0 ]) # [B, S] --> [S, B]
2285
- masked_lm_labels_cur_depth = ScatterOp .apply (masked_lm_labels_cur_depth )
2286
-
2287
2317
res_cur_depth = compute_loss (prediction_scores_cur_depth , masked_lm_labels_cur_depth )
2288
-
2289
- if self .config .sequence_parallel :
2290
- res_cur_depth = res_cur_depth * self .seq_para_scale
2291
- dist .all_reduce (res_cur_depth , op = ReduceOp .SUM , group = self .mp_group )
2292
-
2293
2318
mtp_loss_res .append (res_cur_depth )
2319
+
2294
2320
loss = add_loss (loss , self .config .num_nextn_predict_lambda * sum ([x for x in mtp_loss_res ]) / len (mtp_loss_res )) # fmt: skip
2295
2321
2296
2322
else :
@@ -2336,9 +2362,9 @@ def __init__(self, config: DeepseekV2Config):
2336
2362
2337
2363
def forward (self , hidden_states , tensor_parallel_output = None ):
2338
2364
2339
- # if self.config.sequence_parallel:
2340
- # hidden_states = GatherOp.apply(hidden_states)
2341
- # hidden_states = paddle.transpose(hidden_states, [1, 0, 2])
2365
+ if self .config .sequence_parallel :
2366
+ hidden_states = GatherOp .apply (hidden_states )
2367
+ hidden_states = paddle .transpose (hidden_states , [1 , 0 , 2 ]) # [S, B, H] --> [B, S, H]
2342
2368
# hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size])
2343
2369
2344
2370
if tensor_parallel_output is None :
0 commit comments