16
16
17
17
from typing import Optional
18
18
19
+ import numpy as np
19
20
import paddle
20
21
from paddle import nn
21
22
from paddleformers .utils .log import logger
@@ -110,13 +111,18 @@ def __init__(
110
111
self .weight_key_map = weight_key_map
111
112
112
113
self .use_method = envs .FD_MOE_BACKEND .lower ()
113
- self .gate_correction_bias = None
114
114
self .moe_tag = moe_tag
115
115
if self .ep_size > 1 :
116
116
expert_id_offset = expert_id_offset + self .ep_rank * self .num_local_experts
117
117
118
118
self .expert_id_offset = expert_id_offset
119
119
120
+ self .gate_correction_bias_key = self .weight_key_map .get ("gate_correction_bias_key" , None )
121
+ if self .gate_correction_bias_key is not None :
122
+ self .moe_use_gate_correction_bias = True
123
+ else :
124
+ self .moe_use_gate_correction_bias = False
125
+
120
126
# used for deepseek_v3
121
127
self .topk_method = topk_method
122
128
self .topk_group = topk_group
@@ -175,20 +181,33 @@ def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str]
175
181
176
182
if shard_id is None :
177
183
# 1.gate up fused in disk
178
- return
179
- # 2.gate up splited in disk
180
- assert shard_id in ["gate" , "down" , "up" ]
181
- expert_param = param [expert_id ]
182
- if current_platform .is_cuda ():
183
- SHARD_ID_TO_SHARDED_DIM = {"gate" : 1 , "down" : 0 , "up" : 1 }
184
+ if self .tp_size > 1 :
185
+ shard_offsets = [
186
+ # (shard_id, shard_offset, shard_size)
187
+ ("gate" , 0 , self .moe_intermediate_size * self .tp_size ),
188
+ ("up" , self .moe_intermediate_size * self .tp_size , self .moe_intermediate_size * self .tp_size ),
189
+ ]
190
+ for shard_id , shard_offset , shard_size in shard_offsets :
191
+ loaded_weight_shard = loaded_weight [..., shard_offset : shard_offset + shard_size ]
192
+ self .weight_loader (param , loaded_weight_shard , expert_id , shard_id )
193
+ else :
194
+ expert_param = param [expert_id ]
195
+ loaded_weight = get_tensor (loaded_weight )
196
+ expert_param .copy_ (loaded_weight , False )
184
197
else :
185
- SHARD_ID_TO_SHARDED_DIM = {"gate" : 0 , "down" : 1 , "up" : 0 }
186
- self ._load_expert_weight (
187
- expert_param = expert_param ,
188
- shard_dim = SHARD_ID_TO_SHARDED_DIM [shard_id ],
189
- loaded_weight = loaded_weight ,
190
- shard_id = shard_id ,
191
- )
198
+ # 2.gate up splited in disk
199
+ assert shard_id in ["gate" , "down" , "up" ]
200
+ if current_platform .is_cuda ():
201
+ SHARD_ID_TO_SHARDED_DIM = {"gate" : 1 , "down" : 0 , "up" : 1 }
202
+ else :
203
+ SHARD_ID_TO_SHARDED_DIM = {"gate" : 0 , "down" : 1 , "up" : 0 }
204
+ self ._load_expert_weight (
205
+ param = param ,
206
+ expert_id = expert_id ,
207
+ shard_dim = SHARD_ID_TO_SHARDED_DIM [shard_id ],
208
+ loaded_weight = loaded_weight ,
209
+ shard_id = shard_id ,
210
+ )
192
211
193
212
def _load_gate_up_weight (self , expert_param , shard_dim , loaded_weight , shard_id ):
194
213
tensor_size = expert_param .shape [shard_dim ] // 2
@@ -198,7 +217,10 @@ def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id)
198
217
expert_param = expert_param [..., tensor_size :] if shard_dim else expert_param [tensor_size :, ...]
199
218
200
219
if self .tp_size > 1 :
201
- size = loaded_weight .get_shape ()[- 1 ]
220
+ if isinstance (loaded_weight , np .ndarray ):
221
+ size = loaded_weight .shape [- 1 ]
222
+ else :
223
+ size = loaded_weight .get_shape ()[- 1 ]
202
224
block_size = size // self .tp_size
203
225
shard_offset = self .tp_rank * block_size
204
226
shard_size = (self .tp_rank + 1 ) * block_size
@@ -215,7 +237,10 @@ def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id)
215
237
216
238
def _load_down_weight (self , expert_param , shard_dim , loaded_weight , shard_id ):
217
239
if self .tp_size > 1 :
218
- size = loaded_weight .get_shape ()[shard_dim ]
240
+ if isinstance (loaded_weight , np .ndarray ):
241
+ size = loaded_weight .shape [shard_dim ]
242
+ else :
243
+ size = loaded_weight .get_shape ()[shard_dim ]
219
244
block_size = size // self .tp_size
220
245
shard_offset = self .tp_rank * block_size
221
246
shard_size = (self .tp_rank + 1 ) * block_size
@@ -231,11 +256,13 @@ def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
231
256
232
257
def _load_expert_weight (
233
258
self ,
234
- expert_param ,
259
+ param ,
260
+ expert_id ,
235
261
shard_dim ,
236
262
loaded_weight ,
237
263
shard_id ,
238
264
):
265
+ expert_param = param [expert_id ]
239
266
if shard_id == "down" :
240
267
self ._load_down_weight (expert_param , shard_dim , loaded_weight , shard_id )
241
268
elif shard_id in ["gate" , "up" ]:
@@ -244,29 +271,32 @@ def _load_expert_weight(
244
271
@classmethod
245
272
def make_expert_params_mapping (
246
273
cls ,
247
- ckpt_gate_proj_name : str ,
248
- ckpt_down_proj_name : str ,
249
- ckpt_up_proj_name : str ,
250
- param_gate_up_proj_name : str ,
251
- param_down_proj_name : str ,
252
274
num_experts : int ,
253
- ckpt_expert_key_name : str = "experts" ,
275
+ ckpt_gate_proj_name : Optional [str ] = None ,
276
+ ckpt_up_proj_name : Optional [str ] = None ,
277
+ ckpt_down_proj_name : Optional [str ] = None ,
254
278
ckpt_gate_up_proj_name : Optional [str ] = None ,
279
+ param_gate_up_proj_name : Optional [str ] = None ,
280
+ param_down_proj_name : Optional [str ] = None ,
281
+ ckpt_expert_key_name : str = "experts" ,
255
282
) -> list [tuple [str , str , int , str ]]:
256
- param_name_maping = [
257
- ("gate" , ckpt_gate_proj_name ),
258
- ("down" , ckpt_down_proj_name ),
259
- ("up" , ckpt_up_proj_name ),
260
- ]
283
+ param_name_maping = []
284
+
261
285
if ckpt_gate_up_proj_name :
262
286
param_name_maping .append ((None , ckpt_gate_up_proj_name ))
287
+ if ckpt_gate_proj_name :
288
+ param_name_maping .append (("gate" , ckpt_gate_proj_name ))
289
+ if ckpt_down_proj_name :
290
+ param_name_maping .append (("down" , ckpt_down_proj_name ))
291
+ if ckpt_up_proj_name :
292
+ param_name_maping .append (("up" , ckpt_up_proj_name ))
263
293
264
294
return [
265
295
# (param_name, weight_name, expert_id, shard_id)
266
296
(
267
297
(
268
298
param_gate_up_proj_name
269
- if weight_name in [ckpt_gate_proj_name , ckpt_up_proj_name ]
299
+ if weight_name in [ckpt_gate_proj_name , ckpt_up_proj_name , ckpt_gate_up_proj_name ]
270
300
else param_down_proj_name
271
301
),
272
302
f"{ ckpt_expert_key_name } .{ expert_id } .{ weight_name } ." ,
@@ -505,11 +535,6 @@ def load_state_dict(self, state_dict, is_rearrange: bool = False):
505
535
load_state_dict function.
506
536
"""
507
537
if not is_rearrange :
508
- self .gate_correction_bias_key = self .weight_key_map .get ("gate_correction_bias_key" , None )
509
- if self .gate_correction_bias_key is not None and self .gate_correction_bias_key in state_dict :
510
- self .moe_use_gate_correction_bias = True
511
- else :
512
- self .moe_use_gate_correction_bias = False
513
538
if self .moe_use_gate_correction_bias :
514
539
gate_correction_bias_tensor = self .extract_gate_correction_bias (
515
540
self .gate_correction_bias_key , state_dict
0 commit comments