Skip to content
147 changes: 93 additions & 54 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Optional

import numpy as np
import paddle
from paddle import nn

Expand Down Expand Up @@ -392,30 +393,48 @@ def __init__(
)

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# 1.fused gate_up in disk
# 2.split gate up
assert loaded_shard_id in ["gate", "up"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]

loaded_weight = get_tensor(loaded_weight)

if loaded_shard_id == "gate":
param = param[:, : self.output_size // 2]
elif loaded_shard_id == "up":
param = param[:, self.output_size // 2 :]

assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
if loaded_shard_id is None:
# Loaded weight is already fused on disk.
if self.nranks != 1:
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("gate", 0, self.output_size * self.nranks // 2),
("up", self.output_size * self.nranks // 2, self.output_size * self.nranks // 2),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
loaded_weight = get_tensor(loaded_weight)
param.copy_(loaded_weight, False)
else:
# 1.fused gate_up in disk
# 2.split gate up
assert loaded_shard_id in ["gate", "up"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1
if isinstance(loaded_weight, np.ndarray):
size = loaded_weight.shape[dim]
else:
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]

loaded_weight = get_tensor(loaded_weight)

if loaded_shard_id == "gate":
param = param[:, : self.output_size // 2]
elif loaded_shard_id == "up":
param = param[:, self.output_size // 2 :]

assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)

def load_state_dict(self, state_dict: dict):
"""
Expand Down Expand Up @@ -486,37 +505,57 @@ def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
)

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# 1.fused qkv in disk
# 2.split q k v
assert loaded_shard_id in ["q", "k", "v"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]

loaded_weight = get_tensor(loaded_weight)

if loaded_shard_id == "q":
param = param[:, : self.num_heads_per_rank * self.head_dim]
elif loaded_shard_id == "k":
param = param[
:,
self.num_heads_per_rank
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
* self.head_dim,
]
elif loaded_shard_id == "v":
param = param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :]
if loaded_shard_id is None:
# Loaded weight is already fused on disk
if self.nranks != 1:
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.num_heads * self.head_dim),
("k", self.num_heads * self.head_dim, self.kv_num_heads * self.head_dim),
("v", (self.num_heads + self.kv_num_heads) * self.head_dim, self.kv_num_heads * self.head_dim),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
loaded_weight = get_tensor(loaded_weight)
split_loaded_weight = loaded_weight
param.copy_(split_loaded_weight, False)
else:
# 1.fused qkv in disk
# 2.split q k v
assert loaded_shard_id in ["q", "k", "v"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1
if isinstance(loaded_weight, np.ndarray):
size = loaded_weight.shape[dim]
else:
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]

loaded_weight = get_tensor(loaded_weight)

if loaded_shard_id == "q":
param = param[:, : self.num_heads_per_rank * self.head_dim]
elif loaded_shard_id == "k":
param = param[
:,
self.num_heads_per_rank
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
* self.head_dim,
]
elif loaded_shard_id == "v":
param = param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :]

assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)

def load_weight(self, state_dict: dict):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,10 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):

set_weight_attrs(layer.up_gate_proj_weight, extra_weight_attrs)
set_weight_attrs(layer.down_proj_weight, extra_weight_attrs)

if layer.moe_use_gate_correction_bias:
gate_correction_bias_shape = [1, layer.num_experts]
layer.gate_correction_bias = layer.create_parameter(
shape=gate_correction_bias_shape,
dtype="float32",
)
93 changes: 59 additions & 34 deletions fastdeploy/model_executor/layers/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Optional

import numpy as np
import paddle
from paddle import nn
from paddleformers.utils.log import logger
Expand Down Expand Up @@ -110,13 +111,18 @@ def __init__(
self.weight_key_map = weight_key_map

self.use_method = envs.FD_MOE_BACKEND.lower()
self.gate_correction_bias = None
self.moe_tag = moe_tag
if self.ep_size > 1:
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts

self.expert_id_offset = expert_id_offset

self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
if self.gate_correction_bias_key is not None:
self.moe_use_gate_correction_bias = True
else:
self.moe_use_gate_correction_bias = False

# used for deepseek_v3
self.topk_method = topk_method
self.topk_group = topk_group
Expand Down Expand Up @@ -175,20 +181,33 @@ def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str]

if shard_id is None:
# 1.gate up fused in disk
return
# 2.gate up splited in disk
assert shard_id in ["gate", "down", "up"]
expert_param = param[expert_id]
if current_platform.is_cuda():
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
if self.tp_size > 1:
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("gate", 0, self.moe_intermediate_size * self.tp_size),
("up", self.moe_intermediate_size * self.tp_size, self.moe_intermediate_size * self.tp_size),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id)
else:
expert_param = param[expert_id]
loaded_weight = get_tensor(loaded_weight)
expert_param.copy_(loaded_weight, False)
else:
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
self._load_expert_weight(
expert_param=expert_param,
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
loaded_weight=loaded_weight,
shard_id=shard_id,
)
# 2.gate up splited in disk
assert shard_id in ["gate", "down", "up"]
if current_platform.is_cuda():
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
else:
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
self._load_expert_weight(
param=param,
expert_id=expert_id,
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
loaded_weight=loaded_weight,
shard_id=shard_id,
)

def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
tensor_size = expert_param.shape[shard_dim] // 2
Expand All @@ -198,7 +217,10 @@ def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id)
expert_param = expert_param[..., tensor_size:] if shard_dim else expert_param[tensor_size:, ...]

if self.tp_size > 1:
size = loaded_weight.get_shape()[-1]
if isinstance(loaded_weight, np.ndarray):
size = loaded_weight.shape[-1]
else:
size = loaded_weight.get_shape()[-1]
Comment on lines +220 to +223
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会有ndarray和tensor两种情况吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会有ndarray和pyslice俩种情况

block_size = size // self.tp_size
shard_offset = self.tp_rank * block_size
shard_size = (self.tp_rank + 1) * block_size
Expand All @@ -215,7 +237,10 @@ def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id)

def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
if self.tp_size > 1:
size = loaded_weight.get_shape()[shard_dim]
if isinstance(loaded_weight, np.ndarray):
size = loaded_weight.shape[shard_dim]
else:
size = loaded_weight.get_shape()[shard_dim]
block_size = size // self.tp_size
shard_offset = self.tp_rank * block_size
shard_size = (self.tp_rank + 1) * block_size
Expand All @@ -231,11 +256,13 @@ def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id):

def _load_expert_weight(
self,
expert_param,
param,
expert_id,
shard_dim,
loaded_weight,
shard_id,
):
expert_param = param[expert_id]
if shard_id == "down":
self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id)
elif shard_id in ["gate", "up"]:
Expand All @@ -244,29 +271,32 @@ def _load_expert_weight(
@classmethod
def make_expert_params_mapping(
cls,
ckpt_gate_proj_name: str,
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
param_gate_up_proj_name: str,
param_down_proj_name: str,
num_experts: int,
ckpt_expert_key_name: str = "experts",
ckpt_gate_proj_name: Optional[str] = None,
ckpt_up_proj_name: Optional[str] = None,
ckpt_down_proj_name: Optional[str] = None,
ckpt_gate_up_proj_name: Optional[str] = None,
param_gate_up_proj_name: Optional[str] = None,
param_down_proj_name: Optional[str] = None,
ckpt_expert_key_name: str = "experts",
) -> list[tuple[str, str, int, str]]:
param_name_maping = [
("gate", ckpt_gate_proj_name),
("down", ckpt_down_proj_name),
("up", ckpt_up_proj_name),
]
param_name_maping = []

if ckpt_gate_up_proj_name:
param_name_maping.append((None, ckpt_gate_up_proj_name))
if ckpt_gate_proj_name:
param_name_maping.append(("gate", ckpt_gate_proj_name))
if ckpt_down_proj_name:
param_name_maping.append(("down", ckpt_down_proj_name))
if ckpt_up_proj_name:
param_name_maping.append(("up", ckpt_up_proj_name))

return [
# (param_name, weight_name, expert_id, shard_id)
(
(
param_gate_up_proj_name
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name, ckpt_gate_up_proj_name]
else param_down_proj_name
),
f"{ckpt_expert_key_name}.{expert_id}.{weight_name}.",
Expand Down Expand Up @@ -505,11 +535,6 @@ def load_state_dict(self, state_dict, is_rearrange: bool = False):
load_state_dict function.
"""
if not is_rearrange:
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
if self.gate_correction_bias_key is not None and self.gate_correction_bias_key in state_dict:
self.moe_use_gate_correction_bias = True
else:
self.moe_use_gate_correction_bias = False
if self.moe_use_gate_correction_bias:
gate_correction_bias_tensor = self.extract_gate_correction_bias(
self.gate_correction_bias_key, state_dict
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,12 +647,12 @@ def load_weights(self, weights_iterator) -> None:
]
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
num_experts=self.fd_config.model_config.n_routed_experts,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
param_gate_up_proj_name="experts.up_gate_proj_",
param_down_proj_name="experts.down_proj_",
num_experts=self.fd_config.model_config.n_routed_experts,
)
params_dict = dict(self.named_parameters())

Expand Down
Loading
Loading