Skip to content

Commit 09c979f

Browse files
authored
[V1 Loader] Support Ernie text(moe and dense) (#3110)
* new loader support 0.3B * fix weight * support parallel load * support parallel load * fix slice * support moe * delete code * perfect code * perfect code
1 parent ab60292 commit 09c979f

File tree

6 files changed

+223
-90
lines changed

6 files changed

+223
-90
lines changed

fastdeploy/model_executor/layers/linear.py

Lines changed: 93 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Optional
1818

19+
import numpy as np
1920
import paddle
2021
from paddle import nn
2122

@@ -392,30 +393,48 @@ def __init__(
392393
)
393394

394395
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
395-
# 1.fused gate_up in disk
396-
# 2.split gate up
397-
assert loaded_shard_id in ["gate", "up"]
398-
output_dim = getattr(param, "output_dim", None)
399-
# Tensor parallelism splits the weight along the output_dim
400-
if output_dim is not None:
401-
dim = -1
402-
size = loaded_weight.get_shape()[dim]
403-
block_size = size // self.nranks
404-
shard_offset = self.local_rank * block_size
405-
shard_size = (self.local_rank + 1) * block_size
406-
loaded_weight = loaded_weight[..., shard_offset:shard_size]
407-
408-
loaded_weight = get_tensor(loaded_weight)
409-
410-
if loaded_shard_id == "gate":
411-
param = param[:, : self.output_size // 2]
412-
elif loaded_shard_id == "up":
413-
param = param[:, self.output_size // 2 :]
414-
415-
assert param.shape == loaded_weight.shape, (
416-
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
417-
)
418-
param.copy_(loaded_weight, False)
396+
if loaded_shard_id is None:
397+
# Loaded weight is already fused on disk.
398+
if self.nranks != 1:
399+
shard_offsets = [
400+
# (shard_id, shard_offset, shard_size)
401+
("gate", 0, self.output_size * self.nranks // 2),
402+
("up", self.output_size * self.nranks // 2, self.output_size * self.nranks // 2),
403+
]
404+
for shard_id, shard_offset, shard_size in shard_offsets:
405+
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
406+
self.weight_loader(param, loaded_weight_shard, shard_id)
407+
else:
408+
loaded_weight = get_tensor(loaded_weight)
409+
param.copy_(loaded_weight, False)
410+
else:
411+
# 1.fused gate_up in disk
412+
# 2.split gate up
413+
assert loaded_shard_id in ["gate", "up"]
414+
output_dim = getattr(param, "output_dim", None)
415+
# Tensor parallelism splits the weight along the output_dim
416+
if output_dim is not None:
417+
dim = -1
418+
if isinstance(loaded_weight, np.ndarray):
419+
size = loaded_weight.shape[dim]
420+
else:
421+
size = loaded_weight.get_shape()[dim]
422+
block_size = size // self.nranks
423+
shard_offset = self.local_rank * block_size
424+
shard_size = (self.local_rank + 1) * block_size
425+
loaded_weight = loaded_weight[..., shard_offset:shard_size]
426+
427+
loaded_weight = get_tensor(loaded_weight)
428+
429+
if loaded_shard_id == "gate":
430+
param = param[:, : self.output_size // 2]
431+
elif loaded_shard_id == "up":
432+
param = param[:, self.output_size // 2 :]
433+
434+
assert param.shape == loaded_weight.shape, (
435+
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
436+
)
437+
param.copy_(loaded_weight, False)
419438

420439
def load_state_dict(self, state_dict: dict):
421440
"""
@@ -486,37 +505,57 @@ def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
486505
)
487506

488507
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
489-
# 1.fused qkv in disk
490-
# 2.split q k v
491-
assert loaded_shard_id in ["q", "k", "v"]
492-
output_dim = getattr(param, "output_dim", None)
493-
# Tensor parallelism splits the weight along the output_dim
494-
if output_dim is not None:
495-
dim = -1
496-
size = loaded_weight.get_shape()[dim]
497-
block_size = size // self.nranks
498-
shard_offset = self.local_rank * block_size
499-
shard_size = (self.local_rank + 1) * block_size
500-
loaded_weight = loaded_weight[..., shard_offset:shard_size]
501-
502-
loaded_weight = get_tensor(loaded_weight)
503-
504-
if loaded_shard_id == "q":
505-
param = param[:, : self.num_heads_per_rank * self.head_dim]
506-
elif loaded_shard_id == "k":
507-
param = param[
508-
:,
509-
self.num_heads_per_rank
510-
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
511-
* self.head_dim,
512-
]
513-
elif loaded_shard_id == "v":
514-
param = param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :]
508+
if loaded_shard_id is None:
509+
# Loaded weight is already fused on disk
510+
if self.nranks != 1:
511+
shard_offsets = [
512+
# (shard_id, shard_offset, shard_size)
513+
("q", 0, self.num_heads * self.head_dim),
514+
("k", self.num_heads * self.head_dim, self.kv_num_heads * self.head_dim),
515+
("v", (self.num_heads + self.kv_num_heads) * self.head_dim, self.kv_num_heads * self.head_dim),
516+
]
517+
for shard_id, shard_offset, shard_size in shard_offsets:
518+
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
519+
self.weight_loader(param, loaded_weight_shard, shard_id)
520+
else:
521+
loaded_weight = get_tensor(loaded_weight)
522+
split_loaded_weight = loaded_weight
523+
param.copy_(split_loaded_weight, False)
524+
else:
525+
# 1.fused qkv in disk
526+
# 2.split q k v
527+
assert loaded_shard_id in ["q", "k", "v"]
528+
output_dim = getattr(param, "output_dim", None)
529+
# Tensor parallelism splits the weight along the output_dim
530+
if output_dim is not None:
531+
dim = -1
532+
if isinstance(loaded_weight, np.ndarray):
533+
size = loaded_weight.shape[dim]
534+
else:
535+
size = loaded_weight.get_shape()[dim]
536+
block_size = size // self.nranks
537+
shard_offset = self.local_rank * block_size
538+
shard_size = (self.local_rank + 1) * block_size
539+
loaded_weight = loaded_weight[..., shard_offset:shard_size]
540+
541+
loaded_weight = get_tensor(loaded_weight)
542+
543+
if loaded_shard_id == "q":
544+
param = param[:, : self.num_heads_per_rank * self.head_dim]
545+
elif loaded_shard_id == "k":
546+
param = param[
547+
:,
548+
self.num_heads_per_rank
549+
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
550+
* self.head_dim,
551+
]
552+
elif loaded_shard_id == "v":
553+
param = param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :]
515554

516-
assert param.shape == loaded_weight.shape, (
517-
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
518-
)
519-
param.copy_(loaded_weight, False)
555+
assert param.shape == loaded_weight.shape, (
556+
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
557+
)
558+
param.copy_(loaded_weight, False)
520559

521560
def load_weight(self, state_dict: dict):
522561
"""

fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,10 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
203203

204204
set_weight_attrs(layer.up_gate_proj_weight, extra_weight_attrs)
205205
set_weight_attrs(layer.down_proj_weight, extra_weight_attrs)
206+
207+
if layer.moe_use_gate_correction_bias:
208+
gate_correction_bias_shape = [1, layer.num_experts]
209+
layer.gate_correction_bias = layer.create_parameter(
210+
shape=gate_correction_bias_shape,
211+
dtype="float32",
212+
)

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Optional
1818

19+
import numpy as np
1920
import paddle
2021
from paddle import nn
2122
from paddleformers.utils.log import logger
@@ -110,13 +111,18 @@ def __init__(
110111
self.weight_key_map = weight_key_map
111112

112113
self.use_method = envs.FD_MOE_BACKEND.lower()
113-
self.gate_correction_bias = None
114114
self.moe_tag = moe_tag
115115
if self.ep_size > 1:
116116
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts
117117

118118
self.expert_id_offset = expert_id_offset
119119

120+
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
121+
if self.gate_correction_bias_key is not None:
122+
self.moe_use_gate_correction_bias = True
123+
else:
124+
self.moe_use_gate_correction_bias = False
125+
120126
# used for deepseek_v3
121127
self.topk_method = topk_method
122128
self.topk_group = topk_group
@@ -175,20 +181,33 @@ def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str]
175181

176182
if shard_id is None:
177183
# 1.gate up fused in disk
178-
return
179-
# 2.gate up splited in disk
180-
assert shard_id in ["gate", "down", "up"]
181-
expert_param = param[expert_id]
182-
if current_platform.is_cuda():
183-
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
184+
if self.tp_size > 1:
185+
shard_offsets = [
186+
# (shard_id, shard_offset, shard_size)
187+
("gate", 0, self.moe_intermediate_size * self.tp_size),
188+
("up", self.moe_intermediate_size * self.tp_size, self.moe_intermediate_size * self.tp_size),
189+
]
190+
for shard_id, shard_offset, shard_size in shard_offsets:
191+
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
192+
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id)
193+
else:
194+
expert_param = param[expert_id]
195+
loaded_weight = get_tensor(loaded_weight)
196+
expert_param.copy_(loaded_weight, False)
184197
else:
185-
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
186-
self._load_expert_weight(
187-
expert_param=expert_param,
188-
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
189-
loaded_weight=loaded_weight,
190-
shard_id=shard_id,
191-
)
198+
# 2.gate up splited in disk
199+
assert shard_id in ["gate", "down", "up"]
200+
if current_platform.is_cuda():
201+
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
202+
else:
203+
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
204+
self._load_expert_weight(
205+
param=param,
206+
expert_id=expert_id,
207+
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
208+
loaded_weight=loaded_weight,
209+
shard_id=shard_id,
210+
)
192211

193212
def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
194213
tensor_size = expert_param.shape[shard_dim] // 2
@@ -198,7 +217,10 @@ def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id)
198217
expert_param = expert_param[..., tensor_size:] if shard_dim else expert_param[tensor_size:, ...]
199218

200219
if self.tp_size > 1:
201-
size = loaded_weight.get_shape()[-1]
220+
if isinstance(loaded_weight, np.ndarray):
221+
size = loaded_weight.shape[-1]
222+
else:
223+
size = loaded_weight.get_shape()[-1]
202224
block_size = size // self.tp_size
203225
shard_offset = self.tp_rank * block_size
204226
shard_size = (self.tp_rank + 1) * block_size
@@ -215,7 +237,10 @@ def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id)
215237

216238
def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
217239
if self.tp_size > 1:
218-
size = loaded_weight.get_shape()[shard_dim]
240+
if isinstance(loaded_weight, np.ndarray):
241+
size = loaded_weight.shape[shard_dim]
242+
else:
243+
size = loaded_weight.get_shape()[shard_dim]
219244
block_size = size // self.tp_size
220245
shard_offset = self.tp_rank * block_size
221246
shard_size = (self.tp_rank + 1) * block_size
@@ -231,11 +256,13 @@ def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
231256

232257
def _load_expert_weight(
233258
self,
234-
expert_param,
259+
param,
260+
expert_id,
235261
shard_dim,
236262
loaded_weight,
237263
shard_id,
238264
):
265+
expert_param = param[expert_id]
239266
if shard_id == "down":
240267
self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id)
241268
elif shard_id in ["gate", "up"]:
@@ -244,29 +271,32 @@ def _load_expert_weight(
244271
@classmethod
245272
def make_expert_params_mapping(
246273
cls,
247-
ckpt_gate_proj_name: str,
248-
ckpt_down_proj_name: str,
249-
ckpt_up_proj_name: str,
250-
param_gate_up_proj_name: str,
251-
param_down_proj_name: str,
252274
num_experts: int,
253-
ckpt_expert_key_name: str = "experts",
275+
ckpt_gate_proj_name: Optional[str] = None,
276+
ckpt_up_proj_name: Optional[str] = None,
277+
ckpt_down_proj_name: Optional[str] = None,
254278
ckpt_gate_up_proj_name: Optional[str] = None,
279+
param_gate_up_proj_name: Optional[str] = None,
280+
param_down_proj_name: Optional[str] = None,
281+
ckpt_expert_key_name: str = "experts",
255282
) -> list[tuple[str, str, int, str]]:
256-
param_name_maping = [
257-
("gate", ckpt_gate_proj_name),
258-
("down", ckpt_down_proj_name),
259-
("up", ckpt_up_proj_name),
260-
]
283+
param_name_maping = []
284+
261285
if ckpt_gate_up_proj_name:
262286
param_name_maping.append((None, ckpt_gate_up_proj_name))
287+
if ckpt_gate_proj_name:
288+
param_name_maping.append(("gate", ckpt_gate_proj_name))
289+
if ckpt_down_proj_name:
290+
param_name_maping.append(("down", ckpt_down_proj_name))
291+
if ckpt_up_proj_name:
292+
param_name_maping.append(("up", ckpt_up_proj_name))
263293

264294
return [
265295
# (param_name, weight_name, expert_id, shard_id)
266296
(
267297
(
268298
param_gate_up_proj_name
269-
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
299+
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name, ckpt_gate_up_proj_name]
270300
else param_down_proj_name
271301
),
272302
f"{ckpt_expert_key_name}.{expert_id}.{weight_name}.",
@@ -505,11 +535,6 @@ def load_state_dict(self, state_dict, is_rearrange: bool = False):
505535
load_state_dict function.
506536
"""
507537
if not is_rearrange:
508-
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
509-
if self.gate_correction_bias_key is not None and self.gate_correction_bias_key in state_dict:
510-
self.moe_use_gate_correction_bias = True
511-
else:
512-
self.moe_use_gate_correction_bias = False
513538
if self.moe_use_gate_correction_bias:
514539
gate_correction_bias_tensor = self.extract_gate_correction_bias(
515540
self.gate_correction_bias_key, state_dict

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,12 +647,12 @@ def load_weights(self, weights_iterator) -> None:
647647
]
648648
# (param_name, weight_name, expert_id, shard_id)
649649
expert_params_mapping = FusedMoE.make_expert_params_mapping(
650+
num_experts=self.fd_config.model_config.n_routed_experts,
650651
ckpt_gate_proj_name="gate_proj",
651652
ckpt_down_proj_name="down_proj",
652653
ckpt_up_proj_name="up_proj",
653654
param_gate_up_proj_name="experts.up_gate_proj_",
654655
param_down_proj_name="experts.down_proj_",
655-
num_experts=self.fd_config.model_config.n_routed_experts,
656656
)
657657
params_dict = dict(self.named_parameters())
658658

0 commit comments

Comments
 (0)